From 8160109e6dc4717efc4d5467b0b83618b0ce29ae Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 5 Jan 2023 17:00:47 +0000 Subject: [PATCH 001/725] GeNN now has a top-level namespace --- include/genn/backends/cuda/backend.h | 9 +- include/genn/backends/cuda/optimiser.h | 16 ++- include/genn/backends/opencl/backend.h | 9 +- include/genn/backends/opencl/optimiser.h | 12 +-- include/genn/backends/opencl/utils.h | 12 +-- .../backends/single_threaded_cpu/backend.h | 9 +- .../backends/single_threaded_cpu/optimiser.h | 12 +-- include/genn/genn/binomial.h | 5 +- .../genn/genn/code_generator/backendBase.h | 12 ++- .../genn/genn/code_generator/backendSIMT.h | 8 +- .../genn/genn/code_generator/codeGenUtils.h | 6 +- include/genn/genn/code_generator/codeStream.h | 6 +- .../customConnectivityUpdateGroupMerged.h | 10 +- .../code_generator/customUpdateGroupMerged.h | 18 ++-- .../genn/code_generator/generateMSBuild.h | 7 +- .../genn/code_generator/generateMakefile.h | 7 +- .../genn/code_generator/generateModules.h | 5 +- .../genn/genn/code_generator/generateRunner.h | 6 +- .../genn/code_generator/generateSupportCode.h | 6 +- .../genn/genn/code_generator/groupMerged.h | 20 ++-- .../genn/code_generator/initGroupMerged.h | 28 ++--- .../genn/code_generator/modelSpecMerged.h | 8 +- .../code_generator/neuronUpdateGroupMerged.h | 6 +- .../presynapticUpdateStrategySIMT.h | 22 ++-- .../genn/genn/code_generator/substitutions.h | 6 +- .../genn/code_generator/supportCodeMerged.h | 6 +- .../code_generator/synapseUpdateGroupMerged.h | 16 ++- include/genn/genn/code_generator/teeStream.h | 8 +- include/genn/genn/currentSource.h | 8 +- include/genn/genn/currentSourceInternal.h | 5 +- include/genn/genn/currentSourceModels.h | 6 +- include/genn/genn/customConnectivityUpdate.h | 5 +- .../genn/customConnectivityUpdateInternal.h | 3 + .../genn/customConnectivityUpdateModels.h | 6 +- include/genn/genn/customUpdate.h | 5 +- include/genn/genn/customUpdateInternal.h | 5 +- include/genn/genn/customUpdateModels.h | 6 +- include/genn/genn/gennUtils.h | 8 +- .../genn/genn/initSparseConnectivitySnippet.h | 6 +- .../genn/initToeplitzConnectivitySnippet.h | 12 +-- include/genn/genn/initVarSnippet.h | 28 ++--- include/genn/genn/logging.h | 4 +- include/genn/genn/modelSpec.h | 8 +- include/genn/genn/modelSpecInternal.h | 5 +- include/genn/genn/models.h | 18 ++-- include/genn/genn/neuronGroup.h | 8 +- include/genn/genn/neuronGroupInternal.h | 6 +- include/genn/genn/neuronModels.h | 32 +++--- include/genn/genn/postsynapticModels.h | 12 +-- include/genn/genn/snippet.h | 6 +- include/genn/genn/synapseGroup.h | 8 +- include/genn/genn/synapseGroupInternal.h | 5 +- include/genn/genn/synapseMatrixType.h | 3 + include/genn/genn/varAccess.h | 4 +- include/genn/genn/variableMode.h | 5 +- include/genn/genn/weightUpdateModels.h | 14 +-- src/genn/backends/cuda/backend.cc | 13 +-- src/genn/backends/cuda/optimiser.cc | 18 ++-- src/genn/backends/opencl/backend.cc | 12 +-- src/genn/backends/opencl/optimiser.cc | 12 +-- .../backends/single_threaded_cpu/backend.cc | 15 ++- .../backends/single_threaded_cpu/optimiser.cc | 12 +-- src/genn/generator/generator.cc | 3 +- src/genn/genn/binomial.cc | 2 +- src/genn/genn/code_generator/backendBase.cc | 6 +- src/genn/genn/code_generator/backendSIMT.cc | 14 ++- src/genn/genn/code_generator/codeGenUtils.cc | 75 +++---------- src/genn/genn/code_generator/codeStream.cc | 12 +-- .../customConnectivityUpdateGroupMerged.cc | 3 +- .../code_generator/customUpdateGroupMerged.cc | 9 +- .../genn/code_generator/generateMSBuild.cc | 8 +- .../genn/code_generator/generateMakefile.cc | 6 +- .../genn/code_generator/generateModules.cc | 30 +++--- .../genn/code_generator/generateRunner.cc | 9 +- .../code_generator/generateSupportCode.cc | 6 +- src/genn/genn/code_generator/groupMerged.cc | 13 +-- .../genn/code_generator/initGroupMerged.cc | 17 +-- .../genn/code_generator/modelSpecMerged.cc | 5 +- .../code_generator/neuronUpdateGroupMerged.cc | 5 +- .../presynapticUpdateStrategySIMT.cc | 31 +++--- src/genn/genn/code_generator/substitutions.cc | 29 ++--- .../synapseUpdateGroupMerged.cc | 10 +- src/genn/genn/currentSource.cc | 5 +- src/genn/genn/currentSourceModels.cc | 21 ++-- src/genn/genn/customConnectivityUpdate.cc | 2 + .../genn/customConnectivityUpdateModels.cc | 23 ++-- src/genn/genn/customUpdate.cc | 7 +- src/genn/genn/customUpdateModels.cc | 11 +- src/genn/genn/gennUtils.cc | 6 +- .../genn/initSparseConnectivitySnippet.cc | 27 +++-- .../genn/initToeplitzConnectivitySnippet.cc | 17 +-- src/genn/genn/initVarSnippet.cc | 34 +++--- src/genn/genn/logging.cc | 6 +- src/genn/genn/modelSpec.cc | 102 ++++++++---------- src/genn/genn/models.cc | 18 ++-- src/genn/genn/neuronGroup.cc | 11 +- src/genn/genn/neuronModels.cc | 39 ++++--- src/genn/genn/postsynapticModels.cc | 22 ++-- src/genn/genn/snippet.cc | 11 +- src/genn/genn/synapseGroup.cc | 13 ++- src/genn/genn/weightUpdateModels.cc | 27 +++-- 101 files changed, 671 insertions(+), 642 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index a8d93d53b4..30a91ec87b 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -30,11 +30,9 @@ namespace filesystem } //-------------------------------------------------------------------------- -// CodeGenerator::CUDA::DeviceSelectMethod +// GeNN::CodeGenerator::CUDA::DeviceSelectMethod //-------------------------------------------------------------------------- -namespace CodeGenerator -{ -namespace CUDA +namespace GeNN::CodeGenerator::CUDA { //! Methods for selecting CUDA device enum class DeviceSelect @@ -354,5 +352,4 @@ class BACKEND_EXPORT Backend : public BackendSIMT cudaDeviceProp m_ChosenDevice; int m_RuntimeVersion; }; -} // CUDA -} // CodeGenerator +} // GeNN::CUDA::CodeGenerator diff --git a/include/genn/backends/cuda/optimiser.h b/include/genn/backends/cuda/optimiser.h index bd0aad4442..d6ab6e49db 100644 --- a/include/genn/backends/cuda/optimiser.h +++ b/include/genn/backends/cuda/optimiser.h @@ -7,7 +7,11 @@ #include "backend.h" // Forward declarations +namespace GeNN +{ class ModelSpecInternal; +} + namespace plog { class IAppender; @@ -15,17 +19,11 @@ class IAppender; //-------------------------------------------------------------------------- -// CodeGenerator::CUDA::Optimiser +// GeNN::CodeGenerator::CUDA::Optimiser //-------------------------------------------------------------------------- -namespace CodeGenerator -{ -namespace CUDA -{ -namespace Optimiser +namespace GeNN::CodeGenerator::CUDA::Optimiser { BACKEND_EXPORT Backend createBackend(const ModelSpecInternal &model, const filesystem::path &outputPath, plog::Severity backendLevel, plog::IAppender *backendAppender, const Preferences &preferences); -} // namespace Optimiser -} // namespace CUDA -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator::CUDA::Optimiser diff --git a/include/genn/backends/opencl/backend.h b/include/genn/backends/opencl/backend.h index cf44995dac..3a8265d359 100644 --- a/include/genn/backends/opencl/backend.h +++ b/include/genn/backends/opencl/backend.h @@ -28,11 +28,9 @@ namespace filesystem } //-------------------------------------------------------------------------- -// CodeGenerator::OpenCL::DeviceSelectMethod +// GeNN::CodeGenerator::OpenCL::DeviceSelectMethod //-------------------------------------------------------------------------- -namespace CodeGenerator -{ -namespace OpenCL +namespace GeNN::CodeGenerator::OpenCL { //! Methods for selecting OpenCL platform enum class PlatformSelect @@ -456,5 +454,4 @@ class BACKEND_EXPORT Backend : public BackendSIMT cl::Device m_ChosenDevice; cl::Platform m_ChosenPlatform; }; -} // OpenCL -} // CodeGenerator +} // GeNN::CodeGenerator::OpenCL diff --git a/include/genn/backends/opencl/optimiser.h b/include/genn/backends/opencl/optimiser.h index aa97508e5d..b5611c35f4 100644 --- a/include/genn/backends/opencl/optimiser.h +++ b/include/genn/backends/opencl/optimiser.h @@ -15,17 +15,11 @@ class IAppender; //-------------------------------------------------------------------------- -// CodeGenerator::OpenCL::Optimiser +// GeNN::CodeGenerator::OpenCL::Optimiser //-------------------------------------------------------------------------- -namespace CodeGenerator -{ -namespace OpenCL -{ -namespace Optimiser +namespace GeNN::CodeGenerator::OpenCL::Optimiser { BACKEND_EXPORT Backend createBackend(const ModelSpecInternal &model, const filesystem::path &outputPath, plog::Severity backendLevel, plog::IAppender *backendAppender, const Preferences &preferences); -} // namespace Optimiser -} // namespace CUDA -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator::OpenCL::Optimiser diff --git a/include/genn/backends/opencl/utils.h b/include/genn/backends/opencl/utils.h index c03a100935..127be27bab 100644 --- a/include/genn/backends/opencl/utils.h +++ b/include/genn/backends/opencl/utils.h @@ -20,13 +20,9 @@ } //-------------------------------------------------------------------------- -// CodeGenerator::OpenCL::Utils +// GeNN::CodeGenerator::OpenCL::Utils //-------------------------------------------------------------------------- -namespace CodeGenerator -{ -namespace OpenCL -{ -namespace Utils +namespace GeNN::CodeGenerator::OpenCL::Utils { // OpenCL error string const char *clGetErrorString(cl_int error) @@ -100,6 +96,4 @@ const char *clGetErrorString(cl_int error) } #undef GEN_CL_ERROR_CASE } -} // namespace Utils -} // namespace OpenCL -} // namespace CodeGenerator \ No newline at end of file +} // namespace GeNN::Utils::OpenCL::CodeGenerator diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 25bff17dc1..0190308a19 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -19,11 +19,9 @@ namespace filesystem } //-------------------------------------------------------------------------- -// CodeGenerator::SingleThreadedCPU::Preferences +// GeNN::CodeGenerator::SingleThreadedCPU::Preferences //-------------------------------------------------------------------------- -namespace CodeGenerator -{ -namespace SingleThreadedCPU +namespace GeNN::CodeGenerator::SingleThreadedCPU { struct Preferences : public PreferencesBase { @@ -228,5 +226,4 @@ class BACKEND_EXPORT Backend : public BackendBase } } }; -} // namespace SingleThreadedCPU -} // namespace CodeGenerator +} // namespace GeNN::SingleThreadedCPU::CodeGenerator diff --git a/include/genn/backends/single_threaded_cpu/optimiser.h b/include/genn/backends/single_threaded_cpu/optimiser.h index 4b157def0f..119eeb1a36 100644 --- a/include/genn/backends/single_threaded_cpu/optimiser.h +++ b/include/genn/backends/single_threaded_cpu/optimiser.h @@ -17,17 +17,11 @@ class IAppender; } //-------------------------------------------------------------------------- -// CodeGenerator::SingleThreadedCPU::Optimiser +// GeNN::CodeGenerator::SingleThreadedCPU::Optimiser //-------------------------------------------------------------------------- -namespace CodeGenerator -{ -namespace SingleThreadedCPU -{ -namespace Optimiser +namespace GeNN::CodeGenerator::SingleThreadedCPU::Optimiser { BACKEND_EXPORT Backend createBackend(const ModelSpecInternal &model, const filesystem::path &outputPath, plog::Severity backendLevel, plog::IAppender *backendAppender, const Preferences &preferences); -} -} // namespace SingleThreadedCPU -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator::SingleThreadedCPU::Optimiser diff --git a/include/genn/genn/binomial.h b/include/genn/genn/binomial.h index b69212067c..36e2cf2968 100644 --- a/include/genn/genn/binomial.h +++ b/include/genn/genn/binomial.h @@ -3,4 +3,7 @@ // GeNN includes #include "gennExport.h" -GENN_EXPORT unsigned int binomialInverseCDF(double cdf, unsigned int n, double p); \ No newline at end of file +namespace GeNN +{ +GENN_EXPORT unsigned int binomialInverseCDF(double cdf, unsigned int n, double p); +} diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 29b38ed1a6..e775d27c3d 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -22,6 +22,8 @@ #include "variableMode.h" // Forward declarations +namespace GeNN +{ class CustomUpdateInternal; class CustomUpdateWUInternal; class NeuronGroupInternal; @@ -47,13 +49,13 @@ class CustomWUUpdateSparseInitGroupMerged; class SynapseConnectivityInitGroupMerged; class SynapseInitGroupMerged; class SynapseSparseInitGroupMerged; - +} } //-------------------------------------------------------------------------- -// CodeGenerator::PreferencesBase +// GeNN::CodeGenerator::PreferencesBase //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { //! Base class for backend preferences - can be accessed via a global in 'classic' C++ code generator struct PreferencesBase @@ -99,7 +101,7 @@ struct PreferencesBase }; //-------------------------------------------------------------------------- -// CodeGenerator::MemAlloc +// GeNN::CodeGenerator::MemAlloc //-------------------------------------------------------------------------- class MemAlloc { @@ -546,4 +548,4 @@ class GENN_EXPORT BackendBase //! Preferences const PreferencesBase &m_Preferences; }; -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 33ff58a44c..d7502c7bc0 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -16,9 +16,9 @@ #include "code_generator/substitutions.h" //-------------------------------------------------------------------------- -// CodeGenerator::Kernel +// GeNN::CodeGenerator::Kernel //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { //! Kernels generated by SIMT backends enum Kernel @@ -44,7 +44,7 @@ enum Kernel using KernelBlockSize = std::array; //-------------------------------------------------------------------------- -// CodeGenerator::BackendSIMT +// GeNN::CodeGenerator::BackendSIMT //-------------------------------------------------------------------------- //! Base class for Single Instruction Multiple Thread style backends /*! CUDA terminology is used throughout i.e. thread blocks and shared memory */ @@ -484,4 +484,4 @@ class GENN_EXPORT BackendSIMT : public BackendBase static std::vector s_PresynapticUpdateStrategies; }; -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 60711f87f3..22f3e1c702 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -21,9 +21,9 @@ #include "teeStream.h" //-------------------------------------------------------------------------- -// CodeGenerator +// GeNN::CodeGenerator //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { //-------------------------------------------------------------------------- //! \brief Tool for substituting strings in the neuron code strings or other templates @@ -196,4 +196,4 @@ void genKernelIndex(const G *group, std::ostream &os, const CodeGenerator::Subst } } } -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/codeStream.h b/include/genn/genn/code_generator/codeStream.h index 5676af1eb7..14a9c1eecd 100644 --- a/include/genn/genn/code_generator/codeStream.h +++ b/include/genn/genn/code_generator/codeStream.h @@ -12,11 +12,11 @@ #include "gennExport.h" //---------------------------------------------------------------------------- -// CodeGenerator::CodeStream +// GeNN::CodeGenerator::CodeStream //---------------------------------------------------------------------------- //! Helper class for generating code - automatically inserts brackets, indents etc /*! Based heavily on: https://stackoverflow.com/questions/15053753/writing-a-manipulator-for-a-custom-stream-class */ -namespace CodeGenerator +namespace GeNN::CodeGenerator { class GENN_EXPORT CodeStream : public std::ostream { @@ -160,6 +160,6 @@ class GENN_EXPORT CodeStream : public std::ostream //------------------------------------------------------------------------ GENN_EXPORT std::ostream& operator << (std::ostream& s, const CodeStream::OB &ob); GENN_EXPORT std::ostream& operator << (std::ostream& s, const CodeStream::CB &cb); -} // namespace CodeGenerator; +} // namespace GeNN::CodeGenerator; diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index bdc91819d3..d66e26abcd 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -8,9 +8,9 @@ #include "code_generator/groupMerged.h" //---------------------------------------------------------------------------- -// CodeGenerator::CustomConnectivityUpdateGroupMergedBase +// GeNN::CodeGenerator::CustomConnectivityUpdateGroupMergedBase //---------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { class GENN_EXPORT CustomConnectivityUpdateGroupMergedBase : public GroupMerged { @@ -27,7 +27,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMergedBase : public GroupMerged { @@ -43,7 +43,7 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged { @@ -89,7 +89,7 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged class CustomUpdateHostReductionGroupMergedBase : public GroupMerged @@ -180,7 +180,7 @@ class CustomUpdateHostReductionGroupMergedBase : public GroupMerged }; // ---------------------------------------------------------------------------- -// CustomUpdateHostReductionGroupMerged +// GeNN::CodeGenerator::CustomUpdateHostReductionGroupMerged //---------------------------------------------------------------------------- class GENN_EXPORT CustomUpdateHostReductionGroupMerged : public CustomUpdateHostReductionGroupMergedBase { @@ -206,7 +206,7 @@ class GENN_EXPORT CustomUpdateHostReductionGroupMerged : public CustomUpdateHost }; // ---------------------------------------------------------------------------- -// CustomWUUpdateHostReductionGroupMerged +// GeNN::CodeGenerator::CustomWUUpdateHostReductionGroupMerged //---------------------------------------------------------------------------- class GENN_EXPORT CustomWUUpdateHostReductionGroupMerged : public CustomUpdateHostReductionGroupMergedBase { @@ -230,4 +230,4 @@ class GENN_EXPORT CustomWUUpdateHostReductionGroupMerged : public CustomUpdateHo //---------------------------------------------------------------------------- static const std::string name; }; -} // namespace CodeGenerator \ No newline at end of file +} // namespace GeNN::CodeGenerator \ No newline at end of file diff --git a/include/genn/genn/code_generator/generateMSBuild.h b/include/genn/genn/code_generator/generateMSBuild.h index 98379cc704..71a0fe7b86 100644 --- a/include/genn/genn/code_generator/generateMSBuild.h +++ b/include/genn/genn/code_generator/generateMSBuild.h @@ -6,17 +6,20 @@ #include "gennExport.h" // Forward declarations +namespace GeNN +{ class ModelSpecInternal; namespace CodeGenerator { class BackendBase; } +} //-------------------------------------------------------------------------- -// CodeGenerator +// GeNN::CodeGenerator //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { GENN_EXPORT void generateMSBuild(std::ostream &os, const ModelSpecInternal &model, const BackendBase &backend, const std::string &projectGUID, const std::vector &moduleNames); diff --git a/include/genn/genn/code_generator/generateMakefile.h b/include/genn/genn/code_generator/generateMakefile.h index 4dd91fda89..e75bfa98ba 100644 --- a/include/genn/genn/code_generator/generateMakefile.h +++ b/include/genn/genn/code_generator/generateMakefile.h @@ -8,17 +8,20 @@ #include "gennExport.h" // Forward declarations +namespace GeNN +{ class ModelSpecInternal; namespace CodeGenerator { class BackendBase; } +} //-------------------------------------------------------------------------- -// CodeGenerator +// GeNN::CodeGenerator //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { GENN_EXPORT void generateMakefile(std::ostream &os, const BackendBase &backend, const std::vector &moduleNames); diff --git a/include/genn/genn/code_generator/generateModules.h b/include/genn/genn/code_generator/generateModules.h index 80334b318f..c9ec3bfb61 100644 --- a/include/genn/genn/code_generator/generateModules.h +++ b/include/genn/genn/code_generator/generateModules.h @@ -11,7 +11,10 @@ #include "backendBase.h" // Forward declarations +namespace GeNN +{ class ModelSpecInternal; +} namespace filesystem { @@ -21,7 +24,7 @@ namespace filesystem //-------------------------------------------------------------------------- // CodeGenerator //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { GENN_EXPORT std::pair, MemAlloc> generateAll(const ModelSpecInternal &model, const BackendBase &backend, const filesystem::path &sharePath, const filesystem::path &outputPath, diff --git a/include/genn/genn/code_generator/generateRunner.h b/include/genn/genn/code_generator/generateRunner.h index d974b3d53b..549ac67f0a 100644 --- a/include/genn/genn/code_generator/generateRunner.h +++ b/include/genn/genn/code_generator/generateRunner.h @@ -10,7 +10,7 @@ #include "code_generator/backendBase.h" // Forward declarations -namespace CodeGenerator +namespace GeNN::CodeGenerator { class ModelSpecMerged; } @@ -21,9 +21,9 @@ class path; } //-------------------------------------------------------------------------- -// CodeGenerator +// GeNN::CodeGenerator //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { GENN_EXPORT MemAlloc generateRunner(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, const BackendBase &backend, const std::string &suffix = ""); diff --git a/include/genn/genn/code_generator/generateSupportCode.h b/include/genn/genn/code_generator/generateSupportCode.h index 41ab13a27a..2a5fc3f8f4 100644 --- a/include/genn/genn/code_generator/generateSupportCode.h +++ b/include/genn/genn/code_generator/generateSupportCode.h @@ -7,7 +7,7 @@ #include "gennExport.h" // Forward declarations -namespace CodeGenerator +namespace GeNN::CodeGenerator { class ModelSpecMerged; } @@ -19,9 +19,9 @@ class path; //-------------------------------------------------------------------------- -// CodeGenerator +// GeNN::CodeGenerator //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { GENN_EXPORT void generateSupportCode(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, const std::string &suffix = ""); diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 5cf9f95a19..0dd3e89cea 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -19,18 +19,18 @@ #include "code_generator/codeGenUtils.h" // Forward declarations -namespace CodeGenerator +namespace GeNN::CodeGenerator { class CodeStream; } //------------------------------------------------------------------------ -// GroupMergedFieldType +// GeNN::CodeGenerator::GroupMergedFieldType //------------------------------------------------------------------------ //! Enumeration of field types /*! The only reason this is not a child of GroupMerged is to prevent the template nightmare that would otherwise ensue when declaring operators on it */ -namespace CodeGenerator +namespace GeNN::CodeGenerator { enum class GroupMergedFieldType : unsigned int { @@ -50,7 +50,7 @@ inline bool operator & (GroupMergedFieldType typeA, GroupMergedFieldType typeB) } //---------------------------------------------------------------------------- -// CodeGenerator::GroupMerged +// GeNN::CodeGenerator::GroupMerged //---------------------------------------------------------------------------- //! Very thin wrapper around a number of groups which have been merged together template @@ -114,7 +114,7 @@ class GroupMerged // If field is a pointer and not marked as being a host field // (in which case the backend should leave its type alone!) const std::string &type = std::get<0>(f); - if(::Utils::isTypePointer(type) && !(std::get<3>(f) & GroupMergedFieldType::HOST)) { + if(Utils::isTypePointer(type) && !(std::get<3>(f) & GroupMergedFieldType::HOST)) { // If we are generating a host structure, allow the backend to override the type if(host) { os << backend.getMergedGroupFieldHostType(type); @@ -532,7 +532,7 @@ class GroupMerged }; //---------------------------------------------------------------------------- -// CodeGenerator::NeuronSpikeQueueUpdateGroupMerged +// GeNN::CodeGenerator::NeuronSpikeQueueUpdateGroupMerged //---------------------------------------------------------------------------- class GENN_EXPORT NeuronSpikeQueueUpdateGroupMerged : public GroupMerged { @@ -560,7 +560,7 @@ class GENN_EXPORT NeuronSpikeQueueUpdateGroupMerged : public GroupMerged { @@ -586,7 +586,7 @@ class GENN_EXPORT NeuronPrevSpikeTimeUpdateGroupMerged : public GroupMerged { @@ -932,7 +932,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged { @@ -1108,4 +1108,4 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged { @@ -237,7 +237,7 @@ class GENN_EXPORT SynapseConnectivityHostInitGroupMerged : public GroupMerged @@ -330,7 +330,7 @@ class CustomUpdateInitGroupMergedBase : public GroupMerged }; // ---------------------------------------------------------------------------- -// CodeGenerator::CustomUpdateInitGroupMerged +// GeNN::CodeGenerator::CustomUpdateInitGroupMerged //---------------------------------------------------------------------------- class GENN_EXPORT CustomUpdateInitGroupMerged : public CustomUpdateInitGroupMergedBase { @@ -361,7 +361,7 @@ class GENN_EXPORT CustomUpdateInitGroupMerged : public CustomUpdateInitGroupMerg // ---------------------------------------------------------------------------- -// CodeGenerator::CustomWUUpdateInitGroupMerged +// GeNN::CodeGenerator::CustomWUUpdateInitGroupMerged //---------------------------------------------------------------------------- class GENN_EXPORT CustomWUUpdateInitGroupMerged : public CustomUpdateInitGroupMergedBase @@ -419,7 +419,7 @@ class GENN_EXPORT CustomWUUpdateInitGroupMerged : public CustomUpdateInitGroupMe }; // ---------------------------------------------------------------------------- -// CodeGenerator::CustomWUUpdateSparseInitGroupMerged +// GeNN::CodeGenerator::CustomWUUpdateSparseInitGroupMerged //---------------------------------------------------------------------------- class GENN_EXPORT CustomWUUpdateSparseInitGroupMerged : public CustomUpdateInitGroupMergedBase @@ -450,7 +450,7 @@ class GENN_EXPORT CustomWUUpdateSparseInitGroupMerged : public CustomUpdateInitG }; //---------------------------------------------------------------------------- -// CustomConnectivityUpdatePreInitGroupMerged +// GeNN::CodeGenerator::CustomConnectivityUpdatePreInitGroupMerged //---------------------------------------------------------------------------- class GENN_EXPORT CustomConnectivityUpdatePreInitGroupMerged : public CustomUpdateInitGroupMergedBase @@ -481,7 +481,7 @@ class GENN_EXPORT CustomConnectivityUpdatePreInitGroupMerged : public CustomUpda }; //---------------------------------------------------------------------------- -// CustomConnectivityUpdatePostInitGroupMerged +// GeNN::CodeGenerator::CustomConnectivityUpdatePostInitGroupMerged //---------------------------------------------------------------------------- class GENN_EXPORT CustomConnectivityUpdatePostInitGroupMerged : public CustomUpdateInitGroupMergedBase @@ -512,7 +512,7 @@ class GENN_EXPORT CustomConnectivityUpdatePostInitGroupMerged : public CustomUpd }; //---------------------------------------------------------------------------- -// CustomConnectivityUpdateSparseInitGroupMerged +// GeNN::CodeGenerator::CustomConnectivityUpdateSparseInitGroupMerged //---------------------------------------------------------------------------- class GENN_EXPORT CustomConnectivityUpdateSparseInitGroupMerged : public CustomUpdateInitGroupMergedBase @@ -541,4 +541,4 @@ class GENN_EXPORT CustomConnectivityUpdateSparseInitGroupMerged : public CustomU //---------------------------------------------------------------------------- static const std::string name; }; -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index f3ad6f9444..3619863464 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -19,15 +19,15 @@ #include "code_generator/supportCodeMerged.h" // Forward declarations -namespace CodeGenerator +namespace GeNN::CodeGenerator { class BackendBase; } //-------------------------------------------------------------------------- -// CodeGenerator::ModelSpecMerged +// GeNN::CodeGenerator::ModelSpecMerged //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { class GENN_EXPORT ModelSpecMerged { @@ -470,4 +470,4 @@ class GENN_EXPORT ModelSpecMerged MergedEGPMap m_MergedEGPs; }; -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index e93a64cd34..5725414baf 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -4,9 +4,9 @@ #include "code_generator/groupMerged.h" //---------------------------------------------------------------------------- -// CodeGenerator::NeuronUpdateGroupMerged +// GeNN::CodeGenerator::NeuronUpdateGroupMerged //---------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase { @@ -97,4 +97,4 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase std::vector> m_SortedInSynWithPostCode; std::vector> m_SortedOutSynWithPreCode; }; -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/presynapticUpdateStrategySIMT.h b/include/genn/genn/code_generator/presynapticUpdateStrategySIMT.h index b06d637e13..a8f1c110d2 100644 --- a/include/genn/genn/code_generator/presynapticUpdateStrategySIMT.h +++ b/include/genn/genn/code_generator/presynapticUpdateStrategySIMT.h @@ -4,6 +4,8 @@ #include "code_generator/backendBase.h" // Forward declarations +namespace GeNN +{ class SynapseGroupInternal; namespace CodeGenerator @@ -11,13 +13,12 @@ namespace CodeGenerator class BackendSIMT; class ModelSpecMerged; } +} //-------------------------------------------------------------------------- -// CodeGenerator::PresynapticUpdateStrategySIMT::Base +// GeNN::CodeGenerator::PresynapticUpdateStrategySIMT::Base //-------------------------------------------------------------------------- -namespace CodeGenerator -{ -namespace PresynapticUpdateStrategySIMT +namespace GeNN::CodeGenerator::PresynapticUpdateStrategySIMT { class Base { @@ -49,7 +50,7 @@ class Base }; //-------------------------------------------------------------------------- -// CodeGenerator::PresynapticUpdateStrategySIMT::PreSpan +// GeNN::CodeGenerator::PresynapticUpdateStrategySIMT::PreSpan //-------------------------------------------------------------------------- //! Presynaptic parallelism class PreSpan : public Base @@ -82,7 +83,7 @@ class PreSpan : public Base }; //-------------------------------------------------------------------------- -// CodeGenerator::PresynapticUpdateStrategySIMT::PostSpan +// GeNN::CodeGenerator::PresynapticUpdateStrategySIMT::PostSpan //-------------------------------------------------------------------------- //! Postsynaptic parallelism class PostSpan : public Base @@ -125,7 +126,7 @@ class PostSpan : public Base //-------------------------------------------------------------------------- // CodeGenerator::PresynapticUpdateStrategySIMT::PostSpanBitmask //-------------------------------------------------------------------------- -//! Postsynaptic parallelism +//! GeNN::Postsynaptic parallelism class PostSpanBitmask : public Base { public: @@ -156,7 +157,7 @@ class PostSpanBitmask : public Base }; //-------------------------------------------------------------------------- -// CodeGenerator::PresynapticUpdateStrategySIMT::PreSpanProcedural +// GeNN::CodeGenerator::PresynapticUpdateStrategySIMT::PreSpanProcedural //-------------------------------------------------------------------------- //! Presynaptic parallelism with procedural connectivity class PreSpanProcedural : public Base @@ -189,7 +190,7 @@ class PreSpanProcedural : public Base }; //-------------------------------------------------------------------------- -// CodeGenerator::PresynapticUpdateStrategySIMT::PostSpanToeplitz +// GeNN::CodeGenerator::PresynapticUpdateStrategySIMT::PostSpanToeplitz //-------------------------------------------------------------------------- //! Postsynaptic parallelism for Toeplitz connectivity class PostSpanToeplitz : public Base @@ -220,5 +221,4 @@ class PostSpanToeplitz : public Base virtual void genPostamble(CodeStream &os, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, const Substitutions &popSubs, const BackendSIMT &backend) const override; }; -} // namespace PresynapticUpdateStrategySIMT -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator::PresynapticUpdateStrategySIMT diff --git a/include/genn/genn/code_generator/substitutions.h b/include/genn/genn/code_generator/substitutions.h index 29c70898fe..b119c743f8 100644 --- a/include/genn/genn/code_generator/substitutions.h +++ b/include/genn/genn/code_generator/substitutions.h @@ -14,9 +14,9 @@ #include "logging.h" //-------------------------------------------------------------------------- -// Substitutions +// GeNN::CodeGenerator::Substitutions //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { class GENN_EXPORT Substitutions { @@ -197,4 +197,4 @@ class GENN_EXPORT Substitutions std::map> m_FuncSubstitutions; const Substitutions *m_Parent; }; -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/supportCodeMerged.h b/include/genn/genn/code_generator/supportCodeMerged.h index 8878dcaace..3a61f3d5c9 100644 --- a/include/genn/genn/code_generator/supportCodeMerged.h +++ b/include/genn/genn/code_generator/supportCodeMerged.h @@ -8,9 +8,9 @@ #include "code_generator/codeStream.h" //-------------------------------------------------------------------------- -// CodeGenerator::SupportCodeMerged +// GeNN::CodeGenerator::SupportCodeMerged //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { class SupportCodeMerged { @@ -78,4 +78,4 @@ class SupportCodeMerged // Prefix const std::string m_NamespacePrefix; }; -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h index 1c18edf907..823abf7ec9 100644 --- a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h @@ -4,14 +4,10 @@ #include "code_generator/groupMerged.h" //---------------------------------------------------------------------------- -// CodeGenerator:: +// GeNN::CodeGenerator::PresynapticUpdateGroupMerged //---------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { - -//---------------------------------------------------------------------------- -// CodeGenerator::PresynapticUpdateGroupMerged -//---------------------------------------------------------------------------- class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase { public: @@ -47,7 +43,7 @@ class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase }; //---------------------------------------------------------------------------- -// CodeGenerator::PostsynapticUpdateGroupMerged +// GeNN::CodeGenerator::PostsynapticUpdateGroupMerged //---------------------------------------------------------------------------- class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase { @@ -80,7 +76,7 @@ class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase }; //---------------------------------------------------------------------------- -// CodeGenerator::SynapseDynamicsGroupMerged +// GeNN::CodeGenerator::SynapseDynamicsGroupMerged //---------------------------------------------------------------------------- class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase { @@ -113,7 +109,7 @@ class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase }; //---------------------------------------------------------------------------- -// CodeGenerator::SynapseDendriticDelayUpdateGroupMerged +// GeNN::CodeGenerator::SynapseDendriticDelayUpdateGroupMerged //---------------------------------------------------------------------------- class GENN_EXPORT SynapseDendriticDelayUpdateGroupMerged : public GroupMerged { @@ -137,4 +133,4 @@ class GENN_EXPORT SynapseDendriticDelayUpdateGroupMerged : public GroupMerged //-------------------------------------------------------------------------- -// CodeGenerator::TeeBuf +// GeNN::CodeGenerator::TeeBuf //-------------------------------------------------------------------------- // A stream buffer to support 'Teeing' streams - curtesy of http://wordaligned.org/articles/cpp-streambufs -namespace CodeGenerator +namespace GeNN::CodeGenerator { class TeeBuf: public std::streambuf { @@ -60,7 +60,7 @@ class TeeBuf: public std::streambuf }; //-------------------------------------------------------------------------- -// CodeGenerator::TeeStream +// GeNN::CodeGenerator::TeeStream //-------------------------------------------------------------------------- class TeeStream : public std::ostream { @@ -77,4 +77,4 @@ class TeeStream : public std::ostream //-------------------------------------------------------------------------- TeeBuf m_TeeBuf; }; -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/currentSource.h b/include/genn/genn/currentSource.h index 150be4707e..bd664546bd 100644 --- a/include/genn/genn/currentSource.h +++ b/include/genn/genn/currentSource.h @@ -13,11 +13,16 @@ #include "variableMode.h" // Forward declarations +namespace GeNN +{ class NeuronGroupInternal; +} //------------------------------------------------------------------------ -// CurrentSource +// GeNN::CurrentSource //------------------------------------------------------------------------ +namespace GeNN +{ class GENN_EXPORT CurrentSource { public: @@ -115,3 +120,4 @@ class GENN_EXPORT CurrentSource //! Location of extra global parameters std::vector m_ExtraGlobalParamLocation; }; +} // namespace GeNN diff --git a/include/genn/genn/currentSourceInternal.h b/include/genn/genn/currentSourceInternal.h index 34b9ac199a..ca6e7abe48 100644 --- a/include/genn/genn/currentSourceInternal.h +++ b/include/genn/genn/currentSourceInternal.h @@ -4,8 +4,10 @@ #include "currentSource.h" //------------------------------------------------------------------------ -// CurrentSourceInternal +// GeNN::CurrentSourceInternal //------------------------------------------------------------------------ +namespace GeNN +{ class CurrentSourceInternal : public CurrentSource { public: @@ -76,3 +78,4 @@ class CurrentSourceEGPAdapter //---------------------------------------------------------------------------- const CurrentSourceInternal &m_CS; }; +} // namespace GeNN diff --git a/include/genn/genn/currentSourceModels.h b/include/genn/genn/currentSourceModels.h index 826b945c18..1974b46b9f 100644 --- a/include/genn/genn/currentSourceModels.h +++ b/include/genn/genn/currentSourceModels.h @@ -18,9 +18,9 @@ #define SET_INJECTION_CODE(INJECTION_CODE) virtual std::string getInjectionCode() const override{ return INJECTION_CODE; } //---------------------------------------------------------------------------- -// CurrentSourceModels::Base +// GeNN::CurrentSourceModels::Base //---------------------------------------------------------------------------- -namespace CurrentSourceModels +namespace GeNN::CurrentSourceModels { //! Base class for all current source models class GENN_EXPORT Base : public Models::Base @@ -110,4 +110,4 @@ class PoissonExp : public Base {"Init", [](const std::unordered_map &pars, double dt){ return pars.at("weight") * (1.0 - std::exp(-dt / pars.at("tauSyn"))) * (pars.at("tauSyn") / dt); }}, {"ExpMinusLambda", [](const std::unordered_map &pars, double dt){ return std::exp(-(pars.at("rate") / 1000.0) * dt); }}}); }; -} // CurrentSourceModels +} // GeNN::CurrentSourceModels diff --git a/include/genn/genn/customConnectivityUpdate.h b/include/genn/genn/customConnectivityUpdate.h index 3cae60c1f8..5a1e000370 100644 --- a/include/genn/genn/customConnectivityUpdate.h +++ b/include/genn/genn/customConnectivityUpdate.h @@ -11,8 +11,10 @@ #include "variableMode.h" //------------------------------------------------------------------------ -// CustomConnectivityUpdate +// GeNN::CustomConnectivityUpdate //------------------------------------------------------------------------ +namespace GeNN +{ class GENN_EXPORT CustomConnectivityUpdate { public: @@ -165,3 +167,4 @@ class GENN_EXPORT CustomConnectivityUpdate const NeuronGroup *m_PreDelayNeuronGroup; const NeuronGroup *m_PostDelayNeuronGroup; }; +} // namespace GeNN \ No newline at end of file diff --git a/include/genn/genn/customConnectivityUpdateInternal.h b/include/genn/genn/customConnectivityUpdateInternal.h index a41ccb46f3..54edd43732 100644 --- a/include/genn/genn/customConnectivityUpdateInternal.h +++ b/include/genn/genn/customConnectivityUpdateInternal.h @@ -7,6 +7,8 @@ //------------------------------------------------------------------------ // CustomUpdateInternal //------------------------------------------------------------------------ +namespace GeNN +{ class CustomConnectivityUpdateInternal : public CustomConnectivityUpdate { public: @@ -137,3 +139,4 @@ class CustomConnectivityUpdateEGPAdapter //---------------------------------------------------------------------------- const CustomConnectivityUpdateInternal &m_CU; }; +} // namespace GeNN \ No newline at end of file diff --git a/include/genn/genn/customConnectivityUpdateModels.h b/include/genn/genn/customConnectivityUpdateModels.h index c0151ed81c..7c9c82a872 100644 --- a/include/genn/genn/customConnectivityUpdateModels.h +++ b/include/genn/genn/customConnectivityUpdateModels.h @@ -18,9 +18,9 @@ #define SET_HOST_UPDATE_CODE(HOST_UPDATE_CODE) virtual std::string getHostUpdateCode() const override{ return HOST_UPDATE_CODE; } //---------------------------------------------------------------------------- -// CustomConnectivityUpdateModels::Base +// GeNN::CustomConnectivityUpdateModels::Base //---------------------------------------------------------------------------- -namespace CustomConnectivityUpdateModels +namespace GeNN::CustomConnectivityUpdateModels { //! Base class for all current source models class GENN_EXPORT Base : public Models::Base @@ -80,4 +80,4 @@ class GENN_EXPORT Base : public Models::Base const std::unordered_map &postVarRefTargets, const std::string &description) const; }; -} // CustomConnectivityUpdateModels +} // GeNN::CustomConnectivityUpdateModels diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index c52b82f279..40e4464b09 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -12,8 +12,10 @@ #include "variableMode.h" //------------------------------------------------------------------------ -// CustomUpdateBase +// GeNN::CustomUpdateBase //------------------------------------------------------------------------ +namespace GeNN +{ class GENN_EXPORT CustomUpdateBase { public: @@ -308,3 +310,4 @@ class GENN_EXPORT CustomUpdateWU : public CustomUpdateBase const std::unordered_map m_VarReferences; SynapseGroupInternal *m_SynapseGroup; }; +} // namespace GeNN diff --git a/include/genn/genn/customUpdateInternal.h b/include/genn/genn/customUpdateInternal.h index 2f958d41e0..5093e7a99d 100644 --- a/include/genn/genn/customUpdateInternal.h +++ b/include/genn/genn/customUpdateInternal.h @@ -5,8 +5,10 @@ #include "synapseGroupInternal.h" //------------------------------------------------------------------------ -// CustomUpdateInternal +// GeNN::CustomUpdateInternal //------------------------------------------------------------------------ +namespace GeNN +{ class CustomUpdateInternal : public CustomUpdate { public: @@ -65,3 +67,4 @@ class CustomUpdateWUInternal : public CustomUpdateWU using CustomUpdateWU::isBatchReduction; using CustomUpdateWU::isTransposeOperation; }; +} // namespace GeNN diff --git a/include/genn/genn/customUpdateModels.h b/include/genn/genn/customUpdateModels.h index 95e636848a..bfc4fb6d1b 100644 --- a/include/genn/genn/customUpdateModels.h +++ b/include/genn/genn/customUpdateModels.h @@ -12,9 +12,9 @@ //---------------------------------------------------------------------------- -// CustomUpdateModels::Base +// GeNN::CustomUpdateModels::Base //---------------------------------------------------------------------------- -namespace CustomUpdateModels +namespace GeNN::CustomUpdateModels { //! Base class for all current source models class GENN_EXPORT Base : public Models::Base @@ -63,5 +63,5 @@ class Transpose : public Base SET_VAR_REFS({{"variable", "scalar", VarAccessMode::READ_WRITE}}); }; -} // CustomUpdateModels +} // GeNN::CustomUpdateModels diff --git a/include/genn/genn/gennUtils.h b/include/genn/genn/gennUtils.h index 42d8e1208e..fd610bffe1 100644 --- a/include/genn/genn/gennUtils.h +++ b/include/genn/genn/gennUtils.h @@ -20,15 +20,15 @@ #include "gennExport.h" // Forward declarations -namespace Models +namespace GeNN::Models { class VarInit; } //-------------------------------------------------------------------------- -// Utils +// GeNN::Utils //-------------------------------------------------------------------------- -namespace Utils +namespace GeNN::Utils { //-------------------------------------------------------------------------- //! \brief Does the code string contain any functions requiring random number generator @@ -209,4 +209,4 @@ struct SHA1Hash return hash; }; }; -} // namespace Utils +} // namespace GeNN::Utils diff --git a/include/genn/genn/initSparseConnectivitySnippet.h b/include/genn/genn/initSparseConnectivitySnippet.h index e5e9fc4d01..a84e9fce07 100644 --- a/include/genn/genn/initSparseConnectivitySnippet.h +++ b/include/genn/genn/initSparseConnectivitySnippet.h @@ -31,10 +31,10 @@ #define SET_MAX_COL_LENGTH(MAX_COL_LENGTH) virtual CalcMaxLengthFunc getCalcMaxColLengthFunc() const override{ return [](unsigned int, unsigned int, const std::unordered_map &){ return MAX_COL_LENGTH; }; } //---------------------------------------------------------------------------- -// InitSparseConnectivitySnippet::Base +// GeNN::InitSparseConnectivitySnippet::Base //---------------------------------------------------------------------------- //! Base class for all sparse connectivity initialisation snippets -namespace InitSparseConnectivitySnippet +namespace GeNN::InitSparseConnectivitySnippet { class GENN_EXPORT Base : public Snippet::Base { @@ -440,4 +440,4 @@ class Conv2D : public Base (unsigned int)pars.at("conv_ic"), (unsigned int)pars.at("conv_oc")}; }); }; -} // namespace InitSparseConnectivitySnippet +} // namespace GeNN::InitSparseConnectivitySnippet diff --git a/include/genn/genn/initToeplitzConnectivitySnippet.h b/include/genn/genn/initToeplitzConnectivitySnippet.h index 06b9971e66..e3d6638444 100644 --- a/include/genn/genn/initToeplitzConnectivitySnippet.h +++ b/include/genn/genn/initToeplitzConnectivitySnippet.h @@ -25,10 +25,10 @@ #define SET_MAX_ROW_LENGTH(MAX_ROW_LENGTH) virtual CalcMaxLengthFunc getCalcMaxRowLengthFunc() const override{ return [](unsigned int, unsigned int, const std::unordered_map &){ return MAX_ROW_LENGTH; }; } //---------------------------------------------------------------------------- -// InitToeplitzConnectivitySnippet::Base +// GeNN:::InitToeplitzConnectivitySnippet::Base //---------------------------------------------------------------------------- //! Base class for all toeplitz connectivity initialisation snippets -namespace InitToeplitzConnectivitySnippet +namespace GeNN::InitToeplitzConnectivitySnippet { class GENN_EXPORT Base : public Snippet::Base { @@ -61,7 +61,7 @@ class GENN_EXPORT Base : public Snippet::Base using Init = Snippet::Init; //---------------------------------------------------------------------------- -// InitToeplitzConnectivitySnippet::Uninitialised +// GeNN::InitToeplitzConnectivitySnippet::Uninitialised //---------------------------------------------------------------------------- //! Used to mark connectivity as uninitialised - no initialisation code will be run class Uninitialised : public Base @@ -71,7 +71,7 @@ class Uninitialised : public Base }; //---------------------------------------------------------------------------- -// InitToeplitzConnectivitySnippet::Conv2D +// GeNN::InitToeplitzConnectivitySnippet::Conv2D //---------------------------------------------------------------------------- //! Initialises convolutional connectivity //! Row build state variables are used to convert presynaptic neuron index to rows, columns and channels and, @@ -123,7 +123,7 @@ class Conv2D : public Base }; //---------------------------------------------------------------------------- -// InitToeplitzConnectivitySnippet::AvgPoolConv2D +// GeNN::InitToeplitzConnectivitySnippet::AvgPoolConv2D //---------------------------------------------------------------------------- //! Initialises convolutional connectivity preceded by averaging pooling //! Row build state variables are used to convert presynaptic neuron index to rows, columns and channels and, @@ -183,4 +183,4 @@ class AvgPoolConv2D : public Base (unsigned int)pars.at("pool_ic"), (unsigned int)pars.at("conv_oc")}; }); }; -} // namespace InitToeplitzConnectivitySnippet +} // namespace GeNN::InitToeplitzConnectivitySnippet diff --git a/include/genn/genn/initVarSnippet.h b/include/genn/genn/initVarSnippet.h index 20faa4e331..c56e2b6749 100644 --- a/include/genn/genn/initVarSnippet.h +++ b/include/genn/genn/initVarSnippet.h @@ -9,10 +9,10 @@ #define SET_CODE(CODE) virtual std::string getCode() const override{ return CODE; } //---------------------------------------------------------------------------- -// InitVarSnippet::Base +// GeNN::InitVarSnippet::Base //---------------------------------------------------------------------------- //! Base class for all value initialisation snippets -namespace InitVarSnippet +namespace GeNN::InitVarSnippet { class GENN_EXPORT Base : public Snippet::Base { @@ -36,7 +36,7 @@ class GENN_EXPORT Base : public Snippet::Base }; //---------------------------------------------------------------------------- -// InitVarSnippet::Uninitialised +// GeNN::InitVarSnippet::Uninitialised //---------------------------------------------------------------------------- //! Used to mark variables as uninitialised - no initialisation code will be run class Uninitialised : public Base @@ -46,7 +46,7 @@ class Uninitialised : public Base }; //---------------------------------------------------------------------------- -// InitVarSnippet::Constant +// GeNN::InitVarSnippet::Constant //---------------------------------------------------------------------------- //! Initialises variable to a constant value /*! This snippet takes 1 parameter: @@ -66,7 +66,7 @@ class Constant : public Base }; //---------------------------------------------------------------------------- -// InitVarSnippet::Kernel +// GeNN::InitVarSnippet::Kernel //---------------------------------------------------------------------------- //! Used to initialise synapse variables from a kernel class Kernel : public Base @@ -79,7 +79,7 @@ class Kernel : public Base }; //---------------------------------------------------------------------------- -// InitVarSnippet::Uniform +// GeNN::InitVarSnippet::Uniform //---------------------------------------------------------------------------- //! Initialises variable by sampling from the uniform distribution /*! This snippet takes 2 parameters: @@ -99,7 +99,7 @@ class Uniform : public Base }; //---------------------------------------------------------------------------- -// InitVarSnippet::Normal +// GeNN::InitVarSnippet::Normal //---------------------------------------------------------------------------- //! Initialises variable by sampling from the normal distribution /*! This snippet takes 2 parameters: @@ -117,14 +117,14 @@ class Normal : public Base }; //---------------------------------------------------------------------------- -// InitVarSnippet::NormalClipped +// GeNN::InitVarSnippet::NormalClipped //---------------------------------------------------------------------------- //! Initialises variable by sampling from the normal distribution, //! Resamples value if out of range specified my min and max /*! This snippet takes 2 parameters: * - \c mean - The mean - - \c sd - The standard deviation + - \c sd - ThGeNN::e standard deviation - \c min - The minimum value - \c max - The maximum value*/ class NormalClipped : public Base @@ -144,7 +144,7 @@ class NormalClipped : public Base }; //---------------------------------------------------------------------------- -// InitVarSnippet::NormalClippedDelay +// GeNN::InitVarSnippet::NormalClippedDelay //---------------------------------------------------------------------------- //! Initialises variable by sampling from the normal distribution, //! Resamples value of out of range specified my min and max. @@ -178,7 +178,7 @@ class NormalClippedDelay : public Base }; //---------------------------------------------------------------------------- -// InitVarSnippet::Exponential +// GeNN::InitVarSnippet::Exponential //---------------------------------------------------------------------------- //! Initialises variable by sampling from the exponential distribution /*! This snippet takes 1 parameter: @@ -195,7 +195,7 @@ class Exponential : public Base }; //---------------------------------------------------------------------------- -// InitVarSnippet::Gamma +// GeNN::InitVarSnippet::Gamma //---------------------------------------------------------------------------- //! Initialises variable by sampling from the gamma distribution /*! This snippet takes 2 parameters: @@ -213,7 +213,7 @@ class Gamma : public Base }; //---------------------------------------------------------------------------- -// InitVarSnippet::Binomial +// GeNN::InitVarSnippet::Binomial //---------------------------------------------------------------------------- //! Initialises variable by sampling from the binomial distribution /*! This snippet takes 2 parameters: @@ -229,4 +229,4 @@ class Binomial : public Base SET_PARAM_NAMES({"n", "p"}); }; -} // namespace InitVarSnippet +} // namespace GeNN::InitVarSnippet diff --git a/include/genn/genn/logging.h b/include/genn/genn/logging.h index 646eefe242..9e70b24221 100644 --- a/include/genn/genn/logging.h +++ b/include/genn/genn/logging.h @@ -42,9 +42,9 @@ class IAppender; //---------------------------------------------------------------------------- -// Logging +// GeNN::Logging //---------------------------------------------------------------------------- -namespace Logging +namespace GeNN::Logging { enum Channel { diff --git a/include/genn/genn/modelSpec.h b/include/genn/genn/modelSpec.h index 721a99474a..fda88f416b 100644 --- a/include/genn/genn/modelSpec.h +++ b/include/genn/genn/modelSpec.h @@ -40,6 +40,8 @@ Part of the code generation and generated code sections. #define NO_DELAY 0 //!< Macro used to indicate no synapse delay for the group (only one queue slot will be generated) +namespace GeNN +{ using ParamValues = std::unordered_map; using VarValues = std::unordered_map; using VarReferences = std::unordered_map; @@ -203,7 +205,7 @@ inline Models::WUVarReference createWUVarRef(CustomConnectivityUpdate *cu, const } //---------------------------------------------------------------------------- -// ModelSpec +// GeNN::ModelSpec //---------------------------------------------------------------------------- //! Object used for specifying a neuronal network model class GENN_EXPORT ModelSpec @@ -766,6 +768,4 @@ class GENN_EXPORT ModelSpec //! Batch size of this model - efficiently duplicates model unsigned int m_BatchSize; }; - -// Typedefine NNmodel for backward compatibility -typedef ModelSpec NNmodel; +} // namespace GeNN diff --git a/include/genn/genn/modelSpecInternal.h b/include/genn/genn/modelSpecInternal.h index 93938719ee..7fc79c3327 100644 --- a/include/genn/genn/modelSpecInternal.h +++ b/include/genn/genn/modelSpecInternal.h @@ -5,8 +5,10 @@ #include "modelSpec.h" //------------------------------------------------------------------------ -// ModelSpecInternal +// GeNN::ModelSpecInternal //------------------------------------------------------------------------ +namespace GeNN +{ class ModelSpecInternal : public ModelSpec { public: @@ -28,3 +30,4 @@ class ModelSpecInternal : public ModelSpec using ModelSpec::isRecordingInUse; using ModelSpec::getHashDigest; }; +} // namespace GeNN diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 4254aafab2..95d13fb8e3 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -13,6 +13,8 @@ #include "varAccess.h" // Forward declarations +namespace GeNN +{ class CustomConnectivityUpdate; class CustomUpdate; class CustomUpdateWU; @@ -26,6 +28,7 @@ namespace CodeGenerator { class BackendBase; } +} //---------------------------------------------------------------------------- // Macros @@ -34,10 +37,10 @@ class BackendBase; //---------------------------------------------------------------------------- -// Models::Base +// GeNN::Models::Base //---------------------------------------------------------------------------- //! Base class for all models - in addition to the parameters snippets have, models can have state variables -namespace Models +namespace GeNN::Models { class GENN_EXPORT Base : public Snippet::Base { @@ -119,7 +122,7 @@ class GENN_EXPORT Base : public Snippet::Base //---------------------------------------------------------------------------- -// Models::VarInit +// GeNN::Models::VarInit //---------------------------------------------------------------------------- //! Class used to bind together everything required to initialise a variable: //! 1. A pointer to a variable initialisation snippet @@ -139,7 +142,7 @@ class VarInit : public Snippet::Init }; //---------------------------------------------------------------------------- -// Models::VarReferenceBase +// GeNN::Models::VarReferenceBase //---------------------------------------------------------------------------- class GENN_EXPORT VarReferenceBase { @@ -187,7 +190,7 @@ class GENN_EXPORT VarReferenceBase }; //---------------------------------------------------------------------------- -// Models::VarReference +// GeNN::Models::VarReference //---------------------------------------------------------------------------- class GENN_EXPORT VarReference : public VarReferenceBase { @@ -231,7 +234,7 @@ class GENN_EXPORT VarReference : public VarReferenceBase }; //---------------------------------------------------------------------------- -// Models::WUVarReference +// GeNN::Models::WUVarReference //---------------------------------------------------------------------------- class GENN_EXPORT WUVarReference : public VarReferenceBase { @@ -300,7 +303,6 @@ GENN_EXPORT void updateHash(const Base::VarRef &v, boost::uuids::detail::sha1 &h GENN_EXPORT void updateHash(const VarReference &v, boost::uuids::detail::sha1 &hash); GENN_EXPORT void updateHash(const WUVarReference &v, boost::uuids::detail::sha1 &hash); - //! Helper function to check if variable reference types match those specified in model template void checkVarReferences(const std::unordered_map &varRefs, const Base::VarRefVec &modelVarRefs) @@ -323,4 +325,4 @@ void checkVarReferences(const std::unordered_map &varRefs, const } } } -} // Models +} // GeNN::Models diff --git a/include/genn/genn/neuronGroup.h b/include/genn/genn/neuronGroup.h index a38be41b5b..30a17fd695 100644 --- a/include/genn/genn/neuronGroup.h +++ b/include/genn/genn/neuronGroup.h @@ -13,12 +13,17 @@ #include "variableMode.h" // Forward declarations +namespace GeNN +{ class CurrentSourceInternal; class SynapseGroupInternal; +} //------------------------------------------------------------------------ -// NeuronGroup +// GeNN::NeuronGroup //------------------------------------------------------------------------ +namespace GeNN +{ class GENN_EXPORT NeuronGroup { public: @@ -320,3 +325,4 @@ class GENN_EXPORT NeuronGroup //! Is spike event recording enabled? bool m_SpikeEventRecordingEnabled; }; +} // namespace GeNN diff --git a/include/genn/genn/neuronGroupInternal.h b/include/genn/genn/neuronGroupInternal.h index 3d0746af1f..8e1b23a30f 100644 --- a/include/genn/genn/neuronGroupInternal.h +++ b/include/genn/genn/neuronGroupInternal.h @@ -4,8 +4,10 @@ #include "neuronGroup.h" //------------------------------------------------------------------------ -// NeuronGroupInternal +// GeNN::NeuronGroupInternal //------------------------------------------------------------------------ +namespace GeNN +{ class NeuronGroupInternal : public NeuronGroup { public: @@ -46,7 +48,6 @@ class NeuronGroupInternal : public NeuronGroup using NeuronGroup::getVarLocationHashDigest; }; - //---------------------------------------------------------------------------- // NeuronVarAdapter //---------------------------------------------------------------------------- @@ -94,3 +95,4 @@ class NeuronEGPAdapter //---------------------------------------------------------------------------- const NeuronGroupInternal &m_NG; }; +} // namespace GeNN \ No newline at end of file diff --git a/include/genn/genn/neuronModels.h b/include/genn/genn/neuronModels.h index e4ca532d84..9c44f4a810 100644 --- a/include/genn/genn/neuronModels.h +++ b/include/genn/genn/neuronModels.h @@ -24,9 +24,9 @@ #define SET_NEEDS_AUTO_REFRACTORY(AUTO_REFRACTORY_REQUIRED) virtual bool isAutoRefractoryRequired() const override{ return AUTO_REFRACTORY_REQUIRED; } //---------------------------------------------------------------------------- -// NeuronModels::Base +// GeNN::NeuronModels::Base //---------------------------------------------------------------------------- -namespace NeuronModels +namespace GeNN::NeuronModels { //! Base class for all neuron models class GENN_EXPORT Base : public Models::Base @@ -58,7 +58,7 @@ class GENN_EXPORT Base : public Models::Base virtual Models::Base::ParamValVec getAdditionalInputVars() const{ return {}; } //! Does this model require auto-refractory logic? - virtual bool isAutoRefractoryRequired() const{ return true; } + virtual bool isAutoRefractoryRequired() const{ return false; } //---------------------------------------------------------------------------- // Public API @@ -73,7 +73,7 @@ class GENN_EXPORT Base : public Models::Base }; //---------------------------------------------------------------------------- -// NeuronModels::RulkovMap +// GeNN::NeuronModels::RulkovMap //---------------------------------------------------------------------------- //! Rulkov Map neuron /*! The RulkovMap type is a map based neuron model based on \cite Rulkov2002 but in @@ -135,7 +135,7 @@ class RulkovMap : public Base }; //---------------------------------------------------------------------------- -// NeuronModels::Izhikevich +// GeNN::NeuronModels::Izhikevich //---------------------------------------------------------------------------- //! Izhikevich neuron with fixed parameters \cite izhikevich2003simple. /*! It is usually described as @@ -181,7 +181,7 @@ class Izhikevich : public Base }; //---------------------------------------------------------------------------- -// NeuronModels::IzhikevichVariable +// GeNN::NeuronModels::IzhikevichVariable //---------------------------------------------------------------------------- //! Izhikevich neuron with variable parameters \cite izhikevich2003simple. /*! This is the same model as NeuronModels::Izhikevich but parameters are defined as @@ -209,7 +209,7 @@ class IzhikevichVariable : public Izhikevich }; //---------------------------------------------------------------------------- -// NeuronModels::LIF +// GeNN::NeuronModels::LIF //---------------------------------------------------------------------------- class LIF : public Base { @@ -251,7 +251,7 @@ class LIF : public Base }; //---------------------------------------------------------------------------- -// NeuronModels::SpikeSource +// GeNN::NeuronModels::SpikeSource //---------------------------------------------------------------------------- //! Empty neuron which allows setting spikes from external sources /*! This model does not contain any update code and can be used to implement @@ -266,7 +266,7 @@ class SpikeSource : public Base }; //---------------------------------------------------------------------------- -// NeuronModels::SpikeSourceArray +// GeNN::NeuronModels::SpikeSourceArray //---------------------------------------------------------------------------- //! Spike source array /*! A neuron which reads spike times from a global spikes array. @@ -295,7 +295,7 @@ class SpikeSourceArray : public Base }; //---------------------------------------------------------------------------- -// NeuronModels::Poisson +// GeNN::NeuronModels::Poisson //---------------------------------------------------------------------------- //! Poisson neurons /*! Poisson neurons have constant membrane potential (\c Vrest) unless they are @@ -354,7 +354,7 @@ class Poisson : public Base }; //---------------------------------------------------------------------------- -// NeuronModels::PoissonNew +// GeNN::NeuronModels::PoissonNew //---------------------------------------------------------------------------- //! Poisson neurons /*! This neuron model emits spikes according to the Poisson distribution with a mean firing @@ -391,7 +391,7 @@ class PoissonNew : public Base }; //---------------------------------------------------------------------------- -// NeuronModels::TraubMiles +// GeNN::NeuronModels::TraubMiles //---------------------------------------------------------------------------- //! Hodgkin-Huxley neurons with Traub & Miles algorithm. /*! This conductance based model has been taken from \cite Traub1991 and can be described by the equations: @@ -487,7 +487,7 @@ class TraubMiles : public Base }; //---------------------------------------------------------------------------- -// NeuronModels::TraubMilesFast +// GeNN::NeuronModels::TraubMilesFast //---------------------------------------------------------------------------- //! Hodgkin-Huxley neurons with Traub & Miles algorithm: Original fast implementation, using 25 inner iterations. /*! There are singularities in this model, which can be easily hit in float precision @@ -520,7 +520,7 @@ class TraubMilesFast : public TraubMiles }; //---------------------------------------------------------------------------- -// NeuronModels::TraubMilesAlt +// GeNN::NeuronModels::TraubMilesAlt //---------------------------------------------------------------------------- //! Hodgkin-Huxley neurons with Traub & Miles algorithm /*! Using a workaround to avoid singularity: adding the munimum numerical value of the floating point precision used. @@ -556,7 +556,7 @@ class TraubMilesAlt : public TraubMiles }; //---------------------------------------------------------------------------- -// NeuronModels::TraubMilesNStep +// GeNN::NeuronModels::TraubMilesNStep //---------------------------------------------------------------------------- //! Hodgkin-Huxley neurons with Traub & Miles algorithm. /*! Same as standard TraubMiles model but number of inner loops can be set using a parameter @@ -606,4 +606,4 @@ class TraubMilesNStep : public TraubMiles SET_PARAM_NAMES({"gNa", "ENa", "gK", "EK", "gl", "El", "C", "ntimes"}); }; -} // NeuronModels +} // GeNN::NeuronModels diff --git a/include/genn/genn/postsynapticModels.h b/include/genn/genn/postsynapticModels.h index 87aa1e294b..715cf5ec04 100644 --- a/include/genn/genn/postsynapticModels.h +++ b/include/genn/genn/postsynapticModels.h @@ -15,9 +15,9 @@ #define SET_SUPPORT_CODE(SUPPORT_CODE) virtual std::string getSupportCode() const override{ return SUPPORT_CODE; } //---------------------------------------------------------------------------- -// PostsynapticModels::Base +// GeNN::PostsynapticModels::Base //---------------------------------------------------------------------------- -namespace PostsynapticModels +namespace GeNN::PostsynapticModels { //! Base class for all postsynaptic models class GENN_EXPORT Base : public Models::Base @@ -43,7 +43,7 @@ class GENN_EXPORT Base : public Models::Base }; //---------------------------------------------------------------------------- -// PostsynapticModels::ExpCurr +// GeNN::PostsynapticModels::ExpCurr //---------------------------------------------------------------------------- //! Exponential decay with synaptic input treated as a current value. /*! This model has no variables and a single parameter: @@ -65,7 +65,7 @@ class ExpCurr : public Base }; //---------------------------------------------------------------------------- -// PostsynapticModels::ExpCond +// GeNN::PostsynapticModels::ExpCond //---------------------------------------------------------------------------- //! Exponential decay with synaptic input treated as a conductance value. /*! This model has no variables and two parameters: @@ -88,7 +88,7 @@ class ExpCond : public Base }; //---------------------------------------------------------------------------- -// PostsynapticModels::DeltaCurr +// GeNN::PostsynapticModels::DeltaCurr //---------------------------------------------------------------------------- //! Simple delta current synapse. /*! Synaptic input provides a direct inject of instantaneous current*/ @@ -99,4 +99,4 @@ class DeltaCurr : public Base SET_CURRENT_CONVERTER_CODE("$(inSyn); $(inSyn) = 0"); }; -} +} // namespace GeNN::PostsynapticModels diff --git a/include/genn/genn/snippet.h b/include/genn/genn/snippet.h index de559de139..598fe18ebb 100644 --- a/include/genn/genn/snippet.h +++ b/include/genn/genn/snippet.h @@ -38,10 +38,10 @@ public: \ #define SET_EXTRA_GLOBAL_PARAMS(...) virtual EGPVec getExtraGlobalParams() const override{ return __VA_ARGS__; } //---------------------------------------------------------------------------- -// Snippet::Base +// GeNN::Snippet::Base //---------------------------------------------------------------------------- //! Base class for all code snippets -namespace Snippet +namespace GeNN::Snippet { class GENN_EXPORT Base { @@ -224,4 +224,4 @@ inline void updateHash(const Base::DerivedParam &d, boost::uuids::detail::sha1 & { Utils::updateHash(d.name, hash); } -} // namespace Snippet +} // namespace GeNN::Snippet diff --git a/include/genn/genn/synapseGroup.h b/include/genn/genn/synapseGroup.h index 581cc7a340..1a358305ee 100644 --- a/include/genn/genn/synapseGroup.h +++ b/include/genn/genn/synapseGroup.h @@ -15,14 +15,19 @@ #include "variableMode.h" // Forward declarations +namespace GeNN +{ class CustomConnectivityUpdateInternal; class CustomUpdateWUInternal; class NeuronGroupInternal; class SynapseGroupInternal; +} //------------------------------------------------------------------------ -// SynapseGroup +// GeNN::SynapseGroup //------------------------------------------------------------------------ +namespace GeNN +{ class GENN_EXPORT SynapseGroup { public: @@ -525,3 +530,4 @@ class GENN_EXPORT SynapseGroup /*! Because, if connectivity is sparse, all groups share connectivity this is required if connectivity changes. */ std::vector m_CustomUpdateReferences; }; +} // namespace GeNN diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h index 88ea8f2142..2307e889dd 100644 --- a/include/genn/genn/synapseGroupInternal.h +++ b/include/genn/genn/synapseGroupInternal.h @@ -5,8 +5,10 @@ #include "synapseGroup.h" //------------------------------------------------------------------------ -// SynapseGroupInternal +// GeNN::SynapseGroupInternal //------------------------------------------------------------------------ +namespace GeNN +{ class SynapseGroupInternal : public SynapseGroup { public: @@ -204,3 +206,4 @@ class SynapseWUEGPAdapter //---------------------------------------------------------------------------- const SynapseGroupInternal &m_SG; }; +} // namespace GeNN diff --git a/include/genn/genn/synapseMatrixType.h b/include/genn/genn/synapseMatrixType.h index b0437872cf..720000d70e 100644 --- a/include/genn/genn/synapseMatrixType.h +++ b/include/genn/genn/synapseMatrixType.h @@ -3,6 +3,8 @@ //---------------------------------------------------------------------------- // Enumerations //---------------------------------------------------------------------------- +namespace GeNN +{ //! Flags defining differnet types of synaptic matrix connectivity enum class SynapseMatrixConnectivity : unsigned int { @@ -68,3 +70,4 @@ inline SynapseMatrixWeight getSynapseMatrixWeight(SynapseMatrixType type) { return static_cast(static_cast(type) & ~0x1F); } +} // namespace GeNN diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index 87297e19e7..5fabb846e3 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -1,9 +1,10 @@ #pragma once - //---------------------------------------------------------------------------- // Enumerations //---------------------------------------------------------------------------- +namespace GeNN +{ //! Flags defining attributes of var access models //! **NOTE** Read-only and read-write are seperate flags rather than read and write so you can test mode & VarAccessMode::READ_ONLY enum class VarAccessModeAttribute : unsigned int @@ -80,3 +81,4 @@ inline VarAccessDuplication getVarAccessDuplication(VarAccess type) { return static_cast(static_cast(type) & ~0x1F); } +} // namespace GeNN diff --git a/include/genn/genn/variableMode.h b/include/genn/genn/variableMode.h index 64ea266179..f37edc5d7c 100644 --- a/include/genn/genn/variableMode.h +++ b/include/genn/genn/variableMode.h @@ -6,6 +6,8 @@ //---------------------------------------------------------------------------- // Enumerations //---------------------------------------------------------------------------- +namespace GeNN +{ //!< Flags defining which memory space variables should be allocated in enum class VarLocation : uint8_t { @@ -23,4 +25,5 @@ enum class VarLocation : uint8_t inline bool operator & (VarLocation locA, VarLocation locB) { return (static_cast(locA) & static_cast(locB)) != 0; -} \ No newline at end of file +} +} // namespace GeNN diff --git a/include/genn/genn/weightUpdateModels.h b/include/genn/genn/weightUpdateModels.h index 80cfd94214..41b8e53bc6 100644 --- a/include/genn/genn/weightUpdateModels.h +++ b/include/genn/genn/weightUpdateModels.h @@ -32,9 +32,9 @@ #define SET_NEEDS_PREV_PRE_SPIKE_EVENT_TIME(PREV_PRE_SPIKE_EVENT_TIME_REQUIRED) virtual bool isPrevPreSpikeEventTimeRequired() const override{ return PREV_PRE_SPIKE_EVENT_TIME_REQUIRED; } //---------------------------------------------------------------------------- -// WeightUpdateModels::Base +// GeNN::WeightUpdateModels::Base //---------------------------------------------------------------------------- -namespace WeightUpdateModels +namespace GeNN::WeightUpdateModels { //! Base class for all weight update models class GENN_EXPORT Base : public Models::Base @@ -155,7 +155,7 @@ class GENN_EXPORT Base : public Models::Base }; //---------------------------------------------------------------------------- -// WeightUpdateModels::StaticPulse +// GeNN::WeightUpdateModels::StaticPulse //---------------------------------------------------------------------------- //! Pulse-coupled, static synapse. /*! No learning rule is applied to the synapse and for each pre-synaptic spikes, @@ -180,7 +180,7 @@ class StaticPulse : public Base }; //---------------------------------------------------------------------------- -// WeightUpdateModels::StaticPulseDendriticDelay +// GeNN::WeightUpdateModels::StaticPulseDendriticDelay //---------------------------------------------------------------------------- //! Pulse-coupled, static synapse with heterogenous dendritic delays /*! No learning rule is applied to the synapse and for each pre-synaptic spikes, @@ -206,7 +206,7 @@ class StaticPulseDendriticDelay : public Base }; //---------------------------------------------------------------------------- -// WeightUpdateModels::StaticGraded +// GeNN::WeightUpdateModels::StaticGraded //---------------------------------------------------------------------------- //! Graded-potential, static synapse /*! In a graded synapse, the conductance is updated gradually with the rule: @@ -245,7 +245,7 @@ class StaticGraded : public Base }; //---------------------------------------------------------------------------- -// PiecewiseSTDP +// GeNN::PiecewiseSTDP //---------------------------------------------------------------------------- //! This is a simple STDP rule including a time delay for the finite transmission speed of the synapse. /*! The STDP window is defined as a piecewise function: @@ -348,4 +348,4 @@ class PiecewiseSTDP : public Base SET_NEEDS_PRE_SPIKE_TIME(true); SET_NEEDS_POST_SPIKE_TIME(true); }; -} // WeightUpdateModels +} //namespace GeNN::WeightUpdateModels diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index dc894e1fa0..e205802a6e 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -17,7 +17,8 @@ // CUDA backend includes #include "utils.h" -using namespace CodeGenerator; +using namespace GeNN; +using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- // Anonymous namespace @@ -277,13 +278,10 @@ void genNCCLReduction(CodeStream &os, const G &cg, const std::string &precision) } // Anonymous namespace //-------------------------------------------------------------------------- -// CodeGenerator::CUDA::Backend +// GeNN::CodeGenerator::CUDA::Backend //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator::CUDA { -namespace CUDA -{ -//-------------------------------------------------------------------------- Backend::Backend(const KernelBlockSize &kernelBlockSizes, const Preferences &preferences, const std::string &scalarType, int device) : BackendSIMT(kernelBlockSizes, preferences, scalarType), m_ChosenDeviceID(device) @@ -2264,5 +2262,4 @@ void Backend::genKernelDimensions(CodeStream &os, Kernel kernel, size_t numThrea os << "const dim3 grid(" << gridSize << ", " << batchSize << ");" << std::endl; } } -} // namespace CUDA -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator::CUDA diff --git a/src/genn/backends/cuda/optimiser.cc b/src/genn/backends/cuda/optimiser.cc index a26c3a1938..0917d5743c 100644 --- a/src/genn/backends/cuda/optimiser.cc +++ b/src/genn/backends/cuda/optimiser.cc @@ -37,7 +37,8 @@ // CUDA backend includes #include "utils.h" -using namespace CodeGenerator; +using namespace GeNN; +using namespace GeNN::CodeGenerator; using namespace CUDA; //-------------------------------------------------------------------------- @@ -731,13 +732,12 @@ int chooseDeviceWithMostGlobalMemory() LOGI_BACKEND << "Using device " << bestDevice << " which has " << mostGlobalMemory << " bytes of global memory"; return bestDevice; } -} +} // anonymous namespace + +//-------------------------------------------------------------------------- // CodeGenerator::Backends::Optimiser -namespace CodeGenerator -{ -namespace CUDA -{ -namespace Optimiser +//-------------------------------------------------------------------------- +namespace GeNN::CodeGenerator::CUDA::Optimiser { Backend createBackend(const ModelSpecInternal &model, const filesystem::path &outputPath, plog::Severity backendLevel, plog::IAppender *backendAppender, @@ -791,6 +791,4 @@ Backend createBackend(const ModelSpecInternal &model, const filesystem::path &ou } } -} // namespace Optimiser -} // namespace CUDA -} // namespace CodeGenerator +} // namespace CodeGenerator::Backends::Optimiser diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index 1d8bfcc595..1287af6493 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -17,7 +17,7 @@ // OpenCL backend includes #include "utils.h" -using namespace CodeGenerator; +using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- // Anonymous namespace @@ -148,13 +148,10 @@ void genReadEventTiming(CodeStream &os, const std::string &name) } //-------------------------------------------------------------------------- -// CodeGenerator::OpenCL::Backend +// GeNN::CodeGenerator::OpenCL::Backend //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator::OpenCL { -namespace OpenCL -{ -//-------------------------------------------------------------------------- Backend::Backend(const KernelBlockSize &kernelBlockSizes, const Preferences &preferences, const std::string &scalarType, unsigned int platformIndex, unsigned int deviceIndex) : BackendSIMT(kernelBlockSizes, preferences, scalarType), m_ChosenPlatformIndex(platformIndex), m_ChosenDeviceIndex(deviceIndex) @@ -2804,5 +2801,4 @@ bool Backend::shouldUseSubBufferAllocations() const { return isChosenDeviceAMD(); } -} // namespace OpenCL -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator::OpenCL diff --git a/src/genn/backends/opencl/optimiser.cc b/src/genn/backends/opencl/optimiser.cc index 60739b5609..811b8164c2 100644 --- a/src/genn/backends/opencl/optimiser.cc +++ b/src/genn/backends/opencl/optimiser.cc @@ -37,13 +37,9 @@ unsigned int getDeviceWithMostGlobalMemory(unsigned int platformID) } } //-------------------------------------------------------------------------- -// CodeGenerator::OpenCL::Optimiser +// GeNN::CodeGenerator::OpenCL::Optimiser //-------------------------------------------------------------------------- -namespace CodeGenerator -{ -namespace OpenCL -{ -namespace Optimiser +namespace GeNN::CodeGenerator::OpenCL::Optimiser { Backend createBackend(const ModelSpecInternal &model, const filesystem::path&, plog::Severity backendLevel, plog::IAppender *backendAppender, @@ -65,6 +61,4 @@ Backend createBackend(const ModelSpecInternal &model, const filesystem::path&, return Backend(preferences.manualWorkGroupSizes, preferences, model.getPrecision(), platformID, deviceID); } -} // namespace Optimiser -} // namespace OpenCL -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator::OpenCL::Optimiser diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 839668196f..cdb3498520 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -9,7 +9,7 @@ #include "code_generator/modelSpecMerged.h" #include "code_generator/substitutions.h" -using namespace CodeGenerator; +using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- // Anonymous namespace @@ -117,11 +117,9 @@ void genKernelIteration(CodeStream &os, const G &g, size_t numKernelDims, const } //-------------------------------------------------------------------------- -// CodeGenerator::SingleThreadedCPU::Backend +// GeNN::CodeGenerator::SingleThreadedCPU::Backend //-------------------------------------------------------------------------- -namespace CodeGenerator -{ -namespace SingleThreadedCPU +namespace GeNN::CodeGenerator::SingleThreadedCPU { void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler, HostHandler pushEGPHandler) const @@ -1342,8 +1340,8 @@ void Backend::genExtraGlobalParamAllocation(CodeStream &os, const std::string &t VarLocation, const std::string &countVarName, const std::string &prefix) const { // Get underlying type - const std::string underlyingType = ::Utils::getUnderlyingType(type); - const bool pointerToPointer = ::Utils::isTypePointerToPointer(type); + const std::string underlyingType = Utils::getUnderlyingType(type); + const bool pointerToPointer = Utils::isTypePointerToPointer(type); const std::string pointer = pointerToPointer ? ("*" + prefix + name) : (prefix + name); @@ -1977,5 +1975,4 @@ void Backend::genWriteBackReductions(CodeStream &os, const CustomUpdateWUGroupMe index); }); } -} // namespace SingleThreadedCPU -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator::SingleThreadedCPU diff --git a/src/genn/backends/single_threaded_cpu/optimiser.cc b/src/genn/backends/single_threaded_cpu/optimiser.cc index 41614846cf..3f904d1049 100644 --- a/src/genn/backends/single_threaded_cpu/optimiser.cc +++ b/src/genn/backends/single_threaded_cpu/optimiser.cc @@ -4,13 +4,9 @@ #include "modelSpecInternal.h" //-------------------------------------------------------------------------- -// CodeGenerator::SingleThreadedCPU::Optimiser +// GeNN::CodeGenerator::SingleThreadedCPU::Optimiser //-------------------------------------------------------------------------- -namespace CodeGenerator -{ -namespace SingleThreadedCPU -{ -namespace Optimiser +namespace GeNN::CodeGenerator::SingleThreadedCPU::Optimiser { Backend createBackend(const ModelSpecInternal &model, const filesystem::path&, plog::Severity backendLevel, plog::IAppender *backendAppender, @@ -27,6 +23,4 @@ Backend createBackend(const ModelSpecInternal &model, const filesystem::path&, return Backend(model.getPrecision(), preferences); } -} // namespace Optimiser -} // namespace CUDA -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator::SingleThreadedCPU::Optimiser diff --git a/src/genn/generator/generator.cc b/src/genn/generator/generator.cc index ba1ae7c562..f47041cbd4 100644 --- a/src/genn/generator/generator.cc +++ b/src/genn/generator/generator.cc @@ -24,7 +24,8 @@ #include "optimiser.h" // Declare global GeNN preferences -using namespace CodeGenerator::BACKEND_NAMESPACE; +using namespace GeNN; +using namespace GeNN::CodeGenerator::BACKEND_NAMESPACE; Preferences GENN_PREFERENCES; // Include model diff --git a/src/genn/genn/binomial.cc b/src/genn/genn/binomial.cc index d302d29cda..2be7981fa4 100644 --- a/src/genn/genn/binomial.cc +++ b/src/genn/genn/binomial.cc @@ -35,7 +35,7 @@ double logPMFBinomial(unsigned int n, unsigned int k, double logP, double logOne //! Evaluates the inverse CDF of the binomial distribution directly from the definition //! The calculation is done mostly in the log domain except for the final //! accumulation of the probabilities -unsigned int binomialInverseCDF(double cdf, unsigned int n, double p) +unsigned int GeNN::binomialInverseCDF(double cdf, unsigned int n, double p) { // Validate cdf and p parameters if(cdf < 0.0 || cdf > 1.0) { diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 83e1ed06b9..9c40a8ee8a 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -16,9 +16,9 @@ #define FLOAT_TYPE(T) {#T, {sizeof(T), Utils::writePreciseString(std::numeric_limits::lowest())}} //-------------------------------------------------------------------------- -// CodeGenerator::BackendBase +// GeNN::CodeGenerator::BackendBase //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { BackendBase::BackendBase(const std::string &scalarType, const PreferencesBase &preferences) : m_PointerBytes(sizeof(char*)), m_Types{{TYPE(char), TYPE(wchar_t), TYPE(signed char), TYPE(short), @@ -269,4 +269,4 @@ std::vector BackendBase::genInitReductionTargets(C index); }); } -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator \ No newline at end of file diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index b901832fc2..35df3ca9d2 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -6,8 +6,6 @@ // GeNN code generator includes #include "code_generator/modelSpecMerged.h" -using namespace CodeGenerator; - //----------------------------------------------------------------------- // Anonymous namespace //----------------------------------------------------------------------- @@ -31,9 +29,9 @@ size_t getNumMergedGroupThreads(const std::vector &groups, G getNumThreads) } // Anonymous namespace //-------------------------------------------------------------------------- -// CodeGenerator::BackendSIMT +// GeNN::CodeGenerator::BackendSIMT //-------------------------------------------------------------------------- -namespace CodeGenerator +namespace GeNN::CodeGenerator { const char *BackendSIMT::KernelNames[KernelMax] = { "updateNeuronsKernel", @@ -1601,7 +1599,7 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne // If this connectivity requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id - if(::Utils::isRNGRequired(snippet->getRowBuildCode())) { + if(Utils::isRNGRequired(snippet->getRowBuildCode())) { genGlobalRNGSkipAhead(os, popSubs, "id"); } @@ -1613,7 +1611,7 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne // If this connectivity requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id - if(::Utils::isRNGRequired(snippet->getColBuildCode())) { + if(Utils::isRNGRequired(snippet->getColBuildCode())) { genGlobalRNGSkipAhead(os, popSubs, "id"); } @@ -1715,7 +1713,7 @@ void BackendSIMT::addDeviceType(const std::string &type, size_t size, const std: bool BackendSIMT::isDeviceType(const std::string &type) const { // Get underlying type - const std::string underlyingType = ::Utils::isTypePointer(type) ? ::Utils::getUnderlyingType(type) : type; + const std::string underlyingType = Utils::isTypePointer(type) ? Utils::getUnderlyingType(type) : type; // Return true if it is in device types set return (m_DeviceTypes.find(underlyingType) != m_DeviceTypes.cend()); @@ -1787,4 +1785,4 @@ const PresynapticUpdateStrategySIMT::Base *BackendSIMT::getPresynapticUpdateStra throw std::runtime_error("Unable to find a suitable presynaptic update strategy for synapse group '" + sg.getName() + "'"); return nullptr; } -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index 3da847bcad..2f358781fe 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -101,26 +101,15 @@ const char *mathsFuncs[][MathsFuncMax] = { {"fabs", "fabsf"}, {"fma", "fmaf"} }; - -//-------------------------------------------------------------------------- -/*! \brief This function removes explicit single precision function calls as - single-threaded CPU and CUDA kernels both support C++ i.e. overloads - and, while OpenCL kernels aren't in C++, OpenCL doesn't provide explicit - single precision maths functions, instead having some weird special case - */ //-------------------------------------------------------------------------- void ensureMathFunctionFtype(std::string &code) { // Replace any outstanding explicit single-precision maths functions // with C++ versions where overloads should work the same for(const auto &m : mathsFuncs) { - CodeGenerator::regexFuncSubstitute(code, m[MathsFuncSingle], m[MathsFuncCPP]); + GeNN::CodeGenerator::regexFuncSubstitute(code, m[MathsFuncSingle], m[MathsFuncCPP]); } } - -//-------------------------------------------------------------------------- -/*! \brief This function is part of the parser that converts any floating point constant in a code snippet to a floating point constant with an explicit precision (by appending "f" or removing it). - */ //-------------------------------------------------------------------------- void doFinal(std::string &code, unsigned int i, const std::string &type, unsigned int &state) { @@ -143,7 +132,7 @@ void doFinal(std::string &code, unsigned int i, const std::string &type, unsigne } } } - +//-------------------------------------------------------------------------- bool regexSubstitute(std::string &s, const std::regex ®ex, const std::string &format) { // **NOTE** the following code performs the same function as std::regex_replace @@ -204,14 +193,11 @@ std::string trimWhitespace(const std::string& str) } } // Anonymous namespace -//-------------------------------------------------------------------------- -// CodeGenerator -//-------------------------------------------------------------------------- -namespace CodeGenerator +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator +//---------------------------------------------------------------------------- +namespace GeNN::CodeGenerator { -//-------------------------------------------------------------------------- -//! \brief Tool for substituting strings in the neuron code strings or other templates -//-------------------------------------------------------------------------- void substitute(std::string &s, const std::string &trg, const std::string &rep) { size_t found= s.find(trg); @@ -220,10 +206,7 @@ void substitute(std::string &s, const std::string &trg, const std::string &rep) found= s.find(trg); } } - -//-------------------------------------------------------------------------- -//! \brief Tool for substituting variable names in the neuron code strings or other templates using regular expressions -//-------------------------------------------------------------------------- +//---------------------------------------------------------------------------- bool regexVarSubstitute(std::string &s, const std::string &trg, const std::string &rep) { // Build a regex to match variable name with at least one @@ -238,9 +221,7 @@ bool regexVarSubstitute(std::string &s, const std::string &trg, const std::strin return regexSubstitute(s, regex, format); } -//-------------------------------------------------------------------------- -//! \brief Tool for substituting function names in the neuron code strings or other templates using regular expressions -//-------------------------------------------------------------------------- +//---------------------------------------------------------------------------- bool regexFuncSubstitute(std::string &s, const std::string &trg, const std::string &rep) { // Build a regex to match function name with at least one @@ -254,18 +235,7 @@ bool regexFuncSubstitute(std::string &s, const std::string &trg, const std::stri return regexSubstitute(s, regex, format); } - -//-------------------------------------------------------------------------- -/*! \brief This function substitutes function calls in the form: - * - * $(functionName, parameter1, param2Function(0.12, "string")) - * - * with replacement templates in the form: - * - * actualFunction(CONSTANT, $(0), $(1)) - * - */ -//-------------------------------------------------------------------------- +//---------------------------------------------------------------------------- void functionSubstitute(std::string &code, const std::string &funcName, unsigned int numParams, const std::string &replaceFuncTemplate) { @@ -355,7 +325,7 @@ void functionSubstitute(std::string &code, const std::string &funcName, } } } - +//---------------------------------------------------------------------------- void genTypeRange(CodeStream &os, const std::string &precision, const std::string &prefix) { os << "#define " << prefix << "_MIN "; @@ -379,12 +349,7 @@ void genTypeRange(CodeStream &os, const std::string &precision, const std::strin } os << std::endl; } - -//-------------------------------------------------------------------------- -/*! \brief This function implements a parser that converts any floating point constant in a code snippet to a floating point constant with an explicit precision (by appending "f" or removing it). - */ -//-------------------------------------------------------------------------- - +//---------------------------------------------------------------------------- std::string ensureFtype(const std::string &oldcode, const std::string &type) { // cerr << "entering ensure" << endl; @@ -481,7 +446,7 @@ std::string ensureFtype(const std::string &oldcode, const std::string &type) ensureMathFunctionFtype(code); return code; } - +//---------------------------------------------------------------------------- std::string getReductionInitialValue(const BackendBase &backend, VarAccessMode access, const std::string &type) { // If reduction is a sum, initialise to zero @@ -497,7 +462,7 @@ std::string getReductionInitialValue(const BackendBase &backend, VarAccessMode a return ""; } } - +//---------------------------------------------------------------------------- std::string getReductionOperation(const std::string &reduction, const std::string &value, VarAccessMode access, const std::string &type) { // If operation is sum, add output of custom update to sum @@ -520,11 +485,7 @@ std::string getReductionOperation(const std::string &reduction, const std::strin return ""; } } -//-------------------------------------------------------------------------- -/*! \brief This function checks for unknown variable definitions and returns a gennError if any are found - */ -//-------------------------------------------------------------------------- - +//---------------------------------------------------------------------------- void checkUnreplacedVariables(const std::string &code, const std::string &codeName) { std::regex rgx("\\$\\([\\w]+\\)"); @@ -540,11 +501,7 @@ void checkUnreplacedVariables(const std::string &code, const std::string &codeNa throw std::runtime_error("The "+vars+"undefined in code "+codeName+"."); } } - -//-------------------------------------------------------------------------- -/*! \brief This function substitutes function names in a code with namespace as prefix of the function name for backends that do not support namespaces by checking that the function indeed exists in the support code and returns the substituted code. - */ - //-------------------------------------------------------------------------- +//---------------------------------------------------------------------------- std::string disambiguateNamespaceFunction(const std::string supportCode, const std::string code, std::string namespaceName) { // Regex for function call - looks for words with succeeding parentheses with or without any data inside the parentheses (arguments) std::regex funcCallRegex(R"(\w+(?=\(.*\)))"); @@ -569,4 +526,4 @@ std::string disambiguateNamespaceFunction(const std::string supportCode, const s } return newCode; } -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/code_generator/codeStream.cc b/src/genn/genn/code_generator/codeStream.cc index 5607b65149..33e257b30b 100644 --- a/src/genn/genn/code_generator/codeStream.cc +++ b/src/genn/genn/code_generator/codeStream.cc @@ -3,10 +3,10 @@ // Standard C++ includes #include -//---------------------------------------------------------------------------- -// CodeGenerator::CodeStream::IndentBuffer -//---------------------------------------------------------------------------- -namespace CodeGenerator +//------------------------------------------------------------------------ +// GeNN::CodeGenerator::CodeStream::IndentBuffer +//------------------------------------------------------------------------ +namespace GeNN::CodeGenerator { int CodeStream::IndentBuffer::overflow(int c) { @@ -35,7 +35,7 @@ int CodeStream::IndentBuffer::overflow(int c) } //------------------------------------------------------------------------ -// CodeGenerator::CodeStream::Scope +// GeNN::CodeGenerator::CodeStream::Scope //------------------------------------------------------------------------ unsigned int CodeStream::Scope::s_NextLevel = 0; @@ -76,4 +76,4 @@ std::ostream& operator << (std::ostream& s, const CodeStream::CB &cb) return c; } -} // namspace CodeGenerator \ No newline at end of file +} diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index bacccd8868..e11ad99c27 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -6,7 +6,8 @@ // GeNN code generator includes #include "code_generator/modelSpecMerged.h" -using namespace CodeGenerator; +using namespace GeNN; +using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- // CodeGenerator::CustomConnectivityUpdateGroupMergedBase diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index f98ac18d03..a5a42caa3e 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -3,7 +3,8 @@ // GeNN code generator includes #include "code_generator/modelSpecMerged.h" -using namespace CodeGenerator; +using namespace GeNN; +using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- // Anonymous namespace @@ -97,7 +98,7 @@ void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, } // Anonymous namespace //---------------------------------------------------------------------------- -// CodeGenerator::CustomUpdateGroupMerged +// GeNN::CodeGenerator::CustomUpdateGroupMerged //---------------------------------------------------------------------------- const std::string CustomUpdateGroupMerged::name = "CustomUpdate"; //---------------------------------------------------------------------------- @@ -218,7 +219,7 @@ std::string CustomUpdateGroupMerged::getVarRefIndex(bool delay, VarAccessDuplica } // ---------------------------------------------------------------------------- -// CodeGenerator::CustomUpdateWUGroupMergedBase +// GeNN::CodeGenerator::CustomUpdateWUGroupMergedBase //---------------------------------------------------------------------------- bool CustomUpdateWUGroupMergedBase::isParamHeterogeneous(const std::string ¶mName) const { @@ -364,7 +365,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const } // ---------------------------------------------------------------------------- -// CustomUpdateWUGroupMerged +// GeNN::CodeGenerator::CustomUpdateWUGroupMerged //---------------------------------------------------------------------------- const std::string CustomUpdateWUGroupMerged::name = "CustomUpdateWU"; //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/generateMSBuild.cc b/src/genn/genn/code_generator/generateMSBuild.cc index 4998340b72..26f3fa3f53 100644 --- a/src/genn/genn/code_generator/generateMSBuild.cc +++ b/src/genn/genn/code_generator/generateMSBuild.cc @@ -10,10 +10,10 @@ #include "code_generator/backendBase.h" //-------------------------------------------------------------------------- -// CodeGenerator +// GeNN::CodeGenerator //-------------------------------------------------------------------------- -void CodeGenerator::generateMSBuild(std::ostream &os, const ModelSpecInternal &model, const BackendBase &backend, - const std::string &projectGUID, const std::vector &moduleNames) +void GeNN::CodeGenerator::generateMSBuild(std::ostream &os, const ModelSpecInternal &model, const BackendBase &backend, + const std::string &projectGUID, const std::vector &moduleNames) { // Generate header and targets for release and debug builds os << "" << std::endl; @@ -82,4 +82,4 @@ void CodeGenerator::generateMSBuild(std::ostream &os, const ModelSpecInternal &m os << "" << std::endl; backend.genMSBuildImportTarget(os); os << "" << std::endl; -} \ No newline at end of file +} diff --git a/src/genn/genn/code_generator/generateMakefile.cc b/src/genn/genn/code_generator/generateMakefile.cc index 86767c97dc..66eda3b76b 100644 --- a/src/genn/genn/code_generator/generateMakefile.cc +++ b/src/genn/genn/code_generator/generateMakefile.cc @@ -10,10 +10,10 @@ #include "code_generator/backendBase.h" //-------------------------------------------------------------------------- -// CodeGenerator +// GeNN::CodeGenerator //-------------------------------------------------------------------------- -void CodeGenerator::generateMakefile(std::ostream &os, const BackendBase &backend, - const std::vector &moduleNames) +void GeNN::CodeGenerator::generateMakefile(std::ostream &os, const BackendBase &backend, + const std::vector &moduleNames) { //**TODO** deal with standard include paths e.g. MPI here diff --git a/src/genn/genn/code_generator/generateModules.cc b/src/genn/genn/code_generator/generateModules.cc index 2c9abc2bdb..4805727513 100644 --- a/src/genn/genn/code_generator/generateModules.cc +++ b/src/genn/genn/code_generator/generateModules.cc @@ -20,7 +20,8 @@ #include "code_generator/generateRunner.h" #include "code_generator/modelSpecMerged.h" -using namespace CodeGenerator; +using namespace GeNN; +using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- // Anonymous namespace @@ -88,11 +89,13 @@ bool shouldRebuildModel(const filesystem::path &outputPath, const boost::uuids:: } // Anonymous namespace //-------------------------------------------------------------------------- -// CodeGenerator +// GeNN::CodeGenerator //-------------------------------------------------------------------------- -std::pair, MemAlloc> CodeGenerator::generateAll(const ModelSpecInternal &model, const BackendBase &backend, - const filesystem::path &sharePath, const filesystem::path &outputPath, - bool forceRebuild) +namespace GeNN::CodeGenerator +{ +std::pair, MemAlloc> generateAll(const ModelSpecInternal &model, const BackendBase &backend, + const filesystem::path &sharePath, const filesystem::path &outputPath, + bool forceRebuild) { // Create directory for generated code filesystem::create_directory(outputPath); @@ -179,8 +182,8 @@ std::pair, MemAlloc> CodeGenerator::generateAll(const M return std::make_pair(modules, mem); } //-------------------------------------------------------------------------- -void CodeGenerator::generateNeuronUpdate(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix) +void generateNeuronUpdate(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, + const BackendBase &backend, const std::string &suffix) { // Create output stream to write to file and wrap in CodeStream std::ofstream neuronUpdateStream((outputPath / ("neuronUpdate" + suffix + ".cc")).str()); @@ -208,8 +211,8 @@ void CodeGenerator::generateNeuronUpdate(const filesystem::path &outputPath, con }); } //-------------------------------------------------------------------------- -void CodeGenerator::generateCustomUpdate(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix) +void generateCustomUpdate(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, + const BackendBase &backend, const std::string &suffix) { // Create output stream to write to file and wrap in CodeStream std::ofstream customUpdateStream((outputPath / ("customUpdate" + suffix + ".cc")).str()); @@ -240,8 +243,8 @@ void CodeGenerator::generateCustomUpdate(const filesystem::path &outputPath, con }); } //-------------------------------------------------------------------------- -void CodeGenerator::generateSynapseUpdate(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix) +void generateSynapseUpdate(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, + const BackendBase &backend, const std::string &suffix) { // Create output stream to write to file and wrap in CodeStream std::ofstream synapseUpdateStream((outputPath / ("synapseUpdate" + suffix + ".cc")).str()); @@ -272,8 +275,8 @@ void CodeGenerator::generateSynapseUpdate(const filesystem::path &outputPath, co }); } //-------------------------------------------------------------------------- -void CodeGenerator::generateInit(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix) +void generateInit(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, + const BackendBase &backend, const std::string &suffix) { // Create output stream to write to file and wrap in CodeStream std::ofstream initStream((outputPath / ("init" + suffix + ".cc")).str()); @@ -320,3 +323,4 @@ void CodeGenerator::generateInit(const filesystem::path &outputPath, const Model modelMerged.genScalarEGPPush(os, backend); }); } +} // namespace GeNN::CodeGenerator \ No newline at end of file diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 97cd9b6667..8671b6193f 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -20,7 +20,8 @@ #include "code_generator/backendBase.h" #include "code_generator/modelSpecMerged.h" -using namespace CodeGenerator; +using namespace GeNN; +using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- // Anonymous namespace @@ -509,10 +510,10 @@ void genCustomUpdate(const ModelSpecMerged &modelMerged, const BackendBase &back } // Anonymous namespace //-------------------------------------------------------------------------- -// CodeGenerator +// GeNN::CodeGenerator //-------------------------------------------------------------------------- -MemAlloc CodeGenerator::generateRunner(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix) +MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, + const BackendBase &backend, const std::string &suffix) { // Create output streams to write to file and wrap in CodeStreams std::ofstream definitionsStream((outputPath / ("definitions" + suffix + ".h")).str()); diff --git a/src/genn/genn/code_generator/generateSupportCode.cc b/src/genn/genn/code_generator/generateSupportCode.cc index a90036ee91..4a617803d4 100644 --- a/src/genn/genn/code_generator/generateSupportCode.cc +++ b/src/genn/genn/code_generator/generateSupportCode.cc @@ -10,10 +10,10 @@ #include "code_generator/modelSpecMerged.h" //-------------------------------------------------------------------------- -// CodeGenerator +// GeNN::CodeGenerator //-------------------------------------------------------------------------- -void CodeGenerator::generateSupportCode(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, - const std::string &suffix) +void GeNN::CodeGenerator::generateSupportCode(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, + const std::string &suffix) { std::ofstream supportCodeStream((outputPath / ("supportCode" + suffix + ".h")).str()); CodeStream supportCode(supportCodeStream); diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 6a3cc1b9d0..7168cf9eff 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -11,10 +11,11 @@ #include "code_generator/codeGenUtils.h" #include "code_generator/codeStream.h" -using namespace CodeGenerator; +using namespace GeNN; +using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- -// CodeGenerator::NeuronSpikeQueueUpdateGroupMerged +// GeNN::CodeGenerator::NeuronSpikeQueueUpdateGroupMerged //---------------------------------------------------------------------------- const std::string NeuronSpikeQueueUpdateGroupMerged::name = "NeuronSpikeQueueUpdate"; //---------------------------------------------------------------------------- @@ -61,7 +62,7 @@ void NeuronSpikeQueueUpdateGroupMerged::genMergedGroupSpikeCountReset(CodeStream } //---------------------------------------------------------------------------- -// CodeGenerator::NeuronPrevSpikeTimeUpdateGroupMerged +// GeNN::CodeGenerator::NeuronPrevSpikeTimeUpdateGroupMerged //---------------------------------------------------------------------------- const std::string NeuronPrevSpikeTimeUpdateGroupMerged::name = "NeuronPrevSpikeTimeUpdate"; //---------------------------------------------------------------------------- @@ -95,7 +96,7 @@ NeuronPrevSpikeTimeUpdateGroupMerged::NeuronPrevSpikeTimeUpdateGroupMerged(size_ } //---------------------------------------------------------------------------- -// CodeGenerator::NeuronGroupMergedBase +// GeNN::CodeGenerator::NeuronGroupMergedBase //---------------------------------------------------------------------------- bool NeuronGroupMergedBase::isParamHeterogeneous(const std::string ¶mName) const { @@ -473,7 +474,7 @@ bool NeuronGroupMergedBase::isPSMVarInitParamReferenced(size_t childIndex, const const auto *varInitSnippet = getSortedArchetypeMergedInSyns().at(childIndex)->getPSVarInitialisers().at(varName).getSnippet(); return isParamReferenced({varInitSnippet->getCode()}, paramName); } -//---------------------------------------------------------------------------- +//---------------GeNN::------------------------------------------------------------- void NeuronGroupMergedBase::addMergedInSynPointerField(const std::string &type, const std::string &name, size_t archetypeIndex, const std::string &prefix) { @@ -497,7 +498,7 @@ void NeuronGroupMergedBase::addMergedPreOutputOutSynPointerField(const std::stri } //---------------------------------------------------------------------------- -// CodeGenerator::SynapseGroupMergedBase +// GeNN::CodeGenerator::SynapseGroupMergedBase //---------------------------------------------------------------------------- bool SynapseGroupMergedBase::isWUParamHeterogeneous(const std::string ¶mName) const { diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index ea137b2d31..9b780ae12b 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -3,7 +3,8 @@ // GeNN code generator includes #include "code_generator/modelSpecMerged.h" -using namespace CodeGenerator; +using namespace GeNN; +using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- // Anonymous namespace @@ -180,7 +181,7 @@ void genInitWUVarCode(CodeStream &os, const Substitutions &popSubs, } // Anonymous namespace //---------------------------------------------------------------------------- -// CodeGenerator::NeuronInitGroupMerged +// GeNN::CodeGenerator::NeuronInitGroupMerged //---------------------------------------------------------------------------- const std::string NeuronInitGroupMerged::name = "NeuronInit"; //---------------------------------------------------------------------------- @@ -540,7 +541,7 @@ void NeuronInitGroupMerged::genInitSpikeTime(CodeStream &os, const BackendBase & } //---------------------------------------------------------------------------- -// CodeGenerator::SynapseInitGroupMerged +// GeNN::CodeGenerator::SynapseInitGroupMerged //---------------------------------------------------------------------------- const std::string SynapseInitGroupMerged::name = "SynapseInit"; //---------------------------------------------------------------------------- @@ -594,7 +595,7 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, CodeStream } //---------------------------------------------------------------------------- -// CodeGenerator::SynapseSparseInitGroupMerged +// GeNN::CodeGenerator::SynapseSparseInitGroupMerged //---------------------------------------------------------------------------- const std::string SynapseSparseInitGroupMerged::name = "SynapseSparseInit"; //---------------------------------------------------------------------------- @@ -612,7 +613,7 @@ void SynapseSparseInitGroupMerged::generateInit(const BackendBase &backend, Code } // ---------------------------------------------------------------------------- -// CodeGenerator::SynapseConnectivityInitGroupMerged +// GeNN::CodeGenerator::SynapseConnectivityInitGroupMerged //---------------------------------------------------------------------------- const std::string SynapseConnectivityInitGroupMerged::name = "SynapseConnectivityInit"; //---------------------------------------------------------------------------- @@ -848,7 +849,7 @@ bool SynapseConnectivityHostInitGroupMerged::isSparseConnectivityInitParamRefere } // ---------------------------------------------------------------------------- -// CustomUpdateInitGroupMerged +// GeNN::CodeGenerator::CustomUpdateInitGroupMerged //---------------------------------------------------------------------------- const std::string CustomUpdateInitGroupMerged::name = "CustomUpdateInit"; //---------------------------------------------------------------------------- @@ -883,7 +884,7 @@ void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, CodeS } // ---------------------------------------------------------------------------- -// CustomWUUpdateInitGroupMerged +// GeNN::CodeGenerator::CustomWUUpdateInitGroupMerged //---------------------------------------------------------------------------- const std::string CustomWUUpdateInitGroupMerged::name = "CustomWUUpdateInit"; //---------------------------------------------------------------------------- @@ -991,7 +992,7 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Cod } // ---------------------------------------------------------------------------- -// CustomWUUpdateSparseInitGroupMerged +// GeNN::CodeGenerator::CustomWUUpdateSparseInitGroupMerged //---------------------------------------------------------------------------- const std::string CustomWUUpdateSparseInitGroupMerged::name = "CustomWUUpdateSparseInit"; //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index 7c77195541..63358d44b3 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -7,7 +7,8 @@ // GeNN code generator includes #include "code_generator/backendBase.h" -using namespace CodeGenerator; +using namespace GeNN; +using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- // Anonymous namespace @@ -25,7 +26,7 @@ void assignGroups(const BackendBase &backend, std::vector &groups, BackendBas } //---------------------------------------------------------------------------- -// CodeGenerator::ModelSpecMerged +// GeNN::CodeGenerator::ModelSpecMerged //---------------------------------------------------------------------------- ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBase &backend) : m_Model(model), m_NeuronUpdateSupportCode("NeuronUpdateSupportCode"), m_PostsynapticDynamicsSupportCode("PostsynapticDynamicsSupportCode"), diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index f1807d8072..eca0d61207 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -3,10 +3,11 @@ // GeNN code generator includes #include "code_generator/modelSpecMerged.h" -using namespace CodeGenerator; +using namespace GeNN; +using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- -// CodeGenerator::NeuronUpdateGroupMerged +// GeNN::CodeGenerator::NeuronUpdateGroupMerged //---------------------------------------------------------------------------- const std::string NeuronUpdateGroupMerged::name = "NeuronUpdate"; //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 85abb4e37c..ac5cf8396a 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -14,18 +14,16 @@ #include "code_generator/modelSpecMerged.h" #include "code_generator/substitutions.h" - -using namespace CodeGenerator; - //---------------------------------------------------------------------------- // Anonymous namespace //---------------------------------------------------------------------------- namespace { -bool isSmallSharedMemoryPop(const PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) +bool isSmallSharedMemoryPop(const GeNN::CodeGenerator::PresynapticUpdateGroupMerged &sg, + const GeNN::CodeGenerator::BackendSIMT &backend) { // If shared memory atomics are slow - const size_t blockSize = backend.getKernelBlockSize(CodeGenerator::KernelPresynapticUpdate); + const size_t blockSize = backend.getKernelBlockSize(GeNN::CodeGenerator::KernelPresynapticUpdate); if(backend.areSharedMemAtomicsSlow()) { return false; } @@ -36,7 +34,7 @@ bool isSmallSharedMemoryPop(const PresynapticUpdateGroupMerged &sg, const Backen // Otherwise, we should accumulate each postsynaptic neuron's input in shared menory if all neuron groups targetted // by synapse groups within merged group are small enough that input to then can be stored in a shared memory array else if(std::all_of(sg.getGroups().cbegin(), sg.getGroups().cend(), - [blockSize](const SynapseGroupInternal &sg) + [blockSize](const GeNN::SynapseGroupInternal &sg) { return (sg.getTrgNeuronGroup()->getNumNeurons() <= blockSize); })) @@ -50,11 +48,9 @@ bool isSmallSharedMemoryPop(const PresynapticUpdateGroupMerged &sg, const Backen } // Anonymous namespace //---------------------------------------------------------------------------- -// CodeGenerator::PresynapticUpdateStrategySIMT::PreSpan +// GeNN::CodeGenerator::PresynapticUpdateStrategySIMT::PreSpan //---------------------------------------------------------------------------- -namespace CodeGenerator -{ -namespace PresynapticUpdateStrategySIMT +namespace GeNN::CodeGenerator::PresynapticUpdateStrategySIMT { size_t PreSpan::getNumThreads(const SynapseGroupInternal &sg) const { @@ -200,7 +196,7 @@ void PreSpan::genPostamble(CodeStream&, const ModelSpecMerged&, const Presynapti } //---------------------------------------------------------------------------- -// CodeGenerator::CUDA::PresynapticUpdateStrategy::PostSpan +// GeNN::CodeGenerator::PresynapticUpdateStrategySIMT::PostSpan //---------------------------------------------------------------------------- size_t PostSpan::getNumThreads(const SynapseGroupInternal &sg) const { @@ -435,7 +431,7 @@ bool PostSpan::shouldAccumulateInRegister(const PresynapticUpdateGroupMerged &sg } //-------------------------------------------------------------------------- -// CodeGenerator::CUDA::PresynapticUpdateStrategy::PreSpanProcedural +// GeNN::CodeGenerator::PresynapticUpdateStrategySIMT::PreSpanProcedural //-------------------------------------------------------------------------- size_t PreSpanProcedural::getNumThreads(const SynapseGroupInternal &sg) const { @@ -532,8 +528,8 @@ void PreSpanProcedural::genUpdate(CodeStream &os, const ModelSpecMerged &modelMe // If this connectivity requires an RNG for initialisation, // make copy of connect Phillox RNG and skip ahead to id that would have been used to initialize any variables associated with it - if(::Utils::isRNGRequired(sg.getArchetype().getConnectivityInitialiser().getSnippet()->getRowBuildCode()) - || ((sg.getArchetype().getMatrixType() & SynapseMatrixWeight::PROCEDURAL) && ::Utils::isRNGRequired(sg.getArchetype().getWUVarInitialisers()))) + if(Utils::isRNGRequired(sg.getArchetype().getConnectivityInitialiser().getSnippet()->getRowBuildCode()) + || ((sg.getArchetype().getMatrixType() & SynapseMatrixWeight::PROCEDURAL) && Utils::isRNGRequired(sg.getArchetype().getWUVarInitialisers()))) { std::stringstream skipAhead; if(numThreadsPerSpike > 1) { @@ -628,7 +624,7 @@ void PreSpanProcedural::genPostamble(CodeStream&, const ModelSpecMerged&, const } //---------------------------------------------------------------------------- -// CodeGenerator::CUDA::PresynapticUpdateStrategy::PostSpanBitmask +// GeNN::CodeGenerator::PresynapticUpdateStrategySIMT::PostSpanBitmask //---------------------------------------------------------------------------- size_t PostSpanBitmask::getNumThreads(const SynapseGroupInternal &sg) const { @@ -807,7 +803,7 @@ void PostSpanBitmask::genPostamble(CodeStream &os, const ModelSpecMerged &modelM } //-------------------------------------------------------------------------- -// CodeGenerator::PresynapticUpdateStrategySIMT::PostSpanToeplitz +// GeNN::CodeGenerator::PresynapticUpdateStrategySIMT::PostSpanToeplitz //-------------------------------------------------------------------------- size_t PostSpanToeplitz::getNumThreads(const SynapseGroupInternal &sg) const { @@ -999,5 +995,4 @@ void PostSpanToeplitz::genPostamble(CodeStream &os, const ModelSpecMerged &model } } } -} // namespace PresynapticUpdateStrategySIMT -} // namespace CodeGenerator +} // namespace GeNN::CodeGenerator::PresynapticUpdateStrategySIMT diff --git a/src/genn/genn/code_generator/substitutions.cc b/src/genn/genn/code_generator/substitutions.cc index 927f9e5535..d31792466a 100644 --- a/src/genn/genn/code_generator/substitutions.cc +++ b/src/genn/genn/code_generator/substitutions.cc @@ -4,10 +4,12 @@ #include "code_generator/codeGenUtils.h" //-------------------------------------------------------------------------- -// CodeGenerator::Substitutions +// GeNN::CodeGenerator::Substitutions //-------------------------------------------------------------------------- -void CodeGenerator::Substitutions::addParamValueSubstitution(const std::vector ¶mNames, const std::unordered_map &values, - const std::string &sourceSuffix) +namespace GeNN::CodeGenerator +{ +void Substitutions::addParamValueSubstitution(const std::vector ¶mNames, const std::unordered_map &values, + const std::string &sourceSuffix) { if(paramNames.size() != values.size()) { throw std::runtime_error("Number of parameters does not match number of values"); @@ -19,7 +21,7 @@ void CodeGenerator::Substitutions::addParamValueSubstitution(const std::vectorapplyVars(code); } -} \ No newline at end of file +} +} // namespace GeNN::CodeGenerator \ No newline at end of file diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index e0f63b8649..936f004b4c 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -3,8 +3,8 @@ // GeNN code generator includes #include "code_generator/modelSpecMerged.h" -using namespace CodeGenerator; - +using namespace GeNN; +using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- // Anonymous namespace @@ -173,7 +173,7 @@ void applySynapseSubstitutions(CodeStream &os, std::string code, const std::stri } // Anonymous namespace //---------------------------------------------------------------------------- -// CodeGenerator::PresynapticUpdateGroupMerged +// GeNN::CodeGenerator::PresynapticUpdateGroupMerged //---------------------------------------------------------------------------- const std::string PresynapticUpdateGroupMerged::name = "PresynapticUpdate"; //---------------------------------------------------------------------------- @@ -284,7 +284,7 @@ void PresynapticUpdateGroupMerged::generateToeplitzConnectivity(const BackendBas } //---------------------------------------------------------------------------- -// CodeGenerator::PostsynapticUpdateGroupMerged +// GeNN::CodeGenerator::PostsynapticUpdateGroupMerged //---------------------------------------------------------------------------- const std::string PostsynapticUpdateGroupMerged::name = "PostsynapticUpdate"; //---------------------------------------------------------------------------- @@ -300,7 +300,7 @@ void PostsynapticUpdateGroupMerged::generateSynapseUpdate(const BackendBase &bac } //---------------------------------------------------------------------------- -// CodeGenerator::SynapseDynamicsGroupMerged +// GeNN::CodeGenerator::SynapseDynamicsGroupMerged //---------------------------------------------------------------------------- const std::string SynapseDynamicsGroupMerged::name = "SynapseDynamics"; //---------------------------------------------------------------------------- diff --git a/src/genn/genn/currentSource.cc b/src/genn/genn/currentSource.cc index 5c1dbf6a79..0be714adb0 100644 --- a/src/genn/genn/currentSource.cc +++ b/src/genn/genn/currentSource.cc @@ -8,8 +8,10 @@ #include "gennUtils.h" //------------------------------------------------------------------------ -// CurrentSource +// GeNN::CurrentSource //------------------------------------------------------------------------ +namespace GeNN +{ void CurrentSource::setVarLocation(const std::string &varName, VarLocation loc) { m_VarLocation[getCurrentSourceModel()->getVarIndex(varName)] = loc; @@ -114,3 +116,4 @@ boost::uuids::detail::sha1::digest_type CurrentSource::getVarLocationHashDigest( Utils::updateHash(m_ExtraGlobalParamLocation, hash); return hash.get_digest(); } +} // namespace GeNN \ No newline at end of file diff --git a/src/genn/genn/currentSourceModels.cc b/src/genn/genn/currentSourceModels.cc index c50cd7ae60..a277079083 100644 --- a/src/genn/genn/currentSourceModels.cc +++ b/src/genn/genn/currentSourceModels.cc @@ -1,14 +1,18 @@ #include "currentSourceModels.h" +using namespace GeNN; + +namespace GeNN::CurrentSourceModels +{ // Implement models -IMPLEMENT_SNIPPET(CurrentSourceModels::DC); -IMPLEMENT_SNIPPET(CurrentSourceModels::GaussianNoise); -IMPLEMENT_SNIPPET(CurrentSourceModels::PoissonExp); +IMPLEMENT_SNIPPET(DC); +IMPLEMENT_SNIPPET(GaussianNoise); +IMPLEMENT_SNIPPET(PoissonExp); //---------------------------------------------------------------------------- -// CurrentSourceModels::Base +// GeNN::CurrentSourceModels::Base //---------------------------------------------------------------------------- -boost::uuids::detail::sha1::digest_type CurrentSourceModels::Base::getHashDigest() const +boost::uuids::detail::sha1::digest_type Base::getHashDigest() const { // Superclass boost::uuids::detail::sha1 hash; @@ -18,9 +22,9 @@ boost::uuids::detail::sha1::digest_type CurrentSourceModels::Base::getHashDigest return hash.get_digest(); } //---------------------------------------------------------------------------- -void CurrentSourceModels::Base::validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, - const std::string &description) const +void Base::validate(const std::unordered_map ¶mValues, + const std::unordered_map &varValues, + const std::string &description) const { // Superclass Models::Base::validate(paramValues, varValues, description); @@ -33,3 +37,4 @@ void CurrentSourceModels::Base::validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, - const std::unordered_map &preVarValues, - const std::unordered_map &postVarValues, - const std::unordered_map &varRefTargets, - const std::unordered_map &preVarRefTargets, - const std::unordered_map &postVarRefTargets, - const std::string &description) const +void Base::validate(const std::unordered_map ¶mValues, + const std::unordered_map &varValues, + const std::unordered_map &preVarValues, + const std::unordered_map &postVarValues, + const std::unordered_map &varRefTargets, + const std::unordered_map &preVarRefTargets, + const std::unordered_map &postVarRefTargets, + const std::string &description) const { // Superclass Models::Base::validate(paramValues, varValues, description); @@ -73,3 +75,4 @@ void CustomConnectivityUpdateModels::Base::validate(const std::unordered_mapgetVarIndex(varName)] = loc; @@ -173,7 +175,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdate::getInitHashDigest() const } //---------------------------------------------------------------------------- -// CustomUpdateWU +// GeNN::CustomUpdateWU //---------------------------------------------------------------------------- CustomUpdateWU::CustomUpdateWU(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, @@ -288,3 +290,4 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getInitHashDigest() cons Utils::updateHash(getSynapseGroup()->getSparseIndType(), hash); return hash.get_digest(); } +} // namespace GeNN \ No newline at end of file diff --git a/src/genn/genn/customUpdateModels.cc b/src/genn/genn/customUpdateModels.cc index de178878f9..7422261cb4 100644 --- a/src/genn/genn/customUpdateModels.cc +++ b/src/genn/genn/customUpdateModels.cc @@ -1,12 +1,16 @@ #include "customUpdateModels.h" +using namespace GeNN; + +namespace GeNN::CustomUpdateModels +{ // Implement models -IMPLEMENT_SNIPPET(CustomUpdateModels::Transpose); +IMPLEMENT_SNIPPET(Transpose); //---------------------------------------------------------------------------- -// CustomUpdateModels::Base +// GeNN::CustomUpdateModels::Base //---------------------------------------------------------------------------- -boost::uuids::detail::sha1::digest_type CustomUpdateModels::Base::getHashDigest() const +boost::uuids::detail::sha1::digest_type Base::getHashDigest() const { // Superclass boost::uuids::detail::sha1 hash; @@ -16,3 +20,4 @@ boost::uuids::detail::sha1::digest_type CustomUpdateModels::Base::getHashDigest( Utils::updateHash(getVarRefs(), hash); return hash.get_digest(); } +} diff --git a/src/genn/genn/gennUtils.cc b/src/genn/genn/gennUtils.cc index 8cd49a8899..830d78969a 100644 --- a/src/genn/genn/gennUtils.cc +++ b/src/genn/genn/gennUtils.cc @@ -37,9 +37,9 @@ GenericFunction randomFuncs[] = { } //-------------------------------------------------------------------------- -// Utils +// GeNN::Utils //-------------------------------------------------------------------------- -namespace Utils +namespace GeNN::Utils { bool isRNGRequired(const std::string &code) { @@ -145,4 +145,4 @@ void validateParamNames(const std::vector ¶mNames) validateVarName(p, "Parameter"); } } -} // namespace utils +} // namespace GeNN::utils diff --git a/src/genn/genn/initSparseConnectivitySnippet.cc b/src/genn/genn/initSparseConnectivitySnippet.cc index 6bfe27ae09..9a83167ad6 100644 --- a/src/genn/genn/initSparseConnectivitySnippet.cc +++ b/src/genn/genn/initSparseConnectivitySnippet.cc @@ -1,19 +1,23 @@ #include "initSparseConnectivitySnippet.h" +using namespace GeNN; + +namespace GeNN::InitSparseConnectivitySnippet +{ // Implement sparse connectivity initialization snippets -IMPLEMENT_SNIPPET(InitSparseConnectivitySnippet::Uninitialised); -IMPLEMENT_SNIPPET(InitSparseConnectivitySnippet::OneToOne); -IMPLEMENT_SNIPPET(InitSparseConnectivitySnippet::FixedProbability); -IMPLEMENT_SNIPPET(InitSparseConnectivitySnippet::FixedProbabilityNoAutapse); -IMPLEMENT_SNIPPET(InitSparseConnectivitySnippet::FixedNumberPostWithReplacement); -IMPLEMENT_SNIPPET(InitSparseConnectivitySnippet::FixedNumberTotalWithReplacement); -IMPLEMENT_SNIPPET(InitSparseConnectivitySnippet::FixedNumberPreWithReplacement); -IMPLEMENT_SNIPPET(InitSparseConnectivitySnippet::Conv2D); +IMPLEMENT_SNIPPET(Uninitialised); +IMPLEMENT_SNIPPET(OneToOne); +IMPLEMENT_SNIPPET(FixedProbability); +IMPLEMENT_SNIPPET(FixedProbabilityNoAutapse); +IMPLEMENT_SNIPPET(FixedNumberPostWithReplacement); +IMPLEMENT_SNIPPET(FixedNumberTotalWithReplacement); +IMPLEMENT_SNIPPET(FixedNumberPreWithReplacement); +IMPLEMENT_SNIPPET(Conv2D); //---------------------------------------------------------------------------- -// InitSparseConnectivitySnippet::Base +// GeNN::InitSparseConnectivitySnippet::Base //---------------------------------------------------------------------------- -boost::uuids::detail::sha1::digest_type InitSparseConnectivitySnippet::Base::getHashDigest() const +boost::uuids::detail::sha1::digest_type Base::getHashDigest() const { // Superclass boost::uuids::detail::sha1 hash; @@ -27,10 +31,11 @@ boost::uuids::detail::sha1::digest_type InitSparseConnectivitySnippet::Base::get return hash.get_digest(); } //---------------------------------------------------------------------------- -void InitSparseConnectivitySnippet::Base::validate(const std::unordered_map ¶mValues) const +void Base::validate(const std::unordered_map ¶mValues) const { // Superclass Snippet::Base::validate(paramValues, "Sparse connectivity initialiser "); Utils::validateVecNames(getRowBuildStateVars(), "Row building state variable"); Utils::validateVecNames(getColBuildStateVars(), "Column building state variable"); } +} // namespace GeNN::InitSparseConnectivitySnippet diff --git a/src/genn/genn/initToeplitzConnectivitySnippet.cc b/src/genn/genn/initToeplitzConnectivitySnippet.cc index e0dd415b95..c07f73766f 100644 --- a/src/genn/genn/initToeplitzConnectivitySnippet.cc +++ b/src/genn/genn/initToeplitzConnectivitySnippet.cc @@ -1,14 +1,18 @@ #include "initToeplitzConnectivitySnippet.h" +using namespace GeNN; + +namespace GeNN::InitToeplitzConnectivitySnippet +{ // Implement sparse connectivity initialization snippets -IMPLEMENT_SNIPPET(InitToeplitzConnectivitySnippet::Uninitialised); -IMPLEMENT_SNIPPET(InitToeplitzConnectivitySnippet::Conv2D); -IMPLEMENT_SNIPPET(InitToeplitzConnectivitySnippet::AvgPoolConv2D); +IMPLEMENT_SNIPPET(Uninitialised); +IMPLEMENT_SNIPPET(Conv2D); +IMPLEMENT_SNIPPET(AvgPoolConv2D); //---------------------------------------------------------------------------- -// InitToeplitzConnectivitySnippet::Base +// GeNN::InitToeplitzConnectivitySnippet::Base //---------------------------------------------------------------------------- -boost::uuids::detail::sha1::digest_type InitToeplitzConnectivitySnippet::Base::getHashDigest() const +boost::uuids::detail::sha1::digest_type Base::getHashDigest() const { // Superclass boost::uuids::detail::sha1 hash; @@ -19,9 +23,10 @@ boost::uuids::detail::sha1::digest_type InitToeplitzConnectivitySnippet::Base::g return hash.get_digest(); } //---------------------------------------------------------------------------- -void InitToeplitzConnectivitySnippet::Base::validate(const std::unordered_map ¶mValues) const +void Base::validate(const std::unordered_map ¶mValues) const { // Superclass Snippet::Base::validate(paramValues, "Toeplitz connectivity initialiser "); Utils::validateVecNames(getDiagonalBuildStateVars(), "Row building state variable"); } +} // namespace GeNN::InitToeplitzConnectivitySnippet \ No newline at end of file diff --git a/src/genn/genn/initVarSnippet.cc b/src/genn/genn/initVarSnippet.cc index 9ceb30595d..3db718a5a7 100644 --- a/src/genn/genn/initVarSnippet.cc +++ b/src/genn/genn/initVarSnippet.cc @@ -1,21 +1,25 @@ #include "initVarSnippet.h" +using namespace GeNN; + +namespace GeNN::InitVarSnippet +{ // Implement value initialization snippets -IMPLEMENT_SNIPPET(InitVarSnippet::Uninitialised); -IMPLEMENT_SNIPPET(InitVarSnippet::Constant); -IMPLEMENT_SNIPPET(InitVarSnippet::Kernel); -IMPLEMENT_SNIPPET(InitVarSnippet::Uniform); -IMPLEMENT_SNIPPET(InitVarSnippet::Normal); -IMPLEMENT_SNIPPET(InitVarSnippet::NormalClipped); -IMPLEMENT_SNIPPET(InitVarSnippet::NormalClippedDelay); -IMPLEMENT_SNIPPET(InitVarSnippet::Exponential); -IMPLEMENT_SNIPPET(InitVarSnippet::Gamma); -IMPLEMENT_SNIPPET(InitVarSnippet::Binomial); +IMPLEMENT_SNIPPET(Uninitialised); +IMPLEMENT_SNIPPET(Constant); +IMPLEMENT_SNIPPET(Kernel); +IMPLEMENT_SNIPPET(Uniform); +IMPLEMENT_SNIPPET(Normal); +IMPLEMENT_SNIPPET(NormalClipped); +IMPLEMENT_SNIPPET(NormalClippedDelay); +IMPLEMENT_SNIPPET(Exponential); +IMPLEMENT_SNIPPET(Gamma); +IMPLEMENT_SNIPPET(Binomial); //---------------------------------------------------------------------------- -// InitVarSnippet::Base +// GeNN::InitVarSnippet::Base //---------------------------------------------------------------------------- -boost::uuids::detail::sha1::digest_type InitVarSnippet::Base::getHashDigest() const +boost::uuids::detail::sha1::digest_type Base::getHashDigest() const { // Superclass boost::uuids::detail::sha1 hash; @@ -25,14 +29,14 @@ boost::uuids::detail::sha1::digest_type InitVarSnippet::Base::getHashDigest() co return hash.get_digest(); } //---------------------------------------------------------------------------- -void InitVarSnippet::Base::validate(const std::unordered_map ¶mValues) const +void Base::validate(const std::unordered_map ¶mValues) const { // Superclass Snippet::Base::validate(paramValues, "Variable initialiser "); } //---------------------------------------------------------------------------- -bool InitVarSnippet::Base::requiresKernel() const +bool Base::requiresKernel() const { return (getCode().find("$(id_kernel)") != std::string::npos); } - +} // namespace GeNN::InitVarSnippet \ No newline at end of file diff --git a/src/genn/genn/logging.cc b/src/genn/genn/logging.cc index c2e9d267fb..027ccf9816 100644 --- a/src/genn/genn/logging.cc +++ b/src/genn/genn/logging.cc @@ -1,10 +1,10 @@ #include "logging.h" //---------------------------------------------------------------------------- -// Logging +// GeNN::Logging //---------------------------------------------------------------------------- -void Logging::init(plog::Severity gennLevel, plog::Severity codeGeneratorLevel, - plog::IAppender *gennAppender, plog::IAppender *codeGeneratorAppender) +void GeNN::Logging::init(plog::Severity gennLevel, plog::Severity codeGeneratorLevel, + plog::IAppender *gennAppender, plog::IAppender *codeGeneratorAppender) { // If there isn't already a plog instance, initialise one if(plog::get() == nullptr) { diff --git a/src/genn/genn/modelSpec.cc b/src/genn/genn/modelSpec.cc index 35d92a179b..295f653cbd 100644 --- a/src/genn/genn/modelSpec.cc +++ b/src/genn/genn/modelSpec.cc @@ -30,10 +30,11 @@ #include "code_generator/codeGenUtils.h" #include "code_generator/substitutions.h" -// ------------------------------------------------------------------------ -// ModelSpec -// ------------------------------------------------------------------------ -// class ModelSpec for specifying a neuronal network model +// --------------------------------------------------------------------------- +// GeNN::ModelSpec +// --------------------------------------------------------------------------- +namespace GeNN +{ ModelSpec::ModelSpec() : m_TimePrecision(TimePrecision::DEFAULT), m_DT(0.5), m_TimingEnabled(false), m_Seed(0), m_DefaultVarLocation(VarLocation::HOST_DEVICE), m_DefaultExtraGlobalParamLocation(VarLocation::HOST_DEVICE), @@ -42,11 +43,11 @@ ModelSpec::ModelSpec() { setPrecision(ScalarPrecision::FLOAT); } - +// --------------------------------------------------------------------------- ModelSpec::~ModelSpec() { } - +// --------------------------------------------------------------------------- std::string ModelSpec::getTimePrecision() const { // If time precision is set to match model precision @@ -61,7 +62,7 @@ std::string ModelSpec::getTimePrecision() const return "double"; } } - +// --------------------------------------------------------------------------- unsigned int ModelSpec::getNumNeurons() const { // Return sum of local neuron group sizes @@ -71,7 +72,7 @@ unsigned int ModelSpec::getNumNeurons() const return total + n.second.getNumNeurons(); }); } - +// --------------------------------------------------------------------------- NeuronGroup *ModelSpec::addNeuronPopulation(const std::string &name, unsigned int size, const NeuronModels::Base *model, const ParamValues ¶mValues, const VarValues &varInitialisers) { @@ -89,7 +90,7 @@ NeuronGroup *ModelSpec::addNeuronPopulation(const std::string &name, unsigned in return &result.first->second; } } - +// --------------------------------------------------------------------------- SynapseGroup *ModelSpec::findSynapseGroup(const std::string &name) { // If a matching local synapse group is found, return it @@ -102,10 +103,7 @@ SynapseGroup *ModelSpec::findSynapseGroup(const std::string &name) throw std::runtime_error("synapse group " + name + " not found, aborting ..."); } } - -//-------------------------------------------------------------------------- -/*! \brief This function attempts to find an existing current source */ -//-------------------------------------------------------------------------- +// --------------------------------------------------------------------------- CurrentSource *ModelSpec::findCurrentSource(const std::string &name) { // If a matching local current source is found, return it @@ -118,7 +116,7 @@ CurrentSource *ModelSpec::findCurrentSource(const std::string &name) throw std::runtime_error("current source " + name + " not found, aborting ..."); } } - +// --------------------------------------------------------------------------- CurrentSource *ModelSpec::addCurrentSource(const std::string ¤tSourceName, const CurrentSourceModels::Base *model, const std::string &targetNeuronGroupName, const ParamValues ¶mValues, const VarValues &varInitialisers) { @@ -139,8 +137,7 @@ CurrentSource *ModelSpec::addCurrentSource(const std::string ¤tSourceName, return &result.first->second; } } - - +// --------------------------------------------------------------------------- CustomUpdate *ModelSpec::addCustomUpdate(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *model, const ParamValues ¶mValues, const VarValues &varInitialisers, const VarReferences &varReferences) @@ -159,7 +156,7 @@ CustomUpdate *ModelSpec::addCustomUpdate(const std::string &name, const std::str return &result.first->second; } } - +// --------------------------------------------------------------------------- CustomConnectivityUpdate *ModelSpec::addCustomConnectivityUpdate(const std::string &name, const std::string &updateGroupName, const std::string &targetSynapseGroupName, const CustomConnectivityUpdateModels::Base *model, const ParamValues ¶mValues, const VarValues &varInitialisers, @@ -185,7 +182,7 @@ CustomConnectivityUpdate *ModelSpec::addCustomConnectivityUpdate(const std::stri return &result.first->second; } } - +// --------------------------------------------------------------------------- CustomUpdateWU *ModelSpec::addCustomUpdate(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *model, const ParamValues ¶mValues, const VarValues &varInitialisers, const WUVarReferences &varReferences) @@ -204,11 +201,7 @@ CustomUpdateWU *ModelSpec::addCustomUpdate(const std::string &name, const std::s return &result.first->second; } } - -//-------------------------------------------------------------------------- -/*! \brief This function sets the numerical precision of floating type variables. By default, it is ScalarPrecision::FLOAT - */ -//-------------------------------------------------------------------------- +// --------------------------------------------------------------------------- void ModelSpec::setPrecision(ScalarPrecision scalarPrecision) { switch (scalarPrecision) { @@ -223,8 +216,7 @@ void ModelSpec::setPrecision(ScalarPrecision scalarPrecision) break; } } - - +// --------------------------------------------------------------------------- void ModelSpec::finalize() { // NEURON GROUPS @@ -329,7 +321,7 @@ void ModelSpec::finalize() } } } - +// --------------------------------------------------------------------------- std::string ModelSpec::scalarExpr(double val) const { if (m_Precision == "float") { @@ -342,8 +334,7 @@ std::string ModelSpec::scalarExpr(double val) const throw std::runtime_error("Unrecognised floating-point type."); } } - - +// --------------------------------------------------------------------------- bool ModelSpec::zeroCopyInUse() const { // If any neuron groups use zero copy return true @@ -390,13 +381,13 @@ bool ModelSpec::zeroCopyInUse() const return false; } - +// --------------------------------------------------------------------------- bool ModelSpec::isRecordingInUse() const { return std::any_of(m_LocalNeuronGroups.cbegin(), m_LocalNeuronGroups.cend(), [](const NeuronGroupValueType &n) { return n.second.isRecordingEnabled(); }); } - +// --------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type ModelSpec::getHashDigest() const { boost::uuids::detail::sha1 hash; @@ -411,7 +402,7 @@ boost::uuids::detail::sha1::digest_type ModelSpec::getHashDigest() const return hash.get_digest(); } - +// --------------------------------------------------------------------------- NeuronGroupInternal *ModelSpec::findNeuronGroupInternal(const std::string &name) { // If a matching local neuron group is found, return it @@ -424,7 +415,7 @@ NeuronGroupInternal *ModelSpec::findNeuronGroupInternal(const std::string &name) throw std::runtime_error("neuron group " + name + " not found, aborting ..."); } } - +// --------------------------------------------------------------------------- SynapseGroupInternal *ModelSpec::findSynapseGroupInternal(const std::string &name) { // If a matching local synapse group is found, return it @@ -437,32 +428,33 @@ SynapseGroupInternal *ModelSpec::findSynapseGroupInternal(const std::string &nam throw std::runtime_error("synapse group " + name + " not found, aborting ..."); } } - +// --------------------------------------------------------------------------- SynapseGroup *ModelSpec::addSynapsePopulation(const std::string &name, SynapseMatrixType mtype, unsigned int delaySteps, const std::string& src, const std::string& trg, const WeightUpdateModels::Base *wum, const ParamValues &weightParamValues, const VarValues &weightVarInitialisers, const VarValues &weightPreVarInitialisers, const VarValues &weightPostVarInitialisers, const PostsynapticModels::Base *psm, const ParamValues &postsynapticParamValues, const VarValues &postsynapticVarInitialisers, const InitSparseConnectivitySnippet::Init &connectivityInitialiser, const InitToeplitzConnectivitySnippet::Init &toeplitzConnectivityInitialiser) { - // Get source and target neuron groups - auto srcNeuronGrp = findNeuronGroupInternal(src); - auto trgNeuronGrp = findNeuronGroupInternal(trg); - - // Add synapse group to map - auto result = m_LocalSynapseGroups.emplace( - std::piecewise_construct, - std::forward_as_tuple(name), - std::forward_as_tuple(name, mtype, delaySteps, - wum, weightParamValues, weightVarInitialisers, weightPreVarInitialisers, weightPostVarInitialisers, - psm, postsynapticParamValues, postsynapticVarInitialisers, - srcNeuronGrp, trgNeuronGrp, - connectivityInitialiser, toeplitzConnectivityInitialiser, - m_DefaultVarLocation, m_DefaultExtraGlobalParamLocation, - m_DefaultSparseConnectivityLocation, m_DefaultNarrowSparseIndEnabled)); - - if(!result.second) { - throw std::runtime_error("Cannot add a synapse population with duplicate name:" + name); - } - else { - return &result.first->second; - } + // Get source and target neuron groups + auto srcNeuronGrp = findNeuronGroupInternal(src); + auto trgNeuronGrp = findNeuronGroupInternal(trg); + + // Add synapse group to map + auto result = m_LocalSynapseGroups.emplace( + std::piecewise_construct, + std::forward_as_tuple(name), + std::forward_as_tuple(name, mtype, delaySteps, + wum, weightParamValues, weightVarInitialisers, weightPreVarInitialisers, weightPostVarInitialisers, + psm, postsynapticParamValues, postsynapticVarInitialisers, + srcNeuronGrp, trgNeuronGrp, + connectivityInitialiser, toeplitzConnectivityInitialiser, + m_DefaultVarLocation, m_DefaultExtraGlobalParamLocation, + m_DefaultSparseConnectivityLocation, m_DefaultNarrowSparseIndEnabled)); + + if(!result.second) { + throw std::runtime_error("Cannot add a synapse population with duplicate name:" + name); } + else { + return &result.first->second; + } +} +} // namespace GeNN \ No newline at end of file diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 7fd1d4d88f..6ad949f3c6 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -7,11 +7,11 @@ #include "neuronGroupInternal.h" #include "synapseGroupInternal.h" -using namespace Models; - //---------------------------------------------------------------------------- -// Models::Base +// GeNN::Models::Base //---------------------------------------------------------------------------- +namespace GeNN::Models +{ void Base::updateHash(boost::uuids::detail::sha1 &hash) const { // Superclass @@ -203,28 +203,31 @@ SynapseGroup *WUVarReference::getTransposeSynapseGroup() const { return m_TransposeSG; } + +//---------------------------------------------------------------------------- +// Free functions //---------------------------------------------------------------------------- -void Models::updateHash(const Base::Var &v, boost::uuids::detail::sha1 &hash) +void updateHash(const Base::Var &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.name, hash); Utils::updateHash(v.type, hash); Utils::updateHash(v.access, hash); } //---------------------------------------------------------------------------- -void Models::updateHash(const Base::VarRef &v, boost::uuids::detail::sha1 &hash) +void updateHash(const Base::VarRef &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.name, hash); Utils::updateHash(v.type, hash); Utils::updateHash(v.access, hash); } //---------------------------------------------------------------------------- -void Models::updateHash(const VarReference &v, boost::uuids::detail::sha1 &hash) +void updateHash(const VarReference &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.getTargetName(), hash); Utils::updateHash(v.getVarIndex(), hash); } //---------------------------------------------------------------------------- -void Models::updateHash(const WUVarReference &v, boost::uuids::detail::sha1 &hash) +void updateHash(const WUVarReference &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.getTargetName(), hash); Utils::updateHash(v.getVarIndex(), hash); @@ -234,3 +237,4 @@ void Models::updateHash(const WUVarReference &v, boost::uuids::detail::sha1 &has Utils::updateHash(v.getTransposeVarIndex(), hash); } } +} // namespace GeNN::Models \ No newline at end of file diff --git a/src/genn/genn/neuronGroup.cc b/src/genn/genn/neuronGroup.cc index b54e35d26f..98c2b65e23 100644 --- a/src/genn/genn/neuronGroup.cc +++ b/src/genn/genn/neuronGroup.cc @@ -3,6 +3,7 @@ // Standard includes #include #include +#include // GeNN includes #include "currentSourceInternal.h" @@ -11,6 +12,8 @@ #include "synapseGroupInternal.h" #include "gennUtils.h" +using namespace GeNN; + // ------------------------------------------------------------------------ // Anonymous namespace // ------------------------------------------------------------------------ @@ -45,7 +48,7 @@ void fuseSynapseGroups(const std::vector &unmergedSyn, bo // Loop through un-merged synapse groups for(unsigned int i = 0; !syn.empty(); i++) { // Remove last element from vector - SynapseGroupInternal *a = syn.back(); + auto *a = syn.back(); syn.pop_back(); // Add A to vector of merged groups @@ -99,8 +102,10 @@ void fuseSynapseGroups(const std::vector &unmergedSyn, bo } // Anonymous namespace // ------------------------------------------------------------------------ -// NeuronGroup +// GeNN::NeuronGroup // ------------------------------------------------------------------------ +namespace GeNN +{ void NeuronGroup::setVarLocation(const std::string &varName, VarLocation loc) { m_VarLocation.at(getNeuronModel()->getVarIndex(varName)) = loc; @@ -589,4 +594,4 @@ void NeuronGroup::updateVarQueues(const std::string &code, const std::string &su } } } - +} // namespace GeNN diff --git a/src/genn/genn/neuronModels.cc b/src/genn/genn/neuronModels.cc index d11a01ace5..2a187e8395 100644 --- a/src/genn/genn/neuronModels.cc +++ b/src/genn/genn/neuronModels.cc @@ -1,23 +1,27 @@ #include "neuronModels.h" +using namespace GeNN; + +namespace GeNN::NeuronModels +{ // Implement models -IMPLEMENT_SNIPPET(NeuronModels::RulkovMap); -IMPLEMENT_SNIPPET(NeuronModels::Izhikevich); -IMPLEMENT_SNIPPET(NeuronModels::IzhikevichVariable); -IMPLEMENT_SNIPPET(NeuronModels::LIF); -IMPLEMENT_SNIPPET(NeuronModels::SpikeSource); -IMPLEMENT_SNIPPET(NeuronModels::SpikeSourceArray); -IMPLEMENT_SNIPPET(NeuronModels::Poisson); -IMPLEMENT_SNIPPET(NeuronModels::PoissonNew); -IMPLEMENT_SNIPPET(NeuronModels::TraubMiles); -IMPLEMENT_SNIPPET(NeuronModels::TraubMilesFast); -IMPLEMENT_SNIPPET(NeuronModels::TraubMilesAlt); -IMPLEMENT_SNIPPET(NeuronModels::TraubMilesNStep); +IMPLEMENT_SNIPPET(RulkovMap); +IMPLEMENT_SNIPPET(Izhikevich); +IMPLEMENT_SNIPPET(IzhikevichVariable); +IMPLEMENT_SNIPPET(LIF); +IMPLEMENT_SNIPPET(SpikeSource); +IMPLEMENT_SNIPPET(SpikeSourceArray); +IMPLEMENT_SNIPPET(Poisson); +IMPLEMENT_SNIPPET(PoissonNew); +IMPLEMENT_SNIPPET(TraubMiles); +IMPLEMENT_SNIPPET(TraubMilesFast); +IMPLEMENT_SNIPPET(TraubMilesAlt); +IMPLEMENT_SNIPPET(TraubMilesNStep); //---------------------------------------------------------------------------- -// NeuronModels::Base +// GeNN::NeuronModels::Base //---------------------------------------------------------------------------- -boost::uuids::detail::sha1::digest_type NeuronModels::Base::getHashDigest() const +boost::uuids::detail::sha1::digest_type Base::getHashDigest() const { // Superclass boost::uuids::detail::sha1 hash; @@ -32,9 +36,9 @@ boost::uuids::detail::sha1::digest_type NeuronModels::Base::getHashDigest() cons return hash.get_digest(); } //---------------------------------------------------------------------------- -void NeuronModels::Base::validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, - const std::string &description) const +void Base::validate(const std::unordered_map ¶mValues, + const std::unordered_map &varValues, + const std::string &description) const { // Superclass Models::Base::validate(paramValues, varValues, description); @@ -49,3 +53,4 @@ void NeuronModels::Base::validate(const std::unordered_map throw std::runtime_error("Neuron models cannot include variables with REDUCE access modes - they are only supported by custom update models"); } } +} // namespace GeNN::NeuronModels diff --git a/src/genn/genn/postsynapticModels.cc b/src/genn/genn/postsynapticModels.cc index 8e85bfe666..7f54f1a280 100644 --- a/src/genn/genn/postsynapticModels.cc +++ b/src/genn/genn/postsynapticModels.cc @@ -1,15 +1,18 @@ #include "postsynapticModels.h" -// Implement models -IMPLEMENT_SNIPPET(PostsynapticModels::ExpCurr); -IMPLEMENT_SNIPPET(PostsynapticModels::ExpCond); -IMPLEMENT_SNIPPET(PostsynapticModels::DeltaCurr); +using namespace GeNN; +namespace GeNN::PostsynapticModels +{ +// Implement models +IMPLEMENT_SNIPPET(ExpCurr); +IMPLEMENT_SNIPPET(ExpCond); +IMPLEMENT_SNIPPET(DeltaCurr); //---------------------------------------------------------------------------- -// PostsynapticModels::Base +// GeNN::PostsynapticModels::Base //---------------------------------------------------------------------------- -boost::uuids::detail::sha1::digest_type PostsynapticModels::Base::getHashDigest() const +boost::uuids::detail::sha1::digest_type Base::getHashDigest() const { // Superclass boost::uuids::detail::sha1 hash; @@ -21,9 +24,9 @@ boost::uuids::detail::sha1::digest_type PostsynapticModels::Base::getHashDigest( return hash.get_digest(); } //---------------------------------------------------------------------------- -void PostsynapticModels::Base::validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, - const std::string &description) const +void Base::validate(const std::unordered_map ¶mValues, + const std::unordered_map &varValues, + const std::string &description) const { // Superclass Models::Base::validate(paramValues, varValues, description); @@ -35,3 +38,4 @@ void PostsynapticModels::Base::validate(const std::unordered_map ¶mValues, const std::string &description) const +void Base::validate(const std::unordered_map ¶mValues, const std::string &description) const { const auto paramNames = getParamNames(); Utils::validateParamNames(paramNames); @@ -29,4 +31,5 @@ void Snippet::Base::validate(const std::unordered_map ¶ throw std::runtime_error(description + " missing value for parameter: '" + n + "'"); } } -} \ No newline at end of file +} +} // namespace GeNN::Snippet \ No newline at end of file diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 8cbf73bf98..02d3a56c6a 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -16,7 +16,7 @@ //---------------------------------------------------------------------------- namespace { -std::unordered_map getConstInitVals(const std::unordered_map &varInitialisers) +std::unordered_map getConstInitVals(const std::unordered_map &varInitialisers) { // Reserve initial values to match initialisers std::unordered_map initVals; @@ -26,7 +26,7 @@ std::unordered_map getConstInitVals(const std::unordered_ma [](const auto &v) { // Check - if(dynamic_cast(v.second.getSnippet()) == nullptr) { + if(dynamic_cast(v.second.getSnippet()) == nullptr) { throw std::runtime_error("Only 'Constant' variable initialisation snippets can be used to initialise state variables of synapse groups using GLOBALG"); } @@ -39,8 +39,10 @@ std::unordered_map getConstInitVals(const std::unordered_ma } // Anonymous namespace // ------------------------------------------------------------------------ -// SynapseGroup +// GeNN::SynapseGroup // ------------------------------------------------------------------------ +namespace GeNN +{ void SynapseGroup::setWUVarLocation(const std::string &varName, VarLocation loc) { m_WUVarLocation[getWUModel()->getVarIndex(varName)] = loc; @@ -489,7 +491,7 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType } // Otherwise, if WEIGHTS are procedural e.g. in the case of DENSE_PROCEDURALG, give error if RNG is required for weights else if(m_MatrixType & SynapseMatrixWeight::PROCEDURAL) { - if(::Utils::isRNGRequired(m_WUVarInitialisers)) { + if(Utils::isRNGRequired(m_WUVarInitialisers)) { throw std::runtime_error("Procedural weights used without procedural connectivity cannot currently access RNG."); } } @@ -507,7 +509,7 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType } // Give an error if connectivity initialisation snippet uses RNG - if(::Utils::isRNGRequired(m_ToeplitzConnectivityInitialiser.getSnippet()->getDiagonalBuildCode())) { + if(Utils::isRNGRequired(m_ToeplitzConnectivityInitialiser.getSnippet()->getDiagonalBuildCode())) { throw std::runtime_error("TOEPLITZ connectivity cannot currently access RNG."); } @@ -990,3 +992,4 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getVarLocationHashDigest() Utils::updateHash(m_PSExtraGlobalParamLocation, hash); return hash.get_digest(); } +} // namespace GeNN diff --git a/src/genn/genn/weightUpdateModels.cc b/src/genn/genn/weightUpdateModels.cc index b622369fd3..619fec2a67 100644 --- a/src/genn/genn/weightUpdateModels.cc +++ b/src/genn/genn/weightUpdateModels.cc @@ -1,14 +1,18 @@ #include "weightUpdateModels.h" -IMPLEMENT_SNIPPET(WeightUpdateModels::StaticPulse); -IMPLEMENT_SNIPPET(WeightUpdateModels::StaticPulseDendriticDelay); -IMPLEMENT_SNIPPET(WeightUpdateModels::StaticGraded); -IMPLEMENT_SNIPPET(WeightUpdateModels::PiecewiseSTDP); +using namespace GeNN; + +namespace GeNN::WeightUpdateModels +{ +IMPLEMENT_SNIPPET(StaticPulse); +IMPLEMENT_SNIPPET(StaticPulseDendriticDelay); +IMPLEMENT_SNIPPET(StaticGraded); +IMPLEMENT_SNIPPET(PiecewiseSTDP); //---------------------------------------------------------------------------- -// WeightUpdateModels::Base +// GeNN::WeightUpdateModels::Base //---------------------------------------------------------------------------- -boost::uuids::detail::sha1::digest_type WeightUpdateModels::Base::getHashDigest() const +boost::uuids::detail::sha1::digest_type Base::getHashDigest() const { // Superclass boost::uuids::detail::sha1 hash; @@ -37,11 +41,11 @@ boost::uuids::detail::sha1::digest_type WeightUpdateModels::Base::getHashDigest( return hash.get_digest(); } //---------------------------------------------------------------------------- -void WeightUpdateModels::Base::validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, - const std::unordered_map &preVarValues, - const std::unordered_map &postVarValues, - const std::string &description) const +void Base::validate(const std::unordered_map ¶mValues, + const std::unordered_map &varValues, + const std::unordered_map &preVarValues, + const std::unordered_map &postVarValues, + const std::string &description) const { // Superclass Models::Base::validate(paramValues, varValues, description); @@ -76,3 +80,4 @@ void WeightUpdateModels::Base::validate(const std::unordered_map Date: Thu, 5 Jan 2023 17:14:42 +0000 Subject: [PATCH 002/725] fixed backend issues --- include/genn/backends/opencl/backend.h | 4 +-- include/genn/backends/opencl/optimiser.h | 4 +++ .../backends/single_threaded_cpu/optimiser.h | 4 +++ src/genn/backends/opencl/backend.cc | 30 +++++++++---------- src/genn/backends/opencl/optimiser.cc | 2 ++ 5 files changed, 27 insertions(+), 17 deletions(-) diff --git a/include/genn/backends/opencl/backend.h b/include/genn/backends/opencl/backend.h index 3a8265d359..ad52959179 100644 --- a/include/genn/backends/opencl/backend.h +++ b/include/genn/backends/opencl/backend.h @@ -365,7 +365,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT const auto sortedFields = g.getSortedFields(*this); for(size_t fieldIndex = 0; fieldIndex < sortedFields.size(); fieldIndex++) { const auto &f = sortedFields[fieldIndex]; - if(::Utils::isTypePointer(std::get<0>(f))) { + if(GeNN::Utils::isTypePointer(std::get<0>(f))) { os << "__global "; } os << std::get<0>(f) << " " << std::get<1>(f); @@ -394,7 +394,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT os << "__kernel void setMerged" << T::name << f.mergedGroupIndex << f.fieldName << "Kernel("; os << "__global struct Merged" << T::name << "Group" << f.mergedGroupIndex << " *group, unsigned int idx, "; - if(::Utils::isTypePointer(f.type)) { + if(GeNN::Utils::isTypePointer(f.type)) { os << "__global "; } os << f.type << " " << f.fieldName << ")"; diff --git a/include/genn/backends/opencl/optimiser.h b/include/genn/backends/opencl/optimiser.h index b5611c35f4..7168cf722e 100644 --- a/include/genn/backends/opencl/optimiser.h +++ b/include/genn/backends/opencl/optimiser.h @@ -7,7 +7,11 @@ #include "backend.h" // Forward declarations +namespace GeNN +{ class ModelSpecInternal; +} + namespace plog { class IAppender; diff --git a/include/genn/backends/single_threaded_cpu/optimiser.h b/include/genn/backends/single_threaded_cpu/optimiser.h index 119eeb1a36..570642a3f5 100644 --- a/include/genn/backends/single_threaded_cpu/optimiser.h +++ b/include/genn/backends/single_threaded_cpu/optimiser.h @@ -10,7 +10,11 @@ #include "backend.h" // Forward declarations +namespace GeNN +{ class ModelSpecInternal; +} + namespace plog { class IAppender; diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index 1287af6493..d383f4da60 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -1934,7 +1934,7 @@ void Backend::genExtraGlobalParamDefinition(CodeStream &definitions, CodeStream definitions << "EXPORT_VAR " << type << " " << name << ";" << std::endl; definitionsInternal << "EXPORT_VAR cl::Buffer h_" << name << ";" << std::endl; } - if (loc & VarLocation::DEVICE && ::Utils::isTypePointer(type)) { + if (loc & VarLocation::DEVICE && GeNN::Utils::isTypePointer(type)) { definitionsInternal << "EXPORT_VAR cl::Buffer d_" << name << ";" << std::endl; } } @@ -1945,7 +1945,7 @@ void Backend::genExtraGlobalParamImplementation(CodeStream &os, const std::strin os << type << " " << name << ";" << std::endl; os << "cl::Buffer h_" << name << ";" << std::endl; } - if (loc & VarLocation::DEVICE && ::Utils::isTypePointer(type)) { + if (loc & VarLocation::DEVICE && GeNN::Utils::isTypePointer(type)) { os << "cl::Buffer d_" << name << ";" << std::endl; } } @@ -1954,8 +1954,8 @@ void Backend::genExtraGlobalParamAllocation(CodeStream &os, const std::string &t VarLocation loc, const std::string &countVarName, const std::string &prefix) const { // Get underlying type - const std::string underlyingType = ::Utils::getUnderlyingType(type); - const bool pointerToPointer = ::Utils::isTypePointerToPointer(type); + const std::string underlyingType = GeNN::Utils::getUnderlyingType(type); + const bool pointerToPointer = GeNN::Utils::isTypePointerToPointer(type); const std::string hostPointer = pointerToPointer ? ("*" + prefix + name) : (prefix + name); const std::string deviceBuffer = pointerToPointer ? ("*" + prefix + "d_" + name) : (prefix + "d_" + name); @@ -2008,8 +2008,8 @@ void Backend::genExtraGlobalParamPush(CodeStream &os, const std::string &type, c assert(!getPreferences().automaticCopy); // Get underlying type - const std::string underlyingType = ::Utils::getUnderlyingType(type); - const bool pointerToPointer = ::Utils::isTypePointerToPointer(type); + const std::string underlyingType = GeNN::Utils::getUnderlyingType(type); + const bool pointerToPointer = GeNN::Utils::isTypePointerToPointer(type); const std::string hostPointer = pointerToPointer ? ("*" + prefix + name) : (prefix + name); const std::string devicePointer = pointerToPointer ? ("*" + prefix + "d_" + name) : (prefix + "d_" + name); @@ -2029,8 +2029,8 @@ void Backend::genExtraGlobalParamPull(CodeStream &os, const std::string &type, c assert(!getPreferences().automaticCopy); // Get underlying type - const std::string underlyingType = ::Utils::getUnderlyingType(type); - const bool pointerToPointer = ::Utils::isTypePointerToPointer(type); + const std::string underlyingType = GeNN::Utils::getUnderlyingType(type); + const bool pointerToPointer = GeNN::Utils::isTypePointerToPointer(type); const std::string hostPointer = pointerToPointer ? ("*" + prefix + name) : (prefix + name); const std::string devicePointer = pointerToPointer ? ("*" + prefix + "d_" + name) : (prefix + "d_" + name); @@ -2060,10 +2060,10 @@ void Backend::genMergedExtraGlobalParamPush(CodeStream &os, const std::string &s std::string Backend::getMergedGroupFieldHostType(const std::string &type) const { // If type is a pointer, on the host it is represented by an OpenCL buffer - if(::Utils::isTypePointerToPointer(type)) { + if(GeNN::Utils::isTypePointerToPointer(type)) { return "cl::Buffer*"; } - if(::Utils::isTypePointer(type)) { + else if(GeNN::Utils::isTypePointer(type)) { return "cl::Buffer"; } // Otherwise, type remains the same @@ -2407,13 +2407,13 @@ boost::uuids::detail::sha1::digest_type Backend::getHashDigest() const boost::uuids::detail::sha1 hash; // Update hash was name of backend - ::Utils::updateHash("OpenCL", hash); + GeNN::Utils::updateHash("OpenCL", hash); // Update hash with chosen device ID and kernel block sizes - ::Utils::updateHash(m_ChosenPlatformIndex, hash); - ::Utils::updateHash(m_ChosenDeviceIndex, hash); - ::Utils::updateHash(m_AllocationAlignementBytes, hash); - ::Utils::updateHash(getKernelBlockSize(), hash); + GeNN::Utils::updateHash(m_ChosenPlatformIndex, hash); + GeNN::Utils::updateHash(m_ChosenDeviceIndex, hash); + GeNN::Utils::updateHash(m_AllocationAlignementBytes, hash); + GeNN::Utils::updateHash(getKernelBlockSize(), hash); // Update hash with preferences getPreferences().updateHash(hash); diff --git a/src/genn/backends/opencl/optimiser.cc b/src/genn/backends/opencl/optimiser.cc index 811b8164c2..9ab7a4add0 100644 --- a/src/genn/backends/opencl/optimiser.cc +++ b/src/genn/backends/opencl/optimiser.cc @@ -7,6 +7,8 @@ #include "logging.h" #include "modelSpecInternal.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // Anonymous namespace //-------------------------------------------------------------------------- From ce17b3938cdd2a25b341422e0967e0cd1c7d209d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 5 Jan 2023 17:14:48 +0000 Subject: [PATCH 003/725] fixed pygenn --- pygenn/src/cudaBackend.cc | 4 ++-- pygenn/src/currentSourceModels.cc | 2 +- pygenn/src/customUpdateModels.cc | 2 +- pygenn/src/genn.cc | 1 + pygenn/src/initSparseConnectivitySnippets.cc | 2 +- pygenn/src/initToeplitzConnectivitySnippets.cc | 2 +- pygenn/src/initVarSnippets.cc | 2 +- pygenn/src/neuronModels.cc | 2 +- pygenn/src/openclBackend.cc | 4 ++-- pygenn/src/postsynapticModels.cc | 2 +- pygenn/src/singleThreadedCPUBackend.cc | 4 ++-- pygenn/src/weightUpdateModels.cc | 2 +- 12 files changed, 15 insertions(+), 14 deletions(-) diff --git a/pygenn/src/cudaBackend.cc b/pygenn/src/cudaBackend.cc index 80567f2182..0e720b166d 100644 --- a/pygenn/src/cudaBackend.cc +++ b/pygenn/src/cudaBackend.cc @@ -11,8 +11,8 @@ // CUDA backend includes #include "optimiser.h" - -using namespace CodeGenerator::CUDA; +using namespace GeNN; +using namespace GeNN::CodeGenerator::CUDA; //---------------------------------------------------------------------------- // Anonymous namespace diff --git a/pygenn/src/currentSourceModels.cc b/pygenn/src/currentSourceModels.cc index 4a117ea61a..fee17cf636 100644 --- a/pygenn/src/currentSourceModels.cc +++ b/pygenn/src/currentSourceModels.cc @@ -5,7 +5,7 @@ // GeNN includes #include "currentSourceModels.h" -using namespace CurrentSourceModels; +using namespace GeNN::CurrentSourceModels; namespace { diff --git a/pygenn/src/customUpdateModels.cc b/pygenn/src/customUpdateModels.cc index eb0602adef..345fd7bffc 100644 --- a/pygenn/src/customUpdateModels.cc +++ b/pygenn/src/customUpdateModels.cc @@ -5,7 +5,7 @@ // GeNN includes #include "customUpdateModels.h" -using namespace CustomUpdateModels; +using namespace GeNN::CustomUpdateModels; namespace { diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index bf36d19988..7615690dce 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -28,6 +28,7 @@ #include "code_generator/generateModules.h" #include "code_generator/generateMSBuild.h" +using namespace GeNN; using namespace pybind11::literals; //---------------------------------------------------------------------------- diff --git a/pygenn/src/initSparseConnectivitySnippets.cc b/pygenn/src/initSparseConnectivitySnippets.cc index 49e5800b2f..d3e4427a16 100644 --- a/pygenn/src/initSparseConnectivitySnippets.cc +++ b/pygenn/src/initSparseConnectivitySnippets.cc @@ -5,7 +5,7 @@ // GeNN includes #include "initSparseConnectivitySnippet.h" -using namespace InitSparseConnectivitySnippet; +using namespace GeNN::InitSparseConnectivitySnippet; namespace { diff --git a/pygenn/src/initToeplitzConnectivitySnippets.cc b/pygenn/src/initToeplitzConnectivitySnippets.cc index f55d092866..d6b9d291c3 100644 --- a/pygenn/src/initToeplitzConnectivitySnippets.cc +++ b/pygenn/src/initToeplitzConnectivitySnippets.cc @@ -5,7 +5,7 @@ // GeNN includes #include "initToeplitzConnectivitySnippet.h" -using namespace InitToeplitzConnectivitySnippet; +using namespace GeNN::InitToeplitzConnectivitySnippet; namespace { diff --git a/pygenn/src/initVarSnippets.cc b/pygenn/src/initVarSnippets.cc index 61c602e9ee..b923108ada 100644 --- a/pygenn/src/initVarSnippets.cc +++ b/pygenn/src/initVarSnippets.cc @@ -5,7 +5,7 @@ // GeNN includes #include "initVarSnippet.h" -using namespace InitVarSnippet; +using namespace GeNN::InitVarSnippet; namespace { diff --git a/pygenn/src/neuronModels.cc b/pygenn/src/neuronModels.cc index ebd862779b..d3f66260b1 100644 --- a/pygenn/src/neuronModels.cc +++ b/pygenn/src/neuronModels.cc @@ -5,7 +5,7 @@ // GeNN includes #include "neuronModels.h" -using namespace NeuronModels; +using namespace GeNN::NeuronModels; namespace { diff --git a/pygenn/src/openclBackend.cc b/pygenn/src/openclBackend.cc index 48fd7c3ebc..b6ea55d089 100644 --- a/pygenn/src/openclBackend.cc +++ b/pygenn/src/openclBackend.cc @@ -11,8 +11,8 @@ // CUDA backend includes #include "optimiser.h" - -using namespace CodeGenerator::OpenCL; +using namespace GeNN; +using namespace GeNN::CodeGenerator::OpenCL; //---------------------------------------------------------------------------- // Anonymous namespace diff --git a/pygenn/src/postsynapticModels.cc b/pygenn/src/postsynapticModels.cc index e2cbbfa582..e318ff1e30 100644 --- a/pygenn/src/postsynapticModels.cc +++ b/pygenn/src/postsynapticModels.cc @@ -5,7 +5,7 @@ // GeNN includes #include "postsynapticModels.h" -using namespace PostsynapticModels; +using namespace GeNN::PostsynapticModels; namespace { diff --git a/pygenn/src/singleThreadedCPUBackend.cc b/pygenn/src/singleThreadedCPUBackend.cc index 42b7af1d88..9afcd81adc 100644 --- a/pygenn/src/singleThreadedCPUBackend.cc +++ b/pygenn/src/singleThreadedCPUBackend.cc @@ -11,8 +11,8 @@ // CUDA backend includes #include "optimiser.h" - -using namespace CodeGenerator::SingleThreadedCPU; +using namespace GeNN; +using namespace GeNN::CodeGenerator::SingleThreadedCPU; //---------------------------------------------------------------------------- // Anonymous namespace diff --git a/pygenn/src/weightUpdateModels.cc b/pygenn/src/weightUpdateModels.cc index dc9b0becf6..e823e879d8 100644 --- a/pygenn/src/weightUpdateModels.cc +++ b/pygenn/src/weightUpdateModels.cc @@ -5,7 +5,7 @@ // GeNN includes #include "weightUpdateModels.h" -using namespace WeightUpdateModels; +using namespace GeNN::WeightUpdateModels; namespace { From 024ea95dd5801a30bcc55e36c76923f1f864cf00 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 6 Jan 2023 10:30:36 +0000 Subject: [PATCH 004/725] updated unit tests --- tests/unit/binomial.cc | 2 ++ tests/unit/codeGenUtils.cc | 3 ++- tests/unit/currentSource.cc | 2 ++ tests/unit/currentSourceModels.cc | 2 ++ tests/unit/customUpdate.cc | 2 ++ tests/unit/gennUtils.cc | 2 ++ tests/unit/initSparseConnectivitySnippet.cc | 2 ++ tests/unit/initVarSnippet.cc | 2 ++ tests/unit/modelSpec.cc | 2 ++ tests/unit/modelSpecMerged.cc | 2 ++ tests/unit/models.cc | 2 ++ tests/unit/neuronGroup.cc | 2 ++ tests/unit/neuronModels.cc | 2 ++ tests/unit/postsynapticModels.cc | 2 ++ tests/unit/synapseGroup.cc | 2 ++ tests/unit/weightUpdateModels.cc | 2 ++ 16 files changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/unit/binomial.cc b/tests/unit/binomial.cc index b762280d82..fcb2415c24 100644 --- a/tests/unit/binomial.cc +++ b/tests/unit/binomial.cc @@ -7,6 +7,8 @@ // GeNN code generator includes #include "binomial.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // Tests //-------------------------------------------------------------------------- diff --git a/tests/unit/codeGenUtils.cc b/tests/unit/codeGenUtils.cc index b4223063ab..a935829b21 100644 --- a/tests/unit/codeGenUtils.cc +++ b/tests/unit/codeGenUtils.cc @@ -12,7 +12,8 @@ #include "code_generator/codeGenUtils.h" #include "code_generator/substitutions.h" -using namespace CodeGenerator; +using namespace GeNN; +using namespace GeNN::CodeGenerator; // Test based on original issue found in https://github.com/brian-team/brian2genn/pull/60 to make sure that ensureFtype doesn't break functions it shouldn't TEST(CodeGenUtils, ISinF) { diff --git a/tests/unit/currentSource.cc b/tests/unit/currentSource.cc index 8ebb3840b7..1af2790d43 100644 --- a/tests/unit/currentSource.cc +++ b/tests/unit/currentSource.cc @@ -4,6 +4,8 @@ // GeNN includes #include "modelSpecInternal.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // Tests //-------------------------------------------------------------------------- diff --git a/tests/unit/currentSourceModels.cc b/tests/unit/currentSourceModels.cc index 56a0fabc26..5a46bd9df9 100644 --- a/tests/unit/currentSourceModels.cc +++ b/tests/unit/currentSourceModels.cc @@ -4,6 +4,8 @@ // GeNN includes #include "currentSourceModels.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // GaussianNoiseCopy //-------------------------------------------------------------------------- diff --git a/tests/unit/customUpdate.cc b/tests/unit/customUpdate.cc index d9404c787e..9c5d9c17a3 100644 --- a/tests/unit/customUpdate.cc +++ b/tests/unit/customUpdate.cc @@ -10,6 +10,8 @@ // (Single-threaded CPU) backend includes #include "backend.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // Anonymous namespace //-------------------------------------------------------------------------- diff --git a/tests/unit/gennUtils.cc b/tests/unit/gennUtils.cc index ba12ad1e16..e22309c4cd 100644 --- a/tests/unit/gennUtils.cc +++ b/tests/unit/gennUtils.cc @@ -5,6 +5,8 @@ #include "gennUtils.h" #include "snippet.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // Anonymous namespace //-------------------------------------------------------------------------- diff --git a/tests/unit/initSparseConnectivitySnippet.cc b/tests/unit/initSparseConnectivitySnippet.cc index f80be5dfa8..0d2aa5a23c 100644 --- a/tests/unit/initSparseConnectivitySnippet.cc +++ b/tests/unit/initSparseConnectivitySnippet.cc @@ -4,6 +4,8 @@ // GeNN includes #include "modelSpec.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // OneToOneCopy //-------------------------------------------------------------------------- diff --git a/tests/unit/initVarSnippet.cc b/tests/unit/initVarSnippet.cc index a3cac9128b..895bf988a3 100644 --- a/tests/unit/initVarSnippet.cc +++ b/tests/unit/initVarSnippet.cc @@ -4,6 +4,8 @@ // GeNN includes #include "modelSpec.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // UniformCopy //-------------------------------------------------------------------------- diff --git a/tests/unit/modelSpec.cc b/tests/unit/modelSpec.cc index bfa5805c4b..0f378f7e7e 100644 --- a/tests/unit/modelSpec.cc +++ b/tests/unit/modelSpec.cc @@ -4,6 +4,8 @@ // GeNN includes #include "modelSpecInternal.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // Anonymous namespace //-------------------------------------------------------------------------- diff --git a/tests/unit/modelSpecMerged.cc b/tests/unit/modelSpecMerged.cc index 699e199ca9..ed02938353 100644 --- a/tests/unit/modelSpecMerged.cc +++ b/tests/unit/modelSpecMerged.cc @@ -12,6 +12,8 @@ // (Single-threaded CPU) backend includes #include "backend.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // Anonyous namespace //-------------------------------------------------------------------------- diff --git a/tests/unit/models.cc b/tests/unit/models.cc index 92c553987c..5558fda620 100644 --- a/tests/unit/models.cc +++ b/tests/unit/models.cc @@ -4,6 +4,8 @@ // GeNN includes #include "modelSpecInternal.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // Anonymous namespace //-------------------------------------------------------------------------- diff --git a/tests/unit/neuronGroup.cc b/tests/unit/neuronGroup.cc index 5dad0e7153..f42b834ce2 100644 --- a/tests/unit/neuronGroup.cc +++ b/tests/unit/neuronGroup.cc @@ -10,6 +10,8 @@ // (Single-threaded CPU) backend includes #include "backend.h" +using namespace GeNN; + namespace { class StaticPulseBack : public WeightUpdateModels::Base diff --git a/tests/unit/neuronModels.cc b/tests/unit/neuronModels.cc index b1ced92d70..af6912cda5 100644 --- a/tests/unit/neuronModels.cc +++ b/tests/unit/neuronModels.cc @@ -5,6 +5,8 @@ #include "modelSpec.h" #include "neuronModels.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // LIFCopy //-------------------------------------------------------------------------- diff --git a/tests/unit/postsynapticModels.cc b/tests/unit/postsynapticModels.cc index 3050e02637..1516281b2e 100644 --- a/tests/unit/postsynapticModels.cc +++ b/tests/unit/postsynapticModels.cc @@ -5,6 +5,8 @@ #include "modelSpec.h" #include "postsynapticModels.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // ExpCurrCopy //-------------------------------------------------------------------------- diff --git a/tests/unit/synapseGroup.cc b/tests/unit/synapseGroup.cc index 5640309970..373f102dba 100644 --- a/tests/unit/synapseGroup.cc +++ b/tests/unit/synapseGroup.cc @@ -10,6 +10,8 @@ // (Single-threaded CPU) backend includes #include "backend.h" +using namespace GeNN; + //-------------------------------------------------------------------------- // Anonymous namespace //-------------------------------------------------------------------------- diff --git a/tests/unit/weightUpdateModels.cc b/tests/unit/weightUpdateModels.cc index d50ff2fecb..6817197e40 100644 --- a/tests/unit/weightUpdateModels.cc +++ b/tests/unit/weightUpdateModels.cc @@ -5,6 +5,8 @@ #include "modelSpec.h" #include "weightUpdateModels.h" +using namespace GeNN; + namespace { //-------------------------------------------------------------------------- From 15a519da04317af1b57b5f505323578c28c9af74 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 13:24:06 +0000 Subject: [PATCH 005/725] fixed bad merge --- src/genn/genn/customUpdateModels.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/customUpdateModels.cc b/src/genn/genn/customUpdateModels.cc index 7422261cb4..d6915807b7 100644 --- a/src/genn/genn/customUpdateModels.cc +++ b/src/genn/genn/customUpdateModels.cc @@ -20,4 +20,4 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const Utils::updateHash(getVarRefs(), hash); return hash.get_digest(); } -} +} // namespace GeNN::CustomUpdateModels \ No newline at end of file From bc623586e0d85e94e939700923c49a13347a65e3 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 12 Jan 2023 14:02:14 +0000 Subject: [PATCH 006/725] fixed test --- tests/unit/customConnectivityUpdate.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/customConnectivityUpdate.cc b/tests/unit/customConnectivityUpdate.cc index 418005566a..fad4251cd3 100644 --- a/tests/unit/customConnectivityUpdate.cc +++ b/tests/unit/customConnectivityUpdate.cc @@ -10,6 +10,8 @@ // (Single-threaded CPU) backend includes #include "backend.h" +using namespace GeNN; + namespace { class StaticPulseDendriticDelayReverse : public WeightUpdateModels::Base From be06da5b21fa6d9e0ea7e092a2e9eae4090a58b2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 6 Jan 2023 09:40:13 +0000 Subject: [PATCH 007/725] initial add of files --- include/genn/genn/transpiler/errorHandler.h | 20 + include/genn/genn/transpiler/expression.h | 324 +++++++ include/genn/genn/transpiler/parser.h | 26 + include/genn/genn/transpiler/prettyPrinter.h | 15 + include/genn/genn/transpiler/scanner.h | 26 + include/genn/genn/transpiler/statement.h | 302 ++++++ include/genn/genn/transpiler/token.h | 62 ++ .../genn/genn/transpiler/transpilerUtils.h | 33 + include/genn/genn/transpiler/type.h | 300 ++++++ include/genn/genn/transpiler/typeChecker.h | 64 ++ src/genn/genn/genn.vcxproj | 17 + src/genn/genn/transpiler/expression.cc | 22 + src/genn/genn/transpiler/parser.cc | 857 ++++++++++++++++++ src/genn/genn/transpiler/prettyPrinter.cc | 272 ++++++ src/genn/genn/transpiler/scanner.cc | 486 ++++++++++ src/genn/genn/transpiler/statement.cc | 21 + src/genn/genn/transpiler/type.cc | 139 +++ src/genn/genn/transpiler/typeChecker.cc | 673 ++++++++++++++ 18 files changed, 3659 insertions(+) create mode 100644 include/genn/genn/transpiler/errorHandler.h create mode 100644 include/genn/genn/transpiler/expression.h create mode 100644 include/genn/genn/transpiler/parser.h create mode 100644 include/genn/genn/transpiler/prettyPrinter.h create mode 100644 include/genn/genn/transpiler/scanner.h create mode 100644 include/genn/genn/transpiler/statement.h create mode 100644 include/genn/genn/transpiler/token.h create mode 100644 include/genn/genn/transpiler/transpilerUtils.h create mode 100644 include/genn/genn/transpiler/type.h create mode 100644 include/genn/genn/transpiler/typeChecker.h create mode 100644 src/genn/genn/transpiler/expression.cc create mode 100644 src/genn/genn/transpiler/parser.cc create mode 100644 src/genn/genn/transpiler/prettyPrinter.cc create mode 100644 src/genn/genn/transpiler/scanner.cc create mode 100644 src/genn/genn/transpiler/statement.cc create mode 100644 src/genn/genn/transpiler/type.cc create mode 100644 src/genn/genn/transpiler/typeChecker.cc diff --git a/include/genn/genn/transpiler/errorHandler.h b/include/genn/genn/transpiler/errorHandler.h new file mode 100644 index 0000000000..ad34193da9 --- /dev/null +++ b/include/genn/genn/transpiler/errorHandler.h @@ -0,0 +1,20 @@ +#pragma once + +// Standard C++ includes +#include + +// Mini-parse includes +#include "token.h" + +//--------------------------------------------------------------------------- +// MiniParse::ErrorHandler +//--------------------------------------------------------------------------- +namespace MiniParse +{ +class ErrorHandler +{ +public: + virtual void error(size_t line, std::string_view message) = 0; + virtual void error(const Token &token, std::string_view message) = 0; +}; +} diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h new file mode 100644 index 0000000000..ce4493e01c --- /dev/null +++ b/include/genn/genn/transpiler/expression.h @@ -0,0 +1,324 @@ +#pragma once + +// Standard C++ includes +#include +#include + +// Mini-parse includes +#include "token.h" + +// Forward declarations +namespace MiniParse::Expression +{ +class Visitor; +} +namespace Type +{ +class Base; +} + +//--------------------------------------------------------------------------- +// MiniParse::Expression::Base +//--------------------------------------------------------------------------- +namespace MiniParse::Expression +{ +class Base +{ +public: + virtual void accept(Visitor &visitor) const = 0; +}; + +typedef std::unique_ptr ExpressionPtr; +typedef std::vector ExpressionList; + +//--------------------------------------------------------------------------- +// MiniParse::Expression::ArraySubscript +//--------------------------------------------------------------------------- +class ArraySubscript : public Base +{ +public: + ArraySubscript(Token pointerName, ExpressionPtr index) + : m_PointerName(pointerName), m_Index(std::move(index)) + {} + + virtual void accept(Visitor &visitor) const final; + + const Token &getPointerName() const { return m_PointerName; } + const ExpressionPtr &getIndex() const { return m_Index; } + +private: + const Token m_PointerName; + const ExpressionPtr m_Index; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Expression::Assignment +//--------------------------------------------------------------------------- +class Assignment : public Base +{ +public: + Assignment(Token varName, Token op, ExpressionPtr value) + : m_VarName(varName), m_Operator(op), m_Value(std::move(value)) + {} + + virtual void accept(Visitor &visitor) const final; + + const Token &getVarName() const { return m_VarName; } + const Token &getOperator() const { return m_Operator; } + const Base *getValue() const { return m_Value.get(); } + +private: + const Token m_VarName; + const Token m_Operator; + const ExpressionPtr m_Value; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Expression::Binary +//--------------------------------------------------------------------------- +class Binary : public Base +{ +public: + Binary(ExpressionPtr left, Token op, ExpressionPtr right) + : m_Left(std::move(left)), m_Operator(op), m_Right(std::move(right)) + {} + + virtual void accept(Visitor &visitor) const final; + + const Base *getLeft() const { return m_Left.get(); } + const Token &getOperator() const { return m_Operator; } + const Base *getRight() const { return m_Right.get(); } + +private: + const ExpressionPtr m_Left; + const Token m_Operator; + const ExpressionPtr m_Right; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Expression::Call +//--------------------------------------------------------------------------- +class Call : public Base +{ +public: + Call(ExpressionPtr callee, Token closingParen, ExpressionList arguments) + : m_Callee(std::move(callee)), m_ClosingParen(closingParen), m_Arguments(std::move(arguments)) + {} + + virtual void accept(Visitor &visitor) const final; + + const Base *getCallee() const { return m_Callee.get(); } + const Token &getClosingParen() const { return m_ClosingParen; } + const ExpressionList &getArguments() const { return m_Arguments; } + +private: + const ExpressionPtr m_Callee; + const Token m_ClosingParen; + const ExpressionList m_Arguments; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Expression::Cast +//--------------------------------------------------------------------------- +class Cast : public Base +{ +public: + Cast(const Type::Base *type, bool isConst, ExpressionPtr expression) + : m_Type(type), m_Const(isConst), m_Expression(std::move(expression)) + {} + + virtual void accept(Visitor &visitor) const final; + + const Base *getExpression() const { return m_Expression.get(); } + + const Type::Base *getType() const { return m_Type; } + bool isConst() const { return m_Const; } + +private: + const Type::Base *m_Type; + bool m_Const; + const ExpressionPtr m_Expression; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Expression::Conditional +//--------------------------------------------------------------------------- +class Conditional : public Base +{ +public: + Conditional(ExpressionPtr condition, Token question, ExpressionPtr trueExpression, ExpressionPtr falseExpression) + : m_Condition(std::move(condition)), m_Question(question), m_True(std::move(trueExpression)), m_False(std::move(falseExpression)) + {} + + virtual void accept(Visitor &visitor) const final; + + const Base *getCondition() const { return m_Condition.get(); } + const Token &getQuestion() const { return m_Question; } + const Base *getTrue() const { return m_True.get(); } + const Base *getFalse() const { return m_False.get(); } + +private: + const ExpressionPtr m_Condition; + const Token m_Question; + const ExpressionPtr m_True; + const ExpressionPtr m_False; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Expression::Grouping +//--------------------------------------------------------------------------- +class Grouping : public Base +{ +public: + Grouping(ExpressionPtr expression) + : m_Expression(std::move(expression)) + {} + + virtual void accept(Visitor &visitor) const final; + + const Base *getExpression() const { return m_Expression.get(); } + +private: + const ExpressionPtr m_Expression; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Expression::Literal +//--------------------------------------------------------------------------- +class Literal : public Base +{ +public: + Literal(Token::LiteralValue value) + : m_Value(value) + {} + + virtual void accept(Visitor &visitor) const final; + + Token::LiteralValue getValue() const { return m_Value; } + +private: + const Token::LiteralValue m_Value; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Expression::Logical +//--------------------------------------------------------------------------- +class Logical : public Base +{ +public: + Logical(ExpressionPtr left, Token op, ExpressionPtr right) + : m_Left(std::move(left)), m_Operator(op), m_Right(std::move(right)) + {} + + virtual void accept(Visitor &visitor) const final; + + const Base *getLeft() const { return m_Left.get(); } + const Token &getOperator() const { return m_Operator; } + const Base *getRight() const { return m_Right.get(); } + +private: + const ExpressionPtr m_Left; + const Token m_Operator; + const ExpressionPtr m_Right; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Expression::PostfixIncDec +//--------------------------------------------------------------------------- +class PostfixIncDec : public Base +{ +public: + PostfixIncDec(Token varName, Token op) + : m_VarName(varName), m_Operator(op) + {} + + virtual void accept(Visitor &visitor) const final; + + const Token &getVarName() const { return m_VarName; } + const Token &getOperator() const { return m_Operator; } + +private: + const Token m_VarName; + const Token m_Operator; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Expression::PrefixIncDec +//--------------------------------------------------------------------------- +class PrefixIncDec : public Base +{ +public: + PrefixIncDec(Token varName, Token op) + : m_VarName(varName), m_Operator(op) + {} + + virtual void accept(Visitor &visitor) const final; + + const Token &getVarName() const { return m_VarName; } + const Token &getOperator() const { return m_Operator; } + +private: + const Token m_VarName; + const Token m_Operator; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Expression::Variable +//--------------------------------------------------------------------------- +class Variable : public Base +{ +public: + Variable(Token name) + : m_Name(name) + {} + + virtual void accept(Visitor &visitor) const final; + + const Token &getName() const { return m_Name; } + +private: + const Token m_Name; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Expression::Unary +//--------------------------------------------------------------------------- +class Unary : public Base +{ +public: + Unary(Token op, ExpressionPtr right) + : m_Operator(op), m_Right(std::move(right)) + {} + + virtual void accept(Visitor &visitor) const final; + + const Token &getOperator() const { return m_Operator; } + const Base *getRight() const { return m_Right.get(); } + +private: + const Token m_Operator; + const ExpressionPtr m_Right; +}; + + +//--------------------------------------------------------------------------- +// MiniParse::Expression::Visitor +//--------------------------------------------------------------------------- +class Visitor +{ +public: + virtual void visit(const ArraySubscript &arraySubscript) = 0; + virtual void visit(const Assignment &assignement) = 0; + virtual void visit(const Binary &binary) = 0; + virtual void visit(const Call &call) = 0; + virtual void visit(const Cast &cast) = 0; + virtual void visit(const Conditional &conditional) = 0; + virtual void visit(const Grouping &grouping) = 0; + virtual void visit(const Literal &literal) = 0; + virtual void visit(const Logical &logical) = 0; + virtual void visit(const PostfixIncDec &postfixIncDec) = 0; + virtual void visit(const PrefixIncDec &postfixIncDec) = 0; + virtual void visit(const Variable &variable) = 0; + virtual void visit(const Unary &unary) = 0; +}; +} // namespace MiniParse::Expression \ No newline at end of file diff --git a/include/genn/genn/transpiler/parser.h b/include/genn/genn/transpiler/parser.h new file mode 100644 index 0000000000..8b063c2ccf --- /dev/null +++ b/include/genn/genn/transpiler/parser.h @@ -0,0 +1,26 @@ +#pragma once + +// Standard C++ includes +#include +#include + +// Mini-parse includes +#include "expression.h" +#include "statement.h" +#include "token.h" + +// Forward declarations +namespace MiniParse +{ +class ErrorHandler; +} + +//--------------------------------------------------------------------------- +// MiniParse::Scanner::Parser +//--------------------------------------------------------------------------- +namespace MiniParse::Parser +{ +Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandler &errorHandler); + +Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandler &errorHandler); +} // MiniParse::MiniParse \ No newline at end of file diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h new file mode 100644 index 0000000000..cf4f7949a9 --- /dev/null +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -0,0 +1,15 @@ +#pragma once + +// Standard C++ includes +#include + +// Mini-parse includes +#include "statement.h" + +//--------------------------------------------------------------------------- +// MiniParse::PrettyPrinter +//--------------------------------------------------------------------------- +namespace MiniParse::PrettyPrinter +{ +std::string print(const Statement::StatementList &statements); +} \ No newline at end of file diff --git a/include/genn/genn/transpiler/scanner.h b/include/genn/genn/transpiler/scanner.h new file mode 100644 index 0000000000..07c53dbe9c --- /dev/null +++ b/include/genn/genn/transpiler/scanner.h @@ -0,0 +1,26 @@ +#pragma once + +// Standard C++ includes +#include +#include +#include +#include +#include + +// Mini-parse includes +#include "token.h" + +// Forward declarations +namespace MiniParse +{ +class ErrorHandler; +} + +//--------------------------------------------------------------------------- +// MiniParse::Scanner::Error +//--------------------------------------------------------------------------- +namespace MiniParse::Scanner +{ +std::vector scanSource(const std::string_view &source, ErrorHandler &errorHandler); + +} // namespace Scanner \ No newline at end of file diff --git a/include/genn/genn/transpiler/statement.h b/include/genn/genn/transpiler/statement.h new file mode 100644 index 0000000000..4ee43e1f9d --- /dev/null +++ b/include/genn/genn/transpiler/statement.h @@ -0,0 +1,302 @@ +#pragma once + +// Standard C++ includes +#include +#include + +// Mini-parse includes +#include "expression.h" + +// Forward declarations +namespace MiniParse::Statement +{ +class Visitor; +} +namespace Type +{ +class Base; +} + +//--------------------------------------------------------------------------- +// MiniParse::Statement::Base +//--------------------------------------------------------------------------- +namespace MiniParse::Statement +{ +class Base +{ +public: + virtual void accept(Visitor &visitor) const = 0; +}; + +typedef std::unique_ptr StatementPtr; +typedef std::vector StatementList; + +//--------------------------------------------------------------------------- +// MiniParse::Statement::Break +//--------------------------------------------------------------------------- +class Break : public Base +{ +public: + Break(Token token) + : m_Token(token) + {} + + virtual void accept(Visitor &visitor) const override; + + const Token &getToken() const { return m_Token; } + +private: + const Token m_Token; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Statement::Compound +//--------------------------------------------------------------------------- +class Compound : public Base +{ +public: + Compound(StatementList statements) + : m_Statements(std::move(statements)) + {} + + virtual void accept(Visitor &visitor) const override; + + const StatementList &getStatements() const { return m_Statements; } + +private: + const StatementList m_Statements; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Statement::Continue +//--------------------------------------------------------------------------- +class Continue : public Base +{ +public: + Continue(Token token) + : m_Token(token) + {} + + virtual void accept(Visitor &visitor) const override; + + const Token &getToken() const { return m_Token; } + +private: + const Token m_Token; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Statement::Do +//--------------------------------------------------------------------------- +class Do : public Base +{ +public: + Do(MiniParse::Expression::ExpressionPtr condition, StatementPtr body) + : m_Condition(std::move(condition)), m_Body(std::move(body)) + {} + + virtual void accept(Visitor &visitor) const override; + + const MiniParse::Expression::Base *getCondition() const { return m_Condition.get(); } + const Base *getBody() const { return m_Body.get(); } + +private: + const MiniParse::Expression::ExpressionPtr m_Condition; + const StatementPtr m_Body; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Statement::Expression +//--------------------------------------------------------------------------- +class Expression : public Base +{ +public: + Expression(MiniParse::Expression::ExpressionPtr expression) + : m_Expression(std::move(expression)) + {} + + virtual void accept(Visitor &visitor) const override; + + const MiniParse::Expression::Base *getExpression() const { return m_Expression.get(); } + +private: + const MiniParse::Expression::ExpressionPtr m_Expression; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Statement::For +//--------------------------------------------------------------------------- +class For : public Base +{ +public: + For(StatementPtr initialiser, MiniParse::Expression::ExpressionPtr condition, MiniParse::Expression::ExpressionPtr increment, StatementPtr body) + : m_Initialiser(std::move(initialiser)), m_Condition(std::move(condition)), m_Increment(std::move(increment)), m_Body(std::move(body)) + {} + + virtual void accept(Visitor &visitor) const override; + + const Base *getInitialiser() const { return m_Initialiser.get(); } + const MiniParse::Expression::Base *getCondition() const { return m_Condition.get(); } + const MiniParse::Expression::Base *getIncrement() const { return m_Increment.get(); } + const Base *getBody() const { return m_Body.get(); } + +private: + const StatementPtr m_Initialiser; + const MiniParse::Expression::ExpressionPtr m_Condition; + const MiniParse::Expression::ExpressionPtr m_Increment; + const StatementPtr m_Body; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Statement::If +//--------------------------------------------------------------------------- +class If : public Base +{ +public: + If(MiniParse::Expression::ExpressionPtr condition, StatementPtr thenBranch, StatementPtr elseBranch) + : m_Condition(std::move(condition)), m_ThenBranch(std::move(thenBranch)), m_ElseBranch(std::move(elseBranch)) + {} + + virtual void accept(Visitor &visitor) const override; + + const MiniParse::Expression::Base *getCondition() const { return m_Condition.get(); } + const Base *getThenBranch() const { return m_ThenBranch.get(); } + const Base *getElseBranch() const { return m_ElseBranch.get(); } + +private: + const MiniParse::Expression::ExpressionPtr m_Condition; + const StatementPtr m_ThenBranch; + const StatementPtr m_ElseBranch; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Statement::Labelled +//--------------------------------------------------------------------------- +class Labelled : public Base +{ +public: + Labelled(Token keyword, MiniParse::Expression::ExpressionPtr value, StatementPtr body) + : m_Keyword(keyword), m_Value(std::move(value)), m_Body(std::move(body)) + {} + + virtual void accept(Visitor &visitor) const override; + + const Token &getKeyword() const { return m_Keyword; } + const MiniParse::Expression::Base *getValue() const { return m_Value.get(); } + const Base *getBody() const { return m_Body.get(); } + +private: + const Token m_Keyword; + const MiniParse::Expression::ExpressionPtr m_Value; + const StatementPtr m_Body; +}; + + +//--------------------------------------------------------------------------- +// MiniParse::Statement::Switch +//--------------------------------------------------------------------------- +class Switch : public Base +{ +public: + Switch(Token switchToken, MiniParse::Expression::ExpressionPtr condition, StatementPtr body) + : m_Switch(switchToken), m_Condition(std::move(condition)), m_Body(std::move(body)) + {} + + virtual void accept(Visitor &visitor) const override; + + const Token &getSwitch() const { return m_Switch; } + const MiniParse::Expression::Base *getCondition() const { return m_Condition.get(); } + const Base *getBody() const { return m_Body.get(); } + +private: + const Token m_Switch; + const MiniParse::Expression::ExpressionPtr m_Condition; + const StatementPtr m_Body; +}; + + +//--------------------------------------------------------------------------- +// MiniParse::Statement::VarDeclaration +//--------------------------------------------------------------------------- +class VarDeclaration : public Base +{ +public: + typedef std::vector> InitDeclaratorList; + + VarDeclaration(const Type::Base *type, bool isConst, InitDeclaratorList initDeclaratorList) + : m_Type(type), m_Const(isConst), m_InitDeclaratorList(std::move(initDeclaratorList)) + {} + + virtual void accept(Visitor &visitor) const override; + + const Type::Base *getType() const { return m_Type; } + bool isConst() const { return m_Const; } + + const InitDeclaratorList &getInitDeclaratorList() const { return m_InitDeclaratorList; } + +private: + const Type::Base *m_Type; + const bool m_Const; + const std::vector m_DeclarationSpecifiers; + const InitDeclaratorList m_InitDeclaratorList; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Statement::If +//--------------------------------------------------------------------------- +class While : public Base +{ +public: + While(MiniParse::Expression::ExpressionPtr condition, StatementPtr body) + : m_Condition(std::move(condition)), m_Body(std::move(body)) + {} + + virtual void accept(Visitor &visitor) const override; + + const MiniParse::Expression::Base *getCondition() const { return m_Condition.get(); } + const Base *getBody() const { return m_Body.get(); } + +private: + const MiniParse::Expression::ExpressionPtr m_Condition; + const StatementPtr m_Body; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Statement::Print +//--------------------------------------------------------------------------- +// **HACK** temporary until function calling is working +class Print : public Base +{ +public: + Print(MiniParse::Expression::ExpressionPtr expression) + : m_Expression(std::move(expression)) + {} + + virtual void accept(Visitor &visitor) const override; + + const MiniParse::Expression::Base *getExpression() const { return m_Expression.get(); } + +private: + const MiniParse::Expression::ExpressionPtr m_Expression; +}; + +//--------------------------------------------------------------------------- +// MiniParse::Statement::Visitor +//--------------------------------------------------------------------------- +class Visitor +{ +public: + virtual void visit(const Break &breakStatement) = 0; + virtual void visit(const Compound &compound) = 0; + virtual void visit(const Continue &continueStatement) = 0; + virtual void visit(const Do &doStatement) = 0; + virtual void visit(const Expression &expression) = 0; + virtual void visit(const For &forStatement) = 0; + virtual void visit(const If &ifStatement) = 0; + virtual void visit(const Labelled &labelled) = 0; + virtual void visit(const Switch &switchStatement) = 0; + virtual void visit(const VarDeclaration &varDeclaration) = 0; + virtual void visit(const While &whileStatement) = 0; + virtual void visit(const Print &print) = 0; +}; +} // namespace MiniParse::Statement \ No newline at end of file diff --git a/include/genn/genn/transpiler/token.h b/include/genn/genn/transpiler/token.h new file mode 100644 index 0000000000..be480fa67f --- /dev/null +++ b/include/genn/genn/transpiler/token.h @@ -0,0 +1,62 @@ +#pragma once + +// Standard C++ includes +#include +#include + +// Standard C includes +#include + +//--------------------------------------------------------------------------- +// MiniParse::Token +//--------------------------------------------------------------------------- +namespace MiniParse +{ +struct Token +{ + typedef std::variant LiteralValue; + + enum class Type + { + // Single-character tokens + LEFT_PAREN, RIGHT_PAREN, LEFT_BRACE, RIGHT_BRACE, LEFT_SQUARE_BRACKET, RIGHT_SQUARE_BRACKET, + COMMA, PIPE, CARET, DOT, MINUS, PERCENT, PLUS, COLON, SEMICOLON, SLASH, STAR, TILDA, AMPERSAND, QUESTION, + + // One or two character tokens + NOT, NOT_EQUAL, + EQUAL_EQUAL, + GREATER, GREATER_EQUAL, + LESS, LESS_EQUAL, + EQUAL, STAR_EQUAL, SLASH_EQUAL, PERCENT_EQUAL, PLUS_EQUAL, + MINUS_EQUAL, AMPERSAND_EQUAL, CARET_EQUAL, PIPE_EQUAL, + PIPE_PIPE, AMPERSAND_AMPERSAND, PLUS_PLUS, MINUS_MINUS, + SHIFT_LEFT, SHIFT_RIGHT, + + // Three character tokens + SHIFT_LEFT_EQUAL, SHIFT_RIGHT_EQUAL, + + // Literals + IDENTIFIER, NUMBER, + + // Types + TYPE_SPECIFIER, + TYPE_QUALIFIER, + + // Keywords + DO, ELSE, FALSE, FOR, IF, TRUE, WHILE, PRINT, SWITCH, CONTINUE, BREAK, CASE, DEFAULT, + + END_OF_FILE, + }; + + Token(Type type, std::string_view lexeme, size_t line, LiteralValue literalValue = LiteralValue()) + : type(type), lexeme(lexeme), line(line), literalValue(literalValue) + { + } + + const Type type; + const std::string_view lexeme; + const size_t line; + const LiteralValue literalValue; +}; + +} diff --git a/include/genn/genn/transpiler/transpilerUtils.h b/include/genn/genn/transpiler/transpilerUtils.h new file mode 100644 index 0000000000..be9514005a --- /dev/null +++ b/include/genn/genn/transpiler/transpilerUtils.h @@ -0,0 +1,33 @@ +#pragma once + +// Standard C++ includes +#include +#include + +namespace MiniParse::Utils +{ + template struct Overload : Ts... { using Ts::operator()...; }; + template Overload(Ts...) -> Overload; // line not needed in + + template + T toCharsThrow(std::string_view input, int base = 10) + { + T out; + std::from_chars_result result; + if constexpr (std::is_floating_point_v) { + result = std::from_chars(input.data(), input.data() + input.size(), out, + (base == 10) ? std::chars_format::general : std::chars_format::hex); + } + else { + result = std::from_chars(input.data(), input.data() + input.size(), out, base); + } + + if(result.ec == std::errc::invalid_argument) { + throw std::invalid_argument("Unable to convert chars '" + std::string{input} + "'"); + } + else if(result.ec == std::errc::result_out_of_range) { + throw std::out_of_range("Unable to convert chars '" + std::string{input} + "'"); + } + return out; + } +} diff --git a/include/genn/genn/transpiler/type.h b/include/genn/genn/transpiler/type.h new file mode 100644 index 0000000000..8c93ee6abf --- /dev/null +++ b/include/genn/genn/transpiler/type.h @@ -0,0 +1,300 @@ +#pragma once + +// Standard C includes +#include + +// Standard C++ includes +#include +#include +#include +#include +#include +#include +#include + +//---------------------------------------------------------------------------- +// Macros +//---------------------------------------------------------------------------- +#define DECLARE_TYPE(TYPE) \ + private: \ + /*GENN_EXPORT*/ static TYPE *s_Instance; \ + public: \ + static const TYPE *getInstance() \ + { \ + if(s_Instance == NULL) \ + { \ + s_Instance = new TYPE; \ + } \ + return s_Instance; \ + } + +#define DECLARE_NUMERIC_TYPE(TYPE, UNDERLYING_TYPE, RANK) \ + class TYPE : public Numeric \ + { \ + DECLARE_TYPE(TYPE) \ + virtual std::string getTypeName() const{ return #UNDERLYING_TYPE; } \ + }; \ + class TYPE##Ptr : public NumericPtr \ + { \ + DECLARE_TYPE(TYPE##Ptr) \ + }; \ + template<> \ + struct TypeTraits \ + { \ + using NumericType = TYPE; \ + }; \ + template<> \ + struct TypeTraits \ + { \ + using NumericPtrType = TYPE##Ptr; \ + } + +#define DECLARE_FOREIGN_FUNCTION_TYPE(TYPE, RETURN_TYPE, ...) \ + class TYPE : public ForeignFunction \ + { \ + DECLARE_TYPE(TYPE) \ + } + +#define IMPLEMENT_TYPE(TYPE) TYPE *TYPE::s_Instance = NULL +#define IMPLEMENT_NUMERIC_TYPE(TYPE) IMPLEMENT_TYPE(TYPE); IMPLEMENT_TYPE(TYPE##Ptr) + +//---------------------------------------------------------------------------- +// Type::TypeTraits +//---------------------------------------------------------------------------- +namespace Type +{ +//! Empty type trait structure +template +struct TypeTraits +{ +}; + +//---------------------------------------------------------------------------- +// Type::Base +//---------------------------------------------------------------------------- +//! Base class for all types +class Base +{ +public: + //------------------------------------------------------------------------ + // Declared virtuals + //------------------------------------------------------------------------ + virtual std::string getTypeName() const = 0; + virtual size_t getTypeHash() const = 0; +}; + +//---------------------------------------------------------------------------- +// Type::NumericBase +//---------------------------------------------------------------------------- +class NumericBase : public Base +{ +public: + //------------------------------------------------------------------------ + // Declared virtuals + //------------------------------------------------------------------------ + virtual int getRank() const = 0; + virtual double getMin() const = 0; + virtual double getMax() const = 0; + virtual double getLowest() const = 0; + virtual bool isSigned() const = 0; + virtual bool isIntegral() const = 0; + + virtual const class NumericPtrBase *getPointerType() const = 0; +}; + +//---------------------------------------------------------------------------- +// NumericPtrBase +//---------------------------------------------------------------------------- +class NumericPtrBase : public Base +{ +public: + //------------------------------------------------------------------------ + // Declared virtuals + //------------------------------------------------------------------------ + virtual const NumericBase *getValueType() const = 0; +}; + +//---------------------------------------------------------------------------- +// Type::Numeric +//---------------------------------------------------------------------------- +template +class Numeric : public NumericBase +{ +public: + //------------------------------------------------------------------------ + // Typedefines + //------------------------------------------------------------------------ + typedef T UnderlyingType; + + //------------------------------------------------------------------------ + // Base virtuals + //------------------------------------------------------------------------ + virtual size_t getTypeHash() const final { return typeid(T).hash_code(); } + + //------------------------------------------------------------------------ + // NumericBase virtuals + //------------------------------------------------------------------------ + virtual int getRank() const final { return Rank; } + virtual double getMin() const final { return std::numeric_limits::min(); } + virtual double getMax() const final { return std::numeric_limits::max(); } + virtual double getLowest() const final { return std::numeric_limits::lowest(); } + virtual bool isSigned() const final { return std::is_signed::value; } + virtual bool isIntegral() const final { return std::is_integral::value; } + + virtual const NumericPtrBase *getPointerType() const + { + return TypeTraits>::NumericPtrType::getInstance(); + } +}; + +//---------------------------------------------------------------------------- +// NumericPtr +//---------------------------------------------------------------------------- +template +class NumericPtr : public NumericPtrBase +{ +public: + //------------------------------------------------------------------------ + // Base virtuals + //------------------------------------------------------------------------ + virtual std::string getTypeName() const final { return T::getInstance()->getTypeName() + "*"; } + virtual size_t getTypeHash() const final { return typeid(std::add_pointer_t).hash_code(); } + + //------------------------------------------------------------------------ + // NumericArrayBase virtuals + //------------------------------------------------------------------------ + virtual const NumericBase *getValueType() const final { return T::getInstance(); } +}; + +//---------------------------------------------------------------------------- +// Type::ForeignFunctionBase +//---------------------------------------------------------------------------- +class ForeignFunctionBase : public Base +{ +public: + //------------------------------------------------------------------------ + // Base virtuals + //------------------------------------------------------------------------ + virtual std::string getTypeName() const = 0; + virtual size_t getTypeHash() const = 0; + + //------------------------------------------------------------------------ + // Declared virtuals + //------------------------------------------------------------------------ + virtual const NumericBase *getReturnType() const = 0; + virtual std::vector getArgumentTypes() const = 0; +}; + +//---------------------------------------------------------------------------- +// Type::ForeignFunction +//---------------------------------------------------------------------------- +template +class ForeignFunction : public ForeignFunctionBase +{ +public: + //------------------------------------------------------------------------ + // Base virtuals + //------------------------------------------------------------------------ + virtual std::string getTypeName() const final + { + std::string typeName = getReturnType()->getTypeName() + "("; + updateTypeName(typeName); + typeName += ")"; + return typeName; + } + + virtual size_t getTypeHash() const final + { + // Start with seed of return type hash + size_t seed = getReturnType()->getTypeHash(); + updateTypeHash(seed); + return seed; + } + + //------------------------------------------------------------------------ + // ForeignFunctionBase virtuals + //------------------------------------------------------------------------ + virtual const NumericBase *getReturnType() const final + { + return ReturnType::getInstance(); + } + + virtual std::vector getArgumentTypes() const final + { + std::vector args; + args.reserve(sizeof...(ArgTypes)); + updateArgumentTypes(args); + return args; + } + +private: + //------------------------------------------------------------------------ + // Private methods + //------------------------------------------------------------------------ + template + static void updateTypeHash(size_t &seed) + { + // Combine hashes with argument type + // **NOTE** this is the boost::hash_combine algorithm + seed ^= T::getInstance()->getTypeHash() + 0x9e3779b9 + (seed << 6) + (seed >> 2); + + // If there are more arguments left in pack, recurse + if constexpr (sizeof...(Args)) { + updateTypeHash(seed); + } + } + + template + static void updateTypeName(std::string &typeName) + { + // Add argument typename to string + typeName += T::getInstance()->getTypeName(); + + // If there are more arguments left in pack, add comma and recurse + if constexpr (sizeof...(Args)) { + typeName += ", "; + updateTypeName(typeName); + } + } + + template + static void updateArgumentTypes(std::vector &args) + { + // Add argument typename to string + args.push_back(T::getInstance()); + + // If there are more arguments left in pack, recurse + if constexpr (sizeof...(Args)) { + updateArgumentTypes(args); + } + } + +}; + +//---------------------------------------------------------------------------- +// Declare numeric types +//---------------------------------------------------------------------------- +DECLARE_NUMERIC_TYPE(Bool, bool, 0); +DECLARE_NUMERIC_TYPE(Int8, int8_t, 10); +DECLARE_NUMERIC_TYPE(Int16, int16_t, 20); +DECLARE_NUMERIC_TYPE(Int32, int32_t, 30); +//DECLARE_NUMERIC_TYPE(Int64, int64_t, 40); +DECLARE_NUMERIC_TYPE(Uint8, uint8_t, 10); +DECLARE_NUMERIC_TYPE(Uint16, uint16_t, 20); +DECLARE_NUMERIC_TYPE(Uint32, uint32_t, 30); +//DECLARE_NUMERIC_TYPE(Uint64, uint64_t, 40); +DECLARE_NUMERIC_TYPE(Float, float, 50); +DECLARE_NUMERIC_TYPE(Double, double, 60); + +//---------------------------------------------------------------------------- +// Declare standard library foreign function types +//---------------------------------------------------------------------------- +DECLARE_FOREIGN_FUNCTION_TYPE(Exp, Double, Double); +DECLARE_FOREIGN_FUNCTION_TYPE(Sqrt, Double, Double); + +//! Look up type based on set of type specifiers +const NumericBase *getNumericType(const std::set &typeSpecifiers); +const NumericPtrBase *getNumericPtrType(const std::set &typeSpecifiers); +const NumericBase *getPromotedType(const NumericBase *type); +const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b); +} // namespace Type diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h new file mode 100644 index 0000000000..b9978abd91 --- /dev/null +++ b/include/genn/genn/transpiler/typeChecker.h @@ -0,0 +1,64 @@ +#pragma once + +// Standard C++ includes +#include +#include +#include + +// Mini-parse includes +#include "statement.h" + +// Forward declarations +namespace MiniParse +{ +class ErrorHandler; +struct Token; +} +namespace Type +{ +class Base; +} + +//--------------------------------------------------------------------------- +// MiniParse::TypeChecker::Environment +//--------------------------------------------------------------------------- +namespace MiniParse::TypeChecker +{ +class Environment +{ +public: + Environment(Environment *enclosing = nullptr) + : m_Enclosing(enclosing) + { + } + + //--------------------------------------------------------------------------- + // Public API + //--------------------------------------------------------------------------- + template + void define(std::string_view name, bool isConst = false) + { + if(!m_Types.try_emplace(name, T::getInstance(), isConst).second) { + throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); + } + } + void define(const Token &name, const Type::Base *type, bool isConst, ErrorHandler &errorHandler); + const Type::Base *assign(const Token &name, const Type::Base *assignedType, bool assignedConst, + Token::Type op, ErrorHandler &errorHandler); + const Type::Base *incDec(const Token &name, const Token &op, ErrorHandler &errorHandler); + std::tuple getType(const Token &name, ErrorHandler &errorHandler) const; + +private: + //--------------------------------------------------------------------------- + // Members + //--------------------------------------------------------------------------- + Environment *m_Enclosing; + std::unordered_map> m_Types; +}; + +//--------------------------------------------------------------------------- +// Free functions +//--------------------------------------------------------------------------- +void typeCheck(const Statement::StatementList &statements, Environment &environment, + ErrorHandler &errorHandler); +} // namespace MiniParse::TypeChecker \ No newline at end of file diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index 62331c7698..901a1c5e69 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -55,6 +55,13 @@ + + + + + + + @@ -110,6 +117,16 @@ + + + + + + + + + + diff --git a/src/genn/genn/transpiler/expression.cc b/src/genn/genn/transpiler/expression.cc new file mode 100644 index 0000000000..f85eb8f49e --- /dev/null +++ b/src/genn/genn/transpiler/expression.cc @@ -0,0 +1,22 @@ +#include "expression.h" + +#define IMPLEMENT_ACCEPT(CLASS_NAME) \ + void MiniParse::Expression::CLASS_NAME::accept(Visitor &visitor) const \ + { \ + visitor.visit(*this); \ + } + + +IMPLEMENT_ACCEPT(ArraySubscript) +IMPLEMENT_ACCEPT(Assignment) +IMPLEMENT_ACCEPT(Binary) +IMPLEMENT_ACCEPT(Call) +IMPLEMENT_ACCEPT(Cast) +IMPLEMENT_ACCEPT(Conditional) +IMPLEMENT_ACCEPT(Grouping) +IMPLEMENT_ACCEPT(Literal) +IMPLEMENT_ACCEPT(Logical) +IMPLEMENT_ACCEPT(PrefixIncDec) +IMPLEMENT_ACCEPT(PostfixIncDec) +IMPLEMENT_ACCEPT(Variable) +IMPLEMENT_ACCEPT(Unary) \ No newline at end of file diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc new file mode 100644 index 0000000000..e8f63e7066 --- /dev/null +++ b/src/genn/genn/transpiler/parser.cc @@ -0,0 +1,857 @@ +#include "parser.h" + +// Standard C++ includes +#include +#include +#include +#include +#include + +// Standard C includes +#include + +// GeNN includes +#include "type.h" + +// Mini-parse includes +#include "error_handler.h" + +using namespace MiniParse; + +//--------------------------------------------------------------------------- +// Anonymous namespace +//--------------------------------------------------------------------------- +namespace +{ +//--------------------------------------------------------------------------- +// ParseError +//--------------------------------------------------------------------------- +class ParseError +{ +}; + +//--------------------------------------------------------------------------- +// ParserState +//--------------------------------------------------------------------------- +//! Class encapsulated logic to navigate through tokens +class ParserState +{ +public: + ParserState(const std::vector &tokens, ErrorHandler &errorHandler) + : m_Current(0), m_Tokens(tokens), m_ErrorHandler(errorHandler) + {} + + //--------------------------------------------------------------------------- + // Public API + //--------------------------------------------------------------------------- + bool match(Token::Type t) + { + if(check(t)) { + advance(); + return true; + } + else { + return false; + } + } + + bool match(std::initializer_list types) + { + // Loop through types + for(auto t : types) { + if(match(t)) { + return true; + } + } + return false; + } + + Token advance() + { + if(!isAtEnd()) { + m_Current++; + } + + return previous(); + } + + Token rewind() + { + if(m_Current > 0) { + m_Current--; + } + + return peek(); + } + + Token peek() const + { + return m_Tokens.at(m_Current); + } + + Token previous() const + { + assert(m_Current > 0); + return m_Tokens.at(m_Current - 1); + } + + void error(std::string_view message) const + { + m_ErrorHandler.error(peek(), message); + } + + void error(Token token, std::string_view message) const + { + m_ErrorHandler.error(token, message); + } + + Token consume(Token::Type type, std::string_view message) + { + if(check(type)) { + return advance(); + } + + error(message); + throw ParseError(); + } + + bool check(Token::Type type) const + { + if(isAtEnd()) { + return false; + } + else { + return (peek().type == type); + } + } + + bool isAtEnd() const { return (peek().type == Token::Type::END_OF_FILE); } + +private: + //--------------------------------------------------------------------------- + // Members + //--------------------------------------------------------------------------- + size_t m_Current; + + const std::vector &m_Tokens; + + ErrorHandler &m_ErrorHandler; +}; + + +void synchronise(ParserState &parserState) +{ + parserState.advance(); + while(!parserState.isAtEnd()) { + if(parserState.previous().type == Token::Type::SEMICOLON) { + return; + } + + const auto nextTokenType = parserState.peek().type; + if(nextTokenType == Token::Type::FOR + || nextTokenType == Token::Type::IF + || nextTokenType == Token::Type::WHILE + || nextTokenType == Token::Type::TYPE_SPECIFIER) + { + return; + } + + parserState.advance(); + } +} + +// Forward declarations +Expression::ExpressionPtr parseCast(ParserState &parserState); +Expression::ExpressionPtr parseAssignment(ParserState &parserState); +Expression::ExpressionPtr parseExpression(ParserState &parserState); +Statement::StatementPtr parseBlockItem(ParserState &parserState); +Statement::StatementPtr parseDeclaration(ParserState &parserState); +Statement::StatementPtr parseStatement(ParserState &parserState); + +// Helper to parse binary expressions +// **THINK I think this COULD be variadic but not clear if that's a good idea or not +template +Expression::ExpressionPtr parseBinary(ParserState &parserState, N nonTerminal, std::initializer_list types) +{ + auto expression = nonTerminal(parserState); + while(parserState.match(types)) { + Token op = parserState.previous(); + expression = std::make_unique(std::move(expression), op, nonTerminal(parserState)); + } + + return expression; +} + +std::tuple parseDeclarationSpecifiers(ParserState &parserState) +{ + // Loop through type qualifier and specifier tokens + std::set typeQualifiers{}; + std::set typeSpecifiers{}; + do { + // Add token lexeme to appropriate set, giving error if duplicate + if(parserState.previous().type == Token::Type::TYPE_QUALIFIER) { + if(!typeQualifiers.insert(parserState.previous().lexeme).second) { + parserState.error(parserState.previous(), "duplicate type qualifier"); + } + } + else { + if(!typeSpecifiers.insert(parserState.previous().lexeme).second) { + parserState.error(parserState.previous(), "duplicate type specifier"); + } + } + } while(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER})); + + // Lookup type + const Type::Base *type = (parserState.match({Token::Type::STAR}) + ? static_cast(Type::getNumericPtrType(typeSpecifiers)) + : static_cast(Type::getNumericType(typeSpecifiers))); + if(!type) { + parserState.error("Unknown type specifier"); + } + + // Determine constness + // **NOTE** this only works as const is the ONLY supported qualifier + return std::make_tuple(type, !typeQualifiers.empty()); +} + +Expression::ExpressionPtr parsePrimary(ParserState &parserState) +{ + // primary-expression ::= + // identifier + // constant + // "(" expression ")" + if(parserState.match(Token::Type::FALSE)) { + return std::make_unique(false); + } + else if(parserState.match(Token::Type::TRUE)) { + return std::make_unique(true); + } + else if(parserState.match(Token::Type::NUMBER)) { + return std::make_unique(parserState.previous().literalValue); + } + else if(parserState.match(Token::Type::IDENTIFIER)) { + return std::make_unique(parserState.previous()); + } + else if(parserState.match(Token::Type::LEFT_PAREN)) { + auto expression = parseExpression(parserState); + + parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after expression"); + return std::make_unique(std::move(expression)); + } + + parserState.error("Expect expression"); + throw ParseError(); +} + +Expression::ExpressionPtr parsePostfix(ParserState &parserState) +{ + // postfix-expression ::= + // primary-expression + // postfix-expression "[" expression "]" + // postfix-expression "(" argument-expression-list? ")" + // postfix-expression "++" + // postfix-expression "--" + + // argument-expression-list ::= + // assignment-expression + // argument-expression-list "," assignment-expression + + auto expression = parsePrimary(parserState); + + while(true) { + // If this is a function call + if(parserState.match(Token::Type::LEFT_PAREN)) { + // Build list of arguments + Expression::ExpressionList arguments; + if(!parserState.check(Token::Type::RIGHT_PAREN)) { + do { + arguments.emplace_back(parseAssignment(parserState)); + } while(parserState.check(Token::Type::COMMA)); + } + + Token closingParen = parserState.consume(Token::Type::RIGHT_PAREN, + "Expect ')' after arguments."); + + // Create call expression + expression = std::make_unique(std::move(expression), + closingParen, + std::move(arguments)); + } + // Otherwise, if this is an array index + if(parserState.match(Token::Type::LEFT_SQUARE_BRACKET)) { + auto index = parseExpression(parserState); + Token closingSquareBracket = parserState.consume(Token::Type::RIGHT_SQUARE_BRACKET, + "Expect ']' after index."); + + // **TODO** everything all the way up(?) from unary are l-value so can be used - not just variable + auto expressionVariable = dynamic_cast(expression.get()); + if(expressionVariable) { + expression = std::make_unique(expressionVariable->getName(), + std::move(index)); + } + else { + parserState.error(closingSquareBracket, "Invalid subscript target"); + } + } + // Otherwise if this is an increment or decrement + else if(parserState.match({Token::Type::PLUS_PLUS, Token::Type::MINUS_MINUS})) { + Token op = parserState.previous(); + + // **TODO** everything all the way up(?) from unary are l-value so can be used - not just variable + auto expressionVariable = dynamic_cast(expression.get()); + if(expressionVariable) { + return std::make_unique(expressionVariable->getName(), op); + } + else { + parserState.error(op, "Invalid postfix target"); + } + } + else { + break; + } + } + + return expression; +} + + +Expression::ExpressionPtr parseUnary(ParserState &parserState) +{ + // unary-expression ::= + // postfix-expression + // "++" unary-expression + // "--" unary-expression + // "&" cast-expression + // "*" cast-expression + // "+" cast-expression + // "-" cast-expression + // "~" cast-expression + // "!" cast-expression + // "sizeof" unary-expression **TODO** + // "sizeof" "(" type-name ")" **TODO** + if(parserState.match({Token::Type::AMPERSAND, Token::Type::STAR, Token::Type::PLUS, + Token::Type::MINUS, Token::Type::TILDA, Token::Type::NOT})) { + Token op = parserState.previous(); + return std::make_unique(op, parseCast(parserState)); + } + else if(parserState.match({Token::Type::PLUS_PLUS, Token::Type::MINUS_MINUS})) { + Token op = parserState.previous(); + auto expression = parseUnary(parserState); + + // **TODO** everything all the way up(?) from unary are l-value so can be used - not just variable + auto expressionVariable = dynamic_cast(expression.get()); + if(expressionVariable) { + return std::make_unique(expressionVariable->getName(), op); + } + else { + parserState.error(op, "Invalid prefix target"); + } + } + + return parsePostfix(parserState); +} + +Expression::ExpressionPtr parseCast(ParserState &parserState) +{ + // cast-expression ::= + // unary-expression + // "(" type-name ")" cast-expression + + // If next token is a left parenthesis + if(parserState.match(Token::Type::LEFT_PAREN)) { + // If this is followed by some part of a type declarator + if(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER})) { + // Parse declaration specifiers + const auto [type, isConst] = parseDeclarationSpecifiers(parserState); + + parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after cast type."); + + return std::make_unique(type, isConst, parseCast(parserState)); + } + // Otherwise, rewind parser state so left parenthesis can be parsed again + // **YUCK** + else { + parserState.rewind(); + } + } + + return parseUnary(parserState); +} + +Expression::ExpressionPtr parseMultiplicative(ParserState &parserState) +{ + // multiplicative-expression ::= + // cast-expression + // multiplicative-parseExpression "*" cast-parseExpression + // multiplicative-parseExpression "/" cast-parseExpression + // multiplicative-parseExpression "%" cast-parseExpression + return parseBinary(parserState, parseCast, + {Token::Type::STAR, Token::Type::SLASH, Token::Type::PERCENT}); +} + +Expression::ExpressionPtr parseAdditive(ParserState &parserState) +{ + // additive-expression ::= + // multiplicative-expression + // additive-parseExpression "+" multiplicative-parseExpression + // additive-parseExpression "-" multiplicative-parseExpression + return parseBinary(parserState, parseMultiplicative, + {Token::Type::MINUS, Token::Type::PLUS}); +} + +Expression::ExpressionPtr parseShift(ParserState &parserState) +{ + // shift-expression ::= + // additive-expression + // shift-parseExpression "<<" additive-parseExpression + // shift-parseExpression ">>" additive-parseExpression + return parseBinary(parserState, parseAdditive, + {Token::Type::SHIFT_LEFT, Token::Type::SHIFT_RIGHT}); +} + +Expression::ExpressionPtr parseRelational(ParserState &parserState) +{ + // relational-expression ::= + // shift-expression + // relational-parseExpression "<" shift-parseExpression + // relational-parseExpression ">" shift-parseExpression + // relational-parseExpression "<=" shift-parseExpression + // relational-parseExpression ">=" shift-parseExpression + return parseBinary(parserState, parseShift, + {Token::Type::GREATER, Token::Type::GREATER_EQUAL, + Token::Type::LESS, Token::Type::LESS_EQUAL}); +} + +Expression::ExpressionPtr parseEquality(ParserState &parserState) +{ + // equality-expression ::= + // relational-expression + // equality-parseExpression "==" relational-parseExpression + // equality-parseExpression "!=" relational-parseExpression + return parseBinary(parserState, parseRelational, + {Token::Type::NOT_EQUAL, Token::Type::EQUAL_EQUAL}); +} +Expression::ExpressionPtr parseAnd(ParserState &parserState) +{ + // AND-expression ::= + // equality-expression + // AND-expression "&" equality-expression + return parseBinary(parserState, parseEquality, {Token::Type::AMPERSAND}); +} + +Expression::ExpressionPtr parseXor(ParserState &parserState) +{ + // exclusive-OR-expression ::= + // AND-expression + // exclusive-OR-expression "^" AND-expression + return parseBinary(parserState, parseAnd, {Token::Type::CARET}); +} + +Expression::ExpressionPtr parseOr(ParserState &parserState) +{ + // inclusive-OR-expression ::= + // exclusive-OR-expression + // inclusive-OR-expression "|" exclusive-OR-expression + return parseBinary(parserState, parseXor, {Token::Type::PIPE}); +} + +Expression::ExpressionPtr parseLogicalAnd(ParserState &parserState) +{ + // logical-AND-expression ::= + // inclusive-OR-expression + // logical-AND-expression "&&" inclusive-OR-expression + // **THINK** parseLogicalAnd here (obviously) stack-overflows - why is this the grammar? + auto expression = parseOr(parserState); + + while(parserState.match(Token::Type::AMPERSAND_AMPERSAND)) { + Token op = parserState.previous(); + auto right = parseOr(parserState); + expression = std::make_unique(std::move(expression), op, std::move(right)); + } + return expression; +} + +Expression::ExpressionPtr parseLogicalOr(ParserState &parserState) +{ + // logical-OR-expression ::= + // logical-AND-expression + // logical-OR-expression "||" logical-AND-expression + // **THINK** parseLogicalOr here (obviously) stack-overflows - why is this the grammar? + auto expression = parseLogicalAnd(parserState); + + while(parserState.match(Token::Type::PIPE_PIPE)) { + Token op = parserState.previous(); + auto right = parseLogicalAnd(parserState); + expression = std::make_unique(std::move(expression), op, std::move(right)); + } + return expression; +} + +Expression::ExpressionPtr parseConditional(ParserState &parserState) +{ + // conditional-expression ::= + // logical-OR-expression + // logical-OR-expression "?" expression ":" conditional-expression + auto cond = parseLogicalOr(parserState); + if(parserState.match(Token::Type::QUESTION)) { + Token question = parserState.previous(); + auto trueExpression = parseExpression(parserState); + parserState.consume(Token::Type::COLON, "Expect ':' in conditional expression."); + auto falseExpression = parseConditional(parserState); + return std::make_unique(std::move(cond), question, std::move(trueExpression), + std::move(falseExpression)); + } + + return cond; +} + +Expression::ExpressionPtr parseAssignment(ParserState &parserState) +{ + // assignment-expression ::= + // conditional-expression + // unary-expression assignment-operator assignment-expression + auto expression = parseConditional(parserState); + if(parserState.match({Token::Type::EQUAL, Token::Type::STAR_EQUAL, Token::Type::SLASH_EQUAL, + Token::Type::PERCENT_EQUAL, Token::Type::PLUS_EQUAL, Token::Type::MINUS_EQUAL, + Token::Type::AMPERSAND_EQUAL, Token::Type::CARET_EQUAL, Token::Type::PIPE_EQUAL, + Token::Type::SHIFT_LEFT_EQUAL, Token::Type::SHIFT_RIGHT_EQUAL})) + { + Token op = parserState.previous(); + auto value = parseAssignment(parserState); + + // **TODO** everything all the way up(?) from unary are l-value so can be used - not just variable + auto expressionVariable = dynamic_cast(expression.get()); + if(expressionVariable) { + return std::make_unique(expressionVariable->getName(), op, std::move(value)); + } + else { + parserState.error(op, "Invalid assignment target"); + } + } + + return expression; +} + +Expression::ExpressionPtr parseExpression(ParserState &parserState) +{ + // expression ::= + // assignment-expression + // expression "," assignment-expression + return parseBinary(parserState, parseAssignment, + {Token::Type::COMMA}); +} + +Statement::StatementPtr parseLabelledStatement(ParserState &parserState) +{ + // labeled-statement ::= + // "case" constant-expression ":" statement + // "default" ":" statement + const auto keyword = parserState.previous(); + + Expression::ExpressionPtr value; + if(keyword.type == Token::Type::CASE) { + value = parseConditional(parserState); + } + + parserState.consume(Token::Type::COLON, "Expect ':' after labelled statement."); + + return std::make_unique(keyword, std::move(value), + parseStatement(parserState)); +} + +Statement::StatementPtr parseCompoundStatement(ParserState &parserState) +{ + // compound-statement ::= + // "{" block-item-list? "}" + // block-item-list ::= + // block-item + // block-item-list block-item + // block-item ::= + // declaration + // statement + Statement::StatementList statements; + while(!parserState.check(Token::Type::RIGHT_BRACE) && !parserState.isAtEnd()) { + statements.emplace_back(parseBlockItem(parserState)); + } + parserState.consume(Token::Type::RIGHT_BRACE, "Expect '}' after compound statement."); + + return std::make_unique(std::move(statements)); +} + +Statement::StatementPtr parseExpressionStatement(ParserState &parserState) +{ + // expression-statement ::= + // expression? ";" + auto expression = parseExpression(parserState); + + parserState.consume(Token::Type::SEMICOLON, "Expect ';' after expression"); + return std::make_unique(std::move(expression)); +} + +Statement::StatementPtr parsePrintStatement(ParserState &parserState) +{ + auto expression = parseExpression(parserState); + + parserState.consume(Token::Type::SEMICOLON, "Expect ';' after expression"); + return std::make_unique(std::move(expression)); +} + +Statement::StatementPtr parseSelectionStatement(ParserState &parserState) +{ + // selection-statement ::= + // "if" "(" expression ")" statement + // "if" "(" expression ")" statement "else" statement + // "switch" "(" expression ")" statement + const auto keyword = parserState.previous(); + parserState.consume(Token::Type::LEFT_PAREN, "Expect '(' after '" + std::string{keyword.lexeme} + "'"); + auto condition = parseExpression(parserState); + parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after '" + std::string{keyword.lexeme} + "'"); + + // If this is an if statement + if(keyword.type == Token::Type::IF) { + auto thenBranch = parseStatement(parserState); + Statement::StatementPtr elseBranch; + if(parserState.match(Token::Type::ELSE)) { + elseBranch = parseStatement(parserState); + } + + return std::make_unique(std::move(condition), + std::move(thenBranch), + std::move(elseBranch)); + } + // Otherwise (switch statement) + else { + return std::make_unique(keyword, std::move(condition), + parseStatement(parserState)); + } +} + +Statement::StatementPtr parseIterationStatement(ParserState &parserState) +{ + // iteration-statement ::= + // "while" "(" expression ")" statement + // "do" statement "while" "(" expression ")" ";" + // "for" "(" expression? ";" expression? ";" expression? ")" statement + // "for" "(" declaration expression? ";" expression? ")" statement + + // If this is a while statement + if(parserState.previous().type == Token::Type::WHILE) { + parserState.consume(Token::Type::LEFT_PAREN, "Expect '(' after 'while'"); + auto condition = parseExpression(parserState); + parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after 'while'"); + auto body = parseStatement(parserState); + + return std::make_unique(std::move(condition), + std::move(body)); + } + // Otherwise, if this is a do statement + else if(parserState.previous().type == Token::Type::DO) { + auto body = parseStatement(parserState); + parserState.consume(Token::Type::WHILE, "Expected 'while' after 'do' statement body"); + parserState.consume(Token::Type::LEFT_PAREN, "Expect '(' after 'while'"); + auto condition = parseExpression(parserState); + parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after 'while'"); + parserState.consume(Token::Type::SEMICOLON, "Expect ';' after while"); + return std::make_unique(std::move(condition), + std::move(body)); + } + // Otherwise, it's a for statement + else { + parserState.consume(Token::Type::LEFT_PAREN, "Expect '(' after 'for'"); + + // If statement starts with a semicolon - no initialiser + Statement::StatementPtr initialiser; + if(parserState.match(Token::Type::SEMICOLON)) { + initialiser = nullptr; + } + // Otherwise, if it starts with a declaration + else if(parserState.match({Token::Type::TYPE_SPECIFIER, Token::Type::TYPE_QUALIFIER})) { + initialiser = parseDeclaration(parserState); + } + // Otherwise, must be expression (statement consumes semicolon) + else { + initialiser = parseExpressionStatement(parserState); + } + + // Parse condition + Expression::ExpressionPtr condition = nullptr; + if(!parserState.check(Token::Type::SEMICOLON)) { + condition = parseExpression(parserState); + } + parserState.consume(Token::Type::SEMICOLON, "Expect ';' after loop condition"); + + // Parse increment + Expression::ExpressionPtr increment = nullptr; + if(!parserState.check(Token::Type::RIGHT_PAREN)) { + increment = parseExpression(parserState); + } + parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after for clauses"); + + Statement::StatementPtr body = parseStatement(parserState); + + // Return for statement + // **NOTE** we could "de-sugar" into a while statement but this makes pretty-printing easier + return std::make_unique(std::move(initialiser), + std::move(condition), + std::move(increment), + std::move(body)); + } +} + +Statement::StatementPtr parseJumpStatement(ParserState &parserState) +{ + // jump-statement ::= + // "continue" ";" + // "break" ";" + // "return" expression? ";" // **TODO** + const Token token = parserState.previous(); + if(token.type == Token::Type::CONTINUE) { + parserState.consume(Token::Type::SEMICOLON, "Expect ';' after continue"); + return std::make_unique(token); + } + else if(token.type == Token::Type::BREAK) { + parserState.consume(Token::Type::SEMICOLON, "Expect ';' after break"); + return std::make_unique(token); + } + // Otherwise (return statement) + else { + assert(false); + return nullptr; + } +} + +Statement::StatementPtr parseStatement(ParserState &parserState) +{ + // statement ::= + // labeled-statement + // compound-statement + // expression-statement + // print-statement // **TEMP** + // selection-statement + // iteration-statement + // jump-statement + if(parserState.match(Token::Type::PRINT)) { + return parsePrintStatement(parserState); + } + else if(parserState.match({Token::Type::CASE, Token::Type::DEFAULT})) { + return parseLabelledStatement(parserState); + } + else if(parserState.match({Token::Type::IF, Token::Type::SWITCH})) { + return parseSelectionStatement(parserState); + } + else if(parserState.match({Token::Type::FOR, Token::Type::WHILE, Token::Type::DO})) { + return parseIterationStatement(parserState); + } + else if(parserState.match({Token::Type::CONTINUE, Token::Type::BREAK})) { + return parseJumpStatement(parserState); + } + else if(parserState.match(Token::Type::LEFT_BRACE)) { + return parseCompoundStatement(parserState); + } + else { + return parseExpressionStatement(parserState); + } +} + +Statement::StatementPtr parseDeclaration(ParserState &parserState) +{ + // declaration ::= + // declaration-specifiers init-declarator-list? ";" + + // declaration-specifiers ::= + // declaration-specifiers? + // type-specifier declaration-specifiers? + // type-qualifier declaration-specifiers? + + // type-specifier ::= + // "char" + // "short" + // "int" + // "long" + // "float" + // "double" + // "signed" + // "unsigned" + // "bool" + // typedef-name // **TODO** not sure how to address ambiguity with subsequent identifier + + // type-qualifier ::= + // "const" + + // Parse declaration specifiers + const auto [type, isConst] = parseDeclarationSpecifiers(parserState); + + // Read init declarator list + std::vector> initDeclaratorList; + do { + // init-declarator-list ::= + // init-declarator + // init-declarator-list "," init-declarator + + // init-declarator ::= + // declarator + // declarator "=" assignment-expression + + // declarator ::= + // identifier + Token identifier = parserState.consume(Token::Type::IDENTIFIER, "Expect variable name"); + Expression::ExpressionPtr initialiser; + if(parserState.match(Token::Type::EQUAL)) { + initialiser = parseAssignment(parserState); + } + initDeclaratorList.emplace_back(identifier, std::move(initialiser)); + } while(!parserState.isAtEnd() && parserState.match(Token::Type::COMMA)); + + parserState.consume(Token::Type::SEMICOLON, "Expect ';' after variable declaration"); + return std::make_unique(type, isConst, std::move(initDeclaratorList)); +} + +std::unique_ptr parseBlockItem(ParserState &parserState) +{ + // block-item ::= + // declaration + // statement + try { + if(parserState.match({Token::Type::TYPE_SPECIFIER, Token::Type::TYPE_QUALIFIER})) { + return parseDeclaration(parserState); + } + else { + return parseStatement(parserState); + } + } + catch(ParseError &) { + synchronise(parserState); + return nullptr; + } +} +} // Anonymous namespace + + +//--------------------------------------------------------------------------- +// MiniParse::Parser +//--------------------------------------------------------------------------- +namespace MiniParse::Parser +{ +Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandler &errorHandler) +{ + ParserState parserState(tokens, errorHandler); + + try { + return parseExpression(parserState); + } + catch(ParseError &) { + return nullptr; + } +} + +Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandler &errorHandler) +{ + ParserState parserState(tokens, errorHandler); + std::vector> statements; + + while(!parserState.isAtEnd()) { + statements.emplace_back(parseBlockItem(parserState)); + } + return statements; +} +} diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc new file mode 100644 index 0000000000..0e21c638d1 --- /dev/null +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -0,0 +1,272 @@ +#include "pretty_printer.h" + +// Standard C++ includes +#include +#include +#include + +// Mini-parse includes +#include "type.h" +#include "utils.h" + + +using namespace MiniParse; +using namespace MiniParse::PrettyPrinter; + +//--------------------------------------------------------------------------- +// Anonymous namespace +//--------------------------------------------------------------------------- +namespace +{ +//--------------------------------------------------------------------------- +// Visitor +//--------------------------------------------------------------------------- +class Visitor : public Expression::Visitor, public Statement::Visitor +{ +public: + //--------------------------------------------------------------------------- + // Public API + //--------------------------------------------------------------------------- + std::string print(const Statement::StatementList &statements) + { + // Clear string stream + m_StringStream.str(""); + + for(auto &s : statements) { + s.get()->accept(*this); + m_StringStream << std::endl; + } + + // Return string stream contents + return m_StringStream.str(); + } + + //--------------------------------------------------------------------------- + // Expression::Visitor virtuals + //--------------------------------------------------------------------------- + virtual void visit(const Expression::ArraySubscript &arraySubscript) final + { + m_StringStream << arraySubscript.getPointerName().lexeme << "["; + arraySubscript.getIndex()->accept(*this); + m_StringStream << "]"; + } + + virtual void visit(const Expression::Assignment &assignement) final + { + m_StringStream << assignement.getVarName().lexeme << " " << assignement.getOperator().lexeme << " "; + assignement.getValue()->accept(*this); + } + + virtual void visit(const Expression::Binary &binary) final + { + binary.getLeft()->accept(*this); + m_StringStream << " " << binary.getOperator().lexeme << " "; + binary.getRight()->accept(*this); + } + + virtual void visit(const Expression::Call &call) final + { + call.getCallee()->accept(*this); + m_StringStream << "("; + for(const auto &a : call.getArguments()) { + a->accept(*this); + } + m_StringStream << ")"; + } + + virtual void visit(const Expression::Cast &cast) final + { + m_StringStream << "(" << cast.getType()->getTypeName() << ")"; + cast.getExpression()->accept(*this); + } + + virtual void visit(const Expression::Conditional &conditional) final + { + conditional.getCondition()->accept(*this); + m_StringStream << " ? "; + conditional.getTrue()->accept(*this); + m_StringStream << " : "; + conditional.getFalse()->accept(*this); + } + + virtual void visit(const Expression::Grouping &grouping) final + { + m_StringStream << "("; + grouping.getExpression()->accept(*this); + m_StringStream << ")"; + } + + virtual void visit(const Expression::Literal &literal) final + { + std::visit( + Utils::Overload{ + [this](auto x) { m_StringStream << x; }, + [this](std::monostate) { m_StringStream << "invalid"; }}, + literal.getValue()); + } + + virtual void visit(const Expression::Logical &logical) final + { + logical.getLeft()->accept(*this); + m_StringStream << " " << logical.getOperator().lexeme << " "; + logical.getRight()->accept(*this); + } + + virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final + { + m_StringStream << postfixIncDec.getVarName().lexeme << postfixIncDec.getOperator().lexeme; + } + + virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final + { + m_StringStream << prefixIncDec.getOperator().lexeme << prefixIncDec.getVarName().lexeme; + } + + virtual void visit(const Expression::Variable &variable) final + { + m_StringStream << variable.getName().lexeme; + } + + virtual void visit(const Expression::Unary &unary) final + { + m_StringStream << unary.getOperator().lexeme; + unary.getRight()->accept(*this); + } + + //--------------------------------------------------------------------------- + // Statement::Visitor virtuals + //--------------------------------------------------------------------------- + virtual void visit(const Statement::Break&) final + { + m_StringStream << "break;"; + } + + virtual void visit(const Statement::Compound &compound) final + { + m_StringStream << "{" << std::endl; + for(auto &s : compound.getStatements()) { + s->accept(*this); + m_StringStream << std::endl; + } + m_StringStream << "}" << std::endl; + } + + virtual void visit(const Statement::Continue&) final + { + m_StringStream << "continue;"; + } + + virtual void visit(const Statement::Do &doStatement) final + { + m_StringStream << "do"; + doStatement.getBody()->accept(*this); + m_StringStream << "while("; + doStatement.getCondition()->accept(*this); + m_StringStream << ");" << std::endl; + } + + virtual void visit(const Statement::Expression &expression) final + { + expression.getExpression()->accept(*this); + m_StringStream << ";"; + } + + virtual void visit(const Statement::For &forStatement) final + { + m_StringStream << "for("; + if(forStatement.getInitialiser()) { + forStatement.getInitialiser()->accept(*this); + } + else { + m_StringStream << ";"; + } + m_StringStream << " "; + + if(forStatement.getCondition()) { + forStatement.getCondition()->accept(*this); + } + + m_StringStream << "; "; + if(forStatement.getIncrement()) { + forStatement.getIncrement()->accept(*this); + } + m_StringStream << ")"; + forStatement.getBody()->accept(*this); + } + + virtual void visit(const Statement::If &ifStatement) final + { + m_StringStream << "if("; + ifStatement.getCondition()->accept(*this); + m_StringStream << ")" << std::endl; + ifStatement.getThenBranch()->accept(*this); + if(ifStatement.getElseBranch()) { + m_StringStream << "else" << std::endl; + ifStatement.getElseBranch()->accept(*this); + } + } + + virtual void visit(const Statement::Labelled &labelled) final + { + m_StringStream << labelled.getKeyword().lexeme << " "; + if(labelled.getValue()) { + labelled.getValue()->accept(*this); + } + m_StringStream << " : "; + labelled.getBody()->accept(*this); + } + + virtual void visit(const Statement::Switch &switchStatement) final + { + m_StringStream << "switch("; + switchStatement.getCondition()->accept(*this); + m_StringStream << ")" << std::endl; + switchStatement.getBody()->accept(*this); + } + + virtual void visit(const Statement::VarDeclaration &varDeclaration) final + { + if(varDeclaration.isConst()) { + m_StringStream << "const "; + } + m_StringStream << varDeclaration.getType()->getTypeName() << " "; + + for(const auto &var : varDeclaration.getInitDeclaratorList()) { + m_StringStream << std::get<0>(var).lexeme; + if(std::get<1>(var)) { + m_StringStream << " = "; + std::get<1>(var)->accept(*this); + } + m_StringStream << ", "; + } + m_StringStream << ";"; + } + + virtual void visit(const Statement::While &whileStatement) final + { + m_StringStream << "while("; + whileStatement.getCondition()->accept(*this); + m_StringStream << ")" << std::endl; + whileStatement.getBody()->accept(*this); + } + + virtual void visit(const Statement::Print &print) final + { + m_StringStream << "print "; + print.getExpression()->accept(*this); + m_StringStream << ";"; + } + +private: + //--------------------------------------------------------------------------- + // Members + //--------------------------------------------------------------------------- + std::ostringstream m_StringStream; +}; +} // Anonymous namespace + +std::string MiniParse::PrettyPrinter::print(const Statement::StatementList &statements) +{ + Visitor visitor; + return visitor.print(statements); +} diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc new file mode 100644 index 0000000000..abb9baefc6 --- /dev/null +++ b/src/genn/genn/transpiler/scanner.cc @@ -0,0 +1,486 @@ +#include "scanner.h" + +// Standard C++ includes +#include +#include +#include +#include +#include + +// Standard C includes +#include + +// Mini-parse includes +#include "error_handler.h" +#include "utils.h" + +using namespace MiniParse; +using namespace MiniParse::Scanner; + +//--------------------------------------------------------------------------- +// Anonymous namespace +//--------------------------------------------------------------------------- +namespace +{ +const std::unordered_map keywords{ + {"const", Token::Type::TYPE_QUALIFIER}, + {"do", Token::Type::DO}, + {"else", Token::Type::ELSE}, + {"false", Token::Type::FALSE}, + {"for", Token::Type::FOR}, + {"if", Token::Type::IF}, + {"true", Token::Type::TRUE}, + {"while", Token::Type::WHILE}, + {"switch", Token::Type::SWITCH}, + {"break", Token::Type::BREAK}, + {"continue", Token::Type::CONTINUE}, + {"case", Token::Type::CASE}, + {"default", Token::Type::DEFAULT}, + {"print", Token::Type::PRINT}, // **HACK** + {"char", Token::Type::TYPE_SPECIFIER}, + {"short", Token::Type::TYPE_SPECIFIER}, + {"int", Token::Type::TYPE_SPECIFIER}, + {"long", Token::Type::TYPE_SPECIFIER}, + {"float", Token::Type::TYPE_SPECIFIER}, + {"double", Token::Type::TYPE_SPECIFIER}, + {"signed", Token::Type::TYPE_SPECIFIER}, + {"unsigned", Token::Type::TYPE_SPECIFIER}, + {"bool", Token::Type::TYPE_SPECIFIER}}; +//--------------------------------------------------------------------------- +const std::map, std::function> integerLiteralSuffixParsers{ + {{}, [](std::string_view input, int base) { return Utils::toCharsThrow(input, base); }}, + {{'U'}, [](std::string_view input, int base) { return Utils::toCharsThrow(input, base); }}, +}; +//--------------------------------------------------------------------------- +// ScanState +//--------------------------------------------------------------------------- +//! Class encapsulated logic to navigate through source characters +class ScanState +{ +public: + ScanState(std::string_view source, ErrorHandler &errorHandler) + : m_Start(0), m_Current(0), m_Line(1), m_Source(source), m_ErrorHandler(errorHandler) + {} + + //--------------------------------------------------------------------------- + // Public API + //--------------------------------------------------------------------------- + char advance() { + m_Current++; + return m_Source.at(m_Current - 1); + } + + bool match(char expected) + { + if(isAtEnd()) { + return false; + } + if(m_Source.at(m_Current) != expected) { + return false; + } + + m_Current++; + return true; + } + + void resetLexeme() + { + m_Start = m_Current; + } + + char peek() const + { + if(isAtEnd()) { + return '\0'; + } + return m_Source.at(m_Current); + } + + char peekNext() const + { + if((m_Current + 1) >= m_Source.length()) { + return '\0'; + } + else { + return m_Source.at(m_Current + 1); + } + } + + std::string_view getLexeme() const + { + return m_Source.substr(m_Start, m_Current - m_Start); + } + + size_t getLine() const { return m_Line; } + + bool isAtEnd() const { return m_Current >= m_Source.length(); } + + void nextLine() { m_Line++; } + + void error(std::string_view message) + { + m_ErrorHandler.error(getLine(), message); + } +private: + //--------------------------------------------------------------------------- + // Members + //--------------------------------------------------------------------------- + size_t m_Start; + size_t m_Current; + size_t m_Line; + + const std::string_view m_Source; + + ErrorHandler &m_ErrorHandler; +}; + +bool isodigit(char c) +{ + return (c >= '0' && c <= '7'); +} + +//--------------------------------------------------------------------------- +void emplaceToken(std::vector &tokens, Token::Type type, const ScanState &scanState, Token::LiteralValue literalValue = Token::LiteralValue()) +{ + tokens.emplace_back(type, scanState.getLexeme(), scanState.getLine(), literalValue); +} +//--------------------------------------------------------------------------- +std::set scanIntegerSuffix(ScanState &scanState) +{ + // Read suffix + std::set suffix; + while(std::toupper(scanState.peek()) == 'U' || std::toupper(scanState.peek()) == 'L') { + suffix.insert(std::toupper(scanState.advance())); + } + return suffix; +} +//--------------------------------------------------------------------------- +void scanNumber(char c, ScanState &scanState, std::vector &tokens) +{ + // If this is a hexadecimal literal + if(c == '0' && (scanState.match('x') || scanState.match('X'))) { + // Read hexadecimal digits + while(std::isxdigit(scanState.peek())) { + scanState.advance(); + } + + // Read decimal place + const bool isFloat = scanState.match('.'); + + // Read hexadecimal digits + while(std::isxdigit(scanState.peek())) { + scanState.advance(); + } + + // If number is float + if(isFloat) { + // Check there's an exponent as these are REQUIRED for floating point literals + if(scanState.peek() != 'p') { + scanState.error("Hexadecimal floating point literal missing exponent."); + } + else { + // Read p + scanState.advance(); + + // Read sign + if(scanState.peek() == '-' || scanState.peek() == '+') { + scanState.advance(); + } + + // Read DECIMAL digits + while(std::isdigit(scanState.peek())) { + scanState.advance(); + } + + // If literal has floating point suffix + if(std::tolower(scanState.peek()) == 'f') { + // Add single-precision token + // **NOTE** skip 0x prefix + emplaceToken(tokens, Token::Type::NUMBER, scanState, + Utils::toCharsThrow(scanState.getLexeme().substr(2), 16)); + + // Advance + // **NOTE** we do this AFTER parsing float as std::to_chars doesn't deal with suffixes + scanState.advance(); + } + // Add double-precision token + // **NOTE** skip 0x prefix + else { + emplaceToken(tokens, Token::Type::NUMBER, scanState, + Utils::toCharsThrow(scanState.getLexeme().substr(2), 16)); + } + } + } + // Otherwise, number is hexadecimal integer + else { + // Add integer token + // **NOTE** skip 0x prefix + const auto suffix = scanIntegerSuffix(scanState); + emplaceToken(tokens, Token::Type::NUMBER, scanState, + integerLiteralSuffixParsers.at(suffix)(scanState.getLexeme().substr(2), 16)); + } + } + // Otherwise, if this is an octal integer + else if(c == '0' && isodigit(scanState.peek())){ + scanState.error("Octal literals unsupported."); + } + // Otherwise, if it's decimal + else { + // Read digits + while(std::isdigit(scanState.peek())) { + scanState.advance(); + } + + // Read decimal place + const bool isFloat = scanState.match('.'); + + // Read digits + while(std::isdigit(scanState.peek())) { + scanState.advance(); + } + + // If it's float + if(isFloat) { + // If there's an exponent + if(scanState.match('e')) { + // Read sign + if(scanState.peek() == '-' || scanState.peek() == '+') { + scanState.advance(); + } + + // Read digits + while(std::isdigit(scanState.peek())) { + scanState.advance(); + } + } + + // If literal has floating point suffix + if(std::tolower(scanState.peek()) == 'f') { + // Add single-precision token + emplaceToken(tokens, Token::Type::NUMBER, scanState, + Utils::toCharsThrow(scanState.getLexeme())); + + // Advance + // **NOTE** we do this AFTER parsing float as std::to_chars doesn't deal with suffixes + scanState.advance(); + } + // Otherwise, add double-precision token + else { + emplaceToken(tokens, Token::Type::NUMBER, scanState, + Utils::toCharsThrow(scanState.getLexeme())); + } + } + // Otherwise, number is integer + else { + // Add integer token + const auto suffix = scanIntegerSuffix(scanState); + emplaceToken(tokens, Token::Type::NUMBER, scanState, + integerLiteralSuffixParsers.at(suffix)(scanState.getLexeme(), 10)); + } + } +} +//--------------------------------------------------------------------------- +void scanIdentifier(ScanState &scanState, std::vector &tokens) +{ + // Read subsequent alphanumeric characters and underscores + while(std::isalnum(scanState.peek()) || scanState.peek() == '_') { + scanState.advance(); + } + + // If identifier is a keyword, add appropriate token + const auto k = keywords.find(scanState.getLexeme()); + if(k != keywords.cend()) { + emplaceToken(tokens, k->second, scanState); + } + // Otherwise, add identifier token + else { + emplaceToken(tokens, Token::Type::IDENTIFIER, scanState); + } +} +//--------------------------------------------------------------------------- +void scanToken(ScanState &scanState, std::vector &tokens) +{ + using namespace MiniParse; + + char c = scanState.advance(); + switch(c) { + // Single character tokens + case '(': emplaceToken(tokens, Token::Type::LEFT_PAREN, scanState); break; + case ')': emplaceToken(tokens, Token::Type::RIGHT_PAREN, scanState); break; + case '{': emplaceToken(tokens, Token::Type::LEFT_BRACE, scanState); break; + case '}': emplaceToken(tokens, Token::Type::RIGHT_BRACE, scanState); break; + case '[': emplaceToken(tokens, Token::Type::LEFT_SQUARE_BRACKET, scanState); break; + case ']': emplaceToken(tokens, Token::Type::RIGHT_SQUARE_BRACKET, scanState); break; + case ',': emplaceToken(tokens, Token::Type::COMMA, scanState); break; + case '.': emplaceToken(tokens, Token::Type::DOT, scanState); break; + case ':': emplaceToken(tokens, Token::Type::COLON, scanState); break; + case ';': emplaceToken(tokens, Token::Type::SEMICOLON, scanState); break; + case '~': emplaceToken(tokens, Token::Type::TILDA, scanState); break; + case '?': emplaceToken(tokens, Token::Type::QUESTION, scanState); break; + + // Operators + case '!': emplaceToken(tokens, scanState.match('=') ? Token::Type::NOT_EQUAL : Token::Type::NOT, scanState); break; + case '=': emplaceToken(tokens, scanState.match('=') ? Token::Type::EQUAL_EQUAL : Token::Type::EQUAL, scanState); break; + + // Assignment operators + case '*': emplaceToken(tokens, scanState.match('=') ? Token::Type::STAR_EQUAL : Token::Type::STAR, scanState); break; + //case '/': emplaceToken(tokens, scanState.match('=') ? Token::Type::SLASH_EQUAL : Token::Type::SLASH, scanState); break; + case '%': emplaceToken(tokens, scanState.match('=') ? Token::Type::PERCENT_EQUAL : Token::Type::PERCENT, scanState); break; + case '^': emplaceToken(tokens, scanState.match('=') ? Token::Type::CARET_EQUAL : Token::Type::CARET, scanState); break; + + case '<': + { + if(scanState.match('=')) { + emplaceToken(tokens, Token::Type::LESS_EQUAL, scanState); + } + else if(scanState.match('<')) { + if(scanState.match('=')) { + emplaceToken(tokens, Token::Type::SHIFT_LEFT_EQUAL, scanState); + } + else { + emplaceToken(tokens, Token::Type::SHIFT_LEFT, scanState); + } + } + else { + emplaceToken(tokens, Token::Type::LESS, scanState); + } + break; + } + + case '>': + { + if(scanState.match('=')) { + emplaceToken(tokens, Token::Type::GREATER_EQUAL, scanState); + } + else if(scanState.match('<')) { + if(scanState.match('=')) { + emplaceToken(tokens, Token::Type::SHIFT_RIGHT_EQUAL, scanState); + } + else { + emplaceToken(tokens, Token::Type::SHIFT_RIGHT, scanState); + } + } + else { + emplaceToken(tokens, Token::Type::GREATER, scanState); + } + break; + } + + case '+': + { + if(scanState.match('=')) { + emplaceToken(tokens, Token::Type::PLUS_EQUAL, scanState); + } + else if(scanState.match('+')) { + emplaceToken(tokens, Token::Type::PLUS_PLUS, scanState); + } + else { + emplaceToken(tokens, Token::Type::PLUS, scanState); + } + break; + } + + case '-': + { + if(scanState.match('=')) { + emplaceToken(tokens, Token::Type::MINUS_EQUAL, scanState); + } + else if(scanState.match('-')) { + emplaceToken(tokens, Token::Type::MINUS_MINUS, scanState); + } + else { + emplaceToken(tokens, Token::Type::MINUS, scanState); + } + break; + } + + case '&': + { + if(scanState.match('=')) { + emplaceToken(tokens, Token::Type::AMPERSAND_EQUAL, scanState); + } + else if(scanState.match('&')) { + emplaceToken(tokens, Token::Type::AMPERSAND_AMPERSAND, scanState); + } + else { + emplaceToken(tokens, Token::Type::AMPERSAND, scanState); + } + break; + } + + case '|': + { + if(scanState.match('=')) { + emplaceToken(tokens, Token::Type::PIPE_EQUAL, scanState); + } + else if(scanState.match('|')) { + emplaceToken(tokens, Token::Type::PIPE_PIPE, scanState); + } + else { + emplaceToken(tokens, Token::Type::PIPE, scanState); + } + break; + } + + case '/': + { + // Line comment + if(scanState.match('/')) { + while(scanState.peek() != '\n' && !scanState.isAtEnd()) { + scanState.advance(); + } + } + else { + emplaceToken(tokens, Token::Type::SLASH, scanState); + } + break; + } + + // Whitespace + case ' ': + case '\r': + case '\t': + break; + + // New line + case '\n': scanState.nextLine(); break; + + default: + { + // If we have a digit or a period, scan number + if(std::isdigit(c) || c == '.') { + scanNumber(c, scanState, tokens); + } + // Otherwise, scan identifier + else if(std::isalpha(c) || c == '_') { + scanIdentifier(scanState, tokens); + } + else { + scanState.error("Unexpected character."); + } + } + } +} +} + +//--------------------------------------------------------------------------- +// MiniParse::Scanner +//--------------------------------------------------------------------------- +namespace MiniParse::Scanner +{ +std::vector scanSource(const std::string_view &source, ErrorHandler &errorHandler) +{ + std::vector tokens; + + ScanState scanState(source, errorHandler); + + // Scan tokens + while(!scanState.isAtEnd()) { + scanState.resetLexeme(); + scanToken(scanState, tokens); + } + + emplaceToken(tokens, Token::Type::END_OF_FILE, scanState); + return tokens; +} +} diff --git a/src/genn/genn/transpiler/statement.cc b/src/genn/genn/transpiler/statement.cc new file mode 100644 index 0000000000..eb142d178d --- /dev/null +++ b/src/genn/genn/transpiler/statement.cc @@ -0,0 +1,21 @@ +#include "statement.h" + +#define IMPLEMENT_ACCEPT(CLASS_NAME) \ + void MiniParse::Statement::CLASS_NAME::accept(Visitor &visitor) const \ + { \ + visitor.visit(*this); \ + } + +// Implement accept methods +IMPLEMENT_ACCEPT(Break) +IMPLEMENT_ACCEPT(Compound) +IMPLEMENT_ACCEPT(Continue) +IMPLEMENT_ACCEPT(Do) +IMPLEMENT_ACCEPT(Expression) +IMPLEMENT_ACCEPT(For) +IMPLEMENT_ACCEPT(If) +IMPLEMENT_ACCEPT(Labelled) +IMPLEMENT_ACCEPT(Switch) +IMPLEMENT_ACCEPT(VarDeclaration) +IMPLEMENT_ACCEPT(While) +IMPLEMENT_ACCEPT(Print) \ No newline at end of file diff --git a/src/genn/genn/transpiler/type.cc b/src/genn/genn/transpiler/type.cc new file mode 100644 index 0000000000..f8423b9ba2 --- /dev/null +++ b/src/genn/genn/transpiler/type.cc @@ -0,0 +1,139 @@ +#include "type.h" + +// Standard C++ includes +#include +#include + +// Anonymous namespace +namespace +{ +const std::map, const Type::NumericBase*> numericTypes{ + {{"char"}, Type::Int8::getInstance()}, + + {{"unsigned", "char"}, Type::Uint8::getInstance()}, + + {{"short"}, Type::Int16::getInstance()}, + {{"short", "int"}, Type::Int16::getInstance()}, + {{"signed", "short"}, Type::Int16::getInstance()}, + {{"signed", "short", "int"}, Type::Int16::getInstance()}, + + {{"unsigned", "short"}, Type::Uint16::getInstance()}, + {{"unsigned", "short", "int"}, Type::Uint16::getInstance()}, + + {{"int"}, Type::Int32::getInstance()}, + {{"signed"}, Type::Int32::getInstance()}, + {{"signed", "int"}, Type::Int32::getInstance()}, + + {{"unsigned"}, Type::Uint32::getInstance()}, + {{"unsigned", "int"}, Type::Uint32::getInstance()}, + + {{"float"}, Type::Float::getInstance()}, + {{"double"}, Type::Double::getInstance()}, +}; +//---------------------------------------------------------------------------- +// Mapping of signed integer numericTypes to their unsigned equivalents +const std::unordered_map unsignedType{ + {Type::Int8::getInstance(), Type::Uint8::getInstance()}, + {Type::Int16::getInstance(), Type::Uint16::getInstance()}, + {Type::Int32::getInstance(), Type::Uint32::getInstance()} +}; +} // Anonymous namespace + +//---------------------------------------------------------------------------- +// Type +//---------------------------------------------------------------------------- +namespace Type +{ +// Implement numeric types +IMPLEMENT_NUMERIC_TYPE(Bool); +IMPLEMENT_NUMERIC_TYPE(Int8); +IMPLEMENT_NUMERIC_TYPE(Int16); +IMPLEMENT_NUMERIC_TYPE(Int32); +IMPLEMENT_NUMERIC_TYPE(Uint8); +IMPLEMENT_NUMERIC_TYPE(Uint16); +IMPLEMENT_NUMERIC_TYPE(Uint32); +IMPLEMENT_NUMERIC_TYPE(Float); +IMPLEMENT_NUMERIC_TYPE(Double); + +// Implement foreign function types +IMPLEMENT_TYPE(Exp); +IMPLEMENT_TYPE(Sqrt); + +//---------------------------------------------------------------------------- +// Free functions +//---------------------------------------------------------------------------- +const NumericBase *getNumericType(const std::set &typeSpecifiers) +{ + const auto type = numericTypes.find(typeSpecifiers); + return (type == numericTypes.cend()) ? nullptr : type->second; +} +//---------------------------------------------------------------------------- +const NumericPtrBase *getNumericPtrType(const std::set &typeSpecifiers) +{ + const auto type = numericTypes.find(typeSpecifiers); + return (type == numericTypes.cend()) ? nullptr : type->second->getPointerType(); +} +//---------------------------------------------------------------------------- +const NumericBase *getPromotedType(const NumericBase *type) +{ + // If a small integer type is used in an expression, it is implicitly converted to int which is always signed. + // This is known as the integer promotions or the integer promotion rule + // **NOTE** this is true because in our type system unsigned short is uint16 which can be represented in int32 + if(type->getRank() < Int32::getInstance()->getRank()) { + return Int32::getInstance(); + } + else { + return type; + } +} +//---------------------------------------------------------------------------- +const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b) +{ + // If either type is double, common type is double + const auto aTypeHash = a->getTypeHash(); + const auto bTypeHash = b->getTypeHash(); + if(aTypeHash == Double::getInstance()->getTypeHash() || bTypeHash == Double::getInstance()->getTypeHash()) { + return Double::getInstance(); + } + // Otherwise, if either type is float, common type is float + if(aTypeHash == Float::getInstance()->getTypeHash() || bTypeHash == Float::getInstance()->getTypeHash()) { + return Float::getInstance(); + } + // Otherwise, must be an integer type + else { + // Promote both numericTypes + const auto *aPromoted = getPromotedType(a); + const auto *bPromoted = getPromotedType(b); + + // If both promoted operands have the same type, then no further conversion is needed. + if(aPromoted->getTypeHash() == bPromoted->getTypeHash()) { + return aPromoted; + } + // Otherwise, if both promoted operands have signed integer numericTypes or both have unsigned integer numericTypes, + // the operand with the type of lesser integer conversion rank is converted to the type of the operand with greater rank. + else if(aPromoted->isSigned() == bPromoted->isSigned()) { + return (aPromoted->getRank() > bPromoted->getRank()) ? aPromoted : bPromoted; + } + // Otherwise, if signedness of promoted operands differ + else { + const auto *signedOp = aPromoted->isSigned() ? aPromoted : bPromoted; + const auto *unsignedOp = aPromoted->isSigned() ? bPromoted : aPromoted; + + // Otherwise, if the operand that has unsigned integer type has rank greater or equal to the rank of the type of the other operand, + // then the operand with signed integer type is converted to the type of the operand with unsigned integer type. + if(unsignedOp->getRank() >= signedOp->getRank()) { + return unsignedOp; + } + // Otherwise, if the type of the operand with signed integer type can represent all of the values of the type of the operand with unsigned integer type, + // then the operand with unsigned integer type is converted to the type of the operand with signed integer type. + else if(signedOp->getMin() <= unsignedOp->getMin() && signedOp->getMax() >= unsignedOp->getMax()) { + return signedOp; + } + // Otherwise, both operands are converted to the unsigned integer type corresponding to the type of the operand with signed integer type. + else { + return unsignedType.at(signedOp); + } + } + } +} +} \ No newline at end of file diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc new file mode 100644 index 0000000000..db4a24626d --- /dev/null +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -0,0 +1,673 @@ +#include "type_checker.h" + +// Standard C++ includes +#include + +// Standard C includes +#include + +// GeNN includes +#include "type.h" + +// Mini-parse includes +#include "error_handler.h" +#include "expression.h" +#include "utils.h" + +using namespace MiniParse; +using namespace MiniParse::TypeChecker; + +//--------------------------------------------------------------------------- +// Anonymous namespace +//--------------------------------------------------------------------------- +namespace +{ +//--------------------------------------------------------------------------- +// TypeCheckError +//--------------------------------------------------------------------------- +class TypeCheckError +{ +}; + +//--------------------------------------------------------------------------- +// Vistor +//--------------------------------------------------------------------------- +class Visitor : public Expression::Visitor, public Statement::Visitor +{ +public: + Visitor(ErrorHandler &errorHandler) + : m_Environment(nullptr), m_Type(nullptr), m_Const(false), + m_ErrorHandler(errorHandler), m_InLoop(false), m_InSwitch(false) + { + } + + //--------------------------------------------------------------------------- + // Public API + //--------------------------------------------------------------------------- + void typeCheck(const Statement::StatementList &statements, Environment &environment) + { + Environment *previous = m_Environment; + m_Environment = &environment; + for (auto &s : statements) { + s.get()->accept(*this); + } + m_Environment = previous; + } + + //--------------------------------------------------------------------------- + // Expression::Visitor virtuals + //--------------------------------------------------------------------------- + virtual void visit(const Expression::ArraySubscript &arraySubscript) final + { + // Get pointer type + auto pointerType = dynamic_cast( + std::get<0>(m_Environment->getType(arraySubscript.getPointerName(), m_ErrorHandler))); + + // If pointer is indeed a pointer + if (pointerType) { + // Evaluate pointer type + auto indexType = evaluateType(arraySubscript.getIndex().get()); + auto indexNumericType = dynamic_cast(indexType); + if (!indexNumericType || !indexNumericType->isIntegral()) { + m_ErrorHandler.error(arraySubscript.getPointerName(), + "Invalid subscript index type '" + indexType->getTypeName() + "'"); + throw TypeCheckError(); + } + + // Use value type of array + m_Type = pointerType->getValueType(); + m_Const = false; + } + // Otherwise + else { + m_ErrorHandler.error(arraySubscript.getPointerName(), "Subscripted object is not a pointer"); + throw TypeCheckError(); + } + } + + virtual void visit(const Expression::Assignment &assignment) final + { + const auto [rhsType, rhsConst] = evaluateTypeConst(assignment.getValue()); + m_Type = m_Environment->assign(assignment.getVarName(), rhsType, rhsConst, + assignment.getOperator().type, m_ErrorHandler); + m_Const = false; + } + + virtual void visit(const Expression::Binary &binary) final + { + const auto opType = binary.getOperator().type; + const auto [rightType, rightConst] = evaluateTypeConst(binary.getRight()); + if (opType == Token::Type::COMMA) { + m_Type = rightType; + m_Const = rightConst; + } + else { + // If we're subtracting two pointers + const auto [leftType, leftConst] = evaluateTypeConst(binary.getLeft()); + auto leftNumericType = dynamic_cast(leftType); + auto rightNumericType = dynamic_cast(rightType); + auto leftNumericPtrType = dynamic_cast(leftType); + auto rightNumericPtrType = dynamic_cast(rightType); + if (leftNumericPtrType && rightNumericPtrType && opType == Token::Type::MINUS) { + // Check pointers are compatible + if (leftNumericPtrType->getTypeHash() != rightNumericPtrType->getTypeHash()) { + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getTypeName() + "' and '" + rightType->getTypeName()); + throw TypeCheckError(); + } + + // **TODO** should be std::ptrdiff/Int64 + m_Type = Type::Int32::getInstance(); + m_Const = false; + } + // Otherwise, if we're adding to or subtracting from pointers + else if (leftNumericPtrType && rightNumericType && (opType == Token::Type::PLUS || opType == Token::Type::MINUS)) // P + n or P - n + { + // Check that numeric operand is integer + if (!rightNumericType->isIntegral()) { + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getTypeName() + "' and '" + rightType->getTypeName()); + throw TypeCheckError(); + } + + // Use pointer type + m_Type = leftNumericPtrType; + m_Const = leftConst; + } + // Otherwise, if we're adding a number to a pointer + else if (leftNumericType && rightNumericPtrType && opType == Token::Type::PLUS) // n + P + { + // Check that numeric operand is integer + if (!leftNumericType->isIntegral()) { + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getTypeName() + "' and '" + rightType->getTypeName()); + throw TypeCheckError(); + } + + // Use pointer type + m_Type = rightNumericPtrType; + m_Const = rightConst; + } + // Otherwise, if both operands are numeric + else if (leftNumericType && rightNumericType) { + // Otherwise, if operator requires integer operands + if (opType == Token::Type::PERCENT || opType == Token::Type::SHIFT_LEFT + || opType == Token::Type::SHIFT_RIGHT || opType == Token::Type::CARET + || opType == Token::Type::AMPERSAND || opType == Token::Type::PIPE) + { + // Check that operands are integers + if (!leftNumericType->isIntegral() || !rightNumericType->isIntegral()) { + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getTypeName() + "' and '" + rightType->getTypeName()); + throw TypeCheckError(); + } + + // If operator is a shift, promote left type + if (opType == Token::Type::SHIFT_LEFT || opType == Token::Type::SHIFT_RIGHT) { + m_Type = Type::getPromotedType(leftNumericType); + m_Const = false; + } + // Otherwise, take common type + else { + m_Type = Type::getCommonType(leftNumericType, rightNumericType); + m_Const = false; + } + } + // Otherwise, any numeric type will do, take common type + else { + m_Type = Type::getCommonType(leftNumericType, rightNumericType); + m_Const = false; + } + } + else { + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getTypeName() + "' and '" + rightType->getTypeName()); + throw TypeCheckError(); + } + } + } + + virtual void visit(const Expression::Call &call) final + { + // Evaluate callee type + auto calleeType = evaluateType(call.getCallee()); + auto calleeFunctionType = dynamic_cast(calleeType); + + // If callee's a function + if (calleeFunctionType) { + // If argument count doesn't match + const auto argTypes = calleeFunctionType->getArgumentTypes(); + if (call.getArguments().size() < argTypes.size()) { + m_ErrorHandler.error(call.getClosingParen(), "Too many arguments to function"); + throw TypeCheckError(); + } + else if (call.getArguments().size() > argTypes.size()) { + m_ErrorHandler.error(call.getClosingParen(), "Too few arguments to function"); + throw TypeCheckError(); + } + else { + // Loop through arguments + // **TODO** check + /*for(size_t i = 0; i < argTypes.size(); i++) { + // Evaluate argument type + auto callArgType = evaluateType(call.getArguments().at(i).get()); + }*/ + // Type is return type of function + m_Type = calleeFunctionType->getReturnType(); + m_Const = false; + } + } + // Otherwise + else { + m_ErrorHandler.error(call.getClosingParen(), "Called object is not a function"); + throw TypeCheckError(); + } + } + + virtual void visit(const Expression::Cast &cast) final + { + // **TODO** any numeric can be cast to any numeric and any pointer to pointer but no intermixing + // **TODO** const cannot be removed like this + m_Type = cast.getType(); + m_Const = cast.isConst(); + } + + virtual void visit(const Expression::Conditional &conditional) final + { + const auto [trueType, trueConst] = evaluateTypeConst(conditional.getTrue()); + const auto [falseType, falseConst] = evaluateTypeConst(conditional.getFalse()); + auto trueNumericType = dynamic_cast(trueType); + auto falseNumericType = dynamic_cast(falseType); + if (trueNumericType && falseNumericType) { + m_Type = Type::getCommonType(trueNumericType, falseNumericType); + m_Const = trueConst || falseConst; + } + else { + m_ErrorHandler.error(conditional.getQuestion(), + "Invalid operand types '" + trueType->getTypeName() + "' and '" + std::string{falseType->getTypeName()} + "' to conditional"); + throw TypeCheckError(); + } + } + + virtual void visit(const Expression::Grouping &grouping) final + { + std::tie(m_Type, m_Const) = evaluateTypeConst(grouping.getExpression()); + } + + virtual void visit(const Expression::Literal &literal) final + { + m_Type = std::visit( + MiniParse::Utils::Overload{ + [](auto v)->const Type::NumericBase *{ return Type::TypeTraits::NumericType::getInstance(); }, + [](std::monostate)->const Type::NumericBase *{ return nullptr; }}, + literal.getValue()); + m_Const = false; + } + + virtual void visit(const Expression::Logical &logical) final + { + logical.getLeft()->accept(*this); + logical.getRight()->accept(*this); + m_Type = Type::Int32::getInstance(); + m_Const = false; + } + + virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final + { + m_Type = m_Environment->incDec(postfixIncDec.getVarName(), + postfixIncDec.getOperator(), m_ErrorHandler); + m_Const = false; + } + + virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final + { + m_Type = m_Environment->incDec(prefixIncDec.getVarName(), + prefixIncDec.getOperator(), m_ErrorHandler); + m_Const = false; + } + + virtual void visit(const Expression::Variable &variable) + { + std::tie(m_Type, m_Const) = m_Environment->getType(variable.getName(), m_ErrorHandler); + } + + virtual void visit(const Expression::Unary &unary) final + { + const auto [rightType, rightConst] = evaluateTypeConst(unary.getRight()); + + // If operator is pointer de-reference + if (unary.getOperator().type == Token::Type::STAR) { + auto rightNumericPtrType = dynamic_cast(rightType); + if (!rightNumericPtrType) { + m_ErrorHandler.error(unary.getOperator(), + "Invalid operand type '" + rightType->getTypeName() + "'"); + throw TypeCheckError(); + } + + // Return value type + m_Type = rightNumericPtrType->getValueType(); + + // **THINK** + m_Const = false; + } + // Otherwise + else { + auto rightNumericType = dynamic_cast(rightType); + if (rightNumericType) { + // If operator is arithmetic, return promoted type + if (unary.getOperator().type == Token::Type::PLUS || unary.getOperator().type == Token::Type::MINUS) { + m_Type = Type::getPromotedType(rightNumericType); + m_Const = false; + } + // Otherwise, if operator is bitwise + else if (unary.getOperator().type == Token::Type::TILDA) { + // If type is integer, return promoted type + if (rightNumericType->isIntegral()) { + m_Type = Type::getPromotedType(rightNumericType); + m_Const = false; + } + else { + m_ErrorHandler.error(unary.getOperator(), + "Invalid operand type '" + rightType->getTypeName() + "'"); + throw TypeCheckError(); + } + } + // Otherwise, if operator is logical + else if (unary.getOperator().type == Token::Type::NOT) { + m_Type = Type::Int32::getInstance(); + m_Const = false; + } + // Otherwise, if operator is address of, return pointer type + else if (unary.getOperator().type == Token::Type::AMPERSAND) { + m_Type = rightNumericType->getPointerType(); + m_Const = rightConst; + } + } + else { + m_ErrorHandler.error(unary.getOperator(), + "Invalid operand type '" + rightType->getTypeName() + "'"); + throw TypeCheckError(); + } + } + } + + //--------------------------------------------------------------------------- + // Statement::Visitor virtuals + //--------------------------------------------------------------------------- + virtual void visit(const Statement::Break &breakStatement) final + { + if (!m_InLoop && !m_InSwitch) { + m_ErrorHandler.error(breakStatement.getToken(), "Statement not within loop"); + } + } + + virtual void visit(const Statement::Compound &compound) final + { + Environment environment(m_Environment); + typeCheck(compound.getStatements(), environment); + } + + virtual void visit(const Statement::Continue &continueStatement) final + { + if (!m_InLoop) { + m_ErrorHandler.error(continueStatement.getToken(), "Statement not within loop"); + } + } + + virtual void visit(const Statement::Do &doStatement) final + { + m_InLoop = true; + doStatement.getBody()->accept(*this); + m_InLoop = false; + doStatement.getCondition()->accept(*this); + } + + virtual void visit(const Statement::Expression &expression) final + { + expression.getExpression()->accept(*this); + } + + virtual void visit(const Statement::For &forStatement) final + { + // Create new environment for loop initialisation + Environment *previous = m_Environment; + Environment environment(m_Environment); + m_Environment = &environment; + + // Interpret initialiser if statement present + if (forStatement.getInitialiser()) { + forStatement.getInitialiser()->accept(*this); + } + + if (forStatement.getCondition()) { + forStatement.getCondition()->accept(*this); + } + + if (forStatement.getIncrement()) { + forStatement.getIncrement()->accept(*this); + } + + m_InLoop = true; + forStatement.getBody()->accept(*this); + m_InLoop = false; + + // Restore environment + m_Environment = previous; + } + + virtual void visit(const Statement::If &ifStatement) final + { + ifStatement.getCondition()->accept(*this); + ifStatement.getThenBranch()->accept(*this); + if (ifStatement.getElseBranch()) { + ifStatement.getElseBranch()->accept(*this); + } + } + + virtual void visit(const Statement::Labelled &labelled) final + { + if (!m_InSwitch) { + m_ErrorHandler.error(labelled.getKeyword(), "Statement not within switch statement"); + } + + if (labelled.getValue()) { + auto valType = evaluateType(labelled.getValue()); + auto valNumericType = dynamic_cast(valType); + if (!valNumericType || !valNumericType->isIntegral()) { + m_ErrorHandler.error(labelled.getKeyword(), + "Invalid case value '" + valType->getTypeName() + "'"); + throw TypeCheckError(); + } + } + + labelled.getBody()->accept(*this); + } + + virtual void visit(const Statement::Switch &switchStatement) final + { + auto condType = evaluateType(switchStatement.getCondition()); + auto condNumericType = dynamic_cast(condType); + if (!condNumericType || !condNumericType->isIntegral()) { + m_ErrorHandler.error(switchStatement.getSwitch(), + "Invalid condition '" + condType->getTypeName() + "'"); + throw TypeCheckError(); + } + + m_InSwitch = true; + switchStatement.getBody()->accept(*this); + m_InSwitch = false; + } + + virtual void visit(const Statement::VarDeclaration &varDeclaration) final + { + for (const auto &var : varDeclaration.getInitDeclaratorList()) { + m_Environment->define(std::get<0>(var), varDeclaration.getType(), + varDeclaration.isConst(), m_ErrorHandler); + + // If variable has an initialiser expression + if (std::get<1>(var)) { + // Evaluate type + const auto [initialiserType, initialiserConst] = evaluateTypeConst(std::get<1>(var).get()); + + // Assign initialiser expression to variable + m_Environment->assign(std::get<0>(var), initialiserType, initialiserConst, Token::Type::EQUAL, m_ErrorHandler); + } + } + } + + virtual void visit(const Statement::While &whileStatement) final + { + whileStatement.getCondition()->accept(*this); + m_InLoop = true; + whileStatement.getBody()->accept(*this); + m_InLoop = false; + } + + virtual void visit(const Statement::Print &print) final + { + print.getExpression()->accept(*this); + } + +private: + //--------------------------------------------------------------------------- + // Private methods + //--------------------------------------------------------------------------- + std::tuple evaluateTypeConst(const Expression::Base *expression) + { + expression->accept(*this); + return std::make_tuple(m_Type, m_Const); + } + + const Type::Base *evaluateType(const Expression::Base *expression) + { + return std::get<0>(evaluateTypeConst(expression)); + } + + //--------------------------------------------------------------------------- + // Members + //--------------------------------------------------------------------------- + Environment *m_Environment; + const Type::Base *m_Type; + bool m_Const; + + ErrorHandler &m_ErrorHandler; + bool m_InLoop; + bool m_InSwitch; +}; +} + +//--------------------------------------------------------------------------- +// MiniParse::TypeChecker::Environment +//--------------------------------------------------------------------------- +void Environment::define(const Token &name, const Type::Base *type, bool isConst, ErrorHandler &errorHandler) +{ + if(!m_Types.try_emplace(name.lexeme, type, isConst).second) { + errorHandler.error(name, "Redeclaration of variable"); + throw TypeCheckError(); + } +} +//--------------------------------------------------------------------------- +const Type::Base *Environment::assign(const Token &name, const Type::Base *assignedType, bool assignedConst, + Token::Type op, ErrorHandler &errorHandler) +{ + // If type isn't found + auto existingType = m_Types.find(name.lexeme); + if(existingType == m_Types.end()) { + if(m_Enclosing) { + return m_Enclosing->assign(name, assignedType, + assignedConst, op, errorHandler); + } + else { + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); + } + } + // Otherwise, if type is found and it's const, give error + else if(std::get<1>(existingType->second)) { + errorHandler.error(name, "Assignment of read-only variable"); + throw TypeCheckError(); + } + + auto numericExistingType = dynamic_cast(std::get<0>(existingType->second)); + auto numericAssignedType = dynamic_cast(assignedType); + + auto numericPtrExistingType = dynamic_cast(std::get<0>(existingType->second)); + auto numericPtrAssignedType = dynamic_cast(assignedType); + + // If assignment operation is plain equals, any type is fine so return + // **TODO** pointer type check + if(op == Token::Type::EQUAL) { + // If we're initialising a pointer with another pointer + if (numericPtrAssignedType && numericPtrExistingType) { + // If variable is non-const but initialiser is const + /*if (!varDeclaration.isConst() && intialiserConst) { + m_ErrorHandler.error(std::get<0>(var), + "Invalid operand types '" + initialiserType->getTypeName() + "'"); + throw TypeCheckError(); + }*/ + + // If pointer types aren't compatible + if (numericPtrExistingType->getTypeHash() != numericPtrAssignedType->getTypeHash()) { + errorHandler.error(name, "Invalid operand types '" + numericPtrExistingType->getTypeName() + "' and '" + numericPtrAssignedType->getTypeName()); + throw TypeCheckError(); + } + } + // Otherwise, if we're trying to initialise a pointer with a non-pointer or vice-versa + else if (numericPtrAssignedType || numericPtrExistingType) { + errorHandler.error(name, "Invalid operand types '" + std::get<0>(existingType->second)->getTypeName() + "' and '" + assignedType->getTypeName()); + throw TypeCheckError(); + } + } + // Otherwise, if operation is += or -- + else if (op == Token::Type::PLUS_EQUAL || op == Token::Type::MINUS_EQUAL) { + // If the operand being added isn't numeric or the type being added to is neither numeric or a pointer + if (!numericAssignedType || (!numericPtrExistingType && !numericExistingType)) + { + errorHandler.error(name, "Invalid operand types '" + std::get<0>(existingType->second)->getTypeName() + "' and '" + assignedType->getTypeName() + "'"); + throw TypeCheckError(); + } + + // If we're adding a numeric type to a pointer, check it's an integer + if (numericPtrExistingType && numericAssignedType->isIntegral()) { + errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getTypeName() + "'"); + throw TypeCheckError(); + } + } + // Otherwise, numeric types are required + else { + // If either type is non-numeric, give error + if(!numericAssignedType) { + errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getTypeName() + "'"); + throw TypeCheckError(); + } + if(!numericExistingType) { + errorHandler.error(name, "Invalid operand types '" + std::get<0>(existingType->second)->getTypeName() + "'"); + throw TypeCheckError(); + } + + // If operand isn't one that takes any numeric type, check both operands are integral + if (op != Token::Type::STAR_EQUAL && op != Token::Type::SLASH_EQUAL) { + if(!numericAssignedType->isIntegral()) { + errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getTypeName() + "'"); + throw TypeCheckError(); + } + if(!numericExistingType->isIntegral()) { + errorHandler.error(name, "Invalid operand types '" + numericExistingType->getTypeName() + "'"); + throw TypeCheckError(); + } + } + } + + // Return existing type + return std::get<0>(existingType->second); +} +//--------------------------------------------------------------------------- +const Type::Base *Environment::incDec(const Token &name, const Token &op, ErrorHandler &errorHandler) +{ + auto existingType = m_Types.find(name.lexeme); + if(existingType == m_Types.end()) { + if(m_Enclosing) { + return m_Enclosing->incDec(name, op, errorHandler); + } + else { + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); + } + } + // Otherwise, if type is found and it's const, give error + else if(std::get<1>(existingType->second)) { + errorHandler.error(name, "Increment/decrement of read-only variable"); + throw TypeCheckError(); + } + // Otherwise, return type + // **TODO** pointer + else { + auto numericExistingType = dynamic_cast(std::get<0>(existingType->second)); + if(numericExistingType == nullptr) { + errorHandler.error(op, "Invalid operand types '" + std::get<0>(existingType->second)->getTypeName() + "'"); + throw TypeCheckError(); + } + else { + return std::get<0>(existingType->second); + } + } +} +//--------------------------------------------------------------------------- +std::tuple Environment::getType(const Token &name, ErrorHandler &errorHandler) const +{ + auto type = m_Types.find(std::string{name.lexeme}); + if(type == m_Types.end()) { + if(m_Enclosing) { + return m_Enclosing->getType(name, errorHandler); + } + else { + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); + } + } + else { + return type->second; + } +} +//--------------------------------------------------------------------------- +void MiniParse::TypeChecker::typeCheck(const Statement::StatementList &statements, Environment &environment, + ErrorHandler &errorHandler) +{ + Visitor visitor(errorHandler); + visitor.typeCheck(statements, environment); +} \ No newline at end of file From 2523e61e5b33484161d2c6a0c85238e1ea22c5ad Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 6 Jan 2023 09:41:10 +0000 Subject: [PATCH 008/725] moved type into GeNN --- include/genn/genn/{transpiler => }/type.h | 0 src/genn/genn/{transpiler => }/type.cc | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename include/genn/genn/{transpiler => }/type.h (100%) rename src/genn/genn/{transpiler => }/type.cc (100%) diff --git a/include/genn/genn/transpiler/type.h b/include/genn/genn/type.h similarity index 100% rename from include/genn/genn/transpiler/type.h rename to include/genn/genn/type.h diff --git a/src/genn/genn/transpiler/type.cc b/src/genn/genn/type.cc similarity index 100% rename from src/genn/genn/transpiler/type.cc rename to src/genn/genn/type.cc From ca93282dcfbccf87ae8852424cb2b4e0e99a1fb9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 6 Jan 2023 10:13:00 +0000 Subject: [PATCH 009/725] renamed namespaces etc --- include/genn/genn/transpiler/errorHandler.h | 8 +- include/genn/genn/transpiler/expression.h | 42 ++++---- include/genn/genn/transpiler/parser.h | 14 +-- include/genn/genn/transpiler/prettyPrinter.h | 8 +- include/genn/genn/transpiler/scanner.h | 10 +- include/genn/genn/transpiler/statement.h | 102 ++++++++++-------- include/genn/genn/transpiler/token.h | 7 +- .../genn/genn/transpiler/transpilerUtils.h | 3 +- include/genn/genn/transpiler/typeChecker.h | 14 +-- include/genn/genn/type.h | 20 ++-- src/genn/genn/genn.vcxproj | 4 +- src/genn/genn/transpiler/expression.cc | 10 +- src/genn/genn/transpiler/parser.cc | 20 ++-- src/genn/genn/transpiler/prettyPrinter.cc | 16 +-- src/genn/genn/transpiler/scanner.cc | 18 ++-- src/genn/genn/transpiler/statement.cc | 10 +- src/genn/genn/transpiler/typeChecker.cc | 21 ++-- src/genn/genn/type.cc | 8 +- 18 files changed, 174 insertions(+), 161 deletions(-) diff --git a/include/genn/genn/transpiler/errorHandler.h b/include/genn/genn/transpiler/errorHandler.h index ad34193da9..ceb88fa66f 100644 --- a/include/genn/genn/transpiler/errorHandler.h +++ b/include/genn/genn/transpiler/errorHandler.h @@ -3,13 +3,13 @@ // Standard C++ includes #include -// Mini-parse includes -#include "token.h" +// Transpiler includes +#include "transpiler/token.h" //--------------------------------------------------------------------------- -// MiniParse::ErrorHandler +// GeNN::Transpiler::ErrorHandler //--------------------------------------------------------------------------- -namespace MiniParse +namespace GeNN::Transpiler { class ErrorHandler { diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index ce4493e01c..44721c319a 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -4,23 +4,23 @@ #include #include -// Mini-parse includes -#include "token.h" +// Transpiler includes +#include "transpiler/token.h" // Forward declarations -namespace MiniParse::Expression +namespace GeNN::Transpiler::Expression { class Visitor; } -namespace Type +namespace GeNN::Type { class Base; } //--------------------------------------------------------------------------- -// MiniParse::Expression::Base +// GeNN::Transpiler::Expression::Base //--------------------------------------------------------------------------- -namespace MiniParse::Expression +namespace GeNN::Transpiler::Expression { class Base { @@ -32,7 +32,7 @@ typedef std::unique_ptr ExpressionPtr; typedef std::vector ExpressionList; //--------------------------------------------------------------------------- -// MiniParse::Expression::ArraySubscript +// GeNN::Transpiler::Expression::ArraySubscript //--------------------------------------------------------------------------- class ArraySubscript : public Base { @@ -52,7 +52,7 @@ class ArraySubscript : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Expression::Assignment +// GeNN::Transpiler::Expression::Assignment //--------------------------------------------------------------------------- class Assignment : public Base { @@ -74,7 +74,7 @@ class Assignment : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Expression::Binary +// GeNN::Transpiler::Expression::Binary //--------------------------------------------------------------------------- class Binary : public Base { @@ -96,7 +96,7 @@ class Binary : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Expression::Call +// GeNN::Transpiler::Expression::Call //--------------------------------------------------------------------------- class Call : public Base { @@ -118,7 +118,7 @@ class Call : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Expression::Cast +// GeNN::Transpiler::Expression::Cast //--------------------------------------------------------------------------- class Cast : public Base { @@ -141,7 +141,7 @@ class Cast : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Expression::Conditional +// GeNN::Transpiler::Expression::Conditional //--------------------------------------------------------------------------- class Conditional : public Base { @@ -165,7 +165,7 @@ class Conditional : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Expression::Grouping +// GeNN::Transpiler::Expression::Grouping //--------------------------------------------------------------------------- class Grouping : public Base { @@ -183,7 +183,7 @@ class Grouping : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Expression::Literal +// GeNN::Transpiler::Expression::Literal //--------------------------------------------------------------------------- class Literal : public Base { @@ -201,7 +201,7 @@ class Literal : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Expression::Logical +// GeNN::Transpiler::Expression::Logical //--------------------------------------------------------------------------- class Logical : public Base { @@ -223,7 +223,7 @@ class Logical : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Expression::PostfixIncDec +// GeNN::Transpiler::Expression::PostfixIncDec //--------------------------------------------------------------------------- class PostfixIncDec : public Base { @@ -243,7 +243,7 @@ class PostfixIncDec : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Expression::PrefixIncDec +// GeNN::Transpiler::Expression::PrefixIncDec //--------------------------------------------------------------------------- class PrefixIncDec : public Base { @@ -263,7 +263,7 @@ class PrefixIncDec : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Expression::Variable +// GeNN::Transpiler::Expression::Variable //--------------------------------------------------------------------------- class Variable : public Base { @@ -281,7 +281,7 @@ class Variable : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Expression::Unary +// GeNN::Transpiler::Expression::Unary //--------------------------------------------------------------------------- class Unary : public Base { @@ -302,7 +302,7 @@ class Unary : public Base //--------------------------------------------------------------------------- -// MiniParse::Expression::Visitor +// GeNN::Transpiler::Expression::Visitor //--------------------------------------------------------------------------- class Visitor { @@ -321,4 +321,4 @@ class Visitor virtual void visit(const Variable &variable) = 0; virtual void visit(const Unary &unary) = 0; }; -} // namespace MiniParse::Expression \ No newline at end of file +} // namespace GeNN::Transpiler::Expression \ No newline at end of file diff --git a/include/genn/genn/transpiler/parser.h b/include/genn/genn/transpiler/parser.h index 8b063c2ccf..405246410f 100644 --- a/include/genn/genn/transpiler/parser.h +++ b/include/genn/genn/transpiler/parser.h @@ -4,21 +4,21 @@ #include #include -// Mini-parse includes -#include "expression.h" -#include "statement.h" -#include "token.h" +// Transpiler includes +#include "transpiler/expression.h" +#include "transpiler/statement.h" +#include "transpiler/token.h" // Forward declarations -namespace MiniParse +namespace GeNN::Transpiler { class ErrorHandler; } //--------------------------------------------------------------------------- -// MiniParse::Scanner::Parser +// GeNN::Transpiler::Parser //--------------------------------------------------------------------------- -namespace MiniParse::Parser +namespace GeNN::Transpiler::Parser { Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandler &errorHandler); diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h index cf4f7949a9..ddbb2af9e2 100644 --- a/include/genn/genn/transpiler/prettyPrinter.h +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -3,13 +3,13 @@ // Standard C++ includes #include -// Mini-parse includes -#include "statement.h" +// Transpiler includes +#include "transpiler/statement.h" //--------------------------------------------------------------------------- -// MiniParse::PrettyPrinter +// GeNN::Transpiler::PrettyPrinter //--------------------------------------------------------------------------- -namespace MiniParse::PrettyPrinter +namespace GeNN::Transpiler::PrettyPrinter { std::string print(const Statement::StatementList &statements); } \ No newline at end of file diff --git a/include/genn/genn/transpiler/scanner.h b/include/genn/genn/transpiler/scanner.h index 07c53dbe9c..abf3fbd45c 100644 --- a/include/genn/genn/transpiler/scanner.h +++ b/include/genn/genn/transpiler/scanner.h @@ -7,19 +7,19 @@ #include #include -// Mini-parse includes -#include "token.h" +// Transpiler includes +#include "transpiler/token.h" // Forward declarations -namespace MiniParse +namespace GeNN::Transpiler { class ErrorHandler; } //--------------------------------------------------------------------------- -// MiniParse::Scanner::Error +// GeNN::Transpiler::Scanner::Error //--------------------------------------------------------------------------- -namespace MiniParse::Scanner +namespace GeNN::Transpiler::Scanner { std::vector scanSource(const std::string_view &source, ErrorHandler &errorHandler); diff --git a/include/genn/genn/transpiler/statement.h b/include/genn/genn/transpiler/statement.h index 4ee43e1f9d..cff48c81e1 100644 --- a/include/genn/genn/transpiler/statement.h +++ b/include/genn/genn/transpiler/statement.h @@ -4,23 +4,23 @@ #include #include -// Mini-parse includes -#include "expression.h" +// Transpiler includes +#include "transpiler/expression.h" // Forward declarations -namespace MiniParse::Statement +namespace GeNN::Transpiler::Statement { class Visitor; } -namespace Type +namespace GeNN::Type { class Base; } //--------------------------------------------------------------------------- -// MiniParse::Statement::Base +// GeNN::Transpiler::Statement::Base //--------------------------------------------------------------------------- -namespace MiniParse::Statement +namespace GeNN::Transpiler::Statement { class Base { @@ -32,7 +32,7 @@ typedef std::unique_ptr StatementPtr; typedef std::vector StatementList; //--------------------------------------------------------------------------- -// MiniParse::Statement::Break +// GeNN::Transpiler::Statement::Break //--------------------------------------------------------------------------- class Break : public Base { @@ -50,7 +50,7 @@ class Break : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Statement::Compound +// GeNN::Transpiler::Statement::Compound //--------------------------------------------------------------------------- class Compound : public Base { @@ -68,7 +68,7 @@ class Compound : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Statement::Continue +// GeNN::Transpiler::Statement::Continue //--------------------------------------------------------------------------- class Continue : public Base { @@ -86,142 +86,148 @@ class Continue : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Statement::Do +// GeNN::Transpiler::Statement::Do //--------------------------------------------------------------------------- class Do : public Base { + using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: - Do(MiniParse::Expression::ExpressionPtr condition, StatementPtr body) + Do(ExpressionPtr condition, StatementPtr body) : m_Condition(std::move(condition)), m_Body(std::move(body)) {} virtual void accept(Visitor &visitor) const override; - const MiniParse::Expression::Base *getCondition() const { return m_Condition.get(); } + const ExpressionPtr::element_type *getCondition() const { return m_Condition.get(); } const Base *getBody() const { return m_Body.get(); } private: - const MiniParse::Expression::ExpressionPtr m_Condition; + const ExpressionPtr m_Condition; const StatementPtr m_Body; }; //--------------------------------------------------------------------------- -// MiniParse::Statement::Expression +// GeNN::Transpiler::Statement::Expression //--------------------------------------------------------------------------- class Expression : public Base { + using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: - Expression(MiniParse::Expression::ExpressionPtr expression) + Expression(ExpressionPtr expression) : m_Expression(std::move(expression)) {} virtual void accept(Visitor &visitor) const override; - const MiniParse::Expression::Base *getExpression() const { return m_Expression.get(); } + const ExpressionPtr::element_type *getExpression() const { return m_Expression.get(); } private: - const MiniParse::Expression::ExpressionPtr m_Expression; + const ExpressionPtr m_Expression; }; //--------------------------------------------------------------------------- -// MiniParse::Statement::For +// GeNN::Transpiler::Statement::For //--------------------------------------------------------------------------- class For : public Base { + using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: - For(StatementPtr initialiser, MiniParse::Expression::ExpressionPtr condition, MiniParse::Expression::ExpressionPtr increment, StatementPtr body) + For(StatementPtr initialiser, ExpressionPtr condition, ExpressionPtr increment, StatementPtr body) : m_Initialiser(std::move(initialiser)), m_Condition(std::move(condition)), m_Increment(std::move(increment)), m_Body(std::move(body)) {} virtual void accept(Visitor &visitor) const override; const Base *getInitialiser() const { return m_Initialiser.get(); } - const MiniParse::Expression::Base *getCondition() const { return m_Condition.get(); } - const MiniParse::Expression::Base *getIncrement() const { return m_Increment.get(); } + const ExpressionPtr::element_type *getCondition() const { return m_Condition.get(); } + const ExpressionPtr::element_type *getIncrement() const { return m_Increment.get(); } const Base *getBody() const { return m_Body.get(); } private: const StatementPtr m_Initialiser; - const MiniParse::Expression::ExpressionPtr m_Condition; - const MiniParse::Expression::ExpressionPtr m_Increment; + const ExpressionPtr m_Condition; + const ExpressionPtr m_Increment; const StatementPtr m_Body; }; //--------------------------------------------------------------------------- -// MiniParse::Statement::If +// GeNN::Transpiler::Statement::If //--------------------------------------------------------------------------- class If : public Base { + using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: - If(MiniParse::Expression::ExpressionPtr condition, StatementPtr thenBranch, StatementPtr elseBranch) + If(ExpressionPtr condition, StatementPtr thenBranch, StatementPtr elseBranch) : m_Condition(std::move(condition)), m_ThenBranch(std::move(thenBranch)), m_ElseBranch(std::move(elseBranch)) {} virtual void accept(Visitor &visitor) const override; - const MiniParse::Expression::Base *getCondition() const { return m_Condition.get(); } + const ExpressionPtr::element_type *getCondition() const { return m_Condition.get(); } const Base *getThenBranch() const { return m_ThenBranch.get(); } const Base *getElseBranch() const { return m_ElseBranch.get(); } private: - const MiniParse::Expression::ExpressionPtr m_Condition; + const ExpressionPtr m_Condition; const StatementPtr m_ThenBranch; const StatementPtr m_ElseBranch; }; //--------------------------------------------------------------------------- -// MiniParse::Statement::Labelled +// GeNN::Transpiler::Statement::Labelled //--------------------------------------------------------------------------- class Labelled : public Base { + using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: - Labelled(Token keyword, MiniParse::Expression::ExpressionPtr value, StatementPtr body) + Labelled(Token keyword, ExpressionPtr value, StatementPtr body) : m_Keyword(keyword), m_Value(std::move(value)), m_Body(std::move(body)) {} virtual void accept(Visitor &visitor) const override; const Token &getKeyword() const { return m_Keyword; } - const MiniParse::Expression::Base *getValue() const { return m_Value.get(); } + const ExpressionPtr::element_type *getValue() const { return m_Value.get(); } const Base *getBody() const { return m_Body.get(); } private: const Token m_Keyword; - const MiniParse::Expression::ExpressionPtr m_Value; + const ExpressionPtr m_Value; const StatementPtr m_Body; }; //--------------------------------------------------------------------------- -// MiniParse::Statement::Switch +// GeNN::Transpiler::Statement::Switch //--------------------------------------------------------------------------- class Switch : public Base { + using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: - Switch(Token switchToken, MiniParse::Expression::ExpressionPtr condition, StatementPtr body) + Switch(Token switchToken, ExpressionPtr condition, StatementPtr body) : m_Switch(switchToken), m_Condition(std::move(condition)), m_Body(std::move(body)) {} virtual void accept(Visitor &visitor) const override; const Token &getSwitch() const { return m_Switch; } - const MiniParse::Expression::Base *getCondition() const { return m_Condition.get(); } + const ExpressionPtr::element_type *getCondition() const { return m_Condition.get(); } const Base *getBody() const { return m_Body.get(); } private: const Token m_Switch; - const MiniParse::Expression::ExpressionPtr m_Condition; + const ExpressionPtr m_Condition; const StatementPtr m_Body; }; //--------------------------------------------------------------------------- -// MiniParse::Statement::VarDeclaration +// GeNN::Transpiler::Statement::VarDeclaration //--------------------------------------------------------------------------- class VarDeclaration : public Base { public: - typedef std::vector> InitDeclaratorList; + typedef std::vector> InitDeclaratorList; VarDeclaration(const Type::Base *type, bool isConst, InitDeclaratorList initDeclaratorList) : m_Type(type), m_Const(isConst), m_InitDeclaratorList(std::move(initDeclaratorList)) @@ -242,46 +248,48 @@ class VarDeclaration : public Base }; //--------------------------------------------------------------------------- -// MiniParse::Statement::If +// GeNN::Transpiler::Statement::If //--------------------------------------------------------------------------- class While : public Base { + using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: - While(MiniParse::Expression::ExpressionPtr condition, StatementPtr body) + While(ExpressionPtr condition, StatementPtr body) : m_Condition(std::move(condition)), m_Body(std::move(body)) {} virtual void accept(Visitor &visitor) const override; - const MiniParse::Expression::Base *getCondition() const { return m_Condition.get(); } + const ExpressionPtr::element_type *getCondition() const { return m_Condition.get(); } const Base *getBody() const { return m_Body.get(); } private: - const MiniParse::Expression::ExpressionPtr m_Condition; + const ExpressionPtr m_Condition; const StatementPtr m_Body; }; //--------------------------------------------------------------------------- -// MiniParse::Statement::Print +// GeNN::Transpiler::Statement::Print //--------------------------------------------------------------------------- // **HACK** temporary until function calling is working class Print : public Base { + using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: - Print(MiniParse::Expression::ExpressionPtr expression) + Print(ExpressionPtr expression) : m_Expression(std::move(expression)) {} virtual void accept(Visitor &visitor) const override; - const MiniParse::Expression::Base *getExpression() const { return m_Expression.get(); } + const ExpressionPtr::element_type *getExpression() const { return m_Expression.get(); } private: - const MiniParse::Expression::ExpressionPtr m_Expression; + const ExpressionPtr m_Expression; }; //--------------------------------------------------------------------------- -// MiniParse::Statement::Visitor +// GeNN::Transpiler::Statement::Visitor //--------------------------------------------------------------------------- class Visitor { @@ -299,4 +307,4 @@ class Visitor virtual void visit(const While &whileStatement) = 0; virtual void visit(const Print &print) = 0; }; -} // namespace MiniParse::Statement \ No newline at end of file +} // namespace GeNN::Transpiler::Statement \ No newline at end of file diff --git a/include/genn/genn/transpiler/token.h b/include/genn/genn/transpiler/token.h index be480fa67f..99dadfcbea 100644 --- a/include/genn/genn/transpiler/token.h +++ b/include/genn/genn/transpiler/token.h @@ -8,9 +8,9 @@ #include //--------------------------------------------------------------------------- -// MiniParse::Token +// GeNN::Transpiler::Token //--------------------------------------------------------------------------- -namespace MiniParse +namespace GeNN::Transpiler { struct Token { @@ -58,5 +58,4 @@ struct Token const size_t line; const LiteralValue literalValue; }; - -} +} // namespace GeNN::Transpiler diff --git a/include/genn/genn/transpiler/transpilerUtils.h b/include/genn/genn/transpiler/transpilerUtils.h index be9514005a..f062fc8106 100644 --- a/include/genn/genn/transpiler/transpilerUtils.h +++ b/include/genn/genn/transpiler/transpilerUtils.h @@ -2,9 +2,10 @@ // Standard C++ includes #include +#include #include -namespace MiniParse::Utils +namespace GeNN::Transpiler::Utils { template struct Overload : Ts... { using Ts::operator()...; }; template Overload(Ts...) -> Overload; // line not needed in diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index b9978abd91..ab75bf241a 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -5,24 +5,24 @@ #include #include -// Mini-parse includes -#include "statement.h" +// Transpiler includes +#include "transpiler/statement.h" // Forward declarations -namespace MiniParse +namespace GeNN::Transpiler { class ErrorHandler; struct Token; } -namespace Type +namespace GeNN::Type { class Base; } //--------------------------------------------------------------------------- -// MiniParse::TypeChecker::Environment +// GeNN::Transpiler::TypeChecker::Environment //--------------------------------------------------------------------------- -namespace MiniParse::TypeChecker +namespace GeNN::Transpiler::TypeChecker { class Environment { @@ -61,4 +61,4 @@ class Environment //--------------------------------------------------------------------------- void typeCheck(const Statement::StatementList &statements, Environment &environment, ErrorHandler &errorHandler); -} // namespace MiniParse::TypeChecker \ No newline at end of file +} // namespace MiniParse::GeNN::Transpiler \ No newline at end of file diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 8c93ee6abf..0e69abb537 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -59,9 +59,9 @@ #define IMPLEMENT_NUMERIC_TYPE(TYPE) IMPLEMENT_TYPE(TYPE); IMPLEMENT_TYPE(TYPE##Ptr) //---------------------------------------------------------------------------- -// Type::TypeTraits +// GeNN::Type::TypeTraits //---------------------------------------------------------------------------- -namespace Type +namespace GeNN::Type { //! Empty type trait structure template @@ -70,7 +70,7 @@ struct TypeTraits }; //---------------------------------------------------------------------------- -// Type::Base +// GeNN::Type::Base //---------------------------------------------------------------------------- //! Base class for all types class Base @@ -84,7 +84,7 @@ class Base }; //---------------------------------------------------------------------------- -// Type::NumericBase +// GeNN::Type::NumericBase //---------------------------------------------------------------------------- class NumericBase : public Base { @@ -103,7 +103,7 @@ class NumericBase : public Base }; //---------------------------------------------------------------------------- -// NumericPtrBase +// GeNN::NumericPtrBase //---------------------------------------------------------------------------- class NumericPtrBase : public Base { @@ -115,7 +115,7 @@ class NumericPtrBase : public Base }; //---------------------------------------------------------------------------- -// Type::Numeric +// GeNN::Type::Numeric //---------------------------------------------------------------------------- template class Numeric : public NumericBase @@ -148,7 +148,7 @@ class Numeric : public NumericBase }; //---------------------------------------------------------------------------- -// NumericPtr +// GeNN::NumericPtr //---------------------------------------------------------------------------- template class NumericPtr : public NumericPtrBase @@ -167,7 +167,7 @@ class NumericPtr : public NumericPtrBase }; //---------------------------------------------------------------------------- -// Type::ForeignFunctionBase +// GeNN::Type::ForeignFunctionBase //---------------------------------------------------------------------------- class ForeignFunctionBase : public Base { @@ -186,7 +186,7 @@ class ForeignFunctionBase : public Base }; //---------------------------------------------------------------------------- -// Type::ForeignFunction +// GeNN::Type::ForeignFunction //---------------------------------------------------------------------------- template class ForeignFunction : public ForeignFunctionBase @@ -297,4 +297,4 @@ const NumericBase *getNumericType(const std::set &typeSpecifie const NumericPtrBase *getNumericPtrType(const std::set &typeSpecifiers); const NumericBase *getPromotedType(const NumericBase *type); const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b); -} // namespace Type +} // namespace GeNN::Type diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index 901a1c5e69..fc8fb92f8e 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -60,8 +60,8 @@ - + @@ -125,8 +125,8 @@ - + diff --git a/src/genn/genn/transpiler/expression.cc b/src/genn/genn/transpiler/expression.cc index f85eb8f49e..e3ea06e688 100644 --- a/src/genn/genn/transpiler/expression.cc +++ b/src/genn/genn/transpiler/expression.cc @@ -1,9 +1,9 @@ -#include "expression.h" +#include "transpiler/expression.h" -#define IMPLEMENT_ACCEPT(CLASS_NAME) \ - void MiniParse::Expression::CLASS_NAME::accept(Visitor &visitor) const \ - { \ - visitor.visit(*this); \ +#define IMPLEMENT_ACCEPT(CLASS_NAME) \ + void GeNN::Transpiler::Expression::CLASS_NAME::accept(Visitor &visitor) const \ + { \ + visitor.visit(*this); \ } diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index e8f63e7066..44931ee92e 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -1,4 +1,4 @@ -#include "parser.h" +#include "transpiler/parser.h" // Standard C++ includes #include @@ -13,10 +13,10 @@ // GeNN includes #include "type.h" -// Mini-parse includes -#include "error_handler.h" +// Transpiler includes +#include "transpiler/errorHandler.h" -using namespace MiniParse; +using namespace GeNN::Transpiler; //--------------------------------------------------------------------------- // Anonymous namespace @@ -182,7 +182,7 @@ Expression::ExpressionPtr parseBinary(ParserState &parserState, N nonTerminal, s return expression; } -std::tuple parseDeclarationSpecifiers(ParserState &parserState) +std::tuple parseDeclarationSpecifiers(ParserState &parserState) { // Loop through type qualifier and specifier tokens std::set typeQualifiers{}; @@ -202,9 +202,9 @@ std::tuple parseDeclarationSpecifiers(ParserState &pars } while(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER})); // Lookup type - const Type::Base *type = (parserState.match({Token::Type::STAR}) - ? static_cast(Type::getNumericPtrType(typeSpecifiers)) - : static_cast(Type::getNumericType(typeSpecifiers))); + const GeNN::Type::Base *type = (parserState.match({Token::Type::STAR}) + ? static_cast(GeNN::Type::getNumericPtrType(typeSpecifiers)) + : static_cast(GeNN::Type::getNumericType(typeSpecifiers))); if(!type) { parserState.error("Unknown type specifier"); } @@ -828,9 +828,9 @@ std::unique_ptr parseBlockItem(ParserState &parserState) //--------------------------------------------------------------------------- -// MiniParse::Parser +// GeNN::Transpiler::Parser //--------------------------------------------------------------------------- -namespace MiniParse::Parser +namespace GeNN::Transpiler::Parser { Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandler &errorHandler) { diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 0e21c638d1..98dfc5f959 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -1,17 +1,18 @@ -#include "pretty_printer.h" +#include "transpiler/prettyPrinter.h" // Standard C++ includes #include #include #include -// Mini-parse includes +// GeNN includes #include "type.h" -#include "utils.h" +// Transpiler includes +#include "transpiler/transpilerUtils.h" -using namespace MiniParse; -using namespace MiniParse::PrettyPrinter; +using namespace GeNN::Transpiler; +using namespace GeNN::Transpiler::PrettyPrinter; //--------------------------------------------------------------------------- // Anonymous namespace @@ -265,7 +266,10 @@ class Visitor : public Expression::Visitor, public Statement::Visitor }; } // Anonymous namespace -std::string MiniParse::PrettyPrinter::print(const Statement::StatementList &statements) +//--------------------------------------------------------------------------- +// GeNN::Transpiler::PrettyPrinter +//--------------------------------------------------------------------------- +std::string GeNN::Transpiler::PrettyPrinter::print(const Statement::StatementList &statements) { Visitor visitor; return visitor.print(statements); diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index abb9baefc6..cbd3e43e86 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -1,4 +1,4 @@ -#include "scanner.h" +#include "transpiler/scanner.h" // Standard C++ includes #include @@ -10,12 +10,12 @@ // Standard C includes #include -// Mini-parse includes -#include "error_handler.h" -#include "utils.h" +// Transpiler includes +#include "transpiler/errorHandler.h" +#include "transpiler/transpilerUtils.h" -using namespace MiniParse; -using namespace MiniParse::Scanner; +using namespace GeNN::Transpiler; +using namespace GeNN::Transpiler::Scanner; //--------------------------------------------------------------------------- // Anonymous namespace @@ -300,8 +300,6 @@ void scanIdentifier(ScanState &scanState, std::vector &tokens) //--------------------------------------------------------------------------- void scanToken(ScanState &scanState, std::vector &tokens) { - using namespace MiniParse; - char c = scanState.advance(); switch(c) { // Single character tokens @@ -464,9 +462,9 @@ void scanToken(ScanState &scanState, std::vector &tokens) } //--------------------------------------------------------------------------- -// MiniParse::Scanner +// GeNN::Transpiler::Scanner //--------------------------------------------------------------------------- -namespace MiniParse::Scanner +namespace GeNN::Transpiler::Scanner { std::vector scanSource(const std::string_view &source, ErrorHandler &errorHandler) { diff --git a/src/genn/genn/transpiler/statement.cc b/src/genn/genn/transpiler/statement.cc index eb142d178d..19ca9459c2 100644 --- a/src/genn/genn/transpiler/statement.cc +++ b/src/genn/genn/transpiler/statement.cc @@ -1,9 +1,9 @@ -#include "statement.h" +#include "transpiler/statement.h" -#define IMPLEMENT_ACCEPT(CLASS_NAME) \ - void MiniParse::Statement::CLASS_NAME::accept(Visitor &visitor) const \ - { \ - visitor.visit(*this); \ +#define IMPLEMENT_ACCEPT(CLASS_NAME) \ + void GeNN::Transpiler::Statement::CLASS_NAME::accept(Visitor &visitor) const \ + { \ + visitor.visit(*this); \ } // Implement accept methods diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index db4a24626d..cada02ea49 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -1,4 +1,4 @@ -#include "type_checker.h" +#include "transpiler/typeChecker.h" // Standard C++ includes #include @@ -9,13 +9,14 @@ // GeNN includes #include "type.h" -// Mini-parse includes -#include "error_handler.h" -#include "expression.h" -#include "utils.h" +// Transpiler includes +#include "transpiler/errorHandler.h" +#include "transpiler/expression.h" +#include "transpiler/transpilerUtils.h" -using namespace MiniParse; -using namespace MiniParse::TypeChecker; +using namespace GeNN::Transpiler; +using namespace GeNN::Transpiler::TypeChecker; +namespace Type = GeNN::Type; //--------------------------------------------------------------------------- // Anonymous namespace @@ -252,7 +253,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Literal &literal) final { m_Type = std::visit( - MiniParse::Utils::Overload{ + Utils::Overload{ [](auto v)->const Type::NumericBase *{ return Type::TypeTraits::NumericType::getInstance(); }, [](std::monostate)->const Type::NumericBase *{ return nullptr; }}, literal.getValue()); @@ -665,8 +666,8 @@ std::tuple Environment::getType(const Token &name, Err } } //--------------------------------------------------------------------------- -void MiniParse::TypeChecker::typeCheck(const Statement::StatementList &statements, Environment &environment, - ErrorHandler &errorHandler) +void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &statements, Environment &environment, + ErrorHandler &errorHandler) { Visitor visitor(errorHandler); visitor.typeCheck(statements, environment); diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index f8423b9ba2..8dd8b4a6b2 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -4,6 +4,8 @@ #include #include +using namespace GeNN; + // Anonymous namespace namespace { @@ -40,9 +42,9 @@ const std::unordered_map uns } // Anonymous namespace //---------------------------------------------------------------------------- -// Type +// GeNN::Type //---------------------------------------------------------------------------- -namespace Type +namespace GeNN::Type { // Implement numeric types IMPLEMENT_NUMERIC_TYPE(Bool); @@ -136,4 +138,4 @@ const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b) } } } -} \ No newline at end of file +} // namespace GeNN::Type \ No newline at end of file From 8d11729150ff092dde28cd31547281786e824c18 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 6 Jan 2023 10:32:51 +0000 Subject: [PATCH 010/725] makefile update --- src/genn/genn/Makefile | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/genn/genn/Makefile b/src/genn/genn/Makefile index 066d772919..257c52bd50 100644 --- a/src/genn/genn/Makefile +++ b/src/genn/genn/Makefile @@ -2,9 +2,10 @@ include ../MakefileCommon # Find source files -FRONTEND_SOURCES :=$(wildcard *.cc) -BACKEND_SOURCES :=$(wildcard code_generator/*.cc) -SOURCES :=$(FRONTEND_SOURCES) $(BACKEND_SOURCES) +FRONTEND_SOURCES :=$(wildcard *.cc) +CODE_GENERATOR_SOURCES :=$(wildcard code_generator/*.cc) +TRANSPILER_SOURCES :=$(wildcard transpiler/*.cc) +SOURCES :=$(FRONTEND_SOURCES) $(CODE_GENERATOR_SOURCES) $(TRANSPILER_SOURCES) # Build objecs in sub-directory OBJECT_DIRECTORY :=$(OBJECT_DIRECTORY)/genn/genn From f224705c4f200b0e8e955d8ff1556f29b9560169 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 6 Jan 2023 10:57:31 +0000 Subject: [PATCH 011/725] added some tests of number scanner --- tests/unit/scanner.cc | 87 +++++++++++++++++++++++++++++++++++++++++ tests/unit/unit.vcxproj | 1 + 2 files changed, 88 insertions(+) create mode 100644 tests/unit/scanner.cc diff --git a/tests/unit/scanner.cc b/tests/unit/scanner.cc new file mode 100644 index 0000000000..41fe05c884 --- /dev/null +++ b/tests/unit/scanner.cc @@ -0,0 +1,87 @@ +// Google test includes +#include "gtest/gtest.h" + +// GeNN transpiler includes +#include "transpiler/errorHandler.h" +#include "transpiler/scanner.h" + +using namespace GeNN::Transpiler; + + +class TestErrorHandler : public ErrorHandler +{ +public: + TestErrorHandler() : m_Error(false) + {} + + bool hasError() const { return m_Error; } + + virtual void error(size_t line, std::string_view message) override + { + report(line, "", message); + } + + virtual void error(const Token &token, std::string_view message) override + { + if(token.type == Token::Type::END_OF_FILE) { + report(token.line, " at end", message); + } + else { + report(token.line, " at '" + std::string{token.lexeme} + "'", message); + } + } + +private: + void report(size_t line, std::string_view where, std::string_view message) + { + std::cerr << "[line " << line << "] Error" << where << ": " << message << std::endl; + m_Error = true; + } + + bool m_Error; +}; + +//-------------------------------------------------------------------------- +// Tests +//-------------------------------------------------------------------------- +TEST(Scanner, DecimalInt) +{ + TestErrorHandler errorHandler; + const auto positiveTokens = Scanner::scanSource("1234 4294967295U", errorHandler); + ASSERT_FALSE(errorHandler.hasError()); + ASSERT_EQ(positiveTokens.size(), 3); + ASSERT_EQ(positiveTokens[0].type, Token::Type::NUMBER); + ASSERT_EQ(positiveTokens[1].type, Token::Type::NUMBER); + ASSERT_EQ(std::get(positiveTokens[0].literalValue), 1234); + ASSERT_EQ(std::get(positiveTokens[1].literalValue), 4294967295U); + + //const auto negativeTokens = Scanner::scanSource("-1234 -2147483648", errorHandler); +} +//-------------------------------------------------------------------------- +TEST(Scanner, HexInt) +{ + TestErrorHandler errorHandler; + const auto positiveTokens = Scanner::scanSource("0x1234 0xFFFFFFFFU", errorHandler); + ASSERT_FALSE(errorHandler.hasError()); + ASSERT_EQ(positiveTokens.size(), 3); + ASSERT_EQ(positiveTokens[0].type, Token::Type::NUMBER); + ASSERT_EQ(positiveTokens[1].type, Token::Type::NUMBER); + ASSERT_EQ(std::get(positiveTokens[0].literalValue), 0x1234); + ASSERT_EQ(std::get(positiveTokens[1].literalValue), 0xFFFFFFFFU); +} +//-------------------------------------------------------------------------- +TEST(Scanner, DecimalFloat) +{ + TestErrorHandler errorHandler; + const auto positiveTokens = Scanner::scanSource("1.0 0.2 100.0f 0.2f", errorHandler); + ASSERT_FALSE(errorHandler.hasError()); + ASSERT_EQ(positiveTokens.size(), 5); + ASSERT_EQ(positiveTokens[0].type, Token::Type::NUMBER); + ASSERT_EQ(positiveTokens[1].type, Token::Type::NUMBER); + ASSERT_EQ(positiveTokens[2].type, Token::Type::NUMBER); + ASSERT_EQ(positiveTokens[3].type, Token::Type::NUMBER); + ASSERT_EQ(std::get(positiveTokens[0].literalValue), 1.0); + ASSERT_EQ(std::get(positiveTokens[1].literalValue), 0.2); + ASSERT_EQ(std::get(positiveTokens[2].literalValue), 100.0f); + ASSERT_EQ(std::get(positiveTokens[3].literalValue), 0.2f); +} \ No newline at end of file diff --git a/tests/unit/unit.vcxproj b/tests/unit/unit.vcxproj index 83dce27646..f5caf0d574 100644 --- a/tests/unit/unit.vcxproj +++ b/tests/unit/unit.vcxproj @@ -28,6 +28,7 @@ + From a590ef9c0092f823d025e869426e6b7481089adb Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 6 Jan 2023 11:05:58 +0000 Subject: [PATCH 012/725] include some negation --- tests/unit/scanner.cc | 71 +++++++++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 23 deletions(-) diff --git a/tests/unit/scanner.cc b/tests/unit/scanner.cc index 41fe05c884..a3c5bf4911 100644 --- a/tests/unit/scanner.cc +++ b/tests/unit/scanner.cc @@ -47,41 +47,66 @@ class TestErrorHandler : public ErrorHandler TEST(Scanner, DecimalInt) { TestErrorHandler errorHandler; - const auto positiveTokens = Scanner::scanSource("1234 4294967295U", errorHandler); + const auto tokens = Scanner::scanSource("1234 4294967295U -2345 -2147483647", errorHandler); ASSERT_FALSE(errorHandler.hasError()); - ASSERT_EQ(positiveTokens.size(), 3); - ASSERT_EQ(positiveTokens[0].type, Token::Type::NUMBER); - ASSERT_EQ(positiveTokens[1].type, Token::Type::NUMBER); - ASSERT_EQ(std::get(positiveTokens[0].literalValue), 1234); - ASSERT_EQ(std::get(positiveTokens[1].literalValue), 4294967295U); - //const auto negativeTokens = Scanner::scanSource("-1234 -2147483648", errorHandler); + ASSERT_EQ(tokens.size(), 7); + ASSERT_EQ(tokens[0].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[1].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[2].type, Token::Type::MINUS); + ASSERT_EQ(tokens[3].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[4].type, Token::Type::MINUS); + ASSERT_EQ(tokens[5].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[6].type, Token::Type::END_OF_FILE); + + ASSERT_EQ(std::get(tokens[0].literalValue), 1234); + ASSERT_EQ(std::get(tokens[1].literalValue), 4294967295U); + ASSERT_EQ(std::get(tokens[3].literalValue), 2345); + ASSERT_EQ(std::get(tokens[5].literalValue), 2147483647); } //-------------------------------------------------------------------------- TEST(Scanner, HexInt) { TestErrorHandler errorHandler; - const auto positiveTokens = Scanner::scanSource("0x1234 0xFFFFFFFFU", errorHandler); + const auto tokens = Scanner::scanSource("0x1234 0xFFFFFFFFU -0x1234 -0x7FFFFFFF", errorHandler); ASSERT_FALSE(errorHandler.hasError()); - ASSERT_EQ(positiveTokens.size(), 3); - ASSERT_EQ(positiveTokens[0].type, Token::Type::NUMBER); - ASSERT_EQ(positiveTokens[1].type, Token::Type::NUMBER); - ASSERT_EQ(std::get(positiveTokens[0].literalValue), 0x1234); - ASSERT_EQ(std::get(positiveTokens[1].literalValue), 0xFFFFFFFFU); + + ASSERT_EQ(tokens.size(), 7); + ASSERT_EQ(tokens[0].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[1].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[2].type, Token::Type::MINUS); + ASSERT_EQ(tokens[3].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[4].type, Token::Type::MINUS); + ASSERT_EQ(tokens[5].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[6].type, Token::Type::END_OF_FILE); + + ASSERT_EQ(std::get(tokens[0].literalValue), 0x1234); + ASSERT_EQ(std::get(tokens[1].literalValue), 0xFFFFFFFFU); + ASSERT_EQ(std::get(tokens[3].literalValue), 0x1234); + ASSERT_EQ(std::get(tokens[5].literalValue), 0x7FFFFFFF); } //-------------------------------------------------------------------------- TEST(Scanner, DecimalFloat) { TestErrorHandler errorHandler; - const auto positiveTokens = Scanner::scanSource("1.0 0.2 100.0f 0.2f", errorHandler); + const auto tokens = Scanner::scanSource("1.0 0.2 100.0f 0.2f -12.0 -0.0004f", errorHandler); ASSERT_FALSE(errorHandler.hasError()); - ASSERT_EQ(positiveTokens.size(), 5); - ASSERT_EQ(positiveTokens[0].type, Token::Type::NUMBER); - ASSERT_EQ(positiveTokens[1].type, Token::Type::NUMBER); - ASSERT_EQ(positiveTokens[2].type, Token::Type::NUMBER); - ASSERT_EQ(positiveTokens[3].type, Token::Type::NUMBER); - ASSERT_EQ(std::get(positiveTokens[0].literalValue), 1.0); - ASSERT_EQ(std::get(positiveTokens[1].literalValue), 0.2); - ASSERT_EQ(std::get(positiveTokens[2].literalValue), 100.0f); - ASSERT_EQ(std::get(positiveTokens[3].literalValue), 0.2f); + + ASSERT_EQ(tokens.size(), 9); + ASSERT_EQ(tokens[0].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[1].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[2].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[3].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[4].type, Token::Type::MINUS); + ASSERT_EQ(tokens[5].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[6].type, Token::Type::MINUS); + ASSERT_EQ(tokens[7].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[8].type, Token::Type::END_OF_FILE); + + ASSERT_EQ(std::get(tokens[0].literalValue), 1.0); + ASSERT_EQ(std::get(tokens[1].literalValue), 0.2); + ASSERT_EQ(std::get(tokens[2].literalValue), 100.0f); + ASSERT_EQ(std::get(tokens[3].literalValue), 0.2f); + ASSERT_EQ(std::get(tokens[5].literalValue), 12.0); + ASSERT_EQ(std::get(tokens[7].literalValue), 0.0004f); } \ No newline at end of file From d73ad099ed7dd4a4e97cd134579277754cccd07a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 6 Jan 2023 13:46:33 +0000 Subject: [PATCH 013/725] back-ported parsing of switch statements --- src/genn/genn/transpiler/parser.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 44931ee92e..9cb7e6236f 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -601,7 +601,7 @@ Statement::StatementPtr parseSelectionStatement(ParserState &parserState) // selection-statement ::= // "if" "(" expression ")" statement // "if" "(" expression ")" statement "else" statement - // "switch" "(" expression ")" statement + // "switch" "(" expression ")" compound-statement const auto keyword = parserState.previous(); parserState.consume(Token::Type::LEFT_PAREN, "Expect '(' after '" + std::string{keyword.lexeme} + "'"); auto condition = parseExpression(parserState); @@ -621,8 +621,10 @@ Statement::StatementPtr parseSelectionStatement(ParserState &parserState) } // Otherwise (switch statement) else { - return std::make_unique(keyword, std::move(condition), - parseStatement(parserState)); + // **NOTE** this is a slight simplification of the C standard where any type of statement can be used as the body of the switch + parserState.consume(Token::Type::LEFT_BRACE, "Expect '{' after switch statement."); + auto body = parseCompoundStatement(parserState); + return std::make_unique(keyword, std::move(condition), std::move(body)); } } From 6f79567bcfdada572e30c6216bc3625f74b299b5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 6 Jan 2023 16:22:19 +0000 Subject: [PATCH 014/725] added type checker tests --- include/genn/genn/transpiler/typeChecker.h | 15 ++- src/genn/genn/transpiler/typeChecker.cc | 9 +- tests/unit/scanner.cc | 9 +- tests/unit/typeChecker.cc | 117 +++++++++++++++++++++ 4 files changed, 138 insertions(+), 12 deletions(-) create mode 100644 tests/unit/typeChecker.cc diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index ab75bf241a..3cc16bb670 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -20,10 +20,21 @@ class Base; } //--------------------------------------------------------------------------- -// GeNN::Transpiler::TypeChecker::Environment +// GeNN::Transpiler::TypeChecker::TypeCheckError //--------------------------------------------------------------------------- namespace GeNN::Transpiler::TypeChecker { +class TypeCheckError : public std::runtime_error +{ +public: + TypeCheckError() : std::runtime_error("") + { + } +}; + +//--------------------------------------------------------------------------- +// GeNN::Transpiler::TypeChecker::Environment +//--------------------------------------------------------------------------- class Environment { public: @@ -61,4 +72,4 @@ class Environment //--------------------------------------------------------------------------- void typeCheck(const Statement::StatementList &statements, Environment &environment, ErrorHandler &errorHandler); -} // namespace MiniParse::GeNN::Transpiler \ No newline at end of file +} // namespace MiniParse::GeNN::Transpiler diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index cada02ea49..02e94a7b94 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -23,13 +23,6 @@ namespace Type = GeNN::Type; //--------------------------------------------------------------------------- namespace { -//--------------------------------------------------------------------------- -// TypeCheckError -//--------------------------------------------------------------------------- -class TypeCheckError -{ -}; - //--------------------------------------------------------------------------- // Vistor //--------------------------------------------------------------------------- @@ -671,4 +664,4 @@ void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &st { Visitor visitor(errorHandler); visitor.typeCheck(statements, environment); -} \ No newline at end of file +} diff --git a/tests/unit/scanner.cc b/tests/unit/scanner.cc index a3c5bf4911..5c24f9fa5e 100644 --- a/tests/unit/scanner.cc +++ b/tests/unit/scanner.cc @@ -7,7 +7,11 @@ using namespace GeNN::Transpiler; - +//-------------------------------------------------------------------------- +// Anonymous namespace +//-------------------------------------------------------------------------- +namespace +{ class TestErrorHandler : public ErrorHandler { public: @@ -40,6 +44,7 @@ class TestErrorHandler : public ErrorHandler bool m_Error; }; +} // Anonymous namespace //-------------------------------------------------------------------------- // Tests @@ -109,4 +114,4 @@ TEST(Scanner, DecimalFloat) ASSERT_EQ(std::get(tokens[3].literalValue), 0.2f); ASSERT_EQ(std::get(tokens[5].literalValue), 12.0); ASSERT_EQ(std::get(tokens[7].literalValue), 0.0004f); -} \ No newline at end of file +} diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc new file mode 100644 index 0000000000..aa97d3a74e --- /dev/null +++ b/tests/unit/typeChecker.cc @@ -0,0 +1,117 @@ +// Google test includes +#include "gtest/gtest.h" + +// GeNN includes +#include "type.h" + +// GeNN transpiler includes +#include "transpiler/errorHandler.h" +#include "transpiler/parser.h" +#include "transpiler/scanner.h" +#include "transpiler/typeChecker.h" + +using namespace GeNN; +using namespace GeNN::Transpiler; + +//-------------------------------------------------------------------------- +// Anonymous namespace +//-------------------------------------------------------------------------- +namespace +{ +class TestErrorHandler : public ErrorHandler +{ +public: + TestErrorHandler() : m_Error(false) + {} + + bool hasError() const { return m_Error; } + + virtual void error(size_t line, std::string_view message) override + { + report(line, "", message); + } + + virtual void error(const Token &token, std::string_view message) override + { + if(token.type == Token::Type::END_OF_FILE) { + report(token.line, " at end", message); + } + else { + report(token.line, " at '" + std::string{token.lexeme} + "'", message); + } + } + +private: + void report(size_t line, std::string_view where, std::string_view message) + { + std::cerr << "[line " << line << "] Error" << where << ": " << message << std::endl; + m_Error = true; + } + + bool m_Error; +}; + +void typeCheckCode(std::string_view code, TypeChecker::Environment &typeEnvironment) +{ + // Scan + TestErrorHandler errorHandler; + const auto tokens = Scanner::scanSource(code, errorHandler); + ASSERT_FALSE(errorHandler.hasError()); + + // Parse + const auto statements = Parser::parseBlockItemList(tokens, errorHandler); + ASSERT_FALSE(errorHandler.hasError()); + + // Typecheck + TypeChecker::typeCheck(statements, typeEnvironment, errorHandler); + ASSERT_FALSE(errorHandler.hasError()); +} +} // Anonymous namespace + +//-------------------------------------------------------------------------- +// Tests +//-------------------------------------------------------------------------- +TEST(TypeChecker, ArraySubscript) +{ + // Integer array indexing + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray"); + typeCheckCode("int x = intArray[4];", typeEnvironment); + } + + try { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray"); + typeCheckCode("int x = intArray[4.0f];", typeEnvironment); + FAIL(); + } + catch(const TypeChecker::TypeCheckError&) { + } +} +//-------------------------------------------------------------------------- +TEST(TypeChecker, Assignment) +{ +} +//-------------------------------------------------------------------------- +TEST(TypeChecker, Binary) +{ +} +//-------------------------------------------------------------------------- +TEST(TypeChecker, Call) +{ +} +//-------------------------------------------------------------------------- +TEST(TypeChecker, Cast) +{ +} +//-------------------------------------------------------------------------- +TEST(TypeChecker, Conditional) +{ +} +//-------------------------------------------------------------------------- +TEST(TypeChecker, Literal) +{ + + +} From a714dd4dca8b3a419f8b9c26abd634818a2adf5e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 9 Jan 2023 11:02:47 +0000 Subject: [PATCH 015/725] used mildly newer form of exception test and added type checker to unit test project --- tests/unit/typeChecker.cc | 9 +++------ tests/unit/unit.vcxproj | 1 + 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index aa97d3a74e..85f0bf2654 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -80,14 +80,11 @@ TEST(TypeChecker, ArraySubscript) typeCheckCode("int x = intArray[4];", typeEnvironment); } - try { + EXPECT_THROW({ TypeChecker::Environment typeEnvironment; typeEnvironment.define("intArray"); - typeCheckCode("int x = intArray[4.0f];", typeEnvironment); - FAIL(); - } - catch(const TypeChecker::TypeCheckError&) { - } + typeCheckCode("int x = intArray[4.0f];", typeEnvironment);}, + TypeChecker::TypeCheckError); } //-------------------------------------------------------------------------- TEST(TypeChecker, Assignment) diff --git a/tests/unit/unit.vcxproj b/tests/unit/unit.vcxproj index f5caf0d574..87b2db7cbb 100644 --- a/tests/unit/unit.vcxproj +++ b/tests/unit/unit.vcxproj @@ -31,6 +31,7 @@ + From 941a40fa2b4b55da285e63495f5ee0edd3897ad9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 9 Jan 2023 11:45:15 +0000 Subject: [PATCH 016/725] * Exposed ability to typecheck an expression * More type-checking unit tests --- include/genn/genn/transpiler/typeChecker.h | 3 + src/genn/genn/transpiler/typeChecker.cc | 53 ++++--- tests/unit/typeChecker.cc | 152 ++++++++++++++++++++- 3 files changed, 184 insertions(+), 24 deletions(-) diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 3cc16bb670..91f914ac45 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -72,4 +72,7 @@ class Environment //--------------------------------------------------------------------------- void typeCheck(const Statement::StatementList &statements, Environment &environment, ErrorHandler &errorHandler); + +std::tuple typeCheck(const Expression::Base *expression, Environment &environment, + ErrorHandler &errorHandler); } // namespace MiniParse::GeNN::Transpiler diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 02e94a7b94..5bac4a8491 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -38,6 +38,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- + // **THINK** make constructors? void typeCheck(const Statement::StatementList &statements, Environment &environment) { Environment *previous = m_Environment; @@ -48,6 +49,17 @@ class Visitor : public Expression::Visitor, public Statement::Visitor m_Environment = previous; } + std::tuple typeCheck(const Expression::Base *expression, Environment &environment) + { + Environment *previous = m_Environment; + m_Environment = &environment; + + const auto type = evaluateType(expression); + + m_Environment = previous; + return type; + } + //--------------------------------------------------------------------------- // Expression::Visitor virtuals //--------------------------------------------------------------------------- @@ -60,7 +72,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If pointer is indeed a pointer if (pointerType) { // Evaluate pointer type - auto indexType = evaluateType(arraySubscript.getIndex().get()); + auto indexType = std::get<0>(evaluateType(arraySubscript.getIndex().get())); auto indexNumericType = dynamic_cast(indexType); if (!indexNumericType || !indexNumericType->isIntegral()) { m_ErrorHandler.error(arraySubscript.getPointerName(), @@ -81,7 +93,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Assignment &assignment) final { - const auto [rhsType, rhsConst] = evaluateTypeConst(assignment.getValue()); + const auto [rhsType, rhsConst] = evaluateType(assignment.getValue()); m_Type = m_Environment->assign(assignment.getVarName(), rhsType, rhsConst, assignment.getOperator().type, m_ErrorHandler); m_Const = false; @@ -90,14 +102,14 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Binary &binary) final { const auto opType = binary.getOperator().type; - const auto [rightType, rightConst] = evaluateTypeConst(binary.getRight()); + const auto [rightType, rightConst] = evaluateType(binary.getRight()); if (opType == Token::Type::COMMA) { m_Type = rightType; m_Const = rightConst; } else { // If we're subtracting two pointers - const auto [leftType, leftConst] = evaluateTypeConst(binary.getLeft()); + const auto [leftType, leftConst] = evaluateType(binary.getLeft()); auto leftNumericType = dynamic_cast(leftType); auto rightNumericType = dynamic_cast(rightType); auto leftNumericPtrType = dynamic_cast(leftType); @@ -179,7 +191,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Call &call) final { // Evaluate callee type - auto calleeType = evaluateType(call.getCallee()); + auto calleeType = std::get<0>(evaluateType(call.getCallee())); auto calleeFunctionType = dynamic_cast(calleeType); // If callee's a function @@ -223,8 +235,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Conditional &conditional) final { - const auto [trueType, trueConst] = evaluateTypeConst(conditional.getTrue()); - const auto [falseType, falseConst] = evaluateTypeConst(conditional.getFalse()); + const auto [trueType, trueConst] = evaluateType(conditional.getTrue()); + const auto [falseType, falseConst] = evaluateType(conditional.getFalse()); auto trueNumericType = dynamic_cast(trueType); auto falseNumericType = dynamic_cast(falseType); if (trueNumericType && falseNumericType) { @@ -240,7 +252,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Grouping &grouping) final { - std::tie(m_Type, m_Const) = evaluateTypeConst(grouping.getExpression()); + std::tie(m_Type, m_Const) = evaluateType(grouping.getExpression()); } virtual void visit(const Expression::Literal &literal) final @@ -250,7 +262,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor [](auto v)->const Type::NumericBase *{ return Type::TypeTraits::NumericType::getInstance(); }, [](std::monostate)->const Type::NumericBase *{ return nullptr; }}, literal.getValue()); - m_Const = false; + m_Const = true; } virtual void visit(const Expression::Logical &logical) final @@ -282,7 +294,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Unary &unary) final { - const auto [rightType, rightConst] = evaluateTypeConst(unary.getRight()); + const auto [rightType, rightConst] = evaluateType(unary.getRight()); // If operator is pointer de-reference if (unary.getOperator().type == Token::Type::STAR) { @@ -420,7 +432,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } if (labelled.getValue()) { - auto valType = evaluateType(labelled.getValue()); + auto valType = std::get<0>(evaluateType(labelled.getValue())); auto valNumericType = dynamic_cast(valType); if (!valNumericType || !valNumericType->isIntegral()) { m_ErrorHandler.error(labelled.getKeyword(), @@ -434,7 +446,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::Switch &switchStatement) final { - auto condType = evaluateType(switchStatement.getCondition()); + auto condType = std::get<0>(evaluateType(switchStatement.getCondition())); auto condNumericType = dynamic_cast(condType); if (!condNumericType || !condNumericType->isIntegral()) { m_ErrorHandler.error(switchStatement.getSwitch(), @@ -456,7 +468,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If variable has an initialiser expression if (std::get<1>(var)) { // Evaluate type - const auto [initialiserType, initialiserConst] = evaluateTypeConst(std::get<1>(var).get()); + const auto [initialiserType, initialiserConst] = evaluateType(std::get<1>(var).get()); // Assign initialiser expression to variable m_Environment->assign(std::get<0>(var), initialiserType, initialiserConst, Token::Type::EQUAL, m_ErrorHandler); @@ -481,17 +493,12 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // Private methods //--------------------------------------------------------------------------- - std::tuple evaluateTypeConst(const Expression::Base *expression) + std::tuple evaluateType(const Expression::Base *expression) { expression->accept(*this); return std::make_tuple(m_Type, m_Const); } - const Type::Base *evaluateType(const Expression::Base *expression) - { - return std::get<0>(evaluateTypeConst(expression)); - } - //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- @@ -665,3 +672,11 @@ void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &st Visitor visitor(errorHandler); visitor.typeCheck(statements, environment); } +//--------------------------------------------------------------------------- +std::tuple GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, + Environment &environment, + ErrorHandler &errorHandler) +{ + Visitor visitor(errorHandler); + return visitor.typeCheck(expression, environment); +} diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 85f0bf2654..63b8cba4a9 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -51,7 +51,7 @@ class TestErrorHandler : public ErrorHandler bool m_Error; }; -void typeCheckCode(std::string_view code, TypeChecker::Environment &typeEnvironment) +void typeCheckStatements(std::string_view code, TypeChecker::Environment &typeEnvironment) { // Scan TestErrorHandler errorHandler; @@ -66,6 +66,23 @@ void typeCheckCode(std::string_view code, TypeChecker::Environment &typeEnvironm TypeChecker::typeCheck(statements, typeEnvironment, errorHandler); ASSERT_FALSE(errorHandler.hasError()); } + +std::tuple typeCheckExpression(std::string_view code, TypeChecker::Environment &typeEnvironment) +{ + // Scan + TestErrorHandler errorHandler; + const auto tokens = Scanner::scanSource(code, errorHandler); + EXPECT_FALSE(errorHandler.hasError()); + + // Parse + const auto expression = Parser::parseExpression(tokens, errorHandler); + EXPECT_FALSE(errorHandler.hasError()); + + // Typecheck + const auto type = TypeChecker::typeCheck(expression.get(), typeEnvironment, errorHandler); + EXPECT_FALSE(errorHandler.hasError()); + return type; +} } // Anonymous namespace //-------------------------------------------------------------------------- @@ -77,18 +94,69 @@ TEST(TypeChecker, ArraySubscript) { TypeChecker::Environment typeEnvironment; typeEnvironment.define("intArray"); - typeCheckCode("int x = intArray[4];", typeEnvironment); + typeCheckStatements("int x = intArray[4];", typeEnvironment); } + // Float array indexing EXPECT_THROW({ TypeChecker::Environment typeEnvironment; typeEnvironment.define("intArray"); - typeCheckCode("int x = intArray[4.0f];", typeEnvironment);}, + typeCheckStatements("int x = intArray[4.0f];", typeEnvironment);}, + TypeChecker::TypeCheckError); + + // Pointer indexing + EXPECT_THROW({ + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray"); + typeEnvironment.define("indexArray"); + typeCheckStatements("int x = intArray[indexArray];", typeEnvironment);}, TypeChecker::TypeCheckError); } //-------------------------------------------------------------------------- TEST(TypeChecker, Assignment) { + // Numeric assignment + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intVal"); + typeEnvironment.define("floatVal"); + typeEnvironment.define("intValConst", true); + typeCheckStatements( + "int w = intVal;\n" + "float x = floatVal;\n" + "int y = floatVal;\n" + "float z = intVal;\n" + "int wc = intValConst;\n" + "const int cw = intVal;\n" + "const int cwc = invValConst;\n", + typeEnvironment); + } + + // Pointer assignement + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray"); + typeEnvironment.define("intArrayConst", true); + typeCheckStatements( + "int *x = intArray;\n" + "const *y = intArray;\n" + "const *z = intArrayConst;\n", + typeEnvironment); + } + + // Pointer assignement, attempt to remove const + EXPECT_THROW({ + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray", true); + typeCheckStatements("int *x = intArray;", typeEnvironment);}, + TypeChecker::TypeCheckError); + + // Pointer assignement without explicit cast + EXPECT_THROW({ + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray"); + typeCheckStatements("float *x = intArray;", typeEnvironment);}, + TypeChecker::TypeCheckError); } //-------------------------------------------------------------------------- TEST(TypeChecker, Binary) @@ -101,6 +169,21 @@ TEST(TypeChecker, Call) //-------------------------------------------------------------------------- TEST(TypeChecker, Cast) { + // Numeric cast + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intVal"); + const auto type = typeCheckExpression("(float)intArray", typeEnvironment); + EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Float::getInstance()->getTypeHash()); + EXPECT_FALSE(std::get<1>(type)); + } + + // Pointer cast can't reinterpret + EXPECT_THROW({ + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray"); + typeCheckExpression("(float*)intArray", typeEnvironment);}, + TypeChecker::TypeCheckError); } //-------------------------------------------------------------------------- TEST(TypeChecker, Conditional) @@ -109,6 +192,65 @@ TEST(TypeChecker, Conditional) //-------------------------------------------------------------------------- TEST(TypeChecker, Literal) { - - + // Float + { + TypeChecker::Environment typeEnvironment; + const auto type = typeCheckExpression("1.0f", typeEnvironment); + EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Float::getInstance()->getTypeHash()); + EXPECT_TRUE(std::get<1>(type)); + } + + // Double + { + TypeChecker::Environment typeEnvironment; + const auto type = typeCheckExpression("1.0", typeEnvironment); + EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Double::getInstance()->getTypeHash()); + EXPECT_TRUE(std::get<1>(type)); + } + + // Integer + { + TypeChecker::Environment typeEnvironment; + const auto type = typeCheckExpression("100", typeEnvironment); + EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_TRUE(std::get<1>(type)); + } + + // Unsigned integer + { + TypeChecker::Environment typeEnvironment; + const auto type = typeCheckExpression("100U", typeEnvironment); + EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Uint32::getInstance()->getTypeHash()); + EXPECT_TRUE(std::get<1>(type)); + } } +//-------------------------------------------------------------------------- +TEST(TypeChecker, Unary) +{ + // Dereference pointer + // **TODO** const + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray"); + const auto type = typeCheckExpression("*intArray", typeEnvironment); + EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_FALSE(std::get<1>(type)); + } + + // Dereference numeric + EXPECT_THROW({ + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intVal"); + typeCheckExpression("*intVal", typeEnvironment);}, + TypeChecker::TypeCheckError); + + // Address of numeric + // **TODO** const + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intVal"); + const auto type = typeCheckExpression("&intVal", typeEnvironment); + EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); + EXPECT_FALSE(std::get<1>(type)); + } +} \ No newline at end of file From e2e534a8422c513bfa039734e18dd7abba42fd3b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 9 Jan 2023 18:03:15 +0000 Subject: [PATCH 017/725] First go at type system that can understand the difference between const pointers and pointers to const values * Type::Qualified type encapsulates type, constValue and constPointer * Slightly smarter parser of declarations * Some improvements to type-checker * Correct pretty printing of types --- include/genn/genn/transpiler/expression.h | 20 +- include/genn/genn/transpiler/statement.h | 24 +- include/genn/genn/transpiler/typeChecker.h | 29 ++- include/genn/genn/type.h | 16 ++ src/genn/genn/transpiler/parser.cc | 57 +++-- src/genn/genn/transpiler/prettyPrinter.cc | 20 +- src/genn/genn/transpiler/typeChecker.cc | 259 ++++++++++----------- tests/unit/typeChecker.cc | 70 ++++-- 8 files changed, 271 insertions(+), 224 deletions(-) diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index 44721c319a..083fc56780 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -4,6 +4,9 @@ #include #include +// GeNN includes +#include "type.h" + // Transpiler includes #include "transpiler/token.h" @@ -12,10 +15,6 @@ namespace GeNN::Transpiler::Expression { class Visitor; } -namespace GeNN::Type -{ -class Base; -} //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::Base @@ -123,20 +122,17 @@ class Call : public Base class Cast : public Base { public: - Cast(const Type::Base *type, bool isConst, ExpressionPtr expression) - : m_Type(type), m_Const(isConst), m_Expression(std::move(expression)) + Cast(const Type::QualifiedType &qualifiedType, ExpressionPtr expression) + : m_QualifiedType(qualifiedType), m_Expression(std::move(expression)) {} virtual void accept(Visitor &visitor) const final; const Base *getExpression() const { return m_Expression.get(); } - - const Type::Base *getType() const { return m_Type; } - bool isConst() const { return m_Const; } + const Type::QualifiedType &getQualifiedType() const{ return m_QualifiedType; } private: - const Type::Base *m_Type; - bool m_Const; + const Type::QualifiedType m_QualifiedType; const ExpressionPtr m_Expression; }; @@ -321,4 +317,4 @@ class Visitor virtual void visit(const Variable &variable) = 0; virtual void visit(const Unary &unary) = 0; }; -} // namespace GeNN::Transpiler::Expression \ No newline at end of file +} // namespace GeNN::Transpiler::Expression diff --git a/include/genn/genn/transpiler/statement.h b/include/genn/genn/transpiler/statement.h index cff48c81e1..af0c4f49f2 100644 --- a/include/genn/genn/transpiler/statement.h +++ b/include/genn/genn/transpiler/statement.h @@ -4,6 +4,9 @@ #include #include +// GeNN includes +#include "type.h" + // Transpiler includes #include "transpiler/expression.h" @@ -12,10 +15,6 @@ namespace GeNN::Transpiler::Statement { class Visitor; } -namespace GeNN::Type -{ -class Base; -} //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::Base @@ -229,20 +228,17 @@ class VarDeclaration : public Base public: typedef std::vector> InitDeclaratorList; - VarDeclaration(const Type::Base *type, bool isConst, InitDeclaratorList initDeclaratorList) - : m_Type(type), m_Const(isConst), m_InitDeclaratorList(std::move(initDeclaratorList)) + VarDeclaration(const Type::QualifiedType &qualifiedType, InitDeclaratorList initDeclaratorList) + : m_QualifiedType(qualifiedType), m_InitDeclaratorList(std::move(initDeclaratorList)) {} virtual void accept(Visitor &visitor) const override; - const Type::Base *getType() const { return m_Type; } - bool isConst() const { return m_Const; } - - const InitDeclaratorList &getInitDeclaratorList() const { return m_InitDeclaratorList; } - + const Type::QualifiedType &getQualifiedType() const{ return m_QualifiedType; } + const InitDeclaratorList &getInitDeclaratorList() const { return m_InitDeclaratorList; } + private: - const Type::Base *m_Type; - const bool m_Const; + const Type::QualifiedType m_QualifiedType; const std::vector m_DeclarationSpecifiers; const InitDeclaratorList m_InitDeclaratorList; }; @@ -307,4 +303,4 @@ class Visitor virtual void visit(const While &whileStatement) = 0; virtual void visit(const Print &print) = 0; }; -} // namespace GeNN::Transpiler::Statement \ No newline at end of file +} // namespace GeNN::Transpiler::Statement diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 91f914ac45..89a5fbbda2 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -5,6 +5,9 @@ #include #include +// GeNN includes +#include "type.h" + // Transpiler includes #include "transpiler/statement.h" @@ -14,10 +17,6 @@ namespace GeNN::Transpiler class ErrorHandler; struct Token; } -namespace GeNN::Type -{ -class Base; -} //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker::TypeCheckError @@ -39,7 +38,7 @@ class Environment { public: Environment(Environment *enclosing = nullptr) - : m_Enclosing(enclosing) + : m_Enclosing(enclosing) { } @@ -47,24 +46,24 @@ class Environment // Public API //--------------------------------------------------------------------------- template - void define(std::string_view name, bool isConst = false) + void define(std::string_view name, bool isConstValue = false, bool isConstPointer = false) { - if(!m_Types.try_emplace(name, T::getInstance(), isConst).second) { + if(!m_Types.try_emplace(name, T::getInstance(), isConstValue, isConstPointer).second) { throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } - void define(const Token &name, const Type::Base *type, bool isConst, ErrorHandler &errorHandler); - const Type::Base *assign(const Token &name, const Type::Base *assignedType, bool assignedConst, - Token::Type op, ErrorHandler &errorHandler); - const Type::Base *incDec(const Token &name, const Token &op, ErrorHandler &errorHandler); - std::tuple getType(const Token &name, ErrorHandler &errorHandler) const; + void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandler &errorHandler); + const Type::QualifiedType &assign(const Token &name, const Type::QualifiedType &assignedType, + Token::Type op, ErrorHandler &errorHandler); + const Type::QualifiedType &incDec(const Token &name, const Token &op, ErrorHandler &errorHandler); + const Type::QualifiedType &getType(const Token &name, ErrorHandler &errorHandler) const; private: //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- Environment *m_Enclosing; - std::unordered_map> m_Types; + std::unordered_map m_Types; }; //--------------------------------------------------------------------------- @@ -73,6 +72,6 @@ class Environment void typeCheck(const Statement::StatementList &statements, Environment &environment, ErrorHandler &errorHandler); -std::tuple typeCheck(const Expression::Base *expression, Environment &environment, - ErrorHandler &errorHandler); +Type::QualifiedType typeCheck(const Expression::Base *expression, Environment &environment, + ErrorHandler &errorHandler); } // namespace MiniParse::GeNN::Transpiler diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 0e69abb537..0a9baf13da 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -83,6 +83,22 @@ class Base virtual size_t getTypeHash() const = 0; }; +//---------------------------------------------------------------------------- +// GeNN::Type::QualifiedType +//---------------------------------------------------------------------------- +//! A type with qualifiers attached +struct QualifiedType +{ + QualifiedType(const Base *t, bool v, bool p) + : type(t), constValue(v), constPointer(p) + { + } + + const Base *type; + bool constValue; + bool constPointer; +}; + //---------------------------------------------------------------------------- // GeNN::Type::NumericBase //---------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 9cb7e6236f..74f871c882 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -1,6 +1,7 @@ #include "transpiler/parser.h" // Standard C++ includes +#include #include #include #include @@ -182,36 +183,44 @@ Expression::ExpressionPtr parseBinary(ParserState &parserState, N nonTerminal, s return expression; } -std::tuple parseDeclarationSpecifiers(ParserState &parserState) +GeNN::Type::QualifiedType parseDeclarationSpecifiers(ParserState &parserState) { - // Loop through type qualifier and specifier tokens - std::set typeQualifiers{}; - std::set typeSpecifiers{}; + bool pointerFound = false; + std::set typeSpecifiers; + std::set valueTypeQualifiers; + std::set pointerTypeQualifiers; do { - // Add token lexeme to appropriate set, giving error if duplicate - if(parserState.previous().type == Token::Type::TYPE_QUALIFIER) { + // If token is a star, set pointer found flag + if(parserState.previous().type == Token::Type::STAR) { + pointerFound = true; + } + // Otherwise, if type is a qualifier + else if(parserState.previous().type == Token::Type::TYPE_QUALIFIER) { + // Add qualifier lexeme to correct list + auto &typeQualifiers = pointerFound ? pointerTypeQualifiers : valueTypeQualifiers; if(!typeQualifiers.insert(parserState.previous().lexeme).second) { parserState.error(parserState.previous(), "duplicate type qualifier"); } } - else { - if(!typeSpecifiers.insert(parserState.previous().lexeme).second) { + else if(parserState.previous().type == Token::Type::TYPE_SPECIFIER) { + if(pointerFound) { + parserState.error(parserState.previous(), "invalid type specifier"); + } + else if(!typeSpecifiers.insert(parserState.previous().lexeme).second) { parserState.error(parserState.previous(), "duplicate type specifier"); } } - } while(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER})); + } while(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER, Token::Type::STAR})); - // Lookup type - const GeNN::Type::Base *type = (parserState.match({Token::Type::STAR}) - ? static_cast(GeNN::Type::getNumericPtrType(typeSpecifiers)) - : static_cast(GeNN::Type::getNumericType(typeSpecifiers))); - if(!type) { - parserState.error("Unknown type specifier"); - } - - // Determine constness - // **NOTE** this only works as const is the ONLY supported qualifier - return std::make_tuple(type, !typeQualifiers.empty()); + // Lookup type based on whether token was found + const GeNN::Type::Base *type = (pointerFound + ? static_cast(GeNN::Type::getNumericType(typeSpecifiers)) + : static_cast(GeNN::Type::getNumericPtrType(typeSpecifiers))); + + // Return qualified type + // **THINK** this relies of const being only qualifier + // **TODO** warn of duplicate type qualifiers + return GeNN::Type::QualifiedType{type, !valueTypeQualifiers.empty(), !pointerTypeQualifiers.empty()}; } Expression::ExpressionPtr parsePrimary(ParserState &parserState) @@ -362,11 +371,11 @@ Expression::ExpressionPtr parseCast(ParserState &parserState) // If this is followed by some part of a type declarator if(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER})) { // Parse declaration specifiers - const auto [type, isConst] = parseDeclarationSpecifiers(parserState); + const auto qualifiedType = parseDeclarationSpecifiers(parserState); parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after cast type."); - return std::make_unique(type, isConst, parseCast(parserState)); + return std::make_unique(qualifiedType, parseCast(parserState)); } // Otherwise, rewind parser state so left parenthesis can be parsed again // **YUCK** @@ -781,7 +790,7 @@ Statement::StatementPtr parseDeclaration(ParserState &parserState) // "const" // Parse declaration specifiers - const auto [type, isConst] = parseDeclarationSpecifiers(parserState); + const auto qualifiedType = parseDeclarationSpecifiers(parserState); // Read init declarator list std::vector> initDeclaratorList; @@ -805,7 +814,7 @@ Statement::StatementPtr parseDeclaration(ParserState &parserState) } while(!parserState.isAtEnd() && parserState.match(Token::Type::COMMA)); parserState.consume(Token::Type::SEMICOLON, "Expect ';' after variable declaration"); - return std::make_unique(type, isConst, std::move(initDeclaratorList)); + return std::make_unique(qualifiedType, std::move(initDeclaratorList)); } std::unique_ptr parseBlockItem(ParserState &parserState) diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 98dfc5f959..4d58023a90 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -77,7 +77,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Cast &cast) final { - m_StringStream << "(" << cast.getType()->getTypeName() << ")"; + m_StringStream << "("; + printQualifiedType(cast.getQualifiedType()); + m_StringStream << ")"; cast.getExpression()->accept(*this); } @@ -227,10 +229,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::VarDeclaration &varDeclaration) final { - if(varDeclaration.isConst()) { - m_StringStream << "const "; - } - m_StringStream << varDeclaration.getType()->getTypeName() << " "; + printQualifiedType(varDeclaration.getQualifiedType()); for(const auto &var : varDeclaration.getInitDeclaratorList()) { m_StringStream << std::get<0>(var).lexeme; @@ -259,6 +258,17 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } private: + void printQualifiedType(const GeNN::Type::QualifiedType &qualifiedType) + { + if(qualifiedType.constValue) { + m_StringStream << "const "; + } + m_StringStream << qualifiedType.type->getTypeName() << " "; + + if(qualifiedType.constPointer) { + m_StringStream << "const "; + } + } //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 5bac4a8491..98e985d849 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -30,8 +30,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { public: Visitor(ErrorHandler &errorHandler) - : m_Environment(nullptr), m_Type(nullptr), m_Const(false), - m_ErrorHandler(errorHandler), m_InLoop(false), m_InSwitch(false) + : m_Environment(nullptr), m_QualifiedType{nullptr, false, false}, m_ErrorHandler(errorHandler), + m_InLoop(false), m_InSwitch(false) { } @@ -49,7 +49,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor m_Environment = previous; } - std::tuple typeCheck(const Expression::Base *expression, Environment &environment) + const Type::QualifiedType typeCheck(const Expression::Base *expression, Environment &environment) { Environment *previous = m_Environment; m_Environment = &environment; @@ -66,23 +66,22 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::ArraySubscript &arraySubscript) final { // Get pointer type - auto pointerType = dynamic_cast( - std::get<0>(m_Environment->getType(arraySubscript.getPointerName(), m_ErrorHandler))); + auto arrayType = m_Environment->getType(arraySubscript.getPointerName(), m_ErrorHandler); + auto pointerType = dynamic_cast(arrayType.type); // If pointer is indeed a pointer if (pointerType) { // Evaluate pointer type - auto indexType = std::get<0>(evaluateType(arraySubscript.getIndex().get())); - auto indexNumericType = dynamic_cast(indexType); + auto indexType = evaluateType(arraySubscript.getIndex().get()); + auto indexNumericType = dynamic_cast(indexType.type); if (!indexNumericType || !indexNumericType->isIntegral()) { m_ErrorHandler.error(arraySubscript.getPointerName(), - "Invalid subscript index type '" + indexType->getTypeName() + "'"); + "Invalid subscript index type '" + indexType.type->getTypeName() + "'"); throw TypeCheckError(); } // Use value type of array - m_Type = pointerType->getValueType(); - m_Const = false; + m_QualifiedType = Type::QualifiedType{pointerType->getValueType(), arrayType.constValue, false}; } // Otherwise else { @@ -93,63 +92,58 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Assignment &assignment) final { - const auto [rhsType, rhsConst] = evaluateType(assignment.getValue()); - m_Type = m_Environment->assign(assignment.getVarName(), rhsType, rhsConst, - assignment.getOperator().type, m_ErrorHandler); - m_Const = false; + const auto rhsType = evaluateType(assignment.getValue()); + m_QualifiedType = m_Environment->assign(assignment.getVarName(), rhsType, + assignment.getOperator().type, m_ErrorHandler); } virtual void visit(const Expression::Binary &binary) final { const auto opType = binary.getOperator().type; - const auto [rightType, rightConst] = evaluateType(binary.getRight()); + const auto rightType = evaluateType(binary.getRight()); if (opType == Token::Type::COMMA) { - m_Type = rightType; - m_Const = rightConst; + m_QualifiedType = rightType; } else { // If we're subtracting two pointers - const auto [leftType, leftConst] = evaluateType(binary.getLeft()); - auto leftNumericType = dynamic_cast(leftType); - auto rightNumericType = dynamic_cast(rightType); - auto leftNumericPtrType = dynamic_cast(leftType); - auto rightNumericPtrType = dynamic_cast(rightType); + const auto leftType = evaluateType(binary.getLeft()); + auto leftNumericType = dynamic_cast(leftType.type); + auto rightNumericType = dynamic_cast(rightType.type); + auto leftNumericPtrType = dynamic_cast(leftType.type); + auto rightNumericPtrType = dynamic_cast(rightType.type); if (leftNumericPtrType && rightNumericPtrType && opType == Token::Type::MINUS) { // Check pointers are compatible if (leftNumericPtrType->getTypeHash() != rightNumericPtrType->getTypeHash()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getTypeName() + "' and '" + rightType->getTypeName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getTypeName() + "' and '" + rightType.type->getTypeName()); throw TypeCheckError(); } // **TODO** should be std::ptrdiff/Int64 - m_Type = Type::Int32::getInstance(); - m_Const = false; + m_QualifiedType = Type::QualifiedType{Type::Int32::getInstance(), false, false}; } // Otherwise, if we're adding to or subtracting from pointers else if (leftNumericPtrType && rightNumericType && (opType == Token::Type::PLUS || opType == Token::Type::MINUS)) // P + n or P - n { // Check that numeric operand is integer if (!rightNumericType->isIntegral()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getTypeName() + "' and '" + rightType->getTypeName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getTypeName() + "' and '" + rightType.type->getTypeName()); throw TypeCheckError(); } - // Use pointer type - m_Type = leftNumericPtrType; - m_Const = leftConst; + // Use left type + m_QualifiedType = leftType; } // Otherwise, if we're adding a number to a pointer else if (leftNumericType && rightNumericPtrType && opType == Token::Type::PLUS) // n + P { // Check that numeric operand is integer if (!leftNumericType->isIntegral()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getTypeName() + "' and '" + rightType->getTypeName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getTypeName() + "' and '" + rightType.type->getTypeName()); throw TypeCheckError(); } - // Use pointer type - m_Type = rightNumericPtrType; - m_Const = rightConst; + // Use right type + m_QualifiedType = leftType; } // Otherwise, if both operands are numeric else if (leftNumericType && rightNumericType) { @@ -160,29 +154,27 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Check that operands are integers if (!leftNumericType->isIntegral() || !rightNumericType->isIntegral()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getTypeName() + "' and '" + rightType->getTypeName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getTypeName() + "' and '" + rightType.type->getTypeName()); throw TypeCheckError(); } // If operator is a shift, promote left type if (opType == Token::Type::SHIFT_LEFT || opType == Token::Type::SHIFT_RIGHT) { - m_Type = Type::getPromotedType(leftNumericType); - m_Const = false; + + m_QualifiedType = Type::QualifiedType{Type::getPromotedType(leftNumericType), false, false}; } // Otherwise, take common type else { - m_Type = Type::getCommonType(leftNumericType, rightNumericType); - m_Const = false; + m_QualifiedType = Type::QualifiedType{Type::getCommonType(leftNumericType, rightNumericType), false, false}; } } // Otherwise, any numeric type will do, take common type else { - m_Type = Type::getCommonType(leftNumericType, rightNumericType); - m_Const = false; + m_QualifiedType = Type::QualifiedType{Type::getCommonType(leftNumericType, rightNumericType), false, false}; } } else { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getTypeName() + "' and '" + rightType->getTypeName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getTypeName() + "' and '" + rightType.type->getTypeName()); throw TypeCheckError(); } } @@ -191,8 +183,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Call &call) final { // Evaluate callee type - auto calleeType = std::get<0>(evaluateType(call.getCallee())); - auto calleeFunctionType = dynamic_cast(calleeType); + auto calleeType = evaluateType(call.getCallee()); + auto calleeFunctionType = dynamic_cast(calleeType.type); // If callee's a function if (calleeFunctionType) { @@ -214,8 +206,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto callArgType = evaluateType(call.getArguments().at(i).get()); }*/ // Type is return type of function - m_Type = calleeFunctionType->getReturnType(); - m_Const = false; + m_QualifiedType = Type::QualifiedType{calleeFunctionType->getReturnType(), false, false}; } } // Otherwise @@ -229,124 +220,120 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // **TODO** any numeric can be cast to any numeric and any pointer to pointer but no intermixing // **TODO** const cannot be removed like this - m_Type = cast.getType(); - m_Const = cast.isConst(); + m_QualifiedType = cast.getQualifiedType(); } virtual void visit(const Expression::Conditional &conditional) final { - const auto [trueType, trueConst] = evaluateType(conditional.getTrue()); - const auto [falseType, falseConst] = evaluateType(conditional.getFalse()); - auto trueNumericType = dynamic_cast(trueType); - auto falseNumericType = dynamic_cast(falseType); + const auto trueType = evaluateType(conditional.getTrue()); + const auto falseType = evaluateType(conditional.getFalse()); + auto trueNumericType = dynamic_cast(trueType.type); + auto falseNumericType = dynamic_cast(falseType.type); if (trueNumericType && falseNumericType) { - m_Type = Type::getCommonType(trueNumericType, falseNumericType); - m_Const = trueConst || falseConst; + // **TODO** check behaviour + m_QualifiedType = Type::QualifiedType{Type::getCommonType(trueNumericType, falseNumericType), + trueType.constValue || falseType.constValue, + trueType.constPointer || falseType.constPointer}; } else { m_ErrorHandler.error(conditional.getQuestion(), - "Invalid operand types '" + trueType->getTypeName() + "' and '" + std::string{falseType->getTypeName()} + "' to conditional"); + "Invalid operand types '" + trueType.type->getTypeName() + "' and '" + falseType.type->getTypeName() + "' to conditional"); throw TypeCheckError(); } } virtual void visit(const Expression::Grouping &grouping) final { - std::tie(m_Type, m_Const) = evaluateType(grouping.getExpression()); + m_QualifiedType = evaluateType(grouping.getExpression()); } virtual void visit(const Expression::Literal &literal) final { - m_Type = std::visit( - Utils::Overload{ - [](auto v)->const Type::NumericBase *{ return Type::TypeTraits::NumericType::getInstance(); }, - [](std::monostate)->const Type::NumericBase *{ return nullptr; }}, - literal.getValue()); - m_Const = true; + m_QualifiedType = Type::QualifiedType{ + std::visit(Utils::Overload{ + [](auto v)->const Type::NumericBase *{ return Type::TypeTraits::NumericType::getInstance(); }, + [](std::monostate)->const Type::NumericBase *{ return nullptr; }}, + literal.getValue()), + true, + false}; } virtual void visit(const Expression::Logical &logical) final { logical.getLeft()->accept(*this); logical.getRight()->accept(*this); - m_Type = Type::Int32::getInstance(); - m_Const = false; + m_QualifiedType = Type::QualifiedType{Type::Int32::getInstance(), false, false}; } virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final { - m_Type = m_Environment->incDec(postfixIncDec.getVarName(), - postfixIncDec.getOperator(), m_ErrorHandler); - m_Const = false; + m_QualifiedType = m_Environment->incDec(postfixIncDec.getVarName(), + postfixIncDec.getOperator(), m_ErrorHandler); } virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final { - m_Type = m_Environment->incDec(prefixIncDec.getVarName(), - prefixIncDec.getOperator(), m_ErrorHandler); - m_Const = false; + m_QualifiedType = m_Environment->incDec(prefixIncDec.getVarName(), + prefixIncDec.getOperator(), m_ErrorHandler); } virtual void visit(const Expression::Variable &variable) { - std::tie(m_Type, m_Const) = m_Environment->getType(variable.getName(), m_ErrorHandler); + m_QualifiedType = m_Environment->getType(variable.getName(), m_ErrorHandler); } virtual void visit(const Expression::Unary &unary) final { - const auto [rightType, rightConst] = evaluateType(unary.getRight()); + const auto rightType = evaluateType(unary.getRight()); // If operator is pointer de-reference if (unary.getOperator().type == Token::Type::STAR) { - auto rightNumericPtrType = dynamic_cast(rightType); + auto rightNumericPtrType = dynamic_cast(rightType.type); if (!rightNumericPtrType) { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getTypeName() + "'"); + "Invalid operand type '" + rightType.type->getTypeName() + "'"); throw TypeCheckError(); } // Return value type - m_Type = rightNumericPtrType->getValueType(); - - // **THINK** - m_Const = false; + m_QualifiedType = Type::QualifiedType{rightNumericPtrType->getValueType(), rightType.constValue, false}; } // Otherwise else { - auto rightNumericType = dynamic_cast(rightType); + auto rightNumericType = dynamic_cast(rightType.type); if (rightNumericType) { // If operator is arithmetic, return promoted type if (unary.getOperator().type == Token::Type::PLUS || unary.getOperator().type == Token::Type::MINUS) { - m_Type = Type::getPromotedType(rightNumericType); - m_Const = false; + m_QualifiedType = Type::QualifiedType{Type::getPromotedType(rightNumericType), + rightType.constValue, false}; } // Otherwise, if operator is bitwise else if (unary.getOperator().type == Token::Type::TILDA) { // If type is integer, return promoted type if (rightNumericType->isIntegral()) { - m_Type = Type::getPromotedType(rightNumericType); - m_Const = false; + m_QualifiedType = Type::QualifiedType{Type::getPromotedType(rightNumericType), + rightType.constValue, false}; } else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getTypeName() + "'"); + "Invalid operand type '" + rightType.type->getTypeName() + "'"); throw TypeCheckError(); } } // Otherwise, if operator is logical else if (unary.getOperator().type == Token::Type::NOT) { - m_Type = Type::Int32::getInstance(); - m_Const = false; + m_QualifiedType = Type::QualifiedType{Type::Int32::getInstance(), + rightType.constValue, false}; } // Otherwise, if operator is address of, return pointer type else if (unary.getOperator().type == Token::Type::AMPERSAND) { - m_Type = rightNumericType->getPointerType(); - m_Const = rightConst; + m_QualifiedType = Type::QualifiedType{rightNumericType->getPointerType(), + rightType.constValue, false}; } } else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getTypeName() + "'"); + "Invalid operand type '" + rightType.type->getTypeName() + "'"); throw TypeCheckError(); } } @@ -432,11 +419,11 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } if (labelled.getValue()) { - auto valType = std::get<0>(evaluateType(labelled.getValue())); - auto valNumericType = dynamic_cast(valType); + auto valType = evaluateType(labelled.getValue()); + auto valNumericType = dynamic_cast(valType.type); if (!valNumericType || !valNumericType->isIntegral()) { m_ErrorHandler.error(labelled.getKeyword(), - "Invalid case value '" + valType->getTypeName() + "'"); + "Invalid case value '" + valType.type->getTypeName() + "'"); throw TypeCheckError(); } } @@ -446,11 +433,11 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::Switch &switchStatement) final { - auto condType = std::get<0>(evaluateType(switchStatement.getCondition())); - auto condNumericType = dynamic_cast(condType); + auto condType = evaluateType(switchStatement.getCondition()); + auto condNumericType = dynamic_cast(condType.type); if (!condNumericType || !condNumericType->isIntegral()) { m_ErrorHandler.error(switchStatement.getSwitch(), - "Invalid condition '" + condType->getTypeName() + "'"); + "Invalid condition '" + condType.type->getTypeName() + "'"); throw TypeCheckError(); } @@ -462,16 +449,16 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::VarDeclaration &varDeclaration) final { for (const auto &var : varDeclaration.getInitDeclaratorList()) { - m_Environment->define(std::get<0>(var), varDeclaration.getType(), - varDeclaration.isConst(), m_ErrorHandler); + m_Environment->define(std::get<0>(var), varDeclaration.getQualifiedType(), m_ErrorHandler); // If variable has an initialiser expression if (std::get<1>(var)) { // Evaluate type - const auto [initialiserType, initialiserConst] = evaluateType(std::get<1>(var).get()); + const auto initialiserType = evaluateType(std::get<1>(var).get()); // Assign initialiser expression to variable - m_Environment->assign(std::get<0>(var), initialiserType, initialiserConst, Token::Type::EQUAL, m_ErrorHandler); + // **TODO** flag to signify this is an initialiser + m_Environment->assign(std::get<0>(var), initialiserType, Token::Type::EQUAL, m_ErrorHandler); } } } @@ -493,18 +480,17 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // Private methods //--------------------------------------------------------------------------- - std::tuple evaluateType(const Expression::Base *expression) + const Type::QualifiedType &evaluateType(const Expression::Base *expression) { expression->accept(*this); - return std::make_tuple(m_Type, m_Const); + return m_QualifiedType; } //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- Environment *m_Environment; - const Type::Base *m_Type; - bool m_Const; + Type::QualifiedType m_QualifiedType; ErrorHandler &m_ErrorHandler; bool m_InLoop; @@ -515,42 +501,43 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // MiniParse::TypeChecker::Environment //--------------------------------------------------------------------------- -void Environment::define(const Token &name, const Type::Base *type, bool isConst, ErrorHandler &errorHandler) +void Environment::define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandler &errorHandler) { - if(!m_Types.try_emplace(name.lexeme, type, isConst).second) { + if(!m_Types.try_emplace(name.lexeme, qualifiedType).second) { errorHandler.error(name, "Redeclaration of variable"); throw TypeCheckError(); } } //--------------------------------------------------------------------------- -const Type::Base *Environment::assign(const Token &name, const Type::Base *assignedType, bool assignedConst, - Token::Type op, ErrorHandler &errorHandler) +const Type::QualifiedType &Environment::assign(const Token &name, const Type::QualifiedType &assignedType, + Token::Type op, ErrorHandler &errorHandler) { // If type isn't found auto existingType = m_Types.find(name.lexeme); if(existingType == m_Types.end()) { if(m_Enclosing) { - return m_Enclosing->assign(name, assignedType, - assignedConst, op, errorHandler); + return m_Enclosing->assign(name, assignedType, + op, errorHandler); } else { errorHandler.error(name, "Undefined variable"); throw TypeCheckError(); } } - // Otherwise, if type is found and it's const, give error - else if(std::get<1>(existingType->second)) { + + // If existing type is a constant numeric value or if it's a constant pointer give errors + auto numericExistingType = dynamic_cast(existingType->second.type); + auto numericPtrExistingType = dynamic_cast(existingType->second.type); + if((numericExistingType && existingType->second.constValue) + || (numericPtrExistingType && existingType->second.constPointer)) + { errorHandler.error(name, "Assignment of read-only variable"); throw TypeCheckError(); } - - auto numericExistingType = dynamic_cast(std::get<0>(existingType->second)); - auto numericAssignedType = dynamic_cast(assignedType); - - auto numericPtrExistingType = dynamic_cast(std::get<0>(existingType->second)); - auto numericPtrAssignedType = dynamic_cast(assignedType); - + // If assignment operation is plain equals, any type is fine so return + auto numericAssignedType = dynamic_cast(assignedType.type); + auto numericPtrAssignedType = dynamic_cast(assignedType.type); // **TODO** pointer type check if(op == Token::Type::EQUAL) { // If we're initialising a pointer with another pointer @@ -570,7 +557,7 @@ const Type::Base *Environment::assign(const Token &name, const Type::Base *assig } // Otherwise, if we're trying to initialise a pointer with a non-pointer or vice-versa else if (numericPtrAssignedType || numericPtrExistingType) { - errorHandler.error(name, "Invalid operand types '" + std::get<0>(existingType->second)->getTypeName() + "' and '" + assignedType->getTypeName()); + errorHandler.error(name, "Invalid operand types '" + existingType->second.type->getTypeName() + "' and '" + assignedType.type->getTypeName()); throw TypeCheckError(); } } @@ -579,7 +566,7 @@ const Type::Base *Environment::assign(const Token &name, const Type::Base *assig // If the operand being added isn't numeric or the type being added to is neither numeric or a pointer if (!numericAssignedType || (!numericPtrExistingType && !numericExistingType)) { - errorHandler.error(name, "Invalid operand types '" + std::get<0>(existingType->second)->getTypeName() + "' and '" + assignedType->getTypeName() + "'"); + errorHandler.error(name, "Invalid operand types '" + existingType->second.type->getTypeName() + "' and '" + assignedType.type->getTypeName() + "'"); throw TypeCheckError(); } @@ -597,7 +584,7 @@ const Type::Base *Environment::assign(const Token &name, const Type::Base *assig throw TypeCheckError(); } if(!numericExistingType) { - errorHandler.error(name, "Invalid operand types '" + std::get<0>(existingType->second)->getTypeName() + "'"); + errorHandler.error(name, "Invalid operand types '" + existingType->second.type->getTypeName() + "'"); throw TypeCheckError(); } @@ -615,10 +602,11 @@ const Type::Base *Environment::assign(const Token &name, const Type::Base *assig } // Return existing type - return std::get<0>(existingType->second); + // **THINK** + return existingType->second; } //--------------------------------------------------------------------------- -const Type::Base *Environment::incDec(const Token &name, const Token &op, ErrorHandler &errorHandler) +const Type::QualifiedType &Environment::incDec(const Token &name, const Token &op, ErrorHandler &errorHandler) { auto existingType = m_Types.find(name.lexeme); if(existingType == m_Types.end()) { @@ -630,26 +618,23 @@ const Type::Base *Environment::incDec(const Token &name, const Token &op, ErrorH throw TypeCheckError(); } } - // Otherwise, if type is found and it's const, give error - else if(std::get<1>(existingType->second)) { + + // If existing type is a constant numeric value or if it's a constant pointer give errors + auto numericExistingType = dynamic_cast(existingType->second.type); + auto numericPtrExistingType = dynamic_cast(existingType->second.type); + if((numericExistingType && existingType->second.constValue) + || (numericPtrExistingType && existingType->second.constPointer)) + { errorHandler.error(name, "Increment/decrement of read-only variable"); throw TypeCheckError(); } // Otherwise, return type - // **TODO** pointer else { - auto numericExistingType = dynamic_cast(std::get<0>(existingType->second)); - if(numericExistingType == nullptr) { - errorHandler.error(op, "Invalid operand types '" + std::get<0>(existingType->second)->getTypeName() + "'"); - throw TypeCheckError(); - } - else { - return std::get<0>(existingType->second); - } + return existingType->second; } } //--------------------------------------------------------------------------- -std::tuple Environment::getType(const Token &name, ErrorHandler &errorHandler) const +const Type::QualifiedType &Environment::getType(const Token &name, ErrorHandler &errorHandler) const { auto type = m_Types.find(std::string{name.lexeme}); if(type == m_Types.end()) { @@ -673,9 +658,9 @@ void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &st visitor.typeCheck(statements, environment); } //--------------------------------------------------------------------------- -std::tuple GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, - Environment &environment, - ErrorHandler &errorHandler) +Type::QualifiedType GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, + Environment &environment, + ErrorHandler &errorHandler) { Visitor visitor(errorHandler); return visitor.typeCheck(expression, environment); diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 63b8cba4a9..5d18413610 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -67,7 +67,7 @@ void typeCheckStatements(std::string_view code, TypeChecker::Environment &typeEn ASSERT_FALSE(errorHandler.hasError()); } -std::tuple typeCheckExpression(std::string_view code, TypeChecker::Environment &typeEnvironment) +Type::QualifiedType typeCheckExpression(std::string_view code, TypeChecker::Environment &typeEnvironment) { // Scan TestErrorHandler errorHandler; @@ -174,8 +174,9 @@ TEST(TypeChecker, Cast) TypeChecker::Environment typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("(float)intArray", typeEnvironment); - EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Float::getInstance()->getTypeHash()); - EXPECT_FALSE(std::get<1>(type)); + EXPECT_EQ(type.type->getTypeHash(), Type::Float::getInstance()->getTypeHash()); + EXPECT_FALSE(type.constValue); + EXPECT_FALSE(type.constPointer); } // Pointer cast can't reinterpret @@ -196,45 +197,79 @@ TEST(TypeChecker, Literal) { TypeChecker::Environment typeEnvironment; const auto type = typeCheckExpression("1.0f", typeEnvironment); - EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Float::getInstance()->getTypeHash()); - EXPECT_TRUE(std::get<1>(type)); + EXPECT_EQ(type.type->getTypeHash(), Type::Float::getInstance()->getTypeHash()); + EXPECT_TRUE(type.constValue); + EXPECT_FALSE(type.constPointer); } // Double { TypeChecker::Environment typeEnvironment; const auto type = typeCheckExpression("1.0", typeEnvironment); - EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Double::getInstance()->getTypeHash()); - EXPECT_TRUE(std::get<1>(type)); + EXPECT_EQ(type.type->getTypeHash(), Type::Double::getInstance()->getTypeHash()); + EXPECT_TRUE(type.constValue); + EXPECT_FALSE(type.constPointer); } // Integer { TypeChecker::Environment typeEnvironment; const auto type = typeCheckExpression("100", typeEnvironment); - EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); - EXPECT_TRUE(std::get<1>(type)); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_TRUE(type.constValue); + EXPECT_FALSE(type.constPointer); } // Unsigned integer { TypeChecker::Environment typeEnvironment; const auto type = typeCheckExpression("100U", typeEnvironment); - EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Uint32::getInstance()->getTypeHash()); - EXPECT_TRUE(std::get<1>(type)); + EXPECT_EQ(type.type->getTypeHash(), Type::Uint32::getInstance()->getTypeHash()); + EXPECT_TRUE(type.constValue); + EXPECT_FALSE(type.constPointer); } } //-------------------------------------------------------------------------- TEST(TypeChecker, Unary) { // Dereference pointer - // **TODO** const { TypeChecker::Environment typeEnvironment; typeEnvironment.define("intArray"); const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); - EXPECT_FALSE(std::get<1>(type)); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_FALSE(type.constValue); + EXPECT_FALSE(type.constPointer); + } + + // Dereference pointer to const + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray", true); + const auto type = typeCheckExpression("*intArray", typeEnvironment); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_TRUE(type.constValue); + EXPECT_FALSE(type.constPointer); + } + + // Dereference const pointer + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray", false, true); + const auto type = typeCheckExpression("*intArray", typeEnvironment); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_FALSE(type.constValue); + EXPECT_TRUE(type.constPointer); + } + + // Dereference const pointer to const + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray", true, true); + const auto type = typeCheckExpression("*intArray", typeEnvironment); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_TRUE(type.constValue); + EXPECT_TRUE(type.constPointer); } // Dereference numeric @@ -250,7 +285,8 @@ TEST(TypeChecker, Unary) TypeChecker::Environment typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("&intVal", typeEnvironment); - EXPECT_EQ(std::get<0>(type)->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); - EXPECT_FALSE(std::get<1>(type)); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_FALSE(type.constValue); + EXPECT_FALSE(type.constPointer); } -} \ No newline at end of file +} From cabd83c6d4d8c71fd8b7b5b210980a6d9debd3e4 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 09:56:14 +0000 Subject: [PATCH 018/725] fixed typo in declaration parser --- src/genn/genn/transpiler/parser.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 74f871c882..876e43bd8f 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -214,8 +214,8 @@ GeNN::Type::QualifiedType parseDeclarationSpecifiers(ParserState &parserState) // Lookup type based on whether token was found const GeNN::Type::Base *type = (pointerFound - ? static_cast(GeNN::Type::getNumericType(typeSpecifiers)) - : static_cast(GeNN::Type::getNumericPtrType(typeSpecifiers))); + ? static_cast(GeNN::Type::getNumericPtrType(typeSpecifiers)) + : static_cast(GeNN::Type::getNumericType(typeSpecifiers))); // Return qualified type // **THINK** this relies of const being only qualifier From e0d83f6ab056e1b006214825a0491ff2748a543c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 09:56:34 +0000 Subject: [PATCH 019/725] fixed typos in type-checker tests --- tests/unit/typeChecker.cc | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 5d18413610..fdad5d3b85 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -128,7 +128,7 @@ TEST(TypeChecker, Assignment) "float z = intVal;\n" "int wc = intValConst;\n" "const int cw = intVal;\n" - "const int cwc = invValConst;\n", + "const int cwc = intValConst;\n", typeEnvironment); } @@ -139,8 +139,8 @@ TEST(TypeChecker, Assignment) typeEnvironment.define("intArrayConst", true); typeCheckStatements( "int *x = intArray;\n" - "const *y = intArray;\n" - "const *z = intArrayConst;\n", + "const int *y = intArray;\n" + "const int *z = intArrayConst;\n", typeEnvironment); } @@ -241,7 +241,7 @@ TEST(TypeChecker, Unary) EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } - + // Dereference pointer to const { TypeChecker::Environment typeEnvironment; @@ -251,7 +251,7 @@ TEST(TypeChecker, Unary) EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } - + // Dereference const pointer { TypeChecker::Environment typeEnvironment; @@ -259,9 +259,9 @@ TEST(TypeChecker, Unary) const auto type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); EXPECT_FALSE(type.constValue); - EXPECT_TRUE(type.constPointer); + EXPECT_FALSE(type.constPointer); } - + // Dereference const pointer to const { TypeChecker::Environment typeEnvironment; @@ -269,24 +269,30 @@ TEST(TypeChecker, Unary) const auto type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); EXPECT_TRUE(type.constValue); - EXPECT_TRUE(type.constPointer); + EXPECT_FALSE(type.constPointer); } // Dereference numeric EXPECT_THROW({ TypeChecker::Environment typeEnvironment; typeEnvironment.define("intVal"); - typeCheckExpression("*intVal", typeEnvironment);}, + typeCheckExpression("*intVal", typeEnvironment); }, TypeChecker::TypeCheckError); // Address of numeric - // **TODO** const { TypeChecker::Environment typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("&intVal", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } + + // Address of pointer + EXPECT_THROW({ + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray"); + typeCheckExpression("&intArray", typeEnvironment);}, + TypeChecker::TypeCheckError); } From 73c4a691ea974b0be9a606f81c66ca21e1f32d7f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 09:57:04 +0000 Subject: [PATCH 020/725] fixed const-correctness checking in typechecker assignement --- include/genn/genn/transpiler/typeChecker.h | 2 +- src/genn/genn/transpiler/typeChecker.cc | 21 +++++++++------------ 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 89a5fbbda2..de033ad47f 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -54,7 +54,7 @@ class Environment } void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandler &errorHandler); const Type::QualifiedType &assign(const Token &name, const Type::QualifiedType &assignedType, - Token::Type op, ErrorHandler &errorHandler); + Token::Type op, ErrorHandler &errorHandler, bool initializer = false); const Type::QualifiedType &incDec(const Token &name, const Token &op, ErrorHandler &errorHandler); const Type::QualifiedType &getType(const Token &name, ErrorHandler &errorHandler) const; diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 98e985d849..4b8555206b 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -457,8 +457,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto initialiserType = evaluateType(std::get<1>(var).get()); // Assign initialiser expression to variable - // **TODO** flag to signify this is an initialiser - m_Environment->assign(std::get<0>(var), initialiserType, Token::Type::EQUAL, m_ErrorHandler); + m_Environment->assign(std::get<0>(var), initialiserType, Token::Type::EQUAL, m_ErrorHandler, true); } } } @@ -510,14 +509,14 @@ void Environment::define(const Token &name, const Type::QualifiedType &qualified } //--------------------------------------------------------------------------- const Type::QualifiedType &Environment::assign(const Token &name, const Type::QualifiedType &assignedType, - Token::Type op, ErrorHandler &errorHandler) + Token::Type op, ErrorHandler &errorHandler, bool initializer) { // If type isn't found auto existingType = m_Types.find(name.lexeme); if(existingType == m_Types.end()) { if(m_Enclosing) { return m_Enclosing->assign(name, assignedType, - op, errorHandler); + op, errorHandler, initializer); } else { errorHandler.error(name, "Undefined variable"); @@ -528,8 +527,8 @@ const Type::QualifiedType &Environment::assign(const Token &name, const Type::Qu // If existing type is a constant numeric value or if it's a constant pointer give errors auto numericExistingType = dynamic_cast(existingType->second.type); auto numericPtrExistingType = dynamic_cast(existingType->second.type); - if((numericExistingType && existingType->second.constValue) - || (numericPtrExistingType && existingType->second.constPointer)) + if(!initializer && ((numericExistingType && existingType->second.constValue) + || (numericPtrExistingType && existingType->second.constPointer))) { errorHandler.error(name, "Assignment of read-only variable"); throw TypeCheckError(); @@ -538,16 +537,14 @@ const Type::QualifiedType &Environment::assign(const Token &name, const Type::Qu // If assignment operation is plain equals, any type is fine so return auto numericAssignedType = dynamic_cast(assignedType.type); auto numericPtrAssignedType = dynamic_cast(assignedType.type); - // **TODO** pointer type check if(op == Token::Type::EQUAL) { // If we're initialising a pointer with another pointer if (numericPtrAssignedType && numericPtrExistingType) { - // If variable is non-const but initialiser is const - /*if (!varDeclaration.isConst() && intialiserConst) { - m_ErrorHandler.error(std::get<0>(var), - "Invalid operand types '" + initialiserType->getTypeName() + "'"); + // If we're trying to assign a pointer to a const value to a pointer + if (assignedType.constValue && !existingType->second.constValue) { + errorHandler.error(name, "Invalid operand types '" + numericPtrExistingType->getTypeName() + "' and '" + numericPtrAssignedType->getTypeName()); throw TypeCheckError(); - }*/ + } // If pointer types aren't compatible if (numericPtrExistingType->getTypeHash() != numericPtrAssignedType->getTypeHash()) { From f67bf1b159110e977d87440637943d24dcac887f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 10:21:52 +0000 Subject: [PATCH 021/725] include closing parenthesis token in cast expression for error handling --- include/genn/genn/transpiler/expression.h | 10 ++++++---- src/genn/genn/transpiler/parser.cc | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index 083fc56780..f936bf6b59 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -122,18 +122,20 @@ class Call : public Base class Cast : public Base { public: - Cast(const Type::QualifiedType &qualifiedType, ExpressionPtr expression) - : m_QualifiedType(qualifiedType), m_Expression(std::move(expression)) + Cast(const Type::QualifiedType &qualifiedType, ExpressionPtr expression, Token closingParen) + : m_QualifiedType(qualifiedType), m_Expression(std::move(expression)), m_ClosingParen(closingParen) {} virtual void accept(Visitor &visitor) const final; - const Base *getExpression() const { return m_Expression.get(); } const Type::QualifiedType &getQualifiedType() const{ return m_QualifiedType; } - + const Base *getExpression() const { return m_Expression.get(); } + const Token &getClosingParen() const { return m_ClosingParen; } + private: const Type::QualifiedType m_QualifiedType; const ExpressionPtr m_Expression; + const Token m_ClosingParen; }; //--------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 876e43bd8f..f92166d26b 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -373,9 +373,9 @@ Expression::ExpressionPtr parseCast(ParserState &parserState) // Parse declaration specifiers const auto qualifiedType = parseDeclarationSpecifiers(parserState); - parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after cast type."); + const auto closingParen = parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after cast type."); - return std::make_unique(qualifiedType, parseCast(parserState)); + return std::make_unique(qualifiedType, parseCast(parserState), closingParen); } // Otherwise, rewind parser state so left parenthesis can be parsed again // **YUCK** From 8b94c92315e9564f225f1b5e0555d3e8ffd78728 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 10:22:03 +0000 Subject: [PATCH 022/725] type checks in cast --- src/genn/genn/transpiler/typeChecker.cc | 33 +++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 4b8555206b..26506384a7 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -218,8 +218,37 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Cast &cast) final { - // **TODO** any numeric can be cast to any numeric and any pointer to pointer but no intermixing - // **TODO** const cannot be removed like this + // Evaluate type of expression we're casting + const auto rightType = evaluateType(cast.getExpression()); + + // If value const is being removed + if (rightType.constValue && !cast.getQualifiedType().constValue) { + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getTypeName() + "' and '" + rightType.type->getTypeName()); + throw TypeCheckError(); + } + // Otherwise, if pointer const is being removed + else if (rightType.constPointer && !cast.getQualifiedType().constPointer) { + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getTypeName() + "' and '" + rightType.type->getTypeName()); + throw TypeCheckError(); + } + + // If we're trying to cast pointer to pointer + auto rightNumericType = dynamic_cast(rightType.type); + auto rightNumericPtrType = dynamic_cast(rightType.type); + auto leftNumericType = dynamic_cast(cast.getQualifiedType().type); + auto leftNumericPtrType = dynamic_cast(cast.getQualifiedType().type); + if (rightNumericPtrType && leftNumericPtrType) { + if (rightNumericPtrType->getTypeHash() != leftNumericPtrType->getTypeHash()) { + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getTypeName() + "' and '" + rightType.type->getTypeName()); + throw TypeCheckError(); + } + } + // Otherwise, if either operand isn't numeric + else if(!leftNumericType | !rightNumericType) { + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getTypeName() + "' and '" + rightType.type->getTypeName()); + throw TypeCheckError(); + } + m_QualifiedType = cast.getQualifiedType(); } From e1a843bb77149ae5bfc5b4dc86c2018d4fbd7b0b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 10:22:15 +0000 Subject: [PATCH 023/725] more casting type checks --- tests/unit/typeChecker.cc | 68 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index fdad5d3b85..c5f006321c 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -173,18 +173,84 @@ TEST(TypeChecker, Cast) { TypeChecker::Environment typeEnvironment; typeEnvironment.define("intVal"); - const auto type = typeCheckExpression("(float)intArray", typeEnvironment); + const auto type = typeCheckExpression("(float)intVal", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Float::getInstance()->getTypeHash()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } + // Numeric cast to const + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intVal"); + const auto type = typeCheckExpression("(const int)intVal", typeEnvironment); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_TRUE(type.constValue); + EXPECT_FALSE(type.constPointer); + } + + // Pointer cast to value const + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray"); + const auto type = typeCheckExpression("(const int*)intArray", typeEnvironment); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); + EXPECT_TRUE(type.constValue); + EXPECT_FALSE(type.constPointer); + } + + // Pointer cast to pointer const + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray"); + const auto type = typeCheckExpression("(int * const)intArray", typeEnvironment); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); + EXPECT_FALSE(type.constValue); + EXPECT_TRUE(type.constPointer); + } + + + // Can't remove value const from numeric + EXPECT_THROW({ + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intVal", true); + typeCheckExpression("(int)intVal", typeEnvironment);}, + TypeChecker::TypeCheckError); + + // Can't remove value const from pointer + EXPECT_THROW({ + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray", true); + typeCheckExpression("(int*)intArray", typeEnvironment);}, + TypeChecker::TypeCheckError); + + // Can't remove pointer const from pointer + EXPECT_THROW({ + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray", false, true); + typeCheckExpression("(int*)intArray", typeEnvironment);}, + TypeChecker::TypeCheckError); + // Pointer cast can't reinterpret EXPECT_THROW({ TypeChecker::Environment typeEnvironment; typeEnvironment.define("intArray"); typeCheckExpression("(float*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); + + // Pointer can't be cast to numeric + EXPECT_THROW({ + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray"); + typeCheckExpression("(int)intArray", typeEnvironment);}, + TypeChecker::TypeCheckError); + + // Numeric can't be cast to pointer + EXPECT_THROW({ + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intVal"); + typeCheckExpression("(int*)intVal", typeEnvironment);}, + TypeChecker::TypeCheckError); } //-------------------------------------------------------------------------- TEST(TypeChecker, Conditional) From 7c4d24498ba73d22036f41becfa36dc87c524e08 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 10:29:54 +0000 Subject: [PATCH 024/725] incdec expression tests --- tests/unit/typeChecker.cc | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index c5f006321c..8046673ff3 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -257,6 +257,53 @@ TEST(TypeChecker, Conditional) { } //-------------------------------------------------------------------------- +TEST(TypeChecker, IncDec) +{ + // Can increment numeric + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intVal"); + const auto type = typeCheckExpression("intVal++", typeEnvironment); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_FALSE(type.constValue); + EXPECT_FALSE(type.constPointer); + } + + // Can increment pointer + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray"); + const auto type = typeCheckExpression("intArray++", typeEnvironment); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); + EXPECT_FALSE(type.constValue); + EXPECT_FALSE(type.constPointer); + } + + // Can increment pointer to const + { + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray", true); + const auto type = typeCheckExpression("intArray++", typeEnvironment); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); + EXPECT_TRUE(type.constValue); + EXPECT_FALSE(type.constPointer); + } + + // Can't increment const number + EXPECT_THROW({ + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intVal", true); + typeCheckExpression("intVal++", typeEnvironment);}, + TypeChecker::TypeCheckError); + + // Can't increment const pointer + EXPECT_THROW({ + TypeChecker::Environment typeEnvironment; + typeEnvironment.define("intArray", false, true); + typeCheckExpression("intArray++", typeEnvironment);}, + TypeChecker::TypeCheckError); +} +//-------------------------------------------------------------------------- TEST(TypeChecker, Literal) { // Float From 8a857e5a684d463b28dc5983f8c6605f7217c662 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 10:33:35 +0000 Subject: [PATCH 025/725] improved array subscript tests --- tests/unit/typeChecker.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 8046673ff3..1649dca430 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -94,14 +94,17 @@ TEST(TypeChecker, ArraySubscript) { TypeChecker::Environment typeEnvironment; typeEnvironment.define("intArray"); - typeCheckStatements("int x = intArray[4];", typeEnvironment); + const auto type = typeCheckExpression("intArray[4]", typeEnvironment); + EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_FALSE(type.constValue); + EXPECT_FALSE(type.constPointer); } // Float array indexing EXPECT_THROW({ TypeChecker::Environment typeEnvironment; typeEnvironment.define("intArray"); - typeCheckStatements("int x = intArray[4.0f];", typeEnvironment);}, + const auto type = typeCheckExpression("intArray[4.0f]", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer indexing @@ -109,7 +112,7 @@ TEST(TypeChecker, ArraySubscript) TypeChecker::Environment typeEnvironment; typeEnvironment.define("intArray"); typeEnvironment.define("indexArray"); - typeCheckStatements("int x = intArray[indexArray];", typeEnvironment);}, + const auto type = typeCheckExpression("intArray[indexArray]", typeEnvironment);}, TypeChecker::TypeCheckError); } //-------------------------------------------------------------------------- @@ -157,6 +160,8 @@ TEST(TypeChecker, Assignment) typeEnvironment.define("intArray"); typeCheckStatements("float *x = intArray;", typeEnvironment);}, TypeChecker::TypeCheckError); + + // **TODO** other assignements i.e. += -= %= } //-------------------------------------------------------------------------- TEST(TypeChecker, Binary) @@ -209,7 +214,6 @@ TEST(TypeChecker, Cast) EXPECT_TRUE(type.constPointer); } - // Can't remove value const from numeric EXPECT_THROW({ TypeChecker::Environment typeEnvironment; From 820cf4d518ada922d685a91e512e7d592a23594f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 11:37:59 +0000 Subject: [PATCH 026/725] Switched Models::Var to using Type::NumericBase * Exposed simple type parser * Added Type::parseNumeric to parse direct from string to (unqualified) type --- include/genn/genn/logging.h | 15 +++++++-- include/genn/genn/snippet.h | 1 + include/genn/genn/transpiler/parser.h | 8 +++++ include/genn/genn/type.h | 11 ++++++- src/genn/genn/logging.cc | 13 ++++++-- src/genn/genn/transpiler/parser.cc | 32 +++++++++++++++++- src/genn/genn/type.cc | 47 +++++++++++++++++++++++++++ 7 files changed, 120 insertions(+), 7 deletions(-) diff --git a/include/genn/genn/logging.h b/include/genn/genn/logging.h index 9e70b24221..4849d7c982 100644 --- a/include/genn/genn/logging.h +++ b/include/genn/genn/logging.h @@ -32,6 +32,14 @@ class IAppender; #define LOGE_CODE_GEN LOGE_(Logging::CHANNEL_CODE_GEN) #define LOGF_CODE_GEN LOGF_(Logging::CHANNEL_CODE_GEN) +// Shorthand macros for logging to 'transpiler' channel +#define LOGV_TRANSPILER LOGV_(Logging::CHANNEL_TRANSPILER) +#define LOGD_TRANSPILER LOGD_(Logging::CHANNEL_TRANSPILER) +#define LOGI_TRANSPILER LOGI_(Logging::CHANNEL_TRANSPILER) +#define LOGW_TRANSPILER LOGW_(Logging::CHANNEL_TRANSPILER) +#define LOGE_TRANSPILER LOGE_(Logging::CHANNEL_TRANSPILER) +#define LOGF_TRANSPILER LOGF_(Logging::CHANNEL_TRANSPILER) + // Shorthand macros for logging to 'backend' channel #define LOGV_BACKEND LOGV_(Logging::CHANNEL_BACKEND) #define LOGD_BACKEND LOGD_(Logging::CHANNEL_BACKEND) @@ -50,10 +58,11 @@ enum Channel { CHANNEL_GENN = 0, CHANNEL_CODE_GEN = 1, - CHANNEL_BACKEND = 2, + CHANNEL_TRANSPILER = 2, + CHANNEL_BACKEND = 3, CHANNEL_MAX }; -GENN_EXPORT void init(plog::Severity gennLevel, plog::Severity codeGeneratorLevel, - plog::IAppender *gennAppender, plog::IAppender *codeGeneratorAppender); +GENN_EXPORT void init(plog::Severity gennLevel, plog::Severity codeGeneratorLevel, plog::Severity transpilerLevel, + plog::IAppender *gennAppender, plog::IAppender *codeGeneratorAppender, plog::IAppender *transpilerAppender); } diff --git a/include/genn/genn/snippet.h b/include/genn/genn/snippet.h index 598fe18ebb..f6ff4ab67f 100644 --- a/include/genn/genn/snippet.h +++ b/include/genn/genn/snippet.h @@ -13,6 +13,7 @@ // GeNN includes #include "gennExport.h" #include "gennUtils.h" +#include "type.h" //---------------------------------------------------------------------------- // Macros diff --git a/include/genn/genn/transpiler/parser.h b/include/genn/genn/transpiler/parser.h index 405246410f..4ca4cd523f 100644 --- a/include/genn/genn/transpiler/parser.h +++ b/include/genn/genn/transpiler/parser.h @@ -2,6 +2,7 @@ // Standard C++ includes #include +#include #include // Transpiler includes @@ -20,7 +21,14 @@ class ErrorHandler; //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Parser { +//! Parse expression from tokens Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandler &errorHandler); +//! Parse block item list from tokens +/*! Block item lists are function body scope list of statements */ Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandler &errorHandler); + +//! Parse type from tokens +const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, ErrorHandler &errorHandler); + } // MiniParse::MiniParse \ No newline at end of file diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 0a9baf13da..6541293a87 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -308,9 +308,18 @@ DECLARE_NUMERIC_TYPE(Double, double, 60); DECLARE_FOREIGN_FUNCTION_TYPE(Exp, Double, Double); DECLARE_FOREIGN_FUNCTION_TYPE(Sqrt, Double, Double); -//! Look up type based on set of type specifiers +//! Parse a numeric type +const NumericBase *parseNumeric(std::string_view typeString); + +//! Look up numeric type based on set of type specifiers const NumericBase *getNumericType(const std::set &typeSpecifiers); + +//! Look up numeric pointer type based on set of type specifiers const NumericPtrBase *getNumericPtrType(const std::set &typeSpecifiers); + +//! Apply C type promotion rules to numeric type const NumericBase *getPromotedType(const NumericBase *type); + +//! Apply C rules to get common type between numeric types a and b const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b); } // namespace GeNN::Type diff --git a/src/genn/genn/logging.cc b/src/genn/genn/logging.cc index 027ccf9816..c1f87eda37 100644 --- a/src/genn/genn/logging.cc +++ b/src/genn/genn/logging.cc @@ -3,8 +3,8 @@ //---------------------------------------------------------------------------- // GeNN::Logging //---------------------------------------------------------------------------- -void GeNN::Logging::init(plog::Severity gennLevel, plog::Severity codeGeneratorLevel, - plog::IAppender *gennAppender, plog::IAppender *codeGeneratorAppender) +void GeNN::Logging::init(plog::Severity gennLevel, plog::Severity codeGeneratorLevel, plog::Severity transpilerLevel, + plog::IAppender *gennAppender, plog::IAppender *codeGeneratorAppender, plog::IAppender *transpilerAppender) { // If there isn't already a plog instance, initialise one if(plog::get() == nullptr) { @@ -23,4 +23,13 @@ void GeNN::Logging::init(plog::Severity gennLevel, plog::Severity codeGeneratorL else { plog::get()->setMaxSeverity(codeGeneratorLevel); } + + // If there isn't already a plog instance, initialise one + if(plog::get() == nullptr) { + plog::init(transpilerLevel, transpilerAppender); + } + // Otherwise, set it's max severity from GeNN preferences + else { + plog::get()->setMaxSeverity(transpilerLevel); + } } diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index f92166d26b..c4a9577ff6 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -854,7 +854,7 @@ Expression::ExpressionPtr parseExpression(const std::vector &tokens, Erro return nullptr; } } - +//--------------------------------------------------------------------------- Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandler &errorHandler) { ParserState parserState(tokens, errorHandler); @@ -865,4 +865,34 @@ Statement::StatementList parseBlockItemList(const std::vector &tokens, Er } return statements; } +//--------------------------------------------------------------------------- +const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, ErrorHandler &errorHandler) +{ + ParserState parserState(tokens, errorHandler); + bool pointerFound = false; + std::set typeSpecifiers; + do { + // If token is a star, set pointer found flag + if(parserState.previous().type == Token::Type::STAR) { + if (!allowPointers) { + parserState.error(parserState.previous(), "pointer type not valid in this context"); + } + pointerFound = true; + } + // Otherwise, if token is type specifier + else if(parserState.previous().type == Token::Type::TYPE_SPECIFIER) { + if(pointerFound) { + parserState.error(parserState.previous(), "invalid type specifier"); + } + else if(!typeSpecifiers.insert(parserState.previous().lexeme).second) { + parserState.error(parserState.previous(), "duplicate type specifier"); + } + } + } while(parserState.match({Token::Type::TYPE_SPECIFIER, Token::Type::STAR})); + + // Lookup type based on whether token was found + return (pointerFound + ? static_cast(GeNN::Type::getNumericPtrType(typeSpecifiers)) + : static_cast(GeNN::Type::getNumericType(typeSpecifiers))); +} } diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 8dd8b4a6b2..233d9919e8 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -4,6 +4,14 @@ #include #include +// GeNN includes +#include "logging.h" + +// Transpiler includes +#include "transpiler/errorHandler.h" +#include "transpiler/parser.h" +#include "transpiler/scanner.h" + using namespace GeNN; // Anonymous namespace @@ -39,6 +47,29 @@ const std::unordered_map uns {Type::Int16::getInstance(), Type::Uint16::getInstance()}, {Type::Int32::getInstance(), Type::Uint32::getInstance()} }; + +//---------------------------------------------------------------------------- +// SimpleErrorHandler +//---------------------------------------------------------------------------- +//! Simple error handler used for type parsing - just logs to transpiler log channel +class SimpleErrorHandler : public Transpiler::ErrorHandler +{ +public: + virtual void error(size_t line, std::string_view message) final + { + LOGE_TRANSPILER << "Error: " << message; + } + + virtual void error(const Transpiler::Token &token, std::string_view message) final + { + if(token.type == Transpiler::Token::Type::END_OF_FILE) { + LOGE_TRANSPILER << "Error at end: " << message; + } + else { + LOGE_TRANSPILER << "Error at '" << token.lexeme << "': " << message; + } + } +}; } // Anonymous namespace //---------------------------------------------------------------------------- @@ -64,6 +95,22 @@ IMPLEMENT_TYPE(Sqrt); //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- +const NumericBase *parseNumeric(std::string_view typeString) +{ + using namespace Transpiler; + + // Scan type + SimpleErrorHandler errorHandler; + const auto tokens = Scanner::scanSource(typeString, errorHandler); + + // Parse type, cast to numeric and return + const auto *type = dynamic_cast(Parser::parseType(tokens, false, errorHandler)); + if (!type) { + throw std::runtime_error("Unable to parse type"); + } + return type; +} +//---------------------------------------------------------------------------- const NumericBase *getNumericType(const std::set &typeSpecifiers) { const auto type = numericTypes.find(typeSpecifiers); From 8b17ceaa2666cf845154093a013aecabac2a574e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 16:34:14 +0000 Subject: [PATCH 027/725] * seperated type environments into internal and external with type-checking logic shared between * always create a top-level internal environment when type-checking --- include/genn/genn/transpiler/typeChecker.h | 52 +++-- src/genn/genn/transpiler/typeChecker.cc | 248 ++++++++++++++------- tests/unit/typeChecker.cc | 74 +++--- 3 files changed, 241 insertions(+), 133 deletions(-) diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index de033ad47f..9f2d31681e 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -32,16 +32,37 @@ class TypeCheckError : public std::runtime_error }; //--------------------------------------------------------------------------- -// GeNN::Transpiler::TypeChecker::Environment +// GeNN::Transpiler::TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- -class Environment +class EnvironmentBase { public: - Environment(Environment *enclosing = nullptr) - : m_Enclosing(enclosing) - { - } + //------------------------------------------------------------------------ + // Declared virtuals + //------------------------------------------------------------------------ + virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandler &errorHandler) = 0; + virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, + ErrorHandler &errorHandler, bool initializer = false) = 0; + virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandler &errorHandler) = 0; + virtual const Type::QualifiedType &getType(const Token &name, ErrorHandler &errorHandler) = 0; +protected: + //--------------------------------------------------------------------------- + // Protected API + //--------------------------------------------------------------------------- + const Type::QualifiedType &assign(const Token &name, Token::Type op, + const Type::QualifiedType &existingType, const Type::QualifiedType &assignedType, + ErrorHandler &errorHandler, bool initializer = false) const; + const Type::QualifiedType &incDec(const Token &name, Token::Type op, + const Type::QualifiedType &existingType, ErrorHandler &errorHandler) const; +}; + +//--------------------------------------------------------------------------- +// GeNN::Transpiler::TypeChecker::EnvironmentExternal +//--------------------------------------------------------------------------- +class EnvironmentExternal : public EnvironmentBase +{ +public: //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- @@ -52,26 +73,29 @@ class Environment throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } - void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandler &errorHandler); - const Type::QualifiedType &assign(const Token &name, const Type::QualifiedType &assignedType, - Token::Type op, ErrorHandler &errorHandler, bool initializer = false); - const Type::QualifiedType &incDec(const Token &name, const Token &op, ErrorHandler &errorHandler); - const Type::QualifiedType &getType(const Token &name, ErrorHandler &errorHandler) const; + + //--------------------------------------------------------------------------- + // EnvironmentBase virtuals + //--------------------------------------------------------------------------- + virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandler &errorHandler) final; + virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, + ErrorHandler &errorHandler, bool initializer = false) final; + virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandler &errorHandler) final; + virtual const Type::QualifiedType &getType(const Token &name, ErrorHandler &errorHandler) final; private: //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- - Environment *m_Enclosing; std::unordered_map m_Types; }; //--------------------------------------------------------------------------- // Free functions //--------------------------------------------------------------------------- -void typeCheck(const Statement::StatementList &statements, Environment &environment, +void typeCheck(const Statement::StatementList &statements, EnvironmentExternal &environment, ErrorHandler &errorHandler); -Type::QualifiedType typeCheck(const Expression::Base *expression, Environment &environment, +Type::QualifiedType typeCheck(const Expression::Base *expression, EnvironmentExternal &environment, ErrorHandler &errorHandler); } // namespace MiniParse::GeNN::Transpiler diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 26506384a7..3ca2b1b6ec 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -24,7 +24,92 @@ namespace Type = GeNN::Type; namespace { //--------------------------------------------------------------------------- -// Vistor +// EnvironmentInternal +//--------------------------------------------------------------------------- +class EnvironmentInternal : public EnvironmentBase +{ +public: + EnvironmentInternal(EnvironmentBase *enclosing = nullptr) + : m_Enclosing(enclosing) + { + } + + //--------------------------------------------------------------------------- + // EnvironmentBase virtuals + //--------------------------------------------------------------------------- + virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandler &errorHandler) final + { + if(!m_Types.try_emplace(name.lexeme, qualifiedType).second) { + errorHandler.error(name, "Redeclaration of variable"); + throw TypeCheckError(); + } + } + + virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, + ErrorHandler &errorHandler, bool initializer = false) final + { + // If type isn't found + auto existingType = m_Types.find(name.lexeme); + if(existingType == m_Types.end()) { + if(m_Enclosing) { + return m_Enclosing->assign(name, op, assignedType, + errorHandler, initializer); + } + else { + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); + } + } + + // Perform standard type-checking logic + return EnvironmentBase::assign(name, op, existingType->second, assignedType, errorHandler, initializer); + } + + virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandler &errorHandler) final + { + // If type isn't found + auto existingType = m_Types.find(name.lexeme); + if(existingType == m_Types.end()) { + if(m_Enclosing) { + return m_Enclosing->incDec(name, op, errorHandler); + } + else { + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); + } + } + + // Perform standard type-checking logic + return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); + } + + virtual const Type::QualifiedType &getType(const Token &name, ErrorHandler &errorHandler) final + { + auto type = m_Types.find(std::string{name.lexeme}); + if(type == m_Types.end()) { + if(m_Enclosing) { + return m_Enclosing->getType(name, errorHandler); + } + else { + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); + } + } + else { + return type->second; + } + } + +private: + //--------------------------------------------------------------------------- + // Members + //--------------------------------------------------------------------------- + EnvironmentBase *m_Enclosing; + std::unordered_map m_Types; +}; + +//--------------------------------------------------------------------------- +// Visitor //--------------------------------------------------------------------------- class Visitor : public Expression::Visitor, public Statement::Visitor { @@ -39,25 +124,19 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // Public API //--------------------------------------------------------------------------- // **THINK** make constructors? - void typeCheck(const Statement::StatementList &statements, Environment &environment) + void typeCheck(const Statement::StatementList &statements, EnvironmentInternal &environment) { - Environment *previous = m_Environment; m_Environment = &environment; for (auto &s : statements) { s.get()->accept(*this); } - m_Environment = previous; } - const Type::QualifiedType typeCheck(const Expression::Base *expression, Environment &environment) + const Type::QualifiedType typeCheck(const Expression::Base *expression, EnvironmentInternal &environment) { - Environment *previous = m_Environment; - m_Environment = &environment; - - const auto type = evaluateType(expression); - m_Environment = previous; - return type; + m_Environment = &environment; + return evaluateType(expression); } //--------------------------------------------------------------------------- @@ -93,8 +172,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Assignment &assignment) final { const auto rhsType = evaluateType(assignment.getValue()); - m_QualifiedType = m_Environment->assign(assignment.getVarName(), rhsType, - assignment.getOperator().type, m_ErrorHandler); + m_QualifiedType = m_Environment->assign(assignment.getVarName(), assignment.getOperator().type, rhsType, m_ErrorHandler); } virtual void visit(const Expression::Binary &binary) final @@ -297,13 +375,13 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final { m_QualifiedType = m_Environment->incDec(postfixIncDec.getVarName(), - postfixIncDec.getOperator(), m_ErrorHandler); + postfixIncDec.getOperator().type, m_ErrorHandler); } virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final { m_QualifiedType = m_Environment->incDec(prefixIncDec.getVarName(), - prefixIncDec.getOperator(), m_ErrorHandler); + prefixIncDec.getOperator().type, m_ErrorHandler); } virtual void visit(const Expression::Variable &variable) @@ -380,7 +458,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::Compound &compound) final { - Environment environment(m_Environment); + EnvironmentInternal environment(m_Environment); typeCheck(compound.getStatements(), environment); } @@ -407,8 +485,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::For &forStatement) final { // Create new environment for loop initialisation - Environment *previous = m_Environment; - Environment environment(m_Environment); + EnvironmentInternal *previous = m_Environment; + EnvironmentInternal environment(m_Environment); m_Environment = &environment; // Interpret initialiser if statement present @@ -486,7 +564,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto initialiserType = evaluateType(std::get<1>(var).get()); // Assign initialiser expression to variable - m_Environment->assign(std::get<0>(var), initialiserType, Token::Type::EQUAL, m_ErrorHandler, true); + m_Environment->assign(std::get<0>(var), Token::Type::EQUAL, initialiserType, m_ErrorHandler, true); } } } @@ -517,7 +595,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- - Environment *m_Environment; + EnvironmentInternal *m_Environment; Type::QualifiedType m_QualifiedType; ErrorHandler &m_ErrorHandler; @@ -527,37 +605,17 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } //--------------------------------------------------------------------------- -// MiniParse::TypeChecker::Environment +// GeNN::Transpiler::TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- -void Environment::define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandler &errorHandler) +const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Type op, + const Type::QualifiedType &existingType, const Type::QualifiedType &assignedType, + ErrorHandler &errorHandler, bool initializer) const { - if(!m_Types.try_emplace(name.lexeme, qualifiedType).second) { - errorHandler.error(name, "Redeclaration of variable"); - throw TypeCheckError(); - } -} -//--------------------------------------------------------------------------- -const Type::QualifiedType &Environment::assign(const Token &name, const Type::QualifiedType &assignedType, - Token::Type op, ErrorHandler &errorHandler, bool initializer) -{ - // If type isn't found - auto existingType = m_Types.find(name.lexeme); - if(existingType == m_Types.end()) { - if(m_Enclosing) { - return m_Enclosing->assign(name, assignedType, - op, errorHandler, initializer); - } - else { - errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); - } - } - // If existing type is a constant numeric value or if it's a constant pointer give errors - auto numericExistingType = dynamic_cast(existingType->second.type); - auto numericPtrExistingType = dynamic_cast(existingType->second.type); - if(!initializer && ((numericExistingType && existingType->second.constValue) - || (numericPtrExistingType && existingType->second.constPointer))) + auto numericExistingType = dynamic_cast(existingType.type); + auto numericPtrExistingType = dynamic_cast(existingType.type); + if(!initializer && ((numericExistingType && existingType.constValue) + || (numericPtrExistingType && existingType.constPointer))) { errorHandler.error(name, "Assignment of read-only variable"); throw TypeCheckError(); @@ -570,7 +628,7 @@ const Type::QualifiedType &Environment::assign(const Token &name, const Type::Qu // If we're initialising a pointer with another pointer if (numericPtrAssignedType && numericPtrExistingType) { // If we're trying to assign a pointer to a const value to a pointer - if (assignedType.constValue && !existingType->second.constValue) { + if (assignedType.constValue && !existingType.constValue) { errorHandler.error(name, "Invalid operand types '" + numericPtrExistingType->getTypeName() + "' and '" + numericPtrAssignedType->getTypeName()); throw TypeCheckError(); } @@ -583,7 +641,7 @@ const Type::QualifiedType &Environment::assign(const Token &name, const Type::Qu } // Otherwise, if we're trying to initialise a pointer with a non-pointer or vice-versa else if (numericPtrAssignedType || numericPtrExistingType) { - errorHandler.error(name, "Invalid operand types '" + existingType->second.type->getTypeName() + "' and '" + assignedType.type->getTypeName()); + errorHandler.error(name, "Invalid operand types '" + existingType.type->getTypeName() + "' and '" + assignedType.type->getTypeName()); throw TypeCheckError(); } } @@ -592,7 +650,7 @@ const Type::QualifiedType &Environment::assign(const Token &name, const Type::Qu // If the operand being added isn't numeric or the type being added to is neither numeric or a pointer if (!numericAssignedType || (!numericPtrExistingType && !numericExistingType)) { - errorHandler.error(name, "Invalid operand types '" + existingType->second.type->getTypeName() + "' and '" + assignedType.type->getTypeName() + "'"); + errorHandler.error(name, "Invalid operand types '" + existingType.type->getTypeName() + "' and '" + assignedType.type->getTypeName() + "'"); throw TypeCheckError(); } @@ -610,7 +668,7 @@ const Type::QualifiedType &Environment::assign(const Token &name, const Type::Qu throw TypeCheckError(); } if(!numericExistingType) { - errorHandler.error(name, "Invalid operand types '" + existingType->second.type->getTypeName() + "'"); + errorHandler.error(name, "Invalid operand types '" + existingType.type->getTypeName() + "'"); throw TypeCheckError(); } @@ -629,65 +687,91 @@ const Type::QualifiedType &Environment::assign(const Token &name, const Type::Qu // Return existing type // **THINK** - return existingType->second; + return existingType; } //--------------------------------------------------------------------------- -const Type::QualifiedType &Environment::incDec(const Token &name, const Token &op, ErrorHandler &errorHandler) +const Type::QualifiedType &EnvironmentBase::incDec(const Token &name, Token::Type, + const Type::QualifiedType &existingType, ErrorHandler &errorHandler) const { - auto existingType = m_Types.find(name.lexeme); - if(existingType == m_Types.end()) { - if(m_Enclosing) { - return m_Enclosing->incDec(name, op, errorHandler); - } - else { - errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); - } - } - // If existing type is a constant numeric value or if it's a constant pointer give errors - auto numericExistingType = dynamic_cast(existingType->second.type); - auto numericPtrExistingType = dynamic_cast(existingType->second.type); - if((numericExistingType && existingType->second.constValue) - || (numericPtrExistingType && existingType->second.constPointer)) + auto numericExistingType = dynamic_cast(existingType.type); + auto numericPtrExistingType = dynamic_cast(existingType.type); + if((numericExistingType && existingType.constValue) + || (numericPtrExistingType && existingType.constPointer)) { errorHandler.error(name, "Increment/decrement of read-only variable"); throw TypeCheckError(); } // Otherwise, return type else { - return existingType->second; + return existingType; + } +} + +//--------------------------------------------------------------------------- +// GeNN::Transpiler::TypeChecker::EnvironmentExternal +//--------------------------------------------------------------------------- +void EnvironmentExternal::define(const Token &name, const Type::QualifiedType &, ErrorHandler &errorHandler) +{ + errorHandler.error(name, "Cannot declare variable in external environment"); + throw TypeCheckError(); +} +//--------------------------------------------------------------------------- +const Type::QualifiedType &EnvironmentExternal::assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, + ErrorHandler &errorHandler, bool initializer) +{ + // If type isn't found + auto existingType = m_Types.find(name.lexeme); + if(existingType == m_Types.end()) { + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); + } + + // Perform standard type-checking logic + return EnvironmentBase::assign(name, op, existingType->second, assignedType, errorHandler, initializer); +} +//--------------------------------------------------------------------------- +const Type::QualifiedType &EnvironmentExternal::incDec(const Token &name, Token::Type op, ErrorHandler &errorHandler) +{ + auto existingType = m_Types.find(name.lexeme); + if(existingType == m_Types.end()) { + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); } + + // Perform standard type-checking logic + return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); + } //--------------------------------------------------------------------------- -const Type::QualifiedType &Environment::getType(const Token &name, ErrorHandler &errorHandler) const +const Type::QualifiedType &EnvironmentExternal::getType(const Token &name, ErrorHandler &errorHandler) { auto type = m_Types.find(std::string{name.lexeme}); if(type == m_Types.end()) { - if(m_Enclosing) { - return m_Enclosing->getType(name, errorHandler); - } - else { - errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); - } + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); } else { return type->second; } } + +//--------------------------------------------------------------------------- +// GeNN::Transpiler::TypeChecker //--------------------------------------------------------------------------- -void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &statements, Environment &environment, +void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &statements, EnvironmentExternal &environment, ErrorHandler &errorHandler) { Visitor visitor(errorHandler); - visitor.typeCheck(statements, environment); + EnvironmentInternal internalEnvironment(&environment); + visitor.typeCheck(statements, internalEnvironment); } //--------------------------------------------------------------------------- Type::QualifiedType GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, - Environment &environment, + EnvironmentExternal &environment, ErrorHandler &errorHandler) { Visitor visitor(errorHandler); - return visitor.typeCheck(expression, environment); + EnvironmentInternal internalEnvironment(&environment); + return visitor.typeCheck(expression, internalEnvironment); } diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 1649dca430..47a0b62d95 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -51,7 +51,7 @@ class TestErrorHandler : public ErrorHandler bool m_Error; }; -void typeCheckStatements(std::string_view code, TypeChecker::Environment &typeEnvironment) +void typeCheckStatements(std::string_view code, TypeChecker::EnvironmentExternal &typeEnvironment) { // Scan TestErrorHandler errorHandler; @@ -67,7 +67,7 @@ void typeCheckStatements(std::string_view code, TypeChecker::Environment &typeEn ASSERT_FALSE(errorHandler.hasError()); } -Type::QualifiedType typeCheckExpression(std::string_view code, TypeChecker::Environment &typeEnvironment) +Type::QualifiedType typeCheckExpression(std::string_view code, TypeChecker::EnvironmentExternal &typeEnvironment) { // Scan TestErrorHandler errorHandler; @@ -92,7 +92,7 @@ TEST(TypeChecker, ArraySubscript) { // Integer array indexing { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray"); const auto type = typeCheckExpression("intArray[4]", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); @@ -102,17 +102,17 @@ TEST(TypeChecker, ArraySubscript) // Float array indexing EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray"); - const auto type = typeCheckExpression("intArray[4.0f]", typeEnvironment);}, + typeCheckExpression("intArray[4.0f]", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer indexing EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray"); typeEnvironment.define("indexArray"); - const auto type = typeCheckExpression("intArray[indexArray]", typeEnvironment);}, + typeCheckExpression("intArray[indexArray]", typeEnvironment);}, TypeChecker::TypeCheckError); } //-------------------------------------------------------------------------- @@ -120,7 +120,7 @@ TEST(TypeChecker, Assignment) { // Numeric assignment { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intVal"); typeEnvironment.define("floatVal"); typeEnvironment.define("intValConst", true); @@ -137,7 +137,7 @@ TEST(TypeChecker, Assignment) // Pointer assignement { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray"); typeEnvironment.define("intArrayConst", true); typeCheckStatements( @@ -149,14 +149,14 @@ TEST(TypeChecker, Assignment) // Pointer assignement, attempt to remove const EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray", true); typeCheckStatements("int *x = intArray;", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer assignement without explicit cast EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray"); typeCheckStatements("float *x = intArray;", typeEnvironment);}, TypeChecker::TypeCheckError); @@ -176,7 +176,7 @@ TEST(TypeChecker, Cast) { // Numeric cast { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("(float)intVal", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Float::getInstance()->getTypeHash()); @@ -186,7 +186,7 @@ TEST(TypeChecker, Cast) // Numeric cast to const { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("(const int)intVal", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); @@ -196,7 +196,7 @@ TEST(TypeChecker, Cast) // Pointer cast to value const { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray"); const auto type = typeCheckExpression("(const int*)intArray", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); @@ -206,7 +206,7 @@ TEST(TypeChecker, Cast) // Pointer cast to pointer const { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray"); const auto type = typeCheckExpression("(int * const)intArray", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); @@ -216,42 +216,42 @@ TEST(TypeChecker, Cast) // Can't remove value const from numeric EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intVal", true); typeCheckExpression("(int)intVal", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't remove value const from pointer EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray", true); typeCheckExpression("(int*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't remove pointer const from pointer EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray", false, true); typeCheckExpression("(int*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer cast can't reinterpret EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray"); typeCheckExpression("(float*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer can't be cast to numeric EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray"); typeCheckExpression("(int)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Numeric can't be cast to pointer EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intVal"); typeCheckExpression("(int*)intVal", typeEnvironment);}, TypeChecker::TypeCheckError); @@ -265,7 +265,7 @@ TEST(TypeChecker, IncDec) { // Can increment numeric { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("intVal++", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); @@ -275,7 +275,7 @@ TEST(TypeChecker, IncDec) // Can increment pointer { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray"); const auto type = typeCheckExpression("intArray++", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); @@ -285,7 +285,7 @@ TEST(TypeChecker, IncDec) // Can increment pointer to const { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray", true); const auto type = typeCheckExpression("intArray++", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); @@ -295,14 +295,14 @@ TEST(TypeChecker, IncDec) // Can't increment const number EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intVal", true); typeCheckExpression("intVal++", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't increment const pointer EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray", false, true); typeCheckExpression("intArray++", typeEnvironment);}, TypeChecker::TypeCheckError); @@ -312,7 +312,7 @@ TEST(TypeChecker, Literal) { // Float { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; const auto type = typeCheckExpression("1.0f", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Float::getInstance()->getTypeHash()); EXPECT_TRUE(type.constValue); @@ -321,7 +321,7 @@ TEST(TypeChecker, Literal) // Double { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; const auto type = typeCheckExpression("1.0", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Double::getInstance()->getTypeHash()); EXPECT_TRUE(type.constValue); @@ -330,7 +330,7 @@ TEST(TypeChecker, Literal) // Integer { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; const auto type = typeCheckExpression("100", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); EXPECT_TRUE(type.constValue); @@ -339,7 +339,7 @@ TEST(TypeChecker, Literal) // Unsigned integer { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; const auto type = typeCheckExpression("100U", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Uint32::getInstance()->getTypeHash()); EXPECT_TRUE(type.constValue); @@ -351,7 +351,7 @@ TEST(TypeChecker, Unary) { // Dereference pointer { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray"); const auto type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); @@ -361,7 +361,7 @@ TEST(TypeChecker, Unary) // Dereference pointer to const { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray", true); const auto type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); @@ -371,7 +371,7 @@ TEST(TypeChecker, Unary) // Dereference const pointer { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray", false, true); const auto type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); @@ -381,7 +381,7 @@ TEST(TypeChecker, Unary) // Dereference const pointer to const { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray", true, true); const auto type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); @@ -391,14 +391,14 @@ TEST(TypeChecker, Unary) // Dereference numeric EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intVal"); typeCheckExpression("*intVal", typeEnvironment); }, TypeChecker::TypeCheckError); // Address of numeric { - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("&intVal", typeEnvironment); EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); @@ -408,7 +408,7 @@ TEST(TypeChecker, Unary) // Address of pointer EXPECT_THROW({ - TypeChecker::Environment typeEnvironment; + TypeChecker::EnvironmentExternal typeEnvironment; typeEnvironment.define("intArray"); typeCheckExpression("&intArray", typeEnvironment);}, TypeChecker::TypeCheckError); From 874303ad4b4c62be9ad8fab697722f8c14924e2e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 16:36:54 +0000 Subject: [PATCH 028/725] internal type environments ALWAYS have enclosing type environment so use reference and remove null check --- src/genn/genn/transpiler/typeChecker.cc | 38 +++++++------------------ 1 file changed, 10 insertions(+), 28 deletions(-) diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 3ca2b1b6ec..97ea6f43b4 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -29,7 +29,7 @@ namespace class EnvironmentInternal : public EnvironmentBase { public: - EnvironmentInternal(EnvironmentBase *enclosing = nullptr) + EnvironmentInternal(EnvironmentBase &enclosing) : m_Enclosing(enclosing) { } @@ -51,14 +51,8 @@ class EnvironmentInternal : public EnvironmentBase // If type isn't found auto existingType = m_Types.find(name.lexeme); if(existingType == m_Types.end()) { - if(m_Enclosing) { - return m_Enclosing->assign(name, op, assignedType, - errorHandler, initializer); - } - else { - errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); - } + return m_Enclosing.assign(name, op, assignedType, + errorHandler, initializer); } // Perform standard type-checking logic @@ -70,13 +64,7 @@ class EnvironmentInternal : public EnvironmentBase // If type isn't found auto existingType = m_Types.find(name.lexeme); if(existingType == m_Types.end()) { - if(m_Enclosing) { - return m_Enclosing->incDec(name, op, errorHandler); - } - else { - errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); - } + return m_Enclosing.incDec(name, op, errorHandler); } // Perform standard type-checking logic @@ -87,13 +75,7 @@ class EnvironmentInternal : public EnvironmentBase { auto type = m_Types.find(std::string{name.lexeme}); if(type == m_Types.end()) { - if(m_Enclosing) { - return m_Enclosing->getType(name, errorHandler); - } - else { - errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); - } + return m_Enclosing.getType(name, errorHandler); } else { return type->second; @@ -104,7 +86,7 @@ class EnvironmentInternal : public EnvironmentBase //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- - EnvironmentBase *m_Enclosing; + EnvironmentBase &m_Enclosing; std::unordered_map m_Types; }; @@ -458,7 +440,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::Compound &compound) final { - EnvironmentInternal environment(m_Environment); + EnvironmentInternal environment(*m_Environment); typeCheck(compound.getStatements(), environment); } @@ -486,7 +468,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Create new environment for loop initialisation EnvironmentInternal *previous = m_Environment; - EnvironmentInternal environment(m_Environment); + EnvironmentInternal environment(*m_Environment); m_Environment = &environment; // Interpret initialiser if statement present @@ -763,7 +745,7 @@ void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &st ErrorHandler &errorHandler) { Visitor visitor(errorHandler); - EnvironmentInternal internalEnvironment(&environment); + EnvironmentInternal internalEnvironment(environment); visitor.typeCheck(statements, internalEnvironment); } //--------------------------------------------------------------------------- @@ -772,6 +754,6 @@ Type::QualifiedType GeNN::Transpiler::TypeChecker::typeCheck(const Expression::B ErrorHandler &errorHandler) { Visitor visitor(errorHandler); - EnvironmentInternal internalEnvironment(&environment); + EnvironmentInternal internalEnvironment(environment); return visitor.typeCheck(expression, internalEnvironment); } From 360abdeb64a66eb9b15bd3964508fae717decdd0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 17:05:04 +0000 Subject: [PATCH 029/725] typo fix --- src/genn/genn/code_generator/groupMerged.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 7168cf9eff..fa53ca53d4 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -474,7 +474,7 @@ bool NeuronGroupMergedBase::isPSMVarInitParamReferenced(size_t childIndex, const const auto *varInitSnippet = getSortedArchetypeMergedInSyns().at(childIndex)->getPSVarInitialisers().at(varName).getSnippet(); return isParamReferenced({varInitSnippet->getCode()}, paramName); } -//---------------GeNN::------------------------------------------------------------- +//---------------------------------------------------------------------------- void NeuronGroupMergedBase::addMergedInSynPointerField(const std::string &type, const std::string &name, size_t archetypeIndex, const std::string &prefix) { From 4fb64c6f0632e0cc097346b5e6e400614435430c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 10 Jan 2023 18:05:37 +0000 Subject: [PATCH 030/725] notes on next environment steps --- include/genn/genn/transpiler/typeChecker.h | 34 ++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 9f2d31681e..aca8e41c94 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -60,9 +60,16 @@ class EnvironmentBase //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker::EnvironmentExternal //--------------------------------------------------------------------------- +// template class EnvironmentExternal : public EnvironmentBase { public: + // **THINK** should type need to be same as enclosing group? perhaps this could help with child groups? + //EnvironmentExternal(EnvironmentBase *enclosing) + + //typedef std::function GetFieldValueFunc; + //typedef std::tuple Field; + //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- @@ -73,6 +80,29 @@ class EnvironmentExternal : public EnvironmentBase throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } + + /*void addPointerField(const Type::Base *type, const std::string &name, bool isConstValue = false) + { + assert(dynamic_cast(type)); + + // Define variable type + define(name, type, isConstValue); + + // Add field with pointer type + // **TODO** could also be a const pointer + addField(type->getPointerType(), name, [name](const G &g, size_t) { return devicePrefix + name + g.getName(); }); + + // **TODO** link from type back to field(s) - vector indices would work as we always push back fields + } + + void addVars(const Models::Base::VarVec &vars, const std::string &arrayPrefix) + { + // Loop through variables + for(const auto &v : vars) { + addPointerField(v.type, v.name, (v.access & VarAccessMode::READ_ONLY)); + } + } + */ //--------------------------------------------------------------------------- // EnvironmentBase virtuals @@ -88,6 +118,10 @@ class EnvironmentExternal : public EnvironmentBase // Members //--------------------------------------------------------------------------- std::unordered_map m_Types; + + // **THINK** should fields live in some sort of parent environment external? children are instantiated to type check e.g. child synapse groups + // but we eventually want a flat list of fields and we want that to be located somewhere permanent + //std::vector m_Fields; }; //--------------------------------------------------------------------------- From 4ad08db4cde16c57e2a95a47b25868fbcf019008 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 12 Jan 2023 16:36:57 +0000 Subject: [PATCH 031/725] implemented standard error handler classes --- include/genn/genn/transpiler/errorHandler.h | 75 ++++++++++++++++++- include/genn/genn/transpiler/parser.h | 8 +- include/genn/genn/transpiler/scanner.h | 4 +- include/genn/genn/transpiler/token.h | 6 ++ .../genn/genn/transpiler/transpilerUtils.h | 42 +++++------ include/genn/genn/transpiler/typeChecker.h | 30 ++++---- .../code_generator/customUpdateGroupMerged.cc | 7 ++ src/genn/genn/genn.vcxproj | 1 + src/genn/genn/transpiler/parser.cc | 10 +-- src/genn/genn/transpiler/scanner.cc | 6 +- src/genn/genn/transpiler/typeChecker.cc | 32 ++++---- src/genn/genn/type.cc | 34 +++------ tests/unit/scanner.cc | 2 +- tests/unit/typeChecker.cc | 2 +- 14 files changed, 163 insertions(+), 96 deletions(-) diff --git a/include/genn/genn/transpiler/errorHandler.h b/include/genn/genn/transpiler/errorHandler.h index ceb88fa66f..279dbfcab1 100644 --- a/include/genn/genn/transpiler/errorHandler.h +++ b/include/genn/genn/transpiler/errorHandler.h @@ -7,14 +7,83 @@ #include "transpiler/token.h" //--------------------------------------------------------------------------- -// GeNN::Transpiler::ErrorHandler +// GeNN::Transpiler::ErrorHandlerBase //--------------------------------------------------------------------------- namespace GeNN::Transpiler { -class ErrorHandler +class ErrorHandlerBase { public: + //------------------------------------------------------------------------ + // Declared virtuals + //------------------------------------------------------------------------ virtual void error(size_t line, std::string_view message) = 0; virtual void error(const Token &token, std::string_view message) = 0; }; -} + +//--------------------------------------------------------------------------- +// GeNN::Transpiler::ErrorHandler +//--------------------------------------------------------------------------- +class ErrorHandler : public ErrorHandlerBase +{ +public: + ErrorHandler() : m_Error(false) + { + } + + //------------------------------------------------------------------------ + // ErrorHandlerBase virtuals + //------------------------------------------------------------------------ + virtual void error(size_t line, std::string_view message) final; + virtual void error(const Token &token, std::string_view message) final; + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + bool hasError() const { return m_Error; } + +private: + //------------------------------------------------------------------------ + // Private methods + //------------------------------------------------------------------------ + void report(size_t line, std::string_view where, std::string_view message); + + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + bool m_Error; +}; + +//--------------------------------------------------------------------------- +// GeNN::Transpiler::SingleLineErrorHandler +//--------------------------------------------------------------------------- +class SingleLineErrorHandler : public ErrorHandlerBase +{ +public: + SingleLineErrorHandler() : m_Error(false) + { + } + + //------------------------------------------------------------------------ + // ErrorHandlerBase virtuals + //------------------------------------------------------------------------ + virtual void error(size_t line, std::string_view message) final; + virtual void error(const Token &token, std::string_view message) final; + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + bool hasError() const { return m_Error; } + +private: + //------------------------------------------------------------------------ + // Private methods + //------------------------------------------------------------------------ + void report(std::string_view where, std::string_view message); + + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + bool m_Error; +}; +} // namespace GeNN::Transpiler diff --git a/include/genn/genn/transpiler/parser.h b/include/genn/genn/transpiler/parser.h index 4ca4cd523f..717432d7bf 100644 --- a/include/genn/genn/transpiler/parser.h +++ b/include/genn/genn/transpiler/parser.h @@ -13,7 +13,7 @@ // Forward declarations namespace GeNN::Transpiler { -class ErrorHandler; +class ErrorHandlerBase; } //--------------------------------------------------------------------------- @@ -22,13 +22,13 @@ class ErrorHandler; namespace GeNN::Transpiler::Parser { //! Parse expression from tokens -Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandler &errorHandler); +Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandlerBase &errorHandler); //! Parse block item list from tokens /*! Block item lists are function body scope list of statements */ -Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandler &errorHandler); +Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandlerBase &errorHandler); //! Parse type from tokens -const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, ErrorHandler &errorHandler); +const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, ErrorHandlerBase &errorHandler); } // MiniParse::MiniParse \ No newline at end of file diff --git a/include/genn/genn/transpiler/scanner.h b/include/genn/genn/transpiler/scanner.h index abf3fbd45c..e80a7f5d68 100644 --- a/include/genn/genn/transpiler/scanner.h +++ b/include/genn/genn/transpiler/scanner.h @@ -13,7 +13,7 @@ // Forward declarations namespace GeNN::Transpiler { -class ErrorHandler; +class ErrorHandlerBase; } //--------------------------------------------------------------------------- @@ -21,6 +21,6 @@ class ErrorHandler; //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Scanner { -std::vector scanSource(const std::string_view &source, ErrorHandler &errorHandler); +std::vector scanSource(const std::string_view &source, ErrorHandlerBase &errorHandler); } // namespace Scanner \ No newline at end of file diff --git a/include/genn/genn/transpiler/token.h b/include/genn/genn/transpiler/token.h index 99dadfcbea..433a28dd9a 100644 --- a/include/genn/genn/transpiler/token.h +++ b/include/genn/genn/transpiler/token.h @@ -7,6 +7,12 @@ // Standard C includes #include +// **YUCK** on Windows undefine TRUE and FALSE macros +#ifdef _WIN32 + #undef TRUE + #undef FALSE +#endif + //--------------------------------------------------------------------------- // GeNN::Transpiler::Token //--------------------------------------------------------------------------- diff --git a/include/genn/genn/transpiler/transpilerUtils.h b/include/genn/genn/transpiler/transpilerUtils.h index f062fc8106..5c6340dcd6 100644 --- a/include/genn/genn/transpiler/transpilerUtils.h +++ b/include/genn/genn/transpiler/transpilerUtils.h @@ -7,28 +7,28 @@ namespace GeNN::Transpiler::Utils { - template struct Overload : Ts... { using Ts::operator()...; }; - template Overload(Ts...) -> Overload; // line not needed in +template struct Overload : Ts... { using Ts::operator()...; }; +template Overload(Ts...) -> Overload; // line not needed in - template - T toCharsThrow(std::string_view input, int base = 10) - { - T out; - std::from_chars_result result; - if constexpr (std::is_floating_point_v) { - result = std::from_chars(input.data(), input.data() + input.size(), out, - (base == 10) ? std::chars_format::general : std::chars_format::hex); - } - else { - result = std::from_chars(input.data(), input.data() + input.size(), out, base); - } +template +T toCharsThrow(std::string_view input, int base = 10) +{ + T out; + std::from_chars_result result; + if constexpr (std::is_floating_point_v) { + result = std::from_chars(input.data(), input.data() + input.size(), out, + (base == 10) ? std::chars_format::general : std::chars_format::hex); + } + else { + result = std::from_chars(input.data(), input.data() + input.size(), out, base); + } - if(result.ec == std::errc::invalid_argument) { - throw std::invalid_argument("Unable to convert chars '" + std::string{input} + "'"); - } - else if(result.ec == std::errc::result_out_of_range) { - throw std::out_of_range("Unable to convert chars '" + std::string{input} + "'"); - } - return out; + if(result.ec == std::errc::invalid_argument) { + throw std::invalid_argument("Unable to convert chars '" + std::string{input} + "'"); + } + else if(result.ec == std::errc::result_out_of_range) { + throw std::out_of_range("Unable to convert chars '" + std::string{input} + "'"); } + return out; } +} // namespace GeNN::Transpiler::Utils diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index aca8e41c94..70568914c1 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -14,7 +14,7 @@ // Forward declarations namespace GeNN::Transpiler { -class ErrorHandler; +class ErrorHandlerBase; struct Token; } @@ -40,11 +40,11 @@ class EnvironmentBase //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ - virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandler &errorHandler) = 0; + virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandlerBase &errorHandler) = 0; virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, - ErrorHandler &errorHandler, bool initializer = false) = 0; - virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandler &errorHandler) = 0; - virtual const Type::QualifiedType &getType(const Token &name, ErrorHandler &errorHandler) = 0; + ErrorHandlerBase &errorHandler, bool initializer = false) = 0; + virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) = 0; + virtual const Type::QualifiedType &getType(const Token &name, ErrorHandlerBase &errorHandler) = 0; protected: //--------------------------------------------------------------------------- @@ -52,9 +52,9 @@ class EnvironmentBase //--------------------------------------------------------------------------- const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &existingType, const Type::QualifiedType &assignedType, - ErrorHandler &errorHandler, bool initializer = false) const; + ErrorHandlerBase &errorHandler, bool initializer = false) const; const Type::QualifiedType &incDec(const Token &name, Token::Type op, - const Type::QualifiedType &existingType, ErrorHandler &errorHandler) const; + const Type::QualifiedType &existingType, ErrorHandlerBase &errorHandler) const; }; //--------------------------------------------------------------------------- @@ -107,11 +107,11 @@ class EnvironmentExternal : public EnvironmentBase //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- - virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandler &errorHandler) final; + virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandlerBase &errorHandler) final; virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, - ErrorHandler &errorHandler, bool initializer = false) final; - virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandler &errorHandler) final; - virtual const Type::QualifiedType &getType(const Token &name, ErrorHandler &errorHandler) final; + ErrorHandlerBase &errorHandler, bool initializer = false) final; + virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final; + virtual const Type::QualifiedType &getType(const Token &name, ErrorHandlerBase &errorHandler) final; private: //--------------------------------------------------------------------------- @@ -127,9 +127,9 @@ class EnvironmentExternal : public EnvironmentBase //--------------------------------------------------------------------------- // Free functions //--------------------------------------------------------------------------- -void typeCheck(const Statement::StatementList &statements, EnvironmentExternal &environment, - ErrorHandler &errorHandler); +void typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, + ErrorHandlerBase &errorHandler); -Type::QualifiedType typeCheck(const Expression::Base *expression, EnvironmentExternal &environment, - ErrorHandler &errorHandler); +Type::QualifiedType typeCheck(const Expression::Base *expression, EnvironmentBase &environment, + ErrorHandlerBase &errorHandler); } // namespace MiniParse::GeNN::Transpiler diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index a5a42caa3e..8255fbd9e2 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -3,6 +3,12 @@ // GeNN code generator includes #include "code_generator/modelSpecMerged.h" +// GeNN transpiler includes +#include "transpiler/scanner.h" +#include "transpiler/typeChecker.h" +#include "transpiler/parser.h" + + using namespace GeNN; using namespace GeNN::CodeGenerator; @@ -11,6 +17,7 @@ using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- namespace { + template void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, const ModelSpecMerged &modelMerged, const std::string &index, diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index fc8fb92f8e..9f40ed920d 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -55,6 +55,7 @@ + diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index c4a9577ff6..5e093a90c6 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -38,7 +38,7 @@ class ParseError class ParserState { public: - ParserState(const std::vector &tokens, ErrorHandler &errorHandler) + ParserState(const std::vector &tokens, ErrorHandlerBase &errorHandler) : m_Current(0), m_Tokens(tokens), m_ErrorHandler(errorHandler) {} @@ -136,7 +136,7 @@ class ParserState const std::vector &m_Tokens; - ErrorHandler &m_ErrorHandler; + ErrorHandlerBase &m_ErrorHandler; }; @@ -843,7 +843,7 @@ std::unique_ptr parseBlockItem(ParserState &parserState) //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Parser { -Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandler &errorHandler) +Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandlerBase &errorHandler) { ParserState parserState(tokens, errorHandler); @@ -855,7 +855,7 @@ Expression::ExpressionPtr parseExpression(const std::vector &tokens, Erro } } //--------------------------------------------------------------------------- -Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandler &errorHandler) +Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandlerBase &errorHandler) { ParserState parserState(tokens, errorHandler); std::vector> statements; @@ -866,7 +866,7 @@ Statement::StatementList parseBlockItemList(const std::vector &tokens, Er return statements; } //--------------------------------------------------------------------------- -const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, ErrorHandler &errorHandler) +const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, ErrorHandlerBase &errorHandler) { ParserState parserState(tokens, errorHandler); bool pointerFound = false; diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index cbd3e43e86..63327228fd 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -58,7 +58,7 @@ const std::map, std::function &tokens) //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Scanner { -std::vector scanSource(const std::string_view &source, ErrorHandler &errorHandler) +std::vector scanSource(const std::string_view &source, ErrorHandlerBase &errorHandler) { std::vector tokens; diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 97ea6f43b4..6ea97d2721 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -37,7 +37,7 @@ class EnvironmentInternal : public EnvironmentBase //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- - virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandler &errorHandler) final + virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandlerBase &errorHandler) final { if(!m_Types.try_emplace(name.lexeme, qualifiedType).second) { errorHandler.error(name, "Redeclaration of variable"); @@ -46,7 +46,7 @@ class EnvironmentInternal : public EnvironmentBase } virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, - ErrorHandler &errorHandler, bool initializer = false) final + ErrorHandlerBase &errorHandler, bool initializer = false) final { // If type isn't found auto existingType = m_Types.find(name.lexeme); @@ -59,7 +59,7 @@ class EnvironmentInternal : public EnvironmentBase return EnvironmentBase::assign(name, op, existingType->second, assignedType, errorHandler, initializer); } - virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandler &errorHandler) final + virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final { // If type isn't found auto existingType = m_Types.find(name.lexeme); @@ -71,7 +71,7 @@ class EnvironmentInternal : public EnvironmentBase return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); } - virtual const Type::QualifiedType &getType(const Token &name, ErrorHandler &errorHandler) final + virtual const Type::QualifiedType &getType(const Token &name, ErrorHandlerBase &errorHandler) final { auto type = m_Types.find(std::string{name.lexeme}); if(type == m_Types.end()) { @@ -96,7 +96,7 @@ class EnvironmentInternal : public EnvironmentBase class Visitor : public Expression::Visitor, public Statement::Visitor { public: - Visitor(ErrorHandler &errorHandler) + Visitor(ErrorHandlerBase &errorHandler) : m_Environment(nullptr), m_QualifiedType{nullptr, false, false}, m_ErrorHandler(errorHandler), m_InLoop(false), m_InSwitch(false) { @@ -580,7 +580,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor EnvironmentInternal *m_Environment; Type::QualifiedType m_QualifiedType; - ErrorHandler &m_ErrorHandler; + ErrorHandlerBase &m_ErrorHandler; bool m_InLoop; bool m_InSwitch; }; @@ -591,7 +591,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Type op, const Type::QualifiedType &existingType, const Type::QualifiedType &assignedType, - ErrorHandler &errorHandler, bool initializer) const + ErrorHandlerBase &errorHandler, bool initializer) const { // If existing type is a constant numeric value or if it's a constant pointer give errors auto numericExistingType = dynamic_cast(existingType.type); @@ -673,7 +673,7 @@ const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Typ } //--------------------------------------------------------------------------- const Type::QualifiedType &EnvironmentBase::incDec(const Token &name, Token::Type, - const Type::QualifiedType &existingType, ErrorHandler &errorHandler) const + const Type::QualifiedType &existingType, ErrorHandlerBase &errorHandler) const { // If existing type is a constant numeric value or if it's a constant pointer give errors auto numericExistingType = dynamic_cast(existingType.type); @@ -693,14 +693,14 @@ const Type::QualifiedType &EnvironmentBase::incDec(const Token &name, Token::Typ //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker::EnvironmentExternal //--------------------------------------------------------------------------- -void EnvironmentExternal::define(const Token &name, const Type::QualifiedType &, ErrorHandler &errorHandler) +void EnvironmentExternal::define(const Token &name, const Type::QualifiedType &, ErrorHandlerBase &errorHandler) { errorHandler.error(name, "Cannot declare variable in external environment"); throw TypeCheckError(); } //--------------------------------------------------------------------------- const Type::QualifiedType &EnvironmentExternal::assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, - ErrorHandler &errorHandler, bool initializer) + ErrorHandlerBase &errorHandler, bool initializer) { // If type isn't found auto existingType = m_Types.find(name.lexeme); @@ -713,7 +713,7 @@ const Type::QualifiedType &EnvironmentExternal::assign(const Token &name, Token: return EnvironmentBase::assign(name, op, existingType->second, assignedType, errorHandler, initializer); } //--------------------------------------------------------------------------- -const Type::QualifiedType &EnvironmentExternal::incDec(const Token &name, Token::Type op, ErrorHandler &errorHandler) +const Type::QualifiedType &EnvironmentExternal::incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) { auto existingType = m_Types.find(name.lexeme); if(existingType == m_Types.end()) { @@ -726,7 +726,7 @@ const Type::QualifiedType &EnvironmentExternal::incDec(const Token &name, Token: } //--------------------------------------------------------------------------- -const Type::QualifiedType &EnvironmentExternal::getType(const Token &name, ErrorHandler &errorHandler) +const Type::QualifiedType &EnvironmentExternal::getType(const Token &name, ErrorHandlerBase &errorHandler) { auto type = m_Types.find(std::string{name.lexeme}); if(type == m_Types.end()) { @@ -741,8 +741,8 @@ const Type::QualifiedType &EnvironmentExternal::getType(const Token &name, Error //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker //--------------------------------------------------------------------------- -void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &statements, EnvironmentExternal &environment, - ErrorHandler &errorHandler) +void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, + ErrorHandlerBase &errorHandler) { Visitor visitor(errorHandler); EnvironmentInternal internalEnvironment(environment); @@ -750,8 +750,8 @@ void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &st } //--------------------------------------------------------------------------- Type::QualifiedType GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, - EnvironmentExternal &environment, - ErrorHandler &errorHandler) + EnvironmentBase &environment, + ErrorHandlerBase &errorHandler) { Visitor visitor(errorHandler); EnvironmentInternal internalEnvironment(environment); diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 233d9919e8..6c2aa54a54 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -47,29 +47,6 @@ const std::unordered_map uns {Type::Int16::getInstance(), Type::Uint16::getInstance()}, {Type::Int32::getInstance(), Type::Uint32::getInstance()} }; - -//---------------------------------------------------------------------------- -// SimpleErrorHandler -//---------------------------------------------------------------------------- -//! Simple error handler used for type parsing - just logs to transpiler log channel -class SimpleErrorHandler : public Transpiler::ErrorHandler -{ -public: - virtual void error(size_t line, std::string_view message) final - { - LOGE_TRANSPILER << "Error: " << message; - } - - virtual void error(const Transpiler::Token &token, std::string_view message) final - { - if(token.type == Transpiler::Token::Type::END_OF_FILE) { - LOGE_TRANSPILER << "Error at end: " << message; - } - else { - LOGE_TRANSPILER << "Error at '" << token.lexeme << "': " << message; - } - } -}; } // Anonymous namespace //---------------------------------------------------------------------------- @@ -100,11 +77,18 @@ const NumericBase *parseNumeric(std::string_view typeString) using namespace Transpiler; // Scan type - SimpleErrorHandler errorHandler; + SingleLineErrorHandler errorHandler; const auto tokens = Scanner::scanSource(typeString, errorHandler); - // Parse type, cast to numeric and return + // Parse type and cast to numeric const auto *type = dynamic_cast(Parser::parseType(tokens, false, errorHandler)); + + // If an error was encountered while scanning or parsing, throw exception + if (errorHandler.hasError()) { + throw std::runtime_error("Error parsing type"); + } + + // If tokens did not contain a valid numeric type, throw exception if (!type) { throw std::runtime_error("Unable to parse type"); } diff --git a/tests/unit/scanner.cc b/tests/unit/scanner.cc index 5c24f9fa5e..9151dcdff9 100644 --- a/tests/unit/scanner.cc +++ b/tests/unit/scanner.cc @@ -12,7 +12,7 @@ using namespace GeNN::Transpiler; //-------------------------------------------------------------------------- namespace { -class TestErrorHandler : public ErrorHandler +class TestErrorHandler : public ErrorHandlerBase { public: TestErrorHandler() : m_Error(false) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 47a0b62d95..3cbcbf5f77 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -18,7 +18,7 @@ using namespace GeNN::Transpiler; //-------------------------------------------------------------------------- namespace { -class TestErrorHandler : public ErrorHandler +class TestErrorHandler : public ErrorHandlerBase { public: TestErrorHandler() : m_Error(false) From c3f591323706c9ca1e05b7ac7d0118103871c836 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 12 Jan 2023 17:05:07 +0000 Subject: [PATCH 032/725] started adding a type environment to custom update group merged (no weirdness in that one!) --- .../code_generator/customUpdateGroupMerged.cc | 151 +++++++++++++++++- 1 file changed, 149 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 8255fbd9e2..cbd98fd81c 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -11,12 +11,155 @@ using namespace GeNN; using namespace GeNN::CodeGenerator; +using namespace GeNN::Transpiler; //-------------------------------------------------------------------------- // Anonymous namespace //-------------------------------------------------------------------------- namespace { +template +class GroupMergedTypeEnvironment : public TypeChecker::EnvironmentBase +{ +public: + GroupMergedTypeEnvironment(G &groupMerged, const Type::NumericBase *scalarType, + TypeChecker::EnvironmentBase *enclosing = nullptr) + : m_GroupMerged(groupMerged), m_ScalarType(scalarType), m_Enclosing(enclosing) + { + } + + //--------------------------------------------------------------------------- + // EnvironmentBase virtuals + //--------------------------------------------------------------------------- + virtual void define(const Token &name, const Type::QualifiedType &, ErrorHandlerBase &errorHandler) final + { + errorHandler.error(name, "Cannot declare variable in external environment"); + throw TypeCheckError(); + } + + virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, + ErrorHandlerBase &errorHandler, bool initializer) final + { + // If type isn't found + auto existingType = m_Types.find(name.lexeme); + if(existingType == m_Types.end()) { + if(m_Enclosing) { + return m_Enclosing->assign(name, op, assignedType, errorHandler, initializer); + } + else { + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); + } + } + + // Perform standard type-checking logic + return EnvironmentBase::assign(name, op, existingType->second, assignedType, errorHandler, initializer); + } + + virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final + { + auto existingType = m_Types.find(name.lexeme); + if(existingType == m_Types.end()) { + if(m_Enclosing) { + return m_Enclosing->incDec(name, op, errorHandler); + } + else { + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); + } + } + + // Perform standard type-checking logic + return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); + + } + + virtual const Type::QualifiedType &getType(const Token &name, ErrorHandlerBase &errorHandler) final + { + auto type = m_Types.find(std::string{name.lexeme}); + if(type == m_Types.end()) { + if(m_Enclosing) { + return m_Enclosing->getType(name, errorHandler); + } + else { + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); + } + } + else { + return type->second; + } + } + + //--------------------------------------------------------------------------- + // Public API + //--------------------------------------------------------------------------- + void define(std::string_view name, const Type::Base *type, bool isConstValue = false, bool isConstPointer = false) + { + if(!m_Types.try_emplace(name, type, isConstValue, isConstPointer).second) { + throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); + } + } + template + void define(std::string_view name, bool isConstValue = false, bool isConstPointer = false) + { + define(name, T::getInstance(), isConstValue, isConstPointer); + } + + template + void addHeterogeneousParams(const Snippet::Base::StringVec ¶mNames, const std::string &suffix, + P getParamValues, H isHeterogeneous) + { + // Loop through params + for(const auto &p : paramNames) { + // Define constant + define(p + suffix, m_ScalarType, true); + + // If parameters is heterogeneous + if((static_cast(this)->*isHeterogeneous)(p)) { + // Add field + m_GroupMerged->addScalarField(p + suffix, + [p, getParamValues](const G &g, size_t) + { + const auto &values = getParamValues(g); + return Utils::writePreciseString(values.at(p)); + }); + } + } + } + + template + void addHeterogeneousDerivedParams(const Snippet::Base::DerivedParamVec &derivedParams, const std::string &suffix, + D getDerivedParamValues, H isHeterogeneous) + { + // Loop through derived params + for(const auto &d : derivedParams) { + // If parameters isn't homogeneous + if((static_cast(this)->*isHeterogeneous)(d.name)) { + // Define constant + define(p + suffix, m_ScalarType, true); + + // Add field + addScalarField(d.name + suffix, + [d, getDerivedParamValues](const G &g, size_t) + { + const auto &values = getDerivedParamValues(g); + return Utils::writePreciseString(values.at(d.name)); + }); + } + } + } + +private: + //--------------------------------------------------------------------------- + // Members + //--------------------------------------------------------------------------- + G &m_GroupMerged; + const Type::NumericBase *m_ScalarType; + TypeChecker::EnvironmentBase *m_Enclosing; + + std::unordered_map m_Types; +}; template void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, @@ -113,6 +256,10 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string const std::vector> &groups) : GroupMerged(index, precision, groups) { + // Create type environment + // **TEMP** parse precision to get scalar type + GroupMergedTypeEnvironment typeEnvironment(this, Type::parseNumeric(precision)); + addField("unsigned int", "size", [](const CustomUpdateInternal &c, size_t) { return std::to_string(c.getSize()); }); @@ -127,13 +274,13 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string // Add heterogeneous custom update model parameters const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); - addHeterogeneousParams( + typeEnvironment.addHeterogeneousParams( cm->getParamNames(), "", [](const CustomUpdateInternal &cg) { return cg.getParams(); }, &CustomUpdateGroupMerged::isParamHeterogeneous); // Add heterogeneous weight update model derived parameters - addHeterogeneousDerivedParams( + typeEnvironment.addHeterogeneousDerivedParams( cm->getDerivedParams(), "", [](const CustomUpdateInternal &cg) { return cg.getDerivedParams(); }, &CustomUpdateGroupMerged::isDerivedParamHeterogeneous); From 04c9f489c85cad2e5ba0afe363fd090414c9861c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 12 Jan 2023 18:02:12 +0000 Subject: [PATCH 033/725] added numeric pointer parsing --- include/genn/genn/type.h | 3 +++ src/genn/genn/type.cc | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 6541293a87..ddfe7d414a 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -311,6 +311,9 @@ DECLARE_FOREIGN_FUNCTION_TYPE(Sqrt, Double, Double); //! Parse a numeric type const NumericBase *parseNumeric(std::string_view typeString); +//! Parse a numeric pointer type +const NumericPtrBase *parseNumericPtr(std::string_view typeString); + //! Look up numeric type based on set of type specifiers const NumericBase *getNumericType(const std::set &typeSpecifiers); diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 6c2aa54a54..a70b8f6c49 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -95,6 +95,29 @@ const NumericBase *parseNumeric(std::string_view typeString) return type; } //---------------------------------------------------------------------------- +const NumericPtrBase *parseNumericPtr(std::string_view typeString) +{ + using namespace Transpiler; + + // Scan type + SingleLineErrorHandler errorHandler; + const auto tokens = Scanner::scanSource(typeString, errorHandler); + + // Parse type and cast to numeric pointer + const auto *type = dynamic_cast(Parser::parseType(tokens, true, errorHandler)); + + // If an error was encountered while scanning or parsing, throw exception + if (errorHandler.hasError()) { + throw std::runtime_error("Error parsing type"); + } + + // If tokens did not contain a valid numeric type, throw exception + if (!type) { + throw std::runtime_error("Unable to parse type"); + } + return type; +} +//---------------------------------------------------------------------------- const NumericBase *getNumericType(const std::set &typeSpecifiers) { const auto type = numericTypes.find(typeSpecifiers); From 4c5294f0ecbfaedcdaacce5d1e215379d35d2d78 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 12 Jan 2023 18:03:07 +0000 Subject: [PATCH 034/725] Replaced BackendBase::getMergedGroupFieldHostType with BackendBase::Backend::getMergedGroupFieldHostTypeName which converts a Type::Base pointer to a typename (normally by calling getTypeName!) --- include/genn/backends/cuda/backend.h | 2 +- include/genn/backends/opencl/backend.h | 2 +- include/genn/backends/single_threaded_cpu/backend.h | 2 +- include/genn/genn/code_generator/backendBase.h | 2 +- src/genn/backends/cuda/backend.cc | 4 ++-- src/genn/backends/opencl/backend.cc | 8 ++++---- src/genn/backends/single_threaded_cpu/backend.cc | 4 ++-- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index 30a91ec87b..5044e1cd7e 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -210,7 +210,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT const std::string &egpName) const override; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostType(const std::string &type) const override; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const override; //! When generating merged structures what type to use for simulation RNGs virtual std::string getMergedGroupSimRNGType() const override { return "curandState"; } diff --git a/include/genn/backends/opencl/backend.h b/include/genn/backends/opencl/backend.h index ad52959179..8c7629820f 100644 --- a/include/genn/backends/opencl/backend.h +++ b/include/genn/backends/opencl/backend.h @@ -173,7 +173,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT const std::string &egpName) const override; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostType(const std::string &type) const override; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const override; //! When generating merged structures what type to use for simulation RNGs virtual std::string getMergedGroupSimRNGType() const override { return "clrngLfsr113HostStream"; } diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 0190308a19..27c3fe8927 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -82,7 +82,7 @@ class BACKEND_EXPORT Backend : public BackendBase const std::string &egpName) const override; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostType(const std::string &type) const override; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const override; //! When generating merged structures what type to use for simulation RNGs virtual std::string getMergedGroupSimRNGType() const override; diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index e775d27c3d..046a208f2b 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -271,7 +271,7 @@ class GENN_EXPORT BackendBase const std::string &egpName) const = 0; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostType(const std::string &type) const = 0; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const = 0; //! When generating merged structures what type to use for simulation RNGs virtual std::string getMergedGroupSimRNGType() const = 0; diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index e205802a6e..92fa08a8ca 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -1748,9 +1748,9 @@ void Backend::genMergedExtraGlobalParamPush(CodeStream &os, const std::string &s os << ", (sizeof(" << structName << ") * (" << groupIdx << ")) + offsetof(" << structName << ", " << fieldName << ")));" << std::endl; } //-------------------------------------------------------------------------- -std::string Backend::getMergedGroupFieldHostType(const std::string &type) const +std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) const { - return type; + return type->getTypeName(); } //-------------------------------------------------------------------------- void Backend::genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index d383f4da60..ca11216228 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -2057,18 +2057,18 @@ void Backend::genMergedExtraGlobalParamPush(CodeStream &os, const std::string &s os << "CHECK_OPENCL_ERRORS(commandQueue.enqueueNDRangeKernel(" << kernelName << ", cl::NullRange, globalWorkSize, localWorkSize));" << std::endl; } //-------------------------------------------------------------------------- -std::string Backend::getMergedGroupFieldHostType(const std::string &type) const +std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) const { // If type is a pointer, on the host it is represented by an OpenCL buffer - if(GeNN::Utils::isTypePointerToPointer(type)) { + /*if(GeNN::Utils::isTypePointerToPointer(type)) { return "cl::Buffer*"; } - else if(GeNN::Utils::isTypePointer(type)) { + else */if(dynamic_cast(type)) { return "cl::Buffer"; } // Otherwise, type remains the same else { - return type; + return type->getTypeName(); } } //-------------------------------------------------------------------------- diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index cdb3498520..e2702088c2 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1367,9 +1367,9 @@ void Backend::genMergedExtraGlobalParamPush(CodeStream &os, const std::string &s os << "merged" << suffix << "Group" << mergedGroupIdx << "[" << groupIdx << "]." << fieldName << " = " << egpName << ";" << std::endl; } //-------------------------------------------------------------------------- -std::string Backend::getMergedGroupFieldHostType(const std::string &type) const +std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) const { - return type; + return type->getTypeName(); } //-------------------------------------------------------------------------- std::string Backend::getMergedGroupSimRNGType() const From d1c28fc9f1e9ab7a65bee7ab858ac838f79730cc Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 12 Jan 2023 18:03:38 +0000 Subject: [PATCH 035/725] hacking at GroupMerged to use type system for fields --- .../genn/genn/code_generator/groupMerged.h | 49 ++++++++++--------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 0dd3e89cea..d9effbf26c 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -13,6 +13,7 @@ #include "customUpdateInternal.h" #include "neuronGroupInternal.h" #include "synapseGroupInternal.h" +#include "type.h" // GeNN code generator includes #include "code_generator/backendBase.h" @@ -62,11 +63,11 @@ class GroupMerged //------------------------------------------------------------------------ typedef G GroupInternal; typedef std::function GetFieldValueFunc; - typedef std::tuple Field; - + typedef std::tuple Field; + // **HACK** type should come in as type not string GroupMerged(size_t index, const std::string &precision, const std::vector> groups) - : m_Index(index), m_LiteralSuffix((precision == "float") ? "f" : ""), m_Groups(std::move(groups)) + : m_Index(index), m_LiteralSuffix((precision == "float") ? "f" : ""), m_ScalarType(Type::parseNumeric(precision)), m_Groups(std::move(groups)) {} //------------------------------------------------------------------------ @@ -89,13 +90,14 @@ class GroupMerged //! Get group fields, sorted into order they will appear in struct std::vector getSortedFields(const BackendBase &backend) const { + // **TODO** size should come from type system itself - numerics are easy pointer size is a little trickier // Make a copy of fields and sort so largest come first. This should mean that due // to structure packing rules, significant memory is saved and estimate is more precise auto sortedFields = m_Fields; std::sort(sortedFields.begin(), sortedFields.end(), [&backend](const Field &a, const Field &b) { - return (backend.getSize(std::get<0>(a)) > backend.getSize(std::get<0>(b))); + return (backend.getSize(std::get<0>(a)->getTypeName()) > backend.getSize(std::get<0>(b)->getTypeName())); }); return sortedFields; @@ -113,11 +115,11 @@ class GroupMerged for(const auto &f : sortedFields) { // If field is a pointer and not marked as being a host field // (in which case the backend should leave its type alone!) - const std::string &type = std::get<0>(f); - if(Utils::isTypePointer(type) && !(std::get<3>(f) & GroupMergedFieldType::HOST)) { + const auto *type = std::get<0>(f); + if(dynamic_cast(type) && !(std::get<3>(f) & GroupMergedFieldType::HOST)) { // If we are generating a host structure, allow the backend to override the type if(host) { - os << backend.getMergedGroupFieldHostType(type); + os << backend.getMergedGroupFieldHostTypeName(type); } // Otherwise, allow the backend to add a prefix else { @@ -126,7 +128,7 @@ class GroupMerged } // Otherwise, leave the type alone else { - os << type; + os << type->getTypeName(); } os << " " << std::get<1>(f) << ";" << std::endl; } @@ -142,7 +144,7 @@ class GroupMerged const auto sortedFields = getSortedFields(backend); for(size_t fieldIndex = 0; fieldIndex < sortedFields.size(); fieldIndex++) { const auto &f = sortedFields[fieldIndex]; - os << backend.getMergedGroupFieldHostType(std::get<0>(f)) << " " << std::get<1>(f); + os << backend.getMergedGroupFieldHostTypeName(std::get<0>(f)) << " " << std::get<1>(f); if(fieldIndex != (sortedFields.size() - 1)) { os << ", "; } @@ -157,7 +159,8 @@ class GroupMerged const auto sortedFields = getSortedFields(backend); for(const auto &f : sortedFields) { // Add size of field to total - const size_t fieldSize = backend.getSize(std::get<0>(f)); + // **TODO** size should be built into type system + const size_t fieldSize = backend.getSize(std::get<0>(f)->getTypeName()); structSize += fieldSize; // Update largest field size @@ -195,7 +198,7 @@ class GroupMerged } } -protected: +//protected: //------------------------------------------------------------------------ // Protected methods //------------------------------------------------------------------------ @@ -239,7 +242,7 @@ class GroupMerged }); } - void addField(const std::string &type, const std::string &name, GetFieldValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) + void addField(const Type::Base *type, const std::string &name, GetFieldValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) { // Add field to data structure m_Fields.emplace_back(type, name, getFieldValue, fieldType); @@ -247,7 +250,7 @@ class GroupMerged void addScalarField(const std::string &name, GetFieldValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) { - addField("scalar", name, + addField(m_ScalarType, name, [getFieldValue, this](const G &g, size_t i) { return getFieldValue(g, i) + m_LiteralSuffix; @@ -255,10 +258,9 @@ class GroupMerged fieldType); } - void addPointerField(const std::string &type, const std::string &name, const std::string &prefix) + void addPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - assert(!Utils::isTypePointer(type)); - addField(type + "*", name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); + addField(type->getPointerType(), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); } @@ -266,7 +268,7 @@ class GroupMerged { // Loop through variables for(const auto &v : vars) { - addPointerField(v.type, v.name, arrayPrefix + v.name); + addPointerField(Type::parseNumeric(v.type), v.name, arrayPrefix + v.name); } } @@ -275,7 +277,7 @@ class GroupMerged { // Loop through variables for(const auto &v : varReferences) { - addField(v.type + "*", v.name, + addField(Type::parseNumeric(v.type)->getPointerType(), v.name, [getVarRefFn, arrayPrefix, v](const G &g, size_t) { const auto varRef = getVarRefFn(g).at(v.name); @@ -287,9 +289,9 @@ class GroupMerged void addEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "") { for(const auto &e : egps) { - const std::string prefix = Utils::isTypePointer(e.type) ? arrayPrefix : ""; - addField(e.type, e.name + varName, - [e, prefix, varName](const G &g, size_t) { return prefix + e.name + varName + g.getName(); }, + assert(Utils::isTypePointer(e.type)); + addField(Type::parseNumericPtr(e.type), e.name + varName, + [e, arrayPrefix, varName](const G &g, size_t) { return arrayPrefix + e.name + varName + g.getName(); }, GroupMergedFieldType::DYNAMIC); } } @@ -461,9 +463,9 @@ class GroupMerged // Loop through fields again to generate any EGP pushing functions that are require for(const auto &f : sortedFields) { // If this field is a dynamic pointer - if((std::get<3>(f) & GroupMergedFieldType::DYNAMIC) && Utils::isTypePointer(std::get<0>(f))) { + if((std::get<3>(f) & GroupMergedFieldType::DYNAMIC) && dynamic_cast(std::get<0>(f))) { definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << getIndex() << std::get<1>(f) << "ToDevice(unsigned int idx, "; - definitionsInternalFunc << backend.getMergedGroupFieldHostType(std::get<0>(f)) << " value);" << std::endl; + definitionsInternalFunc << backend.getMergedGroupFieldHostTypeName(std::get<0>(f)) << " value);" << std::endl; } // Raise error if this field is a host field but this isn't a host structure @@ -526,6 +528,7 @@ class GroupMerged //------------------------------------------------------------------------ const size_t m_Index; const std::string m_LiteralSuffix; + const Type::Base *m_ScalarType; std::string m_MemorySpace; std::vector m_Fields; std::vector> m_Groups; From c8afc46e0c801b95f1f7747bcbeaa9ca6bfa4ee6 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 12 Jan 2023 18:04:08 +0000 Subject: [PATCH 036/725] rough integration of type environment into merged custom update groups (simplest choice) --- .../code_generator/customUpdateGroupMerged.h | 8 +- .../code_generator/customUpdateGroupMerged.cc | 118 ++++++++++++------ 2 files changed, 86 insertions(+), 40 deletions(-) diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index c394377c7b..21f3028a34 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -162,18 +162,22 @@ class CustomUpdateHostReductionGroupMergedBase : public GroupMerged const std::vector> &groups) : GroupMerged(index, precision, groups) { + // Create type environment + // **TEMP** parse precision to get scalar type + //GroupMergedTypeEnvironment typeEnvironment(*this, Type::parseNumeric(precision)); + // Loop through variables and add pointers if they are reduction targets const CustomUpdateModels::Base *cm = this->getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { if(v.access & VarAccessModeAttribute::REDUCE) { - this->addPointerField(v.type, v.name, backend.getDeviceVarPrefix() + v.name); + this->addPointerField(Type::parseNumeric(v.type), v.name, backend.getDeviceVarPrefix() + v.name); } } // Loop through variable references and add pointers if they are reduction targets for(const auto &v : cm->getVarRefs()) { if(v.access & VarAccessModeAttribute::REDUCE) { - this->addPointerField(v.type, v.name, backend.getDeviceVarPrefix() + v.name); + this->addPointerField(Type::parseNumeric(v.type), v.name, backend.getDeviceVarPrefix() + v.name); } } } diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index cbd98fd81c..1c86927564 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -4,9 +4,10 @@ #include "code_generator/modelSpecMerged.h" // GeNN transpiler includes +#include "transpiler/errorHandler.h" +#include "transpiler/parser.h" #include "transpiler/scanner.h" #include "transpiler/typeChecker.h" -#include "transpiler/parser.h" using namespace GeNN; @@ -34,7 +35,7 @@ class GroupMergedTypeEnvironment : public TypeChecker::EnvironmentBase virtual void define(const Token &name, const Type::QualifiedType &, ErrorHandlerBase &errorHandler) final { errorHandler.error(name, "Cannot declare variable in external environment"); - throw TypeCheckError(); + throw TypeChecker::TypeCheckError(); } virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, @@ -48,7 +49,7 @@ class GroupMergedTypeEnvironment : public TypeChecker::EnvironmentBase } else { errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); + throw TypeChecker::TypeCheckError(); } } @@ -65,7 +66,7 @@ class GroupMergedTypeEnvironment : public TypeChecker::EnvironmentBase } else { errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); + throw TypeChecker::TypeCheckError(); } } @@ -83,7 +84,7 @@ class GroupMergedTypeEnvironment : public TypeChecker::EnvironmentBase } else { errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); + throw TypeChecker::TypeCheckError(); } } else { @@ -116,10 +117,10 @@ class GroupMergedTypeEnvironment : public TypeChecker::EnvironmentBase define(p + suffix, m_ScalarType, true); // If parameters is heterogeneous - if((static_cast(this)->*isHeterogeneous)(p)) { + if((static_cast(m_GroupMerged).*isHeterogeneous)(p)) { // Add field - m_GroupMerged->addScalarField(p + suffix, - [p, getParamValues](const G &g, size_t) + m_GroupMerged.addScalarField(p + suffix, + [p, getParamValues](const typename G::GroupInternal &g, size_t) { const auto &values = getParamValues(g); return Utils::writePreciseString(values.at(p)); @@ -135,21 +136,58 @@ class GroupMergedTypeEnvironment : public TypeChecker::EnvironmentBase // Loop through derived params for(const auto &d : derivedParams) { // If parameters isn't homogeneous - if((static_cast(this)->*isHeterogeneous)(d.name)) { + if((static_cast(m_GroupMerged).*isHeterogeneous)(d.name)) { // Define constant - define(p + suffix, m_ScalarType, true); + define(d.name + suffix, m_ScalarType, true); // Add field - addScalarField(d.name + suffix, - [d, getDerivedParamValues](const G &g, size_t) - { - const auto &values = getDerivedParamValues(g); - return Utils::writePreciseString(values.at(d.name)); - }); + m_GroupMerged.addScalarField(d.name + suffix, + [d, getDerivedParamValues](const typename G::GroupInternal &g, size_t) + { + const auto &values = getDerivedParamValues(g); + return Utils::writePreciseString(values.at(d.name)); + }); } } } + void addVars(const Models::Base::VarVec &vars, const std::string &arrayPrefix) + { + // Loop through variables + for(const auto &v : vars) { + const auto *type = Type::parseNumeric(v.type); + define(v.name, type, (v.access & VarAccessModeAttribute::READ_ONLY)); + m_GroupMerged.addPointerField(type, v.name, arrayPrefix + v.name); + } + } + + template + void addVarReferences(const Models::Base::VarRefVec &varReferences, const std::string &arrayPrefix, V getVarRefFn) + { + // Loop through variables + for(const auto &v : varReferences) { + const auto *type = Type::parseNumeric(v.type); + define(v.name, type, (v.access & VarAccessModeAttribute::READ_ONLY)); + m_GroupMerged.addField(type->getPointerType(), v.name, + [getVarRefFn, arrayPrefix, v](const typename G::GroupInternal &g, size_t) + { + const auto varRef = getVarRefFn(g).at(v.name); + return arrayPrefix + varRef.getVar().name + varRef.getTargetName(); + }); + } + } + + void addEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "") + { + for(const auto &e : egps) { + const auto *type = Type::parseNumericPtr(e.type); + define(e.name, type); + m_GroupMerged.addField(type, e.name + varName, + [e, arrayPrefix, varName](const typename G::GroupInternal &g, size_t) { return arrayPrefix + e.name + varName + g.getName(); }, + GroupMergedFieldType::DYNAMIC); + } + } + private: //--------------------------------------------------------------------------- // Members @@ -258,14 +296,14 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string { // Create type environment // **TEMP** parse precision to get scalar type - GroupMergedTypeEnvironment typeEnvironment(this, Type::parseNumeric(precision)); + GroupMergedTypeEnvironment typeEnvironment(*this, Type::parseNumeric(precision)); - addField("unsigned int", "size", + addField(Type::Uint32::getInstance(), "size", [](const CustomUpdateInternal &c, size_t) { return std::to_string(c.getSize()); }); // If some variables are delayed, add delay pointer if(getArchetype().getDelayNeuronGroup() != nullptr) { - addField("unsigned int*", "spkQuePtr", + addField(Type::Uint32Ptr::getInstance(), "spkQuePtr", [&backend](const CustomUpdateInternal &cg, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); @@ -286,14 +324,14 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string &CustomUpdateGroupMerged::isDerivedParamHeterogeneous); // Add variables to struct - addVars(cm->getVars(), backend.getDeviceVarPrefix()); + typeEnvironment.addVars(cm->getVars(), backend.getDeviceVarPrefix()); // Add variable references to struct - addVarReferences(cm->getVarRefs(), backend.getDeviceVarPrefix(), + typeEnvironment.addVarReferences(cm->getVarRefs(), backend.getDeviceVarPrefix(), [](const CustomUpdateInternal &cg) { return cg.getVarReferences(); }); // Add EGPs to struct - this->addEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); + typeEnvironment.addEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); } //---------------------------------------------------------------------------- bool CustomUpdateGroupMerged::isParamHeterogeneous(const std::string ¶mName) const @@ -427,13 +465,17 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const const std::vector> &groups) : GroupMerged(index, precision, groups) { + // Create type environment + // **TEMP** parse precision to get scalar type + GroupMergedTypeEnvironment typeEnvironment(*this, Type::parseNumeric(precision)); + // If underlying synapse group has kernel weights if (getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL) { // Loop through kernel size dimensions for (size_t d = 0; d < getArchetype().getSynapseGroup()->getKernelSize().size(); d++) { // If this dimension has a heterogeneous size, add it to struct if (isKernelSizeHeterogeneous(d)) { - addField("unsigned int", "kernelSize" + std::to_string(d), + addField(Type::Uint32::getInstance(), "kernelSize" + std::to_string(d), [d](const CustomUpdateWUInternal &cu, size_t) { return std::to_string(cu.getSynapseGroup()->getKernelSize().at(d)); @@ -443,21 +485,21 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const } // Otherwise else { - addField("unsigned int", "rowStride", + addField(Type::Uint32::getInstance(), "rowStride", [&backend](const CustomUpdateWUInternal &cg, size_t) { const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); return std::to_string(backend.getSynapticMatrixRowStride(*sgInternal)); }); - addField("unsigned int", "numSrcNeurons", + addField(Type::Uint32::getInstance(), "numSrcNeurons", [](const CustomUpdateWUInternal &cg, size_t) { const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); }); - addField("unsigned int", "numTrgNeurons", + addField(Type::Uint32::getInstance(), "numTrgNeurons", [](const CustomUpdateWUInternal &cg, size_t) { const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); @@ -466,13 +508,13 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // If synapse group has sparse connectivity if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addField(getArchetype().getSynapseGroup()->getSparseIndType() + "*", "ind", + addField(Type::parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())->getPointerType(), "ind", [&backend](const CustomUpdateWUInternal &cg, size_t) { return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); }); - addField("unsigned int*", "rowLength", + addField(Type::Uint32Ptr::getInstance(), "rowLength", [&backend](const CustomUpdateWUInternal &cg, size_t) { return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); @@ -482,31 +524,31 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // Add heterogeneous custom update model parameters const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); - addHeterogeneousParams( + typeEnvironment.addHeterogeneousParams( cm->getParamNames(), "", [](const CustomUpdateWUInternal &cg) { return cg.getParams(); }, &CustomUpdateWUGroupMergedBase::isParamHeterogeneous); // Add heterogeneous weight update model derived parameters - addHeterogeneousDerivedParams( + typeEnvironment.addHeterogeneousDerivedParams( cm->getDerivedParams(), "", [](const CustomUpdateWUInternal &cg) { return cg.getDerivedParams(); }, &CustomUpdateWUGroupMergedBase::isDerivedParamHeterogeneous); // Add variables to struct - addVars(cm->getVars(), backend.getDeviceVarPrefix()); + typeEnvironment.addVars(cm->getVars(), backend.getDeviceVarPrefix()); // Add variable references to struct const auto varRefs = cm->getVarRefs(); - addVarReferences(varRefs, backend.getDeviceVarPrefix(), - [](const CustomUpdateWUInternal &cg) { return cg.getVarReferences(); }); + typeEnvironment.addVarReferences(varRefs, backend.getDeviceVarPrefix(), + [](const CustomUpdateWUInternal &cg) { return cg.getVarReferences(); }); // Loop through variables for(const auto &v : varRefs) { // If variable has a transpose if(getArchetype().getVarReferences().at(v.name).getTransposeSynapseGroup() != nullptr) { // Add field with transpose suffix, pointing to transpose var - addField(v.type + "*", v.name + "Transpose", + addField(Type::parseNumeric(v.type)->getPointerType(), v.name + "Transpose", [&backend, v](const CustomUpdateWUInternal &g, size_t) { const auto varRef = g.getVarReferences().at(v.name); @@ -515,7 +557,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const } } // Add EGPs to struct - this->addEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); + typeEnvironment.addEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); } // ---------------------------------------------------------------------------- @@ -557,13 +599,13 @@ CustomUpdateHostReductionGroupMerged::CustomUpdateHostReductionGroupMerged(size_ const std::vector> &groups) : CustomUpdateHostReductionGroupMergedBase(index, precision, backend, groups) { - addField("unsigned int", "size", + addField(Type::Uint32::getInstance(), "size", [](const CustomUpdateInternal &c, size_t) { return std::to_string(c.getSize()); }); // If some variables are delayed, add delay pointer // **NOTE** this is HOST delay pointer if(getArchetype().getDelayNeuronGroup() != nullptr) { - addField("unsigned int*", "spkQuePtr", + addField(Type::Uint32Ptr::getInstance(), "spkQuePtr", [&](const CustomUpdateInternal &cg, size_t) { return "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); @@ -580,7 +622,7 @@ CustomWUUpdateHostReductionGroupMerged::CustomWUUpdateHostReductionGroupMerged(s const std::vector> &groups) : CustomUpdateHostReductionGroupMergedBase(index, precision, backend, groups) { - addField("unsigned int", "size", + addField(Type::Uint32::getInstance(), "size", [&backend](const CustomUpdateWUInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getMaxConnections() * (size_t)cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); From 9e5c37d60831a7822b9817dc53e91e862aa033e2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 13 Jan 2023 10:32:00 +0000 Subject: [PATCH 037/725] slightly more complete integration into merged custom update groups --- .../groupMergedTypeEnvironment.h | 249 ++++++++++++++++++ .../code_generator/customUpdateGroupMerged.cc | 212 ++------------- src/genn/genn/genn.vcxproj | 1 + src/genn/genn/transpiler/errorHandler.cc | 55 ++++ 4 files changed, 325 insertions(+), 192 deletions(-) create mode 100644 include/genn/genn/code_generator/groupMergedTypeEnvironment.h create mode 100644 src/genn/genn/transpiler/errorHandler.cc diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h new file mode 100644 index 0000000000..96ec5dd9f2 --- /dev/null +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -0,0 +1,249 @@ +#pragma once + +// Standard C++ includes +#include + +// GeNN code generator includes +#include "code_generator/groupMerged.h" + +// GeNN transpiler includes +#include "transpiler/typeChecker.h" + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::GroupMergedTypeEnvironment +//---------------------------------------------------------------------------- +namespace GeNN::CodeGenerator +{ +template +class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBase +{ + using Token = Transpiler::Token; + using ErrorHandlerBase = Transpiler::ErrorHandlerBase; + using EnvironmentBase = Transpiler::TypeChecker::EnvironmentBase; + +public: + GroupMergedTypeEnvironment(G &groupMerged, const Type::NumericBase *scalarType, + EnvironmentBase *enclosing = nullptr) + : m_GroupMerged(groupMerged), m_ScalarType(scalarType), m_Enclosing(enclosing) + { + } + + //--------------------------------------------------------------------------- + // EnvironmentBase virtuals + //--------------------------------------------------------------------------- + virtual void define(const Transpiler::Token &name, const Type::QualifiedType &, ErrorHandlerBase &errorHandler) final + { + errorHandler.error(name, "Cannot declare variable in external environment"); + throw TypeChecker::TypeCheckError(); + } + + virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, + ErrorHandlerBase &errorHandler, bool initializer) final + { + // If type isn't found + auto existingType = m_Types.find(name.lexeme); + if(existingType == m_Types.end()) { + if(m_Enclosing) { + return m_Enclosing->assign(name, op, assignedType, errorHandler, initializer); + } + else { + errorHandler.error(name, "Undefined variable"); + throw TypeChecker::TypeCheckError(); + } + } + + // Add field to merged group if required + addField(existingType->second); + + // Perform standard type-checking logic + return EnvironmentBase::assign(name, op, existingType->second.first, assignedType, errorHandler, initializer); + } + + virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final + { + auto existingType = m_Types.find(name.lexeme); + if(existingType == m_Types.end()) { + if(m_Enclosing) { + return m_Enclosing->incDec(name, op, errorHandler); + } + else { + errorHandler.error(name, "Undefined variable"); + throw TypeChecker::TypeCheckError(); + } + } + + // Add field to merged group if required + addField(existingType->second); + + // Perform standard type-checking logic + return EnvironmentBase::incDec(name, op, existingType->second.first, errorHandler); + } + + virtual const Type::QualifiedType &getType(const Token &name, ErrorHandlerBase &errorHandler) final + { + auto type = m_Types.find(std::string{name.lexeme}); + if(type == m_Types.end()) { + if(m_Enclosing) { + return m_Enclosing->getType(name, errorHandler); + } + else { + errorHandler.error(name, "Undefined variable"); + throw TypeChecker::TypeCheckError(); + } + } + else { + // Add field to merged group if required + addField(type->second); + + return type->second.first; + } + } + + //--------------------------------------------------------------------------- + // Public API + //--------------------------------------------------------------------------- + void defineField(const Type::Base *type, std::string_view name, bool isConstValue = false, bool isConstPointer = false) + { + if(!m_Types.try_emplace(name, std::piecewise_construct, + std::forward_as_tuple(type, isConstValue, isConstPointer), + std::forward_as_tuple(std::nullopt)).second) + { + throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); + } + } + + template + void defineField(std::string_view name, bool isConstValue = false, bool isConstPointer = false) + { + defineField(T::getInstance(), name, isConstPointer, isConstPointer); + } + + void defineField(const Type::Base *type, std::string_view name, bool isConstValue, bool isConstPointer, + const Type::Base *fieldType, std::string_view fieldName, typename G::GetFieldValueFunc getFieldValue, + GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD) + { + if(!m_Types.try_emplace(name, std::piecewise_construct, + std::forward_as_tuple(type, isConstValue, isConstPointer), + std::forward_as_tuple(std::in_place, fieldType, fieldName, getFieldValue, mergedFieldType)).second) + { + throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); + } + } + + void defineField(const Type::Base *type, std::string_view name, bool isConstValue, bool isConstPointer, + typename G::GetFieldValueFunc getFieldValue, GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD) + { + defineField(type, name, isConstValue, isConstPointer, + type, name, getFieldValue, mergedFieldType); + } + + void definePointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix, VarAccessMode access) + { + defineField(type, name, (access & VarAccessModeAttribute::READ_ONLY), false, + type->getPointerType(), name, [prefix](const auto &g, size_t) { return prefix + g.getName(); }); + } + + template + void defineHeterogeneousParams(const Snippet::Base::StringVec ¶mNames, const std::string &suffix, + P getParamValues, H isHeterogeneous) + { + // Loop through params + for(const auto &p : paramNames) { + if (std::invoke(isHeterogeneous, m_GroupMerged, p)) { + defineField(m_ScalarType, p + suffix, true, false, + [p, getParamValues](const auto &g, size_t) + { + const auto &values = getParamValues(g); + return Utils::writePreciseString(values.at(p)); + }); + } + else { + defineField(m_ScalarType, p + suffix, true, false); + } + } + } + + template + void defineHeterogeneousDerivedParams(const Snippet::Base::DerivedParamVec &derivedParams, const std::string &suffix, + D getDerivedParamValues, H isHeterogeneous) + { + // Loop through derived params + for(const auto &d : derivedParams) { + if (std::invoke(isHeterogeneous, m_GroupMerged, d.name)) { + defineField(m_ScalarType, d.name + suffix, true, false, + [d, getDerivedParamValues](const auto &g, size_t) + { + const auto &values = getDerivedParamValues(g); + return Utils::writePreciseString(values.at(d.name)); + }); + } + else { + defineField(m_ScalarType, d.name + suffix, true, false); + } + } + } + + void defineVars(const Models::Base::VarVec &vars, const std::string &arrayPrefix) + { + // Loop through variables + for(const auto &v : vars) { + definePointerField(Type::parseNumeric(v.type), v.name, arrayPrefix, getVarAccessMode(v.access)); + } + } + + template + void defineVarReferences(const Models::Base::VarRefVec &varReferences, const std::string &arrayPrefix, V getVarRefFn) + { + // Loop through variables + for(const auto &v : varReferences) { + const auto *type = Type::parseNumeric(v.type); + defineField(type, v.name, (v.access & VarAccessModeAttribute::READ_ONLY), false, + type->getPointerType(), v.name, + [arrayPrefix, getVarRefFn, v](const auto &g, size_t) + { + const auto varRef = getVarRefFn(g).at(v.name); + return arrayPrefix + varRef.getVar().name + varRef.getTargetName(); + }); + } + } + + void defineEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "") + { + for(const auto &e : egps) { + const auto *type = Type::parseNumericPtr(e.type); + defineField(type, e.name, false, false, + type, e.name + varName, + [arrayPrefix, e, varName](const auto &g, size_t) + { + return arrayPrefix + e.name + varName + g.getName(); + }, + GroupMergedFieldType::DYNAMIC); + } + } + +private: + //--------------------------------------------------------------------------- + // Private methods + //--------------------------------------------------------------------------- + void addField(std::pair> &type) + { + // If this type has an associated field + if (type.second) { + // Call function to add field to underlying merge group + std::apply(&G::addField, std::tuple_cat(std::make_tuple(m_GroupMerged), + *type.second)); + + // Reset optional field so it doesn't get added again + type.second.reset(); + } + } + //--------------------------------------------------------------------------- + // Members + //--------------------------------------------------------------------------- + G &m_GroupMerged; + const Type::NumericBase *m_ScalarType; + EnvironmentBase *m_Enclosing; + + std::unordered_map>> m_Types; +}; +} // namespace GeNN::CodeGenerator \ No newline at end of file diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 1c86927564..c864dcad6a 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -1,6 +1,7 @@ #include "code_generator/customUpdateGroupMerged.h" // GeNN code generator includes +#include "code_generator/groupMergedTypeEnvironment.h" #include "code_generator/modelSpecMerged.h" // GeNN transpiler includes @@ -19,186 +20,6 @@ using namespace GeNN::Transpiler; //-------------------------------------------------------------------------- namespace { -template -class GroupMergedTypeEnvironment : public TypeChecker::EnvironmentBase -{ -public: - GroupMergedTypeEnvironment(G &groupMerged, const Type::NumericBase *scalarType, - TypeChecker::EnvironmentBase *enclosing = nullptr) - : m_GroupMerged(groupMerged), m_ScalarType(scalarType), m_Enclosing(enclosing) - { - } - - //--------------------------------------------------------------------------- - // EnvironmentBase virtuals - //--------------------------------------------------------------------------- - virtual void define(const Token &name, const Type::QualifiedType &, ErrorHandlerBase &errorHandler) final - { - errorHandler.error(name, "Cannot declare variable in external environment"); - throw TypeChecker::TypeCheckError(); - } - - virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, - ErrorHandlerBase &errorHandler, bool initializer) final - { - // If type isn't found - auto existingType = m_Types.find(name.lexeme); - if(existingType == m_Types.end()) { - if(m_Enclosing) { - return m_Enclosing->assign(name, op, assignedType, errorHandler, initializer); - } - else { - errorHandler.error(name, "Undefined variable"); - throw TypeChecker::TypeCheckError(); - } - } - - // Perform standard type-checking logic - return EnvironmentBase::assign(name, op, existingType->second, assignedType, errorHandler, initializer); - } - - virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final - { - auto existingType = m_Types.find(name.lexeme); - if(existingType == m_Types.end()) { - if(m_Enclosing) { - return m_Enclosing->incDec(name, op, errorHandler); - } - else { - errorHandler.error(name, "Undefined variable"); - throw TypeChecker::TypeCheckError(); - } - } - - // Perform standard type-checking logic - return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); - - } - - virtual const Type::QualifiedType &getType(const Token &name, ErrorHandlerBase &errorHandler) final - { - auto type = m_Types.find(std::string{name.lexeme}); - if(type == m_Types.end()) { - if(m_Enclosing) { - return m_Enclosing->getType(name, errorHandler); - } - else { - errorHandler.error(name, "Undefined variable"); - throw TypeChecker::TypeCheckError(); - } - } - else { - return type->second; - } - } - - //--------------------------------------------------------------------------- - // Public API - //--------------------------------------------------------------------------- - void define(std::string_view name, const Type::Base *type, bool isConstValue = false, bool isConstPointer = false) - { - if(!m_Types.try_emplace(name, type, isConstValue, isConstPointer).second) { - throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); - } - } - template - void define(std::string_view name, bool isConstValue = false, bool isConstPointer = false) - { - define(name, T::getInstance(), isConstValue, isConstPointer); - } - - template - void addHeterogeneousParams(const Snippet::Base::StringVec ¶mNames, const std::string &suffix, - P getParamValues, H isHeterogeneous) - { - // Loop through params - for(const auto &p : paramNames) { - // Define constant - define(p + suffix, m_ScalarType, true); - - // If parameters is heterogeneous - if((static_cast(m_GroupMerged).*isHeterogeneous)(p)) { - // Add field - m_GroupMerged.addScalarField(p + suffix, - [p, getParamValues](const typename G::GroupInternal &g, size_t) - { - const auto &values = getParamValues(g); - return Utils::writePreciseString(values.at(p)); - }); - } - } - } - - template - void addHeterogeneousDerivedParams(const Snippet::Base::DerivedParamVec &derivedParams, const std::string &suffix, - D getDerivedParamValues, H isHeterogeneous) - { - // Loop through derived params - for(const auto &d : derivedParams) { - // If parameters isn't homogeneous - if((static_cast(m_GroupMerged).*isHeterogeneous)(d.name)) { - // Define constant - define(d.name + suffix, m_ScalarType, true); - - // Add field - m_GroupMerged.addScalarField(d.name + suffix, - [d, getDerivedParamValues](const typename G::GroupInternal &g, size_t) - { - const auto &values = getDerivedParamValues(g); - return Utils::writePreciseString(values.at(d.name)); - }); - } - } - } - - void addVars(const Models::Base::VarVec &vars, const std::string &arrayPrefix) - { - // Loop through variables - for(const auto &v : vars) { - const auto *type = Type::parseNumeric(v.type); - define(v.name, type, (v.access & VarAccessModeAttribute::READ_ONLY)); - m_GroupMerged.addPointerField(type, v.name, arrayPrefix + v.name); - } - } - - template - void addVarReferences(const Models::Base::VarRefVec &varReferences, const std::string &arrayPrefix, V getVarRefFn) - { - // Loop through variables - for(const auto &v : varReferences) { - const auto *type = Type::parseNumeric(v.type); - define(v.name, type, (v.access & VarAccessModeAttribute::READ_ONLY)); - m_GroupMerged.addField(type->getPointerType(), v.name, - [getVarRefFn, arrayPrefix, v](const typename G::GroupInternal &g, size_t) - { - const auto varRef = getVarRefFn(g).at(v.name); - return arrayPrefix + varRef.getVar().name + varRef.getTargetName(); - }); - } - } - - void addEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "") - { - for(const auto &e : egps) { - const auto *type = Type::parseNumericPtr(e.type); - define(e.name, type); - m_GroupMerged.addField(type, e.name + varName, - [e, arrayPrefix, varName](const typename G::GroupInternal &g, size_t) { return arrayPrefix + e.name + varName + g.getName(); }, - GroupMergedFieldType::DYNAMIC); - } - } - -private: - //--------------------------------------------------------------------------- - // Members - //--------------------------------------------------------------------------- - G &m_GroupMerged; - const Type::NumericBase *m_ScalarType; - TypeChecker::EnvironmentBase *m_Enclosing; - - std::unordered_map m_Types; -}; - template void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, const ModelSpecMerged &modelMerged, const std::string &index, @@ -312,26 +133,33 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string // Add heterogeneous custom update model parameters const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); - typeEnvironment.addHeterogeneousParams( + typeEnvironment.defineHeterogeneousParams( cm->getParamNames(), "", [](const CustomUpdateInternal &cg) { return cg.getParams(); }, &CustomUpdateGroupMerged::isParamHeterogeneous); // Add heterogeneous weight update model derived parameters - typeEnvironment.addHeterogeneousDerivedParams( + typeEnvironment.defineHeterogeneousDerivedParams( cm->getDerivedParams(), "", [](const CustomUpdateInternal &cg) { return cg.getDerivedParams(); }, &CustomUpdateGroupMerged::isDerivedParamHeterogeneous); // Add variables to struct - typeEnvironment.addVars(cm->getVars(), backend.getDeviceVarPrefix()); + typeEnvironment.defineVars(cm->getVars(), backend.getDeviceVarPrefix()); // Add variable references to struct - typeEnvironment.addVarReferences(cm->getVarRefs(), backend.getDeviceVarPrefix(), + typeEnvironment.defineVarReferences(cm->getVarRefs(), backend.getDeviceVarPrefix(), [](const CustomUpdateInternal &cg) { return cg.getVarReferences(); }); - // Add EGPs to struct - typeEnvironment.addEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); + // Add EGPs to struct + typeEnvironment.defineEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); + + // Scan, parse and type-check update code + Transpiler::ErrorHandler errorHandler; + const auto tokens = Transpiler::Scanner::scanSource(cm->getUpdateCode(), errorHandler); + const auto statements = Transpiler::Parser::parseBlockItemList(tokens, errorHandler); + Transpiler::TypeChecker::typeCheck(statements, typeEnvironment, errorHandler); + } //---------------------------------------------------------------------------- bool CustomUpdateGroupMerged::isParamHeterogeneous(const std::string ¶mName) const @@ -524,24 +352,24 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // Add heterogeneous custom update model parameters const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); - typeEnvironment.addHeterogeneousParams( + typeEnvironment.defineHeterogeneousParams( cm->getParamNames(), "", [](const CustomUpdateWUInternal &cg) { return cg.getParams(); }, &CustomUpdateWUGroupMergedBase::isParamHeterogeneous); // Add heterogeneous weight update model derived parameters - typeEnvironment.addHeterogeneousDerivedParams( + typeEnvironment.defineHeterogeneousDerivedParams( cm->getDerivedParams(), "", [](const CustomUpdateWUInternal &cg) { return cg.getDerivedParams(); }, &CustomUpdateWUGroupMergedBase::isDerivedParamHeterogeneous); // Add variables to struct - typeEnvironment.addVars(cm->getVars(), backend.getDeviceVarPrefix()); + typeEnvironment.defineVars(cm->getVars(), backend.getDeviceVarPrefix()); // Add variable references to struct const auto varRefs = cm->getVarRefs(); - typeEnvironment.addVarReferences(varRefs, backend.getDeviceVarPrefix(), - [](const CustomUpdateWUInternal &cg) { return cg.getVarReferences(); }); + typeEnvironment.defineVarReferences(varRefs, backend.getDeviceVarPrefix(), + [](const CustomUpdateWUInternal &cg) { return cg.getVarReferences(); }); // Loop through variables for(const auto &v : varRefs) { @@ -557,7 +385,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const } } // Add EGPs to struct - typeEnvironment.addEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); + typeEnvironment.defineEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); } // ---------------------------------------------------------------------------- diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index 9f40ed920d..aa133b54ce 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -67,6 +67,7 @@ + diff --git a/src/genn/genn/transpiler/errorHandler.cc b/src/genn/genn/transpiler/errorHandler.cc new file mode 100644 index 0000000000..92c8339cf0 --- /dev/null +++ b/src/genn/genn/transpiler/errorHandler.cc @@ -0,0 +1,55 @@ +#include "transpiler/errorHandler.h" + +// GeNN includes +#include "logging.h" + +//---------------------------------------------------------------------------- +// GeNN::Transpiler::ErrorHandler +//---------------------------------------------------------------------------- +namespace GeNN::Transpiler +{ +void ErrorHandler::error(size_t line, std::string_view message) +{ + report(line, "", message); +} +//---------------------------------------------------------------------------- +void ErrorHandler::error(const Token &token, std::string_view message) +{ + if(token.type == Token::Type::END_OF_FILE) { + report(token.line, " at end", message); + } + else { + report(token.line, " at '" + std::string{token.lexeme} + "'", message); + } +} +//---------------------------------------------------------------------------- +void ErrorHandler::report(size_t line, std::string_view where, std::string_view message) +{ + LOGE_TRANSPILER << "[line " << line << "] Error" << where << ": " << message; + m_Error = true; +} + +//---------------------------------------------------------------------------- +// GeNN::Transpiler::SingleLineErrorHandler +//---------------------------------------------------------------------------- +void SingleLineErrorHandler::error(size_t line, std::string_view message) +{ + report("", message); +} +//---------------------------------------------------------------------------- +void SingleLineErrorHandler::error(const Token &token, std::string_view message) +{ + if(token.type == Token::Type::END_OF_FILE) { + report(" at end", message); + } + else { + report(" at '" + std::string{token.lexeme} + "'", message); + } +} +//---------------------------------------------------------------------------- +void SingleLineErrorHandler::report(std::string_view where, std::string_view message) +{ + LOGE_TRANSPILER << "Error" << where << ": " << message; + m_Error = true; +} +} // namespace GeNN::Transpiler \ No newline at end of file From 8cc3a5ddc7f1ada2839a262de268bb6b1c509135 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 13 Jan 2023 13:31:04 +0000 Subject: [PATCH 038/725] * updated all merged groups to use type system * removed a bunch of scalar EGP crap * Used auto for long lambda function parameters --- .../genn/genn/code_generator/backendBase.h | 30 +-- .../genn/genn/code_generator/groupMerged.h | 36 ++- .../groupMergedTypeEnvironment.h | 7 +- .../genn/code_generator/initGroupMerged.h | 2 +- .../genn/code_generator/modelSpecMerged.h | 37 +-- src/genn/backends/cuda/backend.cc | 1 + .../customConnectivityUpdateGroupMerged.cc | 127 +++++----- .../code_generator/customUpdateGroupMerged.cc | 117 ++++----- .../genn/code_generator/generateModules.cc | 41 ---- src/genn/genn/code_generator/groupMerged.cc | 225 +++++++++--------- .../genn/code_generator/initGroupMerged.cc | 155 ++++++------ .../genn/code_generator/modelSpecMerged.cc | 3 +- .../code_generator/neuronUpdateGroupMerged.cc | 43 ++-- .../synapseUpdateGroupMerged.cc | 10 +- 14 files changed, 417 insertions(+), 417 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 046a208f2b..797a4b67d1 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -50,6 +50,11 @@ class SynapseConnectivityInitGroupMerged; class SynapseInitGroupMerged; class SynapseSparseInitGroupMerged; } + +namespace Type +{ +class Base; +} } //-------------------------------------------------------------------------- @@ -197,35 +202,26 @@ class GENN_EXPORT BackendBase //! Generate platform-specific function to update the state of all neurons /*! \param os CodeStream to write function to \param modelMerged merged model to generate code for - \param preambleHandler callback to write functions for pushing extra-global parameters - \param pushEGPHandler callback to write required extra-global parameter pushing code to start of neuronUpdate function*/ - virtual void genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const = 0; + \param preambleHandler callback to write functions for pushing extra-global parameters*/ + virtual void genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; //! Generate platform-specific function to update the state of all synapses /*! \param os CodeStream to write function to \param modelMerged merged model to generate code for - \param preambleHandler callback to write functions for pushing extra-global parameters - \param pushEGPHandler callback to write required extra-global parameter pushing code to start of synapseUpdate function*/ - virtual void genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const = 0; + \param preambleHandler callback to write functions for pushing extra-global parameters*/ + virtual void genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; //! Generate platform-specific functions to perform custom updates /*! \param os CodeStream to write function to \param modelMerged merged model to generate code for - \param preambleHandler callback to write functions for pushing extra-global parameters - \param pushEGPHandler callback to write required extra-global parameter pushing code to start of customUpdate function*/ - virtual void genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const = 0; + \param preambleHandler callback to write functions for pushing extra-global parameters*/ + virtual void genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; //! Generate platform-specific function to initialise model /*! \param os CodeStream to write function to \param modelMerged merged model to generate code for - \param preambleHandler callback to write functions for pushing extra-global parameters - \param initPushEGPHandler callback to write required extra-global parameter pushing code to start of initialize function - \param initSparsePushEGPHandler callback to write required extra-global parameter pushing code to start of initializeSparse function*/ - virtual void genInit(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler initPushEGPHandler, HostHandler initSparsePushEGPHandler) const = 0; + \param preambleHandler callback to write functions for pushing extra-global parameters*/ + virtual void genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; //! Gets the stride used to access synaptic matrix rows, taking into account sparse data structure, padding etc virtual size_t getSynapticMatrixRowStride(const SynapseGroupInternal &sg) const = 0; diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index d9effbf26c..3e8279c7b1 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -248,6 +248,13 @@ class GroupMerged m_Fields.emplace_back(type, name, getFieldValue, fieldType); } + template + void addField(const std::string &name, GetFieldValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) + { + // Add field to data structure + m_Fields.emplace_back(T::getInstance(), name, getFieldValue, fieldType); + } + void addScalarField(const std::string &name, GetFieldValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) { addField(m_ScalarType, name, @@ -263,6 +270,12 @@ class GroupMerged addField(type->getPointerType(), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); } + template>* = nullptr> + void addPointerField(const std::string &name, const std::string &prefix) + { + addField(T::getInstance()->getPointerType(), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); + } + void addVars(const Models::Base::VarVec &vars, const std::string &arrayPrefix) { @@ -303,6 +316,7 @@ class GroupMerged // Loop through params for(const auto &p : paramNames) { // If parameters is heterogeneous + // **TODO** std::invoke if((static_cast(this)->*isHeterogeneous)(p)) { // Add field addScalarField(p + suffix, @@ -322,6 +336,7 @@ class GroupMerged // Loop through derived params for(const auto &d : derivedParams) { // If parameters isn't homogeneous + // **TODO** std::invoke if((static_cast(this)->*isHeterogeneous)(d.name)) { // Add field addScalarField(d.name + suffix, @@ -816,11 +831,10 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged(*type.second), std::get<1>(*type.second), + std::get<2>(*type.second), std::get<3>(*type.second)); // Reset optional field so it doesn't get added again type.second.reset(); diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 00772fedcf..4929d2a1c1 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -254,7 +254,7 @@ class CustomUpdateInitGroupMergedBase : public GroupMerged // If we're not initialising or if there is initialization code for this variable const auto &varInit = archetypeAdaptor.getVarInitialisers().at(var.name); if (!varInit.getSnippet()->getCode().empty()) { - this->addPointerField(var.type, var.name, backend.getDeviceVarPrefix() + var.name); + this->addPointerField(Type::parseNumeric(var.type), var.name, backend.getDeviceVarPrefix() + var.name); } // Add any var init EGPs to structure diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 3619863464..23a0ef45ec 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -40,11 +40,11 @@ class GENN_EXPORT ModelSpecMerged //! Immutable structure for tracking fields of merged group structure containing EGPs struct EGPField { - EGPField(size_t m, const std::string &t, const std::string &f, bool h) + EGPField(size_t m, const Type::Base *t, const std::string &f, bool h) : mergedGroupIndex(m), type(t), fieldName(f), hostGroup(h) {} const size_t mergedGroupIndex; - const std::string type; + const Type::Base *type; const std::string fieldName; const bool hostGroup; @@ -52,8 +52,8 @@ class GENN_EXPORT ModelSpecMerged //! lexicographically compares all three struct members bool operator < (const EGPField &other) const { - return (std::tie(mergedGroupIndex, type, fieldName, hostGroup) - < std::tie(other.mergedGroupIndex, other.type, other.fieldName, other.hostGroup)); + return (std::make_tuple(mergedGroupIndex, type->getTypeHash(), fieldName, hostGroup) + < std::make_tuple(other.mergedGroupIndex, other.type->getTypeHash(), other.fieldName, other.hostGroup)); } }; @@ -63,7 +63,7 @@ class GENN_EXPORT ModelSpecMerged //! Immutable structure for tracking where an extra global variable ends up after merging struct MergedEGP : public EGPField { - MergedEGP(size_t m, size_t g, const std::string &t, const std::string &f, bool h) + MergedEGP(size_t m, size_t g, const Type::Base *t, const std::string &f, bool h) : EGPField(m, t, f, h), groupIndex(g) {} const size_t groupIndex; @@ -226,26 +226,6 @@ class GENN_EXPORT ModelSpecMerged return m_MergedEGPs.at(name); } - //! Generate calls to update all target merged groups - template - void genScalarEGPPush(CodeStream &os, const BackendBase &backend) const - { - // Loop through all merged EGPs - for(const auto &e : m_MergedEGPs) { - // Loop through all destination structures with this suffix - const auto groupEGPs = e.second.equal_range(T::name); - for(auto g = groupEGPs.first; g != groupEGPs.second; ++g) { - // If EGP is scalar, generate code to copy - if(!Utils::isTypePointer(g->second.type)) { - backend.genMergedExtraGlobalParamPush(os, T::name, g->second.mergedGroupIndex, - std::to_string(g->second.groupIndex), - g->second.fieldName, e.first); - } - - } - } - } - // Get set of unique fields referenced in a merged group template std::set getMergedGroupFields() const @@ -283,14 +263,17 @@ class GENN_EXPORT ModelSpecMerged for(auto f : mergedGroupFields) { // If EGP is a pointer // **NOTE** this is common to all references! - if(Utils::isTypePointer(f.type)) { - os << "void pushMerged" << T::name << f.mergedGroupIndex << f.fieldName << "ToDevice(unsigned int idx, " << backend.getMergedGroupFieldHostType(f.type) << " value)"; + if(dynamic_cast(f.type)) { + os << "void pushMerged" << T::name << f.mergedGroupIndex << f.fieldName << "ToDevice(unsigned int idx, " << backend.getMergedGroupFieldHostTypeName(f.type) << " value)"; { CodeStream::Scope b(os); backend.genMergedExtraGlobalParamPush(os, T::name, f.mergedGroupIndex, "idx", f.fieldName, "value"); } os << std::endl; } + else { + assert(false); + } } } } diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 92fa08a8ca..43b8259711 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -7,6 +7,7 @@ // GeNN includes #include "gennUtils.h" #include "logging.h" +#include "type.h" // GeNN code generator includes #include "code_generator/codeStream.h" diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index e11ad99c27..f04fa716c5 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -16,15 +16,17 @@ CustomConnectivityUpdateGroupMergedBase::CustomConnectivityUpdateGroupMergedBase const std::vector> &groups) : GroupMerged(index, precision, groups) { - addField("unsigned int", "numSrcNeurons", - [](const CustomConnectivityUpdateInternal &cg, size_t) + using namespace Type; + + addField("numSrcNeurons", + [](const auto &cg, size_t) { const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); }); - addField("unsigned int", "numTrgNeurons", - [](const CustomConnectivityUpdateInternal &cg, size_t) + addField("numTrgNeurons", + [](const auto &cg, size_t) { const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); return std::to_string(sgInternal->getTrgNeuronGroup()->getNumNeurons()); @@ -33,13 +35,13 @@ CustomConnectivityUpdateGroupMergedBase::CustomConnectivityUpdateGroupMergedBase // Add heterogeneous custom update model parameters addHeterogeneousParams( getArchetype().getCustomConnectivityUpdateModel()->getParamNames(), "", - [](const CustomConnectivityUpdateInternal &cg) { return cg.getParams(); }, + [](const auto &cg) { return cg.getParams(); }, &CustomConnectivityUpdateGroupMergedBase::isParamHeterogeneous); // Add heterogeneous weight update model CustomConnectivityUpdateGroupMerged parameters addHeterogeneousDerivedParams( getArchetype().getCustomConnectivityUpdateModel()->getDerivedParams(), "", - [](const CustomConnectivityUpdateInternal &cg) { return cg.getDerivedParams(); }, + [](const auto &cg) { return cg.getDerivedParams(); }, &CustomConnectivityUpdateGroupMergedBase::isDerivedParamHeterogeneous); } //---------------------------------------------------------------------------- @@ -62,6 +64,8 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t const std::vector> &groups) : CustomConnectivityUpdateGroupMergedBase(index, precision, groups) { + using namespace Type; + // Reserve vector of vectors to hold variables to update for all custom connectivity update groups, in archetype order m_SortedDependentVars.reserve(getGroups().size()); @@ -73,7 +77,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // Convert to list and sort // **NOTE** WUVarReferences are non-assignable so can't be sorted in a vector std::list dependentVarsList(dependentVars.cbegin(), dependentVars.cend()); - dependentVarsList.sort([](const Models::WUVarReference &a, const Models::WUVarReference &b) + dependentVarsList.sort([](const auto &a, const auto &b) { boost::uuids::detail::sha1 hashA; Utils::updateHash(a.getVar().type, hashA); @@ -98,48 +102,49 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t })); - addField("unsigned int", "rowStride", - [&backend](const CustomConnectivityUpdateInternal &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(backend.getSynapticMatrixRowStride(*sgInternal)); - }); + addField("rowStride", + [&backend](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(backend.getSynapticMatrixRowStride(*sgInternal)); + }); assert(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE); - addField(getArchetype().getSynapseGroup()->getSparseIndType() + "*", "ind", - [&backend](const CustomConnectivityUpdateInternal &cg, size_t) - { - return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); - }); - - addField("unsigned int*", "rowLength", - [&backend](const CustomConnectivityUpdateInternal &cg, size_t) - { - return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); - }); + addField(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())->getPointerType(), "ind", + [&backend](const auto &cg, size_t) + { + return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); + }); + + addField("rowLength", + [&backend](const auto &cg, size_t) + { + return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); + }); // If some presynaptic variables are delayed, add delay pointer if (getArchetype().getPreDelayNeuronGroup() != nullptr) { - addField("unsigned int*", "preSpkQuePtr", - [&backend](const CustomConnectivityUpdateInternal &cg, size_t) - { - return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPreDelayNeuronGroup()->getName(); - }); + addField("preSpkQuePtr", + [&backend](const auto &cg, size_t) + { + return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPreDelayNeuronGroup()->getName(); + }); } // If some postsynaptic variables are delayed, add delay pointer if (getArchetype().getPostDelayNeuronGroup() != nullptr) { - addField("unsigned int*", "postSpkQuePtr", - [&backend](const CustomConnectivityUpdateInternal &cg, size_t) - { - return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPostDelayNeuronGroup()->getName(); - }); + addField("postSpkQuePtr", + [&backend](const auto &cg, size_t) + { + return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPostDelayNeuronGroup()->getName(); + }); } // If this backend requires per-population RNGs and this group requires one if(backend.isPopulationRNGRequired() && getArchetype().isRowSimRNGRequired()){ - addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); + assert(false); + //addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); } // Add variables to struct @@ -150,11 +155,11 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // Add variable references to struct addVarReferences(cm->getVarRefs(), backend.getDeviceVarPrefix(), - [](const CustomConnectivityUpdateInternal &cg) { return cg.getVarReferences(); }); + [](const auto &cg) { return cg.getVarReferences(); }); addVarReferences(cm->getPreVarRefs(), backend.getDeviceVarPrefix(), - [](const CustomConnectivityUpdateInternal &cg) { return cg.getPreVarReferences(); }); + [](const auto &cg) { return cg.getPreVarReferences(); }); addVarReferences(cm->getPostVarRefs(), backend.getDeviceVarPrefix(), - [](const CustomConnectivityUpdateInternal &cg) { return cg.getPostVarReferences(); }); + [](const auto &cg) { return cg.getPostVarReferences(); }); // Add EGPs to struct this->addEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); @@ -162,8 +167,8 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // Loop through sorted dependent variables for(size_t i = 0; i < getSortedArchetypeDependentVars().size(); i++) { - addField(getSortedArchetypeDependentVars().at(i).getVar().type + "*", "_dependentVar" + std::to_string(i), - [i, &backend, this](const CustomConnectivityUpdateInternal&, size_t g) + addField(parseNumeric(getSortedArchetypeDependentVars().at(i).getVar().type)->getPointerType(), "_dependentVar" + std::to_string(i), + [i, &backend, this](const auto&, size_t g) { const auto &varRef = m_SortedDependentVars[g][i]; return backend.getDeviceVarPrefix() + varRef.getVar().name + varRef.getTargetName(); @@ -179,22 +184,22 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateGroupMerged::get Utils::updateHash(getArchetype().getHashDigest(), hash); // Update hash with sizes of pre and postsynaptic neuron groups - updateHash([](const CustomConnectivityUpdateInternal &cg) + updateHash([](const auto &cg) { return static_cast(cg.getSynapseGroup())->getSrcNeuronGroup()->getNumNeurons(); }, hash); - updateHash([](const CustomConnectivityUpdateInternal &cg) + updateHash([](const auto &cg) { return static_cast(cg.getSynapseGroup())->getTrgNeuronGroup()->getNumNeurons(); }, hash); // Update hash with each group's parameters, derived parameters and variable references - updateHash([](const CustomConnectivityUpdateInternal &cg) { return cg.getParams(); }, hash); - updateHash([](const CustomConnectivityUpdateInternal &cg) { return cg.getDerivedParams(); }, hash); - updateHash([](const CustomConnectivityUpdateInternal &cg) { return cg.getVarReferences(); }, hash); - updateHash([](const CustomConnectivityUpdateInternal &cg) { return cg.getPreVarReferences(); }, hash); - updateHash([](const CustomConnectivityUpdateInternal &cg) { return cg.getPostVarReferences(); }, hash); + updateHash([](const auto &cg) { return cg.getParams(); }, hash); + updateHash([](const auto &cg) { return cg.getDerivedParams(); }, hash); + updateHash([](const auto &cg) { return cg.getVarReferences(); }, hash); + updateHash([](const auto &cg) { return cg.getPreVarReferences(); }, hash); + updateHash([](const auto &cg) { return cg.getPostVarReferences(); }, hash); return hash.get_digest(); } @@ -427,6 +432,8 @@ CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged const std::vector> &groups) : CustomConnectivityUpdateGroupMergedBase(index, precision, groups) { + using namespace Type; + // Add pre and postsynaptic variables const auto *cm = getArchetype().getCustomConnectivityUpdateModel(); addVars(backend, cm->getPreVars(), &CustomConnectivityUpdateInternal::getPreVarLocation); @@ -434,13 +441,14 @@ CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged // Add host extra global parameters for(const auto &e : cm->getExtraGlobalParams()) { - addField(e.type, e.name, - [e](const CustomConnectivityUpdateInternal &g, size_t) { return e.name + g.getName(); }, + const auto *pointerType = parseNumericPtr(e.type); + addField(pointerType, e.name, + [e](const auto &g, size_t) { return e.name + g.getName(); }, GroupMergedFieldType::HOST_DYNAMIC); - if(Utils::isTypePointer(e.type) && !backend.getDeviceVarPrefix().empty()) { - addField(e.type, backend.getDeviceVarPrefix() + e.name, - [e, &backend](const CustomConnectivityUpdateInternal &g, size_t) + if(!backend.getDeviceVarPrefix().empty()) { + addField(pointerType, backend.getDeviceVarPrefix() + e.name, + [e, &backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + e.name + g.getName(); }, @@ -549,17 +557,20 @@ void CustomConnectivityHostUpdateGroupMerged::addVarPushPullFuncSubs(const Backe void CustomConnectivityHostUpdateGroupMerged::addVars(const BackendBase &backend, const Models::Base::VarVec &vars, VarLocation(CustomConnectivityUpdateInternal:: *getVarLocationFn)(const std::string&) const) { + using namespace Type; + // Loop through variables for(const auto &v : vars) { // If var is located on the host - if ((getArchetype().*getVarLocationFn)(v.name) & VarLocation::HOST) { - addField(v.type + "*", v.name, - [v](const CustomConnectivityUpdateInternal &g, size_t) { return v.name + g.getName(); }, + if (std::invoke(getVarLocationFn, getArchetype(), v.name) & VarLocation::HOST) { + addField(parseNumeric(v.type)->getPointerType(), v.name, + [v](const auto &g, size_t) { return v.name + g.getName(); }, GroupMergedFieldType::HOST); - if(!backend.getDeviceVarPrefix().empty()) { - addField(v.type + "*", backend.getDeviceVarPrefix() + v.name, - [v, &backend](const CustomConnectivityUpdateInternal &g, size_t) + if(!backend.getDeviceVarPrefix().empty()) { + // **TODO** I think could use addPointerField + addField(parseNumeric(v.type)->getPointerType(), backend.getDeviceVarPrefix() + v.name, + [v, &backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + v.name + g.getName(); }); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index c864dcad6a..0aefbb2e2b 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -115,20 +115,21 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string const std::vector> &groups) : GroupMerged(index, precision, groups) { + using namespace Type; + // Create type environment // **TEMP** parse precision to get scalar type GroupMergedTypeEnvironment typeEnvironment(*this, Type::parseNumeric(precision)); - addField(Type::Uint32::getInstance(), "size", - [](const CustomUpdateInternal &c, size_t) { return std::to_string(c.getSize()); }); + addField("size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); // If some variables are delayed, add delay pointer if(getArchetype().getDelayNeuronGroup() != nullptr) { - addField(Type::Uint32Ptr::getInstance(), "spkQuePtr", - [&backend](const CustomUpdateInternal &cg, size_t) - { - return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); - }); + addField("spkQuePtr", + [&backend](const auto &cg, size_t) + { + return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); + }); } // Add heterogeneous custom update model parameters @@ -293,6 +294,8 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const const std::vector> &groups) : GroupMerged(index, precision, groups) { + using namespace Type; + // Create type environment // **TEMP** parse precision to get scalar type GroupMergedTypeEnvironment typeEnvironment(*this, Type::parseNumeric(precision)); @@ -303,50 +306,50 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const for (size_t d = 0; d < getArchetype().getSynapseGroup()->getKernelSize().size(); d++) { // If this dimension has a heterogeneous size, add it to struct if (isKernelSizeHeterogeneous(d)) { - addField(Type::Uint32::getInstance(), "kernelSize" + std::to_string(d), - [d](const CustomUpdateWUInternal &cu, size_t) - { - return std::to_string(cu.getSynapseGroup()->getKernelSize().at(d)); - }); + addField("kernelSize" + std::to_string(d), + [d](const auto &cu, size_t) + { + return std::to_string(cu.getSynapseGroup()->getKernelSize().at(d)); + }); } } } // Otherwise else { - addField(Type::Uint32::getInstance(), "rowStride", - [&backend](const CustomUpdateWUInternal &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(backend.getSynapticMatrixRowStride(*sgInternal)); - }); + addField("rowStride", + [&backend](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(backend.getSynapticMatrixRowStride(*sgInternal)); + }); - addField(Type::Uint32::getInstance(), "numSrcNeurons", - [](const CustomUpdateWUInternal &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); - }); - - addField(Type::Uint32::getInstance(), "numTrgNeurons", - [](const CustomUpdateWUInternal &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(sgInternal->getTrgNeuronGroup()->getNumNeurons()); - }); + addField("numSrcNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); + }); + + addField("numTrgNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getTrgNeuronGroup()->getNumNeurons()); + }); // If synapse group has sparse connectivity if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { addField(Type::parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())->getPointerType(), "ind", - [&backend](const CustomUpdateWUInternal &cg, size_t) + [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); }); - addField(Type::Uint32Ptr::getInstance(), "rowLength", - [&backend](const CustomUpdateWUInternal &cg, size_t) - { - return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); - }); + addField("rowLength", + [&backend](const auto &cg, size_t) + { + return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); + }); } } @@ -354,13 +357,13 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); typeEnvironment.defineHeterogeneousParams( cm->getParamNames(), "", - [](const CustomUpdateWUInternal &cg) { return cg.getParams(); }, + [](const auto &cg) { return cg.getParams(); }, &CustomUpdateWUGroupMergedBase::isParamHeterogeneous); // Add heterogeneous weight update model derived parameters typeEnvironment.defineHeterogeneousDerivedParams( cm->getDerivedParams(), "", - [](const CustomUpdateWUInternal &cg) { return cg.getDerivedParams(); }, + [](const auto &cg) { return cg.getDerivedParams(); }, &CustomUpdateWUGroupMergedBase::isDerivedParamHeterogeneous); // Add variables to struct @@ -369,7 +372,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // Add variable references to struct const auto varRefs = cm->getVarRefs(); typeEnvironment.defineVarReferences(varRefs, backend.getDeviceVarPrefix(), - [](const CustomUpdateWUInternal &cg) { return cg.getVarReferences(); }); + [](const auto &cg) { return cg.getVarReferences(); }); // Loop through variables for(const auto &v : varRefs) { @@ -377,7 +380,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const if(getArchetype().getVarReferences().at(v.name).getTransposeSynapseGroup() != nullptr) { // Add field with transpose suffix, pointing to transpose var addField(Type::parseNumeric(v.type)->getPointerType(), v.name + "Transpose", - [&backend, v](const CustomUpdateWUInternal &g, size_t) + [&backend, v](const auto &g, size_t) { const auto varRef = g.getVarReferences().at(v.name); return backend.getDeviceVarPrefix() + varRef.getTransposeVar().name + varRef.getTransposeTargetName(); @@ -396,7 +399,7 @@ const std::string CustomUpdateWUGroupMerged::name = "CustomUpdateWU"; void CustomUpdateWUGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { genCustomUpdate(os, popSubs, *this, modelMerged, "id_syn", - [this, &modelMerged](const Models::WUVarReference &varRef, const std::string &index) + [this, &modelMerged](const auto &varRef, const std::string &index) { return getVarRefIndex(getVarAccessDuplication(varRef.getVar().access), index); @@ -411,7 +414,7 @@ const std::string CustomUpdateTransposeWUGroupMerged::name = "CustomUpdateTransp void CustomUpdateTransposeWUGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { genCustomUpdate(os, popSubs, *this, modelMerged, "id_syn", - [this, &modelMerged](const Models::WUVarReference &varRef, const std::string &index) + [this, &modelMerged](const auto &varRef, const std::string &index) { return getVarRefIndex(getVarAccessDuplication(varRef.getVar().access), index); @@ -427,17 +430,19 @@ CustomUpdateHostReductionGroupMerged::CustomUpdateHostReductionGroupMerged(size_ const std::vector> &groups) : CustomUpdateHostReductionGroupMergedBase(index, precision, backend, groups) { - addField(Type::Uint32::getInstance(), "size", - [](const CustomUpdateInternal &c, size_t) { return std::to_string(c.getSize()); }); + using namespace Type; + + addField("size", + [](const auto &c, size_t) { return std::to_string(c.getSize()); }); // If some variables are delayed, add delay pointer // **NOTE** this is HOST delay pointer if(getArchetype().getDelayNeuronGroup() != nullptr) { - addField(Type::Uint32Ptr::getInstance(), "spkQuePtr", - [&](const CustomUpdateInternal &cg, size_t) - { - return "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); - }); + addField("spkQuePtr", + [&](const auto &cg, size_t) + { + return "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); + }); } } @@ -450,9 +455,11 @@ CustomWUUpdateHostReductionGroupMerged::CustomWUUpdateHostReductionGroupMerged(s const std::vector> &groups) : CustomUpdateHostReductionGroupMergedBase(index, precision, backend, groups) { - addField(Type::Uint32::getInstance(), "size", - [&backend](const CustomUpdateWUInternal &cg, size_t) - { - return std::to_string(cg.getSynapseGroup()->getMaxConnections() * (size_t)cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); - }); + using namespace Type; + + addField("size", + [&backend](const auto &cg, size_t) + { + return std::to_string(cg.getSynapseGroup()->getMaxConnections() * (size_t)cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); + }); } diff --git a/src/genn/genn/code_generator/generateModules.cc b/src/genn/genn/code_generator/generateModules.cc index 4805727513..4923b6ef35 100644 --- a/src/genn/genn/code_generator/generateModules.cc +++ b/src/genn/genn/code_generator/generateModules.cc @@ -203,11 +203,6 @@ void generateNeuronUpdate(const filesystem::path &outputPath, const ModelSpecMer // Generate functions to push merged neuron group structures modelMerged.genMergedGroupPush(os, modelMerged.getMergedNeuronSpikeQueueUpdateGroups(), backend); modelMerged.genMergedGroupPush(os, modelMerged.getMergedNeuronUpdateGroups(), backend); - }, - // Push EGP handler - [&backend, &modelMerged](CodeStream &os) - { - modelMerged.genScalarEGPPush(os, backend); }); } //-------------------------------------------------------------------------- @@ -231,15 +226,6 @@ void generateCustomUpdate(const filesystem::path &outputPath, const ModelSpecMer modelMerged.genMergedGroupPush(os, modelMerged.getMergedCustomUpdateWUGroups(), backend); modelMerged.genMergedGroupPush(os, modelMerged.getMergedCustomUpdateTransposeWUGroups(), backend); modelMerged.genMergedGroupPush(os, modelMerged.getMergedCustomConnectivityUpdateGroups(), backend); - }, - // Push EGP handler - // **TODO** this needs to be per-update group - [&backend, &modelMerged](CodeStream &os) - { - modelMerged.genScalarEGPPush(os, backend); - modelMerged.genScalarEGPPush(os, backend); - modelMerged.genScalarEGPPush(os, backend); - modelMerged.genScalarEGPPush(os, backend); }); } //-------------------------------------------------------------------------- @@ -265,13 +251,6 @@ void generateSynapseUpdate(const filesystem::path &outputPath, const ModelSpecMe modelMerged.genMergedGroupPush(os, modelMerged.getMergedPresynapticUpdateGroups(), backend); modelMerged.genMergedGroupPush(os, modelMerged.getMergedPostsynapticUpdateGroups(), backend); modelMerged.genMergedGroupPush(os, modelMerged.getMergedSynapseDynamicsGroups(), backend); - }, - // Push EGP handler - [&backend, &modelMerged](CodeStream &os) - { - modelMerged.genScalarEGPPush(os, backend); - modelMerged.genScalarEGPPush(os, backend); - modelMerged.genScalarEGPPush(os, backend); }); } //-------------------------------------------------------------------------- @@ -301,26 +280,6 @@ void generateInit(const filesystem::path &outputPath, const ModelSpecMerged &mod modelMerged.genMergedGroupPush(os, modelMerged.getMergedSynapseSparseInitGroups(), backend); modelMerged.genMergedGroupPush(os, modelMerged.getMergedCustomWUUpdateSparseInitGroups(), backend); modelMerged.genMergedGroupPush(os, modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups(), backend); - }, - // Initialise push EGP handler - [&backend, &modelMerged](CodeStream &os) - { - modelMerged.genScalarEGPPush(os, backend); - modelMerged.genScalarEGPPush(os, backend); - modelMerged.genScalarEGPPush(os, backend); - modelMerged.genScalarEGPPush(os, backend); - - modelMerged.genScalarEGPPush(os, backend); - modelMerged.genScalarEGPPush(os, backend); - - modelMerged.genScalarEGPPush(os, backend); - }, - // Initialise sparse push EGP handler - [&backend, &modelMerged](CodeStream &os) - { - modelMerged.genScalarEGPPush(os, backend); - modelMerged.genScalarEGPPush(os, backend); - modelMerged.genScalarEGPPush(os, backend); }); } } // namespace GeNN::CodeGenerator \ No newline at end of file diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index fa53ca53d4..ad43e1bd13 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -23,14 +23,16 @@ NeuronSpikeQueueUpdateGroupMerged::NeuronSpikeQueueUpdateGroupMerged(size_t inde const std::vector> &groups) : GroupMerged(index, precision, groups) { + using namespace Type; + if(getArchetype().isDelayRequired()) { - addPointerField("unsigned int", "spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); + addPointerField("spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); } - addPointerField("unsigned int", "spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); + addPointerField("spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); if(getArchetype().isSpikeEventRequired()) { - addPointerField("unsigned int", "spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); + addPointerField("spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); } } //---------------------------------------------------------------------------- @@ -70,28 +72,31 @@ NeuronPrevSpikeTimeUpdateGroupMerged::NeuronPrevSpikeTimeUpdateGroupMerged(size_ const std::vector> &groups) : GroupMerged(index, precision, groups) { + using namespace Type; + if(getArchetype().isDelayRequired()) { - addPointerField("unsigned int", "spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); + addPointerField("spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); } - addPointerField("unsigned int", "spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); + addPointerField("spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); if(getArchetype().isSpikeEventRequired()) { - addPointerField("unsigned int", "spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); + addPointerField("spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); } + const NumericBase *timeType = parseNumeric(timePrecision); if(getArchetype().isPrevSpikeTimeRequired()) { - addPointerField("unsigned int", "spk", backend.getDeviceVarPrefix() + "glbSpk"); - addPointerField(timePrecision, "prevST", backend.getDeviceVarPrefix() + "prevST"); + addPointerField("spk", backend.getDeviceVarPrefix() + "glbSpk"); + addPointerField(timeType, "prevST", backend.getDeviceVarPrefix() + "prevST"); } if(getArchetype().isPrevSpikeEventTimeRequired()) { - addPointerField("unsigned int", "spkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); - addPointerField(timePrecision, "prevSET", backend.getDeviceVarPrefix() + "prevSET"); + addPointerField("spkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); + addPointerField(timeType, "prevSET", backend.getDeviceVarPrefix() + "prevSET"); } if(getArchetype().isDelayRequired()) { - addField("unsigned int", "numNeurons", - [](const NeuronGroupInternal &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); + addField("numNeurons", + [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); } } @@ -183,6 +188,12 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr bool init, const std::vector> &groups) : GroupMerged(index, precision, groups) { + using namespace Type; + + // **HACK** parse precisions + const NumericBase *scalarType = parseNumeric(precision); + const NumericBase *timeType = parseNumeric(timePrecision); + // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group orderNeuronGroupChildren(m_SortedMergedInSyns, &NeuronGroupInternal::getFusedPSMInSyn, init ? &SynapseGroupInternal::getPSInitHashDigest : &SynapseGroupInternal::getPSHashDigest); @@ -195,40 +206,41 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr orderNeuronGroupChildren(m_SortedCurrentSources, &NeuronGroupInternal::getCurrentSources, init ? &CurrentSourceInternal::getInitHashDigest : &CurrentSourceInternal::getHashDigest); - addField("unsigned int", "numNeurons", - [](const NeuronGroupInternal &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); + addField("numNeurons", + [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); - addPointerField("unsigned int", "spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); - addPointerField("unsigned int", "spk", backend.getDeviceVarPrefix() + "glbSpk"); + addPointerField("spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); + addPointerField("spk", backend.getDeviceVarPrefix() + "glbSpk"); if(getArchetype().isSpikeEventRequired()) { - addPointerField("unsigned int", "spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); - addPointerField("unsigned int", "spkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); + addPointerField("spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); + addPointerField("spkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); } if(getArchetype().isDelayRequired()) { - addPointerField("unsigned int", "spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); + addPointerField("spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); } if(getArchetype().isSpikeTimeRequired()) { - addPointerField(timePrecision, "sT", backend.getDeviceVarPrefix() + "sT"); + addPointerField(timeType, "sT", backend.getDeviceVarPrefix() + "sT"); } if(getArchetype().isSpikeEventTimeRequired()) { - addPointerField(timePrecision, "seT", backend.getDeviceVarPrefix() + "seT"); + addPointerField(timeType, "seT", backend.getDeviceVarPrefix() + "seT"); } if(getArchetype().isPrevSpikeTimeRequired()) { - addPointerField(timePrecision, "prevST", backend.getDeviceVarPrefix() + "prevST"); + addPointerField(timeType, "prevST", backend.getDeviceVarPrefix() + "prevST"); } if(getArchetype().isPrevSpikeEventTimeRequired()) { - addPointerField(timePrecision, "prevSET", backend.getDeviceVarPrefix() + "prevSET"); + addPointerField(timeType, "prevSET", backend.getDeviceVarPrefix() + "prevSET"); } // If this backend initialises population RNGs on device and this group requires on for simulation if(backend.isPopulationRNGRequired() && getArchetype().isSimRNGRequired() && (!init || backend.isPopulationRNGInitialisedOnDevice())) { - addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); + assert(false); + //addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); } // Loop through variables @@ -238,7 +250,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : vars) { // If we're not initialising or if there is initialization code for this variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addPointerField(var.type, var.name, backend.getDeviceVarPrefix() + var.name); + addPointerField(parseNumeric(var.type), var.name, backend.getDeviceVarPrefix() + var.name); } // If we're initializing, add any var init EGPs to structure @@ -278,12 +290,12 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr const SynapseGroupInternal *sg = getSortedArchetypeMergedInSyns().at(i); // Add pointer to insyn - addMergedInSynPointerField(precision, "inSynInSyn", i, backend.getDeviceVarPrefix() + "inSyn"); + addMergedInSynPointerField(scalarType, "inSynInSyn", i, backend.getDeviceVarPrefix() + "inSyn"); // Add pointer to dendritic delay buffer if required if(sg->isDendriticDelayRequired()) { - addMergedInSynPointerField(precision, "denDelayInSyn", i, backend.getDeviceVarPrefix() + "denDelay"); - addMergedInSynPointerField("unsigned int", "denDelayPtrInSyn", i, backend.getScalarAddressPrefix() + "denDelayPtr"); + addMergedInSynPointerField(scalarType, "denDelayInSyn", i, backend.getDeviceVarPrefix() + "denDelay"); + addMergedInSynPointerField(Uint32::getInstance(), "denDelayPtrInSyn", i, backend.getScalarAddressPrefix() + "denDelayPtr"); } // Loop through variables @@ -291,7 +303,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : sg->getPSModel()->getVars()) { // Add pointers to state variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addMergedInSynPointerField(var.type, var.name + "InSyn", i, backend.getDeviceVarPrefix() + var.name); + addMergedInSynPointerField(parseNumeric(var.type), var.name + "InSyn", i, backend.getDeviceVarPrefix() + var.name); } // If we're generating an initialization structure, also add any heterogeneous parameters, derived parameters or extra global parameters required for initializers @@ -334,7 +346,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr // Loop through merged output synapses with presynaptic output of archetypical neuron group (0) in sorted order for(size_t i = 0; i < getSortedArchetypeMergedPreOutputOutSyns().size(); i++) { // Add pointer to revInSyn - addMergedPreOutputOutSynPointerField(precision, "revInSynOutSyn", i, backend.getDeviceVarPrefix() + "revInSyn"); + addMergedPreOutputOutSynPointerField(scalarType, "revInSynOutSyn", i, backend.getDeviceVarPrefix() + "revInSyn"); } // Loop through current sources to archetypical neuron group in sorted order @@ -346,9 +358,8 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : cs->getCurrentSourceModel()->getVars()) { // Add pointers to state variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - assert(!Utils::isTypePointer(var.type)); - addField(var.type + "*", var.name + "CS" + std::to_string(i), - [&backend, i, var, this](const NeuronGroupInternal &, size_t groupIndex) + addField(parseNumeric(var.type)->getPointerType(), var.name + "CS" + std::to_string(i), + [&backend, i, var, this](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + m_SortedCurrentSources.at(groupIndex).at(i)->getName(); }); @@ -475,23 +486,21 @@ bool NeuronGroupMergedBase::isPSMVarInitParamReferenced(size_t childIndex, const return isParamReferenced({varInitSnippet->getCode()}, paramName); } //---------------------------------------------------------------------------- -void NeuronGroupMergedBase::addMergedInSynPointerField(const std::string &type, const std::string &name, +void NeuronGroupMergedBase::addMergedInSynPointerField(const Type::NumericBase *type, const std::string &name, size_t archetypeIndex, const std::string &prefix) { - assert(!Utils::isTypePointer(type)); - addField(type + "*", name + std::to_string(archetypeIndex), - [prefix, archetypeIndex, this](const NeuronGroupInternal &, size_t groupIndex) + addField(type->getPointerType(), name + std::to_string(archetypeIndex), + [prefix, archetypeIndex, this](const auto&, size_t groupIndex) { return prefix + m_SortedMergedInSyns.at(groupIndex).at(archetypeIndex)->getFusedPSVarSuffix(); }); } //---------------------------------------------------------------------------- -void NeuronGroupMergedBase::addMergedPreOutputOutSynPointerField(const std::string &type, const std::string &name, - size_t archetypeIndex, const std::string &prefix) +void NeuronGroupMergedBase::addMergedPreOutputOutSynPointerField(const Type::NumericBase *type, const std::string &name, + size_t archetypeIndex, const std::string &prefix) { - assert(!Utils::isTypePointer(type)); - addField(type + "*", name + std::to_string(archetypeIndex), - [prefix, archetypeIndex, this](const NeuronGroupInternal &, size_t groupIndex) + addField(type->getPointerType(), name + std::to_string(archetypeIndex), + [prefix, archetypeIndex, this](const auto&, size_t groupIndex) { return prefix + m_SortedMergedPreOutputOutSyns.at(groupIndex).at(archetypeIndex)->getFusedPreOutputSuffix(); }); @@ -662,6 +671,12 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & Role role, const std::string &archetypeCode, const std::vector> &groups) : GroupMerged(index, precision, groups), m_ArchetypeCode(archetypeCode) { + using namespace Type; + + // **HACK** parse precisions + const NumericBase *scalarType = parseNumeric(precision); + const NumericBase *timeType = parseNumeric(timePrecision); + const bool updateRole = ((role == Role::PresynapticUpdate) || (role == Role::PostsynapticUpdate) || (role == Role::SynapseDynamics)); @@ -669,85 +684,85 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // If role isn't an init role or weights aren't kernel if(role != Role::Init || !(getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL)) { - addField("unsigned int", "rowStride", - [&backend](const SynapseGroupInternal &sg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(sg)); }); - addField("unsigned int", "numSrcNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); - addField("unsigned int", "numTrgNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); + addField("rowStride", + [&backend](const SynapseGroupInternal &sg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(sg)); }); + addField("numSrcNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + addField("numTrgNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); } if(role == Role::PostsynapticUpdate || role == Role::SparseInit) { - addField("unsigned int", "colStride", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getMaxSourceConnections()); }); + addField("colStride", + [](const auto &sg, size_t) { return std::to_string(sg.getMaxSourceConnections()); }); } // If this role is one where postsynaptic input can be provided if(role == Role::PresynapticUpdate || role == Role::SynapseDynamics) { if(getArchetype().isDendriticDelayRequired()) { - addPSPointerField(precision, "denDelay", backend.getDeviceVarPrefix() + "denDelay"); - addPSPointerField("unsigned int", "denDelayPtr", backend.getScalarAddressPrefix() + "denDelayPtr"); + addPSPointerField(scalarType, "denDelay", backend.getDeviceVarPrefix() + "denDelay"); + addPSPointerField(Uint32::getInstance(), "denDelayPtr", backend.getScalarAddressPrefix() + "denDelayPtr"); } else { - addPSPointerField(precision, "inSyn", backend.getDeviceVarPrefix() + "inSyn"); + addPSPointerField(scalarType, "inSyn", backend.getDeviceVarPrefix() + "inSyn"); } } if(role == Role::PresynapticUpdate) { if(getArchetype().isTrueSpikeRequired()) { - addSrcPointerField("unsigned int", "srcSpkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); - addSrcPointerField("unsigned int", "srcSpk", backend.getDeviceVarPrefix() + "glbSpk"); + addSrcPointerField(Uint32::getInstance(), "srcSpkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); + addSrcPointerField(Uint32::getInstance(), "srcSpk", backend.getDeviceVarPrefix() + "glbSpk"); } if(getArchetype().isSpikeEventRequired()) { - addSrcPointerField("unsigned int", "srcSpkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); - addSrcPointerField("unsigned int", "srcSpkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); + addSrcPointerField(Uint32::getInstance(), "srcSpkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); + addSrcPointerField(Uint32::getInstance(), "srcSpkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); } } else if(role == Role::PostsynapticUpdate) { - addTrgPointerField("unsigned int", "trgSpkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); - addTrgPointerField("unsigned int", "trgSpk", backend.getDeviceVarPrefix() + "glbSpk"); + addTrgPointerField(Uint32::getInstance(), "trgSpkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); + addTrgPointerField(Uint32::getInstance(), "trgSpk", backend.getDeviceVarPrefix() + "glbSpk"); } // If this structure is used for updating rather than initializing if(updateRole) { // for all types of roles if (getArchetype().isPresynapticOutputRequired()) { - addPreOutputPointerField(precision, "revInSyn", backend.getDeviceVarPrefix() + "revInSyn"); + addPreOutputPointerField(scalarType, "revInSyn", backend.getDeviceVarPrefix() + "revInSyn"); } // If presynaptic population has delay buffers if(getArchetype().getSrcNeuronGroup()->isDelayRequired()) { - addSrcPointerField("unsigned int", "srcSpkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); + addSrcPointerField(Uint32::getInstance(), "srcSpkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); } // If postsynaptic population has delay buffers if(getArchetype().getTrgNeuronGroup()->isDelayRequired()) { - addTrgPointerField("unsigned int", "trgSpkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); + addTrgPointerField(Uint32::getInstance(), "trgSpkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); } // Add heterogeneous presynaptic neuron model parameters addHeterogeneousParams( getArchetype().getSrcNeuronGroup()->getNeuronModel()->getParamNames(), "Pre", - [](const SynapseGroupInternal &sg) { return sg.getSrcNeuronGroup()->getParams(); }, + [](const auto &sg) { return sg.getSrcNeuronGroup()->getParams(); }, &SynapseGroupMergedBase::isSrcNeuronParamHeterogeneous); // Add heterogeneous presynaptic neuron model derived parameters addHeterogeneousDerivedParams( getArchetype().getSrcNeuronGroup()->getNeuronModel()->getDerivedParams(), "Pre", - [](const SynapseGroupInternal &sg) { return sg.getSrcNeuronGroup()->getDerivedParams(); }, + [](const auto &sg) { return sg.getSrcNeuronGroup()->getDerivedParams(); }, &SynapseGroupMergedBase::isSrcNeuronDerivedParamHeterogeneous); // Add heterogeneous postsynaptic neuron model parameters addHeterogeneousParams( getArchetype().getTrgNeuronGroup()->getNeuronModel()->getParamNames(), "Post", - [](const SynapseGroupInternal &sg) { return sg.getTrgNeuronGroup()->getParams(); }, + [](const auto &sg) { return sg.getTrgNeuronGroup()->getParams(); }, &SynapseGroupMergedBase::isTrgNeuronParamHeterogeneous); // Add heterogeneous postsynaptic neuron model derived parameters addHeterogeneousDerivedParams( getArchetype().getTrgNeuronGroup()->getNeuronModel()->getDerivedParams(), "Post", - [](const SynapseGroupInternal &sg) { return sg.getTrgNeuronGroup()->getDerivedParams(); }, + [](const auto &sg) { return sg.getTrgNeuronGroup()->getDerivedParams(); }, &SynapseGroupMergedBase::isTrgNeuronDerivedParamHeterogeneous); // Get correct code string @@ -758,7 +773,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & for(const auto &v : preVars) { // If variable is referenced in code string, add source pointer if(code.find("$(" + v.name + "_pre)") != std::string::npos) { - addSrcPointerField(v.type, v.name + "Pre", backend.getDeviceVarPrefix() + v.name); + addSrcPointerField(parseNumeric(v.type), v.name + "Pre", backend.getDeviceVarPrefix() + v.name); } } @@ -767,7 +782,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & for(const auto &v : postVars) { // If variable is referenced in code string, add target pointer if(code.find("$(" + v.name + "_post)") != std::string::npos) { - addTrgPointerField(v.type, v.name + "Post", backend.getDeviceVarPrefix() + v.name); + addTrgPointerField(parseNumeric(v.type), v.name + "Post", backend.getDeviceVarPrefix() + v.name); } } @@ -775,9 +790,9 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & const auto preEGPs = getArchetype().getSrcNeuronGroup()->getNeuronModel()->getExtraGlobalParams(); for(const auto &e : preEGPs) { if(code.find("$(" + e.name + "_pre)") != std::string::npos) { - const std::string prefix = Utils::isTypePointer(e.type) ? backend.getDeviceVarPrefix() : ""; - addField(e.type, e.name + "Pre", - [e, prefix](const SynapseGroupInternal &sg, size_t) { return prefix + e.name + sg.getSrcNeuronGroup()->getName(); }, + const std::string prefix = backend.getDeviceVarPrefix(); + addField(parseNumericPtr(e.type), e.name + "Pre", + [e, prefix](const auto &sg, size_t) { return prefix + e.name + sg.getSrcNeuronGroup()->getName(); }, GroupMergedFieldType::DYNAMIC); } } @@ -786,54 +801,56 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & const auto postEGPs = getArchetype().getTrgNeuronGroup()->getNeuronModel()->getExtraGlobalParams(); for(const auto &e : postEGPs) { if(code.find("$(" + e.name + "_post)") != std::string::npos) { - const std::string prefix = Utils::isTypePointer(e.type) ? backend.getDeviceVarPrefix() : ""; - addField(e.type, e.name + "Post", - [e, prefix](const SynapseGroupInternal &sg, size_t) { return prefix + e.name + sg.getTrgNeuronGroup()->getName(); }, + const std::string prefix = backend.getDeviceVarPrefix(); + addField(parseNumericPtr(e.type), e.name + "Post", + [e, prefix](const auto &sg, size_t) { return prefix + e.name + sg.getTrgNeuronGroup()->getName(); }, GroupMergedFieldType::DYNAMIC); } } // Add spike times if required if(wum->isPreSpikeTimeRequired()) { - addSrcPointerField(timePrecision, "sTPre", backend.getDeviceVarPrefix() + "sT"); + addSrcPointerField(timeType, "sTPre", backend.getDeviceVarPrefix() + "sT"); } if(wum->isPostSpikeTimeRequired()) { - addTrgPointerField(timePrecision, "sTPost", backend.getDeviceVarPrefix() + "sT"); + addTrgPointerField(timeType, "sTPost", backend.getDeviceVarPrefix() + "sT"); } if(wum->isPreSpikeEventTimeRequired()) { - addSrcPointerField(timePrecision, "seTPre", backend.getDeviceVarPrefix() + "seT"); + addSrcPointerField(timeType, "seTPre", backend.getDeviceVarPrefix() + "seT"); } if(wum->isPrevPreSpikeTimeRequired()) { - addSrcPointerField(timePrecision, "prevSTPre", backend.getDeviceVarPrefix() + "prevST"); + addSrcPointerField(timeType, "prevSTPre", backend.getDeviceVarPrefix() + "prevST"); } if(wum->isPrevPostSpikeTimeRequired()) { - addTrgPointerField(timePrecision, "prevSTPost", backend.getDeviceVarPrefix() + "prevST"); + addTrgPointerField(timeType, "prevSTPost", backend.getDeviceVarPrefix() + "prevST"); } if(wum->isPrevPreSpikeEventTimeRequired()) { - addSrcPointerField(timePrecision, "prevSETPre", backend.getDeviceVarPrefix() + "prevSET"); + addSrcPointerField(timeType, "prevSETPre", backend.getDeviceVarPrefix() + "prevSET"); } // Add heterogeneous weight update model parameters addHeterogeneousParams( wum->getParamNames(), "", - [](const SynapseGroupInternal &sg) { return sg.getWUParams(); }, + [](const auto &sg) { return sg.getWUParams(); }, &SynapseGroupMergedBase::isWUParamHeterogeneous); // Add heterogeneous weight update model derived parameters addHeterogeneousDerivedParams( wum->getDerivedParams(), "", - [](const SynapseGroupInternal &sg) { return sg.getWUDerivedParams(); }, + [](const auto &sg) { return sg.getWUDerivedParams(); }, &SynapseGroupMergedBase::isWUDerivedParamHeterogeneous); // Add presynaptic variables to struct for(const auto &v : wum->getPreVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(v.type + "*", v.name, [prefix](const SynapseGroupInternal &g, size_t) { return prefix + g.getFusedWUPreVarSuffix(); }); + addField(parseNumeric(v.type)->getPointerType(), v.name, + [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPreVarSuffix(); }); } // Add presynaptic variables to struct for(const auto &v : wum->getPostVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(v.type + "*", v.name, [prefix](const SynapseGroupInternal &g, size_t) { return prefix + g.getFusedWUPostVarSuffix(); }); + addField(parseNumeric(v.type)->getPointerType(), v.name, + [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPostVarSuffix(); }); } // Add EGPs to struct @@ -842,19 +859,19 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // Add pointers to connectivity data if(getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addPointerField("unsigned int", "rowLength", backend.getDeviceVarPrefix() + "rowLength"); - addPointerField(getArchetype().getSparseIndType(), "ind", backend.getDeviceVarPrefix() + "ind"); + addPointerField("rowLength", backend.getDeviceVarPrefix() + "rowLength"); + addPointerField(parseNumeric(getArchetype().getSparseIndType()), "ind", backend.getDeviceVarPrefix() + "ind"); // Add additional structure for postsynaptic access if(backend.isPostsynapticRemapRequired() && !wum->getLearnPostCode().empty() && (role == Role::PostsynapticUpdate || role == Role::SparseInit)) { - addPointerField("unsigned int", "colLength", backend.getDeviceVarPrefix() + "colLength"); - addPointerField("unsigned int", "remap", backend.getDeviceVarPrefix() + "remap"); + addPointerField("colLength", backend.getDeviceVarPrefix() + "colLength"); + addPointerField("remap", backend.getDeviceVarPrefix() + "remap"); } } else if(getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { - addPointerField("uint32_t", "gp", backend.getDeviceVarPrefix() + "gp"); + addPointerField("gp", backend.getDeviceVarPrefix() + "gp"); } // If we're updating a group with procedural connectivity or initialising connectivity @@ -927,8 +944,8 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & for(size_t d = 0; d < getArchetype().getKernelSize().size(); d++) { // If this dimension has a heterogeneous size, add it to struct if(isKernelSizeHeterogeneous(d)) { - addField("unsigned int", "kernelSize" + std::to_string(d), - [d](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getKernelSize().at(d)); }); + addField("kernelSize" + std::to_string(d), + [d](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getKernelSize().at(d)); }); } } } @@ -957,15 +974,15 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // If we're performing an update with individual weights; or this variable should be initialised if((updateRole && individualWeights) || (kernelWeights && updateRole) || varInitRequired) { - addPointerField(var.type, var.name, backend.getDeviceVarPrefix() + var.name); + addPointerField(parseNumeric(var.type), var.name, backend.getDeviceVarPrefix() + var.name); } // If we're performing a procedural update or this variable should be initialised, add any var init EGPs to structure if((proceduralWeights && updateRole) || varInitRequired) { const auto egps = snippet->getExtraGlobalParams(); for(const auto &e : egps) { - const std::string prefix = Utils::isTypePointer(e.type) ? backend.getDeviceVarPrefix() : ""; - addField(e.type, e.name + var.name, + const std::string prefix = backend.getDeviceVarPrefix(); + addField(parseNumericPtr(e.type), e.name + var.name, [e, prefix, var](const SynapseGroupInternal &sg, size_t) { return prefix + e.name + var.name + sg.getName(); @@ -1090,28 +1107,24 @@ boost::uuids::detail::sha1::digest_type SynapseGroupMergedBase::getHashDigest(Ro return hash.get_digest(); } //---------------------------------------------------------------------------- -void SynapseGroupMergedBase::addPSPointerField(const std::string &type, const std::string &name, const std::string &prefix) +void SynapseGroupMergedBase::addPSPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - assert(!Utils::isTypePointer(type)); - addField(type + "*", name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPSVarSuffix(); }); + addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPSVarSuffix(); }); } //---------------------------------------------------------------------------- -void SynapseGroupMergedBase::addPreOutputPointerField(const std::string &type, const std::string &name, const std::string &prefix) +void SynapseGroupMergedBase::addPreOutputPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - assert(!Utils::isTypePointer(type)); - addField(type + "*", name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPreOutputSuffix(); }); + addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPreOutputSuffix(); }); } //---------------------------------------------------------------------------- -void SynapseGroupMergedBase::addSrcPointerField(const std::string &type, const std::string &name, const std::string &prefix) +void SynapseGroupMergedBase::addSrcPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - assert(!Utils::isTypePointer(type)); - addField(type + "*", name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getSrcNeuronGroup()->getName(); }); + addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getSrcNeuronGroup()->getName(); }); } //---------------------------------------------------------------------------- -void SynapseGroupMergedBase::addTrgPointerField(const std::string &type, const std::string &name, const std::string &prefix) +void SynapseGroupMergedBase::addTrgPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - assert(!Utils::isTypePointer(type)); - addField(type + "*", name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getTrgNeuronGroup()->getName(); }); + addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getTrgNeuronGroup()->getName(); }); } //---------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 9b780ae12b..ab76ef350d 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -430,6 +430,8 @@ void NeuronInitGroupMerged::generateWUVar(const BackendBase &backend, bool(NeuronInitGroupMerged::*isDerivedParamHeterogeneousFn)(size_t, const std::string&, const std::string&) const, const std::string&(SynapseGroupInternal::*getFusedVarSuffix)(void) const) { + using namespace Type; + // Loop through synapse groups const auto &archetypeSyns = sortedSyn.front(); for(size_t i = 0; i < archetypeSyns.size(); i++) { @@ -441,9 +443,8 @@ void NeuronInitGroupMerged::generateWUVar(const BackendBase &backend, for(const auto &var : vars) { // Add pointers to state variable if(!varInit.at(var.name).getSnippet()->getCode().empty()) { - assert(!Utils::isTypePointer(var.type)); - addField(var.type + "*", var.name + fieldPrefixStem + std::to_string(i), - [i, var, &backend, &sortedSyn, getFusedVarSuffix](const NeuronGroupInternal &, size_t groupIndex) + addField(parseNumeric(var.type)->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), + [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto&, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); return backend.getDeviceVarPrefix() + var.name + varMergeSuffix; @@ -718,13 +719,15 @@ SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(s const std::vector> &groups) : GroupMerged(index, precision, groups) { + using namespace Type; + // **TODO** these could be generic - addField("unsigned int", "numSrcNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); - addField("unsigned int", "numTrgNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); - addField("unsigned int", "rowStride", - [&backend](const SynapseGroupInternal &sg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(sg)); }); + addField("numSrcNeurons", + [](const auto &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + addField("numTrgNeurons", + [](const auto &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); + addField("rowStride", + [&backend](const auto &sg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(sg)); }); // Add heterogeneous connectivity initialiser model parameters addHeterogeneousParams( @@ -741,7 +744,8 @@ SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(s // Add EGP pointers to struct for both host and device EGPs if they are seperate const auto egps = getArchetype().getConnectivityInitialiser().getSnippet()->getExtraGlobalParams(); for(const auto &e : egps) { - addField(e.type + "*", e.name, + assert(false); + /*addField(e.type + "*", e.name, [e](const SynapseGroupInternal &g, size_t) { return "&" + e.name + g.getName(); }, GroupMergedFieldType::HOST_DYNAMIC); @@ -760,7 +764,7 @@ SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(s return "&" + backend.getHostVarPrefix() + e.name + g.getName(); }, GroupMergedFieldType::DYNAMIC); - } + }*/ } } //------------------------------------------------------------------------- @@ -857,8 +861,8 @@ CustomUpdateInitGroupMerged::CustomUpdateInitGroupMerged(size_t index, const std const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, precision, backend, groups) { - addField("unsigned int", "size", - [](const CustomUpdateInternal &c, size_t) { return std::to_string(c.getSize()); }); + addField("size", + [](const auto &c, size_t) { return std::to_string(c.getSize()); }); } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomUpdateInitGroupMerged::getHashDigest() const @@ -892,23 +896,25 @@ CustomWUUpdateInitGroupMerged::CustomWUUpdateInitGroupMerged(size_t index, const const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, precision, backend, groups) { + using namespace Type; + if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL) { // Loop through kernel size dimensions for (size_t d = 0; d < getArchetype().getSynapseGroup()->getKernelSize().size(); d++) { // If this dimension has a heterogeneous size, add it to struct if (isKernelSizeHeterogeneous(d)) { - addField("unsigned int", "kernelSize" + std::to_string(d), - [d](const CustomUpdateWUInternal &g, size_t) { return std::to_string(g.getSynapseGroup()->getKernelSize().at(d)); }); + addField("kernelSize" + std::to_string(d), + [d](const auto &g, size_t) { return std::to_string(g.getSynapseGroup()->getKernelSize().at(d)); }); } } } else { - addField("unsigned int", "rowStride", - [&backend](const CustomUpdateWUInternal &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); - addField("unsigned int", "numSrcNeurons", - [](const CustomUpdateWUInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); - addField("unsigned int", "numTrgNeurons", - [](const CustomUpdateWUInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); + addField("rowStride", + [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); + addField("numSrcNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); + addField("numTrgNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); } } //---------------------------------------------------------------------------- @@ -921,22 +927,22 @@ boost::uuids::detail::sha1::digest_type CustomWUUpdateInitGroupMerged::getHashDi // If underlying synapse group has kernel weights, update hash with kernel size if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL) { - updateHash([](const CustomUpdateWUInternal &g) { return g.getSynapseGroup()->getKernelSize(); }, hash); + updateHash([](const auto &g) { return g.getSynapseGroup()->getKernelSize(); }, hash); } // Otherwise, update hash with sizes of pre and postsynaptic neuron groups else { - updateHash([](const CustomUpdateWUInternal &cg) + updateHash([](const auto &cg) { return static_cast(cg.getSynapseGroup())->getSrcNeuronGroup()->getNumNeurons(); }, hash); - updateHash([](const CustomUpdateWUInternal &cg) + updateHash([](const auto &cg) { return static_cast(cg.getSynapseGroup())->getTrgNeuronGroup()->getNumNeurons(); }, hash); - updateHash([](const CustomUpdateWUInternal &cg) + updateHash([](const auto &cg) { return static_cast(cg.getSynapseGroup())->getMaxConnections(); }, hash); @@ -1000,22 +1006,24 @@ CustomWUUpdateSparseInitGroupMerged::CustomWUUpdateSparseInitGroupMerged(size_t const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, precision, backend, groups) { - addField("unsigned int", "rowStride", - [&backend](const CustomUpdateWUInternal &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); + using namespace Type; - addField("unsigned int", "numSrcNeurons", - [](const CustomUpdateWUInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); - addField("unsigned int", "numTrgNeurons", - [](const CustomUpdateWUInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); + addField("rowStride", + [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); - addField("unsigned int*", "rowLength", - [&backend](const CustomUpdateWUInternal &cg, size_t) - { - const SynapseGroupInternal *sg = cg.getSynapseGroup(); - return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); - }); - addField(getArchetype().getSynapseGroup()->getSparseIndType() + "*", "ind", - [&backend](const CustomUpdateWUInternal &cg, size_t) + addField("numSrcNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); + addField("numTrgNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); + + addField("rowLength", + [&backend](const auto &cg, size_t) + { + const SynapseGroupInternal *sg = cg.getSynapseGroup(); + return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); + }); + addField(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())->getPointerType(), "ind", + [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); return backend.getDeviceVarPrefix() + "ind" + sg->getName(); @@ -1030,17 +1038,17 @@ boost::uuids::detail::sha1::digest_type CustomWUUpdateSparseInitGroupMerged::get updateBaseHash(hash); // Update hash with sizes of pre and postsynaptic neuron groups; and max row length - updateHash([](const CustomUpdateWUInternal &cg) + updateHash([](const auto &cg) { return static_cast(cg.getSynapseGroup())->getSrcNeuronGroup()->getNumNeurons(); }, hash); - updateHash([](const CustomUpdateWUInternal &cg) + updateHash([](const auto &cg) { return static_cast(cg.getSynapseGroup())->getTrgNeuronGroup()->getNumNeurons(); }, hash); - updateHash([](const CustomUpdateWUInternal& cg) + updateHash([](const auto& cg) { return cg.getSynapseGroup()->getMaxConnections(); }, hash); @@ -1070,15 +1078,18 @@ CustomConnectivityUpdatePreInitGroupMerged::CustomConnectivityUpdatePreInitGroup const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, precision, backend, groups) { - addField("unsigned int", "size", - [](const CustomConnectivityUpdateInternal &c, size_t) - { - return std::to_string(c.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); - }); + using namespace Type; + + addField("size", + [](const auto &c, size_t) + { + return std::to_string(c.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); + }); // If this backend initialises population RNGs on device and this group requires one for simulation if(backend.isPopulationRNGRequired() && getArchetype().isRowSimRNGRequired() && backend.isPopulationRNGInitialisedOnDevice()) { - addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); + assert(false); + //addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); } } //---------------------------------------------------------------------------- @@ -1090,7 +1101,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePreInitGroupMerg updateBaseHash(hash); // Update hash with size of custom update - updateHash([](const CustomConnectivityUpdateInternal &cg) + updateHash([](const auto &cg) { return cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(); }, hash); @@ -1116,11 +1127,11 @@ CustomConnectivityUpdatePostInitGroupMerged::CustomConnectivityUpdatePostInitGro const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, precision, backend, groups) { - addField("unsigned int", "size", - [](const CustomConnectivityUpdateInternal &c, size_t) - { - return std::to_string(c.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); - }); + addField("size", + [](const auto &c, size_t) + { + return std::to_string(c.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); + }); } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePostInitGroupMerged::getHashDigest() const @@ -1131,7 +1142,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePostInitGroupMer updateBaseHash(hash); // Update hash with size of custom update - updateHash([](const CustomConnectivityUpdateInternal &cg) + updateHash([](const auto &cg) { return cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons(); }, hash); @@ -1157,22 +1168,24 @@ CustomConnectivityUpdateSparseInitGroupMerged::CustomConnectivityUpdateSparseIni const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, precision, backend, groups) { - addField("unsigned int", "rowStride", - [&backend](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); - - addField("unsigned int", "numSrcNeurons", - [](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); - addField("unsigned int", "numTrgNeurons", - [](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); - - addField("unsigned int*", "rowLength", - [&backend](const CustomConnectivityUpdateInternal &cg, size_t) - { - const SynapseGroupInternal *sg = cg.getSynapseGroup(); - return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); - }); - addField(getArchetype().getSynapseGroup()->getSparseIndType() + "*", "ind", - [&backend](const CustomConnectivityUpdateInternal &cg, size_t) + using namespace Type; + + addField("rowStride", + [&backend](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); + + addField("numSrcNeurons", + [](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); + addField("numTrgNeurons", + [](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); + + addField("rowLength", + [&backend](const CustomConnectivityUpdateInternal &cg, size_t) + { + const SynapseGroupInternal *sg = cg.getSynapseGroup(); + return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); + }); + addField(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())->getPointerType(), "ind", + [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); return backend.getDeviceVarPrefix() + "ind" + sg->getName(); diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index 63358d44b3..0e4d7ab788 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -568,10 +568,11 @@ bool ModelSpecMerged::anyPointerEGPs() const // Loop through grouped merged EGPs for(const auto &e : m_MergedEGPs) { // If there's any pointer EGPs, return true + // **TODO** without scalar EGPS, all EGPS are pointer EGPS! if(std::any_of(e.second.cbegin(), e.second.cend(), [](const MergedEGPDestinations::value_type &g) { - return Utils::isTypePointer(g.second.type); + return dynamic_cast(g.second.type); })) { return true; diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index eca0d61207..2825d4b6cb 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -15,6 +15,8 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string const std::vector> &groups) : NeuronGroupMergedBase(index, precision, timePrecision, backend, false, groups) { + using namespace Type; + // Build vector of vectors containing each child group's incoming synapse groups // with postsynaptic updates, ordered to match those of the archetype group orderNeuronGroupChildren(m_SortedInSynWithPostCode, &NeuronGroupInternal::getFusedInSynWithPostCode, @@ -63,9 +65,9 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string for(const auto &egp : sgEGPs) { // If EGP is referenced in event threshold code if(s.eventThresholdCode.find("$(" + egp.name + ")") != std::string::npos) { - const std::string prefix = Utils::isTypePointer(egp.type) ? backend.getDeviceVarPrefix() : ""; - addField(egp.type, egp.name + "EventThresh" + std::to_string(i), - [eventThresholdSGs, prefix, egp, i](const NeuronGroupInternal &, size_t groupIndex) + const std::string prefix = backend.getDeviceVarPrefix(); + addField(parseNumericPtr(egp.type), egp.name + "EventThresh" + std::to_string(i), + [eventThresholdSGs, prefix, egp, i](const auto &, size_t groupIndex) { return prefix + egp.name + eventThresholdSGs.at(groupIndex).at(i)->getName(); }, @@ -78,8 +80,8 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string for(const auto &var : sgPreVars) { // If variable is referenced in event threshold code if(s.eventThresholdCode.find("$(" + var.name + ")") != std::string::npos) { - addField(var.type + "*", var.name + "EventThresh" + std::to_string(i), - [&backend, eventThresholdSGs, var, i](const NeuronGroupInternal &, size_t groupIndex) + addField(parseNumeric(var.type)->getPointerType(), var.name + "EventThresh" + std::to_string(i), + [&backend, eventThresholdSGs, var, i](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + eventThresholdSGs.at(groupIndex).at(i)->getName(); }); @@ -91,24 +93,22 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string if(getArchetype().isSpikeRecordingEnabled()) { // Add field for spike recording - // **YUCK** this mechanism needs to be renamed from PointerEGP to RuntimeAlloc - addField("uint32_t*", "recordSpk", - [&backend](const NeuronGroupInternal &ng, size_t) - { - return backend.getDeviceVarPrefix() + "recordSpk" + ng.getName(); - }, - GroupMergedFieldType::DYNAMIC); + addField("recordSpk", + [&backend](const auto &ng, size_t) + { + return backend.getDeviceVarPrefix() + "recordSpk" + ng.getName(); + }, + GroupMergedFieldType::DYNAMIC); } if(getArchetype().isSpikeEventRecordingEnabled()) { // Add field for spike event recording - // **YUCK** this mechanism needs to be renamed from PointerEGP to RuntimeAlloc - addField("uint32_t*", "recordSpkEvent", - [&backend](const NeuronGroupInternal &ng, size_t) - { - return backend.getDeviceVarPrefix() + "recordSpkEvent" + ng.getName(); - }, - GroupMergedFieldType::DYNAMIC); + addField("recordSpkEvent", + [&backend](const auto &ng, size_t) + { + return backend.getDeviceVarPrefix() + "recordSpkEvent" + ng.getName(); + }, + GroupMergedFieldType::DYNAMIC); } } @@ -716,9 +716,8 @@ void NeuronUpdateGroupMerged::generateWUVar(const BackendBase &backend, const s for(size_t v = 0; v < vars.size(); v++) { // Add pointers to state variable const auto var = vars[v]; - assert(!Utils::isTypePointer(var.type)); - addField(var.type + "*", var.name + fieldPrefixStem + std::to_string(i), - [i, var, &backend, &sortedSyn, getFusedVarSuffix](const NeuronGroupInternal &, size_t groupIndex) + addField(Type::parseNumeric(var.type)->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), + [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto &, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); return backend.getDeviceVarPrefix() + var.name + varMergeSuffix; diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 936f004b4c..9689dbf193 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -325,9 +325,9 @@ SynapseDendriticDelayUpdateGroupMerged::SynapseDendriticDelayUpdateGroupMerged(s const std::vector> &groups) : GroupMerged(index, precision, groups) { - addField("unsigned int*", "denDelayPtr", - [&backend](const SynapseGroupInternal &sg, size_t) - { - return backend.getScalarAddressPrefix() + "denDelayPtr" + sg.getFusedPSVarSuffix(); - }); + addField("denDelayPtr", + [&backend](const SynapseGroupInternal &sg, size_t) + { + return backend.getScalarAddressPrefix() + "denDelayPtr" + sg.getFusedPSVarSuffix(); + }); } From 38b5a66412a1c44d62f8c65ad6b1682100bc4ee0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 13 Jan 2023 16:58:14 +0000 Subject: [PATCH 039/725] fixed missing includes --- .../genn/code_generator/groupMergedTypeEnvironment.h | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index 6b9dabb907..95850a2548 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -7,6 +7,7 @@ #include "code_generator/groupMerged.h" // GeNN transpiler includes +#include "transpiler/errorHandler.h" #include "transpiler/typeChecker.h" //---------------------------------------------------------------------------- @@ -20,6 +21,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa using Token = Transpiler::Token; using ErrorHandlerBase = Transpiler::ErrorHandlerBase; using EnvironmentBase = Transpiler::TypeChecker::EnvironmentBase; + using TypeCheckError = Transpiler::TypeChecker::TypeCheckError; public: GroupMergedTypeEnvironment(G &groupMerged, const Type::NumericBase *scalarType, @@ -34,7 +36,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa virtual void define(const Transpiler::Token &name, const Type::QualifiedType &, ErrorHandlerBase &errorHandler) final { errorHandler.error(name, "Cannot declare variable in external environment"); - throw TypeChecker::TypeCheckError(); + throw TypeCheckError(); } virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, @@ -48,7 +50,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa } else { errorHandler.error(name, "Undefined variable"); - throw TypeChecker::TypeCheckError(); + throw TypeCheckError(); } } @@ -68,7 +70,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa } else { errorHandler.error(name, "Undefined variable"); - throw TypeChecker::TypeCheckError(); + throw TypeCheckError(); } } @@ -88,7 +90,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa } else { errorHandler.error(name, "Undefined variable"); - throw TypeChecker::TypeCheckError(); + throw TypeCheckError(); } } else { @@ -249,4 +251,4 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa std::unordered_map>> m_Types; }; -} // namespace GeNN::CodeGenerator \ No newline at end of file +} // namespace GeNN::CodeGenerator From ad06aafe5a967d74faa21976d30c384d193db935 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 13 Jan 2023 16:58:30 +0000 Subject: [PATCH 040/725] removed scalar EGP pushing and pulling from derived backends --- include/genn/backends/cuda/backend.h | 12 +++------ include/genn/backends/opencl/backend.h | 12 +++------ .../backends/single_threaded_cpu/backend.h | 12 +++------ src/genn/backends/cuda/backend.cc | 27 +++---------------- src/genn/backends/opencl/backend.cc | 27 +++---------------- .../backends/single_threaded_cpu/backend.cc | 27 +++---------------- 6 files changed, 24 insertions(+), 93 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index 5044e1cd7e..02b67358ac 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -171,17 +171,13 @@ class BACKEND_EXPORT Backend : public BackendSIMT //-------------------------------------------------------------------------- // CodeGenerator::BackendBase virtuals //-------------------------------------------------------------------------- - virtual void genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const override; + virtual void genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; - virtual void genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const override; + virtual void genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; - virtual void genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const override; + virtual void genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; - virtual void genInit(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler initPushEGPHandler, HostHandler initSparsePushEGPHandler) const override; + virtual void genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; virtual void genDefinitionsPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; virtual void genDefinitionsInternalPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; diff --git a/include/genn/backends/opencl/backend.h b/include/genn/backends/opencl/backend.h index 8c7629820f..55b89c7e39 100644 --- a/include/genn/backends/opencl/backend.h +++ b/include/genn/backends/opencl/backend.h @@ -134,17 +134,13 @@ class BACKEND_EXPORT Backend : public BackendSIMT //-------------------------------------------------------------------------- // CodeGenerator::BackendBase:: virtuals //-------------------------------------------------------------------------- - virtual void genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const override; + virtual void genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; - virtual void genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const override; + virtual void genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; - virtual void genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const override; + virtual void genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; - virtual void genInit(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler initPushEGPHandler, HostHandler initSparsePushEGPHandler) const override; + virtual void genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; virtual void genDefinitionsPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; virtual void genDefinitionsInternalPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 27c3fe8927..c19db36852 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -41,17 +41,13 @@ class BACKEND_EXPORT Backend : public BackendBase //-------------------------------------------------------------------------- // CodeGenerator::BackendBase virtuals //-------------------------------------------------------------------------- - virtual void genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const override; + virtual void genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; - virtual void genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const override; + virtual void genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; - virtual void genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const override; + virtual void genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; - virtual void genInit(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler initPushEGPHandler, HostHandler initSparsePushEGPHandler) const override; + virtual void genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; virtual size_t getSynapticMatrixRowStride(const SynapseGroupInternal &sg) const override; diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 43b8259711..a5c0c32411 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -399,8 +399,7 @@ void Backend::genGlobalRNGSkipAhead(CodeStream &os, Substitutions &subs, const s subs.addVarSubstitution(name, "&localRNG"); } //-------------------------------------------------------------------------- -void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const +void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); @@ -492,9 +491,6 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - // Push any required EGPS - pushEGPHandler(os); - if(idNeuronPrevSpikeTimeUpdate > 0) { CodeStream::Scope b(os); genKernelDimensions(os, KernelNeuronPrevSpikeTimeUpdate, idNeuronPrevSpikeTimeUpdate, model.getBatchSize()); @@ -523,8 +519,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged } } //-------------------------------------------------------------------------- -void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const +void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { // Generate struct definitions modelMerged.genMergedSynapseDendriticDelayUpdateStructs(os, *this); @@ -637,9 +632,6 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge { CodeStream::Scope b(os); - // Push any required EGPs - pushEGPHandler(os); - // Launch pre-synapse reset kernel if required if(idSynapseDendricDelayUpdate > 0) { CodeStream::Scope b(os); @@ -680,8 +672,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge } } //-------------------------------------------------------------------------- -void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const +void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); @@ -801,9 +792,6 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged } } - // Push any required EGPs - pushEGPHandler(os); - // Launch custom update kernel if required if(idCustomUpdateStart > 0) { CodeStream::Scope b(os); @@ -868,8 +856,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged } } //-------------------------------------------------------------------------- -void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler initPushEGPHandler, HostHandler initSparsePushEGPHandler) const +void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { os << "#include " << std::endl; os << "#include " << std::endl; @@ -1029,9 +1016,6 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, } } - // Push any required EGPs - initPushEGPHandler(os); - // If there are any initialisation threads if(idInitStart > 0) { CodeStream::Scope b(os); @@ -1049,9 +1033,6 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, { CodeStream::Scope b(os); - // Push any required EGPs - initSparsePushEGPHandler(os); - // Copy all uninitialised state variables to device if(!getPreferences().automaticCopy) { os << "copyStateToDevice(true);" << std::endl; diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index ca11216228..933b025ee4 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -267,8 +267,7 @@ void Backend::genGlobalRNGSkipAhead(CodeStream &os, Substitutions &subs, const s subs.addVarSubstitution(name, "&localStream"); } //-------------------------------------------------------------------------- -void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const +void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { // Generate reset kernel to be run before the neuron kernel const ModelSpecInternal &model = modelMerged.getModel(); @@ -484,9 +483,6 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - // Push any required EGPS - pushEGPHandler(os); - if (idNeuronPrevSpikeTimeUpdate > 0) { CodeStream::Scope b(os); os << "CHECK_OPENCL_ERRORS(" << KernelNames[KernelNeuronPrevSpikeTimeUpdate] << ".setArg(" << modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups().size() << ", t));" << std::endl; @@ -517,8 +513,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged } } //-------------------------------------------------------------------------- -void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const +void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { // Generate reset kernel to be run before the neuron kernel const ModelSpecInternal &model = modelMerged.getModel(); @@ -776,9 +771,6 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge { CodeStream::Scope b(os); - // Push any required EGPs - pushEGPHandler(os); - // Launch pre-synapse reset kernel if required if (idSynapseDendriticDelayUpdate > 0) { CodeStream::Scope b(os); @@ -827,8 +819,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge } } //-------------------------------------------------------------------------- -void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const +void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); @@ -991,9 +982,6 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged } } - // Push any required EGPs - pushEGPHandler(os); - // If there are any custom update work-items if(g.second.first > 0) { CodeStream::Scope b(os); @@ -1099,8 +1087,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged } } //-------------------------------------------------------------------------- -void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler initPushEGPHandler, HostHandler initSparsePushEGPHandler) const +void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { // Generate reset kernel to be run before the neuron kernel const ModelSpecInternal &model = modelMerged.getModel(); @@ -1353,9 +1340,6 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, } os << std::endl; - // Push any required EGPs - initPushEGPHandler(os); - // If there are any initialisation work-items if (idInitStart > 0) { CodeStream::Scope b(os); @@ -1389,9 +1373,6 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, { CodeStream::Scope b(os); - // Push any required EGPs - initSparsePushEGPHandler(os); - // Copy all uninitialised state variables to device os << "copyStateToDevice(true);" << std::endl; os << "copyConnectivityToDevice(true);" << std::endl; diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index e2702088c2..d15e1765e2 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -121,8 +121,7 @@ void genKernelIteration(CodeStream &os, const G &g, size_t numKernelDims, const //-------------------------------------------------------------------------- namespace GeNN::CodeGenerator::SingleThreadedCPU { -void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const +void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); if(model.getBatchSize() != 1) { @@ -154,9 +153,6 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged funcSubs.addVarSubstitution("t", "t"); funcSubs.addVarSubstitution("batch", "0"); - // Push any required EGPs - pushEGPHandler(os); - Timer t(os, "neuronUpdate", model.isTimingEnabled()); // Loop through merged previous spike time update groups @@ -293,8 +289,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged } } //-------------------------------------------------------------------------- -void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const +void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); if(model.getBatchSize() != 1) { @@ -325,9 +320,6 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge funcSubs.addVarSubstitution("t", "t"); funcSubs.addVarSubstitution("batch", "0"); - // Push any required EGPs - pushEGPHandler(os); - // Synapse dynamics { // Loop through merged synapse dynamics groups @@ -496,8 +488,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge } } //-------------------------------------------------------------------------- -void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler pushEGPHandler) const +void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); @@ -545,9 +536,6 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged } } - // Push any required EGPs - pushEGPHandler(os); - { Timer t(os, "customUpdate" + g, model.isTimingEnabled()); @@ -782,8 +770,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged } } //-------------------------------------------------------------------------- -void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, - HostHandler preambleHandler, HostHandler initPushEGPHandler, HostHandler initSparsePushEGPHandler) const +void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); if(model.getBatchSize() != 1) { @@ -822,9 +809,6 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, CodeStream::Scope b(os); Substitutions funcSubs(getFunctionTemplates(model.getPrecision())); - // Push any required EGPs - initPushEGPHandler(os); - Timer t(os, "init", model.isTimingEnabled()); // If model requires a host RNG, add RNG to substitutions @@ -1080,9 +1064,6 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, CodeStream::Scope b(os); Substitutions funcSubs(getFunctionTemplates(model.getPrecision())); - // Push any required EGPs - initSparsePushEGPHandler(os); - Timer t(os, "initSparse", model.isTimingEnabled()); // If model requires RNG, add it to substitutions From f041c7bd89eb12dcb6d799d5733994f6dead03b1 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 13 Jan 2023 17:03:08 +0000 Subject: [PATCH 041/725] improved type parsing errors --- src/genn/genn/transpiler/parser.cc | 4 ++-- src/genn/genn/type.cc | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 5e093a90c6..163a1a164e 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -871,7 +871,7 @@ const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPo ParserState parserState(tokens, errorHandler); bool pointerFound = false; std::set typeSpecifiers; - do { + while(parserState.match({Token::Type::TYPE_SPECIFIER, Token::Type::STAR})) { // If token is a star, set pointer found flag if(parserState.previous().type == Token::Type::STAR) { if (!allowPointers) { @@ -888,7 +888,7 @@ const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPo parserState.error(parserState.previous(), "duplicate type specifier"); } } - } while(parserState.match({Token::Type::TYPE_SPECIFIER, Token::Type::STAR})); + }; // Lookup type based on whether token was found return (pointerFound diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index a70b8f6c49..8b50645632 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -85,12 +85,12 @@ const NumericBase *parseNumeric(std::string_view typeString) // If an error was encountered while scanning or parsing, throw exception if (errorHandler.hasError()) { - throw std::runtime_error("Error parsing type"); + throw std::runtime_error("Error parsing type '" + std::string{typeString} + "'"); } // If tokens did not contain a valid numeric type, throw exception if (!type) { - throw std::runtime_error("Unable to parse type"); + throw std::runtime_error("Unable to parse type '" + std::string{typeString} + "'"); } return type; } @@ -108,12 +108,12 @@ const NumericPtrBase *parseNumericPtr(std::string_view typeString) // If an error was encountered while scanning or parsing, throw exception if (errorHandler.hasError()) { - throw std::runtime_error("Error parsing type"); + throw std::runtime_error("Error parsing type '" + std::string{typeString} + "'"); } // If tokens did not contain a valid numeric type, throw exception if (!type) { - throw std::runtime_error("Unable to parse type"); + throw std::runtime_error("Unable to parse type '" + std::string{typeString} + "'"); } return type; } @@ -192,4 +192,4 @@ const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b) } } } -} // namespace GeNN::Type \ No newline at end of file +} // namespace GeNN::Type From bddddeeed205756c3c9a61780010cf3b3744c4d6 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 10:10:31 +0000 Subject: [PATCH 042/725] type system update part 1 (broken) - pointer types now created dynamically --- include/genn/genn/type.h | 128 +++++++----------- .../customConnectivityUpdateGroupMerged.cc | 2 +- src/genn/genn/transpiler/parser.cc | 34 +++-- src/genn/genn/transpiler/typeChecker.cc | 54 ++++---- src/genn/genn/type.cc | 36 +---- 5 files changed, 108 insertions(+), 146 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index ddfe7d414a..6a133189e6 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -12,12 +12,15 @@ #include #include +// GeNN includes +#include "gennExport.h" + //---------------------------------------------------------------------------- // Macros //---------------------------------------------------------------------------- #define DECLARE_TYPE(TYPE) \ private: \ - /*GENN_EXPORT*/ static TYPE *s_Instance; \ + GENN_EXPORT static TYPE *s_Instance; \ public: \ static const TYPE *getInstance() \ { \ @@ -34,20 +37,11 @@ DECLARE_TYPE(TYPE) \ virtual std::string getTypeName() const{ return #UNDERLYING_TYPE; } \ }; \ - class TYPE##Ptr : public NumericPtr \ - { \ - DECLARE_TYPE(TYPE##Ptr) \ - }; \ template<> \ struct TypeTraits \ { \ using NumericType = TYPE; \ - }; \ - template<> \ - struct TypeTraits \ - { \ - using NumericPtrType = TYPE##Ptr; \ - } + } #define DECLARE_FOREIGN_FUNCTION_TYPE(TYPE, RETURN_TYPE, ...) \ class TYPE : public ForeignFunction \ @@ -56,7 +50,7 @@ } #define IMPLEMENT_TYPE(TYPE) TYPE *TYPE::s_Instance = NULL -#define IMPLEMENT_NUMERIC_TYPE(TYPE) IMPLEMENT_TYPE(TYPE); IMPLEMENT_TYPE(TYPE##Ptr) +#define IMPLEMENT_NUMERIC_TYPE(TYPE) IMPLEMENT_TYPE(TYPE) //---------------------------------------------------------------------------- // GeNN::Type::TypeTraits @@ -69,6 +63,19 @@ struct TypeTraits { }; +//---------------------------------------------------------------------------- +// GeNN::Type::Qualifier +//---------------------------------------------------------------------------- +enum class Qualifier : unsigned int +{ + CONSTT = (1 << 0) +}; + +inline bool operator & (Qualifier a, Qualifier b) +{ + return (static_cast(a) & static_cast(b)) != 0; +} + //---------------------------------------------------------------------------- // GeNN::Type::Base //---------------------------------------------------------------------------- @@ -80,7 +87,35 @@ class Base // Declared virtuals //------------------------------------------------------------------------ virtual std::string getTypeName() const = 0; - virtual size_t getTypeHash() const = 0; +}; + +//---------------------------------------------------------------------------- +// GeNN::Type::Pointer +//---------------------------------------------------------------------------- +//! Type representing a pointer +class Pointer : Base +{ +public: + Pointer(const Base *valueType) + : m_ValueType(valueType) + { + } + + //------------------------------------------------------------------------ + // Base virtuals + //------------------------------------------------------------------------ + virtual std::string getTypeName() const{ return getValueType()->getTypeName() + "*";} + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + const Base *getValueType() const{ return m_ValueType; } + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + const Base *m_ValueType; }; //---------------------------------------------------------------------------- @@ -114,20 +149,6 @@ class NumericBase : public Base virtual double getLowest() const = 0; virtual bool isSigned() const = 0; virtual bool isIntegral() const = 0; - - virtual const class NumericPtrBase *getPointerType() const = 0; -}; - -//---------------------------------------------------------------------------- -// GeNN::NumericPtrBase -//---------------------------------------------------------------------------- -class NumericPtrBase : public Base -{ -public: - //------------------------------------------------------------------------ - // Declared virtuals - //------------------------------------------------------------------------ - virtual const NumericBase *getValueType() const = 0; }; //---------------------------------------------------------------------------- @@ -156,30 +177,6 @@ class Numeric : public NumericBase virtual double getLowest() const final { return std::numeric_limits::lowest(); } virtual bool isSigned() const final { return std::is_signed::value; } virtual bool isIntegral() const final { return std::is_integral::value; } - - virtual const NumericPtrBase *getPointerType() const - { - return TypeTraits>::NumericPtrType::getInstance(); - } -}; - -//---------------------------------------------------------------------------- -// GeNN::NumericPtr -//---------------------------------------------------------------------------- -template -class NumericPtr : public NumericPtrBase -{ -public: - //------------------------------------------------------------------------ - // Base virtuals - //------------------------------------------------------------------------ - virtual std::string getTypeName() const final { return T::getInstance()->getTypeName() + "*"; } - virtual size_t getTypeHash() const final { return typeid(std::add_pointer_t).hash_code(); } - - //------------------------------------------------------------------------ - // NumericArrayBase virtuals - //------------------------------------------------------------------------ - virtual const NumericBase *getValueType() const final { return T::getInstance(); } }; //---------------------------------------------------------------------------- @@ -192,7 +189,6 @@ class ForeignFunctionBase : public Base // Base virtuals //------------------------------------------------------------------------ virtual std::string getTypeName() const = 0; - virtual size_t getTypeHash() const = 0; //------------------------------------------------------------------------ // Declared virtuals @@ -219,14 +215,6 @@ class ForeignFunction : public ForeignFunctionBase return typeName; } - virtual size_t getTypeHash() const final - { - // Start with seed of return type hash - size_t seed = getReturnType()->getTypeHash(); - updateTypeHash(seed); - return seed; - } - //------------------------------------------------------------------------ // ForeignFunctionBase virtuals //------------------------------------------------------------------------ @@ -247,18 +235,6 @@ class ForeignFunction : public ForeignFunctionBase //------------------------------------------------------------------------ // Private methods //------------------------------------------------------------------------ - template - static void updateTypeHash(size_t &seed) - { - // Combine hashes with argument type - // **NOTE** this is the boost::hash_combine algorithm - seed ^= T::getInstance()->getTypeHash() + 0x9e3779b9 + (seed << 6) + (seed >> 2); - - // If there are more arguments left in pack, recurse - if constexpr (sizeof...(Args)) { - updateTypeHash(seed); - } - } template static void updateTypeName(std::string &typeName) @@ -308,18 +284,14 @@ DECLARE_NUMERIC_TYPE(Double, double, 60); DECLARE_FOREIGN_FUNCTION_TYPE(Exp, Double, Double); DECLARE_FOREIGN_FUNCTION_TYPE(Sqrt, Double, Double); +const Pointer *createPointer(const Base *valueType); + //! Parse a numeric type const NumericBase *parseNumeric(std::string_view typeString); -//! Parse a numeric pointer type -const NumericPtrBase *parseNumericPtr(std::string_view typeString); - //! Look up numeric type based on set of type specifiers const NumericBase *getNumericType(const std::set &typeSpecifiers); -//! Look up numeric pointer type based on set of type specifiers -const NumericPtrBase *getNumericPtrType(const std::set &typeSpecifiers); - //! Apply C type promotion rules to numeric type const NumericBase *getPromotedType(const NumericBase *type); diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index f04fa716c5..a0e763a868 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -167,7 +167,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // Loop through sorted dependent variables for(size_t i = 0; i < getSortedArchetypeDependentVars().size(); i++) { - addField(parseNumeric(getSortedArchetypeDependentVars().at(i).getVar().type)->getPointerType(), "_dependentVar" + std::to_string(i), + addField(createPointer(parseNumeric(getSortedArchetypeDependentVars().at(i).getVar().type)), "_dependentVar" + std::to_string(i), [i, &backend, this](const auto&, size_t g) { const auto &varRef = m_SortedDependentVars[g][i]; diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 163a1a164e..717451a13e 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -212,15 +212,20 @@ GeNN::Type::QualifiedType parseDeclarationSpecifiers(ParserState &parserState) } } while(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER, Token::Type::STAR})); - // Lookup type based on whether token was found - const GeNN::Type::Base *type = (pointerFound - ? static_cast(GeNN::Type::getNumericPtrType(typeSpecifiers)) - : static_cast(GeNN::Type::getNumericType(typeSpecifiers))); - - // Return qualified type + // Lookup numeric type + const auto *numericType = GeNN::Type::getNumericType(typeSpecifiers); + + // If pointer, return pointer to numeric type // **THINK** this relies of const being only qualifier // **TODO** warn of duplicate type qualifiers - return GeNN::Type::QualifiedType{type, !valueTypeQualifiers.empty(), !pointerTypeQualifiers.empty()}; + if (pointerFound) { + return GeNN::Type::QualifiedType{GeNN::Type::createPointer(numericType), + !valueTypeQualifiers.empty(), !pointerTypeQualifiers.empty()}; + } + // Otherwise, return numeric type directly + else { + return GeNN::Type::QualifiedType{numericType, !valueTypeQualifiers.empty(), !pointerTypeQualifiers.empty()}; + } } Expression::ExpressionPtr parsePrimary(ParserState &parserState) @@ -890,9 +895,16 @@ const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPo } }; - // Lookup type based on whether token was found - return (pointerFound - ? static_cast(GeNN::Type::getNumericPtrType(typeSpecifiers)) - : static_cast(GeNN::Type::getNumericType(typeSpecifiers))); + // Lookup numeric type + const auto *numericType = GeNN::Type::getNumericType(typeSpecifiers); + + // If pointer, return pointer to numeric type + if (pointerFound) { + return GeNN::Type::createPointer(numericType); + } + // Otherwise, return numeric type directly + else { + return numericType; + } } } diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 6ea97d2721..b2806f0e2e 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -128,7 +128,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Get pointer type auto arrayType = m_Environment->getType(arraySubscript.getPointerName(), m_ErrorHandler); - auto pointerType = dynamic_cast(arrayType.type); + auto pointerType = dynamic_cast(arrayType.type); // If pointer is indeed a pointer if (pointerType) { @@ -169,11 +169,11 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto leftType = evaluateType(binary.getLeft()); auto leftNumericType = dynamic_cast(leftType.type); auto rightNumericType = dynamic_cast(rightType.type); - auto leftNumericPtrType = dynamic_cast(leftType.type); - auto rightNumericPtrType = dynamic_cast(rightType.type); - if (leftNumericPtrType && rightNumericPtrType && opType == Token::Type::MINUS) { + auto leftPointerType = dynamic_cast(leftType.type); + auto rightPointerType = dynamic_cast(rightType.type); + if (leftPointerType && rightPointerType && opType == Token::Type::MINUS) { // Check pointers are compatible - if (leftNumericPtrType->getTypeHash() != rightNumericPtrType->getTypeHash()) { + if (leftPointerType->getTypeName() != rightPointerType->getTypeName()) { m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getTypeName() + "' and '" + rightType.type->getTypeName()); throw TypeCheckError(); } @@ -182,7 +182,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor m_QualifiedType = Type::QualifiedType{Type::Int32::getInstance(), false, false}; } // Otherwise, if we're adding to or subtracting from pointers - else if (leftNumericPtrType && rightNumericType && (opType == Token::Type::PLUS || opType == Token::Type::MINUS)) // P + n or P - n + else if (leftPointerType && rightNumericType && (opType == Token::Type::PLUS || opType == Token::Type::MINUS)) // P + n or P - n { // Check that numeric operand is integer if (!rightNumericType->isIntegral()) { @@ -194,7 +194,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor m_QualifiedType = leftType; } // Otherwise, if we're adding a number to a pointer - else if (leftNumericType && rightNumericPtrType && opType == Token::Type::PLUS) // n + P + else if (leftNumericType && rightPointerType && opType == Token::Type::PLUS) // n + P { // Check that numeric operand is integer if (!leftNumericType->isIntegral()) { @@ -294,11 +294,11 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If we're trying to cast pointer to pointer auto rightNumericType = dynamic_cast(rightType.type); - auto rightNumericPtrType = dynamic_cast(rightType.type); + auto rightPointerType = dynamic_cast(rightType.type); auto leftNumericType = dynamic_cast(cast.getQualifiedType().type); - auto leftNumericPtrType = dynamic_cast(cast.getQualifiedType().type); - if (rightNumericPtrType && leftNumericPtrType) { - if (rightNumericPtrType->getTypeHash() != leftNumericPtrType->getTypeHash()) { + auto leftPointerType = dynamic_cast(cast.getQualifiedType().type); + if (rightPointerType && leftPointerType) { + if (rightPointerType->getTypeName() != leftPointerType->getTypeName()) { m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getTypeName() + "' and '" + rightType.type->getTypeName()); throw TypeCheckError(); } @@ -377,15 +377,15 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If operator is pointer de-reference if (unary.getOperator().type == Token::Type::STAR) { - auto rightNumericPtrType = dynamic_cast(rightType.type); - if (!rightNumericPtrType) { + auto rightPointerType = dynamic_cast(rightType.type); + if (!rightPointerType) { m_ErrorHandler.error(unary.getOperator(), "Invalid operand type '" + rightType.type->getTypeName() + "'"); throw TypeCheckError(); } // Return value type - m_QualifiedType = Type::QualifiedType{rightNumericPtrType->getValueType(), rightType.constValue, false}; + m_QualifiedType = Type::QualifiedType{rightPointerType->getValueType(), rightType.constValue, false}; } // Otherwise else { @@ -416,7 +416,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Otherwise, if operator is address of, return pointer type else if (unary.getOperator().type == Token::Type::AMPERSAND) { - m_QualifiedType = Type::QualifiedType{rightNumericType->getPointerType(), + m_QualifiedType = Type::QualifiedType{createPointer(rightType), rightType.constValue, false}; } } @@ -595,9 +595,9 @@ const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Typ { // If existing type is a constant numeric value or if it's a constant pointer give errors auto numericExistingType = dynamic_cast(existingType.type); - auto numericPtrExistingType = dynamic_cast(existingType.type); + auto pointerExistingType = dynamic_cast(existingType.type); if(!initializer && ((numericExistingType && existingType.constValue) - || (numericPtrExistingType && existingType.constPointer))) + || (pointerExistingType && existingType.constPointer))) { errorHandler.error(name, "Assignment of read-only variable"); throw TypeCheckError(); @@ -605,24 +605,24 @@ const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Typ // If assignment operation is plain equals, any type is fine so return auto numericAssignedType = dynamic_cast(assignedType.type); - auto numericPtrAssignedType = dynamic_cast(assignedType.type); + auto pointerAssignedType = dynamic_cast(assignedType.type); if(op == Token::Type::EQUAL) { // If we're initialising a pointer with another pointer - if (numericPtrAssignedType && numericPtrExistingType) { + if (pointerAssignedType && pointerExistingType) { // If we're trying to assign a pointer to a const value to a pointer if (assignedType.constValue && !existingType.constValue) { - errorHandler.error(name, "Invalid operand types '" + numericPtrExistingType->getTypeName() + "' and '" + numericPtrAssignedType->getTypeName()); + errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getTypeName() + "' and '" + pointerAssignedType->getTypeName()); throw TypeCheckError(); } // If pointer types aren't compatible - if (numericPtrExistingType->getTypeHash() != numericPtrAssignedType->getTypeHash()) { - errorHandler.error(name, "Invalid operand types '" + numericPtrExistingType->getTypeName() + "' and '" + numericPtrAssignedType->getTypeName()); + if (pointerExistingType->getTypeName() != pointerAssignedType->getTypeName()) { + errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getTypeName() + "' and '" + pointerAssignedType->getTypeName()); throw TypeCheckError(); } } // Otherwise, if we're trying to initialise a pointer with a non-pointer or vice-versa - else if (numericPtrAssignedType || numericPtrExistingType) { + else if (pointerAssignedType || pointerExistingType) { errorHandler.error(name, "Invalid operand types '" + existingType.type->getTypeName() + "' and '" + assignedType.type->getTypeName()); throw TypeCheckError(); } @@ -630,14 +630,14 @@ const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Typ // Otherwise, if operation is += or -- else if (op == Token::Type::PLUS_EQUAL || op == Token::Type::MINUS_EQUAL) { // If the operand being added isn't numeric or the type being added to is neither numeric or a pointer - if (!numericAssignedType || (!numericPtrExistingType && !numericExistingType)) + if (!numericAssignedType || (!pointerExistingType && !numericExistingType)) { errorHandler.error(name, "Invalid operand types '" + existingType.type->getTypeName() + "' and '" + assignedType.type->getTypeName() + "'"); throw TypeCheckError(); } // If we're adding a numeric type to a pointer, check it's an integer - if (numericPtrExistingType && numericAssignedType->isIntegral()) { + if (pointerExistingType && numericAssignedType->isIntegral()) { errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getTypeName() + "'"); throw TypeCheckError(); } @@ -677,9 +677,9 @@ const Type::QualifiedType &EnvironmentBase::incDec(const Token &name, Token::Typ { // If existing type is a constant numeric value or if it's a constant pointer give errors auto numericExistingType = dynamic_cast(existingType.type); - auto numericPtrExistingType = dynamic_cast(existingType.type); + auto pointerExistingType = dynamic_cast(existingType.type); if((numericExistingType && existingType.constValue) - || (numericPtrExistingType && existingType.constPointer)) + || (pointerExistingType && existingType.constPointer)) { errorHandler.error(name, "Increment/decrement of read-only variable"); throw TypeCheckError(); diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 8b50645632..3366e987f9 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -72,6 +72,13 @@ IMPLEMENT_TYPE(Sqrt); //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- +const Pointer *createPointer(const Base *valueType) +{ + // **TODO** befriend constructor + // **TODO** don't just leak these! + return new Pointer(valueType); +} +//---------------------------------------------------------------------------- const NumericBase *parseNumeric(std::string_view typeString) { using namespace Transpiler; @@ -95,41 +102,12 @@ const NumericBase *parseNumeric(std::string_view typeString) return type; } //---------------------------------------------------------------------------- -const NumericPtrBase *parseNumericPtr(std::string_view typeString) -{ - using namespace Transpiler; - - // Scan type - SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, errorHandler); - - // Parse type and cast to numeric pointer - const auto *type = dynamic_cast(Parser::parseType(tokens, true, errorHandler)); - - // If an error was encountered while scanning or parsing, throw exception - if (errorHandler.hasError()) { - throw std::runtime_error("Error parsing type '" + std::string{typeString} + "'"); - } - - // If tokens did not contain a valid numeric type, throw exception - if (!type) { - throw std::runtime_error("Unable to parse type '" + std::string{typeString} + "'"); - } - return type; -} -//---------------------------------------------------------------------------- const NumericBase *getNumericType(const std::set &typeSpecifiers) { const auto type = numericTypes.find(typeSpecifiers); return (type == numericTypes.cend()) ? nullptr : type->second; } //---------------------------------------------------------------------------- -const NumericPtrBase *getNumericPtrType(const std::set &typeSpecifiers) -{ - const auto type = numericTypes.find(typeSpecifiers); - return (type == numericTypes.cend()) ? nullptr : type->second->getPointerType(); -} -//---------------------------------------------------------------------------- const NumericBase *getPromotedType(const NumericBase *type) { // If a small integer type is used in an expression, it is implicitly converted to int which is always signed. From 504269c50294400f638143f8985914dd1d4c9b30 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 11:45:58 +0000 Subject: [PATCH 043/725] fixed up type so everything compiles --- .../genn/genn/code_generator/groupMerged.h | 10 +- .../groupMergedTypeEnvironment.h | 6 +- .../genn/code_generator/modelSpecMerged.h | 6 +- include/genn/genn/transpiler/typeChecker.h | 67 ----- include/genn/genn/type.h | 12 +- .../customConnectivityUpdateGroupMerged.cc | 38 +-- .../code_generator/customUpdateGroupMerged.cc | 64 ++--- src/genn/genn/code_generator/groupMerged.cc | 18 +- .../genn/code_generator/initGroupMerged.cc | 30 +-- .../genn/code_generator/modelSpecMerged.cc | 2 +- .../code_generator/neuronUpdateGroupMerged.cc | 28 +-- .../synapseUpdateGroupMerged.cc | 10 +- src/genn/genn/transpiler/typeChecker.cc | 50 +--- src/genn/genn/type.cc | 33 ++- tests/unit/typeChecker.cc | 234 ++++++++++++------ 15 files changed, 306 insertions(+), 302 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 3e8279c7b1..e85da70444 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -116,7 +116,7 @@ class GroupMerged // If field is a pointer and not marked as being a host field // (in which case the backend should leave its type alone!) const auto *type = std::get<0>(f); - if(dynamic_cast(type) && !(std::get<3>(f) & GroupMergedFieldType::HOST)) { + if(dynamic_cast(type) && !(std::get<3>(f) & GroupMergedFieldType::HOST)) { // If we are generating a host structure, allow the backend to override the type if(host) { os << backend.getMergedGroupFieldHostTypeName(type); @@ -267,13 +267,13 @@ class GroupMerged void addPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - addField(type->getPointerType(), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); + addField(createPointer(type), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); } template>* = nullptr> void addPointerField(const std::string &name, const std::string &prefix) { - addField(T::getInstance()->getPointerType(), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); + addField(createPointer(T::getInstance()), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); } @@ -290,7 +290,7 @@ class GroupMerged { // Loop through variables for(const auto &v : varReferences) { - addField(Type::parseNumeric(v.type)->getPointerType(), v.name, + addField(createPointer(Type::parseNumeric(v.type)), v.name, [getVarRefFn, arrayPrefix, v](const G &g, size_t) { const auto varRef = getVarRefFn(g).at(v.name); @@ -478,7 +478,7 @@ class GroupMerged // Loop through fields again to generate any EGP pushing functions that are require for(const auto &f : sortedFields) { // If this field is a dynamic pointer - if((std::get<3>(f) & GroupMergedFieldType::DYNAMIC) && dynamic_cast(std::get<0>(f))) { + if((std::get<3>(f) & GroupMergedFieldType::DYNAMIC) && dynamic_cast(std::get<0>(f))) { definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << getIndex() << std::get<1>(f) << "ToDevice(unsigned int idx, "; definitionsInternalFunc << backend.getMergedGroupFieldHostTypeName(std::get<0>(f)) << " value);" << std::endl; } diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index 95850a2548..b45feedb27 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -57,7 +57,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa // Add field to merged group if required addField(existingType->second); - // Perform standard type-checking logic + // Perform standard type-checking logicGroupMergedTypeEnvironment return EnvironmentBase::assign(name, op, existingType->second.first, assignedType, errorHandler, initializer); } @@ -142,7 +142,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa void definePointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix, VarAccessMode access) { defineField(type, name, (access & VarAccessModeAttribute::READ_ONLY), false, - type->getPointerType(), name, [prefix](const auto &g, size_t) { return prefix + g.getName(); }); + Type::createPointer(type), name, [prefix](const auto &g, size_t) { return prefix + g.getName(); }); } template @@ -200,7 +200,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa for(const auto &v : varReferences) { const auto *type = Type::parseNumeric(v.type); defineField(type, v.name, (v.access & VarAccessModeAttribute::READ_ONLY), false, - type->getPointerType(), v.name, + Type::createPointer(type), v.name, [arrayPrefix, getVarRefFn, v](const auto &g, size_t) { const auto varRef = getVarRefFn(g).at(v.name); diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 23a0ef45ec..875f8e4cca 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -52,8 +52,8 @@ class GENN_EXPORT ModelSpecMerged //! lexicographically compares all three struct members bool operator < (const EGPField &other) const { - return (std::make_tuple(mergedGroupIndex, type->getTypeHash(), fieldName, hostGroup) - < std::make_tuple(other.mergedGroupIndex, other.type->getTypeHash(), other.fieldName, other.hostGroup)); + return (std::make_tuple(mergedGroupIndex, type->getTypeName(), fieldName, hostGroup) + < std::make_tuple(other.mergedGroupIndex, other.type->getTypeName(), other.fieldName, other.hostGroup)); } }; @@ -263,7 +263,7 @@ class GENN_EXPORT ModelSpecMerged for(auto f : mergedGroupFields) { // If EGP is a pointer // **NOTE** this is common to all references! - if(dynamic_cast(f.type)) { + if(dynamic_cast(f.type)) { os << "void pushMerged" << T::name << f.mergedGroupIndex << f.fieldName << "ToDevice(unsigned int idx, " << backend.getMergedGroupFieldHostTypeName(f.type) << " value)"; { CodeStream::Scope b(os); diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 70568914c1..da2f9322e0 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -57,73 +57,6 @@ class EnvironmentBase const Type::QualifiedType &existingType, ErrorHandlerBase &errorHandler) const; }; -//--------------------------------------------------------------------------- -// GeNN::Transpiler::TypeChecker::EnvironmentExternal -//--------------------------------------------------------------------------- -// template -class EnvironmentExternal : public EnvironmentBase -{ -public: - // **THINK** should type need to be same as enclosing group? perhaps this could help with child groups? - //EnvironmentExternal(EnvironmentBase *enclosing) - - //typedef std::function GetFieldValueFunc; - //typedef std::tuple Field; - - //--------------------------------------------------------------------------- - // Public API - //--------------------------------------------------------------------------- - template - void define(std::string_view name, bool isConstValue = false, bool isConstPointer = false) - { - if(!m_Types.try_emplace(name, T::getInstance(), isConstValue, isConstPointer).second) { - throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); - } - } - - /*void addPointerField(const Type::Base *type, const std::string &name, bool isConstValue = false) - { - assert(dynamic_cast(type)); - - // Define variable type - define(name, type, isConstValue); - - // Add field with pointer type - // **TODO** could also be a const pointer - addField(type->getPointerType(), name, [name](const G &g, size_t) { return devicePrefix + name + g.getName(); }); - - // **TODO** link from type back to field(s) - vector indices would work as we always push back fields - } - - void addVars(const Models::Base::VarVec &vars, const std::string &arrayPrefix) - { - // Loop through variables - for(const auto &v : vars) { - addPointerField(v.type, v.name, (v.access & VarAccessMode::READ_ONLY)); - } - } - */ - - //--------------------------------------------------------------------------- - // EnvironmentBase virtuals - //--------------------------------------------------------------------------- - virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandlerBase &errorHandler) final; - virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, - ErrorHandlerBase &errorHandler, bool initializer = false) final; - virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final; - virtual const Type::QualifiedType &getType(const Token &name, ErrorHandlerBase &errorHandler) final; - -private: - //--------------------------------------------------------------------------- - // Members - //--------------------------------------------------------------------------- - std::unordered_map m_Types; - - // **THINK** should fields live in some sort of parent environment external? children are instantiated to type check e.g. child synapse groups - // but we eventually want a flat list of fields and we want that to be located somewhere permanent - //std::vector m_Fields; -}; - //--------------------------------------------------------------------------- // Free functions //--------------------------------------------------------------------------- diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 6a133189e6..eca9dc31c2 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -93,7 +93,7 @@ class Base // GeNN::Type::Pointer //---------------------------------------------------------------------------- //! Type representing a pointer -class Pointer : Base +class Pointer : public Base { public: Pointer(const Base *valueType) @@ -284,11 +284,21 @@ DECLARE_NUMERIC_TYPE(Double, double, 60); DECLARE_FOREIGN_FUNCTION_TYPE(Exp, Double, Double); DECLARE_FOREIGN_FUNCTION_TYPE(Sqrt, Double, Double); +//! Create a pointer type to the given value type const Pointer *createPointer(const Base *valueType); +template +const Pointer *createPointer() +{ + return createPointer(T::getInstance()); +} + //! Parse a numeric type const NumericBase *parseNumeric(std::string_view typeString); +//! Parse a numeric pointer type +const Pointer *parseNumericPtr(std::string_view typeString); + //! Look up numeric type based on set of type specifiers const NumericBase *getNumericType(const std::set &typeSpecifiers); diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index a0e763a868..71b0e79fa5 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -111,34 +111,34 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t assert(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE); - addField(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())->getPointerType(), "ind", + addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())), "ind", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); }); - addField("rowLength", - [&backend](const auto &cg, size_t) - { - return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); - }); - + addField(createPointer(), "rowLength", + [&backend](const auto &cg, size_t) + { + return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); + }); + // If some presynaptic variables are delayed, add delay pointer if (getArchetype().getPreDelayNeuronGroup() != nullptr) { - addField("preSpkQuePtr", - [&backend](const auto &cg, size_t) - { - return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPreDelayNeuronGroup()->getName(); - }); + addField(createPointer(), "preSpkQuePtr", + [&backend](const auto &cg, size_t) + { + return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPreDelayNeuronGroup()->getName(); + }); } // If some postsynaptic variables are delayed, add delay pointer if (getArchetype().getPostDelayNeuronGroup() != nullptr) { - addField("postSpkQuePtr", - [&backend](const auto &cg, size_t) - { - return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPostDelayNeuronGroup()->getName(); - }); + addField(createPointer(), "postSpkQuePtr", + [&backend](const auto &cg, size_t) + { + return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPostDelayNeuronGroup()->getName(); + }); } // If this backend requires per-population RNGs and this group requires one @@ -563,13 +563,13 @@ void CustomConnectivityHostUpdateGroupMerged::addVars(const BackendBase &backend for(const auto &v : vars) { // If var is located on the host if (std::invoke(getVarLocationFn, getArchetype(), v.name) & VarLocation::HOST) { - addField(parseNumeric(v.type)->getPointerType(), v.name, + addField(createPointer(parseNumeric(v.type)), v.name, [v](const auto &g, size_t) { return v.name + g.getName(); }, GroupMergedFieldType::HOST); if(!backend.getDeviceVarPrefix().empty()) { // **TODO** I think could use addPointerField - addField(parseNumeric(v.type)->getPointerType(), backend.getDeviceVarPrefix() + v.name, + addField(createPointer(parseNumeric(v.type)), backend.getDeviceVarPrefix() + v.name, [v, &backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + v.name + g.getName(); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 0aefbb2e2b..b43903cbc3 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -125,24 +125,24 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string // If some variables are delayed, add delay pointer if(getArchetype().getDelayNeuronGroup() != nullptr) { - addField("spkQuePtr", - [&backend](const auto &cg, size_t) - { - return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); - }); + addField(createPointer(), "spkQuePtr", + [&backend](const auto &cg, size_t) + { + return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); + }); } // Add heterogeneous custom update model parameters const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); typeEnvironment.defineHeterogeneousParams( cm->getParamNames(), "", - [](const CustomUpdateInternal &cg) { return cg.getParams(); }, + [](const auto &cg) { return cg.getParams(); }, &CustomUpdateGroupMerged::isParamHeterogeneous); // Add heterogeneous weight update model derived parameters typeEnvironment.defineHeterogeneousDerivedParams( cm->getDerivedParams(), "", - [](const CustomUpdateInternal &cg) { return cg.getDerivedParams(); }, + [](const auto &cg) { return cg.getDerivedParams(); }, &CustomUpdateGroupMerged::isDerivedParamHeterogeneous); // Add variables to struct @@ -150,7 +150,7 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string // Add variable references to struct typeEnvironment.defineVarReferences(cm->getVarRefs(), backend.getDeviceVarPrefix(), - [](const CustomUpdateInternal &cg) { return cg.getVarReferences(); }); + [](const auto &cg) { return cg.getVarReferences(); }); // Add EGPs to struct typeEnvironment.defineEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); @@ -165,12 +165,12 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string //---------------------------------------------------------------------------- bool CustomUpdateGroupMerged::isParamHeterogeneous(const std::string ¶mName) const { - return isParamValueHeterogeneous(paramName, [](const CustomUpdateInternal &cg) { return cg.getParams(); }); + return isParamValueHeterogeneous(paramName, [](const auto &cg) { return cg.getParams(); }); } //---------------------------------------------------------------------------- bool CustomUpdateGroupMerged::isDerivedParamHeterogeneous(const std::string ¶mName) const { - return isParamValueHeterogeneous(paramName, [](const CustomUpdateInternal &cg) { return cg.getDerivedParams(); }); + return isParamValueHeterogeneous(paramName, [](const auto &cg) { return cg.getDerivedParams(); }); } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() const @@ -181,12 +181,12 @@ boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() Utils::updateHash(getArchetype().getHashDigest(), hash); // Update hash with each group's custom update size - updateHash([](const CustomUpdateInternal &cg) { return cg.getSize(); }, hash); + updateHash([](const auto &cg) { return cg.getSize(); }, hash); // Update hash with each group's parameters, derived parameters and variable references - updateHash([](const CustomUpdateInternal &cg) { return cg.getParams(); }, hash); - updateHash([](const CustomUpdateInternal &cg) { return cg.getDerivedParams(); }, hash); - updateHash([](const CustomUpdateInternal &cg) { return cg.getVarReferences(); }, hash); + updateHash([](const auto &cg) { return cg.getParams(); }, hash); + updateHash([](const auto &cg) { return cg.getDerivedParams(); }, hash); + updateHash([](const auto &cg) { return cg.getVarReferences(); }, hash); return hash.get_digest(); } @@ -194,7 +194,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { genCustomUpdate(os, popSubs, *this, modelMerged, "id", - [this](const Models::VarReference &varRef, const std::string &index) + [this](const auto &varRef, const std::string &index) { return getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, getVarAccessDuplication(varRef.getVar().access), @@ -260,20 +260,20 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWUGroupMergedBase::getHashDi Utils::updateHash(getArchetype().getHashDigest(), hash); // Update hash with sizes of pre and postsynaptic neuron groups - updateHash([](const CustomUpdateWUInternal &cg) + updateHash([](const auto &cg) { return static_cast(cg.getSynapseGroup())->getSrcNeuronGroup()->getNumNeurons(); }, hash); - updateHash([](const CustomUpdateWUInternal &cg) + updateHash([](const auto &cg) { return static_cast(cg.getSynapseGroup())->getTrgNeuronGroup()->getNumNeurons(); }, hash); // Update hash with each group's parameters, derived parameters and variable referneces - updateHash([](const CustomUpdateWUInternal &cg) { return cg.getParams(); }, hash); - updateHash([](const CustomUpdateWUInternal &cg) { return cg.getDerivedParams(); }, hash); - updateHash([](const CustomUpdateWUInternal &cg) { return cg.getVarReferences(); }, hash); + updateHash([](const auto &cg) { return cg.getParams(); }, hash); + updateHash([](const auto &cg) { return cg.getDerivedParams(); }, hash); + updateHash([](const auto &cg) { return cg.getVarReferences(); }, hash); return hash.get_digest(); } @@ -339,17 +339,17 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // If synapse group has sparse connectivity if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addField(Type::parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())->getPointerType(), "ind", + addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())), "ind", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); }); - addField("rowLength", - [&backend](const auto &cg, size_t) - { - return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); - }); + addField(createPointer(), "rowLength", + [&backend](const auto &cg, size_t) + { + return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); + }); } } @@ -379,7 +379,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // If variable has a transpose if(getArchetype().getVarReferences().at(v.name).getTransposeSynapseGroup() != nullptr) { // Add field with transpose suffix, pointing to transpose var - addField(Type::parseNumeric(v.type)->getPointerType(), v.name + "Transpose", + addField(createPointer(parseNumeric(v.type)), v.name + "Transpose", [&backend, v](const auto &g, size_t) { const auto varRef = g.getVarReferences().at(v.name); @@ -438,11 +438,11 @@ CustomUpdateHostReductionGroupMerged::CustomUpdateHostReductionGroupMerged(size_ // If some variables are delayed, add delay pointer // **NOTE** this is HOST delay pointer if(getArchetype().getDelayNeuronGroup() != nullptr) { - addField("spkQuePtr", - [&](const auto &cg, size_t) - { - return "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); - }); + addField(createPointer(), "spkQuePtr", + [](const auto &cg, size_t) + { + return "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); + }); } } diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index ad43e1bd13..58d446c597 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -358,7 +358,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : cs->getCurrentSourceModel()->getVars()) { // Add pointers to state variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(parseNumeric(var.type)->getPointerType(), var.name + "CS" + std::to_string(i), + addField(Type::createPointer(parseNumeric(var.type)), var.name + "CS" + std::to_string(i), [&backend, i, var, this](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + m_SortedCurrentSources.at(groupIndex).at(i)->getName(); @@ -489,7 +489,7 @@ bool NeuronGroupMergedBase::isPSMVarInitParamReferenced(size_t childIndex, const void NeuronGroupMergedBase::addMergedInSynPointerField(const Type::NumericBase *type, const std::string &name, size_t archetypeIndex, const std::string &prefix) { - addField(type->getPointerType(), name + std::to_string(archetypeIndex), + addField(Type::createPointer(type), name + std::to_string(archetypeIndex), [prefix, archetypeIndex, this](const auto&, size_t groupIndex) { return prefix + m_SortedMergedInSyns.at(groupIndex).at(archetypeIndex)->getFusedPSVarSuffix(); @@ -499,7 +499,7 @@ void NeuronGroupMergedBase::addMergedInSynPointerField(const Type::NumericBase * void NeuronGroupMergedBase::addMergedPreOutputOutSynPointerField(const Type::NumericBase *type, const std::string &name, size_t archetypeIndex, const std::string &prefix) { - addField(type->getPointerType(), name + std::to_string(archetypeIndex), + addField(Type::createPointer(type), name + std::to_string(archetypeIndex), [prefix, archetypeIndex, this](const auto&, size_t groupIndex) { return prefix + m_SortedMergedPreOutputOutSyns.at(groupIndex).at(archetypeIndex)->getFusedPreOutputSuffix(); @@ -842,14 +842,14 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // Add presynaptic variables to struct for(const auto &v : wum->getPreVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(parseNumeric(v.type)->getPointerType(), v.name, + addField(createPointer(parseNumeric(v.type)), v.name, [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPreVarSuffix(); }); } // Add presynaptic variables to struct for(const auto &v : wum->getPostVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(parseNumeric(v.type)->getPointerType(), v.name, + addField(createPointer(parseNumeric(v.type)), v.name, [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPostVarSuffix(); }); } @@ -1109,22 +1109,22 @@ boost::uuids::detail::sha1::digest_type SynapseGroupMergedBase::getHashDigest(Ro //---------------------------------------------------------------------------- void SynapseGroupMergedBase::addPSPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPSVarSuffix(); }); + addField(createPointer(type), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPSVarSuffix(); }); } //---------------------------------------------------------------------------- void SynapseGroupMergedBase::addPreOutputPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPreOutputSuffix(); }); + addField(createPointer(type), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPreOutputSuffix(); }); } //---------------------------------------------------------------------------- void SynapseGroupMergedBase::addSrcPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getSrcNeuronGroup()->getName(); }); + addField(createPointer(type), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getSrcNeuronGroup()->getName(); }); } //---------------------------------------------------------------------------- void SynapseGroupMergedBase::addTrgPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getTrgNeuronGroup()->getName(); }); + addField(createPointer(type), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getTrgNeuronGroup()->getName(); }); } //---------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index ab76ef350d..2612dfff7c 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -443,7 +443,7 @@ void NeuronInitGroupMerged::generateWUVar(const BackendBase &backend, for(const auto &var : vars) { // Add pointers to state variable if(!varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(parseNumeric(var.type)->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), + addField(createPointer(parseNumeric(var.type)), var.name + fieldPrefixStem + std::to_string(i), [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto&, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); @@ -1016,13 +1016,13 @@ CustomWUUpdateSparseInitGroupMerged::CustomWUUpdateSparseInitGroupMerged(size_t addField("numTrgNeurons", [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); - addField("rowLength", - [&backend](const auto &cg, size_t) - { - const SynapseGroupInternal *sg = cg.getSynapseGroup(); - return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); - }); - addField(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())->getPointerType(), "ind", + addField(createPointer(), "rowLength", + [&backend](const auto &cg, size_t) + { + const SynapseGroupInternal *sg = cg.getSynapseGroup(); + return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); + }); + addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())), "ind", [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); @@ -1178,13 +1178,13 @@ CustomConnectivityUpdateSparseInitGroupMerged::CustomConnectivityUpdateSparseIni addField("numTrgNeurons", [](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); - addField("rowLength", - [&backend](const CustomConnectivityUpdateInternal &cg, size_t) - { - const SynapseGroupInternal *sg = cg.getSynapseGroup(); - return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); - }); - addField(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())->getPointerType(), "ind", + addField(createPointer(), "rowLength", + [&backend](const CustomConnectivityUpdateInternal &cg, size_t) + { + const SynapseGroupInternal *sg = cg.getSynapseGroup(); + return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); + }); + addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())), "ind", [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index 0e4d7ab788..6aa77abb71 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -572,7 +572,7 @@ bool ModelSpecMerged::anyPointerEGPs() const if(std::any_of(e.second.cbegin(), e.second.cend(), [](const MergedEGPDestinations::value_type &g) { - return dynamic_cast(g.second.type); + return dynamic_cast(g.second.type); })) { return true; diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 2825d4b6cb..062ad5fe9c 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -80,7 +80,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string for(const auto &var : sgPreVars) { // If variable is referenced in event threshold code if(s.eventThresholdCode.find("$(" + var.name + ")") != std::string::npos) { - addField(parseNumeric(var.type)->getPointerType(), var.name + "EventThresh" + std::to_string(i), + addField(createPointer(parseNumeric(var.type)), var.name + "EventThresh" + std::to_string(i), [&backend, eventThresholdSGs, var, i](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + eventThresholdSGs.at(groupIndex).at(i)->getName(); @@ -93,22 +93,22 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string if(getArchetype().isSpikeRecordingEnabled()) { // Add field for spike recording - addField("recordSpk", - [&backend](const auto &ng, size_t) - { - return backend.getDeviceVarPrefix() + "recordSpk" + ng.getName(); - }, - GroupMergedFieldType::DYNAMIC); + addField(createPointer(), "recordSpk", + [&backend](const auto &ng, size_t) + { + return backend.getDeviceVarPrefix() + "recordSpk" + ng.getName(); + }, + GroupMergedFieldType::DYNAMIC); } if(getArchetype().isSpikeEventRecordingEnabled()) { // Add field for spike event recording - addField("recordSpkEvent", - [&backend](const auto &ng, size_t) - { - return backend.getDeviceVarPrefix() + "recordSpkEvent" + ng.getName(); - }, - GroupMergedFieldType::DYNAMIC); + addField(createPointer(), "recordSpkEvent", + [&backend](const auto &ng, size_t) + { + return backend.getDeviceVarPrefix() + "recordSpkEvent" + ng.getName(); + }, + GroupMergedFieldType::DYNAMIC); } } @@ -716,7 +716,7 @@ void NeuronUpdateGroupMerged::generateWUVar(const BackendBase &backend, const s for(size_t v = 0; v < vars.size(); v++) { // Add pointers to state variable const auto var = vars[v]; - addField(Type::parseNumeric(var.type)->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), + addField(Type::createPointer(Type::parseNumeric(var.type)), var.name + fieldPrefixStem + std::to_string(i), [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto &, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 9689dbf193..04e385e4dc 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -325,9 +325,9 @@ SynapseDendriticDelayUpdateGroupMerged::SynapseDendriticDelayUpdateGroupMerged(s const std::vector> &groups) : GroupMerged(index, precision, groups) { - addField("denDelayPtr", - [&backend](const SynapseGroupInternal &sg, size_t) - { - return backend.getScalarAddressPrefix() + "denDelayPtr" + sg.getFusedPSVarSuffix(); - }); + addField(Type::createPointer(), "denDelayPtr", + [&backend](const SynapseGroupInternal &sg, size_t) + { + return backend.getScalarAddressPrefix() + "denDelayPtr" + sg.getFusedPSVarSuffix(); + }); } diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index b2806f0e2e..d18b94798b 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -416,7 +416,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Otherwise, if operator is address of, return pointer type else if (unary.getOperator().type == Token::Type::AMPERSAND) { - m_QualifiedType = Type::QualifiedType{createPointer(rightType), + m_QualifiedType = Type::QualifiedType{Type::createPointer(rightType.type), rightType.constValue, false}; } } @@ -690,54 +690,6 @@ const Type::QualifiedType &EnvironmentBase::incDec(const Token &name, Token::Typ } } -//--------------------------------------------------------------------------- -// GeNN::Transpiler::TypeChecker::EnvironmentExternal -//--------------------------------------------------------------------------- -void EnvironmentExternal::define(const Token &name, const Type::QualifiedType &, ErrorHandlerBase &errorHandler) -{ - errorHandler.error(name, "Cannot declare variable in external environment"); - throw TypeCheckError(); -} -//--------------------------------------------------------------------------- -const Type::QualifiedType &EnvironmentExternal::assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, - ErrorHandlerBase &errorHandler, bool initializer) -{ - // If type isn't found - auto existingType = m_Types.find(name.lexeme); - if(existingType == m_Types.end()) { - errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); - } - - // Perform standard type-checking logic - return EnvironmentBase::assign(name, op, existingType->second, assignedType, errorHandler, initializer); -} -//--------------------------------------------------------------------------- -const Type::QualifiedType &EnvironmentExternal::incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) -{ - auto existingType = m_Types.find(name.lexeme); - if(existingType == m_Types.end()) { - errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); - } - - // Perform standard type-checking logic - return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); - -} -//--------------------------------------------------------------------------- -const Type::QualifiedType &EnvironmentExternal::getType(const Token &name, ErrorHandlerBase &errorHandler) -{ - auto type = m_Types.find(std::string{name.lexeme}); - if(type == m_Types.end()) { - errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); - } - else { - return type->second; - } -} - //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker //--------------------------------------------------------------------------- diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 3366e987f9..15a5092c2b 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -102,6 +102,29 @@ const NumericBase *parseNumeric(std::string_view typeString) return type; } //---------------------------------------------------------------------------- +const Pointer *parseNumericPtr(std::string_view typeString) +{ + using namespace Transpiler; + + // Scan type + SingleLineErrorHandler errorHandler; + const auto tokens = Scanner::scanSource(typeString, errorHandler); + + // Parse type and cast to numeric pointer + const auto *type = dynamic_cast(Parser::parseType(tokens, true, errorHandler)); + + // If an error was encountered while scanning or parsing, throw exception + if (errorHandler.hasError()) { + throw std::runtime_error("Error parsing type '" + std::string{typeString} + "'"); + } + + // If tokens did not contain a valid numeric type, throw exception + if (!type) { + throw std::runtime_error("Unable to parse type '" + std::string{typeString} + "'"); + } + return type; +} +//---------------------------------------------------------------------------- const NumericBase *getNumericType(const std::set &typeSpecifiers) { const auto type = numericTypes.find(typeSpecifiers); @@ -124,13 +147,13 @@ const NumericBase *getPromotedType(const NumericBase *type) const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b) { // If either type is double, common type is double - const auto aTypeHash = a->getTypeHash(); - const auto bTypeHash = b->getTypeHash(); - if(aTypeHash == Double::getInstance()->getTypeHash() || bTypeHash == Double::getInstance()->getTypeHash()) { + const auto &aTypeName = a->getTypeName(); + const auto &bTypeName = b->getTypeName(); + if(aTypeName == Double::getInstance()->getTypeName() || bTypeName == Double::getInstance()->getTypeName()) { return Double::getInstance(); } // Otherwise, if either type is float, common type is float - if(aTypeHash == Float::getInstance()->getTypeHash() || bTypeHash == Float::getInstance()->getTypeHash()) { + if(aTypeName == Float::getInstance()->getTypeName() || bTypeName == Float::getInstance()->getTypeName()) { return Float::getInstance(); } // Otherwise, must be an integer type @@ -140,7 +163,7 @@ const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b) const auto *bPromoted = getPromotedType(b); // If both promoted operands have the same type, then no further conversion is needed. - if(aPromoted->getTypeHash() == bPromoted->getTypeHash()) { + if(aPromoted->getTypeName() == bPromoted->getTypeName()) { return aPromoted; } // Otherwise, if both promoted operands have signed integer numericTypes or both have unsigned integer numericTypes, diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 3cbcbf5f77..7d4473c455 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -51,7 +51,93 @@ class TestErrorHandler : public ErrorHandlerBase bool m_Error; }; -void typeCheckStatements(std::string_view code, TypeChecker::EnvironmentExternal &typeEnvironment) +class TestEnvironment : public TypeChecker::EnvironmentBase +{ +public: + //--------------------------------------------------------------------------- + // Public API + //--------------------------------------------------------------------------- + void define(const Type::Base *type, std::string_view name, bool isConstValue = false, bool isConstPointer = false) + { + if(!m_Types.try_emplace(name, type, isConstValue, isConstPointer).second) { + throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); + } + } + + template + void define(std::string_view name, bool isConstValue = false, bool isConstPointer = false) + { + define(T::getInstance(), name, isConstValue, isConstPointer); + } + + template + void definePointer(std::string_view name, bool isConstValue = false, bool isConstPointer = false) + { + define(Type::createPointer(T::getInstance()), name, isConstValue, isConstPointer); + } + + + //--------------------------------------------------------------------------- + // EnvironmentBase virtuals + //--------------------------------------------------------------------------- + virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandlerBase &errorHandler) final + { + errorHandler.error(name, "Cannot declare variable in external environment"); + throw TypeChecker::TypeCheckError(); + } + + virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, + ErrorHandlerBase &errorHandler, bool initializer = false) final + { + // If type isn't found + auto existingType = m_Types.find(name.lexeme); + if(existingType == m_Types.end()) { + errorHandler.error(name, "Undefined variable"); + throw TypeChecker::TypeCheckError(); + } + + // Perform standard type-checking logic + return EnvironmentBase::assign(name, op, existingType->second, assignedType, errorHandler, initializer); + } + + virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final + { + auto existingType = m_Types.find(name.lexeme); + if(existingType == m_Types.end()) { + errorHandler.error(name, "Undefined variable"); + throw TypeChecker::TypeCheckError(); + } + + // Perform standard type-checking logic + return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); + } + + virtual const Type::QualifiedType &getType(const Token &name, ErrorHandlerBase &errorHandler) final + { + auto type = m_Types.find(std::string{name.lexeme}); + if(type == m_Types.end()) { + errorHandler.error(name, "Undefined variable"); + throw TypeChecker::TypeCheckError(); + } + else { + return type->second; + } + } + +private: + //--------------------------------------------------------------------------- + // Members + //--------------------------------------------------------------------------- + std::unordered_map m_Types; +}; + +template +std::string getPointerTypeName() +{ + return createPointer(T::getInstance())->getTypeName(); +} + +void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment) { // Scan TestErrorHandler errorHandler; @@ -67,7 +153,7 @@ void typeCheckStatements(std::string_view code, TypeChecker::EnvironmentExternal ASSERT_FALSE(errorHandler.hasError()); } -Type::QualifiedType typeCheckExpression(std::string_view code, TypeChecker::EnvironmentExternal &typeEnvironment) +Type::QualifiedType typeCheckExpression(std::string_view code, TestEnvironment &typeEnvironment) { // Scan TestErrorHandler errorHandler; @@ -92,26 +178,26 @@ TEST(TypeChecker, ArraySubscript) { // Integer array indexing { - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray"); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); const auto type = typeCheckExpression("intArray[4]", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } // Float array indexing EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray"); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); typeCheckExpression("intArray[4.0f]", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer indexing EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray"); - typeEnvironment.define("indexArray"); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); + typeEnvironment.definePointer("indexArray"); typeCheckExpression("intArray[indexArray]", typeEnvironment);}, TypeChecker::TypeCheckError); } @@ -120,7 +206,7 @@ TEST(TypeChecker, Assignment) { // Numeric assignment { - TypeChecker::EnvironmentExternal typeEnvironment; + TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); typeEnvironment.define("floatVal"); typeEnvironment.define("intValConst", true); @@ -137,9 +223,9 @@ TEST(TypeChecker, Assignment) // Pointer assignement { - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray"); - typeEnvironment.define("intArrayConst", true); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); + typeEnvironment.definePointer("intArrayConst", true); typeCheckStatements( "int *x = intArray;\n" "const int *y = intArray;\n" @@ -149,15 +235,15 @@ TEST(TypeChecker, Assignment) // Pointer assignement, attempt to remove const EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray", true); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray", true); typeCheckStatements("int *x = intArray;", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer assignement without explicit cast EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray"); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); typeCheckStatements("float *x = intArray;", typeEnvironment);}, TypeChecker::TypeCheckError); @@ -176,82 +262,82 @@ TEST(TypeChecker, Cast) { // Numeric cast { - TypeChecker::EnvironmentExternal typeEnvironment; + TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("(float)intVal", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Float::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), Type::Float::getInstance()->getTypeName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } // Numeric cast to const { - TypeChecker::EnvironmentExternal typeEnvironment; + TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("(const int)intVal", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } // Pointer cast to value const { - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray"); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); const auto type = typeCheckExpression("(const int*)intArray", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), getPointerTypeName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } // Pointer cast to pointer const { - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray"); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); const auto type = typeCheckExpression("(int * const)intArray", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), getPointerTypeName()); EXPECT_FALSE(type.constValue); EXPECT_TRUE(type.constPointer); } // Can't remove value const from numeric EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; + TestEnvironment typeEnvironment; typeEnvironment.define("intVal", true); typeCheckExpression("(int)intVal", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't remove value const from pointer EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray", true); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray", true); typeCheckExpression("(int*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't remove pointer const from pointer EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray", false, true); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray", false, true); typeCheckExpression("(int*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer cast can't reinterpret EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray"); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); typeCheckExpression("(float*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer can't be cast to numeric EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray"); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); typeCheckExpression("(int)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Numeric can't be cast to pointer EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; + TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); typeCheckExpression("(int*)intVal", typeEnvironment);}, TypeChecker::TypeCheckError); @@ -265,45 +351,45 @@ TEST(TypeChecker, IncDec) { // Can increment numeric { - TypeChecker::EnvironmentExternal typeEnvironment; + TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("intVal++", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } // Can increment pointer { - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray"); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); const auto type = typeCheckExpression("intArray++", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), getPointerTypeName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } // Can increment pointer to const { - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray", true); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray", true); const auto type = typeCheckExpression("intArray++", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), getPointerTypeName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } // Can't increment const number EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; + TestEnvironment typeEnvironment; typeEnvironment.define("intVal", true); typeCheckExpression("intVal++", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't increment const pointer EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray", false, true); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray", false, true); typeCheckExpression("intArray++", typeEnvironment);}, TypeChecker::TypeCheckError); } @@ -312,36 +398,36 @@ TEST(TypeChecker, Literal) { // Float { - TypeChecker::EnvironmentExternal typeEnvironment; + TestEnvironment typeEnvironment; const auto type = typeCheckExpression("1.0f", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Float::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), Type::Float::getInstance()->getTypeName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } // Double { - TypeChecker::EnvironmentExternal typeEnvironment; + TestEnvironment typeEnvironment; const auto type = typeCheckExpression("1.0", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Double::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), Type::Double::getInstance()->getTypeName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } // Integer { - TypeChecker::EnvironmentExternal typeEnvironment; + TestEnvironment typeEnvironment; const auto type = typeCheckExpression("100", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } // Unsigned integer { - TypeChecker::EnvironmentExternal typeEnvironment; + TestEnvironment typeEnvironment; const auto type = typeCheckExpression("100U", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Uint32::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), Type::Uint32::getInstance()->getTypeName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -351,65 +437,65 @@ TEST(TypeChecker, Unary) { // Dereference pointer { - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray"); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } // Dereference pointer to const { - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray", true); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray", true); const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } // Dereference const pointer { - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray", false, true); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray", false, true); const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } // Dereference const pointer to const { - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray", true, true); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray", true, true); const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } // Dereference numeric EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; + TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); typeCheckExpression("*intVal", typeEnvironment); }, TypeChecker::TypeCheckError); // Address of numeric { - TypeChecker::EnvironmentExternal typeEnvironment; + TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("&intVal", typeEnvironment); - EXPECT_EQ(type.type->getTypeHash(), Type::Int32Ptr::getInstance()->getTypeHash()); + EXPECT_EQ(type.type->getTypeName(), getPointerTypeName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } // Address of pointer EXPECT_THROW({ - TypeChecker::EnvironmentExternal typeEnvironment; - typeEnvironment.define("intArray"); + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); typeCheckExpression("&intArray", typeEnvironment);}, TypeChecker::TypeCheckError); } From 23dddffadbf5bb346ad8171f96aaf76bdf110caf Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 12:09:13 +0000 Subject: [PATCH 044/725] adding support for 'scalar' type to scanner --- include/genn/genn/transpiler/scanner.h | 8 +- .../code_generator/customUpdateGroupMerged.cc | 5 +- src/genn/genn/transpiler/scanner.cc | 102 ++++++++---------- src/genn/genn/type.cc | 4 +- tests/unit/scanner.cc | 37 ++++++- tests/unit/typeChecker.cc | 28 ++++- 6 files changed, 113 insertions(+), 71 deletions(-) diff --git a/include/genn/genn/transpiler/scanner.h b/include/genn/genn/transpiler/scanner.h index e80a7f5d68..976c593b64 100644 --- a/include/genn/genn/transpiler/scanner.h +++ b/include/genn/genn/transpiler/scanner.h @@ -11,6 +11,10 @@ #include "transpiler/token.h" // Forward declarations +namespace GeNN::Type +{ +class NumericBase; +} namespace GeNN::Transpiler { class ErrorHandlerBase; @@ -21,6 +25,6 @@ class ErrorHandlerBase; //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Scanner { -std::vector scanSource(const std::string_view &source, ErrorHandlerBase &errorHandler); +std::vector scanSource(const std::string_view &source, const Type::NumericBase *scalarType, ErrorHandlerBase &errorHandler); -} // namespace Scanner \ No newline at end of file +} // namespace Scanner diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index b43903cbc3..74a768bfaf 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -119,7 +119,8 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string // Create type environment // **TEMP** parse precision to get scalar type - GroupMergedTypeEnvironment typeEnvironment(*this, Type::parseNumeric(precision)); + const auto *scalarType = Type::parseNumeric(precision); + GroupMergedTypeEnvironment typeEnvironment(*this, scalarType); addField("size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); @@ -157,7 +158,7 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string // Scan, parse and type-check update code Transpiler::ErrorHandler errorHandler; - const auto tokens = Transpiler::Scanner::scanSource(cm->getUpdateCode(), errorHandler); + const auto tokens = Transpiler::Scanner::scanSource(cm->getUpdateCode(), scalarType, errorHandler); const auto statements = Transpiler::Parser::parseBlockItemList(tokens, errorHandler); Transpiler::TypeChecker::typeCheck(statements, typeEnvironment, errorHandler); diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index 63327228fd..c49fdf3f1a 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -8,12 +8,17 @@ #include // Standard C includes +#include #include +// GeNN includes +#include "type.h" + // Transpiler includes #include "transpiler/errorHandler.h" #include "transpiler/transpilerUtils.h" +using namespace GeNN; using namespace GeNN::Transpiler; using namespace GeNN::Transpiler::Scanner; @@ -43,6 +48,7 @@ const std::unordered_map keywords{ {"long", Token::Type::TYPE_SPECIFIER}, {"float", Token::Type::TYPE_SPECIFIER}, {"double", Token::Type::TYPE_SPECIFIER}, + {"scalar", Token::Type::TYPE_SPECIFIER}, {"signed", Token::Type::TYPE_SPECIFIER}, {"unsigned", Token::Type::TYPE_SPECIFIER}, {"bool", Token::Type::TYPE_SPECIFIER}}; @@ -58,8 +64,8 @@ const std::map, std::function &tokens) scanState.advance(); } - // Read decimal place - const bool isFloat = scanState.match('.'); + // If a decimal point is found, give an error + if(scanState.match('.')) { + scanState.error("Hexadecimal floating pointer literals unsupported."); + } // Read hexadecimal digits while(std::isxdigit(scanState.peek())) { scanState.advance(); } - // If number is float - if(isFloat) { - // Check there's an exponent as these are REQUIRED for floating point literals - if(scanState.peek() != 'p') { - scanState.error("Hexadecimal floating point literal missing exponent."); - } - else { - // Read p - scanState.advance(); - - // Read sign - if(scanState.peek() == '-' || scanState.peek() == '+') { - scanState.advance(); - } - - // Read DECIMAL digits - while(std::isdigit(scanState.peek())) { - scanState.advance(); - } - - // If literal has floating point suffix - if(std::tolower(scanState.peek()) == 'f') { - // Add single-precision token - // **NOTE** skip 0x prefix - emplaceToken(tokens, Token::Type::NUMBER, scanState, - Utils::toCharsThrow(scanState.getLexeme().substr(2), 16)); - - // Advance - // **NOTE** we do this AFTER parsing float as std::to_chars doesn't deal with suffixes - scanState.advance(); - } - // Add double-precision token - // **NOTE** skip 0x prefix - else { - emplaceToken(tokens, Token::Type::NUMBER, scanState, - Utils::toCharsThrow(scanState.getLexeme().substr(2), 16)); - } - } - } - // Otherwise, number is hexadecimal integer - else { - // Add integer token - // **NOTE** skip 0x prefix - const auto suffix = scanIntegerSuffix(scanState); - emplaceToken(tokens, Token::Type::NUMBER, scanState, - integerLiteralSuffixParsers.at(suffix)(scanState.getLexeme().substr(2), 16)); - } + // Add integer token + // **NOTE** skip 0x prefix + const auto suffix = scanIntegerSuffix(scanState); + emplaceToken(tokens, Token::Type::NUMBER, scanState, + integerLiteralSuffixParsers.at(suffix)(scanState.getLexeme().substr(2), 16)); } // Otherwise, if this is an octal integer else if(c == '0' && isodigit(scanState.peek())){ @@ -264,11 +236,29 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) // **NOTE** we do this AFTER parsing float as std::to_chars doesn't deal with suffixes scanState.advance(); } - // Otherwise, add double-precision token - else { + // Otherwise, if literal has double precision suffix + // **NOTE** this is a GeNN extension not standard C + else if(std::tolower(scanState.peek()) == 'd') { emplaceToken(tokens, Token::Type::NUMBER, scanState, Utils::toCharsThrow(scanState.getLexeme())); } + // Otherwise, this is a scalar literal + else { + // If the scalar type is float, add single-precision token + if(scanState.getScalarType()->getTypeName() == "float") { + emplaceToken(tokens, Token::Type::NUMBER, scanState, + Utils::toCharsThrow(scanState.getLexeme())); + + } + // Otherwise, add double-precision token + else if(scanState.getScalarType()->getTypeName() == "double") { + emplaceToken(tokens, Token::Type::NUMBER, scanState, + Utils::toCharsThrow(scanState.getLexeme())); + } + else { + assert(false); + } + } } // Otherwise, number is integer else { @@ -466,11 +456,11 @@ void scanToken(ScanState &scanState, std::vector &tokens) //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Scanner { -std::vector scanSource(const std::string_view &source, ErrorHandlerBase &errorHandler) +std::vector scanSource(const std::string_view &source, const Type::NumericBase *scalarType, ErrorHandlerBase &errorHandler) { std::vector tokens; - ScanState scanState(source, errorHandler); + ScanState scanState(source, scalarType, errorHandler); // Scan tokens while(!scanState.isAtEnd()) { diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 15a5092c2b..cd776ddb0a 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -85,7 +85,7 @@ const NumericBase *parseNumeric(std::string_view typeString) // Scan type SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, errorHandler); + const auto tokens = Scanner::scanSource(typeString, nullptr, errorHandler); // Parse type and cast to numeric const auto *type = dynamic_cast(Parser::parseType(tokens, false, errorHandler)); @@ -108,7 +108,7 @@ const Pointer *parseNumericPtr(std::string_view typeString) // Scan type SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, errorHandler); + const auto tokens = Scanner::scanSource(typeString, nullptr, errorHandler); // Parse type and cast to numeric pointer const auto *type = dynamic_cast(Parser::parseType(tokens, true, errorHandler)); diff --git a/tests/unit/scanner.cc b/tests/unit/scanner.cc index 9151dcdff9..52ca2f0b42 100644 --- a/tests/unit/scanner.cc +++ b/tests/unit/scanner.cc @@ -1,10 +1,14 @@ // Google test includes #include "gtest/gtest.h" +// GeNN includes +#include "type.h" + // GeNN transpiler includes #include "transpiler/errorHandler.h" #include "transpiler/scanner.h" +using namespace GeNN; using namespace GeNN::Transpiler; //-------------------------------------------------------------------------- @@ -52,7 +56,7 @@ class TestErrorHandler : public ErrorHandlerBase TEST(Scanner, DecimalInt) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("1234 4294967295U -2345 -2147483647", errorHandler); + const auto tokens = Scanner::scanSource("1234 4294967295U -2345 -2147483647", GeNN::Type::Float::getInstance(), errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 7); @@ -73,7 +77,7 @@ TEST(Scanner, DecimalInt) TEST(Scanner, HexInt) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("0x1234 0xFFFFFFFFU -0x1234 -0x7FFFFFFF", errorHandler); + const auto tokens = Scanner::scanSource("0x1234 0xFFFFFFFFU -0x1234 -0x7FFFFFFF", GeNN::Type::Float::getInstance(), errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 7); @@ -91,10 +95,35 @@ TEST(Scanner, HexInt) ASSERT_EQ(std::get(tokens[5].literalValue), 0x7FFFFFFF); } //-------------------------------------------------------------------------- -TEST(Scanner, DecimalFloat) +TEST(Scanner, DecimalFloatFloatScalar) +{ + TestErrorHandler errorHandler; + const auto tokens = Scanner::scanSource("1.0 0.2 100.0f 0.2f -12.0d -0.0004f", GeNN::Type::Float::getInstance(), errorHandler); + ASSERT_FALSE(errorHandler.hasError()); + + ASSERT_EQ(tokens.size(), 9); + ASSERT_EQ(tokens[0].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[1].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[2].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[3].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[4].type, Token::Type::MINUS); + ASSERT_EQ(tokens[5].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[6].type, Token::Type::MINUS); + ASSERT_EQ(tokens[7].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[8].type, Token::Type::END_OF_FILE); + + ASSERT_EQ(std::get(tokens[0].literalValue), 1.0f); + ASSERT_EQ(std::get(tokens[1].literalValue), 0.2f); + ASSERT_EQ(std::get(tokens[2].literalValue), 100.0f); + ASSERT_EQ(std::get(tokens[3].literalValue), 0.2f); + ASSERT_EQ(std::get(tokens[5].literalValue), 12.0); + ASSERT_EQ(std::get(tokens[7].literalValue), 0.0004f); +} +//-------------------------------------------------------------------------- +TEST(Scanner, DecimalFloatDoubleScalar) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("1.0 0.2 100.0f 0.2f -12.0 -0.0004f", errorHandler); + const auto tokens = Scanner::scanSource("1.0 0.2 100.0f 0.2f -12.0d -0.0004f", GeNN::Type::Double::getInstance(), errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 9); diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 7d4473c455..61d524e0dd 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -80,7 +80,7 @@ class TestEnvironment : public TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- - virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandlerBase &errorHandler) final + virtual void define(const Token &name, const Type::QualifiedType &, ErrorHandlerBase &errorHandler) final { errorHandler.error(name, "Cannot declare variable in external environment"); throw TypeChecker::TypeCheckError(); @@ -141,7 +141,7 @@ void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment { // Scan TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(code, errorHandler); + const auto tokens = Scanner::scanSource(code, Type::Float::getInstance(), errorHandler); ASSERT_FALSE(errorHandler.hasError()); // Parse @@ -153,11 +153,11 @@ void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment ASSERT_FALSE(errorHandler.hasError()); } -Type::QualifiedType typeCheckExpression(std::string_view code, TestEnvironment &typeEnvironment) +Type::QualifiedType typeCheckExpression(std::string_view code, TestEnvironment &typeEnvironment, const Type::NumericBase *scalarType = Type::Float::getInstance()) { // Scan TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(code, errorHandler); + const auto tokens = Scanner::scanSource(code, scalarType, errorHandler); EXPECT_FALSE(errorHandler.hasError()); // Parse @@ -405,10 +405,28 @@ TEST(TypeChecker, Literal) EXPECT_FALSE(type.constPointer); } - // Double + // Scalar with single-precision { TestEnvironment typeEnvironment; const auto type = typeCheckExpression("1.0", typeEnvironment); + EXPECT_EQ(type.type->getTypeName(), Type::Float::getInstance()->getTypeName()); + EXPECT_TRUE(type.constValue); + EXPECT_FALSE(type.constPointer); + } + + // Scalar with double-precision + { + TestEnvironment typeEnvironment; + const auto type = typeCheckExpression("1.0", typeEnvironment, Type::Double::getInstance()); + EXPECT_EQ(type.type->getTypeName(), Type::Double::getInstance()->getTypeName()); + EXPECT_TRUE(type.constValue); + EXPECT_FALSE(type.constPointer); + } + + // Double + { + TestEnvironment typeEnvironment; + const auto type = typeCheckExpression("1.0d", typeEnvironment); EXPECT_EQ(type.type->getTypeName(), Type::Double::getInstance()->getTypeName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); From d7e4b4eff09030b35ab5ee05bed8f41974d77102 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 12:41:53 +0000 Subject: [PATCH 045/725] pass scalar type through to parser --- .../code_generator/customUpdateGroupMerged.h | 6 +-- .../genn/genn/code_generator/groupMerged.h | 14 +++--- .../groupMergedTypeEnvironment.h | 6 +-- .../genn/code_generator/initGroupMerged.h | 2 +- include/genn/genn/transpiler/parser.h | 11 +++-- include/genn/genn/type.h | 6 +-- .../customConnectivityUpdateGroupMerged.cc | 10 ++--- .../code_generator/customUpdateGroupMerged.cc | 14 +++--- src/genn/genn/code_generator/groupMerged.cc | 44 +++++++++---------- .../genn/code_generator/initGroupMerged.cc | 6 +-- .../code_generator/neuronUpdateGroupMerged.cc | 6 +-- src/genn/genn/transpiler/parser.cc | 28 +++++++----- src/genn/genn/type.cc | 39 ++++++++++------ tests/unit/typeChecker.cc | 8 ++-- 14 files changed, 110 insertions(+), 90 deletions(-) diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index 21f3028a34..66ed67cf6e 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -170,14 +170,14 @@ class CustomUpdateHostReductionGroupMergedBase : public GroupMerged const CustomUpdateModels::Base *cm = this->getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { if(v.access & VarAccessModeAttribute::REDUCE) { - this->addPointerField(Type::parseNumeric(v.type), v.name, backend.getDeviceVarPrefix() + v.name); + this->addPointerField(Type::parseNumeric(v.type, this->getScalarType()), v.name, backend.getDeviceVarPrefix() + v.name); } } // Loop through variable references and add pointers if they are reduction targets for(const auto &v : cm->getVarRefs()) { if(v.access & VarAccessModeAttribute::REDUCE) { - this->addPointerField(Type::parseNumeric(v.type), v.name, backend.getDeviceVarPrefix() + v.name); + this->addPointerField(Type::parseNumeric(v.type, this->getScalarType()), v.name, backend.getDeviceVarPrefix() + v.name); } } } @@ -234,4 +234,4 @@ class GENN_EXPORT CustomWUUpdateHostReductionGroupMerged : public CustomUpdateHo //---------------------------------------------------------------------------- static const std::string name; }; -} // namespace GeNN::CodeGenerator \ No newline at end of file +} // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index e85da70444..8a22be6353 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -67,7 +67,7 @@ class GroupMerged // **HACK** type should come in as type not string GroupMerged(size_t index, const std::string &precision, const std::vector> groups) - : m_Index(index), m_LiteralSuffix((precision == "float") ? "f" : ""), m_ScalarType(Type::parseNumeric(precision)), m_Groups(std::move(groups)) + : m_Index(index), m_LiteralSuffix((precision == "float") ? "f" : ""), m_ScalarType(Type::parseNumeric(precision, nullptr)), m_Groups(std::move(groups)) {} //------------------------------------------------------------------------ @@ -202,6 +202,8 @@ class GroupMerged //------------------------------------------------------------------------ // Protected methods //------------------------------------------------------------------------ + const Type::NumericBase *getScalarType() const{ return m_ScalarType; } + //! Helper to test whether parameter is referenced in vector of codestrings bool isParamReferenced(const std::vector &codeStrings, const std::string ¶mName) const { @@ -281,7 +283,7 @@ class GroupMerged { // Loop through variables for(const auto &v : vars) { - addPointerField(Type::parseNumeric(v.type), v.name, arrayPrefix + v.name); + addPointerField(Type::parseNumeric(v.type, getScalarType()), v.name, arrayPrefix + v.name); } } @@ -290,7 +292,7 @@ class GroupMerged { // Loop through variables for(const auto &v : varReferences) { - addField(createPointer(Type::parseNumeric(v.type)), v.name, + addField(createPointer(Type::parseNumeric(v.type, getScalarType())), v.name, [getVarRefFn, arrayPrefix, v](const G &g, size_t) { const auto varRef = getVarRefFn(g).at(v.name); @@ -303,7 +305,7 @@ class GroupMerged { for(const auto &e : egps) { assert(Utils::isTypePointer(e.type)); - addField(Type::parseNumericPtr(e.type), e.name + varName, + addField(Type::parseNumericPtr(e.type, getScalarType()), e.name + varName, [e, arrayPrefix, varName](const G &g, size_t) { return arrayPrefix + e.name + varName + g.getName(); }, GroupMergedFieldType::DYNAMIC); } @@ -543,7 +545,7 @@ class GroupMerged //------------------------------------------------------------------------ const size_t m_Index; const std::string m_LiteralSuffix; - const Type::Base *m_ScalarType; + const Type::NumericBase *m_ScalarType; std::string m_MemorySpace; std::vector m_Fields; std::vector> m_Groups; @@ -831,7 +833,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged // If we're not initialising or if there is initialization code for this variable const auto &varInit = archetypeAdaptor.getVarInitialisers().at(var.name); if (!varInit.getSnippet()->getCode().empty()) { - this->addPointerField(Type::parseNumeric(var.type), var.name, backend.getDeviceVarPrefix() + var.name); + this->addPointerField(Type::parseNumeric(var.type, this->getScalarType()), var.name, backend.getDeviceVarPrefix() + var.name); } // Add any var init EGPs to structure diff --git a/include/genn/genn/transpiler/parser.h b/include/genn/genn/transpiler/parser.h index 717432d7bf..6068ec65ca 100644 --- a/include/genn/genn/transpiler/parser.h +++ b/include/genn/genn/transpiler/parser.h @@ -22,13 +22,16 @@ class ErrorHandlerBase; namespace GeNN::Transpiler::Parser { //! Parse expression from tokens -Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandlerBase &errorHandler); +Expression::ExpressionPtr parseExpression(const std::vector &tokens, const GeNN::Type::NumericBase *scalarType, + ErrorHandlerBase &errorHandler); //! Parse block item list from tokens /*! Block item lists are function body scope list of statements */ -Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandlerBase &errorHandler); +Statement::StatementList parseBlockItemList(const std::vector &tokens, const GeNN::Type::NumericBase *scalarType, + ErrorHandlerBase &errorHandler); //! Parse type from tokens -const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, ErrorHandlerBase &errorHandler); +const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, + const GeNN::Type::NumericBase *scalarType, ErrorHandlerBase &errorHandler); -} // MiniParse::MiniParse \ No newline at end of file +} // MiniParse::MiniParse diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index eca9dc31c2..1e092f09fe 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -294,13 +294,13 @@ const Pointer *createPointer() } //! Parse a numeric type -const NumericBase *parseNumeric(std::string_view typeString); +const NumericBase *parseNumeric(std::string_view typeString, const NumericBase *scalarType); //! Parse a numeric pointer type -const Pointer *parseNumericPtr(std::string_view typeString); +const Pointer *parseNumericPtr(std::string_view typeString, const NumericBase *scalarType); //! Look up numeric type based on set of type specifiers -const NumericBase *getNumericType(const std::set &typeSpecifiers); +const NumericBase *getNumericType(const std::set &typeSpecifiers, const NumericBase *scalarType); //! Apply C type promotion rules to numeric type const NumericBase *getPromotedType(const NumericBase *type); diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 71b0e79fa5..bbfa29e061 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -111,7 +111,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t assert(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE); - addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())), "ind", + addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType(), getScalarType())), "ind", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); @@ -167,7 +167,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // Loop through sorted dependent variables for(size_t i = 0; i < getSortedArchetypeDependentVars().size(); i++) { - addField(createPointer(parseNumeric(getSortedArchetypeDependentVars().at(i).getVar().type)), "_dependentVar" + std::to_string(i), + addField(createPointer(parseNumeric(getSortedArchetypeDependentVars().at(i).getVar().type, getScalarType())), "_dependentVar" + std::to_string(i), [i, &backend, this](const auto&, size_t g) { const auto &varRef = m_SortedDependentVars[g][i]; @@ -441,7 +441,7 @@ CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged // Add host extra global parameters for(const auto &e : cm->getExtraGlobalParams()) { - const auto *pointerType = parseNumericPtr(e.type); + const auto *pointerType = parseNumericPtr(e.type, getScalarType()); addField(pointerType, e.name, [e](const auto &g, size_t) { return e.name + g.getName(); }, GroupMergedFieldType::HOST_DYNAMIC); @@ -563,13 +563,13 @@ void CustomConnectivityHostUpdateGroupMerged::addVars(const BackendBase &backend for(const auto &v : vars) { // If var is located on the host if (std::invoke(getVarLocationFn, getArchetype(), v.name) & VarLocation::HOST) { - addField(createPointer(parseNumeric(v.type)), v.name, + addField(createPointer(parseNumeric(v.type, getScalarType())), v.name, [v](const auto &g, size_t) { return v.name + g.getName(); }, GroupMergedFieldType::HOST); if(!backend.getDeviceVarPrefix().empty()) { // **TODO** I think could use addPointerField - addField(createPointer(parseNumeric(v.type)), backend.getDeviceVarPrefix() + v.name, + addField(createPointer(parseNumeric(v.type, getScalarType())), backend.getDeviceVarPrefix() + v.name, [v, &backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + v.name + g.getName(); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 74a768bfaf..093ff071b1 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -118,9 +118,7 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string using namespace Type; // Create type environment - // **TEMP** parse precision to get scalar type - const auto *scalarType = Type::parseNumeric(precision); - GroupMergedTypeEnvironment typeEnvironment(*this, scalarType); + GroupMergedTypeEnvironment typeEnvironment(*this, getScalarType()); addField("size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); @@ -158,8 +156,8 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string // Scan, parse and type-check update code Transpiler::ErrorHandler errorHandler; - const auto tokens = Transpiler::Scanner::scanSource(cm->getUpdateCode(), scalarType, errorHandler); - const auto statements = Transpiler::Parser::parseBlockItemList(tokens, errorHandler); + const auto tokens = Transpiler::Scanner::scanSource(cm->getUpdateCode(), getScalarType(), errorHandler); + const auto statements = Transpiler::Parser::parseBlockItemList(tokens, getScalarType(), errorHandler); Transpiler::TypeChecker::typeCheck(statements, typeEnvironment, errorHandler); } @@ -299,7 +297,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // Create type environment // **TEMP** parse precision to get scalar type - GroupMergedTypeEnvironment typeEnvironment(*this, Type::parseNumeric(precision)); + GroupMergedTypeEnvironment typeEnvironment(*this, getScalarType()); // If underlying synapse group has kernel weights if (getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL) { @@ -340,7 +338,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // If synapse group has sparse connectivity if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())), "ind", + addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType(), getScalarType())), "ind", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); @@ -380,7 +378,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // If variable has a transpose if(getArchetype().getVarReferences().at(v.name).getTransposeSynapseGroup() != nullptr) { // Add field with transpose suffix, pointing to transpose var - addField(createPointer(parseNumeric(v.type)), v.name + "Transpose", + addField(createPointer(parseNumeric(v.type, getScalarType())), v.name + "Transpose", [&backend, v](const auto &g, size_t) { const auto varRef = g.getVarReferences().at(v.name); diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 58d446c597..f0c857bb96 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -84,7 +84,7 @@ NeuronPrevSpikeTimeUpdateGroupMerged::NeuronPrevSpikeTimeUpdateGroupMerged(size_ addPointerField("spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); } - const NumericBase *timeType = parseNumeric(timePrecision); + const NumericBase *timeType = parseNumeric(timePrecision, nullptr); if(getArchetype().isPrevSpikeTimeRequired()) { addPointerField("spk", backend.getDeviceVarPrefix() + "glbSpk"); addPointerField(timeType, "prevST", backend.getDeviceVarPrefix() + "prevST"); @@ -191,8 +191,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr using namespace Type; // **HACK** parse precisions - const NumericBase *scalarType = parseNumeric(precision); - const NumericBase *timeType = parseNumeric(timePrecision); + const NumericBase *timeType = parseNumeric(timePrecision, nullptr); // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group orderNeuronGroupChildren(m_SortedMergedInSyns, &NeuronGroupInternal::getFusedPSMInSyn, @@ -250,7 +249,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : vars) { // If we're not initialising or if there is initialization code for this variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addPointerField(parseNumeric(var.type), var.name, backend.getDeviceVarPrefix() + var.name); + addPointerField(parseNumeric(var.type, getScalarType()), var.name, backend.getDeviceVarPrefix() + var.name); } // If we're initializing, add any var init EGPs to structure @@ -290,11 +289,11 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr const SynapseGroupInternal *sg = getSortedArchetypeMergedInSyns().at(i); // Add pointer to insyn - addMergedInSynPointerField(scalarType, "inSynInSyn", i, backend.getDeviceVarPrefix() + "inSyn"); + addMergedInSynPointerField(getScalarType(), "inSynInSyn", i, backend.getDeviceVarPrefix() + "inSyn"); // Add pointer to dendritic delay buffer if required if(sg->isDendriticDelayRequired()) { - addMergedInSynPointerField(scalarType, "denDelayInSyn", i, backend.getDeviceVarPrefix() + "denDelay"); + addMergedInSynPointerField(getScalarType(), "denDelayInSyn", i, backend.getDeviceVarPrefix() + "denDelay"); addMergedInSynPointerField(Uint32::getInstance(), "denDelayPtrInSyn", i, backend.getScalarAddressPrefix() + "denDelayPtr"); } @@ -303,7 +302,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : sg->getPSModel()->getVars()) { // Add pointers to state variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addMergedInSynPointerField(parseNumeric(var.type), var.name + "InSyn", i, backend.getDeviceVarPrefix() + var.name); + addMergedInSynPointerField(parseNumeric(var.type, getScalarType()), var.name + "InSyn", i, backend.getDeviceVarPrefix() + var.name); } // If we're generating an initialization structure, also add any heterogeneous parameters, derived parameters or extra global parameters required for initializers @@ -346,7 +345,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr // Loop through merged output synapses with presynaptic output of archetypical neuron group (0) in sorted order for(size_t i = 0; i < getSortedArchetypeMergedPreOutputOutSyns().size(); i++) { // Add pointer to revInSyn - addMergedPreOutputOutSynPointerField(scalarType, "revInSynOutSyn", i, backend.getDeviceVarPrefix() + "revInSyn"); + addMergedPreOutputOutSynPointerField(getScalarType(), "revInSynOutSyn", i, backend.getDeviceVarPrefix() + "revInSyn"); } // Loop through current sources to archetypical neuron group in sorted order @@ -358,7 +357,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : cs->getCurrentSourceModel()->getVars()) { // Add pointers to state variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(Type::createPointer(parseNumeric(var.type)), var.name + "CS" + std::to_string(i), + addField(Type::createPointer(parseNumeric(var.type, getScalarType())), var.name + "CS" + std::to_string(i), [&backend, i, var, this](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + m_SortedCurrentSources.at(groupIndex).at(i)->getName(); @@ -674,8 +673,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & using namespace Type; // **HACK** parse precisions - const NumericBase *scalarType = parseNumeric(precision); - const NumericBase *timeType = parseNumeric(timePrecision); + const NumericBase *timeType = parseNumeric(timePrecision, nullptr); const bool updateRole = ((role == Role::PresynapticUpdate) || (role == Role::PostsynapticUpdate) @@ -700,11 +698,11 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // If this role is one where postsynaptic input can be provided if(role == Role::PresynapticUpdate || role == Role::SynapseDynamics) { if(getArchetype().isDendriticDelayRequired()) { - addPSPointerField(scalarType, "denDelay", backend.getDeviceVarPrefix() + "denDelay"); + addPSPointerField(getScalarType(), "denDelay", backend.getDeviceVarPrefix() + "denDelay"); addPSPointerField(Uint32::getInstance(), "denDelayPtr", backend.getScalarAddressPrefix() + "denDelayPtr"); } else { - addPSPointerField(scalarType, "inSyn", backend.getDeviceVarPrefix() + "inSyn"); + addPSPointerField(getScalarType(), "inSyn", backend.getDeviceVarPrefix() + "inSyn"); } } @@ -728,7 +726,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & if(updateRole) { // for all types of roles if (getArchetype().isPresynapticOutputRequired()) { - addPreOutputPointerField(scalarType, "revInSyn", backend.getDeviceVarPrefix() + "revInSyn"); + addPreOutputPointerField(getScalarType(), "revInSyn", backend.getDeviceVarPrefix() + "revInSyn"); } // If presynaptic population has delay buffers @@ -773,7 +771,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & for(const auto &v : preVars) { // If variable is referenced in code string, add source pointer if(code.find("$(" + v.name + "_pre)") != std::string::npos) { - addSrcPointerField(parseNumeric(v.type), v.name + "Pre", backend.getDeviceVarPrefix() + v.name); + addSrcPointerField(parseNumeric(v.type, getScalarType()), v.name + "Pre", backend.getDeviceVarPrefix() + v.name); } } @@ -782,7 +780,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & for(const auto &v : postVars) { // If variable is referenced in code string, add target pointer if(code.find("$(" + v.name + "_post)") != std::string::npos) { - addTrgPointerField(parseNumeric(v.type), v.name + "Post", backend.getDeviceVarPrefix() + v.name); + addTrgPointerField(parseNumeric(v.type, getScalarType()), v.name + "Post", backend.getDeviceVarPrefix() + v.name); } } @@ -791,7 +789,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & for(const auto &e : preEGPs) { if(code.find("$(" + e.name + "_pre)") != std::string::npos) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(parseNumericPtr(e.type), e.name + "Pre", + addField(parseNumericPtr(e.type, getScalarType()), e.name + "Pre", [e, prefix](const auto &sg, size_t) { return prefix + e.name + sg.getSrcNeuronGroup()->getName(); }, GroupMergedFieldType::DYNAMIC); } @@ -802,7 +800,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & for(const auto &e : postEGPs) { if(code.find("$(" + e.name + "_post)") != std::string::npos) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(parseNumericPtr(e.type), e.name + "Post", + addField(parseNumericPtr(e.type, getScalarType()), e.name + "Post", [e, prefix](const auto &sg, size_t) { return prefix + e.name + sg.getTrgNeuronGroup()->getName(); }, GroupMergedFieldType::DYNAMIC); } @@ -842,14 +840,14 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // Add presynaptic variables to struct for(const auto &v : wum->getPreVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(createPointer(parseNumeric(v.type)), v.name, + addField(createPointer(parseNumeric(v.type, getScalarType())), v.name, [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPreVarSuffix(); }); } // Add presynaptic variables to struct for(const auto &v : wum->getPostVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(createPointer(parseNumeric(v.type)), v.name, + addField(createPointer(parseNumeric(v.type, getScalarType())), v.name, [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPostVarSuffix(); }); } @@ -860,7 +858,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // Add pointers to connectivity data if(getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { addPointerField("rowLength", backend.getDeviceVarPrefix() + "rowLength"); - addPointerField(parseNumeric(getArchetype().getSparseIndType()), "ind", backend.getDeviceVarPrefix() + "ind"); + addPointerField(parseNumeric(getArchetype().getSparseIndType(), getScalarType()), "ind", backend.getDeviceVarPrefix() + "ind"); // Add additional structure for postsynaptic access if(backend.isPostsynapticRemapRequired() && !wum->getLearnPostCode().empty() @@ -974,7 +972,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // If we're performing an update with individual weights; or this variable should be initialised if((updateRole && individualWeights) || (kernelWeights && updateRole) || varInitRequired) { - addPointerField(parseNumeric(var.type), var.name, backend.getDeviceVarPrefix() + var.name); + addPointerField(parseNumeric(var.type, getScalarType()), var.name, backend.getDeviceVarPrefix() + var.name); } // If we're performing a procedural update or this variable should be initialised, add any var init EGPs to structure @@ -982,7 +980,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & const auto egps = snippet->getExtraGlobalParams(); for(const auto &e : egps) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(parseNumericPtr(e.type), e.name + var.name, + addField(parseNumericPtr(e.type, getScalarType()), e.name + var.name, [e, prefix, var](const SynapseGroupInternal &sg, size_t) { return prefix + e.name + var.name + sg.getName(); diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 2612dfff7c..a6b323b8ce 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -443,7 +443,7 @@ void NeuronInitGroupMerged::generateWUVar(const BackendBase &backend, for(const auto &var : vars) { // Add pointers to state variable if(!varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(createPointer(parseNumeric(var.type)), var.name + fieldPrefixStem + std::to_string(i), + addField(createPointer(parseNumeric(var.type, getScalarType())), var.name + fieldPrefixStem + std::to_string(i), [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto&, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); @@ -1022,7 +1022,7 @@ CustomWUUpdateSparseInitGroupMerged::CustomWUUpdateSparseInitGroupMerged(size_t const SynapseGroupInternal *sg = cg.getSynapseGroup(); return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); }); - addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())), "ind", + addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType(), getScalarType())), "ind", [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); @@ -1184,7 +1184,7 @@ CustomConnectivityUpdateSparseInitGroupMerged::CustomConnectivityUpdateSparseIni const SynapseGroupInternal *sg = cg.getSynapseGroup(); return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); }); - addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType())), "ind", + addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType(), getScalarType())), "ind", [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 062ad5fe9c..6fe074f8e3 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -66,7 +66,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string // If EGP is referenced in event threshold code if(s.eventThresholdCode.find("$(" + egp.name + ")") != std::string::npos) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(parseNumericPtr(egp.type), egp.name + "EventThresh" + std::to_string(i), + addField(parseNumericPtr(egp.type, getScalarType()), egp.name + "EventThresh" + std::to_string(i), [eventThresholdSGs, prefix, egp, i](const auto &, size_t groupIndex) { return prefix + egp.name + eventThresholdSGs.at(groupIndex).at(i)->getName(); @@ -80,7 +80,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string for(const auto &var : sgPreVars) { // If variable is referenced in event threshold code if(s.eventThresholdCode.find("$(" + var.name + ")") != std::string::npos) { - addField(createPointer(parseNumeric(var.type)), var.name + "EventThresh" + std::to_string(i), + addField(createPointer(parseNumeric(var.type, getScalarType())), var.name + "EventThresh" + std::to_string(i), [&backend, eventThresholdSGs, var, i](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + eventThresholdSGs.at(groupIndex).at(i)->getName(); @@ -716,7 +716,7 @@ void NeuronUpdateGroupMerged::generateWUVar(const BackendBase &backend, const s for(size_t v = 0; v < vars.size(); v++) { // Add pointers to state variable const auto var = vars[v]; - addField(Type::createPointer(Type::parseNumeric(var.type)), var.name + fieldPrefixStem + std::to_string(i), + addField(Type::createPointer(Type::parseNumeric(var.type, getScalarType())), var.name + fieldPrefixStem + std::to_string(i), [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto &, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 717451a13e..dcac94988e 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -38,8 +38,8 @@ class ParseError class ParserState { public: - ParserState(const std::vector &tokens, ErrorHandlerBase &errorHandler) - : m_Current(0), m_Tokens(tokens), m_ErrorHandler(errorHandler) + ParserState(const std::vector &tokens, const GeNN::Type::NumericBase *scalarType, ErrorHandlerBase &errorHandler) + : m_Current(0), m_Tokens(tokens), m_ScalarType(scalarType), m_ErrorHandler(errorHandler) {} //--------------------------------------------------------------------------- @@ -127,6 +127,8 @@ class ParserState } bool isAtEnd() const { return (peek().type == Token::Type::END_OF_FILE); } + + const GeNN::Type::NumericBase *getScalarType() const{ return m_ScalarType; } private: //--------------------------------------------------------------------------- @@ -135,7 +137,7 @@ class ParserState size_t m_Current; const std::vector &m_Tokens; - + const GeNN::Type::NumericBase *m_ScalarType; ErrorHandlerBase &m_ErrorHandler; }; @@ -213,7 +215,7 @@ GeNN::Type::QualifiedType parseDeclarationSpecifiers(ParserState &parserState) } while(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER, Token::Type::STAR})); // Lookup numeric type - const auto *numericType = GeNN::Type::getNumericType(typeSpecifiers); + const auto *numericType = GeNN::Type::getNumericType(typeSpecifiers, parserState.getScalarType()); // If pointer, return pointer to numeric type // **THINK** this relies of const being only qualifier @@ -786,6 +788,7 @@ Statement::StatementPtr parseDeclaration(ParserState &parserState) // "long" // "float" // "double" + // "scalar" // "signed" // "unsigned" // "bool" @@ -848,9 +851,10 @@ std::unique_ptr parseBlockItem(ParserState &parserState) //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Parser { -Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandlerBase &errorHandler) +Expression::ExpressionPtr parseExpression(const std::vector &tokens, const GeNN::Type::NumericBase *scalarType, + ErrorHandlerBase &errorHandler) { - ParserState parserState(tokens, errorHandler); + ParserState parserState(tokens, scalarType, errorHandler); try { return parseExpression(parserState); @@ -860,9 +864,10 @@ Expression::ExpressionPtr parseExpression(const std::vector &tokens, Erro } } //--------------------------------------------------------------------------- -Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandlerBase &errorHandler) +Statement::StatementList parseBlockItemList(const std::vector &tokens, const GeNN::Type::NumericBase *scalarType, + ErrorHandlerBase &errorHandler) { - ParserState parserState(tokens, errorHandler); + ParserState parserState(tokens, scalarType, errorHandler); std::vector> statements; while(!parserState.isAtEnd()) { @@ -871,9 +876,10 @@ Statement::StatementList parseBlockItemList(const std::vector &tokens, Er return statements; } //--------------------------------------------------------------------------- -const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, ErrorHandlerBase &errorHandler) +const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, + const GeNN::Type::NumericBase *scalarType, ErrorHandlerBase &errorHandler) { - ParserState parserState(tokens, errorHandler); + ParserState parserState(tokens, scalarType, errorHandler); bool pointerFound = false; std::set typeSpecifiers; while(parserState.match({Token::Type::TYPE_SPECIFIER, Token::Type::STAR})) { @@ -896,7 +902,7 @@ const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPo }; // Lookup numeric type - const auto *numericType = GeNN::Type::getNumericType(typeSpecifiers); + const auto *numericType = GeNN::Type::getNumericType(typeSpecifiers, scalarType); // If pointer, return pointer to numeric type if (pointerFound) { diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index cd776ddb0a..c6679c73d6 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -17,7 +17,7 @@ using namespace GeNN; // Anonymous namespace namespace { -const std::map, const Type::NumericBase*> numericTypes{ +const std::map, const Type::NumericBase*> numericTypeSpecifiers{ {{"char"}, Type::Int8::getInstance()}, {{"unsigned", "char"}, Type::Uint8::getInstance()}, @@ -41,7 +41,9 @@ const std::map, const Type::NumericBase*> numericType {{"double"}, Type::Double::getInstance()}, }; //---------------------------------------------------------------------------- -// Mapping of signed integer numericTypes to their unsigned equivalents +const std::set scalarTypeSpecifier{{"scalar"}}; +//---------------------------------------------------------------------------- +// Mapping of signed integer numericTypeSpecifiers to their unsigned equivalents const std::unordered_map unsignedType{ {Type::Int8::getInstance(), Type::Uint8::getInstance()}, {Type::Int16::getInstance(), Type::Uint16::getInstance()}, @@ -79,16 +81,17 @@ const Pointer *createPointer(const Base *valueType) return new Pointer(valueType); } //---------------------------------------------------------------------------- -const NumericBase *parseNumeric(std::string_view typeString) +const NumericBase *parseNumeric(std::string_view typeString, const NumericBase *scalarType) { using namespace Transpiler; // Scan type SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, nullptr, errorHandler); + const auto tokens = Scanner::scanSource(typeString, scalarType, errorHandler); // Parse type and cast to numeric - const auto *type = dynamic_cast(Parser::parseType(tokens, false, errorHandler)); + const auto *type = dynamic_cast(Parser::parseType(tokens, false, scalarType, + errorHandler)); // If an error was encountered while scanning or parsing, throw exception if (errorHandler.hasError()) { @@ -102,16 +105,16 @@ const NumericBase *parseNumeric(std::string_view typeString) return type; } //---------------------------------------------------------------------------- -const Pointer *parseNumericPtr(std::string_view typeString) +const Pointer *parseNumericPtr(std::string_view typeString, const NumericBase *scalarType) { using namespace Transpiler; // Scan type SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, nullptr, errorHandler); + const auto tokens = Scanner::scanSource(typeString, scalarType, errorHandler); // Parse type and cast to numeric pointer - const auto *type = dynamic_cast(Parser::parseType(tokens, true, errorHandler)); + const auto *type = dynamic_cast(Parser::parseType(tokens, true, scalarType, errorHandler)); // If an error was encountered while scanning or parsing, throw exception if (errorHandler.hasError()) { @@ -125,10 +128,20 @@ const Pointer *parseNumericPtr(std::string_view typeString) return type; } //---------------------------------------------------------------------------- -const NumericBase *getNumericType(const std::set &typeSpecifiers) +const NumericBase *getNumericType(const std::set &typeSpecifiers, const NumericBase *scalarType) { - const auto type = numericTypes.find(typeSpecifiers); - return (type == numericTypes.cend()) ? nullptr : type->second; + if(typeSpecifiers == scalarTypeSpecifier) { + if(scalarType) { + return scalarType; + } + else { + throw std::runtime_error("'scalar' type is not available in this context"); + } + } + else { + const auto type = numericTypeSpecifiers.find(typeSpecifiers); + return (type == numericTypeSpecifiers.cend()) ? nullptr : type->second; + } } //---------------------------------------------------------------------------- const NumericBase *getPromotedType(const NumericBase *type) @@ -158,7 +171,7 @@ const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b) } // Otherwise, must be an integer type else { - // Promote both numericTypes + // Promote both numeric types const auto *aPromoted = getPromotedType(a); const auto *bPromoted = getPromotedType(b); @@ -166,7 +179,7 @@ const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b) if(aPromoted->getTypeName() == bPromoted->getTypeName()) { return aPromoted; } - // Otherwise, if both promoted operands have signed integer numericTypes or both have unsigned integer numericTypes, + // Otherwise, if both promoted operands have signed integer numeric types or both have unsigned integer numeric types, // the operand with the type of lesser integer conversion rank is converted to the type of the operand with greater rank. else if(aPromoted->isSigned() == bPromoted->isSigned()) { return (aPromoted->getRank() > bPromoted->getRank()) ? aPromoted : bPromoted; diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 61d524e0dd..866e8fa71b 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -137,15 +137,15 @@ std::string getPointerTypeName() return createPointer(T::getInstance())->getTypeName(); } -void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment) +void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment, const Type::NumericBase *scalarType = Type::Float::getInstance()) { // Scan TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(code, Type::Float::getInstance(), errorHandler); + const auto tokens = Scanner::scanSource(code, scalarType, errorHandler); ASSERT_FALSE(errorHandler.hasError()); // Parse - const auto statements = Parser::parseBlockItemList(tokens, errorHandler); + const auto statements = Parser::parseBlockItemList(tokens, scalarType, errorHandler); ASSERT_FALSE(errorHandler.hasError()); // Typecheck @@ -161,7 +161,7 @@ Type::QualifiedType typeCheckExpression(std::string_view code, TestEnvironment & EXPECT_FALSE(errorHandler.hasError()); // Parse - const auto expression = Parser::parseExpression(tokens, errorHandler); + const auto expression = Parser::parseExpression(tokens, scalarType, errorHandler); EXPECT_FALSE(errorHandler.hasError()); // Typecheck From 50849a8e099b12f2c63675ba81656c88cdc29265 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 12:45:15 +0000 Subject: [PATCH 046/725] add support for sized types --- src/genn/genn/transpiler/scanner.cc | 6 ++++++ src/genn/genn/type.cc | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index c49fdf3f1a..473218f496 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -51,6 +51,12 @@ const std::unordered_map keywords{ {"scalar", Token::Type::TYPE_SPECIFIER}, {"signed", Token::Type::TYPE_SPECIFIER}, {"unsigned", Token::Type::TYPE_SPECIFIER}, + {"uint8_t", Token::Type::TYPE_SPECIFIER}, + {"int8_t", Token::Type::TYPE_SPECIFIER}, + {"uint16_t", Token::Type::TYPE_SPECIFIER}, + {"int16_t", Token::Type::TYPE_SPECIFIER}, + {"uint32_t", Token::Type::TYPE_SPECIFIER}, + {"int32_t", Token::Type::TYPE_SPECIFIER}, {"bool", Token::Type::TYPE_SPECIFIER}}; //--------------------------------------------------------------------------- const std::map, std::function> integerLiteralSuffixParsers{ diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index c6679c73d6..c54558a4b3 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -19,23 +19,29 @@ namespace { const std::map, const Type::NumericBase*> numericTypeSpecifiers{ {{"char"}, Type::Int8::getInstance()}, + {{"int8_t"}, Type::Int8::getInstance()}, {{"unsigned", "char"}, Type::Uint8::getInstance()}, + {{"uint8_t"}, Type::Uint8::getInstance()}, {{"short"}, Type::Int16::getInstance()}, {{"short", "int"}, Type::Int16::getInstance()}, {{"signed", "short"}, Type::Int16::getInstance()}, {{"signed", "short", "int"}, Type::Int16::getInstance()}, + {{"int16_t"}, Type::Int16::getInstance()}, {{"unsigned", "short"}, Type::Uint16::getInstance()}, {{"unsigned", "short", "int"}, Type::Uint16::getInstance()}, + {{"uint16_t"}, Type::Uint8::getInstance()}, {{"int"}, Type::Int32::getInstance()}, {{"signed"}, Type::Int32::getInstance()}, {{"signed", "int"}, Type::Int32::getInstance()}, + {{"int32_t"}, Type::Int32::getInstance()}, {{"unsigned"}, Type::Uint32::getInstance()}, {{"unsigned", "int"}, Type::Uint32::getInstance()}, + {{"uint32_t"}, Type::Uint32::getInstance()}, {{"float"}, Type::Float::getInstance()}, {{"double"}, Type::Double::getInstance()}, From 6a2a00b65e628ab41611c5356d9d703ed3edd1b9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 13:19:35 +0000 Subject: [PATCH 047/725] fixed small bug in scanner --- src/genn/genn/transpiler/scanner.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index 473218f496..19d689c483 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -247,6 +247,10 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) else if(std::tolower(scanState.peek()) == 'd') { emplaceToken(tokens, Token::Type::NUMBER, scanState, Utils::toCharsThrow(scanState.getLexeme())); + + // Advance + // **NOTE** we do this AFTER parsing float as std::to_chars doesn't deal with suffixes + scanState.advance(); } // Otherwise, this is a scalar literal else { From 0d6294c827f16822dd7acac2906495898735d0a4 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 13:47:49 +0000 Subject: [PATCH 048/725] preprocess codestrings to get rid of legacy $(XX) syntax --- include/genn/genn/code_generator/codeGenUtils.h | 2 ++ src/genn/genn/code_generator/codeGenUtils.cc | 10 ++++++++++ .../genn/code_generator/customUpdateGroupMerged.cc | 3 ++- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 22f3e1c702..223ba23864 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -96,6 +96,8 @@ GENN_EXPORT void checkUnreplacedVariables(const std::string &code, const std::st //-------------------------------------------------------------------------- GENN_EXPORT std::string disambiguateNamespaceFunction(const std::string supportCode, const std::string code, std::string namespaceName); +GENN_EXPORT std::string upgradeCodeString(const std::string &codeString); + //------------------------------------------------------------------------- /*! \brief Function for performing the code and value substitutions necessary to insert neuron related variables, parameters, and extraGlobal parameters into synaptic code. diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index 2f358781fe..8e69b934b0 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -526,4 +526,14 @@ std::string disambiguateNamespaceFunction(const std::string supportCode, const s } return newCode; } +//---------------------------------------------------------------------------- +std::string upgradeCodeString(const std::string &codeString) +{ + // **TODO** snake-case -> camel case known built in variables e.g id_pre -> idPre + // **TODO** old style function call to standard C (these are ambiguous so need to be applied to existing genn functions) + std::regex variable(R"(\$\(([_a-zA-Z][a-zA-Z0-9]*)\))"); + + return std::regex_replace(codeString, variable, "$1"); +} + } // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 093ff071b1..4fc3d51068 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -156,7 +156,8 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string // Scan, parse and type-check update code Transpiler::ErrorHandler errorHandler; - const auto tokens = Transpiler::Scanner::scanSource(cm->getUpdateCode(), getScalarType(), errorHandler); + const std::string code = upgradeCodeString(cm->getUpdateCode()); + const auto tokens = Transpiler::Scanner::scanSource(code, getScalarType(), errorHandler); const auto statements = Transpiler::Parser::parseBlockItemList(tokens, getScalarType(), errorHandler); Transpiler::TypeChecker::typeCheck(statements, typeEnvironment, errorHandler); From 9ebf8735606489eef2fcac9f46ae1ad03720e9df Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 13:49:02 +0000 Subject: [PATCH 049/725] names of external fields in GroupMergedTypeEnvironment have unknown lifetime so can't be stored in std::string_view --- .../code_generator/groupMergedTypeEnvironment.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index 7dd368a611..047dea08e3 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -43,7 +43,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa ErrorHandlerBase &errorHandler, bool initializer) final { // If type isn't found - auto existingType = m_Types.find(name.lexeme); + auto existingType = m_Types.find(std::string{name.lexeme}); if(existingType == m_Types.end()) { if(m_Enclosing) { return m_Enclosing->assign(name, op, assignedType, errorHandler, initializer); @@ -63,7 +63,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final { - auto existingType = m_Types.find(name.lexeme); + auto existingType = m_Types.find(std::string{name.lexeme}); if(existingType == m_Types.end()) { if(m_Enclosing) { return m_Enclosing->incDec(name, op, errorHandler); @@ -104,7 +104,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- - void defineField(const Type::Base *type, std::string_view name, bool isConstValue = false, bool isConstPointer = false) + void defineField(const Type::Base *type, const std::string &name, bool isConstValue = false, bool isConstPointer = false) { if(!m_Types.try_emplace(name, std::piecewise_construct, std::forward_as_tuple(type, isConstValue, isConstPointer), @@ -115,12 +115,12 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa } template - void defineField(std::string_view name, bool isConstValue = false, bool isConstPointer = false) + void defineField(const std::string &name, bool isConstValue = false, bool isConstPointer = false) { defineField(T::getInstance(), name, isConstPointer, isConstPointer); } - void defineField(const Type::Base *type, std::string_view name, bool isConstValue, bool isConstPointer, + void defineField(const Type::Base *type, const std::string &name, bool isConstValue, bool isConstPointer, const Type::Base *fieldType, std::string_view fieldName, typename G::GetFieldValueFunc getFieldValue, GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD) { @@ -132,7 +132,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa } } - void defineField(const Type::Base *type, std::string_view name, bool isConstValue, bool isConstPointer, + void defineField(const Type::Base *type, const std::string &name, bool isConstValue, bool isConstPointer, typename G::GetFieldValueFunc getFieldValue, GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD) { defineField(type, name, isConstValue, isConstPointer, @@ -249,6 +249,6 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa const Type::NumericBase *m_ScalarType; EnvironmentBase *m_Enclosing; - std::unordered_map>> m_Types; + std::unordered_map>> m_Types; }; } // namespace GeNN::CodeGenerator From e3e9dc5abf5b9ddfce771874d14f2d1d371137e7 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 14:03:35 +0000 Subject: [PATCH 050/725] updated generator --- src/genn/generator/generator.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/genn/generator/generator.cc b/src/genn/generator/generator.cc index f47041cbd4..1eebeb20d2 100644 --- a/src/genn/generator/generator.cc +++ b/src/genn/generator/generator.cc @@ -52,7 +52,8 @@ int main(int argc, //!< number of arguments; expected to be 3 // Initialise logging, appending all to console plog::ConsoleAppender consoleAppender; - Logging::init(GENN_PREFERENCES.logLevel, GENN_PREFERENCES.logLevel, &consoleAppender, &consoleAppender); + Logging::init(GENN_PREFERENCES.logLevel, GENN_PREFERENCES.logLevel, GENN_PREFERENCES.logLevel, + &consoleAppender, &consoleAppender, &consoleAppender); // Finalize model model.finalize(); From 460f43d6367f55821b00c73055ca731efb941926 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 14:03:52 +0000 Subject: [PATCH 051/725] fixed small bug --- include/genn/genn/code_generator/groupMerged.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 8a22be6353..50564f9678 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -123,7 +123,7 @@ class GroupMerged } // Otherwise, allow the backend to add a prefix else { - os << backend.getPointerPrefix() << type; + os << backend.getPointerPrefix() << type->getTypeName(); } } // Otherwise, leave the type alone From 5ad13d7c28704f8a264726f719e120e14bcc055d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 14:15:27 +0000 Subject: [PATCH 052/725] switched ``SynapseGroup::getSparseIndType`` to using type system --- include/genn/genn/synapseGroup.h | 2 +- .../customConnectivityUpdateGroupMerged.cc | 2 +- .../code_generator/customUpdateGroupMerged.cc | 2 +- src/genn/genn/code_generator/generateRunner.cc | 4 ++-- src/genn/genn/code_generator/groupMerged.cc | 2 +- .../genn/code_generator/initGroupMerged.cc | 4 ++-- src/genn/genn/customConnectivityUpdate.cc | 2 +- src/genn/genn/customUpdate.cc | 6 +++--- src/genn/genn/synapseGroup.cc | 18 +++++++++--------- 9 files changed, 21 insertions(+), 21 deletions(-) diff --git a/include/genn/genn/synapseGroup.h b/include/genn/genn/synapseGroup.h index 1a358305ee..48e089f302 100644 --- a/include/genn/genn/synapseGroup.h +++ b/include/genn/genn/synapseGroup.h @@ -314,7 +314,7 @@ class GENN_EXPORT SynapseGroup bool isWUPostModelFused() const { return m_FusedWUPostVarSuffix != getName(); } //! Get the type to use for sparse connectivity indices for synapse group - std::string getSparseIndType() const; + const Type::NumericBase *getSparseIndType() const; //! Generate hash of weight update component of this synapse group /*! NOTE: this can only be called after model is finalized */ diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index bbfa29e061..1cce726eba 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -111,7 +111,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t assert(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE); - addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType(), getScalarType())), "ind", + addField(createPointer(getArchetype().getSynapseGroup()->getSparseIndType()), "ind", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 4fc3d51068..7ef2aa19c7 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -339,7 +339,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // If synapse group has sparse connectivity if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType(), getScalarType())), "ind", + addField(createPointer(getArchetype().getSynapseGroup()->getSparseIndType()), "ind", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 8671b6193f..fbe1f541b4 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -1327,7 +1327,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Target indices backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - s.second.getSparseIndType(), "ind" + s.second.getName(), varLoc, size, mem); + s.second.getSparseIndType()->getTypeName(), "ind" + s.second.getName(), varLoc, size, mem); // **TODO** remap is not always required if(backend.isPostsynapticRemapRequired() && !s.second.getWUModel()->getLearnPostCode().empty()) { @@ -1352,7 +1352,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, s.second.getSparseConnectivityLocation(), autoInitialized, s.second.getSrcNeuronGroup()->getNumNeurons()); // Target indices - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, s.second.getSparseIndType(), "ind" + s.second.getName(), + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, s.second.getSparseIndType()->getTypeName(), "ind" + s.second.getName(), s.second.getSparseConnectivityLocation(), autoInitialized, size); }); } diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index f0c857bb96..3fd5933200 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -858,7 +858,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // Add pointers to connectivity data if(getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { addPointerField("rowLength", backend.getDeviceVarPrefix() + "rowLength"); - addPointerField(parseNumeric(getArchetype().getSparseIndType(), getScalarType()), "ind", backend.getDeviceVarPrefix() + "ind"); + addPointerField(getArchetype().getSparseIndType(), "ind", backend.getDeviceVarPrefix() + "ind"); // Add additional structure for postsynaptic access if(backend.isPostsynapticRemapRequired() && !wum->getLearnPostCode().empty() diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index a6b323b8ce..9f2078ffe0 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -1022,7 +1022,7 @@ CustomWUUpdateSparseInitGroupMerged::CustomWUUpdateSparseInitGroupMerged(size_t const SynapseGroupInternal *sg = cg.getSynapseGroup(); return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); }); - addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType(), getScalarType())), "ind", + addField(createPointer(getArchetype().getSynapseGroup()->getSparseIndType()), "ind", [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); @@ -1184,7 +1184,7 @@ CustomConnectivityUpdateSparseInitGroupMerged::CustomConnectivityUpdateSparseIni const SynapseGroupInternal *sg = cg.getSynapseGroup(); return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); }); - addField(createPointer(parseNumeric(getArchetype().getSynapseGroup()->getSparseIndType(), getScalarType())), "ind", + addField(createPointer(getArchetype().getSynapseGroup()->getSparseIndType()), "ind", [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index a08296d799..b550a70e53 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -311,7 +311,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdate::getHashDigest( Utils::updateHash(getUpdateGroupName(), hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType(), hash); + Utils::updateHash(getSynapseGroup()->getSparseIndType()->getTypeName(), hash); // Because it adds and removes synapses, connectivity update has to update // ALL variables associated with synapse group being modified as well as diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 3ac751379f..a9981dd929 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -266,7 +266,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getHashDigest() const CustomUpdateBase::updateHash(hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType(), hash); + Utils::updateHash(getSynapseGroup()->getSparseIndType()->getTypeName(), hash); // Loop through variable references for(const auto &v : getVarReferences()) { @@ -287,7 +287,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getInitHashDigest() cons CustomUpdateBase::updateInitHash(hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType(), hash); + Utils::updateHash(getSynapseGroup()->getSparseIndType()->getTypeName(), hash); return hash.get_digest(); } -} // namespace GeNN \ No newline at end of file +} // namespace GeNN diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 02d3a56c6a..22d6cf0bcd 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -713,23 +713,23 @@ bool SynapseGroup::canPreOutputBeFused() const return true; } //---------------------------------------------------------------------------- -std::string SynapseGroup::getSparseIndType() const +const Type::NumericBase *SynapseGroup::getSparseIndType() const { // If narrow sparse inds are enabled if(m_NarrowSparseIndEnabled) { // If number of target neurons can be represented using a uint8, use this type const unsigned int numTrgNeurons = getTrgNeuronGroup()->getNumNeurons(); - if(numTrgNeurons <= std::numeric_limits::max()) { - return "uint8_t"; + if(numTrgNeurons <= Type::Uint8::getInstance()->getMax()) { + return Type::Uint8::getInstance();; } // Otherwise, if they can be represented as a uint16, use this type - else if(numTrgNeurons <= std::numeric_limits::max()) { - return "uint16_t"; + else if(numTrgNeurons <= Type::Uint16::getInstance()->getMax()) { + return Type::Uint16::getInstance(); } } // Otherwise, use 32-bit int - return "uint32_t"; + return Type::Uint32::getInstance(); } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type SynapseGroup::getWUHashDigest() const @@ -739,7 +739,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getWUHashDigest() const Utils::updateHash(getDelaySteps(), hash); Utils::updateHash(getBackPropDelaySteps(), hash); Utils::updateHash(getMaxDendriticDelayTimesteps(), hash); - Utils::updateHash(getSparseIndType(), hash); + Utils::updateHash(getSparseIndType()->getTypeName(), hash); Utils::updateHash(getNumThreadsPerSpike(), hash); Utils::updateHash(isEventThresholdReTestRequired(), hash); Utils::updateHash(getSpanType(), hash); @@ -904,7 +904,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getWUInitHashDigest() cons { boost::uuids::detail::sha1 hash; Utils::updateHash(getMatrixType(), hash); - Utils::updateHash(getSparseIndType(), hash); + Utils::updateHash(getSparseIndType()->getTypeName(), hash); Utils::updateHash(getWUModel()->getVars(), hash); Utils::updateHash(getWUModel()->getSynapseDynamicsCode().empty(), hash); @@ -969,7 +969,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getConnectivityInitHashDig boost::uuids::detail::sha1 hash; Utils::updateHash(getConnectivityInitialiser().getHashDigest(), hash); Utils::updateHash(getMatrixType(), hash); - Utils::updateHash(getSparseIndType(), hash); + Utils::updateHash(getSparseIndType()->getTypeName(), hash); return hash.get_digest(); } //---------------------------------------------------------------------------- From b474b82a9c7af94b71f86fb7ab8665979be9fe4f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 14:36:35 +0000 Subject: [PATCH 053/725] rename Type::Base::getTypeName to Type::Base::getName as the typeness if obvious! --- .../genn/genn/code_generator/groupMerged.h | 8 +-- .../genn/code_generator/modelSpecMerged.h | 4 +- include/genn/genn/type.h | 33 ++++++++--- src/genn/backends/cuda/backend.cc | 2 +- src/genn/backends/opencl/backend.cc | 2 +- .../backends/single_threaded_cpu/backend.cc | 2 +- .../genn/code_generator/generateRunner.cc | 4 +- src/genn/genn/customConnectivityUpdate.cc | 2 +- src/genn/genn/customUpdate.cc | 4 +- src/genn/genn/synapseGroup.cc | 6 +- src/genn/genn/transpiler/prettyPrinter.cc | 2 +- src/genn/genn/transpiler/scanner.cc | 4 +- src/genn/genn/transpiler/typeChecker.cc | 56 +++++++++---------- src/genn/genn/type.cc | 20 +++++-- tests/unit/typeChecker.cc | 40 ++++++------- 15 files changed, 107 insertions(+), 82 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 50564f9678..520170a058 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -97,7 +97,7 @@ class GroupMerged std::sort(sortedFields.begin(), sortedFields.end(), [&backend](const Field &a, const Field &b) { - return (backend.getSize(std::get<0>(a)->getTypeName()) > backend.getSize(std::get<0>(b)->getTypeName())); + return (backend.getSize(std::get<0>(a)->getName()) > backend.getSize(std::get<0>(b)->getName())); }); return sortedFields; @@ -123,12 +123,12 @@ class GroupMerged } // Otherwise, allow the backend to add a prefix else { - os << backend.getPointerPrefix() << type->getTypeName(); + os << backend.getPointerPrefix() << type->getName(); } } // Otherwise, leave the type alone else { - os << type->getTypeName(); + os << type->getName(); } os << " " << std::get<1>(f) << ";" << std::endl; } @@ -160,7 +160,7 @@ class GroupMerged for(const auto &f : sortedFields) { // Add size of field to total // **TODO** size should be built into type system - const size_t fieldSize = backend.getSize(std::get<0>(f)->getTypeName()); + const size_t fieldSize = backend.getSize(std::get<0>(f)->getName()); structSize += fieldSize; // Update largest field size diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 875f8e4cca..7af74bb2f7 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -52,8 +52,8 @@ class GENN_EXPORT ModelSpecMerged //! lexicographically compares all three struct members bool operator < (const EGPField &other) const { - return (std::make_tuple(mergedGroupIndex, type->getTypeName(), fieldName, hostGroup) - < std::make_tuple(other.mergedGroupIndex, other.type->getTypeName(), other.fieldName, other.hostGroup)); + return (std::make_tuple(mergedGroupIndex, type->getName(), fieldName, hostGroup) + < std::make_tuple(other.mergedGroupIndex, other.type->getName(), other.fieldName, other.hostGroup)); } }; diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 1e092f09fe..e73ebc10df 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -35,7 +35,7 @@ class TYPE : public Numeric \ { \ DECLARE_TYPE(TYPE) \ - virtual std::string getTypeName() const{ return #UNDERLYING_TYPE; } \ + virtual std::string getName() const{ return #UNDERLYING_TYPE; } \ }; \ template<> \ struct TypeTraits \ @@ -52,6 +52,11 @@ #define IMPLEMENT_TYPE(TYPE) TYPE *TYPE::s_Instance = NULL #define IMPLEMENT_NUMERIC_TYPE(TYPE) IMPLEMENT_TYPE(TYPE) +// **YUCK** on Windows undefine CONST macro (some part of wincrypt) +#ifdef _WIN32 + #undef CONST +#endif + //---------------------------------------------------------------------------- // GeNN::Type::TypeTraits //---------------------------------------------------------------------------- @@ -68,7 +73,7 @@ struct TypeTraits //---------------------------------------------------------------------------- enum class Qualifier : unsigned int { - CONSTT = (1 << 0) + CONST = (1 << 0) }; inline bool operator & (Qualifier a, Qualifier b) @@ -76,6 +81,11 @@ inline bool operator & (Qualifier a, Qualifier b) return (static_cast(a) & static_cast(b)) != 0; } +inline Qualifier operator | (Qualifier a, Qualifier b) +{ + return static_cast(static_cast(a) | static_cast(b)); +} + //---------------------------------------------------------------------------- // GeNN::Type::Base //---------------------------------------------------------------------------- @@ -86,7 +96,12 @@ class Base //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ - virtual std::string getTypeName() const = 0; + virtual std::string getName() const = 0; + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + const Base *getPointerType() const; }; //---------------------------------------------------------------------------- @@ -104,8 +119,8 @@ class Pointer : public Base //------------------------------------------------------------------------ // Base virtuals //------------------------------------------------------------------------ - virtual std::string getTypeName() const{ return getValueType()->getTypeName() + "*";} - + virtual std::string getName() const{ return getValueType()->getName() + "*";} + //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ @@ -188,7 +203,7 @@ class ForeignFunctionBase : public Base //------------------------------------------------------------------------ // Base virtuals //------------------------------------------------------------------------ - virtual std::string getTypeName() const = 0; + virtual std::string getName() const = 0; //------------------------------------------------------------------------ // Declared virtuals @@ -207,9 +222,9 @@ class ForeignFunction : public ForeignFunctionBase //------------------------------------------------------------------------ // Base virtuals //------------------------------------------------------------------------ - virtual std::string getTypeName() const final + virtual std::string getName() const final { - std::string typeName = getReturnType()->getTypeName() + "("; + std::string typeName = getReturnType()->getName() + "("; updateTypeName(typeName); typeName += ")"; return typeName; @@ -240,7 +255,7 @@ class ForeignFunction : public ForeignFunctionBase static void updateTypeName(std::string &typeName) { // Add argument typename to string - typeName += T::getInstance()->getTypeName(); + typeName += T::getInstance()->getName(); // If there are more arguments left in pack, add comma and recurse if constexpr (sizeof...(Args)) { diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index a5c0c32411..1e55ca0966 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -1732,7 +1732,7 @@ void Backend::genMergedExtraGlobalParamPush(CodeStream &os, const std::string &s //-------------------------------------------------------------------------- std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) const { - return type->getTypeName(); + return type->getName(); } //-------------------------------------------------------------------------- void Backend::genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index 933b025ee4..1942898695 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -2049,7 +2049,7 @@ std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) con } // Otherwise, type remains the same else { - return type->getTypeName(); + return type->getName(); } } //-------------------------------------------------------------------------- diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index d15e1765e2..653a342e15 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1350,7 +1350,7 @@ void Backend::genMergedExtraGlobalParamPush(CodeStream &os, const std::string &s //-------------------------------------------------------------------------- std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) const { - return type->getTypeName(); + return type->getName(); } //-------------------------------------------------------------------------- std::string Backend::getMergedGroupSimRNGType() const diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index fbe1f541b4..4597505f42 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -1327,7 +1327,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Target indices backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - s.second.getSparseIndType()->getTypeName(), "ind" + s.second.getName(), varLoc, size, mem); + s.second.getSparseIndType()->getName(), "ind" + s.second.getName(), varLoc, size, mem); // **TODO** remap is not always required if(backend.isPostsynapticRemapRequired() && !s.second.getWUModel()->getLearnPostCode().empty()) { @@ -1352,7 +1352,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, s.second.getSparseConnectivityLocation(), autoInitialized, s.second.getSrcNeuronGroup()->getNumNeurons()); // Target indices - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, s.second.getSparseIndType()->getTypeName(), "ind" + s.second.getName(), + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, s.second.getSparseIndType()->getName(), "ind" + s.second.getName(), s.second.getSparseConnectivityLocation(), autoInitialized, size); }); } diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index b550a70e53..296821bfd3 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -311,7 +311,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdate::getHashDigest( Utils::updateHash(getUpdateGroupName(), hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType()->getTypeName(), hash); + Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName(), hash); // Because it adds and removes synapses, connectivity update has to update // ALL variables associated with synapse group being modified as well as diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index a9981dd929..5d7fc7af13 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -266,7 +266,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getHashDigest() const CustomUpdateBase::updateHash(hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType()->getTypeName(), hash); + Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName(), hash); // Loop through variable references for(const auto &v : getVarReferences()) { @@ -287,7 +287,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getInitHashDigest() cons CustomUpdateBase::updateInitHash(hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType()->getTypeName(), hash); + Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName(), hash); return hash.get_digest(); } } // namespace GeNN diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 22d6cf0bcd..0c83d81aa0 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -739,7 +739,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getWUHashDigest() const Utils::updateHash(getDelaySteps(), hash); Utils::updateHash(getBackPropDelaySteps(), hash); Utils::updateHash(getMaxDendriticDelayTimesteps(), hash); - Utils::updateHash(getSparseIndType()->getTypeName(), hash); + Utils::updateHash(getSparseIndType()->getName(), hash); Utils::updateHash(getNumThreadsPerSpike(), hash); Utils::updateHash(isEventThresholdReTestRequired(), hash); Utils::updateHash(getSpanType(), hash); @@ -904,7 +904,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getWUInitHashDigest() cons { boost::uuids::detail::sha1 hash; Utils::updateHash(getMatrixType(), hash); - Utils::updateHash(getSparseIndType()->getTypeName(), hash); + Utils::updateHash(getSparseIndType()->getName(), hash); Utils::updateHash(getWUModel()->getVars(), hash); Utils::updateHash(getWUModel()->getSynapseDynamicsCode().empty(), hash); @@ -969,7 +969,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getConnectivityInitHashDig boost::uuids::detail::sha1 hash; Utils::updateHash(getConnectivityInitialiser().getHashDigest(), hash); Utils::updateHash(getMatrixType(), hash); - Utils::updateHash(getSparseIndType()->getTypeName(), hash); + Utils::updateHash(getSparseIndType()->getName(), hash); return hash.get_digest(); } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 4d58023a90..c42350c84e 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -263,7 +263,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor if(qualifiedType.constValue) { m_StringStream << "const "; } - m_StringStream << qualifiedType.type->getTypeName() << " "; + m_StringStream << qualifiedType.type->getName() << " "; if(qualifiedType.constPointer) { m_StringStream << "const "; diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index 19d689c483..5afe181268 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -255,13 +255,13 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) // Otherwise, this is a scalar literal else { // If the scalar type is float, add single-precision token - if(scanState.getScalarType()->getTypeName() == "float") { + if(scanState.getScalarType()->getName() == "float") { emplaceToken(tokens, Token::Type::NUMBER, scanState, Utils::toCharsThrow(scanState.getLexeme())); } // Otherwise, add double-precision token - else if(scanState.getScalarType()->getTypeName() == "double") { + else if(scanState.getScalarType()->getName() == "double") { emplaceToken(tokens, Token::Type::NUMBER, scanState, Utils::toCharsThrow(scanState.getLexeme())); } diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index d18b94798b..b4bf0cb57d 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -137,7 +137,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto indexNumericType = dynamic_cast(indexType.type); if (!indexNumericType || !indexNumericType->isIntegral()) { m_ErrorHandler.error(arraySubscript.getPointerName(), - "Invalid subscript index type '" + indexType.type->getTypeName() + "'"); + "Invalid subscript index type '" + indexType.type->getName() + "'"); throw TypeCheckError(); } @@ -173,8 +173,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto rightPointerType = dynamic_cast(rightType.type); if (leftPointerType && rightPointerType && opType == Token::Type::MINUS) { // Check pointers are compatible - if (leftPointerType->getTypeName() != rightPointerType->getTypeName()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getTypeName() + "' and '" + rightType.type->getTypeName()); + if (leftPointerType->getName() != rightPointerType->getName()) { + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getName() + "' and '" + rightType.type->getName()); throw TypeCheckError(); } @@ -186,7 +186,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Check that numeric operand is integer if (!rightNumericType->isIntegral()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getTypeName() + "' and '" + rightType.type->getTypeName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getName() + "' and '" + rightType.type->getName()); throw TypeCheckError(); } @@ -198,7 +198,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Check that numeric operand is integer if (!leftNumericType->isIntegral()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getTypeName() + "' and '" + rightType.type->getTypeName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getName() + "' and '" + rightType.type->getName()); throw TypeCheckError(); } @@ -214,7 +214,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Check that operands are integers if (!leftNumericType->isIntegral() || !rightNumericType->isIntegral()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getTypeName() + "' and '" + rightType.type->getTypeName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getName() + "' and '" + rightType.type->getName()); throw TypeCheckError(); } @@ -234,7 +234,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } } else { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getTypeName() + "' and '" + rightType.type->getTypeName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getName() + "' and '" + rightType.type->getName()); throw TypeCheckError(); } } @@ -283,12 +283,12 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If value const is being removed if (rightType.constValue && !cast.getQualifiedType().constValue) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getTypeName() + "' and '" + rightType.type->getTypeName()); + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getName() + "' and '" + rightType.type->getName()); throw TypeCheckError(); } // Otherwise, if pointer const is being removed else if (rightType.constPointer && !cast.getQualifiedType().constPointer) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getTypeName() + "' and '" + rightType.type->getTypeName()); + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getName() + "' and '" + rightType.type->getName()); throw TypeCheckError(); } @@ -298,14 +298,14 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto leftNumericType = dynamic_cast(cast.getQualifiedType().type); auto leftPointerType = dynamic_cast(cast.getQualifiedType().type); if (rightPointerType && leftPointerType) { - if (rightPointerType->getTypeName() != leftPointerType->getTypeName()) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getTypeName() + "' and '" + rightType.type->getTypeName()); + if (rightPointerType->getName() != leftPointerType->getName()) { + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getName() + "' and '" + rightType.type->getName()); throw TypeCheckError(); } } // Otherwise, if either operand isn't numeric else if(!leftNumericType | !rightNumericType) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getTypeName() + "' and '" + rightType.type->getTypeName()); + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getName() + "' and '" + rightType.type->getName()); throw TypeCheckError(); } @@ -326,7 +326,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } else { m_ErrorHandler.error(conditional.getQuestion(), - "Invalid operand types '" + trueType.type->getTypeName() + "' and '" + falseType.type->getTypeName() + "' to conditional"); + "Invalid operand types '" + trueType.type->getName() + "' and '" + falseType.type->getName() + "' to conditional"); throw TypeCheckError(); } } @@ -380,7 +380,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto rightPointerType = dynamic_cast(rightType.type); if (!rightPointerType) { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType.type->getTypeName() + "'"); + "Invalid operand type '" + rightType.type->getName() + "'"); throw TypeCheckError(); } @@ -405,7 +405,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType.type->getTypeName() + "'"); + "Invalid operand type '" + rightType.type->getName() + "'"); throw TypeCheckError(); } } @@ -422,7 +422,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType.type->getTypeName() + "'"); + "Invalid operand type '" + rightType.type->getName() + "'"); throw TypeCheckError(); } } @@ -512,7 +512,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto valNumericType = dynamic_cast(valType.type); if (!valNumericType || !valNumericType->isIntegral()) { m_ErrorHandler.error(labelled.getKeyword(), - "Invalid case value '" + valType.type->getTypeName() + "'"); + "Invalid case value '" + valType.type->getName() + "'"); throw TypeCheckError(); } } @@ -526,7 +526,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto condNumericType = dynamic_cast(condType.type); if (!condNumericType || !condNumericType->isIntegral()) { m_ErrorHandler.error(switchStatement.getSwitch(), - "Invalid condition '" + condType.type->getTypeName() + "'"); + "Invalid condition '" + condType.type->getName() + "'"); throw TypeCheckError(); } @@ -611,19 +611,19 @@ const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Typ if (pointerAssignedType && pointerExistingType) { // If we're trying to assign a pointer to a const value to a pointer if (assignedType.constValue && !existingType.constValue) { - errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getTypeName() + "' and '" + pointerAssignedType->getTypeName()); + errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName() + "' and '" + pointerAssignedType->getName()); throw TypeCheckError(); } // If pointer types aren't compatible - if (pointerExistingType->getTypeName() != pointerAssignedType->getTypeName()) { - errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getTypeName() + "' and '" + pointerAssignedType->getTypeName()); + if (pointerExistingType->getName() != pointerAssignedType->getName()) { + errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName() + "' and '" + pointerAssignedType->getName()); throw TypeCheckError(); } } // Otherwise, if we're trying to initialise a pointer with a non-pointer or vice-versa else if (pointerAssignedType || pointerExistingType) { - errorHandler.error(name, "Invalid operand types '" + existingType.type->getTypeName() + "' and '" + assignedType.type->getTypeName()); + errorHandler.error(name, "Invalid operand types '" + existingType.type->getName() + "' and '" + assignedType.type->getName()); throw TypeCheckError(); } } @@ -632,13 +632,13 @@ const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Typ // If the operand being added isn't numeric or the type being added to is neither numeric or a pointer if (!numericAssignedType || (!pointerExistingType && !numericExistingType)) { - errorHandler.error(name, "Invalid operand types '" + existingType.type->getTypeName() + "' and '" + assignedType.type->getTypeName() + "'"); + errorHandler.error(name, "Invalid operand types '" + existingType.type->getName() + "' and '" + assignedType.type->getName() + "'"); throw TypeCheckError(); } // If we're adding a numeric type to a pointer, check it's an integer if (pointerExistingType && numericAssignedType->isIntegral()) { - errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getTypeName() + "'"); + errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName() + "'"); throw TypeCheckError(); } } @@ -646,22 +646,22 @@ const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Typ else { // If either type is non-numeric, give error if(!numericAssignedType) { - errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getTypeName() + "'"); + errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName() + "'"); throw TypeCheckError(); } if(!numericExistingType) { - errorHandler.error(name, "Invalid operand types '" + existingType.type->getTypeName() + "'"); + errorHandler.error(name, "Invalid operand types '" + existingType.type->getName() + "'"); throw TypeCheckError(); } // If operand isn't one that takes any numeric type, check both operands are integral if (op != Token::Type::STAR_EQUAL && op != Token::Type::SLASH_EQUAL) { if(!numericAssignedType->isIntegral()) { - errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getTypeName() + "'"); + errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName() + "'"); throw TypeCheckError(); } if(!numericExistingType->isIntegral()) { - errorHandler.error(name, "Invalid operand types '" + numericExistingType->getTypeName() + "'"); + errorHandler.error(name, "Invalid operand types '" + numericExistingType->getName() + "'"); throw TypeCheckError(); } } diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index c54558a4b3..3793ee563a 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -77,6 +77,16 @@ IMPLEMENT_NUMERIC_TYPE(Double); IMPLEMENT_TYPE(Exp); IMPLEMENT_TYPE(Sqrt); +//---------------------------------------------------------------------------- +// GeNN::Type::Base +//---------------------------------------------------------------------------- +const Base *Base::getPointerType() const +{ + // **TODO** befriend constructor + // **TODO** don't just leak these! + return new Pointer(this); +} + //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- @@ -166,13 +176,13 @@ const NumericBase *getPromotedType(const NumericBase *type) const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b) { // If either type is double, common type is double - const auto &aTypeName = a->getTypeName(); - const auto &bTypeName = b->getTypeName(); - if(aTypeName == Double::getInstance()->getTypeName() || bTypeName == Double::getInstance()->getTypeName()) { + const auto &aTypeName = a->getName(); + const auto &bTypeName = b->getName(); + if(aTypeName == Double::getInstance()->getName() || bTypeName == Double::getInstance()->getName()) { return Double::getInstance(); } // Otherwise, if either type is float, common type is float - if(aTypeName == Float::getInstance()->getTypeName() || bTypeName == Float::getInstance()->getTypeName()) { + if(aTypeName == Float::getInstance()->getName() || bTypeName == Float::getInstance()->getName()) { return Float::getInstance(); } // Otherwise, must be an integer type @@ -182,7 +192,7 @@ const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b) const auto *bPromoted = getPromotedType(b); // If both promoted operands have the same type, then no further conversion is needed. - if(aPromoted->getTypeName() == bPromoted->getTypeName()) { + if(aPromoted->getName() == bPromoted->getName()) { return aPromoted; } // Otherwise, if both promoted operands have signed integer numeric types or both have unsigned integer numeric types, diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 866e8fa71b..b34e26ecf3 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -134,7 +134,7 @@ class TestEnvironment : public TypeChecker::EnvironmentBase template std::string getPointerTypeName() { - return createPointer(T::getInstance())->getTypeName(); + return createPointer(T::getInstance())->getName(); } void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment, const Type::NumericBase *scalarType = Type::Float::getInstance()) @@ -181,7 +181,7 @@ TEST(TypeChecker, ArraySubscript) TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); const auto type = typeCheckExpression("intArray[4]", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -265,7 +265,7 @@ TEST(TypeChecker, Cast) TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("(float)intVal", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), Type::Float::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Float::getInstance()->getName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -275,7 +275,7 @@ TEST(TypeChecker, Cast) TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("(const int)intVal", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -285,7 +285,7 @@ TEST(TypeChecker, Cast) TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); const auto type = typeCheckExpression("(const int*)intArray", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), getPointerTypeName()); + EXPECT_EQ(type.type->getName(), getPointerTypeName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -295,7 +295,7 @@ TEST(TypeChecker, Cast) TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); const auto type = typeCheckExpression("(int * const)intArray", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), getPointerTypeName()); + EXPECT_EQ(type.type->getName(), getPointerTypeName()); EXPECT_FALSE(type.constValue); EXPECT_TRUE(type.constPointer); } @@ -354,7 +354,7 @@ TEST(TypeChecker, IncDec) TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("intVal++", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -364,7 +364,7 @@ TEST(TypeChecker, IncDec) TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); const auto type = typeCheckExpression("intArray++", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), getPointerTypeName()); + EXPECT_EQ(type.type->getName(), getPointerTypeName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -374,7 +374,7 @@ TEST(TypeChecker, IncDec) TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray", true); const auto type = typeCheckExpression("intArray++", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), getPointerTypeName()); + EXPECT_EQ(type.type->getName(), getPointerTypeName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -400,7 +400,7 @@ TEST(TypeChecker, Literal) { TestEnvironment typeEnvironment; const auto type = typeCheckExpression("1.0f", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), Type::Float::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Float::getInstance()->getName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -409,7 +409,7 @@ TEST(TypeChecker, Literal) { TestEnvironment typeEnvironment; const auto type = typeCheckExpression("1.0", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), Type::Float::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Float::getInstance()->getName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -418,7 +418,7 @@ TEST(TypeChecker, Literal) { TestEnvironment typeEnvironment; const auto type = typeCheckExpression("1.0", typeEnvironment, Type::Double::getInstance()); - EXPECT_EQ(type.type->getTypeName(), Type::Double::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Double::getInstance()->getName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -427,7 +427,7 @@ TEST(TypeChecker, Literal) { TestEnvironment typeEnvironment; const auto type = typeCheckExpression("1.0d", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), Type::Double::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Double::getInstance()->getName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -436,7 +436,7 @@ TEST(TypeChecker, Literal) { TestEnvironment typeEnvironment; const auto type = typeCheckExpression("100", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -445,7 +445,7 @@ TEST(TypeChecker, Literal) { TestEnvironment typeEnvironment; const auto type = typeCheckExpression("100U", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), Type::Uint32::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Uint32::getInstance()->getName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -458,7 +458,7 @@ TEST(TypeChecker, Unary) TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -468,7 +468,7 @@ TEST(TypeChecker, Unary) TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray", true); const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -478,7 +478,7 @@ TEST(TypeChecker, Unary) TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray", false, true); const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -488,7 +488,7 @@ TEST(TypeChecker, Unary) TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray", true, true); const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), Type::Int32::getInstance()->getTypeName()); + EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); EXPECT_TRUE(type.constValue); EXPECT_FALSE(type.constPointer); } @@ -505,7 +505,7 @@ TEST(TypeChecker, Unary) TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); const auto type = typeCheckExpression("&intVal", typeEnvironment); - EXPECT_EQ(type.type->getTypeName(), getPointerTypeName()); + EXPECT_EQ(type.type->getName(), getPointerTypeName()); EXPECT_FALSE(type.constValue); EXPECT_FALSE(type.constPointer); } From 170bb055ed572aca06ff351972e6a129818ad5b9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 14:43:29 +0000 Subject: [PATCH 054/725] slightly more elegant API for getting pointer types --- include/genn/genn/code_generator/groupMerged.h | 6 +++--- .../groupMergedTypeEnvironment.h | 4 ++-- include/genn/genn/type.h | 9 --------- .../customConnectivityUpdateGroupMerged.cc | 14 +++++++------- .../code_generator/customUpdateGroupMerged.cc | 10 +++++----- src/genn/genn/code_generator/groupMerged.cc | 18 +++++++++--------- .../genn/code_generator/initGroupMerged.cc | 10 +++++----- .../code_generator/neuronUpdateGroupMerged.cc | 8 ++++---- .../code_generator/synapseUpdateGroupMerged.cc | 2 +- src/genn/genn/transpiler/parser.cc | 4 ++-- src/genn/genn/transpiler/typeChecker.cc | 2 +- src/genn/genn/type.cc | 7 ------- tests/unit/typeChecker.cc | 4 ++-- 13 files changed, 41 insertions(+), 57 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 520170a058..2e8fc26ece 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -269,13 +269,13 @@ class GroupMerged void addPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - addField(createPointer(type), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); + addField(type->getPointerType(), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); } template>* = nullptr> void addPointerField(const std::string &name, const std::string &prefix) { - addField(createPointer(T::getInstance()), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); + addField(T::getInstance()->getPointerType(), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); } @@ -292,7 +292,7 @@ class GroupMerged { // Loop through variables for(const auto &v : varReferences) { - addField(createPointer(Type::parseNumeric(v.type, getScalarType())), v.name, + addField(Type::parseNumeric(v.type, getScalarType())->getPointerType(), v.name, [getVarRefFn, arrayPrefix, v](const G &g, size_t) { const auto varRef = getVarRefFn(g).at(v.name); diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index 047dea08e3..641631d4b9 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -142,7 +142,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa void definePointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix, VarAccessMode access) { defineField(type, name, (access & VarAccessModeAttribute::READ_ONLY), false, - Type::createPointer(type), name, [prefix](const auto &g, size_t) { return prefix + g.getName(); }); + type->getPointerType(), name, [prefix](const auto &g, size_t) { return prefix + g.getName(); }); } template @@ -200,7 +200,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa for(const auto &v : varReferences) { const auto *type = Type::parseNumeric(v.type, m_ScalarType); defineField(type, v.name, (v.access & VarAccessModeAttribute::READ_ONLY), false, - Type::createPointer(type), v.name, + type->getPointerType(), v.name, [arrayPrefix, getVarRefFn, v](const auto &g, size_t) { const auto varRef = getVarRefFn(g).at(v.name); diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index e73ebc10df..182c5fa3d3 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -299,15 +299,6 @@ DECLARE_NUMERIC_TYPE(Double, double, 60); DECLARE_FOREIGN_FUNCTION_TYPE(Exp, Double, Double); DECLARE_FOREIGN_FUNCTION_TYPE(Sqrt, Double, Double); -//! Create a pointer type to the given value type -const Pointer *createPointer(const Base *valueType); - -template -const Pointer *createPointer() -{ - return createPointer(T::getInstance()); -} - //! Parse a numeric type const NumericBase *parseNumeric(std::string_view typeString, const NumericBase *scalarType); diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 1cce726eba..9551255875 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -111,13 +111,13 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t assert(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE); - addField(createPointer(getArchetype().getSynapseGroup()->getSparseIndType()), "ind", + addField(getArchetype().getSynapseGroup()->getSparseIndType()->getPointerType(), "ind", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); }); - addField(createPointer(), "rowLength", + addField(Uint32::getInstance()->getPointerType(), "rowLength", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); @@ -125,7 +125,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // If some presynaptic variables are delayed, add delay pointer if (getArchetype().getPreDelayNeuronGroup() != nullptr) { - addField(createPointer(), "preSpkQuePtr", + addField(Uint32::getInstance()->getPointerType(), "preSpkQuePtr", [&backend](const auto &cg, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPreDelayNeuronGroup()->getName(); @@ -134,7 +134,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // If some postsynaptic variables are delayed, add delay pointer if (getArchetype().getPostDelayNeuronGroup() != nullptr) { - addField(createPointer(), "postSpkQuePtr", + addField(Uint32::getInstance()->getPointerType(), "postSpkQuePtr", [&backend](const auto &cg, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPostDelayNeuronGroup()->getName(); @@ -167,7 +167,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // Loop through sorted dependent variables for(size_t i = 0; i < getSortedArchetypeDependentVars().size(); i++) { - addField(createPointer(parseNumeric(getSortedArchetypeDependentVars().at(i).getVar().type, getScalarType())), "_dependentVar" + std::to_string(i), + addField(parseNumeric(getSortedArchetypeDependentVars().at(i).getVar().type, getScalarType())->getPointerType(), "_dependentVar" + std::to_string(i), [i, &backend, this](const auto&, size_t g) { const auto &varRef = m_SortedDependentVars[g][i]; @@ -563,13 +563,13 @@ void CustomConnectivityHostUpdateGroupMerged::addVars(const BackendBase &backend for(const auto &v : vars) { // If var is located on the host if (std::invoke(getVarLocationFn, getArchetype(), v.name) & VarLocation::HOST) { - addField(createPointer(parseNumeric(v.type, getScalarType())), v.name, + addField(parseNumeric(v.type, getScalarType())->getPointerType(), v.name, [v](const auto &g, size_t) { return v.name + g.getName(); }, GroupMergedFieldType::HOST); if(!backend.getDeviceVarPrefix().empty()) { // **TODO** I think could use addPointerField - addField(createPointer(parseNumeric(v.type, getScalarType())), backend.getDeviceVarPrefix() + v.name, + addField(parseNumeric(v.type, getScalarType())->getPointerType(), backend.getDeviceVarPrefix() + v.name, [v, &backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + v.name + g.getName(); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 7ef2aa19c7..f563d932c8 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -124,7 +124,7 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string // If some variables are delayed, add delay pointer if(getArchetype().getDelayNeuronGroup() != nullptr) { - addField(createPointer(), "spkQuePtr", + addField(Uint32::getInstance()->getPointerType(), "spkQuePtr", [&backend](const auto &cg, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); @@ -339,13 +339,13 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // If synapse group has sparse connectivity if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addField(createPointer(getArchetype().getSynapseGroup()->getSparseIndType()), "ind", + addField(getArchetype().getSynapseGroup()->getSparseIndType()->getPointerType(), "ind", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); }); - addField(createPointer(), "rowLength", + addField(Uint32::getInstance()->getPointerType(), "rowLength", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); @@ -379,7 +379,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // If variable has a transpose if(getArchetype().getVarReferences().at(v.name).getTransposeSynapseGroup() != nullptr) { // Add field with transpose suffix, pointing to transpose var - addField(createPointer(parseNumeric(v.type, getScalarType())), v.name + "Transpose", + addField(parseNumeric(v.type, getScalarType())->getPointerType(), v.name + "Transpose", [&backend, v](const auto &g, size_t) { const auto varRef = g.getVarReferences().at(v.name); @@ -438,7 +438,7 @@ CustomUpdateHostReductionGroupMerged::CustomUpdateHostReductionGroupMerged(size_ // If some variables are delayed, add delay pointer // **NOTE** this is HOST delay pointer if(getArchetype().getDelayNeuronGroup() != nullptr) { - addField(createPointer(), "spkQuePtr", + addField(Uint32::getInstance()->getPointerType(), "spkQuePtr", [](const auto &cg, size_t) { return "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 3fd5933200..ecc33d2289 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -357,7 +357,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : cs->getCurrentSourceModel()->getVars()) { // Add pointers to state variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(Type::createPointer(parseNumeric(var.type, getScalarType())), var.name + "CS" + std::to_string(i), + addField(parseNumeric(var.type, getScalarType())->getPointerType(), var.name + "CS" + std::to_string(i), [&backend, i, var, this](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + m_SortedCurrentSources.at(groupIndex).at(i)->getName(); @@ -488,7 +488,7 @@ bool NeuronGroupMergedBase::isPSMVarInitParamReferenced(size_t childIndex, const void NeuronGroupMergedBase::addMergedInSynPointerField(const Type::NumericBase *type, const std::string &name, size_t archetypeIndex, const std::string &prefix) { - addField(Type::createPointer(type), name + std::to_string(archetypeIndex), + addField(type->getPointerType(), name + std::to_string(archetypeIndex), [prefix, archetypeIndex, this](const auto&, size_t groupIndex) { return prefix + m_SortedMergedInSyns.at(groupIndex).at(archetypeIndex)->getFusedPSVarSuffix(); @@ -498,7 +498,7 @@ void NeuronGroupMergedBase::addMergedInSynPointerField(const Type::NumericBase * void NeuronGroupMergedBase::addMergedPreOutputOutSynPointerField(const Type::NumericBase *type, const std::string &name, size_t archetypeIndex, const std::string &prefix) { - addField(Type::createPointer(type), name + std::to_string(archetypeIndex), + addField(type->getPointerType(), name + std::to_string(archetypeIndex), [prefix, archetypeIndex, this](const auto&, size_t groupIndex) { return prefix + m_SortedMergedPreOutputOutSyns.at(groupIndex).at(archetypeIndex)->getFusedPreOutputSuffix(); @@ -840,14 +840,14 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // Add presynaptic variables to struct for(const auto &v : wum->getPreVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(createPointer(parseNumeric(v.type, getScalarType())), v.name, + addField(parseNumeric(v.type, getScalarType())->getPointerType(), v.name, [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPreVarSuffix(); }); } // Add presynaptic variables to struct for(const auto &v : wum->getPostVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(createPointer(parseNumeric(v.type, getScalarType())), v.name, + addField(parseNumeric(v.type, getScalarType())->getPointerType(), v.name, [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPostVarSuffix(); }); } @@ -1107,22 +1107,22 @@ boost::uuids::detail::sha1::digest_type SynapseGroupMergedBase::getHashDigest(Ro //---------------------------------------------------------------------------- void SynapseGroupMergedBase::addPSPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - addField(createPointer(type), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPSVarSuffix(); }); + addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPSVarSuffix(); }); } //---------------------------------------------------------------------------- void SynapseGroupMergedBase::addPreOutputPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - addField(createPointer(type), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPreOutputSuffix(); }); + addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPreOutputSuffix(); }); } //---------------------------------------------------------------------------- void SynapseGroupMergedBase::addSrcPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - addField(createPointer(type), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getSrcNeuronGroup()->getName(); }); + addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getSrcNeuronGroup()->getName(); }); } //---------------------------------------------------------------------------- void SynapseGroupMergedBase::addTrgPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) { - addField(createPointer(type), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getTrgNeuronGroup()->getName(); }); + addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getTrgNeuronGroup()->getName(); }); } //---------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 9f2078ffe0..8ec6db0517 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -443,7 +443,7 @@ void NeuronInitGroupMerged::generateWUVar(const BackendBase &backend, for(const auto &var : vars) { // Add pointers to state variable if(!varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(createPointer(parseNumeric(var.type, getScalarType())), var.name + fieldPrefixStem + std::to_string(i), + addField(parseNumeric(var.type, getScalarType())->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto&, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); @@ -1016,13 +1016,13 @@ CustomWUUpdateSparseInitGroupMerged::CustomWUUpdateSparseInitGroupMerged(size_t addField("numTrgNeurons", [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); - addField(createPointer(), "rowLength", + addField(Uint32::getInstance()->getPointerType(), "rowLength", [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); }); - addField(createPointer(getArchetype().getSynapseGroup()->getSparseIndType()), "ind", + addField(getArchetype().getSynapseGroup()->getSparseIndType()->getPointerType(), "ind", [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); @@ -1178,13 +1178,13 @@ CustomConnectivityUpdateSparseInitGroupMerged::CustomConnectivityUpdateSparseIni addField("numTrgNeurons", [](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); - addField(createPointer(), "rowLength", + addField(Uint32::getInstance()->getPointerType(), "rowLength", [&backend](const CustomConnectivityUpdateInternal &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); }); - addField(createPointer(getArchetype().getSynapseGroup()->getSparseIndType()), "ind", + addField(getArchetype().getSynapseGroup()->getSparseIndType()->getPointerType(), "ind", [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 6fe074f8e3..8725670a7b 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -80,7 +80,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string for(const auto &var : sgPreVars) { // If variable is referenced in event threshold code if(s.eventThresholdCode.find("$(" + var.name + ")") != std::string::npos) { - addField(createPointer(parseNumeric(var.type, getScalarType())), var.name + "EventThresh" + std::to_string(i), + addField(parseNumeric(var.type, getScalarType())->getPointerType(), var.name + "EventThresh" + std::to_string(i), [&backend, eventThresholdSGs, var, i](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + eventThresholdSGs.at(groupIndex).at(i)->getName(); @@ -93,7 +93,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string if(getArchetype().isSpikeRecordingEnabled()) { // Add field for spike recording - addField(createPointer(), "recordSpk", + addField(Uint32::getInstance()->getPointerType(), "recordSpk", [&backend](const auto &ng, size_t) { return backend.getDeviceVarPrefix() + "recordSpk" + ng.getName(); @@ -103,7 +103,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string if(getArchetype().isSpikeEventRecordingEnabled()) { // Add field for spike event recording - addField(createPointer(), "recordSpkEvent", + addField(Uint32::getInstance()->getPointerType(), "recordSpkEvent", [&backend](const auto &ng, size_t) { return backend.getDeviceVarPrefix() + "recordSpkEvent" + ng.getName(); @@ -716,7 +716,7 @@ void NeuronUpdateGroupMerged::generateWUVar(const BackendBase &backend, const s for(size_t v = 0; v < vars.size(); v++) { // Add pointers to state variable const auto var = vars[v]; - addField(Type::createPointer(Type::parseNumeric(var.type, getScalarType())), var.name + fieldPrefixStem + std::to_string(i), + addField(Type::parseNumeric(var.type, getScalarType())->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto &, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 04e385e4dc..98c50035e1 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -325,7 +325,7 @@ SynapseDendriticDelayUpdateGroupMerged::SynapseDendriticDelayUpdateGroupMerged(s const std::vector> &groups) : GroupMerged(index, precision, groups) { - addField(Type::createPointer(), "denDelayPtr", + addField(Type::Uint32::getInstance()->getPointerType(), "denDelayPtr", [&backend](const SynapseGroupInternal &sg, size_t) { return backend.getScalarAddressPrefix() + "denDelayPtr" + sg.getFusedPSVarSuffix(); diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index dcac94988e..33dfc169bc 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -221,7 +221,7 @@ GeNN::Type::QualifiedType parseDeclarationSpecifiers(ParserState &parserState) // **THINK** this relies of const being only qualifier // **TODO** warn of duplicate type qualifiers if (pointerFound) { - return GeNN::Type::QualifiedType{GeNN::Type::createPointer(numericType), + return GeNN::Type::QualifiedType{numericType->getPointerType(), !valueTypeQualifiers.empty(), !pointerTypeQualifiers.empty()}; } // Otherwise, return numeric type directly @@ -906,7 +906,7 @@ const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPo // If pointer, return pointer to numeric type if (pointerFound) { - return GeNN::Type::createPointer(numericType); + return numericType->getPointerType(); } // Otherwise, return numeric type directly else { diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index b4bf0cb57d..e476c83f63 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -416,7 +416,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Otherwise, if operator is address of, return pointer type else if (unary.getOperator().type == Token::Type::AMPERSAND) { - m_QualifiedType = Type::QualifiedType{Type::createPointer(rightType.type), + m_QualifiedType = Type::QualifiedType{rightType.type->getPointerType(), rightType.constValue, false}; } } diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 3793ee563a..dbf0af2cc5 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -90,13 +90,6 @@ const Base *Base::getPointerType() const //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- -const Pointer *createPointer(const Base *valueType) -{ - // **TODO** befriend constructor - // **TODO** don't just leak these! - return new Pointer(valueType); -} -//---------------------------------------------------------------------------- const NumericBase *parseNumeric(std::string_view typeString, const NumericBase *scalarType) { using namespace Transpiler; diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index b34e26ecf3..a091584a09 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -73,7 +73,7 @@ class TestEnvironment : public TypeChecker::EnvironmentBase template void definePointer(std::string_view name, bool isConstValue = false, bool isConstPointer = false) { - define(Type::createPointer(T::getInstance()), name, isConstValue, isConstPointer); + define(T::getInstance()->getPointerType(), name, isConstValue, isConstPointer); } @@ -134,7 +134,7 @@ class TestEnvironment : public TypeChecker::EnvironmentBase template std::string getPointerTypeName() { - return createPointer(T::getInstance())->getName(); + return T::getInstance()->getPointerType()->getName(); } void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment, const Type::NumericBase *scalarType = Type::Float::getInstance()) From 52276b5f311c76a403a006e41be006ed0e4ee4ca Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 17:13:51 +0000 Subject: [PATCH 055/725] integrated type qualifiers into type classes --- .../groupMergedTypeEnvironment.h | 60 +++-- include/genn/genn/transpiler/expression.h | 8 +- include/genn/genn/transpiler/statement.h | 8 +- include/genn/genn/transpiler/typeChecker.h | 22 +- include/genn/genn/type.h | 74 +++--- src/genn/genn/transpiler/parser.cc | 47 ++-- src/genn/genn/transpiler/prettyPrinter.cc | 46 +++- src/genn/genn/transpiler/typeChecker.cc | 217 ++++++++---------- src/genn/genn/type.cc | 4 +- tests/unit/typeChecker.cc | 208 +++++++++-------- 10 files changed, 360 insertions(+), 334 deletions(-) diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index 641631d4b9..b2c158c894 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -33,14 +33,14 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- - virtual void define(const Transpiler::Token &name, const Type::QualifiedType &, ErrorHandlerBase &errorHandler) final + virtual void define(const Transpiler::Token &name, const Type::Base*, ErrorHandlerBase &errorHandler) final { errorHandler.error(name, "Cannot declare variable in external environment"); throw TypeCheckError(); } - virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, - ErrorHandlerBase &errorHandler, bool initializer) final + virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, + ErrorHandlerBase &errorHandler, bool initializer) final { // If type isn't found auto existingType = m_Types.find(std::string{name.lexeme}); @@ -61,7 +61,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa return EnvironmentBase::assign(name, op, existingType->second.first, assignedType, errorHandler, initializer); } - virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final + virtual const Type::Base *incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final { auto existingType = m_Types.find(std::string{name.lexeme}); if(existingType == m_Types.end()) { @@ -81,7 +81,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa return EnvironmentBase::incDec(name, op, existingType->second.first, errorHandler); } - virtual const Type::QualifiedType &getType(const Token &name, ErrorHandlerBase &errorHandler) final + virtual const Type::Base *getType(const Token &name, ErrorHandlerBase &errorHandler) final { auto type = m_Types.find(std::string{name.lexeme}); if(type == m_Types.end()) { @@ -104,47 +104,39 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- - void defineField(const Type::Base *type, const std::string &name, bool isConstValue = false, bool isConstPointer = false) + void defineField(const Type::Base *type, const std::string &name) { - if(!m_Types.try_emplace(name, std::piecewise_construct, - std::forward_as_tuple(type, isConstValue, isConstPointer), - std::forward_as_tuple(std::nullopt)).second) + if(!m_Types.try_emplace(name, type, std::nullopt).second) { throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } template - void defineField(const std::string &name, bool isConstValue = false, bool isConstPointer = false) + void defineField(const std::string &name) { - defineField(T::getInstance(), name, isConstPointer, isConstPointer); + defineField(T::getInstance(), name); } - void defineField(const Type::Base *type, const std::string &name, bool isConstValue, bool isConstPointer, + void defineField(const Type::Base *type, const std::string &name, const Type::Base *fieldType, std::string_view fieldName, typename G::GetFieldValueFunc getFieldValue, GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD) { if(!m_Types.try_emplace(name, std::piecewise_construct, - std::forward_as_tuple(type, isConstValue, isConstPointer), + std::forward_as_tuple(type), std::forward_as_tuple(std::in_place, fieldType, fieldName, getFieldValue, mergedFieldType)).second) { throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } - void defineField(const Type::Base *type, const std::string &name, bool isConstValue, bool isConstPointer, - typename G::GetFieldValueFunc getFieldValue, GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD) - { - defineField(type, name, isConstValue, isConstPointer, - type, name, getFieldValue, mergedFieldType); - } - - void definePointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix, VarAccessMode access) + void definePointerField(const Type::NumericBase *type, const std::string &name,const std::string &prefix, VarAccessMode access) { - defineField(type, name, (access & VarAccessModeAttribute::READ_ONLY), false, + const auto *qualifiedType = (access & VarAccessModeAttribute::READ_ONLY) ? type->getQualifiedType(Type::Qualifier::CONST) : type; + defineField(qualifiedType, name, type->getPointerType(), name, [prefix](const auto &g, size_t) { return prefix + g.getName(); }); } - + template void defineHeterogeneousParams(const Snippet::Base::StringVec ¶mNames, const std::string &suffix, P getParamValues, H isHeterogeneous) @@ -152,15 +144,17 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa // Loop through params for(const auto &p : paramNames) { if (std::invoke(isHeterogeneous, m_GroupMerged, p)) { - defineField(m_ScalarType, p + suffix, true, false, + defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONST), p + suffix, + m_ScalarType, p + suffix, [p, getParamValues](const auto &g, size_t) { const auto &values = getParamValues(g); return Utils::writePreciseString(values.at(p)); }); } + // Otherwise, just add a const-qualified scalar to the type environment else { - defineField(m_ScalarType, p + suffix, true, false); + defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONST), p + suffix); } } } @@ -172,7 +166,8 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa // Loop through derived params for(const auto &d : derivedParams) { if (std::invoke(isHeterogeneous, m_GroupMerged, d.name)) { - defineField(m_ScalarType, d.name + suffix, true, false, + defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONST), d.name + suffix, + m_ScalarType, d.name + suffix, [d, getDerivedParamValues](const auto &g, size_t) { const auto &values = getDerivedParamValues(g); @@ -180,7 +175,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa }); } else { - defineField(m_ScalarType, d.name + suffix, true, false); + defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONST), d.name + suffix); } } } @@ -199,7 +194,10 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa // Loop through variables for(const auto &v : varReferences) { const auto *type = Type::parseNumeric(v.type, m_ScalarType); - defineField(type, v.name, (v.access & VarAccessModeAttribute::READ_ONLY), false, + + // If variable access is read-only, qualify type with const + const auto *qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? type->getQualifiedType(Type::Qualifier::CONST) : type; + defineField(qualifiedType, v.name, type->getPointerType(), v.name, [arrayPrefix, getVarRefFn, v](const auto &g, size_t) { @@ -213,7 +211,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa { for(const auto &e : egps) { const auto *type = Type::parseNumericPtr(e.type, m_ScalarType); - defineField(type, e.name, false, false, + defineField(type, e.name, type, e.name + varName, [arrayPrefix, e, varName](const auto &g, size_t) { @@ -227,7 +225,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa //--------------------------------------------------------------------------- // Private methods //--------------------------------------------------------------------------- - void addField(std::pair> &type) + void addField(std::pair> &type) { // If this type has an associated field if (type.second) { @@ -249,6 +247,6 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa const Type::NumericBase *m_ScalarType; EnvironmentBase *m_Enclosing; - std::unordered_map>> m_Types; + std::unordered_map>> m_Types; }; } // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index f936bf6b59..9963240520 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -122,18 +122,18 @@ class Call : public Base class Cast : public Base { public: - Cast(const Type::QualifiedType &qualifiedType, ExpressionPtr expression, Token closingParen) - : m_QualifiedType(qualifiedType), m_Expression(std::move(expression)), m_ClosingParen(closingParen) + Cast(const Type::Base *type, ExpressionPtr expression, Token closingParen) + : m_Type(type), m_Expression(std::move(expression)), m_ClosingParen(closingParen) {} virtual void accept(Visitor &visitor) const final; - const Type::QualifiedType &getQualifiedType() const{ return m_QualifiedType; } + const Type::Base *getType() const{ return m_Type; } const Base *getExpression() const { return m_Expression.get(); } const Token &getClosingParen() const { return m_ClosingParen; } private: - const Type::QualifiedType m_QualifiedType; + const Type::Base *m_Type; const ExpressionPtr m_Expression; const Token m_ClosingParen; }; diff --git a/include/genn/genn/transpiler/statement.h b/include/genn/genn/transpiler/statement.h index af0c4f49f2..d0b3bdbe87 100644 --- a/include/genn/genn/transpiler/statement.h +++ b/include/genn/genn/transpiler/statement.h @@ -228,17 +228,17 @@ class VarDeclaration : public Base public: typedef std::vector> InitDeclaratorList; - VarDeclaration(const Type::QualifiedType &qualifiedType, InitDeclaratorList initDeclaratorList) - : m_QualifiedType(qualifiedType), m_InitDeclaratorList(std::move(initDeclaratorList)) + VarDeclaration(const Type::Base *type, InitDeclaratorList initDeclaratorList) + : m_Type(type), m_InitDeclaratorList(std::move(initDeclaratorList)) {} virtual void accept(Visitor &visitor) const override; - const Type::QualifiedType &getQualifiedType() const{ return m_QualifiedType; } + const Type::Base *getType() const{ return m_Type; } const InitDeclaratorList &getInitDeclaratorList() const { return m_InitDeclaratorList; } private: - const Type::QualifiedType m_QualifiedType; + const Type::Base *m_Type; const std::vector m_DeclarationSpecifiers; const InitDeclaratorList m_InitDeclaratorList; }; diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index da2f9322e0..3bdc62cb89 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -40,21 +40,21 @@ class EnvironmentBase //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ - virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandlerBase &errorHandler) = 0; - virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, + virtual void define(const Token &name, const Type::Base *type, ErrorHandlerBase &errorHandler) = 0; + virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, ErrorHandlerBase &errorHandler, bool initializer = false) = 0; - virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) = 0; - virtual const Type::QualifiedType &getType(const Token &name, ErrorHandlerBase &errorHandler) = 0; + virtual const Type::Base *incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) = 0; + virtual const Type::Base *getType(const Token &name, ErrorHandlerBase &errorHandler) = 0; protected: //--------------------------------------------------------------------------- // Protected API //--------------------------------------------------------------------------- - const Type::QualifiedType &assign(const Token &name, Token::Type op, - const Type::QualifiedType &existingType, const Type::QualifiedType &assignedType, - ErrorHandlerBase &errorHandler, bool initializer = false) const; - const Type::QualifiedType &incDec(const Token &name, Token::Type op, - const Type::QualifiedType &existingType, ErrorHandlerBase &errorHandler) const; + const Type::Base *assign(const Token &name, Token::Type op, + const Type::Base *existingType, const Type::Base *assignedType, + ErrorHandlerBase &errorHandler, bool initializer = false) const; + const Type::Base *incDec(const Token &name, Token::Type op, + const Type::Base *existingType, ErrorHandlerBase &errorHandler) const; }; //--------------------------------------------------------------------------- @@ -63,6 +63,6 @@ class EnvironmentBase void typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, ErrorHandlerBase &errorHandler); -Type::QualifiedType typeCheck(const Expression::Base *expression, EnvironmentBase &environment, - ErrorHandlerBase &errorHandler); +const Type::Base *typeCheck(const Expression::Base *expression, EnvironmentBase &environment, + ErrorHandlerBase &errorHandler); } // namespace MiniParse::GeNN::Transpiler diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 182c5fa3d3..5db472750f 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -31,22 +31,26 @@ return s_Instance; \ } -#define DECLARE_NUMERIC_TYPE(TYPE, UNDERLYING_TYPE, RANK) \ - class TYPE : public Numeric \ - { \ - DECLARE_TYPE(TYPE) \ - virtual std::string getName() const{ return #UNDERLYING_TYPE; } \ - }; \ - template<> \ - struct TypeTraits \ - { \ - using NumericType = TYPE; \ +#define DECLARE_NUMERIC_TYPE(TYPE, UNDERLYING_TYPE, RANK) \ + class TYPE : public Numeric \ + { \ + DECLARE_TYPE(TYPE) \ + TYPE(Qualifier qualifiers = Qualifier{0}) : Numeric(qualifiers){} \ + virtual std::string getName() const final{ return #UNDERLYING_TYPE; } \ + virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new TYPE(qualifiers); } \ + }; \ + template<> \ + struct TypeTraits \ + { \ + using NumericType = TYPE; \ } -#define DECLARE_FOREIGN_FUNCTION_TYPE(TYPE, RETURN_TYPE, ...) \ - class TYPE : public ForeignFunction \ - { \ - DECLARE_TYPE(TYPE) \ +#define DECLARE_FOREIGN_FUNCTION_TYPE(TYPE, RETURN_TYPE, ...) \ + class TYPE : public ForeignFunction \ + { \ + DECLARE_TYPE(TYPE) \ + TYPE(Qualifier qualifiers = Qualifier{0}) : ForeignFunction(qualifiers){} \ + virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new TYPE(qualifiers); } \ } #define IMPLEMENT_TYPE(TYPE) TYPE *TYPE::s_Instance = NULL @@ -93,15 +97,26 @@ inline Qualifier operator | (Qualifier a, Qualifier b) class Base { public: + Base(Qualifier qualifiers = Qualifier{0}) : m_Qualifiers(qualifiers){} + //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ virtual std::string getName() const = 0; + virtual Base *getQualifiedType(Qualifier qualifiers) const = 0; //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - const Base *getPointerType() const; + const Base *getPointerType(Qualifier qualifiers = Qualifier{0}) const; + + bool hasQualifier(Qualifier qualifier) const{ return (m_Qualifiers & qualifier); }; + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + const Qualifier m_Qualifiers; }; //---------------------------------------------------------------------------- @@ -111,8 +126,8 @@ class Base class Pointer : public Base { public: - Pointer(const Base *valueType) - : m_ValueType(valueType) + Pointer(const Base *valueType, Qualifier qualifiers = Qualifier{0}) + : Base(qualifiers), m_ValueType(valueType) { } @@ -120,6 +135,7 @@ class Pointer : public Base // Base virtuals //------------------------------------------------------------------------ virtual std::string getName() const{ return getValueType()->getName() + "*";} + virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new Pointer(m_ValueType, qualifiers); } //------------------------------------------------------------------------ // Public API @@ -133,28 +149,14 @@ class Pointer : public Base const Base *m_ValueType; }; -//---------------------------------------------------------------------------- -// GeNN::Type::QualifiedType -//---------------------------------------------------------------------------- -//! A type with qualifiers attached -struct QualifiedType -{ - QualifiedType(const Base *t, bool v, bool p) - : type(t), constValue(v), constPointer(p) - { - } - - const Base *type; - bool constValue; - bool constPointer; -}; - //---------------------------------------------------------------------------- // GeNN::Type::NumericBase //---------------------------------------------------------------------------- class NumericBase : public Base { public: + NumericBase(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} + //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ @@ -173,6 +175,8 @@ template class Numeric : public NumericBase { public: + Numeric(Qualifier qualifiers = Qualifier{0}) : NumericBase(qualifiers){} + //------------------------------------------------------------------------ // Typedefines //------------------------------------------------------------------------ @@ -200,6 +204,8 @@ class Numeric : public NumericBase class ForeignFunctionBase : public Base { public: + ForeignFunctionBase(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} + //------------------------------------------------------------------------ // Base virtuals //------------------------------------------------------------------------ @@ -219,6 +225,8 @@ template class ForeignFunction : public ForeignFunctionBase { public: + ForeignFunction(Qualifier qualifiers = Qualifier{0}) : ForeignFunctionBase(qualifiers){} + //------------------------------------------------------------------------ // Base virtuals //------------------------------------------------------------------------ diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 33dfc169bc..3bc11d5343 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -185,27 +185,29 @@ Expression::ExpressionPtr parseBinary(ParserState &parserState, N nonTerminal, s return expression; } -GeNN::Type::QualifiedType parseDeclarationSpecifiers(ParserState &parserState) +const GeNN::Type::Base *parseDeclarationSpecifiers(ParserState &parserState) { - bool pointerFound = false; + using namespace GeNN::Type; + std::set typeSpecifiers; - std::set valueTypeQualifiers; - std::set pointerTypeQualifiers; + std::set typeQualifiers; + std::vector> pointerTypeQualifiers; + do { - // If token is a star, set pointer found flag + // If token is a star, add new set of pointer type qualifiers if(parserState.previous().type == Token::Type::STAR) { - pointerFound = true; + pointerTypeQualifiers.emplace_back(); } // Otherwise, if type is a qualifier else if(parserState.previous().type == Token::Type::TYPE_QUALIFIER) { // Add qualifier lexeme to correct list - auto &typeQualifiers = pointerFound ? pointerTypeQualifiers : valueTypeQualifiers; + std::set &typeQualifiers = pointerTypeQualifiers.empty() ? typeQualifiers : pointerTypeQualifiers.back(); if(!typeQualifiers.insert(parserState.previous().lexeme).second) { parserState.error(parserState.previous(), "duplicate type qualifier"); } } else if(parserState.previous().type == Token::Type::TYPE_SPECIFIER) { - if(pointerFound) { + if(!pointerTypeQualifiers.empty()) { parserState.error(parserState.previous(), "invalid type specifier"); } else if(!typeSpecifiers.insert(parserState.previous().lexeme).second) { @@ -215,19 +217,20 @@ GeNN::Type::QualifiedType parseDeclarationSpecifiers(ParserState &parserState) } while(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER, Token::Type::STAR})); // Lookup numeric type - const auto *numericType = GeNN::Type::getNumericType(typeSpecifiers, parserState.getScalarType()); - - // If pointer, return pointer to numeric type + const Base *type = getNumericType(typeSpecifiers, parserState.getScalarType()); + + // If there are any type qualifiers, add const // **THINK** this relies of const being only qualifier - // **TODO** warn of duplicate type qualifiers - if (pointerFound) { - return GeNN::Type::QualifiedType{numericType->getPointerType(), - !valueTypeQualifiers.empty(), !pointerTypeQualifiers.empty()}; + if(!typeQualifiers.empty()) { + type = type->getQualifiedType(Qualifier::CONST); } - // Otherwise, return numeric type directly - else { - return GeNN::Type::QualifiedType{numericType, !valueTypeQualifiers.empty(), !pointerTypeQualifiers.empty()}; + + // Loop through levels of pointer indirection + // **THINK** this relies of const being only qualifier + for(const auto &p : pointerTypeQualifiers) { + type = type->getPointerType(p.empty() ? Qualifier{0} : Qualifier::CONST); } + return type; } Expression::ExpressionPtr parsePrimary(ParserState &parserState) @@ -378,11 +381,11 @@ Expression::ExpressionPtr parseCast(ParserState &parserState) // If this is followed by some part of a type declarator if(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER})) { // Parse declaration specifiers - const auto qualifiedType = parseDeclarationSpecifiers(parserState); + const auto *type = parseDeclarationSpecifiers(parserState); const auto closingParen = parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after cast type."); - return std::make_unique(qualifiedType, parseCast(parserState), closingParen); + return std::make_unique(type, parseCast(parserState), closingParen); } // Otherwise, rewind parser state so left parenthesis can be parsed again // **YUCK** @@ -798,7 +801,7 @@ Statement::StatementPtr parseDeclaration(ParserState &parserState) // "const" // Parse declaration specifiers - const auto qualifiedType = parseDeclarationSpecifiers(parserState); + const auto *type = parseDeclarationSpecifiers(parserState); // Read init declarator list std::vector> initDeclaratorList; @@ -822,7 +825,7 @@ Statement::StatementPtr parseDeclaration(ParserState &parserState) } while(!parserState.isAtEnd() && parserState.match(Token::Type::COMMA)); parserState.consume(Token::Type::SEMICOLON, "Expect ';' after variable declaration"); - return std::make_unique(qualifiedType, std::move(initDeclaratorList)); + return std::make_unique(type, std::move(initDeclaratorList)); } std::unique_ptr parseBlockItem(ParserState &parserState) diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index c42350c84e..d9761a9ef8 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -3,6 +3,7 @@ // Standard C++ includes #include #include +#include #include // GeNN includes @@ -78,7 +79,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Cast &cast) final { m_StringStream << "("; - printQualifiedType(cast.getQualifiedType()); + printType(cast.getType()); m_StringStream << ")"; cast.getExpression()->accept(*this); } @@ -229,7 +230,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::VarDeclaration &varDeclaration) final { - printQualifiedType(varDeclaration.getQualifiedType()); + printType(varDeclaration.getType()); for(const auto &var : varDeclaration.getInitDeclaratorList()) { m_StringStream << std::get<0>(var).lexeme; @@ -258,16 +259,43 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } private: - void printQualifiedType(const GeNN::Type::QualifiedType &qualifiedType) + void printType(const GeNN::Type::Base *type) { - if(qualifiedType.constValue) { - m_StringStream << "const "; - } - m_StringStream << qualifiedType.type->getName() << " "; + // **THINK** this should be Type::getName! - if(qualifiedType.constPointer) { - m_StringStream << "const "; + // Loop, building reversed list of tokens + std::vector tokens; + while(true) { + // If type is a pointer + const auto *pointerType = dynamic_cast(type); + if(pointerType) { + // If pointer has const qualifier, add const + if(pointerType->hasQualifier(GeNN::Type::Qualifier::CONST)) { + tokens.push_back("const"); + } + + // Add * + tokens.push_back("*"); + + // Go to value type + type = pointerType->getValueType(); + } + // Otherwise + else { + // Add type specifier + tokens.push_back(type->getName()); + + + if(pointerType->hasQualifier(GeNN::Type::Qualifier::CONST)) { + tokens.push_back("const"); + } + break; + } } + + // Copy tokens backwards into string stream, seperating with spaces + std::copy(tokens.rbegin(), tokens.rend(), std::ostream_iterator(m_StringStream, " ")); + } //--------------------------------------------------------------------------- // Members diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index e476c83f63..c0749bef78 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -37,16 +37,16 @@ class EnvironmentInternal : public EnvironmentBase //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- - virtual void define(const Token &name, const Type::QualifiedType &qualifiedType, ErrorHandlerBase &errorHandler) final + virtual void define(const Token &name, const Type::Base *type, ErrorHandlerBase &errorHandler) final { - if(!m_Types.try_emplace(name.lexeme, qualifiedType).second) { + if(!m_Types.try_emplace(name.lexeme, type).second) { errorHandler.error(name, "Redeclaration of variable"); throw TypeCheckError(); } } - virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, - ErrorHandlerBase &errorHandler, bool initializer = false) final + virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, + ErrorHandlerBase &errorHandler, bool initializer = false) final { // If type isn't found auto existingType = m_Types.find(name.lexeme); @@ -59,7 +59,7 @@ class EnvironmentInternal : public EnvironmentBase return EnvironmentBase::assign(name, op, existingType->second, assignedType, errorHandler, initializer); } - virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final + virtual const Type::Base *incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final { // If type isn't found auto existingType = m_Types.find(name.lexeme); @@ -71,9 +71,9 @@ class EnvironmentInternal : public EnvironmentBase return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); } - virtual const Type::QualifiedType &getType(const Token &name, ErrorHandlerBase &errorHandler) final + virtual const Type::Base *getType(const Token &name, ErrorHandlerBase &errorHandler) final { - auto type = m_Types.find(std::string{name.lexeme}); + auto type = m_Types.find(name.lexeme); if(type == m_Types.end()) { return m_Enclosing.getType(name, errorHandler); } @@ -87,7 +87,7 @@ class EnvironmentInternal : public EnvironmentBase // Members //--------------------------------------------------------------------------- EnvironmentBase &m_Enclosing; - std::unordered_map m_Types; + std::unordered_map m_Types; }; //--------------------------------------------------------------------------- @@ -97,7 +97,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { public: Visitor(ErrorHandlerBase &errorHandler) - : m_Environment(nullptr), m_QualifiedType{nullptr, false, false}, m_ErrorHandler(errorHandler), + : m_Environment(nullptr), m_Type(nullptr), m_ErrorHandler(errorHandler), m_InLoop(false), m_InSwitch(false) { } @@ -114,7 +114,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } } - const Type::QualifiedType typeCheck(const Expression::Base *expression, EnvironmentInternal &environment) + const Type::Base *typeCheck(const Expression::Base *expression, EnvironmentInternal &environment) { m_Environment = &environment; @@ -128,21 +128,21 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Get pointer type auto arrayType = m_Environment->getType(arraySubscript.getPointerName(), m_ErrorHandler); - auto pointerType = dynamic_cast(arrayType.type); + auto pointerType = dynamic_cast(arrayType); // If pointer is indeed a pointer if (pointerType) { // Evaluate pointer type auto indexType = evaluateType(arraySubscript.getIndex().get()); - auto indexNumericType = dynamic_cast(indexType.type); + auto indexNumericType = dynamic_cast(indexType); if (!indexNumericType || !indexNumericType->isIntegral()) { m_ErrorHandler.error(arraySubscript.getPointerName(), - "Invalid subscript index type '" + indexType.type->getName() + "'"); + "Invalid subscript index type '" + indexType->getName() + "'"); throw TypeCheckError(); } // Use value type of array - m_QualifiedType = Type::QualifiedType{pointerType->getValueType(), arrayType.constValue, false}; + m_Type = pointerType->getValueType(); } // Otherwise else { @@ -154,7 +154,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Assignment &assignment) final { const auto rhsType = evaluateType(assignment.getValue()); - m_QualifiedType = m_Environment->assign(assignment.getVarName(), assignment.getOperator().type, rhsType, m_ErrorHandler); + m_Type = m_Environment->assign(assignment.getVarName(), assignment.getOperator().type, rhsType, m_ErrorHandler); } virtual void visit(const Expression::Binary &binary) final @@ -162,48 +162,48 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto opType = binary.getOperator().type; const auto rightType = evaluateType(binary.getRight()); if (opType == Token::Type::COMMA) { - m_QualifiedType = rightType; + m_Type = rightType; } else { // If we're subtracting two pointers const auto leftType = evaluateType(binary.getLeft()); - auto leftNumericType = dynamic_cast(leftType.type); - auto rightNumericType = dynamic_cast(rightType.type); - auto leftPointerType = dynamic_cast(leftType.type); - auto rightPointerType = dynamic_cast(rightType.type); + auto leftNumericType = dynamic_cast(leftType); + auto rightNumericType = dynamic_cast(rightType); + auto leftPointerType = dynamic_cast(leftType); + auto rightPointerType = dynamic_cast(rightType); if (leftPointerType && rightPointerType && opType == Token::Type::MINUS) { // Check pointers are compatible if (leftPointerType->getName() != rightPointerType->getName()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getName() + "' and '" + rightType.type->getName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } // **TODO** should be std::ptrdiff/Int64 - m_QualifiedType = Type::QualifiedType{Type::Int32::getInstance(), false, false}; + m_Type = Type::Int32::getInstance(); } // Otherwise, if we're adding to or subtracting from pointers else if (leftPointerType && rightNumericType && (opType == Token::Type::PLUS || opType == Token::Type::MINUS)) // P + n or P - n { // Check that numeric operand is integer if (!rightNumericType->isIntegral()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getName() + "' and '" + rightType.type->getName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } // Use left type - m_QualifiedType = leftType; + m_Type = leftType; } // Otherwise, if we're adding a number to a pointer else if (leftNumericType && rightPointerType && opType == Token::Type::PLUS) // n + P { // Check that numeric operand is integer if (!leftNumericType->isIntegral()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getName() + "' and '" + rightType.type->getName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } // Use right type - m_QualifiedType = leftType; + m_Type = leftType; } // Otherwise, if both operands are numeric else if (leftNumericType && rightNumericType) { @@ -214,27 +214,27 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Check that operands are integers if (!leftNumericType->isIntegral() || !rightNumericType->isIntegral()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getName() + "' and '" + rightType.type->getName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } // If operator is a shift, promote left type if (opType == Token::Type::SHIFT_LEFT || opType == Token::Type::SHIFT_RIGHT) { - m_QualifiedType = Type::QualifiedType{Type::getPromotedType(leftNumericType), false, false}; + m_Type = Type::getPromotedType(leftNumericType); } // Otherwise, take common type else { - m_QualifiedType = Type::QualifiedType{Type::getCommonType(leftNumericType, rightNumericType), false, false}; + m_Type = Type::getCommonType(leftNumericType, rightNumericType); } } // Otherwise, any numeric type will do, take common type else { - m_QualifiedType = Type::QualifiedType{Type::getCommonType(leftNumericType, rightNumericType), false, false}; + m_Type = Type::getCommonType(leftNumericType, rightNumericType); } } else { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.type->getName() + "' and '" + rightType.type->getName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } } @@ -244,7 +244,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Evaluate callee type auto calleeType = evaluateType(call.getCallee()); - auto calleeFunctionType = dynamic_cast(calleeType.type); + auto calleeFunctionType = dynamic_cast(calleeType); // If callee's a function if (calleeFunctionType) { @@ -266,7 +266,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto callArgType = evaluateType(call.getArguments().at(i).get()); }*/ // Type is return type of function - m_QualifiedType = Type::QualifiedType{calleeFunctionType->getReturnType(), false, false}; + m_Type = calleeFunctionType->getReturnType(); } } // Otherwise @@ -281,94 +281,87 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // Evaluate type of expression we're casting const auto rightType = evaluateType(cast.getExpression()); - // If value const is being removed - if (rightType.constValue && !cast.getQualifiedType().constValue) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getName() + "' and '" + rightType.type->getName()); - throw TypeCheckError(); - } - // Otherwise, if pointer const is being removed - else if (rightType.constPointer && !cast.getQualifiedType().constPointer) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getName() + "' and '" + rightType.type->getName()); + // If const is being removed + if (rightType->hasQualifier(Type::Qualifier::CONST) && !cast.getType()->hasQualifier(Type::Qualifier::CONST)) { + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } // If we're trying to cast pointer to pointer - auto rightNumericType = dynamic_cast(rightType.type); - auto rightPointerType = dynamic_cast(rightType.type); - auto leftNumericType = dynamic_cast(cast.getQualifiedType().type); - auto leftPointerType = dynamic_cast(cast.getQualifiedType().type); + auto rightNumericType = dynamic_cast(rightType); + auto rightPointerType = dynamic_cast(rightType); + auto leftNumericType = dynamic_cast(cast.getType()); + auto leftPointerType = dynamic_cast(cast.getType()); if (rightPointerType && leftPointerType) { if (rightPointerType->getName() != leftPointerType->getName()) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getName() + "' and '" + rightType.type->getName()); + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } } // Otherwise, if either operand isn't numeric else if(!leftNumericType | !rightNumericType) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getQualifiedType().type->getName() + "' and '" + rightType.type->getName()); + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } - m_QualifiedType = cast.getQualifiedType(); + m_Type = cast.getType(); } virtual void visit(const Expression::Conditional &conditional) final { const auto trueType = evaluateType(conditional.getTrue()); const auto falseType = evaluateType(conditional.getFalse()); - auto trueNumericType = dynamic_cast(trueType.type); - auto falseNumericType = dynamic_cast(falseType.type); + auto trueNumericType = dynamic_cast(trueType); + auto falseNumericType = dynamic_cast(falseType); if (trueNumericType && falseNumericType) { // **TODO** check behaviour - m_QualifiedType = Type::QualifiedType{Type::getCommonType(trueNumericType, falseNumericType), - trueType.constValue || falseType.constValue, - trueType.constPointer || falseType.constPointer}; + m_Type = Type::getCommonType(trueNumericType, falseNumericType); + if(trueType->hasQualifier(Type::Qualifier::CONST) || falseType->hasQualifier(Type::Qualifier::CONST)) { + m_Type = m_Type->getQualifiedType(Type::Qualifier::CONST); + } } else { m_ErrorHandler.error(conditional.getQuestion(), - "Invalid operand types '" + trueType.type->getName() + "' and '" + falseType.type->getName() + "' to conditional"); + "Invalid operand types '" + trueType->getName() + "' and '" + falseType->getName() + "' to conditional"); throw TypeCheckError(); } } virtual void visit(const Expression::Grouping &grouping) final { - m_QualifiedType = evaluateType(grouping.getExpression()); + m_Type = evaluateType(grouping.getExpression()); } virtual void visit(const Expression::Literal &literal) final { - m_QualifiedType = Type::QualifiedType{ - std::visit(Utils::Overload{ - [](auto v)->const Type::NumericBase *{ return Type::TypeTraits::NumericType::getInstance(); }, - [](std::monostate)->const Type::NumericBase *{ return nullptr; }}, - literal.getValue()), - true, - false}; + m_Type = std::visit(Utils::Overload{ + [](auto v)->const Type::NumericBase *{ return Type::TypeTraits::NumericType::getInstance(); }, + [](std::monostate)->const Type::NumericBase *{ return nullptr; }}, + literal.getValue()); } virtual void visit(const Expression::Logical &logical) final { logical.getLeft()->accept(*this); logical.getRight()->accept(*this); - m_QualifiedType = Type::QualifiedType{Type::Int32::getInstance(), false, false}; + m_Type = Type::Int32::getInstance(); } virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final { - m_QualifiedType = m_Environment->incDec(postfixIncDec.getVarName(), - postfixIncDec.getOperator().type, m_ErrorHandler); + m_Type = m_Environment->incDec(postfixIncDec.getVarName(), + postfixIncDec.getOperator().type, m_ErrorHandler); } virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final { - m_QualifiedType = m_Environment->incDec(prefixIncDec.getVarName(), - prefixIncDec.getOperator().type, m_ErrorHandler); + m_Type = m_Environment->incDec(prefixIncDec.getVarName(), + prefixIncDec.getOperator().type, m_ErrorHandler); } virtual void visit(const Expression::Variable &variable) { - m_QualifiedType = m_Environment->getType(variable.getName(), m_ErrorHandler); + m_Type = m_Environment->getType(variable.getName(), m_ErrorHandler); } virtual void visit(const Expression::Unary &unary) final @@ -377,52 +370,50 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If operator is pointer de-reference if (unary.getOperator().type == Token::Type::STAR) { - auto rightPointerType = dynamic_cast(rightType.type); + auto rightPointerType = dynamic_cast(rightType); if (!rightPointerType) { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType.type->getName() + "'"); + "Invalid operand type '" + rightType->getName() + "'"); throw TypeCheckError(); } // Return value type - m_QualifiedType = Type::QualifiedType{rightPointerType->getValueType(), rightType.constValue, false}; + m_Type = rightPointerType->getValueType(); } // Otherwise else { - auto rightNumericType = dynamic_cast(rightType.type); + auto rightNumericType = dynamic_cast(rightType); if (rightNumericType) { // If operator is arithmetic, return promoted type if (unary.getOperator().type == Token::Type::PLUS || unary.getOperator().type == Token::Type::MINUS) { - m_QualifiedType = Type::QualifiedType{Type::getPromotedType(rightNumericType), - rightType.constValue, false}; + // **THINK** const through these? + m_Type = Type::getPromotedType(rightNumericType); } // Otherwise, if operator is bitwise else if (unary.getOperator().type == Token::Type::TILDA) { // If type is integer, return promoted type if (rightNumericType->isIntegral()) { - m_QualifiedType = Type::QualifiedType{Type::getPromotedType(rightNumericType), - rightType.constValue, false}; + // **THINK** const through these? + m_Type = Type::getPromotedType(rightNumericType); } else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType.type->getName() + "'"); + "Invalid operand type '" + rightType->getName() + "'"); throw TypeCheckError(); } } // Otherwise, if operator is logical else if (unary.getOperator().type == Token::Type::NOT) { - m_QualifiedType = Type::QualifiedType{Type::Int32::getInstance(), - rightType.constValue, false}; + m_Type = Type::Int32::getInstance();; } // Otherwise, if operator is address of, return pointer type else if (unary.getOperator().type == Token::Type::AMPERSAND) { - m_QualifiedType = Type::QualifiedType{rightType.type->getPointerType(), - rightType.constValue, false}; + m_Type = rightType->getPointerType(); } } else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType.type->getName() + "'"); + "Invalid operand type '" + rightType->getName() + "'"); throw TypeCheckError(); } } @@ -509,10 +500,10 @@ class Visitor : public Expression::Visitor, public Statement::Visitor if (labelled.getValue()) { auto valType = evaluateType(labelled.getValue()); - auto valNumericType = dynamic_cast(valType.type); + auto valNumericType = dynamic_cast(valType); if (!valNumericType || !valNumericType->isIntegral()) { m_ErrorHandler.error(labelled.getKeyword(), - "Invalid case value '" + valType.type->getName() + "'"); + "Invalid case value '" + valType->getName() + "'"); throw TypeCheckError(); } } @@ -523,10 +514,10 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::Switch &switchStatement) final { auto condType = evaluateType(switchStatement.getCondition()); - auto condNumericType = dynamic_cast(condType.type); + auto condNumericType = dynamic_cast(condType); if (!condNumericType || !condNumericType->isIntegral()) { m_ErrorHandler.error(switchStatement.getSwitch(), - "Invalid condition '" + condType.type->getName() + "'"); + "Invalid condition '" + condType->getName() + "'"); throw TypeCheckError(); } @@ -538,7 +529,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::VarDeclaration &varDeclaration) final { for (const auto &var : varDeclaration.getInitDeclaratorList()) { - m_Environment->define(std::get<0>(var), varDeclaration.getQualifiedType(), m_ErrorHandler); + m_Environment->define(std::get<0>(var), varDeclaration.getType(), m_ErrorHandler); // If variable has an initialiser expression if (std::get<1>(var)) { @@ -568,17 +559,17 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // Private methods //--------------------------------------------------------------------------- - const Type::QualifiedType &evaluateType(const Expression::Base *expression) + const Type::Base *evaluateType(const Expression::Base *expression) { expression->accept(*this); - return m_QualifiedType; + return m_Type; } //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- EnvironmentInternal *m_Environment; - Type::QualifiedType m_QualifiedType; + const Type::Base *m_Type; ErrorHandlerBase &m_ErrorHandler; bool m_InLoop; @@ -589,28 +580,26 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- -const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Type op, - const Type::QualifiedType &existingType, const Type::QualifiedType &assignedType, - ErrorHandlerBase &errorHandler, bool initializer) const +const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, + const Type::Base *existingType, const Type::Base *assignedType, + ErrorHandlerBase &errorHandler, bool initializer) const { - // If existing type is a constant numeric value or if it's a constant pointer give errors - auto numericExistingType = dynamic_cast(existingType.type); - auto pointerExistingType = dynamic_cast(existingType.type); - if(!initializer && ((numericExistingType && existingType.constValue) - || (pointerExistingType && existingType.constPointer))) - { + // If existing type is a const qualified and isn't being initialized, give error + if(!initializer && existingType->hasQualifier(Type::Qualifier::CONST)) { errorHandler.error(name, "Assignment of read-only variable"); throw TypeCheckError(); } // If assignment operation is plain equals, any type is fine so return - auto numericAssignedType = dynamic_cast(assignedType.type); - auto pointerAssignedType = dynamic_cast(assignedType.type); + auto numericExistingType = dynamic_cast(existingType); + auto pointerExistingType = dynamic_cast(existingType); + auto numericAssignedType = dynamic_cast(assignedType); + auto pointerAssignedType = dynamic_cast(assignedType); if(op == Token::Type::EQUAL) { // If we're initialising a pointer with another pointer if (pointerAssignedType && pointerExistingType) { // If we're trying to assign a pointer to a const value to a pointer - if (assignedType.constValue && !existingType.constValue) { + if (assignedType->hasQualifier(Type::Qualifier::CONST) && !existingType->hasQualifier(Type::Qualifier::CONST)) { errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName() + "' and '" + pointerAssignedType->getName()); throw TypeCheckError(); } @@ -623,7 +612,7 @@ const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Typ } // Otherwise, if we're trying to initialise a pointer with a non-pointer or vice-versa else if (pointerAssignedType || pointerExistingType) { - errorHandler.error(name, "Invalid operand types '" + existingType.type->getName() + "' and '" + assignedType.type->getName()); + errorHandler.error(name, "Invalid operand types '" + existingType->getName() + "' and '" + assignedType->getName()); throw TypeCheckError(); } } @@ -632,7 +621,7 @@ const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Typ // If the operand being added isn't numeric or the type being added to is neither numeric or a pointer if (!numericAssignedType || (!pointerExistingType && !numericExistingType)) { - errorHandler.error(name, "Invalid operand types '" + existingType.type->getName() + "' and '" + assignedType.type->getName() + "'"); + errorHandler.error(name, "Invalid operand types '" + existingType->getName() + "' and '" + assignedType->getName() + "'"); throw TypeCheckError(); } @@ -650,7 +639,7 @@ const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Typ throw TypeCheckError(); } if(!numericExistingType) { - errorHandler.error(name, "Invalid operand types '" + existingType.type->getName() + "'"); + errorHandler.error(name, "Invalid operand types '" + existingType->getName() + "'"); throw TypeCheckError(); } @@ -672,15 +661,11 @@ const Type::QualifiedType &EnvironmentBase::assign(const Token &name, Token::Typ return existingType; } //--------------------------------------------------------------------------- -const Type::QualifiedType &EnvironmentBase::incDec(const Token &name, Token::Type, - const Type::QualifiedType &existingType, ErrorHandlerBase &errorHandler) const +const Type::Base *EnvironmentBase::incDec(const Token &name, Token::Type, + const Type::Base *existingType, ErrorHandlerBase &errorHandler) const { - // If existing type is a constant numeric value or if it's a constant pointer give errors - auto numericExistingType = dynamic_cast(existingType.type); - auto pointerExistingType = dynamic_cast(existingType.type); - if((numericExistingType && existingType.constValue) - || (pointerExistingType && existingType.constPointer)) - { + // If existing type has a constant qualifier, give errors + if(existingType->hasQualifier(Type::Qualifier::CONST)) { errorHandler.error(name, "Increment/decrement of read-only variable"); throw TypeCheckError(); } @@ -701,9 +686,9 @@ void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &st visitor.typeCheck(statements, internalEnvironment); } //--------------------------------------------------------------------------- -Type::QualifiedType GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, - EnvironmentBase &environment, - ErrorHandlerBase &errorHandler) +const Type::Base *GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, + EnvironmentBase &environment, + ErrorHandlerBase &errorHandler) { Visitor visitor(errorHandler); EnvironmentInternal internalEnvironment(environment); diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index dbf0af2cc5..5a092a5a61 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -80,11 +80,11 @@ IMPLEMENT_TYPE(Sqrt); //---------------------------------------------------------------------------- // GeNN::Type::Base //---------------------------------------------------------------------------- -const Base *Base::getPointerType() const +const Base *Base::getPointerType(Qualifier qualifiers) const { // **TODO** befriend constructor // **TODO** don't just leak these! - return new Pointer(this); + return new Pointer(this, qualifiers); } //---------------------------------------------------------------------------- diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index a091584a09..dbd8f7561c 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -57,37 +57,38 @@ class TestEnvironment : public TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- - void define(const Type::Base *type, std::string_view name, bool isConstValue = false, bool isConstPointer = false) + void define(std::string_view name, const Type::Base *type) { - if(!m_Types.try_emplace(name, type, isConstValue, isConstPointer).second) { + if(!m_Types.try_emplace(name, type).second) { throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } template - void define(std::string_view name, bool isConstValue = false, bool isConstPointer = false) + void define(std::string_view name, Type::Qualifier qualifiers = Type::Qualifier{0}) { - define(T::getInstance(), name, isConstValue, isConstPointer); + define(name, T::getInstance()->getQualifiedType(qualifiers)); } template - void definePointer(std::string_view name, bool isConstValue = false, bool isConstPointer = false) + void definePointer(std::string_view name, Type::Qualifier valueQualifiers = Type::Qualifier{0}, + Type::Qualifier pointerQualifiers = Type::Qualifier{0}) { - define(T::getInstance()->getPointerType(), name, isConstValue, isConstPointer); + define(name, T::getInstance()->getQualifiedType(valueQualifiers)->getPointerType(pointerQualifiers)); } //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- - virtual void define(const Token &name, const Type::QualifiedType &, ErrorHandlerBase &errorHandler) final + virtual void define(const Token &name, const Type::Base*, ErrorHandlerBase &errorHandler) final { errorHandler.error(name, "Cannot declare variable in external environment"); throw TypeChecker::TypeCheckError(); } - virtual const Type::QualifiedType &assign(const Token &name, Token::Type op, const Type::QualifiedType &assignedType, - ErrorHandlerBase &errorHandler, bool initializer = false) final + virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, + ErrorHandlerBase &errorHandler, bool initializer = false) final { // If type isn't found auto existingType = m_Types.find(name.lexeme); @@ -100,7 +101,7 @@ class TestEnvironment : public TypeChecker::EnvironmentBase return EnvironmentBase::assign(name, op, existingType->second, assignedType, errorHandler, initializer); } - virtual const Type::QualifiedType &incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final + virtual const Type::Base *incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final { auto existingType = m_Types.find(name.lexeme); if(existingType == m_Types.end()) { @@ -112,7 +113,7 @@ class TestEnvironment : public TypeChecker::EnvironmentBase return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); } - virtual const Type::QualifiedType &getType(const Token &name, ErrorHandlerBase &errorHandler) final + virtual const Type::Base *getType(const Token &name, ErrorHandlerBase &errorHandler) final { auto type = m_Types.find(std::string{name.lexeme}); if(type == m_Types.end()) { @@ -128,7 +129,7 @@ class TestEnvironment : public TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- - std::unordered_map m_Types; + std::unordered_map m_Types; }; template @@ -153,7 +154,7 @@ void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment ASSERT_FALSE(errorHandler.hasError()); } -Type::QualifiedType typeCheckExpression(std::string_view code, TestEnvironment &typeEnvironment, const Type::NumericBase *scalarType = Type::Float::getInstance()) +const Type::Base *typeCheckExpression(std::string_view code, TestEnvironment &typeEnvironment, const Type::NumericBase *scalarType = Type::Float::getInstance()) { // Scan TestErrorHandler errorHandler; @@ -165,7 +166,7 @@ Type::QualifiedType typeCheckExpression(std::string_view code, TestEnvironment & EXPECT_FALSE(errorHandler.hasError()); // Typecheck - const auto type = TypeChecker::typeCheck(expression.get(), typeEnvironment, errorHandler); + const auto *type = TypeChecker::typeCheck(expression.get(), typeEnvironment, errorHandler); EXPECT_FALSE(errorHandler.hasError()); return type; } @@ -180,10 +181,9 @@ TEST(TypeChecker, ArraySubscript) { TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); - const auto type = typeCheckExpression("intArray[4]", typeEnvironment); - EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("intArray[4]", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); } // Float array indexing @@ -209,7 +209,7 @@ TEST(TypeChecker, Assignment) TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); typeEnvironment.define("floatVal"); - typeEnvironment.define("intValConst", true); + typeEnvironment.define("intValConst", Type::Qualifier::CONST); typeCheckStatements( "int w = intVal;\n" "float x = floatVal;\n" @@ -225,7 +225,7 @@ TEST(TypeChecker, Assignment) { TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); - typeEnvironment.definePointer("intArrayConst", true); + typeEnvironment.definePointer("intArrayConst", Type::Qualifier::CONST); typeCheckStatements( "int *x = intArray;\n" "const int *y = intArray;\n" @@ -236,7 +236,7 @@ TEST(TypeChecker, Assignment) // Pointer assignement, attempt to remove const EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", true); + typeEnvironment.definePointer("intArray", Type::Qualifier::CONST); typeCheckStatements("int *x = intArray;", typeEnvironment);}, TypeChecker::TypeCheckError); @@ -264,60 +264,64 @@ TEST(TypeChecker, Cast) { TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); - const auto type = typeCheckExpression("(float)intVal", typeEnvironment); - EXPECT_EQ(type.type->getName(), Type::Float::getInstance()->getName()); - EXPECT_FALSE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("(float)intVal", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); } // Numeric cast to const { TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); - const auto type = typeCheckExpression("(const int)intVal", typeEnvironment); - EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("(const int)intVal", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONST)); } // Pointer cast to value const { TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); - const auto type = typeCheckExpression("(const int*)intArray", typeEnvironment); - EXPECT_EQ(type.type->getName(), getPointerTypeName()); - EXPECT_TRUE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("(const int*)intArray", typeEnvironment); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); + + const auto *pointerType = dynamic_cast(type); + EXPECT_TRUE(pointerType); + EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); + EXPECT_TRUE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONST)); } // Pointer cast to pointer const { TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); - const auto type = typeCheckExpression("(int * const)intArray", typeEnvironment); - EXPECT_EQ(type.type->getName(), getPointerTypeName()); - EXPECT_FALSE(type.constValue); - EXPECT_TRUE(type.constPointer); + const auto *type = typeCheckExpression("(int * const)intArray", typeEnvironment); + EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONST)); + + const auto *pointerType = dynamic_cast(type); + EXPECT_TRUE(pointerType); + EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); + EXPECT_FALSE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONST)); } // Can't remove value const from numeric EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.define("intVal", true); + typeEnvironment.define("intVal", Type::Qualifier::CONST); typeCheckExpression("(int)intVal", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't remove value const from pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", true); + typeEnvironment.definePointer("intArray", Type::Qualifier::CONST); typeCheckExpression("(int*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't remove pointer const from pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", false, true); + typeEnvironment.definePointer("intArray", Type::Qualifier{0}, Type::Qualifier::CONST); typeCheckExpression("(int*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); @@ -353,43 +357,44 @@ TEST(TypeChecker, IncDec) { TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); - const auto type = typeCheckExpression("intVal++", typeEnvironment); - EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("intVal++", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); } // Can increment pointer { TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); - const auto type = typeCheckExpression("intArray++", typeEnvironment); - EXPECT_EQ(type.type->getName(), getPointerTypeName()); - EXPECT_FALSE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("intArray++", typeEnvironment); + EXPECT_EQ(type->getName(), getPointerTypeName()); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); } // Can increment pointer to const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", true); - const auto type = typeCheckExpression("intArray++", typeEnvironment); - EXPECT_EQ(type.type->getName(), getPointerTypeName()); - EXPECT_TRUE(type.constValue); - EXPECT_FALSE(type.constPointer); + typeEnvironment.definePointer("intArray", Type::Qualifier::CONST); + const auto *type = typeCheckExpression("intArray++", typeEnvironment); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); + + const auto *pointerType = dynamic_cast(type); + EXPECT_TRUE(pointerType); + EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); + EXPECT_TRUE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONST)); } // Can't increment const number EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.define("intVal", true); + typeEnvironment.define("intVal", Type::Qualifier::CONST); typeCheckExpression("intVal++", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't increment const pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", false, true); + typeEnvironment.definePointer("intArray", Type::Qualifier{0}, Type::Qualifier::CONST); typeCheckExpression("intArray++", typeEnvironment);}, TypeChecker::TypeCheckError); } @@ -399,55 +404,55 @@ TEST(TypeChecker, Literal) // Float { TestEnvironment typeEnvironment; - const auto type = typeCheckExpression("1.0f", typeEnvironment); - EXPECT_EQ(type.type->getName(), Type::Float::getInstance()->getName()); - EXPECT_TRUE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("1.0f", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); + //EXPECT_TRUE(type.constValue); + //EXPECT_FALSE(type.constPointer); } // Scalar with single-precision { TestEnvironment typeEnvironment; - const auto type = typeCheckExpression("1.0", typeEnvironment); - EXPECT_EQ(type.type->getName(), Type::Float::getInstance()->getName()); - EXPECT_TRUE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("1.0", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); + //EXPECT_TRUE(type.constValue); + //EXPECT_FALSE(type.constPointer); } // Scalar with double-precision { TestEnvironment typeEnvironment; - const auto type = typeCheckExpression("1.0", typeEnvironment, Type::Double::getInstance()); - EXPECT_EQ(type.type->getName(), Type::Double::getInstance()->getName()); - EXPECT_TRUE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("1.0", typeEnvironment, Type::Double::getInstance()); + EXPECT_EQ(type->getName(), Type::Double::getInstance()->getName()); + //EXPECT_TRUE(type.constValue); + //EXPECT_FALSE(type.constPointer); } // Double { TestEnvironment typeEnvironment; - const auto type = typeCheckExpression("1.0d", typeEnvironment); - EXPECT_EQ(type.type->getName(), Type::Double::getInstance()->getName()); - EXPECT_TRUE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("1.0d", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Double::getInstance()->getName()); + //EXPECT_TRUE(type.constValue); + //EXPECT_FALSE(type.constPointer); } // Integer { TestEnvironment typeEnvironment; - const auto type = typeCheckExpression("100", typeEnvironment); - EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("100", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + //EXPECT_TRUE(type.constValue); + //EXPECT_FALSE(type.constPointer); } // Unsigned integer { TestEnvironment typeEnvironment; - const auto type = typeCheckExpression("100U", typeEnvironment); - EXPECT_EQ(type.type->getName(), Type::Uint32::getInstance()->getName()); - EXPECT_TRUE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("100U", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Uint32::getInstance()->getName()); + //EXPECT_TRUE(type.constValue); + //EXPECT_FALSE(type.constPointer); } } //-------------------------------------------------------------------------- @@ -457,40 +462,36 @@ TEST(TypeChecker, Unary) { TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); - const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("*intArray", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); } // Dereference pointer to const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", true); - const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(type.constValue); - EXPECT_FALSE(type.constPointer); + typeEnvironment.definePointer("intArray", Type::Qualifier::CONST); + const auto *type = typeCheckExpression("*intArray", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONST)); } // Dereference const pointer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", false, true); - const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(type.constValue); - EXPECT_FALSE(type.constPointer); + typeEnvironment.definePointer("intArray", Type::Qualifier{0}, Type::Qualifier::CONST); + const auto *type = typeCheckExpression("*intArray", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); } // Dereference const pointer to const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", true, true); - const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type.type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(type.constValue); - EXPECT_FALSE(type.constPointer); + typeEnvironment.definePointer("intArray", Type::Qualifier::CONST, Type::Qualifier::CONST); + const auto *type = typeCheckExpression("*intArray", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONST)); } // Dereference numeric @@ -504,10 +505,13 @@ TEST(TypeChecker, Unary) { TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); - const auto type = typeCheckExpression("&intVal", typeEnvironment); - EXPECT_EQ(type.type->getName(), getPointerTypeName()); - EXPECT_FALSE(type.constValue); - EXPECT_FALSE(type.constPointer); + const auto *type = typeCheckExpression("&intVal", typeEnvironment); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); + + const auto *pointerType = dynamic_cast(type); + EXPECT_TRUE(pointerType); + EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); + EXPECT_FALSE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONST)); } // Address of pointer From 3abbf2f443e19b301767e8216e5c1485301e5772 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 17:53:34 +0000 Subject: [PATCH 056/725] added getSizeBytes() to Type::Base and ripped out old Backend-based type size getting functionality --- include/genn/backends/cuda/backend.h | 2 +- include/genn/backends/opencl/backend.h | 2 +- .../backends/single_threaded_cpu/backend.h | 2 +- .../genn/genn/code_generator/backendBase.h | 26 +------- .../genn/genn/code_generator/backendSIMT.h | 9 --- .../genn/genn/code_generator/codeGenUtils.h | 4 ++ .../genn/genn/code_generator/groupMerged.h | 10 ++- include/genn/genn/type.h | 17 +++-- src/genn/backends/cuda/backend.cc | 49 ++++++++++++++- src/genn/backends/opencl/backend.cc | 48 ++++++++++++-- .../backends/single_threaded_cpu/backend.cc | 4 +- src/genn/genn/code_generator/backendBase.cc | 63 +------------------ src/genn/genn/code_generator/backendSIMT.cc | 15 ----- .../customConnectivityUpdateGroupMerged.cc | 3 +- src/genn/genn/code_generator/groupMerged.cc | 3 +- .../genn/code_generator/initGroupMerged.cc | 3 +- 16 files changed, 120 insertions(+), 140 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index 02b67358ac..e8335bd09c 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -209,7 +209,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const override; //! When generating merged structures what type to use for simulation RNGs - virtual std::string getMergedGroupSimRNGType() const override { return "curandState"; } + virtual const Type::Base *getMergedGroupSimRNGType() const override; virtual void genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const override; virtual void genVariablePull(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, size_t count) const override; diff --git a/include/genn/backends/opencl/backend.h b/include/genn/backends/opencl/backend.h index 55b89c7e39..2b657fdad3 100644 --- a/include/genn/backends/opencl/backend.h +++ b/include/genn/backends/opencl/backend.h @@ -172,7 +172,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const override; //! When generating merged structures what type to use for simulation RNGs - virtual std::string getMergedGroupSimRNGType() const override { return "clrngLfsr113HostStream"; } + virtual const Type::Base *getMergedGroupSimRNGType() const; virtual void genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const override; virtual void genVariablePull(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, size_t count) const override; diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index c19db36852..b59cea525e 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -81,7 +81,7 @@ class BACKEND_EXPORT Backend : public BackendBase virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const override; //! When generating merged structures what type to use for simulation RNGs - virtual std::string getMergedGroupSimRNGType() const override; + virtual const Type::Base *getMergedGroupSimRNGType() const override; virtual void genPopVariableInit(CodeStream &os,const Substitutions &kernelSubs, Handler handler) const override; virtual void genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 797a4b67d1..51aad00465 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -270,7 +270,7 @@ class GENN_EXPORT BackendBase virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const = 0; //! When generating merged structures what type to use for simulation RNGs - virtual std::string getMergedGroupSimRNGType() const = 0; + virtual const Type::Base *getMergedGroupSimRNGType() const = 0; virtual void genPopVariableInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0; virtual void genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, @@ -430,12 +430,6 @@ class GENN_EXPORT BackendBase genVariableAllocation(allocations, type, name, loc, count, memAlloc); } - //! Get the size of the type - size_t getSize(const std::string &type) const; - - //! Get the lowest value of a type - std::string getLowestValue(const std::string &type) const; - //! Get the prefix for accessing the address of 'scalar' variables std::string getScalarAddressPrefix() const { @@ -470,17 +464,6 @@ class GENN_EXPORT BackendBase //-------------------------------------------------------------------------- // Protected API //-------------------------------------------------------------------------- - void addType(const std::string &type, size_t size, const std::string &lowestValue = "") - { - m_Types.emplace(std::piecewise_construct, std::forward_as_tuple(type), - std::forward_as_tuple(size, lowestValue)); - } - - void setPointerBytes(size_t pointerBytes) - { - m_PointerBytes = pointerBytes; - } - void genNeuronIndexCalculation(CodeStream &os, const NeuronUpdateGroupMerged &ng, unsigned int batchSize) const; void genSynapseIndexCalculation(CodeStream &os, const SynapseGroupMergedBase &sg, unsigned int batchSize) const; @@ -534,13 +517,6 @@ class GENN_EXPORT BackendBase //-------------------------------------------------------------------------- // Members //-------------------------------------------------------------------------- - //! How large is a device pointer? E.g. on some AMD devices this != sizeof(char*) - size_t m_PointerBytes; - - //! Size of supported types in bytes and string containing their lowest value - //! used for estimating memory usage and for reduction operations - std::unordered_map> m_Types; - //! Preferences const PreferencesBase &m_Preferences; }; diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index d7502c7bc0..0525766fcf 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -209,12 +209,6 @@ class GENN_EXPORT BackendSIMT : public BackendBase void genInitializeSparseKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t numInitializeThreads, size_t &idStart) const; - //! Adds a type - both to backend base's list of sized types but also to device types set - void addDeviceType(const std::string &type, size_t size, const std::string &maxValue = ""); - - //! Is type a a device only type? - bool isDeviceType(const std::string &type) const; - //! Helper wrapper around padSize to pad size to a kernel size size_t padKernelSize(size_t size, Kernel kernel) const; @@ -475,9 +469,6 @@ class GENN_EXPORT BackendSIMT : public BackendBase //-------------------------------------------------------------------------- const KernelBlockSize m_KernelBlockSizes; - //! Types that are only supported on device i.e. should never be exposed to user code - std::unordered_set m_DeviceTypes; - //-------------------------------------------------------------------------- // Static members //-------------------------------------------------------------------------- diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 223ba23864..fcec71cffa 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -96,6 +96,10 @@ GENN_EXPORT void checkUnreplacedVariables(const std::string &code, const std::st //-------------------------------------------------------------------------- GENN_EXPORT std::string disambiguateNamespaceFunction(const std::string supportCode, const std::string code, std::string namespaceName); +//-------------------------------------------------------------------------- +/*! \brief This function automatically replaces old style $(variable) variable references and $(function, arg1, arg2) syntax with new form. + */ + //-------------------------------------------------------------------------- GENN_EXPORT std::string upgradeCodeString(const std::string &codeString); //------------------------------------------------------------------------- diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 2e8fc26ece..f167918a73 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -90,14 +90,13 @@ class GroupMerged //! Get group fields, sorted into order they will appear in struct std::vector getSortedFields(const BackendBase &backend) const { - // **TODO** size should come from type system itself - numerics are easy pointer size is a little trickier // Make a copy of fields and sort so largest come first. This should mean that due // to structure packing rules, significant memory is saved and estimate is more precise auto sortedFields = m_Fields; std::sort(sortedFields.begin(), sortedFields.end(), [&backend](const Field &a, const Field &b) { - return (backend.getSize(std::get<0>(a)->getName()) > backend.getSize(std::get<0>(b)->getName())); + return (std::get<0>(a)->getSizeBytes() > std::get<0>(b)->getSizeBytes()); }); return sortedFields; @@ -159,8 +158,7 @@ class GroupMerged const auto sortedFields = getSortedFields(backend); for(const auto &f : sortedFields) { // Add size of field to total - // **TODO** size should be built into type system - const size_t fieldSize = backend.getSize(std::get<0>(f)->getName()); + const size_t fieldSize = std::get<0>(f)->getSizeBytes(); structSize += fieldSize; // Update largest field size @@ -267,12 +265,12 @@ class GroupMerged fieldType); } - void addPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) + void addPointerField(const Type::Base *type, const std::string &name, const std::string &prefix) { addField(type->getPointerType(), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); } - template>* = nullptr> + template void addPointerField(const std::string &name, const std::string &prefix) { addField(T::getInstance()->getPointerType(), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 5db472750f..d1b5a96ee9 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -1,6 +1,7 @@ #pragma once // Standard C includes +#include #include // Standard C++ includes @@ -104,6 +105,7 @@ class Base //------------------------------------------------------------------------ virtual std::string getName() const = 0; virtual Base *getQualifiedType(Qualifier qualifiers) const = 0; + virtual size_t getSizeBytes() const = 0; //------------------------------------------------------------------------ // Public API @@ -136,7 +138,8 @@ class Pointer : public Base //------------------------------------------------------------------------ virtual std::string getName() const{ return getValueType()->getName() + "*";} virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new Pointer(m_ValueType, qualifiers); } - + virtual size_t getSizeBytes() const final{ return sizeof(char*); } + //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ @@ -186,6 +189,7 @@ class Numeric : public NumericBase // Base virtuals //------------------------------------------------------------------------ virtual size_t getTypeHash() const final { return typeid(T).hash_code(); } + virtual size_t getSizeBytes() const final{ return sizeof(T); } //------------------------------------------------------------------------ // NumericBase virtuals @@ -205,11 +209,6 @@ class ForeignFunctionBase : public Base { public: ForeignFunctionBase(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} - - //------------------------------------------------------------------------ - // Base virtuals - //------------------------------------------------------------------------ - virtual std::string getName() const = 0; //------------------------------------------------------------------------ // Declared virtuals @@ -237,6 +236,12 @@ class ForeignFunction : public ForeignFunctionBase typeName += ")"; return typeName; } + + virtual size_t getSizeBytes() const final + { + assert(false); + return 0; + } //------------------------------------------------------------------------ // ForeignFunctionBase virtuals diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 1e55ca0966..a49f122ed8 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -43,6 +43,45 @@ const std::vector cudaDoublePrecisionFunctions {"gennrand_gamma", 1, "gammaDistDouble($(rng), $(0))"}, {"gennrand_binomial", 2, "binomialDistDouble($(rng), $(0), $(1))"} }; + +//-------------------------------------------------------------------------- +// CURandState +//-------------------------------------------------------------------------- +class CURandState : public Type::Base +{ +public: + DECLARE_TYPE(CURandState); + + CURandState(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} + + //------------------------------------------------------------------------ + // Base overloads + //------------------------------------------------------------------------ + virtual std::string getName() const final{ return "curandState"; } + virtual Base *getQualifiedType(Qualifier qualifiers) const { return new CURandState(qualifiers); } + virtual size_t getSizeBytes() const final{ return 44; } +}; +IMPLEMENT_TYPE(CURandState); + +//-------------------------------------------------------------------------- +// CURandStatePhilox43210 +//-------------------------------------------------------------------------- +class CURandStatePhilox43210 : public Type::Base +{ +public: + DECLARE_TYPE(CURandStatePhilox43210); + + CURandStatePhilox43210(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} + + //------------------------------------------------------------------------ + // Base overloads + //------------------------------------------------------------------------ + virtual std::string getName() const final{ return "curandStatePhilox4_32_10_t"; } + virtual Base *getQualifiedType(Qualifier qualifiers) const { return new CURandStatePhilox43210(qualifiers); } + virtual size_t getSizeBytes() const final{ return 64; } +}; +IMPLEMENT_TYPE(CURandStatePhilox43210); + //-------------------------------------------------------------------------- // Timer //-------------------------------------------------------------------------- @@ -310,9 +349,8 @@ Backend::Backend(const KernelBlockSize &kernelBlockSizes, const Preferences &pre #endif // Add CUDA-specific types, additionally marking them as 'device types' innaccesible to host code - addDeviceType("curandState", 44); - addDeviceType("curandStatePhilox4_32_10_t", 64); - addDeviceType("half", 2); + + //addDeviceType("half", 2); } //-------------------------------------------------------------------------- bool Backend::areSharedMemAtomicsSlow() const @@ -1735,6 +1773,11 @@ std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) con return type->getName(); } //-------------------------------------------------------------------------- +const Type::Base *Backend::getMergedGroupSimRNGType() const +{ + return CLRRNGLFSR113Stream::getInstance(); +} +//-------------------------------------------------------------------------- void Backend::genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const { assert(!getPreferences().automaticCopy); diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index 1942898695..e395cca026 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -41,7 +41,46 @@ const std::vector openclPhilloxFunctions = { {"gennrand_gamma", 1, "gammaDistPhilox432($(rng), $(0))"}, {"gennrand_binomial", 2, "binomialDistPhilox432($(rng), $(0), $(1))"} }; + +//-------------------------------------------------------------------------- +// CLRRNGLFSR113Stream //-------------------------------------------------------------------------- +class CLRRNGLFSR113Stream : public Type::Base +{ +public: + DECLARE_TYPE(CLRRNGLFSR113Stream); + + CLRRNGLFSR113Stream(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} + + //------------------------------------------------------------------------ + // Base overloads + //------------------------------------------------------------------------ + virtual std::string getName() const final{ return "clrngLfsr113Stream"; } + virtual Base *getQualifiedType(Qualifier qualifiers) const { return new CLRRNGLFSR113Stream(qualifiers); } + virtual size_t getSizeBytes() const final{ return 48; } +}; +IMPLEMENT_TYPE(CLRRNGLFSR113Stream); + +//-------------------------------------------------------------------------- +// CLRRNGPhilox432Stream +//-------------------------------------------------------------------------- +class CLRRNGPhilox432Stream : public Type::Base +{ +public: + DECLARE_TYPE(CLRRNGPhilox432Stream); + + CLRRNGPhilox432Stream(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} + + //------------------------------------------------------------------------ + // Base overloads + //------------------------------------------------------------------------ + virtual std::string getName() const final{ return "clrngPhilox432Stream"; } + virtual Base *getQualifiedType(Qualifier qualifiers) const { return new CLRRNGLFSR113Stream(qualifiers); } + virtual size_t getSizeBytes() const final{ return 132; } +}; +IMPLEMENT_TYPE(CLRRNGLFSR113Stream); + + template void genMergedGroupKernelParams(CodeStream &os, const std::vector &groups, bool includeFinalComma = false) { @@ -193,10 +232,6 @@ Backend::Backend(const KernelBlockSize &kernelBlockSizes, const Preferences &pre // Determine minimum alignement m_AllocationAlignementBytes = m_ChosenDevice.getInfo() / 8; LOGI_BACKEND << "Device uses " << m_AllocationAlignementBytes << " byte alignement"; - - // Add OpenCL-specific types - addType("clrngLfsr113Stream", 48); - addType("clrngPhilox432Stream", 132); } //-------------------------------------------------------------------------- bool Backend::areSharedMemAtomicsSlow() const @@ -2053,6 +2088,11 @@ std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) con } } //-------------------------------------------------------------------------- +const Type::Base *Backend::getMergedGroupSimRNGType() const +{ + return CLRRNGLFSR113Stream::getInstance(); +} +//-------------------------------------------------------------------------- void Backend::genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const { if (!(loc & VarLocation::ZERO_COPY)) { diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 653a342e15..59b1d85c23 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1353,10 +1353,10 @@ std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) con return type->getName(); } //-------------------------------------------------------------------------- -std::string Backend::getMergedGroupSimRNGType() const +const Type::Base *Backend::getMergedGroupSimRNGType() const { assert(false); - return ""; + return nullptr; } //-------------------------------------------------------------------------- void Backend::genPopVariableInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 9c40a8ee8a..8a110517a4 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -10,73 +10,14 @@ #include "code_generator/customUpdateGroupMerged.h" #include "code_generator/neuronUpdateGroupMerged.h" - -// Macro for simplifying defining type sizes -#define TYPE(T) {#T, {sizeof(T), std::to_string(std::numeric_limits::lowest())}} -#define FLOAT_TYPE(T) {#T, {sizeof(T), Utils::writePreciseString(std::numeric_limits::lowest())}} - //-------------------------------------------------------------------------- // GeNN::CodeGenerator::BackendBase //-------------------------------------------------------------------------- namespace GeNN::CodeGenerator { BackendBase::BackendBase(const std::string &scalarType, const PreferencesBase &preferences) -: m_PointerBytes(sizeof(char*)), m_Types{{TYPE(char), TYPE(wchar_t), TYPE(signed char), TYPE(short), - TYPE(signed short), TYPE(short int), TYPE(signed short int), TYPE(int), TYPE(signed int), TYPE(long), - TYPE(signed long), TYPE(long int), TYPE(signed long int), TYPE(long long), TYPE(signed long long), TYPE(long long int), - TYPE(signed long long int), TYPE(unsigned char), TYPE(unsigned short), TYPE(unsigned short int), TYPE(unsigned), - TYPE(unsigned int), TYPE(unsigned long), TYPE(unsigned long int), TYPE(unsigned long long), - TYPE(unsigned long long int), TYPE(bool), TYPE(intmax_t), TYPE(uintmax_t), TYPE(int8_t), TYPE(uint8_t), - TYPE(int16_t), TYPE(uint16_t), TYPE(int32_t), TYPE(uint32_t), TYPE(int64_t), TYPE(uint64_t), - TYPE(int_least8_t), TYPE(uint_least8_t), TYPE(int_least16_t), TYPE(uint_least16_t), TYPE(int_least32_t), - TYPE(uint_least32_t), TYPE(int_least64_t), TYPE(uint_least64_t), TYPE(int_fast8_t), TYPE(uint_fast8_t), - TYPE(int_fast16_t), TYPE(uint_fast16_t), TYPE(int_fast32_t), TYPE(uint_fast32_t), TYPE(int_fast64_t), - TYPE(uint_fast64_t), FLOAT_TYPE(float), FLOAT_TYPE(double), FLOAT_TYPE(long double)}}, m_Preferences(preferences) -{ - // Add scalar type - if(scalarType == "float") { - addType("scalar", sizeof(float), Utils::writePreciseString(std::numeric_limits::lowest())); - } - else { - addType("scalar", sizeof(double), Utils::writePreciseString(std::numeric_limits::lowest())); - } -} -//-------------------------------------------------------------------------- -size_t BackendBase::getSize(const std::string &type) const -{ - // If type is a pointer, any pointer should have the same type - if(Utils::isTypePointer(type)) { - return m_PointerBytes; - } - // Otherwise - else { - // If type isn't found in dictionary, give a warning and return 0 - const auto typeSizeLowest = m_Types.find(type); - if(typeSizeLowest == m_Types.cend()) { - LOGW_CODE_GEN << "Unable to estimate size of type '" << type << "'"; - return 0; - } - // Otherwise, return its size - else { - return typeSizeLowest->second.first; - } - } -} -//-------------------------------------------------------------------------- -std::string BackendBase::getLowestValue(const std::string &type) const +: m_Preferences(preferences) { - assert(!Utils::isTypePointer(type)); - - // If type's found in dictionary and it has a lowest value - const auto typeSizeLowest = m_Types.find(type); - if(typeSizeLowest != m_Types.cend() && !typeSizeLowest->second.second.empty()) { - return typeSizeLowest->second.second; - } - // Otherwise, give warning and return empty string - else { - LOGW_CODE_GEN << "Unable to get lowest value for type '" << type << "'"; - return ""; - } } //-------------------------------------------------------------------------- bool BackendBase::areSixtyFourBitSynapseIndicesRequired(const SynapseGroupMergedBase &sg) const @@ -269,4 +210,4 @@ std::vector BackendBase::genInitReductionTargets(C index); }); } -} // namespace GeNN::CodeGenerator \ No newline at end of file +} // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 35df3ca9d2..d02d5287c0 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -1704,21 +1704,6 @@ void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions }); } //-------------------------------------------------------------------------- -void BackendSIMT::addDeviceType(const std::string &type, size_t size, const std::string &maxValue) -{ - addType(type, size, maxValue); - m_DeviceTypes.emplace(type); -} -//-------------------------------------------------------------------------- -bool BackendSIMT::isDeviceType(const std::string &type) const -{ - // Get underlying type - const std::string underlyingType = Utils::isTypePointer(type) ? Utils::getUnderlyingType(type) : type; - - // Return true if it is in device types set - return (m_DeviceTypes.find(underlyingType) != m_DeviceTypes.cend()); -} -//-------------------------------------------------------------------------- size_t BackendSIMT::padKernelSize(size_t size, Kernel kernel) const { return padSize(size, getKernelBlockSize(kernel)); diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 9551255875..047d933262 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -143,8 +143,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // If this backend requires per-population RNGs and this group requires one if(backend.isPopulationRNGRequired() && getArchetype().isRowSimRNGRequired()){ - assert(false); - //addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); + addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); } // Add variables to struct diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index ecc33d2289..0822bd5bed 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -238,8 +238,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr if(backend.isPopulationRNGRequired() && getArchetype().isSimRNGRequired() && (!init || backend.isPopulationRNGInitialisedOnDevice())) { - assert(false); - //addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); + addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); } // Loop through variables diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 8ec6db0517..7016cc4db2 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -1088,8 +1088,7 @@ CustomConnectivityUpdatePreInitGroupMerged::CustomConnectivityUpdatePreInitGroup // If this backend initialises population RNGs on device and this group requires one for simulation if(backend.isPopulationRNGRequired() && getArchetype().isRowSimRNGRequired() && backend.isPopulationRNGInitialisedOnDevice()) { - assert(false); - //addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); + addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); } } //---------------------------------------------------------------------------- From bcacfa76c21905447b999072b644f86228a1a82f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 16 Jan 2023 18:13:52 +0000 Subject: [PATCH 057/725] commenting --- include/genn/genn/type.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index d1b5a96ee9..c311638901 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -103,21 +103,29 @@ class Base //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ + //! Get the (unqualified) name of this type virtual std::string getName() const = 0; + + //! Return new version of this type with specified qualifiers virtual Base *getQualifiedType(Qualifier qualifiers) const = 0; + + //! Get size of this type in bytes virtual size_t getSizeBytes() const = 0; //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ + //! Return a pointer to this type, optionally, with specified qualifiers const Base *getPointerType(Qualifier qualifiers = Qualifier{0}) const; + //! Does this type have qualifier? bool hasQualifier(Qualifier qualifier) const{ return (m_Qualifiers & qualifier); }; private: //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ + //! Bitfield of qualifiers const Qualifier m_Qualifiers; }; @@ -188,7 +196,6 @@ class Numeric : public NumericBase //------------------------------------------------------------------------ // Base virtuals //------------------------------------------------------------------------ - virtual size_t getTypeHash() const final { return typeid(T).hash_code(); } virtual size_t getSizeBytes() const final{ return sizeof(T); } //------------------------------------------------------------------------ From 6e96fdbaaa3dda8745dba51b0f1b723c6f78ee16 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 17 Jan 2023 12:22:10 +0000 Subject: [PATCH 058/725] given up figting windows stupid CONST macro - just renamed qualifier to CONSTANT --- .../genn/code_generator/groupMergedTypeEnvironment.h | 12 ++++++------ include/genn/genn/type.h | 7 +------ src/genn/genn/transpiler/parser.cc | 8 ++++---- src/genn/genn/transpiler/prettyPrinter.cc | 4 ++-- src/genn/genn/transpiler/typeChecker.cc | 12 ++++++------ 5 files changed, 19 insertions(+), 24 deletions(-) diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index b2c158c894..fb55907a19 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -132,7 +132,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa void definePointerField(const Type::NumericBase *type, const std::string &name,const std::string &prefix, VarAccessMode access) { - const auto *qualifiedType = (access & VarAccessModeAttribute::READ_ONLY) ? type->getQualifiedType(Type::Qualifier::CONST) : type; + const auto *qualifiedType = (access & VarAccessModeAttribute::READ_ONLY) ? type->getQualifiedType(Type::Qualifier::CONSTANT) : type; defineField(qualifiedType, name, type->getPointerType(), name, [prefix](const auto &g, size_t) { return prefix + g.getName(); }); } @@ -144,7 +144,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa // Loop through params for(const auto &p : paramNames) { if (std::invoke(isHeterogeneous, m_GroupMerged, p)) { - defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONST), p + suffix, + defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONSTANT), p + suffix, m_ScalarType, p + suffix, [p, getParamValues](const auto &g, size_t) { @@ -154,7 +154,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa } // Otherwise, just add a const-qualified scalar to the type environment else { - defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONST), p + suffix); + defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONSTANT), p + suffix); } } } @@ -166,7 +166,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa // Loop through derived params for(const auto &d : derivedParams) { if (std::invoke(isHeterogeneous, m_GroupMerged, d.name)) { - defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONST), d.name + suffix, + defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONSTANT), d.name + suffix, m_ScalarType, d.name + suffix, [d, getDerivedParamValues](const auto &g, size_t) { @@ -175,7 +175,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa }); } else { - defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONST), d.name + suffix); + defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONSTANT), d.name + suffix); } } } @@ -196,7 +196,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa const auto *type = Type::parseNumeric(v.type, m_ScalarType); // If variable access is read-only, qualify type with const - const auto *qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? type->getQualifiedType(Type::Qualifier::CONST) : type; + const auto *qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? type->getQualifiedType(Type::Qualifier::CONSTANT) : type; defineField(qualifiedType, v.name, type->getPointerType(), v.name, [arrayPrefix, getVarRefFn, v](const auto &g, size_t) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index c311638901..4528c90f53 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -57,11 +57,6 @@ #define IMPLEMENT_TYPE(TYPE) TYPE *TYPE::s_Instance = NULL #define IMPLEMENT_NUMERIC_TYPE(TYPE) IMPLEMENT_TYPE(TYPE) -// **YUCK** on Windows undefine CONST macro (some part of wincrypt) -#ifdef _WIN32 - #undef CONST -#endif - //---------------------------------------------------------------------------- // GeNN::Type::TypeTraits //---------------------------------------------------------------------------- @@ -78,7 +73,7 @@ struct TypeTraits //---------------------------------------------------------------------------- enum class Qualifier : unsigned int { - CONST = (1 << 0) + CONSTANT = (1 << 0) }; inline bool operator & (Qualifier a, Qualifier b) diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 3bc11d5343..6ba80ce2b9 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -201,8 +201,8 @@ const GeNN::Type::Base *parseDeclarationSpecifiers(ParserState &parserState) // Otherwise, if type is a qualifier else if(parserState.previous().type == Token::Type::TYPE_QUALIFIER) { // Add qualifier lexeme to correct list - std::set &typeQualifiers = pointerTypeQualifiers.empty() ? typeQualifiers : pointerTypeQualifiers.back(); - if(!typeQualifiers.insert(parserState.previous().lexeme).second) { + std::set &qualifiers = pointerTypeQualifiers.empty() ? typeQualifiers : pointerTypeQualifiers.back(); + if(!qualifiers.insert(parserState.previous().lexeme).second) { parserState.error(parserState.previous(), "duplicate type qualifier"); } } @@ -222,13 +222,13 @@ const GeNN::Type::Base *parseDeclarationSpecifiers(ParserState &parserState) // If there are any type qualifiers, add const // **THINK** this relies of const being only qualifier if(!typeQualifiers.empty()) { - type = type->getQualifiedType(Qualifier::CONST); + type = type->getQualifiedType(Qualifier::CONSTANT); } // Loop through levels of pointer indirection // **THINK** this relies of const being only qualifier for(const auto &p : pointerTypeQualifiers) { - type = type->getPointerType(p.empty() ? Qualifier{0} : Qualifier::CONST); + type = type->getPointerType(p.empty() ? Qualifier{0} : Qualifier::CONSTANT); } return type; } diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index d9761a9ef8..3faa576e20 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -270,7 +270,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto *pointerType = dynamic_cast(type); if(pointerType) { // If pointer has const qualifier, add const - if(pointerType->hasQualifier(GeNN::Type::Qualifier::CONST)) { + if(pointerType->hasQualifier(GeNN::Type::Qualifier::CONSTANT)) { tokens.push_back("const"); } @@ -286,7 +286,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor tokens.push_back(type->getName()); - if(pointerType->hasQualifier(GeNN::Type::Qualifier::CONST)) { + if(pointerType->hasQualifier(GeNN::Type::Qualifier::CONSTANT)) { tokens.push_back("const"); } break; diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index c0749bef78..bb17a711b4 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -282,7 +282,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto rightType = evaluateType(cast.getExpression()); // If const is being removed - if (rightType->hasQualifier(Type::Qualifier::CONST) && !cast.getType()->hasQualifier(Type::Qualifier::CONST)) { + if (rightType->hasQualifier(Type::Qualifier::CONSTANT) && !cast.getType()->hasQualifier(Type::Qualifier::CONSTANT)) { m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } @@ -316,8 +316,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor if (trueNumericType && falseNumericType) { // **TODO** check behaviour m_Type = Type::getCommonType(trueNumericType, falseNumericType); - if(trueType->hasQualifier(Type::Qualifier::CONST) || falseType->hasQualifier(Type::Qualifier::CONST)) { - m_Type = m_Type->getQualifiedType(Type::Qualifier::CONST); + if(trueType->hasQualifier(Type::Qualifier::CONSTANT) || falseType->hasQualifier(Type::Qualifier::CONSTANT)) { + m_Type = m_Type->getQualifiedType(Type::Qualifier::CONSTANT); } } else { @@ -585,7 +585,7 @@ const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler, bool initializer) const { // If existing type is a const qualified and isn't being initialized, give error - if(!initializer && existingType->hasQualifier(Type::Qualifier::CONST)) { + if(!initializer && existingType->hasQualifier(Type::Qualifier::CONSTANT)) { errorHandler.error(name, "Assignment of read-only variable"); throw TypeCheckError(); } @@ -599,7 +599,7 @@ const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, // If we're initialising a pointer with another pointer if (pointerAssignedType && pointerExistingType) { // If we're trying to assign a pointer to a const value to a pointer - if (assignedType->hasQualifier(Type::Qualifier::CONST) && !existingType->hasQualifier(Type::Qualifier::CONST)) { + if (assignedType->hasQualifier(Type::Qualifier::CONSTANT) && !existingType->hasQualifier(Type::Qualifier::CONSTANT)) { errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName() + "' and '" + pointerAssignedType->getName()); throw TypeCheckError(); } @@ -665,7 +665,7 @@ const Type::Base *EnvironmentBase::incDec(const Token &name, Token::Type, const Type::Base *existingType, ErrorHandlerBase &errorHandler) const { // If existing type has a constant qualifier, give errors - if(existingType->hasQualifier(Type::Qualifier::CONST)) { + if(existingType->hasQualifier(Type::Qualifier::CONSTANT)) { errorHandler.error(name, "Increment/decrement of read-only variable"); throw TypeCheckError(); } From ad0e5699b2902f4a917ad71f2d297844bb060d38 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 17 Jan 2023 12:50:51 +0000 Subject: [PATCH 059/725] Numeric::UnderlyingType not required --- include/genn/genn/type.h | 5 ----- 1 file changed, 5 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 4528c90f53..0f278ea3d1 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -182,11 +182,6 @@ class Numeric : public NumericBase { public: Numeric(Qualifier qualifiers = Qualifier{0}) : NumericBase(qualifiers){} - - //------------------------------------------------------------------------ - // Typedefines - //------------------------------------------------------------------------ - typedef T UnderlyingType; //------------------------------------------------------------------------ // Base virtuals From 2be4fa42475076ef91a26d123f4aaaa96936a4ae Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 17 Jan 2023 13:02:06 +0000 Subject: [PATCH 060/725] scanner no longer evaluates literal tokens --- include/genn/genn/transpiler/scanner.h | 3 +- include/genn/genn/transpiler/token.h | 7 +- src/genn/genn/transpiler/scanner.cc | 89 ++++++++------------------ 3 files changed, 29 insertions(+), 70 deletions(-) diff --git a/include/genn/genn/transpiler/scanner.h b/include/genn/genn/transpiler/scanner.h index 976c593b64..6eb5d6c1dd 100644 --- a/include/genn/genn/transpiler/scanner.h +++ b/include/genn/genn/transpiler/scanner.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -25,6 +26,6 @@ class ErrorHandlerBase; //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Scanner { -std::vector scanSource(const std::string_view &source, const Type::NumericBase *scalarType, ErrorHandlerBase &errorHandler); +std::vector scanSource(const std::string_view &source, const std::unordered_set &typedefNames, ErrorHandlerBase &errorHandler); } // namespace Scanner diff --git a/include/genn/genn/transpiler/token.h b/include/genn/genn/transpiler/token.h index 433a28dd9a..3a67f38d3a 100644 --- a/include/genn/genn/transpiler/token.h +++ b/include/genn/genn/transpiler/token.h @@ -20,8 +20,6 @@ namespace GeNN::Transpiler { struct Token { - typedef std::variant LiteralValue; - enum class Type { // Single-character tokens @@ -54,14 +52,13 @@ struct Token END_OF_FILE, }; - Token(Type type, std::string_view lexeme, size_t line, LiteralValue literalValue = LiteralValue()) - : type(type), lexeme(lexeme), line(line), literalValue(literalValue) + Token(Type type, std::string_view lexeme, size_t line) + : type(type), lexeme(lexeme), line(line) { } const Type type; const std::string_view lexeme; const size_t line; - const LiteralValue literalValue; }; } // namespace GeNN::Transpiler diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index 5afe181268..8acc0285f4 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -11,9 +11,6 @@ #include #include -// GeNN includes -#include "type.h" - // Transpiler includes #include "transpiler/errorHandler.h" #include "transpiler/transpilerUtils.h" @@ -48,7 +45,6 @@ const std::unordered_map keywords{ {"long", Token::Type::TYPE_SPECIFIER}, {"float", Token::Type::TYPE_SPECIFIER}, {"double", Token::Type::TYPE_SPECIFIER}, - {"scalar", Token::Type::TYPE_SPECIFIER}, {"signed", Token::Type::TYPE_SPECIFIER}, {"unsigned", Token::Type::TYPE_SPECIFIER}, {"uint8_t", Token::Type::TYPE_SPECIFIER}, @@ -58,11 +54,7 @@ const std::unordered_map keywords{ {"uint32_t", Token::Type::TYPE_SPECIFIER}, {"int32_t", Token::Type::TYPE_SPECIFIER}, {"bool", Token::Type::TYPE_SPECIFIER}}; -//--------------------------------------------------------------------------- -const std::map, std::function> integerLiteralSuffixParsers{ - {{}, [](std::string_view input, int base) { return Utils::toCharsThrow(input, base); }}, - {{'U'}, [](std::string_view input, int base) { return Utils::toCharsThrow(input, base); }}, -}; + //--------------------------------------------------------------------------- // ScanState //--------------------------------------------------------------------------- @@ -70,8 +62,8 @@ const std::map, std::function &typedefNames, ErrorHandlerBase &errorHandler) + : m_Start(0), m_Current(0), m_Line(1), m_Source(source), m_TypedefNames(typedefNames), m_ErrorHandler(errorHandler) {} //--------------------------------------------------------------------------- @@ -134,9 +126,8 @@ class ScanState m_ErrorHandler.error(getLine(), message); } - const Type::NumericBase *getScalarType() const - { - return m_ScalarType; + bool isTypedefIdentifier(std::string_view lexeme) { + return (m_TypedefNames.find(std::string{lexeme}) != m_TypedefNames.cend()); } private: @@ -148,7 +139,7 @@ class ScanState size_t m_Line; const std::string_view m_Source; - const Type::NumericBase *m_ScalarType; + const std::unordered_set m_TypedefNames; ErrorHandlerBase &m_ErrorHandler; }; @@ -158,19 +149,17 @@ bool isodigit(char c) } //--------------------------------------------------------------------------- -void emplaceToken(std::vector &tokens, Token::Type type, const ScanState &scanState, Token::LiteralValue literalValue = Token::LiteralValue()) +void emplaceToken(std::vector &tokens, Token::Type type, const ScanState &scanState) { - tokens.emplace_back(type, scanState.getLexeme(), scanState.getLine(), literalValue); + tokens.emplace_back(type, scanState.getLexeme(), scanState.getLine()); } //--------------------------------------------------------------------------- -std::set scanIntegerSuffix(ScanState &scanState) +void scanIntegerSuffix(ScanState &scanState) { // Read suffix - std::set suffix; while(std::toupper(scanState.peek()) == 'U' || std::toupper(scanState.peek()) == 'L') { - suffix.insert(std::toupper(scanState.advance())); + scanState.advance(); } - return suffix; } //--------------------------------------------------------------------------- void scanNumber(char c, ScanState &scanState, std::vector &tokens) @@ -193,10 +182,8 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) } // Add integer token - // **NOTE** skip 0x prefix - const auto suffix = scanIntegerSuffix(scanState); - emplaceToken(tokens, Token::Type::NUMBER, scanState, - integerLiteralSuffixParsers.at(suffix)(scanState.getLexeme().substr(2), 16)); + scanIntegerSuffix(scanState); + emplaceToken(tokens, Token::Type::NUMBER, scanState); } // Otherwise, if this is an octal integer else if(c == '0' && isodigit(scanState.peek())){ @@ -232,50 +219,20 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) } } - // If literal has floating point suffix - if(std::tolower(scanState.peek()) == 'f') { - // Add single-precision token - emplaceToken(tokens, Token::Type::NUMBER, scanState, - Utils::toCharsThrow(scanState.getLexeme())); - - // Advance - // **NOTE** we do this AFTER parsing float as std::to_chars doesn't deal with suffixes - scanState.advance(); - } - // Otherwise, if literal has double precision suffix - // **NOTE** this is a GeNN extension not standard C - else if(std::tolower(scanState.peek()) == 'd') { - emplaceToken(tokens, Token::Type::NUMBER, scanState, - Utils::toCharsThrow(scanState.getLexeme())); - - // Advance - // **NOTE** we do this AFTER parsing float as std::to_chars doesn't deal with suffixes + // Read possible floating point suffix + // **NOTE** 'd' is a GeNN extension not standard C + if (std::tolower(scanState.peek()) == 'f' || std::tolower(scanState.peek()) == 'd') { scanState.advance(); } - // Otherwise, this is a scalar literal - else { - // If the scalar type is float, add single-precision token - if(scanState.getScalarType()->getName() == "float") { - emplaceToken(tokens, Token::Type::NUMBER, scanState, - Utils::toCharsThrow(scanState.getLexeme())); - } - // Otherwise, add double-precision token - else if(scanState.getScalarType()->getName() == "double") { - emplaceToken(tokens, Token::Type::NUMBER, scanState, - Utils::toCharsThrow(scanState.getLexeme())); - } - else { - assert(false); - } - } + // Emplace token + emplaceToken(tokens, Token::Type::NUMBER, scanState); } // Otherwise, number is integer else { // Add integer token - const auto suffix = scanIntegerSuffix(scanState); - emplaceToken(tokens, Token::Type::NUMBER, scanState, - integerLiteralSuffixParsers.at(suffix)(scanState.getLexeme(), 10)); + scanIntegerSuffix(scanState); + emplaceToken(tokens, Token::Type::NUMBER, scanState); } } } @@ -292,6 +249,10 @@ void scanIdentifier(ScanState &scanState, std::vector &tokens) if(k != keywords.cend()) { emplaceToken(tokens, k->second, scanState); } + // Otherwise, if identifier is typedef, add type specifier token + else if (scanState.isTypedefIdentifier(scanState.getLexeme())) { + emplaceToken(tokens, Token::Type::TYPE_SPECIFIER, scanState); + } // Otherwise, add identifier token else { emplaceToken(tokens, Token::Type::IDENTIFIER, scanState); @@ -466,11 +427,11 @@ void scanToken(ScanState &scanState, std::vector &tokens) //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Scanner { -std::vector scanSource(const std::string_view &source, const Type::NumericBase *scalarType, ErrorHandlerBase &errorHandler) +std::vector scanSource(const std::string_view &source, const std::unordered_set &typedefNames, ErrorHandlerBase &errorHandler) { std::vector tokens; - ScanState scanState(source, scalarType, errorHandler); + ScanState scanState(source, typedefNames, errorHandler); // Scan tokens while(!scanState.isAtEnd()) { From 770ad33ba0a95f352c1888a15f266a04c5ce2b14 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 17 Jan 2023 13:43:27 +0000 Subject: [PATCH 061/725] parser leaves token string contents alone in literal token --- include/genn/genn/transpiler/expression.h | 6 +++--- include/genn/genn/transpiler/parser.h | 7 ++++--- src/genn/genn/transpiler/parser.cc | 19 ++++++------------- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index 9963240520..ecdc1d48d6 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -186,16 +186,16 @@ class Grouping : public Base class Literal : public Base { public: - Literal(Token::LiteralValue value) + Literal(std::string_view value) : m_Value(value) {} virtual void accept(Visitor &visitor) const final; - Token::LiteralValue getValue() const { return m_Value; } + std::string_view getValue() const { return m_Value; } private: - const Token::LiteralValue m_Value; + const std::string_view m_Value; }; //--------------------------------------------------------------------------- diff --git a/include/genn/genn/transpiler/parser.h b/include/genn/genn/transpiler/parser.h index 6068ec65ca..44d2119c92 100644 --- a/include/genn/genn/transpiler/parser.h +++ b/include/genn/genn/transpiler/parser.h @@ -3,6 +3,7 @@ // Standard C++ includes #include #include +#include #include // Transpiler includes @@ -22,16 +23,16 @@ class ErrorHandlerBase; namespace GeNN::Transpiler::Parser { //! Parse expression from tokens -Expression::ExpressionPtr parseExpression(const std::vector &tokens, const GeNN::Type::NumericBase *scalarType, +Expression::ExpressionPtr parseExpression(const std::vector &tokens, const std::unordered_set &typedefNames, ErrorHandlerBase &errorHandler); //! Parse block item list from tokens /*! Block item lists are function body scope list of statements */ -Statement::StatementList parseBlockItemList(const std::vector &tokens, const GeNN::Type::NumericBase *scalarType, +Statement::StatementList parseBlockItemList(const std::vector &tokens, const std::unordered_set &typedefNames, ErrorHandlerBase &errorHandler); //! Parse type from tokens const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, - const GeNN::Type::NumericBase *scalarType, ErrorHandlerBase &errorHandler); + const std::unordered_set &typedefNames, ErrorHandlerBase &errorHandler); } // MiniParse::MiniParse diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 6ba80ce2b9..216c17fa22 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -38,8 +38,8 @@ class ParseError class ParserState { public: - ParserState(const std::vector &tokens, const GeNN::Type::NumericBase *scalarType, ErrorHandlerBase &errorHandler) - : m_Current(0), m_Tokens(tokens), m_ScalarType(scalarType), m_ErrorHandler(errorHandler) + ParserState(const std::vector &tokens, const std::unordered_set &typedefNames, ErrorHandlerBase &errorHandler) + : m_Current(0), m_Tokens(tokens), m_TypedefNames(typedefNames), m_ErrorHandler(errorHandler) {} //--------------------------------------------------------------------------- @@ -128,8 +128,7 @@ class ParserState bool isAtEnd() const { return (peek().type == Token::Type::END_OF_FILE); } - const GeNN::Type::NumericBase *getScalarType() const{ return m_ScalarType; } - + private: //--------------------------------------------------------------------------- // Members @@ -137,7 +136,7 @@ class ParserState size_t m_Current; const std::vector &m_Tokens; - const GeNN::Type::NumericBase *m_ScalarType; + const std::unordered_set m_TypedefNames; ErrorHandlerBase &m_ErrorHandler; }; @@ -239,14 +238,8 @@ Expression::ExpressionPtr parsePrimary(ParserState &parserState) // identifier // constant // "(" expression ")" - if(parserState.match(Token::Type::FALSE)) { - return std::make_unique(false); - } - else if(parserState.match(Token::Type::TRUE)) { - return std::make_unique(true); - } - else if(parserState.match(Token::Type::NUMBER)) { - return std::make_unique(parserState.previous().literalValue); + if (parserState.match({Token::Type::FALSE, Token::Type::TRUE, Token::Type::NUMBER})) { + return std::make_unique(parserState.previous().lexeme); } else if(parserState.match(Token::Type::IDENTIFIER)) { return std::make_unique(parserState.previous()); From fc08ec3011962b5916805774fc4e25411ce7f970 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 17 Jan 2023 13:44:06 +0000 Subject: [PATCH 062/725] typedef type and added type context argument to all type virtuals --- include/genn/genn/type.h | 109 ++++++++++++++++++++++++++------------ src/genn/genn/type.cc | 110 +++++++++++++++++++++++++++++++-------- 2 files changed, 164 insertions(+), 55 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 0f278ea3d1..e2a566ee15 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -11,6 +11,8 @@ #include #include #include +#include +#include #include // GeNN includes @@ -37,7 +39,7 @@ { \ DECLARE_TYPE(TYPE) \ TYPE(Qualifier qualifiers = Qualifier{0}) : Numeric(qualifiers){} \ - virtual std::string getName() const final{ return #UNDERLYING_TYPE; } \ + virtual std::string getName(const TypeContext&) const final{ return #UNDERLYING_TYPE; } \ virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new TYPE(qualifiers); } \ }; \ template<> \ @@ -68,6 +70,8 @@ struct TypeTraits { }; +typedef std::unordered_map TypeContext; + //---------------------------------------------------------------------------- // GeNN::Type::Qualifier //---------------------------------------------------------------------------- @@ -99,14 +103,14 @@ class Base // Declared virtuals //------------------------------------------------------------------------ //! Get the (unqualified) name of this type - virtual std::string getName() const = 0; + virtual std::string getName(const TypeContext &context) const = 0; + //! Get size of this type in bytes + virtual size_t getSizeBytes(const TypeContext &context) const = 0; + //! Return new version of this type with specified qualifiers virtual Base *getQualifiedType(Qualifier qualifiers) const = 0; - //! Get size of this type in bytes - virtual size_t getSizeBytes() const = 0; - //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ @@ -139,10 +143,10 @@ class Pointer : public Base //------------------------------------------------------------------------ // Base virtuals //------------------------------------------------------------------------ - virtual std::string getName() const{ return getValueType()->getName() + "*";} + virtual std::string getName(const TypeContext &context) const{ return getValueType()->getName(context) + "*";} + virtual size_t getSizeBytes(const TypeContext&) const final{ return sizeof(char*); } virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new Pointer(m_ValueType, qualifiers); } - virtual size_t getSizeBytes() const final{ return sizeof(char*); } - + //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ @@ -166,12 +170,12 @@ class NumericBase : public Base //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ - virtual int getRank() const = 0; - virtual double getMin() const = 0; - virtual double getMax() const = 0; - virtual double getLowest() const = 0; - virtual bool isSigned() const = 0; - virtual bool isIntegral() const = 0; + virtual int getRank(const TypeContext&) const = 0; + virtual double getMin(const TypeContext&) const = 0; + virtual double getMax(const TypeContext&) const = 0; + virtual double getLowest(const TypeContext&) const = 0; + virtual bool isSigned(const TypeContext&) const = 0; + virtual bool isIntegral(const TypeContext&) const = 0; }; //---------------------------------------------------------------------------- @@ -186,17 +190,56 @@ class Numeric : public NumericBase //------------------------------------------------------------------------ // Base virtuals //------------------------------------------------------------------------ - virtual size_t getSizeBytes() const final{ return sizeof(T); } + virtual size_t getSizeBytes(const TypeContext&) const final{ return sizeof(T); } //------------------------------------------------------------------------ // NumericBase virtuals //------------------------------------------------------------------------ - virtual int getRank() const final { return Rank; } - virtual double getMin() const final { return std::numeric_limits::min(); } - virtual double getMax() const final { return std::numeric_limits::max(); } - virtual double getLowest() const final { return std::numeric_limits::lowest(); } - virtual bool isSigned() const final { return std::is_signed::value; } - virtual bool isIntegral() const final { return std::is_integral::value; } + virtual int getRank(const TypeContext&) const final { return Rank; } + virtual double getMin(const TypeContext&) const final { return std::numeric_limits::min(); } + virtual double getMax(const TypeContext&) const final { return std::numeric_limits::max(); } + virtual double getLowest(const TypeContext&) const final { return std::numeric_limits::lowest(); } + virtual bool isSigned(const TypeContext&) const final { return std::is_signed::value; } + virtual bool isIntegral(const TypeContext&) const final { return std::is_integral::value; } +}; + +//---------------------------------------------------------------------------- +// GeNN::Type::NumericTypedef +//---------------------------------------------------------------------------- +class NumericTypedef : public NumericBase +{ +public: + NumericTypedef(const std::string &name, Qualifier qualifiers = Qualifier{0}) + : m_Name(name), NumericBase(qualifiers){} + + //------------------------------------------------------------------------ + // Base virtuals + //------------------------------------------------------------------------ + virtual std::string getName(const TypeContext &context) const final; + virtual size_t getSizeBytes(const TypeContext &context) const final; + + virtual Base *getQualifiedType(Qualifier qualifiers) const final; + + //------------------------------------------------------------------------ + // NumericBase virtuals + //------------------------------------------------------------------------ + virtual int getRank(const TypeContext &context) const final; + virtual double getMin(const TypeContext &context) const final; + virtual double getMax(const TypeContext &context) const final; + virtual double getLowest(const TypeContext &context) const final; + virtual bool isSigned(const TypeContext &context) const final; + virtual bool isIntegral(const TypeContext &context) const final; + +private: + //------------------------------------------------------------------------ + // Private methods + //------------------------------------------------------------------------ + const Type::NumericBase *getNumeric(const TypeContext &context) const; + + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + const std::string m_Name; }; //---------------------------------------------------------------------------- @@ -226,15 +269,15 @@ class ForeignFunction : public ForeignFunctionBase //------------------------------------------------------------------------ // Base virtuals //------------------------------------------------------------------------ - virtual std::string getName() const final + virtual std::string getName(const TypeContext &context) const final { - std::string typeName = getReturnType()->getName() + "("; - updateTypeName(typeName); + std::string typeName = getReturnType()->getName(context) + "("; + updateTypeName(context, typeName); typeName += ")"; return typeName; } - virtual size_t getSizeBytes() const final + virtual size_t getSizeBytes(const TypeContext&) const final { assert(false); return 0; @@ -262,15 +305,15 @@ class ForeignFunction : public ForeignFunctionBase //------------------------------------------------------------------------ template - static void updateTypeName(std::string &typeName) + static void updateTypeName(const TypeContext &context, std::string &typeName) { // Add argument typename to string - typeName += T::getInstance()->getName(); + typeName += T::getInstance()->getName(context); // If there are more arguments left in pack, add comma and recurse if constexpr (sizeof...(Args)) { typeName += ", "; - updateTypeName(typeName); + updateTypeName(context, typeName); } } @@ -310,17 +353,17 @@ DECLARE_FOREIGN_FUNCTION_TYPE(Exp, Double, Double); DECLARE_FOREIGN_FUNCTION_TYPE(Sqrt, Double, Double); //! Parse a numeric type -const NumericBase *parseNumeric(std::string_view typeString, const NumericBase *scalarType); +const NumericBase *parseNumeric(std::string_view typeString, const std::unordered_set &typedefNames); //! Parse a numeric pointer type -const Pointer *parseNumericPtr(std::string_view typeString, const NumericBase *scalarType); +const Pointer *parseNumericPtr(std::string_view typeString, const std::unordered_set &typedefNames); //! Look up numeric type based on set of type specifiers -const NumericBase *getNumericType(const std::set &typeSpecifiers, const NumericBase *scalarType); +const NumericBase *getNumericType(const std::set &typeSpecifiers, const std::unordered_set &typedefNames); //! Apply C type promotion rules to numeric type -const NumericBase *getPromotedType(const NumericBase *type); +const NumericBase *getPromotedType(const NumericBase *type, const TypeContext &context); //! Apply C rules to get common type between numeric types a and b -const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b); +const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b, const TypeContext &context); } // namespace GeNN::Type diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 5a092a5a61..0764eca3a8 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -87,19 +87,83 @@ const Base *Base::getPointerType(Qualifier qualifiers) const return new Pointer(this, qualifiers); } +//---------------------------------------------------------------------------- +// GeNN::Type::NumericTypedef +//---------------------------------------------------------------------------- +std::string NumericTypedef::getName(const TypeContext &context) const +{ + return getNumeric(context)->getName(context); +} +//---------------------------------------------------------------------------- +size_t NumericTypedef::getSizeBytes(const TypeContext &context) const +{ + return getNumeric(context)->getSizeBytes(context); +} +//---------------------------------------------------------------------------- +Base *NumericTypedef::getQualifiedType(Qualifier qualifiers) const +{ + return new NumericTypedef(m_Name, qualifiers); +} +//---------------------------------------------------------------------------- +int NumericTypedef::getRank(const TypeContext &context) const +{ + return getNumeric(context)->getRank(context); +} +//---------------------------------------------------------------------------- +double NumericTypedef::getMin(const TypeContext &context) const +{ + return getNumeric(context)->getMin(context); +} +//---------------------------------------------------------------------------- +double NumericTypedef::getMax(const TypeContext &context) const +{ + return getNumeric(context)->getMax(context); +} +//---------------------------------------------------------------------------- +double NumericTypedef::getLowest(const TypeContext &context) const +{ + return getNumeric(context)->getLowest(context); +} +//---------------------------------------------------------------------------- +bool NumericTypedef::isSigned(const TypeContext &context) const +{ + return getNumeric(context)->getSizeBytes(context); +} +//---------------------------------------------------------------------------- +bool NumericTypedef::isIntegral(const TypeContext &context) const +{ + return getNumeric(context)->isIntegral(context); +} +//---------------------------------------------------------------------------- +const Type::NumericBase *NumericTypedef::getNumeric(const TypeContext &context) const +{ + const auto t = context.find(m_Name); + if (t == context.cend()) { + throw std::runtime_error("No context for typedef '" + m_Name + "'"); + } + else { + const NumericBase *numericType = dynamic_cast(t->second); + if (numericType) { + return numericType; + } + else { + throw std::runtime_error("Numeric typedef '" + m_Name + "' resolved to non-numeric type '" + t->second->getName(context) + "'"); + } + } +} //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- -const NumericBase *parseNumeric(std::string_view typeString, const NumericBase *scalarType) +const NumericBase *parseNumeric(std::string_view typeString, const std::unordered_set &typedefNames) { using namespace Transpiler; // Scan type SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, scalarType, errorHandler); + const auto tokens = Scanner::scanSource(typeString, typedefNames, errorHandler); // Parse type and cast to numeric - const auto *type = dynamic_cast(Parser::parseType(tokens, false, scalarType, + const auto *type = dynamic_cast(Parser::parseType(tokens, false, typedefNames, errorHandler)); // If an error was encountered while scanning or parsing, throw exception @@ -114,16 +178,16 @@ const NumericBase *parseNumeric(std::string_view typeString, const NumericBase * return type; } //---------------------------------------------------------------------------- -const Pointer *parseNumericPtr(std::string_view typeString, const NumericBase *scalarType) +const Pointer *parseNumericPtr(std::string_view typeString, const std::unordered_set &typedefNames) { using namespace Transpiler; // Scan type SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, scalarType, errorHandler); + const auto tokens = Scanner::scanSource(typeString, typedefNames, errorHandler); // Parse type and cast to numeric pointer - const auto *type = dynamic_cast(Parser::parseType(tokens, true, scalarType, errorHandler)); + const auto *type = dynamic_cast(Parser::parseType(tokens, true, typedefNames, errorHandler)); // If an error was encountered while scanning or parsing, throw exception if (errorHandler.hasError()) { @@ -153,12 +217,12 @@ const NumericBase *getNumericType(const std::set &typeSpecifie } } //---------------------------------------------------------------------------- -const NumericBase *getPromotedType(const NumericBase *type) +const NumericBase *getPromotedType(const NumericBase *type, const TypeContext &context) { // If a small integer type is used in an expression, it is implicitly converted to int which is always signed. // This is known as the integer promotions or the integer promotion rule // **NOTE** this is true because in our type system unsigned short is uint16 which can be represented in int32 - if(type->getRank() < Int32::getInstance()->getRank()) { + if(type->getRank(context) < Int32::getInstance()->getRank(context)) { return Int32::getInstance(); } else { @@ -166,46 +230,48 @@ const NumericBase *getPromotedType(const NumericBase *type) } } //---------------------------------------------------------------------------- -const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b) +const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b, const TypeContext &context) { // If either type is double, common type is double - const auto &aTypeName = a->getName(); - const auto &bTypeName = b->getName(); - if(aTypeName == Double::getInstance()->getName() || bTypeName == Double::getInstance()->getName()) { + const auto &aTypeName = a->getName(context); + const auto &bTypeName = b->getName(context); + if(aTypeName == Double::getInstance()->getName(context) || bTypeName == Double::getInstance()->getName(context)) { return Double::getInstance(); } // Otherwise, if either type is float, common type is float - if(aTypeName == Float::getInstance()->getName() || bTypeName == Float::getInstance()->getName()) { + if(aTypeName == Float::getInstance()->getName(context) || bTypeName == Float::getInstance()->getName(context)) { return Float::getInstance(); } // Otherwise, must be an integer type else { // Promote both numeric types - const auto *aPromoted = getPromotedType(a); - const auto *bPromoted = getPromotedType(b); + const auto *aPromoted = getPromotedType(a, context); + const auto *bPromoted = getPromotedType(b, context); // If both promoted operands have the same type, then no further conversion is needed. - if(aPromoted->getName() == bPromoted->getName()) { + if(aPromoted->getName(context) == bPromoted->getName(context)) { return aPromoted; } // Otherwise, if both promoted operands have signed integer numeric types or both have unsigned integer numeric types, // the operand with the type of lesser integer conversion rank is converted to the type of the operand with greater rank. - else if(aPromoted->isSigned() == bPromoted->isSigned()) { - return (aPromoted->getRank() > bPromoted->getRank()) ? aPromoted : bPromoted; + else if(aPromoted->isSigned(context) == bPromoted->isSigned(context)) { + return (aPromoted->getRank(context) > bPromoted->getRank(context)) ? aPromoted : bPromoted; } // Otherwise, if signedness of promoted operands differ else { - const auto *signedOp = aPromoted->isSigned() ? aPromoted : bPromoted; - const auto *unsignedOp = aPromoted->isSigned() ? bPromoted : aPromoted; + const auto *signedOp = aPromoted->isSigned(context) ? aPromoted : bPromoted; + const auto *unsignedOp = aPromoted->isSigned(context) ? bPromoted : aPromoted; // Otherwise, if the operand that has unsigned integer type has rank greater or equal to the rank of the type of the other operand, // then the operand with signed integer type is converted to the type of the operand with unsigned integer type. - if(unsignedOp->getRank() >= signedOp->getRank()) { + if(unsignedOp->getRank(context) >= signedOp->getRank(context)) { return unsignedOp; } // Otherwise, if the type of the operand with signed integer type can represent all of the values of the type of the operand with unsigned integer type, // then the operand with unsigned integer type is converted to the type of the operand with signed integer type. - else if(signedOp->getMin() <= unsignedOp->getMin() && signedOp->getMax() >= unsignedOp->getMax()) { + else if((signedOp->getMin(context) <= unsignedOp->getMin(context)) + && (signedOp->getMax(context) >= unsignedOp->getMax(context))) + { return signedOp; } // Otherwise, both operands are converted to the unsigned integer type corresponding to the type of the operand with signed integer type. From 333327ee6cae056e7505ada7e5600db6be050883 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 17 Jan 2023 17:30:03 +0000 Subject: [PATCH 063/725] WIP passing around type context --- include/genn/backends/cuda/backend.h | 2 +- include/genn/backends/opencl/backend.h | 2 +- .../backends/single_threaded_cpu/backend.h | 2 +- .../genn/genn/code_generator/backendBase.h | 8 +- .../genn/genn/code_generator/codeGenUtils.h | 6 +- .../customConnectivityUpdateGroupMerged.h | 22 ++-- .../code_generator/customUpdateGroupMerged.h | 49 +++++---- .../genn/genn/code_generator/groupMerged.h | 74 ++++++------- .../groupMergedTypeEnvironment.h | 6 +- .../genn/code_generator/initGroupMerged.h | 101 ++++++++++-------- .../genn/code_generator/modelSpecMerged.h | 77 ++++++------- .../code_generator/neuronUpdateGroupMerged.h | 9 +- .../code_generator/synapseUpdateGroupMerged.h | 38 ++++--- include/genn/genn/transpiler/parser.h | 10 +- include/genn/genn/type.h | 8 +- src/genn/backends/cuda/backend.cc | 4 +- src/genn/backends/opencl/backend.cc | 4 +- .../backends/single_threaded_cpu/backend.cc | 4 +- src/genn/genn/code_generator/codeGenUtils.cc | 18 ++-- .../customConnectivityUpdateGroupMerged.cc | 8 +- .../code_generator/customUpdateGroupMerged.cc | 4 +- .../genn/code_generator/generateRunner.cc | 52 ++++----- src/genn/genn/code_generator/groupMerged.cc | 28 ++--- .../genn/code_generator/initGroupMerged.cc | 2 +- .../genn/code_generator/modelSpecMerged.cc | 2 +- .../code_generator/neuronUpdateGroupMerged.cc | 6 +- src/genn/genn/customConnectivityUpdate.cc | 2 +- src/genn/genn/customUpdate.cc | 4 +- src/genn/genn/synapseGroup.cc | 10 +- src/genn/genn/transpiler/parser.cc | 20 ++-- src/genn/genn/type.cc | 26 ++--- 31 files changed, 313 insertions(+), 295 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index e8335bd09c..964ee8a77b 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -206,7 +206,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT const std::string &egpName) const override; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const override; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const override; //! When generating merged structures what type to use for simulation RNGs virtual const Type::Base *getMergedGroupSimRNGType() const override; diff --git a/include/genn/backends/opencl/backend.h b/include/genn/backends/opencl/backend.h index 2b657fdad3..43f8b9e7be 100644 --- a/include/genn/backends/opencl/backend.h +++ b/include/genn/backends/opencl/backend.h @@ -169,7 +169,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT const std::string &egpName) const override; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const override; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const override; //! When generating merged structures what type to use for simulation RNGs virtual const Type::Base *getMergedGroupSimRNGType() const; diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index b59cea525e..066dfbe3e2 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -78,7 +78,7 @@ class BACKEND_EXPORT Backend : public BackendBase const std::string &egpName) const override; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const override; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const override; //! When generating merged structures what type to use for simulation RNGs virtual const Type::Base *getMergedGroupSimRNGType() const override; diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 51aad00465..11423f6785 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -18,6 +18,7 @@ #include "codeStream.h" #include "gennExport.h" #include "gennUtils.h" +#include "type.h" #include "varAccess.h" #include "variableMode.h" @@ -50,11 +51,6 @@ class SynapseConnectivityInitGroupMerged; class SynapseInitGroupMerged; class SynapseSparseInitGroupMerged; } - -namespace Type -{ -class Base; -} } //-------------------------------------------------------------------------- @@ -267,7 +263,7 @@ class GENN_EXPORT BackendBase const std::string &egpName) const = 0; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const = 0; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const = 0; //! When generating merged structures what type to use for simulation RNGs virtual const Type::Base *getMergedGroupSimRNGType() const = 0; diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index fcec71cffa..0a01ff429d 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -12,6 +12,7 @@ #include "gennExport.h" #include "gennUtils.h" #include "neuronGroupInternal.h" +#include "type.h" #include "variableMode.h" // GeNN code generator includes @@ -77,12 +78,13 @@ GENN_EXPORT std::string ensureFtype(const std::string &oldcode, const std::strin //-------------------------------------------------------------------------- //! \brief Get the initial value to start reduction operations from //-------------------------------------------------------------------------- -GENN_EXPORT std::string getReductionInitialValue(const BackendBase &backend, VarAccessMode access, const std::string &type); +GENN_EXPORT std::string getReductionInitialValue(VarAccessMode access, const Type::NumericBase *type, const Type::TypeContext &context); //-------------------------------------------------------------------------- //! \brief Generate a reduction operation to reduce value into reduction //-------------------------------------------------------------------------- -GENN_EXPORT std::string getReductionOperation(const std::string &reduction, const std::string &value, VarAccessMode access, const std::string &type); +GENN_EXPORT std::string getReductionOperation(const std::string &reduction, const std::string &value, VarAccessMode access, + const Type::NumericBase *type, const Type::TypeContext &context); //-------------------------------------------------------------------------- /*! \brief This function checks for unknown variable definitions and returns a gennError if any are found diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index d66e26abcd..3474e10331 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -40,12 +40,13 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public CustomConnectivit //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, - runnerVarDecl, runnerMergedStructAlloc, name); + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + runnerVarDecl, runnerMergedStructAlloc, name); } void generateUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; @@ -79,12 +80,13 @@ class GENN_EXPORT CustomConnectivityHostUpdateGroupMerged : public CustomConnect //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, - runnerVarDecl, runnerMergedStructAlloc, name, true); + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + runnerVarDecl, runnerMergedStructAlloc, name, true); } void generateUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged) const; diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index 66ed67cf6e..b57a0341e6 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -23,11 +23,12 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged const CustomUpdateModels::Base *cm = this->getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { if(v.access & VarAccessModeAttribute::REDUCE) { - this->addPointerField(Type::parseNumeric(v.type, this->getScalarType()), v.name, backend.getDeviceVarPrefix() + v.name); + this->addPointerField(Type::parseNumeric(v.type), v.name, backend.getDeviceVarPrefix() + v.name); } } // Loop through variable references and add pointers if they are reduction targets for(const auto &v : cm->getVarRefs()) { if(v.access & VarAccessModeAttribute::REDUCE) { - this->addPointerField(Type::parseNumeric(v.type, this->getScalarType()), v.name, backend.getDeviceVarPrefix() + v.name); + this->addPointerField(Type::parseNumeric(v.type), v.name, backend.getDeviceVarPrefix() + v.name); } } } @@ -195,11 +198,12 @@ class GENN_EXPORT CustomUpdateHostReductionGroupMerged : public CustomUpdateHost //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name, true); } @@ -221,11 +225,12 @@ class GENN_EXPORT CustomWUUpdateHostReductionGroupMerged : public CustomUpdateHo //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name, true); } diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index f167918a73..331f40d1ec 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -67,7 +67,7 @@ class GroupMerged // **HACK** type should come in as type not string GroupMerged(size_t index, const std::string &precision, const std::vector> groups) - : m_Index(index), m_LiteralSuffix((precision == "float") ? "f" : ""), m_ScalarType(Type::parseNumeric(precision, nullptr)), m_Groups(std::move(groups)) + : m_Index(index), m_LiteralSuffix((precision == "float") ? "f" : ""), m_ScalarType(Type::parseNumeric(precision)), m_Groups(std::move(groups)) {} //------------------------------------------------------------------------ @@ -88,29 +88,29 @@ class GroupMerged const std::vector &getFields() const{ return m_Fields; } //! Get group fields, sorted into order they will appear in struct - std::vector getSortedFields(const BackendBase &backend) const + std::vector getSortedFields(const BackendBase &backend, const Type::TypeContext &context) const { // Make a copy of fields and sort so largest come first. This should mean that due // to structure packing rules, significant memory is saved and estimate is more precise auto sortedFields = m_Fields; std::sort(sortedFields.begin(), sortedFields.end(), - [&backend](const Field &a, const Field &b) + [&backend, &context](const Field &a, const Field &b) { - return (std::get<0>(a)->getSizeBytes() > std::get<0>(b)->getSizeBytes()); + return (std::get<0>(a)->getSizeBytes(context) > std::get<0>(b)->getSizeBytes(context)); }); return sortedFields; } //! Generate declaration of struct to hold this merged group - void generateStruct(CodeStream &os, const BackendBase &backend, const std::string &name, - bool host = false) const + void generateStruct(CodeStream &os, const BackendBase &backend, const Type::TypeContext &context, + const std::string &name, bool host = false) const { os << "struct Merged" << name << "Group" << getIndex() << std::endl; { // Loop through fields and write to structure CodeStream::Scope b(os); - const auto sortedFields = getSortedFields(backend); + const auto sortedFields = getSortedFields(backend, context); for(const auto &f : sortedFields) { // If field is a pointer and not marked as being a host field // (in which case the backend should leave its type alone!) @@ -118,16 +118,16 @@ class GroupMerged if(dynamic_cast(type) && !(std::get<3>(f) & GroupMergedFieldType::HOST)) { // If we are generating a host structure, allow the backend to override the type if(host) { - os << backend.getMergedGroupFieldHostTypeName(type); + os << backend.getMergedGroupFieldHostTypeName(type, context); } // Otherwise, allow the backend to add a prefix else { - os << backend.getPointerPrefix() << type->getName(); + os << backend.getPointerPrefix() << type->getName(context); } } // Otherwise, leave the type alone else { - os << type->getName(); + os << type->getName(context); } os << " " << std::get<1>(f) << ";" << std::endl; } @@ -137,28 +137,28 @@ class GroupMerged os << ";" << std::endl; } - void generateStructFieldArgumentDefinitions(CodeStream &os, const BackendBase &backend) const + void generateStructFieldArgumentDefinitions(CodeStream &os, const BackendBase &backend, const Type::TypeContext &context) const { // Get sorted fields - const auto sortedFields = getSortedFields(backend); + const auto sortedFields = getSortedFields(backend, context); for(size_t fieldIndex = 0; fieldIndex < sortedFields.size(); fieldIndex++) { const auto &f = sortedFields[fieldIndex]; - os << backend.getMergedGroupFieldHostTypeName(std::get<0>(f)) << " " << std::get<1>(f); + os << backend.getMergedGroupFieldHostTypeName(std::get<0>(f), context) << " " << std::get<1>(f); if(fieldIndex != (sortedFields.size() - 1)) { os << ", "; } } } - size_t getStructArraySize(const BackendBase &backend) const + size_t getStructArraySize(const BackendBase &backend, const Type::TypeContext &context) const { // Loop through fields again to generate any EGP pushing functions that are required and to calculate struct size size_t structSize = 0; size_t largestFieldSize = 0; - const auto sortedFields = getSortedFields(backend); + const auto sortedFields = getSortedFields(backend, context); for(const auto &f : sortedFields) { // Add size of field to total - const size_t fieldSize = std::get<0>(f)->getSizeBytes(); + const size_t fieldSize = std::get<0>(f)->getSizeBytes(context); structSize += fieldSize; // Update largest field size @@ -281,7 +281,7 @@ class GroupMerged { // Loop through variables for(const auto &v : vars) { - addPointerField(Type::parseNumeric(v.type, getScalarType()), v.name, arrayPrefix + v.name); + addPointerField(Type::parseNumeric(v.type), v.name, arrayPrefix + v.name); } } @@ -290,7 +290,7 @@ class GroupMerged { // Loop through variables for(const auto &v : varReferences) { - addField(Type::parseNumeric(v.type, getScalarType())->getPointerType(), v.name, + addField(Type::parseNumeric(v.type)->getPointerType(), v.name, [getVarRefFn, arrayPrefix, v](const G &g, size_t) { const auto varRef = getVarRefFn(g).at(v.name); @@ -303,7 +303,7 @@ class GroupMerged { for(const auto &e : egps) { assert(Utils::isTypePointer(e.type)); - addField(Type::parseNumericPtr(e.type, getScalarType()), e.name + varName, + addField(Type::parseNumericPtr(e.type), e.name + varName, [e, arrayPrefix, varName](const G &g, size_t) { return arrayPrefix + e.name + varName + g.getName(); }, GroupMergedFieldType::DYNAMIC); } @@ -459,19 +459,19 @@ class GroupMerged } } - void generateRunnerBase(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc, - const std::string &name, bool host = false) const + void generateRunnerBase(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc, const std::string &name, bool host = false) const { // Make a copy of fields and sort so largest come first. This should mean that due // to structure packing rules, significant memory is saved and estimate is more precise - auto sortedFields = getSortedFields(backend); + auto sortedFields = getSortedFields(backend, context); // If this isn't a host merged structure, generate definition for function to push group if(!host) { definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << "Group" << getIndex() << "ToDevice(unsigned int idx, "; - generateStructFieldArgumentDefinitions(definitionsInternalFunc, backend); + generateStructFieldArgumentDefinitions(definitionsInternalFunc, backend, context); definitionsInternalFunc << ");" << std::endl; } @@ -480,7 +480,7 @@ class GroupMerged // If this field is a dynamic pointer if((std::get<3>(f) & GroupMergedFieldType::DYNAMIC) && dynamic_cast(std::get<0>(f))) { definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << getIndex() << std::get<1>(f) << "ToDevice(unsigned int idx, "; - definitionsInternalFunc << backend.getMergedGroupFieldHostTypeName(std::get<0>(f)) << " value);" << std::endl; + definitionsInternalFunc << backend.getMergedGroupFieldHostTypeName(std::get<0>(f), context) << " value);" << std::endl; } // Raise error if this field is a host field but this isn't a host structure @@ -491,7 +491,7 @@ class GroupMerged if(host) { // Generate struct directly into internal definitions // **NOTE** we ignore any backend prefix as we're generating this struct for use on the host - generateStruct(definitionsInternal, backend, name, true); + generateStruct(definitionsInternal, backend, context, name, true); // Declare array of these structs containing individual neuron group pointers etc runnerVarDecl << "Merged" << name << "Group" << getIndex() << " merged" << name << "Group" << getIndex() << "[" << getGroups().size() << "];" << std::endl; @@ -561,11 +561,12 @@ class GENN_EXPORT NeuronSpikeQueueUpdateGroupMerged : public GroupMergedgetQualifiedType(Type::Qualifier::CONSTANT) : type; @@ -210,7 +210,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa void defineEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "") { for(const auto &e : egps) { - const auto *type = Type::parseNumericPtr(e.type, m_ScalarType); + const auto *type = Type::parseNumericPtr(e.type); defineField(type, e.name, type, e.name + varName, [arrayPrefix, e, varName](const auto &g, size_t) diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 0c4c0280b7..007b3c6d4f 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -20,11 +20,12 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //! Get hash digest used for detecting changes boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -105,11 +106,12 @@ class GENN_EXPORT SynapseInitGroupMerged : public SynapseGroupMergedBase return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::Init); } - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -137,11 +139,12 @@ class GENN_EXPORT SynapseSparseInitGroupMerged : public SynapseGroupMergedBase return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::SparseInit); } - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -169,11 +172,12 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMerged return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::ConnectivityInit); } - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -207,11 +211,12 @@ class GENN_EXPORT SynapseConnectivityHostInitGroupMerged : public GroupMerged // If we're not initialising or if there is initialization code for this variable const auto &varInit = archetypeAdaptor.getVarInitialisers().at(var.name); if (!varInit.getSnippet()->getCode().empty()) { - this->addPointerField(Type::parseNumeric(var.type, this->getScalarType()), var.name, backend.getDeviceVarPrefix() + var.name); + this->addPointerField(Type::parseNumeric(var.type), var.name, backend.getDeviceVarPrefix() + var.name); } // Add any var init EGPs to structure @@ -343,11 +348,12 @@ class GENN_EXPORT CustomUpdateInitGroupMerged : public CustomUpdateInitGroupMerg //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -375,11 +381,12 @@ class GENN_EXPORT CustomWUUpdateInitGroupMerged : public CustomUpdateInitGroupMe //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -433,11 +440,12 @@ class GENN_EXPORT CustomWUUpdateSparseInitGroupMerged : public CustomUpdateInitG //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -464,11 +472,12 @@ class GENN_EXPORT CustomConnectivityUpdatePreInitGroupMerged : public CustomUpda //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -495,11 +504,12 @@ class GENN_EXPORT CustomConnectivityUpdatePostInitGroupMerged : public CustomUpd //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -526,11 +536,12 @@ class GENN_EXPORT CustomConnectivityUpdateSparseInitGroupMerged : public CustomU //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 7af74bb2f7..98e2e2f7c3 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -40,11 +40,11 @@ class GENN_EXPORT ModelSpecMerged //! Immutable structure for tracking fields of merged group structure containing EGPs struct EGPField { - EGPField(size_t m, const Type::Base *t, const std::string &f, bool h) - : mergedGroupIndex(m), type(t), fieldName(f), hostGroup(h) {} + EGPField(size_t m, const std::string &t, const std::string &f, bool h) + : mergedGroupIndex(m), typeName(t), fieldName(f), hostGroup(h) {} const size_t mergedGroupIndex; - const Type::Base *type; + const std::string typeName; const std::string fieldName; const bool hostGroup; @@ -52,8 +52,8 @@ class GENN_EXPORT ModelSpecMerged //! lexicographically compares all three struct members bool operator < (const EGPField &other) const { - return (std::make_tuple(mergedGroupIndex, type->getName(), fieldName, hostGroup) - < std::make_tuple(other.mergedGroupIndex, other.type->getName(), other.fieldName, other.hostGroup)); + return (std::tie(mergedGroupIndex, typeName, fieldName, hostGroup) + < std::tie(other.mergedGroupIndex, other.typeName, other.fieldName, other.hostGroup)); } }; @@ -63,7 +63,7 @@ class GENN_EXPORT ModelSpecMerged //! Immutable structure for tracking where an extra global variable ends up after merging struct MergedEGP : public EGPField { - MergedEGP(size_t m, size_t g, const Type::Base *t, const std::string &f, bool h) + MergedEGP(size_t m, size_t g, const std::string &t, const std::string &f, bool h) : EGPField(m, t, f, h), groupIndex(g) {} const size_t groupIndex; @@ -82,6 +82,9 @@ class GENN_EXPORT ModelSpecMerged //-------------------------------------------------------------------------- //! Get underlying, unmerged model const ModelSpecInternal &getModel() const{ return m_Model; } + + //! Get type context used to resolve all types used in model + const Type::TypeContext &getTypeContext() const{ return m_TypeContext; } //! Get merged neuron groups which require updating const std::vector &getMergedNeuronUpdateGroups() const{ return m_MergedNeuronUpdateGroups; } @@ -158,31 +161,31 @@ class GENN_EXPORT ModelSpecMerged //! Get merged custom connectivity update groups where host processing needs to be performed const std::vector &getMergedCustomConnectivityHostUpdateGroups() const { return m_MergedCustomConnectivityHostUpdateGroups; } - void genMergedNeuronUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedNeuronUpdateGroups); } - void genMergedPresynapticUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedPresynapticUpdateGroups); } - void genMergedPostsynapticUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedPostsynapticUpdateGroups); } - void genMergedSynapseDynamicsGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseDynamicsGroups); } - void genMergedNeuronInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedNeuronInitGroups); } - void genMergedCustomUpdateInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomUpdateInitGroups); } - void genMergedCustomWUUpdateInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomWUUpdateInitGroups); } - void genMergedSynapseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseInitGroups); } - void genMergedSynapseConnectivityInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseConnectivityInitGroups); } - void genMergedSynapseSparseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseSparseInitGroups); } - void genMergedCustomWUUpdateSparseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomWUUpdateSparseInitGroups); } - void genMergedCustomConnectivityUpdatePreInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomConnectivityUpdatePreInitGroups); } - void genMergedCustomConnectivityUpdatePostInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomConnectivityUpdatePostInitGroups); } - void genMergedCustomConnectivityUpdateSparseInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomConnectivityUpdateSparseInitGroups); } - void genMergedNeuronSpikeQueueUpdateStructs(CodeStream &os, const BackendBase &backend) const{ genMergedStructures(os, backend, m_MergedNeuronSpikeQueueUpdateGroups); } - void genMergedNeuronPrevSpikeTimeUpdateStructs(CodeStream &os, const BackendBase &backend) const{ genMergedStructures(os, backend, m_MergedNeuronPrevSpikeTimeUpdateGroups); } - void genMergedSynapseDendriticDelayUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseDendriticDelayUpdateGroups); } - void genMergedSynapseConnectivityHostInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseConnectivityHostInitGroups); } - void genMergedCustomUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomUpdateGroups); } - void genMergedCustomUpdateWUStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomUpdateWUGroups); } - void genMergedCustomUpdateTransposeWUStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomUpdateTransposeWUGroups); } - void genMergedCustomUpdateHostReductionStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomUpdateHostReductionGroups); } - void genMergedCustomWUUpdateHostReductionStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomWUUpdateHostReductionGroups); } - void genMergedCustomConnectivityUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomConnectivityUpdateGroups); } - void genMergedCustomConnectivityHostUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomConnectivityHostUpdateGroups); } + void genMergedNeuronUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedNeuronUpdateGroups); } + void genMergedPresynapticUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedPresynapticUpdateGroups); } + void genMergedPostsynapticUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedPostsynapticUpdateGroups); } + void genMergedSynapseDynamicsGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedSynapseDynamicsGroups); } + void genMergedNeuronInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedNeuronInitGroups); } + void genMergedCustomUpdateInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomUpdateInitGroups); } + void genMergedCustomWUUpdateInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomWUUpdateInitGroups); } + void genMergedSynapseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedSynapseInitGroups); } + void genMergedSynapseConnectivityInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedSynapseConnectivityInitGroups); } + void genMergedSynapseSparseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedSynapseSparseInitGroups); } + void genMergedCustomWUUpdateSparseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomWUUpdateSparseInitGroups); } + void genMergedCustomConnectivityUpdatePreInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomConnectivityUpdatePreInitGroups); } + void genMergedCustomConnectivityUpdatePostInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomConnectivityUpdatePostInitGroups); } + void genMergedCustomConnectivityUpdateSparseInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomConnectivityUpdateSparseInitGroups); } + void genMergedNeuronSpikeQueueUpdateStructs(CodeStream &os, const BackendBase &backend) const{ genMergedStructures(os, backend, m_TypeContext, m_MergedNeuronSpikeQueueUpdateGroups); } + void genMergedNeuronPrevSpikeTimeUpdateStructs(CodeStream &os, const BackendBase &backend) const{ genMergedStructures(os, backend, m_TypeContext, m_MergedNeuronPrevSpikeTimeUpdateGroups); } + void genMergedSynapseDendriticDelayUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedSynapseDendriticDelayUpdateGroups); } + void genMergedSynapseConnectivityHostInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedSynapseConnectivityHostInitGroups); } + void genMergedCustomUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomUpdateGroups); } + void genMergedCustomUpdateWUStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomUpdateWUGroups); } + void genMergedCustomUpdateTransposeWUStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomUpdateTransposeWUGroups); } + void genMergedCustomUpdateHostReductionStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomUpdateHostReductionGroups); } + void genMergedCustomWUUpdateHostReductionStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomWUUpdateHostReductionGroups); } + void genMergedCustomConnectivityUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomConnectivityUpdateGroups); } + void genMergedCustomConnectivityHostUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomConnectivityHostUpdateGroups); } void genNeuronUpdateGroupSupportCode(CodeStream &os, bool supportsNamespace = true) const{ m_NeuronUpdateSupportCode.gen(os, getModel().getPrecision(), supportsNamespace); } void genPostsynapticDynamicsSupportCode(CodeStream &os, bool supportsNamespace = true) const{ m_PostsynapticDynamicsSupportCode.gen(os, getModel().getPrecision(), supportsNamespace); } @@ -240,7 +243,7 @@ class GENN_EXPORT ModelSpecMerged std::transform(groupEGPs.first, groupEGPs.second, std::inserter(mergedGroupFields, mergedGroupFields.end()), [](const MergedEGPMap::value_type::second_type::value_type &g) { - return EGPField{g.second.mergedGroupIndex, g.second.type, g.second.fieldName, g.second.hostGroup}; + return EGPField{g.second.mergedGroupIndex, g.second.typeName, g.second.fieldName, g.second.hostGroup}; }); } @@ -283,11 +286,11 @@ class GENN_EXPORT ModelSpecMerged // Private methods //-------------------------------------------------------------------------- template - void genMergedStructures(CodeStream &os, const BackendBase &backend, const std::vector &mergedGroups) const + void genMergedStructures(CodeStream &os, const BackendBase &backend, const Type::TypeContext &context, const std::vector &mergedGroups) const { // Loop through all merged groups and generate struct for(const auto &g : mergedGroups) { - g.generateStruct(os, backend, T::name); + g.generateStruct(os, backend, context, T::name); } } @@ -449,8 +452,10 @@ class GENN_EXPORT ModelSpecMerged //! Unique support code strings for synapse dynamics SupportCodeMerged m_SynapseDynamicsSupportCode; - // Map containing mapping of original extra global param names to their locations within merged groups + //! Map containing mapping of original extra global param names to their locations within merged groups MergedEGPMap m_MergedEGPs; - + + //! Type context used to resolve all types used in model + const Type::TypeContext m_TypeContext; }; } // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index 5725414baf..383a689e67 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -38,11 +38,12 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //! Get hash digest used for detecting changes boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } diff --git a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h index 823abf7ec9..ce9cfc6530 100644 --- a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h @@ -22,11 +22,12 @@ class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::PresynapticUpdate); } - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -59,11 +60,12 @@ class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::PostsynapticUpdate); } - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -92,11 +94,12 @@ class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::SynapseDynamics); } - void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, - CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const + void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -120,11 +123,12 @@ class GENN_EXPORT SynapseDendriticDelayUpdateGroupMerged : public GroupMerged #include -#include #include // Transpiler includes @@ -23,16 +22,13 @@ class ErrorHandlerBase; namespace GeNN::Transpiler::Parser { //! Parse expression from tokens -Expression::ExpressionPtr parseExpression(const std::vector &tokens, const std::unordered_set &typedefNames, - ErrorHandlerBase &errorHandler); +Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandlerBase &errorHandler); //! Parse block item list from tokens /*! Block item lists are function body scope list of statements */ -Statement::StatementList parseBlockItemList(const std::vector &tokens, const std::unordered_set &typedefNames, - ErrorHandlerBase &errorHandler); +Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandlerBase &errorHandler); //! Parse type from tokens -const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, - const std::unordered_set &typedefNames, ErrorHandlerBase &errorHandler); +const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, ErrorHandlerBase &errorHandler); } // MiniParse::MiniParse diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index e2a566ee15..dd3e78d582 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -210,7 +210,7 @@ class NumericTypedef : public NumericBase { public: NumericTypedef(const std::string &name, Qualifier qualifiers = Qualifier{0}) - : m_Name(name), NumericBase(qualifiers){} + : NumericBase(qualifiers), m_Name(name){} //------------------------------------------------------------------------ // Base virtuals @@ -353,13 +353,13 @@ DECLARE_FOREIGN_FUNCTION_TYPE(Exp, Double, Double); DECLARE_FOREIGN_FUNCTION_TYPE(Sqrt, Double, Double); //! Parse a numeric type -const NumericBase *parseNumeric(std::string_view typeString, const std::unordered_set &typedefNames); +const NumericBase *parseNumeric(std::string_view typeString); //! Parse a numeric pointer type -const Pointer *parseNumericPtr(std::string_view typeString, const std::unordered_set &typedefNames); +const Pointer *parseNumericPtr(std::string_view typeString); //! Look up numeric type based on set of type specifiers -const NumericBase *getNumericType(const std::set &typeSpecifiers, const std::unordered_set &typedefNames); +const NumericBase *getNumericType(const std::set &typeSpecifiers); //! Apply C type promotion rules to numeric type const NumericBase *getPromotedType(const NumericBase *type, const TypeContext &context); diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index a49f122ed8..3f912a5b9d 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -1768,9 +1768,9 @@ void Backend::genMergedExtraGlobalParamPush(CodeStream &os, const std::string &s os << ", (sizeof(" << structName << ") * (" << groupIdx << ")) + offsetof(" << structName << ", " << fieldName << ")));" << std::endl; } //-------------------------------------------------------------------------- -std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) const +std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const { - return type->getName(); + return type->getName(context); } //-------------------------------------------------------------------------- const Type::Base *Backend::getMergedGroupSimRNGType() const diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index e395cca026..3d794d2335 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -2073,7 +2073,7 @@ void Backend::genMergedExtraGlobalParamPush(CodeStream &os, const std::string &s os << "CHECK_OPENCL_ERRORS(commandQueue.enqueueNDRangeKernel(" << kernelName << ", cl::NullRange, globalWorkSize, localWorkSize));" << std::endl; } //-------------------------------------------------------------------------- -std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) const +std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const { // If type is a pointer, on the host it is represented by an OpenCL buffer /*if(GeNN::Utils::isTypePointerToPointer(type)) { @@ -2084,7 +2084,7 @@ std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) con } // Otherwise, type remains the same else { - return type->getName(); + return type->getName(context); } } //-------------------------------------------------------------------------- diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 59b1d85c23..bc9ccc8ee7 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1348,9 +1348,9 @@ void Backend::genMergedExtraGlobalParamPush(CodeStream &os, const std::string &s os << "merged" << suffix << "Group" << mergedGroupIdx << "[" << groupIdx << "]." << fieldName << " = " << egpName << ";" << std::endl; } //-------------------------------------------------------------------------- -std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) const +std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const { - return type->getName(); + return type->getName(context); } //-------------------------------------------------------------------------- const Type::Base *Backend::getMergedGroupSimRNGType() const diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index 8e69b934b0..78346f3fb2 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -447,7 +447,7 @@ std::string ensureFtype(const std::string &oldcode, const std::string &type) return code; } //---------------------------------------------------------------------------- -std::string getReductionInitialValue(const BackendBase &backend, VarAccessMode access, const std::string &type) +std::string getReductionInitialValue(VarAccessMode access, const Type::NumericBase *type, const Type::TypeContext &context) { // If reduction is a sum, initialise to zero if(access & VarAccessModeAttribute::SUM) { @@ -455,7 +455,7 @@ std::string getReductionInitialValue(const BackendBase &backend, VarAccessMode a } // Otherwise, reduction is a maximum operation, return lowest value for type else if(access & VarAccessModeAttribute::MAX) { - return backend.getLowestValue(type); + return Utils::writePreciseString(type->getLowest(context)); } else { assert(false); @@ -463,7 +463,8 @@ std::string getReductionInitialValue(const BackendBase &backend, VarAccessMode a } } //---------------------------------------------------------------------------- -std::string getReductionOperation(const std::string &reduction, const std::string &value, VarAccessMode access, const std::string &type) +std::string getReductionOperation(const std::string &reduction, const std::string &value, VarAccessMode access, + const Type::NumericBase *type, const Type::TypeContext &context) { // If operation is sum, add output of custom update to sum if(access & VarAccessModeAttribute::SUM) { @@ -471,13 +472,14 @@ std::string getReductionOperation(const std::string &reduction, const std::strin } // Otherwise, if it's max else if(access & VarAccessModeAttribute::MAX) { - // If type is floating point, generate fmax call - if(Utils::isTypeFloatingPoint(type)) { - return reduction + " = " + "fmax(" + reduction + ", " + value + ")"; + // If type is integral, generate max call + if(type->isIntegral(context)) { + return reduction + " = " + "max(" + reduction + ", " + value + ")"; + } - // Otherwise, generate max call + // Otherwise, generate gmax call else { - return reduction + " = " + "max(" + reduction + ", " + value + ")"; + return reduction + " = " + "fmax(" + reduction + ", " + value + ")"; } } else { diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 047d933262..55e5c2f1d2 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -166,7 +166,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // Loop through sorted dependent variables for(size_t i = 0; i < getSortedArchetypeDependentVars().size(); i++) { - addField(parseNumeric(getSortedArchetypeDependentVars().at(i).getVar().type, getScalarType())->getPointerType(), "_dependentVar" + std::to_string(i), + addField(parseNumeric(getSortedArchetypeDependentVars().at(i).getVar().type)->getPointerType(), "_dependentVar" + std::to_string(i), [i, &backend, this](const auto&, size_t g) { const auto &varRef = m_SortedDependentVars[g][i]; @@ -440,7 +440,7 @@ CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged // Add host extra global parameters for(const auto &e : cm->getExtraGlobalParams()) { - const auto *pointerType = parseNumericPtr(e.type, getScalarType()); + const auto *pointerType = parseNumericPtr(e.type); addField(pointerType, e.name, [e](const auto &g, size_t) { return e.name + g.getName(); }, GroupMergedFieldType::HOST_DYNAMIC); @@ -562,13 +562,13 @@ void CustomConnectivityHostUpdateGroupMerged::addVars(const BackendBase &backend for(const auto &v : vars) { // If var is located on the host if (std::invoke(getVarLocationFn, getArchetype(), v.name) & VarLocation::HOST) { - addField(parseNumeric(v.type, getScalarType())->getPointerType(), v.name, + addField(parseNumeric(v.type)->getPointerType(), v.name, [v](const auto &g, size_t) { return v.name + g.getName(); }, GroupMergedFieldType::HOST); if(!backend.getDeviceVarPrefix().empty()) { // **TODO** I think could use addPointerField - addField(parseNumeric(v.type, getScalarType())->getPointerType(), backend.getDeviceVarPrefix() + v.name, + addField(parseNumeric(v.type)->getPointerType(), backend.getDeviceVarPrefix() + v.name, [v, &backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + v.name + g.getName(); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index f563d932c8..07f8193d80 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -157,7 +157,7 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string // Scan, parse and type-check update code Transpiler::ErrorHandler errorHandler; const std::string code = upgradeCodeString(cm->getUpdateCode()); - const auto tokens = Transpiler::Scanner::scanSource(code, getScalarType(), errorHandler); + const auto tokens = Transpiler::Scanner::scanSource(code, errorHandler); const auto statements = Transpiler::Parser::parseBlockItemList(tokens, getScalarType(), errorHandler); Transpiler::TypeChecker::typeCheck(statements, typeEnvironment, errorHandler); @@ -379,7 +379,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // If variable has a transpose if(getArchetype().getVarReferences().at(v.name).getTransposeSynapseGroup() != nullptr) { // Add field with transpose suffix, pointing to transpose var - addField(parseNumeric(v.type, getScalarType())->getPointerType(), v.name + "Transpose", + addField(parseNumeric(v.type)->getPointerType(), v.name + "Transpose", [&backend, v](const auto &g, size_t) { const auto varRef = g.getVarReferences().at(v.name); diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 4597505f42..927bf722fa 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -730,7 +730,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Loop through merged synapse connectivity host initialisation groups for(const auto &m : modelMerged.getMergedSynapseConnectivityHostInitGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } @@ -742,145 +742,145 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Generate merged neuron initialisation groups for(const auto &m : modelMerged.getMergedNeuronInitGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through merged synapse init groups for(const auto &m : modelMerged.getMergedSynapseInitGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through merged synapse connectivity initialisation groups for(const auto &m : modelMerged.getMergedSynapseConnectivityInitGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through merged sparse synapse init groups for(const auto &m : modelMerged.getMergedSynapseSparseInitGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Generate merged custom update initialisation groups for(const auto &m : modelMerged.getMergedCustomUpdateInitGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Generate merged custom WU update initialisation groups for(const auto &m : modelMerged.getMergedCustomWUUpdateInitGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Generate merged custom sparse WU update initialisation groups for(const auto &m : modelMerged.getMergedCustomWUUpdateSparseInitGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Generate merged custom connectivity update presynaptic initialisation groups for(const auto &m : modelMerged.getMergedCustomConnectivityUpdatePreInitGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Generate merged custom connectivity update postsynaptic initialisation groups for(const auto &m : modelMerged.getMergedCustomConnectivityUpdatePostInitGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Generate merged custom connectivity update synaptic initialisation groups for(const auto &m : modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through merged neuron update groups for(const auto &m : modelMerged.getMergedNeuronUpdateGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through merged presynaptic update groups for(const auto &m : modelMerged.getMergedPresynapticUpdateGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through merged postsynaptic update groups for(const auto &m : modelMerged.getMergedPostsynapticUpdateGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through synapse dynamics groups for(const auto &m : modelMerged.getMergedSynapseDynamicsGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through neuron groups whose previous spike times need resetting for(const auto &m : modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through neuron groups whose spike queues need resetting for(const auto &m : modelMerged.getMergedNeuronSpikeQueueUpdateGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through synapse groups whose dendritic delay pointers need updating for(const auto &m : modelMerged.getMergedSynapseDendriticDelayUpdateGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom variable update groups for(const auto &m : modelMerged.getMergedCustomUpdateGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom WU variable update groups for(const auto &m : modelMerged.getMergedCustomUpdateWUGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom WU transpose variable update groups for(const auto &m : modelMerged.getMergedCustomUpdateTransposeWUGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom update host reduction groups for(const auto &m : modelMerged.getMergedCustomUpdateHostReductionGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom weight update host reduction groups for(const auto &m : modelMerged.getMergedCustomWUUpdateHostReductionGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom connectivity update groups for(const auto &m : modelMerged.getMergedCustomConnectivityUpdateGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom connectivity host update groups for(const auto &m : modelMerged.getMergedCustomConnectivityHostUpdateGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } @@ -1327,7 +1327,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Target indices backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - s.second.getSparseIndType()->getName(), "ind" + s.second.getName(), varLoc, size, mem); + s.second.getSparseIndType()->getName(modelMerged.getTypeContext()), "ind" + s.second.getName(), varLoc, size, mem); // **TODO** remap is not always required if(backend.isPostsynapticRemapRequired() && !s.second.getWUModel()->getLearnPostCode().empty()) { diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 0822bd5bed..fb67a37766 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -84,7 +84,7 @@ NeuronPrevSpikeTimeUpdateGroupMerged::NeuronPrevSpikeTimeUpdateGroupMerged(size_ addPointerField("spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); } - const NumericBase *timeType = parseNumeric(timePrecision, nullptr); + const NumericBase *timeType = parseNumeric(timePrecision); if(getArchetype().isPrevSpikeTimeRequired()) { addPointerField("spk", backend.getDeviceVarPrefix() + "glbSpk"); addPointerField(timeType, "prevST", backend.getDeviceVarPrefix() + "prevST"); @@ -191,7 +191,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr using namespace Type; // **HACK** parse precisions - const NumericBase *timeType = parseNumeric(timePrecision, nullptr); + const NumericBase *timeType = parseNumeric(timePrecision); // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group orderNeuronGroupChildren(m_SortedMergedInSyns, &NeuronGroupInternal::getFusedPSMInSyn, @@ -248,7 +248,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : vars) { // If we're not initialising or if there is initialization code for this variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addPointerField(parseNumeric(var.type, getScalarType()), var.name, backend.getDeviceVarPrefix() + var.name); + addPointerField(parseNumeric(var.type), var.name, backend.getDeviceVarPrefix() + var.name); } // If we're initializing, add any var init EGPs to structure @@ -301,7 +301,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : sg->getPSModel()->getVars()) { // Add pointers to state variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addMergedInSynPointerField(parseNumeric(var.type, getScalarType()), var.name + "InSyn", i, backend.getDeviceVarPrefix() + var.name); + addMergedInSynPointerField(parseNumeric(var.type), var.name + "InSyn", i, backend.getDeviceVarPrefix() + var.name); } // If we're generating an initialization structure, also add any heterogeneous parameters, derived parameters or extra global parameters required for initializers @@ -356,7 +356,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : cs->getCurrentSourceModel()->getVars()) { // Add pointers to state variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(parseNumeric(var.type, getScalarType())->getPointerType(), var.name + "CS" + std::to_string(i), + addField(parseNumeric(var.type)->getPointerType(), var.name + "CS" + std::to_string(i), [&backend, i, var, this](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + m_SortedCurrentSources.at(groupIndex).at(i)->getName(); @@ -672,7 +672,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & using namespace Type; // **HACK** parse precisions - const NumericBase *timeType = parseNumeric(timePrecision, nullptr); + const NumericBase *timeType = parseNumeric(timePrecision); const bool updateRole = ((role == Role::PresynapticUpdate) || (role == Role::PostsynapticUpdate) @@ -770,7 +770,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & for(const auto &v : preVars) { // If variable is referenced in code string, add source pointer if(code.find("$(" + v.name + "_pre)") != std::string::npos) { - addSrcPointerField(parseNumeric(v.type, getScalarType()), v.name + "Pre", backend.getDeviceVarPrefix() + v.name); + addSrcPointerField(parseNumeric(v.type), v.name + "Pre", backend.getDeviceVarPrefix() + v.name); } } @@ -779,7 +779,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & for(const auto &v : postVars) { // If variable is referenced in code string, add target pointer if(code.find("$(" + v.name + "_post)") != std::string::npos) { - addTrgPointerField(parseNumeric(v.type, getScalarType()), v.name + "Post", backend.getDeviceVarPrefix() + v.name); + addTrgPointerField(parseNumeric(v.type), v.name + "Post", backend.getDeviceVarPrefix() + v.name); } } @@ -788,7 +788,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & for(const auto &e : preEGPs) { if(code.find("$(" + e.name + "_pre)") != std::string::npos) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(parseNumericPtr(e.type, getScalarType()), e.name + "Pre", + addField(parseNumericPtr(e.type), e.name + "Pre", [e, prefix](const auto &sg, size_t) { return prefix + e.name + sg.getSrcNeuronGroup()->getName(); }, GroupMergedFieldType::DYNAMIC); } @@ -799,7 +799,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & for(const auto &e : postEGPs) { if(code.find("$(" + e.name + "_post)") != std::string::npos) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(parseNumericPtr(e.type, getScalarType()), e.name + "Post", + addField(parseNumericPtr(e.type), e.name + "Post", [e, prefix](const auto &sg, size_t) { return prefix + e.name + sg.getTrgNeuronGroup()->getName(); }, GroupMergedFieldType::DYNAMIC); } @@ -839,14 +839,14 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // Add presynaptic variables to struct for(const auto &v : wum->getPreVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(parseNumeric(v.type, getScalarType())->getPointerType(), v.name, + addField(parseNumeric(v.type)->getPointerType(), v.name, [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPreVarSuffix(); }); } // Add presynaptic variables to struct for(const auto &v : wum->getPostVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(parseNumeric(v.type, getScalarType())->getPointerType(), v.name, + addField(parseNumeric(v.type)->getPointerType(), v.name, [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPostVarSuffix(); }); } @@ -971,7 +971,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // If we're performing an update with individual weights; or this variable should be initialised if((updateRole && individualWeights) || (kernelWeights && updateRole) || varInitRequired) { - addPointerField(parseNumeric(var.type, getScalarType()), var.name, backend.getDeviceVarPrefix() + var.name); + addPointerField(parseNumeric(var.type), var.name, backend.getDeviceVarPrefix() + var.name); } // If we're performing a procedural update or this variable should be initialised, add any var init EGPs to structure @@ -979,7 +979,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & const auto egps = snippet->getExtraGlobalParams(); for(const auto &e : egps) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(parseNumericPtr(e.type, getScalarType()), e.name + var.name, + addField(parseNumericPtr(e.type), e.name + var.name, [e, prefix, var](const SynapseGroupInternal &sg, size_t) { return prefix + e.name + var.name + sg.getName(); diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 7016cc4db2..17fda21165 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -443,7 +443,7 @@ void NeuronInitGroupMerged::generateWUVar(const BackendBase &backend, for(const auto &var : vars) { // Add pointers to state variable if(!varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(parseNumeric(var.type, getScalarType())->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), + addField(parseNumeric(var.type)->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto&, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index 6aa77abb71..c57f1b0cde 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -31,7 +31,7 @@ void assignGroups(const BackendBase &backend, std::vector &groups, BackendBas ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBase &backend) : m_Model(model), m_NeuronUpdateSupportCode("NeuronUpdateSupportCode"), m_PostsynapticDynamicsSupportCode("PostsynapticDynamicsSupportCode"), m_PresynapticUpdateSupportCode("PresynapticUpdateSupportCode"), m_PostsynapticUpdateSupportCode("PostsynapticUpdateSupportCode"), - m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode") + m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode"), m_TypeContext{{"scalar", Type::parseNumeric(model.getPrecision())}} { LOGD_CODE_GEN << "Merging neuron update groups:"; createMergedGroupsHash(model, backend, model.getNeuronGroups(), m_MergedNeuronUpdateGroups, diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 8725670a7b..81a3892889 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -66,7 +66,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string // If EGP is referenced in event threshold code if(s.eventThresholdCode.find("$(" + egp.name + ")") != std::string::npos) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(parseNumericPtr(egp.type, getScalarType()), egp.name + "EventThresh" + std::to_string(i), + addField(parseNumericPtr(egp.type), egp.name + "EventThresh" + std::to_string(i), [eventThresholdSGs, prefix, egp, i](const auto &, size_t groupIndex) { return prefix + egp.name + eventThresholdSGs.at(groupIndex).at(i)->getName(); @@ -80,7 +80,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string for(const auto &var : sgPreVars) { // If variable is referenced in event threshold code if(s.eventThresholdCode.find("$(" + var.name + ")") != std::string::npos) { - addField(parseNumeric(var.type, getScalarType())->getPointerType(), var.name + "EventThresh" + std::to_string(i), + addField(parseNumeric(var.type)->getPointerType(), var.name + "EventThresh" + std::to_string(i), [&backend, eventThresholdSGs, var, i](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + eventThresholdSGs.at(groupIndex).at(i)->getName(); @@ -716,7 +716,7 @@ void NeuronUpdateGroupMerged::generateWUVar(const BackendBase &backend, const s for(size_t v = 0; v < vars.size(); v++) { // Add pointers to state variable const auto var = vars[v]; - addField(Type::parseNumeric(var.type, getScalarType())->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), + addField(Type::parseNumeric(var.type)->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto &, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index 296821bfd3..21d8f86773 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -311,7 +311,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdate::getHashDigest( Utils::updateHash(getUpdateGroupName(), hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName(), hash); + Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName({}), hash); // Because it adds and removes synapses, connectivity update has to update // ALL variables associated with synapse group being modified as well as diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 5d7fc7af13..1a3c816e1c 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -266,7 +266,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getHashDigest() const CustomUpdateBase::updateHash(hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName(), hash); + Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName({}), hash); // Loop through variable references for(const auto &v : getVarReferences()) { @@ -287,7 +287,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getInitHashDigest() cons CustomUpdateBase::updateInitHash(hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName(), hash); + Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName({}), hash); return hash.get_digest(); } } // namespace GeNN diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 0c83d81aa0..1bad25f7be 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -719,11 +719,11 @@ const Type::NumericBase *SynapseGroup::getSparseIndType() const if(m_NarrowSparseIndEnabled) { // If number of target neurons can be represented using a uint8, use this type const unsigned int numTrgNeurons = getTrgNeuronGroup()->getNumNeurons(); - if(numTrgNeurons <= Type::Uint8::getInstance()->getMax()) { + if(numTrgNeurons <= Type::Uint8::getInstance()->getMax({})) { return Type::Uint8::getInstance();; } // Otherwise, if they can be represented as a uint16, use this type - else if(numTrgNeurons <= Type::Uint16::getInstance()->getMax()) { + else if(numTrgNeurons <= Type::Uint16::getInstance()->getMax({})) { return Type::Uint16::getInstance(); } } @@ -739,7 +739,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getWUHashDigest() const Utils::updateHash(getDelaySteps(), hash); Utils::updateHash(getBackPropDelaySteps(), hash); Utils::updateHash(getMaxDendriticDelayTimesteps(), hash); - Utils::updateHash(getSparseIndType()->getName(), hash); + Utils::updateHash(getSparseIndType()->getName({}), hash); Utils::updateHash(getNumThreadsPerSpike(), hash); Utils::updateHash(isEventThresholdReTestRequired(), hash); Utils::updateHash(getSpanType(), hash); @@ -904,7 +904,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getWUInitHashDigest() cons { boost::uuids::detail::sha1 hash; Utils::updateHash(getMatrixType(), hash); - Utils::updateHash(getSparseIndType()->getName(), hash); + Utils::updateHash(getSparseIndType()->getName({}), hash); Utils::updateHash(getWUModel()->getVars(), hash); Utils::updateHash(getWUModel()->getSynapseDynamicsCode().empty(), hash); @@ -969,7 +969,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getConnectivityInitHashDig boost::uuids::detail::sha1 hash; Utils::updateHash(getConnectivityInitialiser().getHashDigest(), hash); Utils::updateHash(getMatrixType(), hash); - Utils::updateHash(getSparseIndType()->getName(), hash); + Utils::updateHash(getSparseIndType()->getName({}), hash); return hash.get_digest(); } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 216c17fa22..99a69b6e98 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -38,8 +38,8 @@ class ParseError class ParserState { public: - ParserState(const std::vector &tokens, const std::unordered_set &typedefNames, ErrorHandlerBase &errorHandler) - : m_Current(0), m_Tokens(tokens), m_TypedefNames(typedefNames), m_ErrorHandler(errorHandler) + ParserState(const std::vector &tokens, ErrorHandlerBase &errorHandler) + : m_Current(0), m_Tokens(tokens), m_ErrorHandler(errorHandler) {} //--------------------------------------------------------------------------- @@ -136,7 +136,6 @@ class ParserState size_t m_Current; const std::vector &m_Tokens; - const std::unordered_set m_TypedefNames; ErrorHandlerBase &m_ErrorHandler; }; @@ -216,7 +215,7 @@ const GeNN::Type::Base *parseDeclarationSpecifiers(ParserState &parserState) } while(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER, Token::Type::STAR})); // Lookup numeric type - const Base *type = getNumericType(typeSpecifiers, parserState.getScalarType()); + const Base *type = getNumericType(typeSpecifiers); // If there are any type qualifiers, add const // **THINK** this relies of const being only qualifier @@ -847,10 +846,9 @@ std::unique_ptr parseBlockItem(ParserState &parserState) //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Parser { -Expression::ExpressionPtr parseExpression(const std::vector &tokens, const GeNN::Type::NumericBase *scalarType, - ErrorHandlerBase &errorHandler) +Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandlerBase &errorHandler) { - ParserState parserState(tokens, scalarType, errorHandler); + ParserState parserState(tokens, errorHandler); try { return parseExpression(parserState); @@ -860,8 +858,7 @@ Expression::ExpressionPtr parseExpression(const std::vector &tokens, cons } } //--------------------------------------------------------------------------- -Statement::StatementList parseBlockItemList(const std::vector &tokens, const GeNN::Type::NumericBase *scalarType, - ErrorHandlerBase &errorHandler) +Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandlerBase &errorHandler) { ParserState parserState(tokens, scalarType, errorHandler); std::vector> statements; @@ -872,8 +869,7 @@ Statement::StatementList parseBlockItemList(const std::vector &tokens, co return statements; } //--------------------------------------------------------------------------- -const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, - const GeNN::Type::NumericBase *scalarType, ErrorHandlerBase &errorHandler) +const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, ErrorHandlerBase &errorHandler) { ParserState parserState(tokens, scalarType, errorHandler); bool pointerFound = false; @@ -898,7 +894,7 @@ const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPo }; // Lookup numeric type - const auto *numericType = GeNN::Type::getNumericType(typeSpecifiers, scalarType); + const auto *numericType = GeNN::Type::getNumericType(typeSpecifiers); // If pointer, return pointer to numeric type if (pointerFound) { diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 0764eca3a8..9d641339b4 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -154,17 +154,16 @@ const Type::NumericBase *NumericTypedef::getNumeric(const TypeContext &context) //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- -const NumericBase *parseNumeric(std::string_view typeString, const std::unordered_set &typedefNames) +const NumericBase *parseNumeric(std::string_view typeString) { using namespace Transpiler; // Scan type SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, typedefNames, errorHandler); + const auto tokens = Scanner::scanSource(typeString, {"scalar"}, errorHandler); // Parse type and cast to numeric - const auto *type = dynamic_cast(Parser::parseType(tokens, false, typedefNames, - errorHandler)); + const auto *type = dynamic_cast(Parser::parseType(tokens, false, errorHandler)); // If an error was encountered while scanning or parsing, throw exception if (errorHandler.hasError()) { @@ -178,16 +177,16 @@ const NumericBase *parseNumeric(std::string_view typeString, const std::unordere return type; } //---------------------------------------------------------------------------- -const Pointer *parseNumericPtr(std::string_view typeString, const std::unordered_set &typedefNames) +const Pointer *parseNumericPtr(std::string_view typeString) { - using namespace Transpiler; + using namespace Transpiler; // Scan type SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, typedefNames, errorHandler); + const auto tokens = Scanner::scanSource(typeString, {"scalar"}, errorHandler); // Parse type and cast to numeric pointer - const auto *type = dynamic_cast(Parser::parseType(tokens, true, typedefNames, errorHandler)); + const auto *type = dynamic_cast(Parser::parseType(tokens, true, errorHandler)); // If an error was encountered while scanning or parsing, throw exception if (errorHandler.hasError()) { @@ -201,16 +200,13 @@ const Pointer *parseNumericPtr(std::string_view typeString, const std::unordered return type; } //---------------------------------------------------------------------------- -const NumericBase *getNumericType(const std::set &typeSpecifiers, const NumericBase *scalarType) +const NumericBase *getNumericType(const std::set &typeSpecifiers) { + // If type matches scalar type specifiers if(typeSpecifiers == scalarTypeSpecifier) { - if(scalarType) { - return scalarType; - } - else { - throw std::runtime_error("'scalar' type is not available in this context"); - } + return new NumericTypedef("scalar"); } + // Otherwise else { const auto type = numericTypeSpecifiers.find(typeSpecifiers); return (type == numericTypeSpecifiers.cend()) ? nullptr : type->second; From 95f48320bc1272f654426fb154ef58969a0f1e37 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 18 Jan 2023 11:09:00 +0000 Subject: [PATCH 064/725] Refactor of literal parsing * Token types for each form of literal e.g. INT32_NUMBER, SCALAR_NUMBER * Scanner uses existing logic to identify token type of literals * Parse sticks entire token, not just lexeme into literal expression * Type checker turns token type into actual type --- .../groupMergedTypeEnvironment.h | 14 +- include/genn/genn/transpiler/expression.h | 6 +- include/genn/genn/transpiler/scanner.h | 2 +- include/genn/genn/transpiler/token.h | 2 +- include/genn/genn/transpiler/typeChecker.h | 13 +- .../code_generator/customUpdateGroupMerged.cc | 2 +- src/genn/genn/transpiler/parser.cc | 10 +- src/genn/genn/transpiler/scanner.cc | 42 +++-- src/genn/genn/transpiler/typeChecker.cc | 159 ++++++++++-------- src/genn/genn/type.cc | 4 +- 10 files changed, 148 insertions(+), 106 deletions(-) diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index bc48e61c73..30f650f605 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -40,13 +40,15 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa } virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, - ErrorHandlerBase &errorHandler, bool initializer) final + const Type::TypeContext &context, ErrorHandlerBase &errorHandler, + bool initializer) final { // If type isn't found auto existingType = m_Types.find(std::string{name.lexeme}); if(existingType == m_Types.end()) { if(m_Enclosing) { - return m_Enclosing->assign(name, op, assignedType, errorHandler, initializer); + return m_Enclosing->assign(name, op, assignedType, + context, errorHandler, initializer); } else { errorHandler.error(name, "Undefined variable"); @@ -58,15 +60,17 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa addField(existingType->second); // Perform standard type-checking logicGroupMergedTypeEnvironment - return EnvironmentBase::assign(name, op, existingType->second.first, assignedType, errorHandler, initializer); + return EnvironmentBase::assign(name, op, existingType->second.first, assignedType, + context, errorHandler, initializer); } - virtual const Type::Base *incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final + virtual const Type::Base *incDec(const Token &name, Token::Type op, + const Type::TypeContext &context, ErrorHandlerBase &errorHandler) final { auto existingType = m_Types.find(std::string{name.lexeme}); if(existingType == m_Types.end()) { if(m_Enclosing) { - return m_Enclosing->incDec(name, op, errorHandler); + return m_Enclosing->incDec(name, op, context, errorHandler); } else { errorHandler.error(name, "Undefined variable"); diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index ecdc1d48d6..53b77fa472 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -186,16 +186,16 @@ class Grouping : public Base class Literal : public Base { public: - Literal(std::string_view value) + Literal(Token value) : m_Value(value) {} virtual void accept(Visitor &visitor) const final; - std::string_view getValue() const { return m_Value; } + Token getValue() const { return m_Value; } private: - const std::string_view m_Value; + const Token m_Value; }; //--------------------------------------------------------------------------- diff --git a/include/genn/genn/transpiler/scanner.h b/include/genn/genn/transpiler/scanner.h index 6eb5d6c1dd..b3af6b8b75 100644 --- a/include/genn/genn/transpiler/scanner.h +++ b/include/genn/genn/transpiler/scanner.h @@ -26,6 +26,6 @@ class ErrorHandlerBase; //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Scanner { -std::vector scanSource(const std::string_view &source, const std::unordered_set &typedefNames, ErrorHandlerBase &errorHandler); +std::vector scanSource(const std::string_view &source, ErrorHandlerBase &errorHandler); } // namespace Scanner diff --git a/include/genn/genn/transpiler/token.h b/include/genn/genn/transpiler/token.h index 3a67f38d3a..ed8022d05b 100644 --- a/include/genn/genn/transpiler/token.h +++ b/include/genn/genn/transpiler/token.h @@ -40,7 +40,7 @@ struct Token SHIFT_LEFT_EQUAL, SHIFT_RIGHT_EQUAL, // Literals - IDENTIFIER, NUMBER, + IDENTIFIER, UINT32_NUMBER, INT32_NUMBER, FLOAT_NUMBER, DOUBLE_NUMBER, SCALAR_NUMBER, // Types TYPE_SPECIFIER, diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 3bdc62cb89..e5df2098a0 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -42,8 +42,10 @@ class EnvironmentBase //------------------------------------------------------------------------ virtual void define(const Token &name, const Type::Base *type, ErrorHandlerBase &errorHandler) = 0; virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, - ErrorHandlerBase &errorHandler, bool initializer = false) = 0; - virtual const Type::Base *incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) = 0; + const Type::TypeContext &context, ErrorHandlerBase &errorHandler, + bool initializer = false) = 0; + virtual const Type::Base *incDec(const Token &name, Token::Type op, + const Type::TypeContext &context, ErrorHandlerBase &errorHandler) = 0; virtual const Type::Base *getType(const Token &name, ErrorHandlerBase &errorHandler) = 0; protected: @@ -52,7 +54,8 @@ class EnvironmentBase //--------------------------------------------------------------------------- const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *existingType, const Type::Base *assignedType, - ErrorHandlerBase &errorHandler, bool initializer = false) const; + const Type::TypeContext &context, ErrorHandlerBase &errorHandler, + bool initializer = false) const; const Type::Base *incDec(const Token &name, Token::Type op, const Type::Base *existingType, ErrorHandlerBase &errorHandler) const; }; @@ -61,8 +64,8 @@ class EnvironmentBase // Free functions //--------------------------------------------------------------------------- void typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, - ErrorHandlerBase &errorHandler); + const Type::TypeContext &context, ErrorHandlerBase &errorHandler); const Type::Base *typeCheck(const Expression::Base *expression, EnvironmentBase &environment, - ErrorHandlerBase &errorHandler); + const Type::TypeContext &context, ErrorHandlerBase &errorHandler); } // namespace MiniParse::GeNN::Transpiler diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 07f8193d80..0e4e5969e2 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -158,7 +158,7 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string Transpiler::ErrorHandler errorHandler; const std::string code = upgradeCodeString(cm->getUpdateCode()); const auto tokens = Transpiler::Scanner::scanSource(code, errorHandler); - const auto statements = Transpiler::Parser::parseBlockItemList(tokens, getScalarType(), errorHandler); + const auto statements = Transpiler::Parser::parseBlockItemList(tokens, errorHandler); Transpiler::TypeChecker::typeCheck(statements, typeEnvironment, errorHandler); } diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 99a69b6e98..f4bb83a544 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -237,8 +237,10 @@ Expression::ExpressionPtr parsePrimary(ParserState &parserState) // identifier // constant // "(" expression ")" - if (parserState.match({Token::Type::FALSE, Token::Type::TRUE, Token::Type::NUMBER})) { - return std::make_unique(parserState.previous().lexeme); + if (parserState.match({Token::Type::FALSE, Token::Type::TRUE, Token::Type::DOUBLE_NUMBER, + Token::Type::FLOAT_NUMBER, Token::Type::SCALAR_NUMBER, + Token::Type::INT32_NUMBER, Token::Type::UINT32_NUMBER})) { + return std::make_unique(parserState.previous()); } else if(parserState.match(Token::Type::IDENTIFIER)) { return std::make_unique(parserState.previous()); @@ -860,7 +862,7 @@ Expression::ExpressionPtr parseExpression(const std::vector &tokens, Erro //--------------------------------------------------------------------------- Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandlerBase &errorHandler) { - ParserState parserState(tokens, scalarType, errorHandler); + ParserState parserState(tokens, errorHandler); std::vector> statements; while(!parserState.isAtEnd()) { @@ -871,7 +873,7 @@ Statement::StatementList parseBlockItemList(const std::vector &tokens, Er //--------------------------------------------------------------------------- const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, ErrorHandlerBase &errorHandler) { - ParserState parserState(tokens, scalarType, errorHandler); + ParserState parserState(tokens, errorHandler); bool pointerFound = false; std::set typeSpecifiers; while(parserState.match({Token::Type::TYPE_SPECIFIER, Token::Type::STAR})) { diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index 8acc0285f4..6f13db090c 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -54,7 +54,11 @@ const std::unordered_map keywords{ {"uint32_t", Token::Type::TYPE_SPECIFIER}, {"int32_t", Token::Type::TYPE_SPECIFIER}, {"bool", Token::Type::TYPE_SPECIFIER}}; - +//--------------------------------------------------------------------------- +const std::map, Token::Type> integerLiteralTokenTypes{ + {{}, Token::Type::INT32_NUMBER}, + {{'U'}, Token::Type::UINT32_NUMBER} +}; //--------------------------------------------------------------------------- // ScanState //--------------------------------------------------------------------------- @@ -154,12 +158,14 @@ void emplaceToken(std::vector &tokens, Token::Type type, const ScanState tokens.emplace_back(type, scanState.getLexeme(), scanState.getLine()); } //--------------------------------------------------------------------------- -void scanIntegerSuffix(ScanState &scanState) +Token::Type scanIntegerSuffix(ScanState &scanState) { // Read suffix + std::set suffix; while(std::toupper(scanState.peek()) == 'U' || std::toupper(scanState.peek()) == 'L') { - scanState.advance(); + suffix.insert(std::toupper(scanState.advance())); } + return integerLiteralTokenTypes.at(suffix); } //--------------------------------------------------------------------------- void scanNumber(char c, ScanState &scanState, std::vector &tokens) @@ -182,8 +188,7 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) } // Add integer token - scanIntegerSuffix(scanState); - emplaceToken(tokens, Token::Type::NUMBER, scanState); + emplaceToken(tokens, scanIntegerSuffix(scanState), scanState); } // Otherwise, if this is an octal integer else if(c == '0' && isodigit(scanState.peek())){ @@ -219,20 +224,25 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) } } - // Read possible floating point suffix + // If number has an f suffix, emplace FLOAT_NUMBER token + if (std::tolower(scanState.peek()) == 'f') { + emplaceToken(tokens, Token::Type::FLOAT_NUMBER, scanState); + scanState.advance(); + } + // Otherwise, if it has a d suffix, emplace DOUBLE_NUMBER token // **NOTE** 'd' is a GeNN extension not standard C - if (std::tolower(scanState.peek()) == 'f' || std::tolower(scanState.peek()) == 'd') { + else if (std::tolower(scanState.peek()) == 'd') { + emplaceToken(tokens, Token::Type::DOUBLE_NUMBER, scanState); scanState.advance(); } - - // Emplace token - emplaceToken(tokens, Token::Type::NUMBER, scanState); + // Otherwise, emplace SCALAR_NUMBER token + else { + emplaceToken(tokens, Token::Type::SCALAR_NUMBER, scanState); + } } - // Otherwise, number is integer + // Otherwise, emplace integer token else { - // Add integer token - scanIntegerSuffix(scanState); - emplaceToken(tokens, Token::Type::NUMBER, scanState); + emplaceToken(tokens, scanIntegerSuffix(scanState), scanState); } } } @@ -427,11 +437,11 @@ void scanToken(ScanState &scanState, std::vector &tokens) //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Scanner { -std::vector scanSource(const std::string_view &source, const std::unordered_set &typedefNames, ErrorHandlerBase &errorHandler) +std::vector scanSource(const std::string_view &source, ErrorHandlerBase &errorHandler) { std::vector tokens; - ScanState scanState(source, typedefNames, errorHandler); + ScanState scanState(source, {"scalar"}, errorHandler); // Scan tokens while(!scanState.isAtEnd()) { diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index bb17a711b4..8b348edfda 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -1,6 +1,7 @@ #include "transpiler/typeChecker.h" // Standard C++ includes +#include #include // Standard C includes @@ -46,25 +47,28 @@ class EnvironmentInternal : public EnvironmentBase } virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, - ErrorHandlerBase &errorHandler, bool initializer = false) final + const Type::TypeContext &context, ErrorHandlerBase &errorHandler, + bool initializer = false) final { // If type isn't found auto existingType = m_Types.find(name.lexeme); if(existingType == m_Types.end()) { return m_Enclosing.assign(name, op, assignedType, - errorHandler, initializer); + context, errorHandler, initializer); } // Perform standard type-checking logic - return EnvironmentBase::assign(name, op, existingType->second, assignedType, errorHandler, initializer); + return EnvironmentBase::assign(name, op, existingType->second, assignedType, + context, errorHandler, initializer); } - virtual const Type::Base *incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final + virtual const Type::Base *incDec(const Token &name, Token::Type op, + const Type::TypeContext &context, ErrorHandlerBase &errorHandler) final { // If type isn't found auto existingType = m_Types.find(name.lexeme); if(existingType == m_Types.end()) { - return m_Enclosing.incDec(name, op, errorHandler); + return m_Enclosing.incDec(name, op, context, errorHandler); } // Perform standard type-checking logic @@ -96,8 +100,8 @@ class EnvironmentInternal : public EnvironmentBase class Visitor : public Expression::Visitor, public Statement::Visitor { public: - Visitor(ErrorHandlerBase &errorHandler) - : m_Environment(nullptr), m_Type(nullptr), m_ErrorHandler(errorHandler), + Visitor(const Type::TypeContext &context, ErrorHandlerBase &errorHandler) + : m_Environment(nullptr), m_Type(nullptr), m_Context(context), m_ErrorHandler(errorHandler), m_InLoop(false), m_InSwitch(false) { } @@ -135,9 +139,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // Evaluate pointer type auto indexType = evaluateType(arraySubscript.getIndex().get()); auto indexNumericType = dynamic_cast(indexType); - if (!indexNumericType || !indexNumericType->isIntegral()) { + if (!indexNumericType || !indexNumericType->isIntegral(m_Context)) { m_ErrorHandler.error(arraySubscript.getPointerName(), - "Invalid subscript index type '" + indexType->getName() + "'"); + "Invalid subscript index type '" + indexType->getName(m_Context) + "'"); throw TypeCheckError(); } @@ -154,7 +158,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Assignment &assignment) final { const auto rhsType = evaluateType(assignment.getValue()); - m_Type = m_Environment->assign(assignment.getVarName(), assignment.getOperator().type, rhsType, m_ErrorHandler); + m_Type = m_Environment->assign(assignment.getVarName(), assignment.getOperator().type, rhsType, + m_Context, m_ErrorHandler); } virtual void visit(const Expression::Binary &binary) final @@ -173,8 +178,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto rightPointerType = dynamic_cast(rightType); if (leftPointerType && rightPointerType && opType == Token::Type::MINUS) { // Check pointers are compatible - if (leftPointerType->getName() != rightPointerType->getName()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); + if (leftPointerType->getName(m_Context) != rightPointerType->getName(m_Context)) { + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName(m_Context) + "' and '" + rightType->getName(m_Context)); throw TypeCheckError(); } @@ -185,8 +190,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor else if (leftPointerType && rightNumericType && (opType == Token::Type::PLUS || opType == Token::Type::MINUS)) // P + n or P - n { // Check that numeric operand is integer - if (!rightNumericType->isIntegral()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); + if (!rightNumericType->isIntegral(m_Context)) { + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName(m_Context) + "' and '" + rightType->getName(m_Context)); throw TypeCheckError(); } @@ -197,8 +202,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor else if (leftNumericType && rightPointerType && opType == Token::Type::PLUS) // n + P { // Check that numeric operand is integer - if (!leftNumericType->isIntegral()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); + if (!leftNumericType->isIntegral(m_Context)) { + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName(m_Context) + "' and '" + rightType->getName(m_Context)); throw TypeCheckError(); } @@ -213,28 +218,28 @@ class Visitor : public Expression::Visitor, public Statement::Visitor || opType == Token::Type::AMPERSAND || opType == Token::Type::PIPE) { // Check that operands are integers - if (!leftNumericType->isIntegral() || !rightNumericType->isIntegral()) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); + if (!leftNumericType->isIntegral(m_Context) || !rightNumericType->isIntegral(m_Context)) { + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName(m_Context) + "' and '" + rightType->getName(m_Context)); throw TypeCheckError(); } // If operator is a shift, promote left type if (opType == Token::Type::SHIFT_LEFT || opType == Token::Type::SHIFT_RIGHT) { - m_Type = Type::getPromotedType(leftNumericType); + m_Type = Type::getPromotedType(leftNumericType, m_Context); } // Otherwise, take common type else { - m_Type = Type::getCommonType(leftNumericType, rightNumericType); + m_Type = Type::getCommonType(leftNumericType, rightNumericType, m_Context); } } // Otherwise, any numeric type will do, take common type else { - m_Type = Type::getCommonType(leftNumericType, rightNumericType); + m_Type = Type::getCommonType(leftNumericType, rightNumericType, m_Context); } } else { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName(m_Context) + "' and '" + rightType->getName(m_Context)); throw TypeCheckError(); } } @@ -283,7 +288,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If const is being removed if (rightType->hasQualifier(Type::Qualifier::CONSTANT) && !cast.getType()->hasQualifier(Type::Qualifier::CONSTANT)) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName(m_Context) + "' and '" + rightType->getName(m_Context)); throw TypeCheckError(); } @@ -293,14 +298,14 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto leftNumericType = dynamic_cast(cast.getType()); auto leftPointerType = dynamic_cast(cast.getType()); if (rightPointerType && leftPointerType) { - if (rightPointerType->getName() != leftPointerType->getName()) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); + if (rightPointerType->getName(m_Context) != leftPointerType->getName(m_Context)) { + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName(m_Context) + "' and '" + rightType->getName(m_Context)); throw TypeCheckError(); } } // Otherwise, if either operand isn't numeric else if(!leftNumericType | !rightNumericType) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName(m_Context) + "' and '" + rightType->getName(m_Context)); throw TypeCheckError(); } @@ -315,14 +320,14 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto falseNumericType = dynamic_cast(falseType); if (trueNumericType && falseNumericType) { // **TODO** check behaviour - m_Type = Type::getCommonType(trueNumericType, falseNumericType); + m_Type = Type::getCommonType(trueNumericType, falseNumericType, m_Context); if(trueType->hasQualifier(Type::Qualifier::CONSTANT) || falseType->hasQualifier(Type::Qualifier::CONSTANT)) { m_Type = m_Type->getQualifiedType(Type::Qualifier::CONSTANT); } } else { m_ErrorHandler.error(conditional.getQuestion(), - "Invalid operand types '" + trueType->getName() + "' and '" + falseType->getName() + "' to conditional"); + "Invalid operand types '" + trueType->getName(m_Context) + "' and '" + falseType->getName(m_Context) + "' to conditional"); throw TypeCheckError(); } } @@ -334,10 +339,27 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Literal &literal) final { - m_Type = std::visit(Utils::Overload{ - [](auto v)->const Type::NumericBase *{ return Type::TypeTraits::NumericType::getInstance(); }, - [](std::monostate)->const Type::NumericBase *{ return nullptr; }}, - literal.getValue()); + // Convert number token type to type + // **THINK** is it better to use typedef for scalar or resolve from m_Context + if (literal.getValue().type == Token::Type::DOUBLE_NUMBER) { + m_Type = Type::Double::getInstance(); + } + else if (literal.getValue().type == Token::Type::FLOAT_NUMBER) { + m_Type = Type::Double::getInstance(); + } + else if (literal.getValue().type == Token::Type::SCALAR_NUMBER) { + // **TODO** cache + m_Type = new Type::NumericTypedef("scalar"); + } + else if (literal.getValue().type == Token::Type::INT32_NUMBER) { + m_Type = Type::Int32::getInstance(); + } + else if (literal.getValue().type == Token::Type::UINT32_NUMBER) { + m_Type = Type::Uint32::getInstance(); + } + else { + assert(false); + } } virtual void visit(const Expression::Logical &logical) final @@ -349,14 +371,14 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final { - m_Type = m_Environment->incDec(postfixIncDec.getVarName(), - postfixIncDec.getOperator().type, m_ErrorHandler); + m_Type = m_Environment->incDec(postfixIncDec.getVarName(), postfixIncDec.getOperator().type, + m_Context, m_ErrorHandler); } virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final { - m_Type = m_Environment->incDec(prefixIncDec.getVarName(), - prefixIncDec.getOperator().type, m_ErrorHandler); + m_Type = m_Environment->incDec(prefixIncDec.getVarName(), prefixIncDec.getOperator().type, + m_Context, m_ErrorHandler); } virtual void visit(const Expression::Variable &variable) @@ -373,7 +395,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto rightPointerType = dynamic_cast(rightType); if (!rightPointerType) { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getName() + "'"); + "Invalid operand type '" + rightType->getName(m_Context) + "'"); throw TypeCheckError(); } @@ -387,18 +409,18 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If operator is arithmetic, return promoted type if (unary.getOperator().type == Token::Type::PLUS || unary.getOperator().type == Token::Type::MINUS) { // **THINK** const through these? - m_Type = Type::getPromotedType(rightNumericType); + m_Type = Type::getPromotedType(rightNumericType, m_Context); } // Otherwise, if operator is bitwise else if (unary.getOperator().type == Token::Type::TILDA) { // If type is integer, return promoted type - if (rightNumericType->isIntegral()) { + if (rightNumericType->isIntegral(m_Context)) { // **THINK** const through these? - m_Type = Type::getPromotedType(rightNumericType); + m_Type = Type::getPromotedType(rightNumericType, m_Context); } else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getName() + "'"); + "Invalid operand type '" + rightType->getName(m_Context) + "'"); throw TypeCheckError(); } } @@ -413,7 +435,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getName() + "'"); + "Invalid operand type '" + rightType->getName(m_Context) + "'"); throw TypeCheckError(); } } @@ -501,9 +523,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor if (labelled.getValue()) { auto valType = evaluateType(labelled.getValue()); auto valNumericType = dynamic_cast(valType); - if (!valNumericType || !valNumericType->isIntegral()) { + if (!valNumericType || !valNumericType->isIntegral(m_Context)) { m_ErrorHandler.error(labelled.getKeyword(), - "Invalid case value '" + valType->getName() + "'"); + "Invalid case value '" + valType->getName(m_Context) + "'"); throw TypeCheckError(); } } @@ -515,9 +537,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { auto condType = evaluateType(switchStatement.getCondition()); auto condNumericType = dynamic_cast(condType); - if (!condNumericType || !condNumericType->isIntegral()) { + if (!condNumericType || !condNumericType->isIntegral(m_Context)) { m_ErrorHandler.error(switchStatement.getSwitch(), - "Invalid condition '" + condType->getName() + "'"); + "Invalid condition '" + condType->getName(m_Context) + "'"); throw TypeCheckError(); } @@ -537,7 +559,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto initialiserType = evaluateType(std::get<1>(var).get()); // Assign initialiser expression to variable - m_Environment->assign(std::get<0>(var), Token::Type::EQUAL, initialiserType, m_ErrorHandler, true); + m_Environment->assign(std::get<0>(var), Token::Type::EQUAL, initialiserType, + m_Context, m_ErrorHandler, true); } } } @@ -570,7 +593,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- EnvironmentInternal *m_Environment; const Type::Base *m_Type; - + const Type::TypeContext &m_Context; ErrorHandlerBase &m_ErrorHandler; bool m_InLoop; bool m_InSwitch; @@ -582,7 +605,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, const Type::Base *existingType, const Type::Base *assignedType, - ErrorHandlerBase &errorHandler, bool initializer) const + const Type::TypeContext &context, ErrorHandlerBase &errorHandler, + bool initializer) const { // If existing type is a const qualified and isn't being initialized, give error if(!initializer && existingType->hasQualifier(Type::Qualifier::CONSTANT)) { @@ -600,19 +624,19 @@ const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, if (pointerAssignedType && pointerExistingType) { // If we're trying to assign a pointer to a const value to a pointer if (assignedType->hasQualifier(Type::Qualifier::CONSTANT) && !existingType->hasQualifier(Type::Qualifier::CONSTANT)) { - errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName() + "' and '" + pointerAssignedType->getName()); + errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName(context) + "' and '" + pointerAssignedType->getName(context)); throw TypeCheckError(); } // If pointer types aren't compatible - if (pointerExistingType->getName() != pointerAssignedType->getName()) { - errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName() + "' and '" + pointerAssignedType->getName()); + if (pointerExistingType->getName(context) != pointerAssignedType->getName(context)) { + errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName(context) + "' and '" + pointerAssignedType->getName(context)); throw TypeCheckError(); } } // Otherwise, if we're trying to initialise a pointer with a non-pointer or vice-versa else if (pointerAssignedType || pointerExistingType) { - errorHandler.error(name, "Invalid operand types '" + existingType->getName() + "' and '" + assignedType->getName()); + errorHandler.error(name, "Invalid operand types '" + existingType->getName(context) + "' and '" + assignedType->getName(context)); throw TypeCheckError(); } } @@ -621,13 +645,13 @@ const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, // If the operand being added isn't numeric or the type being added to is neither numeric or a pointer if (!numericAssignedType || (!pointerExistingType && !numericExistingType)) { - errorHandler.error(name, "Invalid operand types '" + existingType->getName() + "' and '" + assignedType->getName() + "'"); + errorHandler.error(name, "Invalid operand types '" + existingType->getName(context) + "' and '" + assignedType->getName(context) + "'"); throw TypeCheckError(); } // If we're adding a numeric type to a pointer, check it's an integer - if (pointerExistingType && numericAssignedType->isIntegral()) { - errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName() + "'"); + if (pointerExistingType && numericAssignedType->isIntegral(context)) { + errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName(context) + "'"); throw TypeCheckError(); } } @@ -635,22 +659,22 @@ const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, else { // If either type is non-numeric, give error if(!numericAssignedType) { - errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName() + "'"); + errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName(context) + "'"); throw TypeCheckError(); } if(!numericExistingType) { - errorHandler.error(name, "Invalid operand types '" + existingType->getName() + "'"); + errorHandler.error(name, "Invalid operand types '" + existingType->getName(context) + "'"); throw TypeCheckError(); } // If operand isn't one that takes any numeric type, check both operands are integral if (op != Token::Type::STAR_EQUAL && op != Token::Type::SLASH_EQUAL) { - if(!numericAssignedType->isIntegral()) { - errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName() + "'"); + if(!numericAssignedType->isIntegral(context)) { + errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName(context) + "'"); throw TypeCheckError(); } - if(!numericExistingType->isIntegral()) { - errorHandler.error(name, "Invalid operand types '" + numericExistingType->getName() + "'"); + if(!numericExistingType->isIntegral(context)) { + errorHandler.error(name, "Invalid operand types '" + numericExistingType->getName(context) + "'"); throw TypeCheckError(); } } @@ -679,18 +703,17 @@ const Type::Base *EnvironmentBase::incDec(const Token &name, Token::Type, // GeNN::Transpiler::TypeChecker //--------------------------------------------------------------------------- void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, - ErrorHandlerBase &errorHandler) + const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { - Visitor visitor(errorHandler); + Visitor visitor(context, errorHandler); EnvironmentInternal internalEnvironment(environment); visitor.typeCheck(statements, internalEnvironment); } //--------------------------------------------------------------------------- -const Type::Base *GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, - EnvironmentBase &environment, - ErrorHandlerBase &errorHandler) +const Type::Base *GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, EnvironmentBase &environment, + const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { - Visitor visitor(errorHandler); + Visitor visitor(context, errorHandler); EnvironmentInternal internalEnvironment(environment); return visitor.typeCheck(expression, internalEnvironment); } diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 9d641339b4..2f9aa118df 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -160,7 +160,7 @@ const NumericBase *parseNumeric(std::string_view typeString) // Scan type SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, {"scalar"}, errorHandler); + const auto tokens = Scanner::scanSource(typeString, errorHandler); // Parse type and cast to numeric const auto *type = dynamic_cast(Parser::parseType(tokens, false, errorHandler)); @@ -183,7 +183,7 @@ const Pointer *parseNumericPtr(std::string_view typeString) // Scan type SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, {"scalar"}, errorHandler); + const auto tokens = Scanner::scanSource(typeString, errorHandler); // Parse type and cast to numeric pointer const auto *type = dynamic_cast(Parser::parseType(tokens, true, errorHandler)); From 0f174268cbac00ab04e53ea85df4371ce5b8bb4a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 18 Jan 2023 12:15:01 +0000 Subject: [PATCH 065/725] started hacking out string types * Model variable and variable reference type is now a Type::NumericBase - can be parsed from string * ModelSpec precision and time precision are now Type::NumericBase --- .../genn/genn/code_generator/backendBase.h | 24 ++++---- .../customConnectivityUpdateGroupMerged.h | 6 +- .../code_generator/customUpdateGroupMerged.h | 18 +++--- .../genn/genn/code_generator/groupMerged.h | 21 ++++--- .../groupMergedTypeEnvironment.h | 8 +-- .../genn/code_generator/initGroupMerged.h | 28 +++++----- .../code_generator/neuronUpdateGroupMerged.h | 4 +- .../genn/code_generator/supportCodeMerged.h | 6 +- .../code_generator/synapseUpdateGroupMerged.h | 8 +-- include/genn/genn/modelSpec.h | 15 ++--- include/genn/genn/models.h | 30 +++++----- include/genn/genn/type.h | 47 ++++++++++++---- src/genn/backends/cuda/backend.cc | 2 +- src/genn/backends/opencl/backend.cc | 2 +- .../backends/single_threaded_cpu/backend.cc | 2 +- .../customConnectivityUpdateGroupMerged.cc | 20 +++---- .../code_generator/customUpdateGroupMerged.cc | 12 ++-- .../genn/code_generator/generateRunner.cc | 18 +++--- src/genn/genn/code_generator/groupMerged.cc | 55 ++++++++----------- .../genn/code_generator/initGroupMerged.cc | 44 +++++++-------- .../code_generator/neuronUpdateGroupMerged.cc | 34 ++++++------ .../synapseUpdateGroupMerged.cc | 12 ++-- src/genn/genn/modelSpec.cc | 50 +---------------- src/genn/genn/models.cc | 35 +++++++++++- src/genn/genn/transpiler/typeChecker.cc | 54 +++++++++--------- src/genn/genn/type.cc | 16 +++--- 26 files changed, 285 insertions(+), 286 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 11423f6785..88e91a0121 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -243,9 +243,9 @@ class GENN_EXPORT BackendBase //! After all timestep logic is complete virtual void genStepTimeFinalisePreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const = 0; - virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, const std::string &type, const std::string &name, VarLocation loc) const = 0; - virtual void genVariableImplementation(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc) const = 0; - virtual void genVariableAllocation(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const = 0; + virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, const Type::Base *type, const std::string &name, VarLocation loc) const = 0; + virtual void genVariableImplementation(CodeStream &os, const Type::Base *type, const std::string &name, VarLocation loc) const = 0; + virtual void genVariableAllocation(CodeStream &os, const Type::Base *type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const = 0; virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const = 0; virtual void genExtraGlobalParamDefinition(CodeStream &definitions, CodeStream &definitionsInternal, const std::string &type, const std::string &name, VarLocation loc) const = 0; @@ -277,17 +277,17 @@ class GENN_EXPORT BackendBase virtual void genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged &cu, const Substitutions &kernelSubs, Handler handler) const = 0; //! Generate code for pushing a variable to the 'device' - virtual void genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const = 0; + virtual void genVariablePush(CodeStream &os, const Type::Base *type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const = 0; //! Generate code for pulling a variable from the 'device' - virtual void genVariablePull(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, size_t count) const = 0; + virtual void genVariablePull(CodeStream &os, const Type::Base *type, const std::string &name, VarLocation loc, size_t count) const = 0; //! Generate code for pushing a variable's value in the current timestep to the 'device' - virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, const std::string &type, + virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, const Type::Base *type, const std::string &name, VarLocation loc, unsigned int batchSize) const = 0; //! Generate code for pulling a variable's value in the current timestep from the 'device' - virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, const std::string &type, + virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, const Type::Base *type, const std::string &name, VarLocation loc, unsigned int batchSize) const = 0; //! Generate code for pushing true spikes emitted by a neuron group in the current timestep to the 'device' @@ -402,14 +402,14 @@ class GENN_EXPORT BackendBase //-------------------------------------------------------------------------- //! Helper function to generate matching push and pull functions for a variable void genVariablePushPull(CodeStream &push, CodeStream &pull, - const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const + const Type::Base *type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const { genVariablePush(push, type, name, loc, autoInitialized, count); genVariablePull(pull, type, name, loc, count); } //! Helper function to generate matching push and pull functions for the current state of a variable - void genCurrentVariablePushPull(CodeStream &push, CodeStream &pull, const NeuronGroupInternal &ng, const std::string &type, + void genCurrentVariablePushPull(CodeStream &push, CodeStream &pull, const NeuronGroupInternal &ng, const Type::Base *type, const std::string &name, VarLocation loc, unsigned int batchSize) const { genCurrentVariablePush(push, ng, type, name, loc, batchSize); @@ -418,10 +418,10 @@ class GENN_EXPORT BackendBase //! Helper function to generate matching definition, declaration, allocation and free code for an array void genArray(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const std::string &type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const + const Type::Base *type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const { - genVariableDefinition(definitions, definitionsInternal, type + "*", name, loc); - genVariableImplementation(runner, type + "*", name, loc); + genVariableDefinition(definitions, definitionsInternal, type->getPointerType(), name, loc); + genVariableImplementation(runner, type->getPointerType(), name, loc); genVariableFree(free, name, loc); genVariableAllocation(allocations, type, name, loc, count, memAlloc); } diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index 3474e10331..806784508f 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -15,7 +15,7 @@ namespace GeNN::CodeGenerator class GENN_EXPORT CustomConnectivityUpdateGroupMergedBase : public GroupMerged { public: - CustomConnectivityUpdateGroupMergedBase(size_t index, const std::string &precision, + CustomConnectivityUpdateGroupMergedBase(size_t index, const Type::NumericBase *precision, const std::vector> &groups); protected: @@ -32,7 +32,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMergedBase : public GroupMerged> &groups); //---------------------------------------------------------------------------- @@ -75,7 +75,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public CustomConnectivit class GENN_EXPORT CustomConnectivityHostUpdateGroupMerged : public CustomConnectivityUpdateGroupMergedBase { public: - CustomConnectivityHostUpdateGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, + CustomConnectivityHostUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- // Public API diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index b57a0341e6..23a9263969 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -12,7 +12,7 @@ namespace GeNN::CodeGenerator class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged { public: - CustomUpdateGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, + CustomUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -79,7 +79,7 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged> &groups); private: @@ -95,7 +95,7 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged> &groups) : CustomUpdateWUGroupMergedBase(index, precision, timePrecision, backend, groups) { @@ -127,7 +127,7 @@ class GENN_EXPORT CustomUpdateWUGroupMerged : public CustomUpdateWUGroupMergedBa class GENN_EXPORT CustomUpdateTransposeWUGroupMerged : public CustomUpdateWUGroupMergedBase { public: - CustomUpdateTransposeWUGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, + CustomUpdateTransposeWUGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups) : CustomUpdateWUGroupMergedBase(index, precision, timePrecision, backend, groups) { @@ -161,7 +161,7 @@ template class CustomUpdateHostReductionGroupMergedBase : public GroupMerged { protected: - CustomUpdateHostReductionGroupMergedBase(size_t index, const std::string &precision, const BackendBase &backend, + CustomUpdateHostReductionGroupMergedBase(size_t index, const Type::NumericBase *precision, const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, precision, groups) { @@ -173,14 +173,14 @@ class CustomUpdateHostReductionGroupMergedBase : public GroupMerged const CustomUpdateModels::Base *cm = this->getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { if(v.access & VarAccessModeAttribute::REDUCE) { - this->addPointerField(Type::parseNumeric(v.type), v.name, backend.getDeviceVarPrefix() + v.name); + this->addPointerField(v.type, v.name, backend.getDeviceVarPrefix() + v.name); } } // Loop through variable references and add pointers if they are reduction targets for(const auto &v : cm->getVarRefs()) { if(v.access & VarAccessModeAttribute::REDUCE) { - this->addPointerField(Type::parseNumeric(v.type), v.name, backend.getDeviceVarPrefix() + v.name); + this->addPointerField(v.type, v.name, backend.getDeviceVarPrefix() + v.name); } } } @@ -192,7 +192,7 @@ class CustomUpdateHostReductionGroupMergedBase : public GroupMerged class GENN_EXPORT CustomUpdateHostReductionGroupMerged : public CustomUpdateHostReductionGroupMergedBase { public: - CustomUpdateHostReductionGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, + CustomUpdateHostReductionGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups); //------------------------------------------------------------------------ @@ -219,7 +219,7 @@ class GENN_EXPORT CustomUpdateHostReductionGroupMerged : public CustomUpdateHost class GENN_EXPORT CustomWUUpdateHostReductionGroupMerged : public CustomUpdateHostReductionGroupMergedBase { public: - CustomWUUpdateHostReductionGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, + CustomWUUpdateHostReductionGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups); //------------------------------------------------------------------------ diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 331f40d1ec..378c891933 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -65,9 +65,8 @@ class GroupMerged typedef std::function GetFieldValueFunc; typedef std::tuple Field; - // **HACK** type should come in as type not string - GroupMerged(size_t index, const std::string &precision, const std::vector> groups) - : m_Index(index), m_LiteralSuffix((precision == "float") ? "f" : ""), m_ScalarType(Type::parseNumeric(precision)), m_Groups(std::move(groups)) + GroupMerged(size_t index, const Type::NumericBase *precision, const std::vector> groups) + : m_Index(index), m_LiteralSuffix((precision == "float") ? "f" : ""), m_ScalarType(precision), m_Groups(std::move(groups)) {} //------------------------------------------------------------------------ @@ -122,12 +121,12 @@ class GroupMerged } // Otherwise, allow the backend to add a prefix else { - os << backend.getPointerPrefix() << type->getName(context); + os << backend.getPointerPrefix() << type->getResolvedName(context); } } // Otherwise, leave the type alone else { - os << type->getName(context); + os << type->getResolvedName(context); } os << " " << std::get<1>(f) << ";" << std::endl; } @@ -281,7 +280,7 @@ class GroupMerged { // Loop through variables for(const auto &v : vars) { - addPointerField(Type::parseNumeric(v.type), v.name, arrayPrefix + v.name); + addPointerField(v.type, v.name, arrayPrefix + v.name); } } @@ -290,7 +289,7 @@ class GroupMerged { // Loop through variables for(const auto &v : varReferences) { - addField(Type::parseNumeric(v.type)->getPointerType(), v.name, + addField(v.type->getPointerType(), v.name, [getVarRefFn, arrayPrefix, v](const G &g, size_t) { const auto varRef = getVarRefFn(g).at(v.name); @@ -555,7 +554,7 @@ class GroupMerged class GENN_EXPORT NeuronSpikeQueueUpdateGroupMerged : public GroupMerged { public: - NeuronSpikeQueueUpdateGroupMerged(size_t index, const std::string &precision, const std::string &timePrecison, const BackendBase &backend, + NeuronSpikeQueueUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecison, const BackendBase &backend, const std::vector> &groups); //------------------------------------------------------------------------ @@ -584,7 +583,7 @@ class GENN_EXPORT NeuronSpikeQueueUpdateGroupMerged : public GroupMerged { public: - NeuronPrevSpikeTimeUpdateGroupMerged(size_t index, const std::string &precision, const std::string &timePrecison, const BackendBase &backend, + NeuronPrevSpikeTimeUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecison, const BackendBase &backend, const std::vector> &groups); //------------------------------------------------------------------------ @@ -663,7 +662,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged> &groups); void updateBaseHash(bool init, boost::uuids::detail::sha1 &hash) const; @@ -1074,7 +1073,7 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged> &groups); //---------------------------------------------------------------------------- diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index 30f650f605..b8890d44ea 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -188,7 +188,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa { // Loop through variables for(const auto &v : vars) { - definePointerField(Type::parseNumeric(v.type), v.name, arrayPrefix, getVarAccessMode(v.access)); + definePointerField(v.type, v.name, arrayPrefix, getVarAccessMode(v.access)); } } @@ -197,12 +197,10 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa { // Loop through variables for(const auto &v : varReferences) { - const auto *type = Type::parseNumeric(v.type); - // If variable access is read-only, qualify type with const - const auto *qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? type->getQualifiedType(Type::Qualifier::CONSTANT) : type; + const auto *qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? v.type->getQualifiedType(Type::Qualifier::CONSTANT) : v.type; defineField(qualifiedType, v.name, - type->getPointerType(), v.name, + v.type->getPointerType(), v.name, [arrayPrefix, getVarRefFn, v](const auto &g, size_t) { const auto varRef = getVarRefFn(g).at(v.name); diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 007b3c6d4f..88c0e48083 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -11,7 +11,7 @@ namespace GeNN::CodeGenerator class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase { public: - NeuronInitGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, + NeuronInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -96,7 +96,7 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase class GENN_EXPORT SynapseInitGroupMerged : public SynapseGroupMergedBase { public: - SynapseInitGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, + SynapseInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups) : SynapseGroupMergedBase(index, precision, timePrecision, backend, SynapseGroupMergedBase::Role::Init, "", groups) {} @@ -129,7 +129,7 @@ class GENN_EXPORT SynapseInitGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseSparseInitGroupMerged : public SynapseGroupMergedBase { public: - SynapseSparseInitGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, + SynapseSparseInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups) : SynapseGroupMergedBase(index, precision, timePrecision, backend, SynapseGroupMergedBase::Role::SparseInit, "", groups) {} @@ -162,7 +162,7 @@ class GENN_EXPORT SynapseSparseInitGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMergedBase { public: - SynapseConnectivityInitGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, + SynapseConnectivityInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups) : SynapseGroupMergedBase(index, precision, timePrecision, backend, SynapseGroupMergedBase::Role::ConnectivityInit, "", groups) {} @@ -195,7 +195,7 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMerged // Private methods //---------------------------------------------------------------------------- //! Generate either row or column connectivity init code - void genInitConnectivity(CodeStream &os, Substitutions &popSubs, const std::string &ftype, bool rowNotColumns) const; + void genInitConnectivity(CodeStream &os, Substitutions &popSubs, const Type::NumericBase *scalarType, bool rowNotColumns) const; }; @@ -205,7 +205,7 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMerged class GENN_EXPORT SynapseConnectivityHostInitGroupMerged : public GroupMerged { public: - SynapseConnectivityHostInitGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, + SynapseConnectivityHostInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups); //------------------------------------------------------------------------ @@ -249,7 +249,7 @@ template class CustomUpdateInitGroupMergedBase : public GroupMerged { protected: - CustomUpdateInitGroupMergedBase(size_t index, const std::string &precision, const BackendBase &backend, + CustomUpdateInitGroupMergedBase(size_t index, const Type::NumericBase *precision, const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, precision, groups) { @@ -259,7 +259,7 @@ class CustomUpdateInitGroupMergedBase : public GroupMerged // If we're not initialising or if there is initialization code for this variable const auto &varInit = archetypeAdaptor.getVarInitialisers().at(var.name); if (!varInit.getSnippet()->getCode().empty()) { - this->addPointerField(Type::parseNumeric(var.type), var.name, backend.getDeviceVarPrefix() + var.name); + this->addPointerField(var.type, var.name, backend.getDeviceVarPrefix() + var.name); } // Add any var init EGPs to structure @@ -340,7 +340,7 @@ class CustomUpdateInitGroupMergedBase : public GroupMerged class GENN_EXPORT CustomUpdateInitGroupMerged : public CustomUpdateInitGroupMergedBase { public: - CustomUpdateInitGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, + CustomUpdateInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -373,7 +373,7 @@ class GENN_EXPORT CustomWUUpdateInitGroupMerged : public CustomUpdateInitGroupMe CustomUpdateVarAdapter> { public: - CustomWUUpdateInitGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, + CustomWUUpdateInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -432,7 +432,7 @@ class GENN_EXPORT CustomWUUpdateSparseInitGroupMerged : public CustomUpdateInitG CustomUpdateVarAdapter> { public: - CustomWUUpdateSparseInitGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, + CustomWUUpdateSparseInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -464,7 +464,7 @@ class GENN_EXPORT CustomConnectivityUpdatePreInitGroupMerged : public CustomUpda CustomConnectivityUpdatePreVarAdapter> { public: - CustomConnectivityUpdatePreInitGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, + CustomConnectivityUpdatePreInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -496,7 +496,7 @@ class GENN_EXPORT CustomConnectivityUpdatePostInitGroupMerged : public CustomUpd CustomConnectivityUpdatePostVarAdapter> { public: - CustomConnectivityUpdatePostInitGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, + CustomConnectivityUpdatePostInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -528,7 +528,7 @@ class GENN_EXPORT CustomConnectivityUpdateSparseInitGroupMerged : public CustomU CustomConnectivityUpdateVarAdapter> { public: - CustomConnectivityUpdateSparseInitGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, + CustomConnectivityUpdateSparseInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index 383a689e67..ef173dad79 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -11,7 +11,7 @@ namespace GeNN::CodeGenerator class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase { public: - NeuronUpdateGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, + NeuronUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups); //------------------------------------------------------------------------ @@ -83,7 +83,7 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase void addNeuronModelSubstitutions(Substitutions &substitution, const std::string &sourceSuffix = "", const std::string &destSuffix = "") const; void generateWUVarUpdate(CodeStream &os, const Substitutions &popSubs, - const std::string &fieldPrefixStem, const std::string &precision, const std::string &sourceSuffix, + const std::string &fieldPrefixStem, const std::string &sourceSuffix, bool useLocalNeuronVars, unsigned int batchSize, const std::vector &archetypeSyn, unsigned int(SynapseGroupInternal::*getDelaySteps)(void) const, diff --git a/include/genn/genn/code_generator/supportCodeMerged.h b/include/genn/genn/code_generator/supportCodeMerged.h index 3a61f3d5c9..4afcedbee2 100644 --- a/include/genn/genn/code_generator/supportCodeMerged.h +++ b/include/genn/genn/code_generator/supportCodeMerged.h @@ -44,7 +44,7 @@ class SupportCodeMerged } //! Generate support code - void gen(CodeStream &os, const std::string &ftype, const bool supportsNamespace = true) const + void gen(CodeStream &os, const Type::NumericBase *scalarType, const bool supportsNamespace = true) const { // Loop through support code for(const auto &s : m_SupportCode) { @@ -53,13 +53,13 @@ class SupportCodeMerged os << "namespace " << s.second; { CodeStream::Scope b(os); - os << ensureFtype(s.first, ftype) << std::endl; + os <> &groups) : SynapseGroupMergedBase(index, precision, timePrecision, backend, SynapseGroupMergedBase::Role::PresynapticUpdate, groups.front().get().getWUModel()->getSimCode() + groups.front().get().getWUModel()->getEventCode() + groups.front().get().getWUModel()->getEventThresholdConditionCode(), groups) @@ -49,7 +49,7 @@ class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase { public: - PostsynapticUpdateGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, + PostsynapticUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups) : SynapseGroupMergedBase(index, precision, timePrecision, backend, SynapseGroupMergedBase::Role::PostsynapticUpdate, groups.front().get().getWUModel()->getLearnPostCode(), groups) @@ -83,7 +83,7 @@ class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase { public: - SynapseDynamicsGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, + SynapseDynamicsGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups) : SynapseGroupMergedBase(index, precision, timePrecision, backend, SynapseGroupMergedBase::Role::SynapseDynamics, groups.front().get().getWUModel()->getSynapseDynamicsCode(), groups) @@ -117,7 +117,7 @@ class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseDendriticDelayUpdateGroupMerged : public GroupMerged { public: - SynapseDendriticDelayUpdateGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, + SynapseDendriticDelayUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &group); //------------------------------------------------------------------------ diff --git a/include/genn/genn/modelSpec.h b/include/genn/genn/modelSpec.h index fda88f416b..6d398c378d 100644 --- a/include/genn/genn/modelSpec.h +++ b/include/genn/genn/modelSpec.h @@ -231,10 +231,10 @@ class GENN_EXPORT ModelSpec void setName(const std::string &name){ m_Name = name; } //! Set numerical precision for floating point - void setPrecision(ScalarPrecision scalarPrecision); + void setPrecision(const Type::NumericBase *precision){ m_Precision = precision; } //! Set numerical precision for time - void setTimePrecision(TimePrecision timePrecision){ m_TimePrecision = timePrecision; } + void setTimePrecision(const Type::NumericBase *timePrecision){ m_TimePrecision = timePrecision; } //! Set the integration step size of the model void setDT(double dt){ m_DT = dt; } @@ -278,10 +278,10 @@ class GENN_EXPORT ModelSpec const std::string &getName() const{ return m_Name; } //! Gets the floating point numerical precision - const std::string &getPrecision() const{ return m_Precision; } + const Type::NumericBase *getPrecision() const{ return m_Precision; } //! Gets the floating point numerical precision used to represent time - std::string getTimePrecision() const; + const Type::NumericBase *getTimePrecision() const{ return m_TimePrecision ? m_TimePrecision : m_Precision; } //! Gets the model integration step size double getDT() const { return m_DT; } @@ -665,9 +665,6 @@ class GENN_EXPORT ModelSpec //-------------------------------------------------------------------------- // Protected const methods //-------------------------------------------------------------------------- - //! Get the string literal that should be used to represent a value in the model's floating-point type - std::string scalarExpr(double) const; - //! Are any variables in any populations in this model using zero-copy memory? bool zeroCopyInUse() const; @@ -731,10 +728,10 @@ class GENN_EXPORT ModelSpec std::string m_Name; //! Type of floating point variables (float, double, ...; default: float) - std::string m_Precision; + const Type::NumericBase *m_Precision; //! Type of floating point variables used to store time - TimePrecision m_TimePrecision; + const Type::NumericBase *m_TimePrecision; //! The integration time step of the model double m_DT; diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 95d13fb8e3..659a248a30 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -28,6 +28,10 @@ namespace CodeGenerator { class BackendBase; } +namespace Type +{ +class NumericBase; +} } //---------------------------------------------------------------------------- @@ -54,35 +58,33 @@ class GENN_EXPORT Base : public Snippet::Base if not specified, this results in a -Wmissing-field-initializers warning on GCC and Clang*/ struct Var { - Var(const std::string &n, const std::string &t, VarAccess a) : name(n), type(t), access(a) + Var(const std::string &n, const Type::NumericBase *t, VarAccess a) : name(n), type(t), access(a) {} - Var(const std::string &n, const std::string &t) : Var(n, t, VarAccess::READ_WRITE) + Var(const std::string &n, const Type::NumericBase *t) : Var(n, t, VarAccess::READ_WRITE) {} + Var(const std::string &n, const std::string &t, VarAccess a); + Var(const std::string &n, const std::string &t); - bool operator == (const Var &other) const - { - return (std::tie(name, type, access) == std::tie(other.name, other.type, other.access)); - } + bool operator == (const Var &other) const; const std::string name; - const std::string type; + const Type::NumericBase *type; const VarAccess access; }; struct VarRef { - VarRef(const std::string &n, const std::string &t, VarAccessMode a) : name(n), type(t), access(a) + VarRef(const std::string &n, const Type::NumericBase *t, VarAccessMode a) : name(n), type(t), access(a) {} - VarRef(const std::string &n, const std::string &t) : VarRef(n, t, VarAccessMode::READ_WRITE) + VarRef(const std::string &n, const Type::NumericBase *t) : VarRef(n, t, VarAccessMode::READ_WRITE) {} + VarRef(const std::string &n, const std::string &t, VarAccessMode a); + VarRef(const std::string &n, const std::string &t); - bool operator == (const VarRef &other) const - { - return (std::tie(name, type, access) == std::tie(other.name, other.type, other.access)); - } + bool operator == (const VarRef &other) const; const std::string name; - const std::string type; + const Type::NumericBase *type; const VarAccessMode access; }; diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index dd3e78d582..5a279bbb0f 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -39,7 +39,8 @@ { \ DECLARE_TYPE(TYPE) \ TYPE(Qualifier qualifiers = Qualifier{0}) : Numeric(qualifiers){} \ - virtual std::string getName(const TypeContext&) const final{ return #UNDERLYING_TYPE; } \ + virtual std::string getName() const final{ return #UNDERLYING_TYPE; } \ + virtual std::string getResolvedName(const TypeContext&) const final{ return #UNDERLYING_TYPE; } \ virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new TYPE(qualifiers); } \ }; \ template<> \ @@ -103,7 +104,10 @@ class Base // Declared virtuals //------------------------------------------------------------------------ //! Get the (unqualified) name of this type - virtual std::string getName(const TypeContext &context) const = 0; + virtual std::string getName() const = 0; + + //! Get fully-resolved (unqualified) name of this type + virtual std::string getResolvedName(const TypeContext &context) const = 0; //! Get size of this type in bytes virtual size_t getSizeBytes(const TypeContext &context) const = 0; @@ -143,7 +147,8 @@ class Pointer : public Base //------------------------------------------------------------------------ // Base virtuals //------------------------------------------------------------------------ - virtual std::string getName(const TypeContext &context) const{ return getValueType()->getName(context) + "*";} + virtual std::string getName() const{ return getValueType()->getName() + "*";} + virtual std::string getResolvedName(const TypeContext &context) const{ return getValueType()->getResolvedName(context) + "*"; } virtual size_t getSizeBytes(const TypeContext&) const final{ return sizeof(char*); } virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new Pointer(m_ValueType, qualifiers); } @@ -215,9 +220,9 @@ class NumericTypedef : public NumericBase //------------------------------------------------------------------------ // Base virtuals //------------------------------------------------------------------------ - virtual std::string getName(const TypeContext &context) const final; + virtual std::string getName() const final{ return m_Name; } + virtual std::string getResolvedName(const TypeContext &context) const; virtual size_t getSizeBytes(const TypeContext &context) const final; - virtual Base *getQualifiedType(Qualifier qualifiers) const final; //------------------------------------------------------------------------ @@ -269,10 +274,18 @@ class ForeignFunction : public ForeignFunctionBase //------------------------------------------------------------------------ // Base virtuals //------------------------------------------------------------------------ - virtual std::string getName(const TypeContext &context) const final + virtual std::string getName() const final + { + std::string typeName = getReturnType()->getName() + "("; + updateTypeName(typeName); + typeName += ")"; + return typeName; + } + + virtual std::string getResolvedName(const TypeContext &context) const final { - std::string typeName = getReturnType()->getName(context) + "("; - updateTypeName(context, typeName); + std::string typeName = getReturnType()->getResolvedName(context) + "("; + updateResolvedTypeName(context, typeName); typeName += ")"; return typeName; } @@ -303,17 +316,29 @@ class ForeignFunction : public ForeignFunctionBase //------------------------------------------------------------------------ // Private methods //------------------------------------------------------------------------ + template + static void updateTypeName(std::string &typeName) + { + // Add argument typename to string + typeName += T::getInstance()->getName(); + + // If there are more arguments left in pack, add comma and recurse + if constexpr (sizeof...(Args)) { + typeName += ", "; + updateTypeName(typeName); + } + } template - static void updateTypeName(const TypeContext &context, std::string &typeName) + static void updateResolvedTypeName(const TypeContext &context, std::string &typeName) { // Add argument typename to string - typeName += T::getInstance()->getName(context); + typeName += T::getInstance()->getResolvedName(context); // If there are more arguments left in pack, add comma and recurse if constexpr (sizeof...(Args)) { typeName += ", "; - updateTypeName(context, typeName); + updateResolvedTypeName(context, typeName); } } diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 3f912a5b9d..716222589d 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -1770,7 +1770,7 @@ void Backend::genMergedExtraGlobalParamPush(CodeStream &os, const std::string &s //-------------------------------------------------------------------------- std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const { - return type->getName(context); + return type->getResolvedName(context); } //-------------------------------------------------------------------------- const Type::Base *Backend::getMergedGroupSimRNGType() const diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index 3d794d2335..b958131c1a 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -2084,7 +2084,7 @@ std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, con } // Otherwise, type remains the same else { - return type->getName(context); + return type->getResolvedName(context); } } //-------------------------------------------------------------------------- diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index bc9ccc8ee7..43ab05ae3a 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1350,7 +1350,7 @@ void Backend::genMergedExtraGlobalParamPush(CodeStream &os, const std::string &s //-------------------------------------------------------------------------- std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const { - return type->getName(context); + return type->getResolvedName(context); } //-------------------------------------------------------------------------- const Type::Base *Backend::getMergedGroupSimRNGType() const diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 55e5c2f1d2..f21a43f142 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -12,7 +12,7 @@ using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- // CodeGenerator::CustomConnectivityUpdateGroupMergedBase //---------------------------------------------------------------------------- -CustomConnectivityUpdateGroupMergedBase::CustomConnectivityUpdateGroupMergedBase(size_t index, const std::string &precision, +CustomConnectivityUpdateGroupMergedBase::CustomConnectivityUpdateGroupMergedBase(size_t index, const Type::NumericBase *precision, const std::vector> &groups) : GroupMerged(index, precision, groups) { @@ -60,7 +60,7 @@ bool CustomConnectivityUpdateGroupMergedBase::isDerivedParamHeterogeneous(const //---------------------------------------------------------------------------- const std::string CustomConnectivityUpdateGroupMerged::name = "CustomConnectivityUpdate"; //---------------------------------------------------------------------------- -CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, +CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : CustomConnectivityUpdateGroupMergedBase(index, precision, groups) { @@ -166,7 +166,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // Loop through sorted dependent variables for(size_t i = 0; i < getSortedArchetypeDependentVars().size(); i++) { - addField(parseNumeric(getSortedArchetypeDependentVars().at(i).getVar().type)->getPointerType(), "_dependentVar" + std::to_string(i), + addField(getSortedArchetypeDependentVars().at(i).getVar().type->getPointerType(), "_dependentVar" + std::to_string(i), [i, &backend, this](const auto&, size_t g) { const auto &varRef = m_SortedDependentVars[g][i]; @@ -418,7 +418,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Apply substitutons to row update code and write out std::string code = cm->getRowUpdateCode(); updateSubs.applyCheckUnreplaced(code, "custom connectivity update : merged" + std::to_string(getIndex())); - code = ensureFtype(code, modelMerged.getModel().getPrecision()); + //code = ensureFtype(code, Type::modelMerged.getModel().getPrecision()); os << code; } @@ -427,7 +427,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back //---------------------------------------------------------------------------- const std::string CustomConnectivityHostUpdateGroupMerged::name = "CustomConnectivityHostUpdate"; //---------------------------------------------------------------------------- -CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, +CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : CustomConnectivityUpdateGroupMergedBase(index, precision, groups) { @@ -516,7 +516,7 @@ void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase & // Apply substitutons to row update code and write out std::string code = cm->getHostUpdateCode(); subs.applyCheckUnreplaced(code, "custom connectivity host update : merged" + std::to_string(getIndex())); - code = ensureFtype(code, modelMerged.getModel().getPrecision()); + //code = ensureFtype(code, modelMerged.getModel().getPrecision()); os << code; } } @@ -534,7 +534,7 @@ void CustomConnectivityHostUpdateGroupMerged::addVarPushPullFuncSubs(const Backe // **YUCK** these EGP functions should probably just be called dynamic or something std::stringstream pushStream; CodeStream push(pushStream); - backend.genExtraGlobalParamPush(push, v.type + "*", v.name, + backend.genExtraGlobalParamPush(push, v.type->getPointerType(), v.name, loc, count, "group->"); // Add substitution @@ -544,7 +544,7 @@ void CustomConnectivityHostUpdateGroupMerged::addVarPushPullFuncSubs(const Backe // **YUCK** these EGP functions should probably just be called dynamic or something std::stringstream pullStream; CodeStream pull(pullStream); - backend.genExtraGlobalParamPull(pull, v.type + "*", v.name, + backend.genExtraGlobalParamPull(pull, v.->getPointerType(), v.name, loc, count, "group->"); // Add substitution @@ -562,13 +562,13 @@ void CustomConnectivityHostUpdateGroupMerged::addVars(const BackendBase &backend for(const auto &v : vars) { // If var is located on the host if (std::invoke(getVarLocationFn, getArchetype(), v.name) & VarLocation::HOST) { - addField(parseNumeric(v.type)->getPointerType(), v.name, + addField(v.type->getPointerType(), v.name, [v](const auto &g, size_t) { return v.name + g.getName(); }, GroupMergedFieldType::HOST); if(!backend.getDeviceVarPrefix().empty()) { // **TODO** I think could use addPointerField - addField(parseNumeric(v.type)->getPointerType(), backend.getDeviceVarPrefix() + v.name, + addField(v.type->getPointerType(), backend.getDeviceVarPrefix() + v.name, [v, &backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + v.name + g.getName(); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 0e4e5969e2..6e87acef0d 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -111,7 +111,7 @@ void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, //---------------------------------------------------------------------------- const std::string CustomUpdateGroupMerged::name = "CustomUpdate"; //---------------------------------------------------------------------------- -CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const std::string &precision, const std::string&, const BackendBase &backend, +CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, precision, groups) { @@ -290,7 +290,7 @@ std::string CustomUpdateWUGroupMergedBase::getVarRefIndex(VarAccessDuplication v return ((varDuplication == VarAccessDuplication::SHARED || !getArchetype().isBatched()) ? "" : "batchOffset + ") + index; } //---------------------------------------------------------------------------- -CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, +CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, precision, groups) { @@ -298,7 +298,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // Create type environment // **TEMP** parse precision to get scalar type - GroupMergedTypeEnvironment typeEnvironment(*this, getScalarType()); + GroupMergedTypeEnvironment typeEnvironment(*this, precision); // If underlying synapse group has kernel weights if (getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL) { @@ -379,7 +379,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // If variable has a transpose if(getArchetype().getVarReferences().at(v.name).getTransposeSynapseGroup() != nullptr) { // Add field with transpose suffix, pointing to transpose var - addField(parseNumeric(v.type)->getPointerType(), v.name + "Transpose", + addField(v.type->getPointerType(), v.name + "Transpose", [&backend, v](const auto &g, size_t) { const auto varRef = g.getVarReferences().at(v.name); @@ -426,7 +426,7 @@ void CustomUpdateTransposeWUGroupMerged::generateCustomUpdate(const BackendBase& //---------------------------------------------------------------------------- const std::string CustomUpdateHostReductionGroupMerged::name = "CustomUpdateHostReduction"; //---------------------------------------------------------------------------- -CustomUpdateHostReductionGroupMerged::CustomUpdateHostReductionGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, +CustomUpdateHostReductionGroupMerged::CustomUpdateHostReductionGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : CustomUpdateHostReductionGroupMergedBase(index, precision, backend, groups) { @@ -451,7 +451,7 @@ CustomUpdateHostReductionGroupMerged::CustomUpdateHostReductionGroupMerged(size_ //---------------------------------------------------------------------------- const std::string CustomWUUpdateHostReductionGroupMerged::name = "CustomWUUpdateHostReduction"; //---------------------------------------------------------------------------- -CustomWUUpdateHostReductionGroupMerged::CustomWUUpdateHostReductionGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, +CustomWUUpdateHostReductionGroupMerged::CustomWUUpdateHostReductionGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : CustomUpdateHostReductionGroupMergedBase(index, precision, backend, groups) { diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 927bf722fa..747456e333 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -152,15 +152,15 @@ void genVarPushPullScope(CodeStream &definitionsFunc, CodeStream &runnerPushFunc //------------------------------------------------------------------------- void genVarGetterScope(CodeStream &definitionsFunc, CodeStream &runnerGetterFunc, VarLocation loc, const std::string &description, - const std::string &type, std::function handler) + const std::string &typeName, std::function handler) { // If this variable has a location that allows pushing and pulling and hence getting a host pointer if(canPushPullVar(loc)) { // Export getter - definitionsFunc << "EXPORT_FUNC " << type << " get" << description << "(unsigned int batch = 0); " << std::endl; + definitionsFunc << "EXPORT_FUNC " << typeName << " get" << description << "(unsigned int batch = 0); " << std::endl; // Define getter - runnerGetterFunc << type << " get" << description << "(" << "unsigned int batch" << ")"; + runnerGetterFunc << typeName << " get" << description << "(" << "unsigned int batch" << ")"; { CodeStream::Scope a(runnerGetterFunc); handler(); @@ -251,10 +251,10 @@ void genStatePushPull(CodeStream &definitionsFunc, CodeStream &runnerPushFunc, C } //------------------------------------------------------------------------- void genVariable(const BackendBase &backend, CodeStream &definitionsVar, CodeStream &definitionsFunc, - CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - CodeStream &push, CodeStream &pull, const std::string &type, const std::string &name, - VarLocation loc, bool autoInitialized, size_t count, MemAlloc &mem, - std::vector &statePushPullFunction) + CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, + CodeStream &push, CodeStream &pull, const Type::Base *type, const std::string &name, + VarLocation loc, bool autoInitialized, size_t count, MemAlloc &mem, + std::vector &statePushPullFunction) { // Generate push and pull functions genVarPushPullScope(definitionsFunc, push, pull, loc, backend.getPreferences().automaticCopy, name, statePushPullFunction, @@ -1087,7 +1087,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Write getter to get access to correct pointer const bool delayRequired = (n.second.isVarQueueRequired(var.name) && n.second.isDelayRequired()); genVarGetterScope(definitionsFunc, runnerGetterFunc, n.second.getVarLocation(var.name), - "Current" + var.name + n.first, var.type + "*", + "Current" + var.name + n.first, var.type->getPointerType()->getResolvedName(modelMerged.getTypeContext()), [&]() { runnerGetterFunc << "return " << var.name << n.first; @@ -1327,7 +1327,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Target indices backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - s.second.getSparseIndType()->getName(modelMerged.getTypeContext()), "ind" + s.second.getName(), varLoc, size, mem); + s.second.getSparseIndType()->getResolvedName(modelMerged.getTypeContext()), "ind" + s.second.getName(), varLoc, size, mem); // **TODO** remap is not always required if(backend.isPostsynapticRemapRequired() && !s.second.getWUModel()->getLearnPostCode().empty()) { diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index fb67a37766..884d741e14 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -19,7 +19,7 @@ using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- const std::string NeuronSpikeQueueUpdateGroupMerged::name = "NeuronSpikeQueueUpdate"; //---------------------------------------------------------------------------- -NeuronSpikeQueueUpdateGroupMerged::NeuronSpikeQueueUpdateGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, +NeuronSpikeQueueUpdateGroupMerged::NeuronSpikeQueueUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, precision, groups) { @@ -68,7 +68,7 @@ void NeuronSpikeQueueUpdateGroupMerged::genMergedGroupSpikeCountReset(CodeStream //---------------------------------------------------------------------------- const std::string NeuronPrevSpikeTimeUpdateGroupMerged::name = "NeuronPrevSpikeTimeUpdate"; //---------------------------------------------------------------------------- -NeuronPrevSpikeTimeUpdateGroupMerged::NeuronPrevSpikeTimeUpdateGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, +NeuronPrevSpikeTimeUpdateGroupMerged::NeuronPrevSpikeTimeUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, precision, groups) { @@ -84,14 +84,13 @@ NeuronPrevSpikeTimeUpdateGroupMerged::NeuronPrevSpikeTimeUpdateGroupMerged(size_ addPointerField("spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); } - const NumericBase *timeType = parseNumeric(timePrecision); if(getArchetype().isPrevSpikeTimeRequired()) { addPointerField("spk", backend.getDeviceVarPrefix() + "glbSpk"); - addPointerField(timeType, "prevST", backend.getDeviceVarPrefix() + "prevST"); + addPointerField(timePrecision, "prevST", backend.getDeviceVarPrefix() + "prevST"); } if(getArchetype().isPrevSpikeEventTimeRequired()) { addPointerField("spkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); - addPointerField(timeType, "prevSET", backend.getDeviceVarPrefix() + "prevSET"); + addPointerField(timePrecision, "prevSET", backend.getDeviceVarPrefix() + "prevSET"); } if(getArchetype().isDelayRequired()) { @@ -184,15 +183,12 @@ bool NeuronGroupMergedBase::isPSMVarInitDerivedParamHeterogeneous(size_t childIn [varName](const SynapseGroupInternal *inSyn){ return inSyn->getPSVarInitialisers().at(varName).getDerivedParams(); })); } //---------------------------------------------------------------------------- -NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, +NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, bool init, const std::vector> &groups) : GroupMerged(index, precision, groups) { using namespace Type; - // **HACK** parse precisions - const NumericBase *timeType = parseNumeric(timePrecision); - // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group orderNeuronGroupChildren(m_SortedMergedInSyns, &NeuronGroupInternal::getFusedPSMInSyn, init ? &SynapseGroupInternal::getPSInitHashDigest : &SynapseGroupInternal::getPSHashDigest); @@ -221,17 +217,17 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr } if(getArchetype().isSpikeTimeRequired()) { - addPointerField(timeType, "sT", backend.getDeviceVarPrefix() + "sT"); + addPointerField(timePrecision, "sT", backend.getDeviceVarPrefix() + "sT"); } if(getArchetype().isSpikeEventTimeRequired()) { - addPointerField(timeType, "seT", backend.getDeviceVarPrefix() + "seT"); + addPointerField(timePrecision, "seT", backend.getDeviceVarPrefix() + "seT"); } if(getArchetype().isPrevSpikeTimeRequired()) { - addPointerField(timeType, "prevST", backend.getDeviceVarPrefix() + "prevST"); + addPointerField(timePrecision, "prevST", backend.getDeviceVarPrefix() + "prevST"); } if(getArchetype().isPrevSpikeEventTimeRequired()) { - addPointerField(timeType, "prevSET", backend.getDeviceVarPrefix() + "prevSET"); + addPointerField(timePrecision, "prevSET", backend.getDeviceVarPrefix() + "prevSET"); } // If this backend initialises population RNGs on device and this group requires on for simulation @@ -248,7 +244,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : vars) { // If we're not initialising or if there is initialization code for this variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addPointerField(parseNumeric(var.type), var.name, backend.getDeviceVarPrefix() + var.name); + addPointerField(var.type, var.name, backend.getDeviceVarPrefix() + var.name); } // If we're initializing, add any var init EGPs to structure @@ -301,7 +297,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : sg->getPSModel()->getVars()) { // Add pointers to state variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addMergedInSynPointerField(parseNumeric(var.type), var.name + "InSyn", i, backend.getDeviceVarPrefix() + var.name); + addMergedInSynPointerField(var.type, var.name + "InSyn", i, backend.getDeviceVarPrefix() + var.name); } // If we're generating an initialization structure, also add any heterogeneous parameters, derived parameters or extra global parameters required for initializers @@ -356,7 +352,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const std::string &pr for(const auto &var : cs->getCurrentSourceModel()->getVars()) { // Add pointers to state variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(parseNumeric(var.type)->getPointerType(), var.name + "CS" + std::to_string(i), + addField(var.type->getPointerType(), var.name + "CS" + std::to_string(i), [&backend, i, var, this](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + m_SortedCurrentSources.at(groupIndex).at(i)->getName(); @@ -665,15 +661,12 @@ std::string SynapseGroupMergedBase::getKernelVarIndex(unsigned int batchSize, Va return (singleBatch ? "" : "kernBatchOffset + ") + index; } //---------------------------------------------------------------------------- -SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, +SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, Role role, const std::string &archetypeCode, const std::vector> &groups) : GroupMerged(index, precision, groups), m_ArchetypeCode(archetypeCode) { using namespace Type; - // **HACK** parse precisions - const NumericBase *timeType = parseNumeric(timePrecision); - const bool updateRole = ((role == Role::PresynapticUpdate) || (role == Role::PostsynapticUpdate) || (role == Role::SynapseDynamics)); @@ -770,7 +763,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & for(const auto &v : preVars) { // If variable is referenced in code string, add source pointer if(code.find("$(" + v.name + "_pre)") != std::string::npos) { - addSrcPointerField(parseNumeric(v.type), v.name + "Pre", backend.getDeviceVarPrefix() + v.name); + addSrcPointerField(v.type, v.name + "Pre", backend.getDeviceVarPrefix() + v.name); } } @@ -779,7 +772,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & for(const auto &v : postVars) { // If variable is referenced in code string, add target pointer if(code.find("$(" + v.name + "_post)") != std::string::npos) { - addTrgPointerField(parseNumeric(v.type), v.name + "Post", backend.getDeviceVarPrefix() + v.name); + addTrgPointerField(v.type, v.name + "Post", backend.getDeviceVarPrefix() + v.name); } } @@ -807,22 +800,22 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // Add spike times if required if(wum->isPreSpikeTimeRequired()) { - addSrcPointerField(timeType, "sTPre", backend.getDeviceVarPrefix() + "sT"); + addSrcPointerField(timePrecision, "sTPre", backend.getDeviceVarPrefix() + "sT"); } if(wum->isPostSpikeTimeRequired()) { - addTrgPointerField(timeType, "sTPost", backend.getDeviceVarPrefix() + "sT"); + addTrgPointerField(timePrecision, "sTPost", backend.getDeviceVarPrefix() + "sT"); } if(wum->isPreSpikeEventTimeRequired()) { - addSrcPointerField(timeType, "seTPre", backend.getDeviceVarPrefix() + "seT"); + addSrcPointerField(timePrecision, "seTPre", backend.getDeviceVarPrefix() + "seT"); } if(wum->isPrevPreSpikeTimeRequired()) { - addSrcPointerField(timeType, "prevSTPre", backend.getDeviceVarPrefix() + "prevST"); + addSrcPointerField(timePrecision, "prevSTPre", backend.getDeviceVarPrefix() + "prevST"); } if(wum->isPrevPostSpikeTimeRequired()) { - addTrgPointerField(timeType, "prevSTPost", backend.getDeviceVarPrefix() + "prevST"); + addTrgPointerField(timePrecision, "prevSTPost", backend.getDeviceVarPrefix() + "prevST"); } if(wum->isPrevPreSpikeEventTimeRequired()) { - addSrcPointerField(timeType, "prevSETPre", backend.getDeviceVarPrefix() + "prevSET"); + addSrcPointerField(timePrecision, "prevSETPre", backend.getDeviceVarPrefix() + "prevSET"); } // Add heterogeneous weight update model parameters addHeterogeneousParams( @@ -839,14 +832,14 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // Add presynaptic variables to struct for(const auto &v : wum->getPreVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(parseNumeric(v.type)->getPointerType(), v.name, + addField(v.type->getPointerType(), v.name, [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPreVarSuffix(); }); } // Add presynaptic variables to struct for(const auto &v : wum->getPostVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(parseNumeric(v.type)->getPointerType(), v.name, + addField(v.type->getPointerType(), v.name, [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPostVarSuffix(); }); } @@ -971,7 +964,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const std::string & // If we're performing an update with individual weights; or this variable should be initialised if((updateRole && individualWeights) || (kernelWeights && updateRole) || varInitRequired) { - addPointerField(parseNumeric(var.type), var.name, backend.getDeviceVarPrefix() + var.name); + addPointerField(var.type, var.name, backend.getDeviceVarPrefix() + var.name); } // If we're performing a procedural update or this variable should be initialised, add any var init EGPs to structure diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 17fda21165..898ba27fbc 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -55,7 +55,7 @@ template void genInitNeuronVarCode(CodeStream &os, const BackendBase &backend, const Substitutions &popSubs, const Models::Base::VarVec &vars, const std::unordered_map &varInitialisers, const std::string &fieldSuffix, const std::string &countMember, - size_t numDelaySlots, const size_t groupIndex, const std::string &ftype, unsigned int batchSize, + size_t numDelaySlots, const size_t groupIndex, const Type::NumericBase *scalarType, unsigned int batchSize, Q isVarQueueRequired, P isParamHeterogeneousFn, D isDerivedParamHeterogeneousFn) { const std::string count = "group->" + countMember; @@ -82,7 +82,7 @@ void genInitNeuronVarCode(CodeStream &os, const BackendBase &backend, const Subs if (getVarAccessDuplication(var.access) == VarAccessDuplication::SHARED_NEURON) { backend.genPopVariableInit( os, varSubs, - [&var, &varInit, &fieldSuffix, &ftype, batchSize, groupIndex, numDelaySlots, isVarQueueRequired] + [&var, &varInit, &fieldSuffix, scalarType, batchSize, groupIndex, numDelaySlots, isVarQueueRequired] (CodeStream &os, Substitutions &varInitSubs) { // Generate initial value into temporary variable @@ -90,7 +90,7 @@ void genInitNeuronVarCode(CodeStream &os, const BackendBase &backend, const Subs varInitSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); varInitSubs.applyCheckUnreplaced(code, "initVar : " + var.name + "merged" + std::to_string(groupIndex)); - code = ensureFtype(code, ftype); + //code = ensureFtype(code, scalarType); os << code << std::endl; // Fill value across all delay slots and batches @@ -127,9 +127,9 @@ template void genInitNeuronVarCode(CodeStream &os, const BackendBase &backend, const Substitutions &popSubs, const Models::Base::VarVec &vars, const std::unordered_map &varInitialisers, const std::string &fieldSuffix, const std::string &countMember, const size_t groupIndex, - const std::string &ftype, unsigned int batchSize, P isParamHeterogeneousFn, D isDerivedParamHeterogeneousFn) + const Type::NumericBase *scalarType, unsigned int batchSize, P isParamHeterogeneousFn, D isDerivedParamHeterogeneousFn) { - genInitNeuronVarCode(os, backend, popSubs, vars, varInitialisers, fieldSuffix, countMember, 0, groupIndex, ftype, batchSize, + genInitNeuronVarCode(os, backend, popSubs, vars, varInitialisers, fieldSuffix, countMember, 0, groupIndex, scalarType, batchSize, [](const std::string&){ return false; }, isParamHeterogeneousFn, isDerivedParamHeterogeneousFn); @@ -139,7 +139,7 @@ void genInitNeuronVarCode(CodeStream &os, const BackendBase &backend, const Subs template void genInitWUVarCode(CodeStream &os, const Substitutions &popSubs, const Models::Base::VarVec &vars, const std::unordered_map &varInitialisers, - const std::string &stride, const size_t groupIndex, const std::string &ftype, unsigned int batchSize, + const std::string &stride, const size_t groupIndex, const Type::NumericBase *scalarType, unsigned int batchSize, P isParamHeterogeneousFn, D isDerivedParamHeterogeneousFn, G genSynapseVariableRowInitFn) { for (const auto &var : vars) { @@ -151,7 +151,7 @@ void genInitWUVarCode(CodeStream &os, const Substitutions &popSubs, // Generate target-specific code to initialise variable genSynapseVariableRowInitFn(os, popSubs, - [&var, &varInit, &ftype, &stride, batchSize, groupIndex, isParamHeterogeneousFn, isDerivedParamHeterogeneousFn] + [&var, &varInit, &stride, batchSize, groupIndex, isParamHeterogeneousFn, isDerivedParamHeterogeneousFn, scalarType] (CodeStream &os, Substitutions &varSubs) { varSubs.addParamValueSubstitution(varInit.getSnippet()->getParamNames(), varInit.getParams(), @@ -168,7 +168,7 @@ void genInitWUVarCode(CodeStream &os, const Substitutions &popSubs, varSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); varSubs.applyCheckUnreplaced(code, "initVar : merged" + var.name + std::to_string(groupIndex)); - code = ensureFtype(code, ftype); + //code = ensureFtype(code, scalarType); os << code << std::endl; // Fill value across all batches @@ -185,7 +185,7 @@ void genInitWUVarCode(CodeStream &os, const Substitutions &popSubs, //---------------------------------------------------------------------------- const std::string NeuronInitGroupMerged::name = "NeuronInit"; //---------------------------------------------------------------------------- -NeuronInitGroupMerged::NeuronInitGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, +NeuronInitGroupMerged::NeuronInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups) : NeuronGroupMergedBase(index, precision, timePrecision, backend, true, groups) { @@ -443,7 +443,7 @@ void NeuronInitGroupMerged::generateWUVar(const BackendBase &backend, for(const auto &var : vars) { // Add pointers to state variable if(!varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(parseNumeric(var.type)->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), + addField(var.type->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto&, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); @@ -658,7 +658,7 @@ void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase&, std::string code = varInit.getSnippet()->getCode(); //popSubs.applyCheckUnreplaced(code, "initVar : merged" + vars[k].name + std::to_string(sg.getIndex())); popSubs.apply(code); - code = ensureFtype(code, modelMerged.getModel().getPrecision()); + //code = ensureFtype(code, modelMerged.getModel().getPrecision()); os << code << std::endl; // Fill value across all batches @@ -668,7 +668,7 @@ void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase&, } } //---------------------------------------------------------------------------- -void SynapseConnectivityInitGroupMerged::genInitConnectivity(CodeStream &os, Substitutions &popSubs, const std::string &ftype, bool rowNotColumns) const +void SynapseConnectivityInitGroupMerged::genInitConnectivity(CodeStream &os, Substitutions &popSubs, const Type::NumericBase *scalarType, bool rowNotColumns) const { const auto &connectInit = getArchetype().getConnectivityInitialiser(); const auto *snippet = connectInit.getSnippet(); @@ -690,7 +690,7 @@ void SynapseConnectivityInitGroupMerged::genInitConnectivity(CodeStream &os, Sub // Apply substitutions to value std::string value = a.value; popSubs.applyCheckUnreplaced(value, "initSparseConnectivity state var : merged" + std::to_string(getIndex())); - value = ensureFtype(value, ftype); + //value = ensureFtype(value, ftype); os << a.type << " " << a.name << " = " << value << ";" << std::endl; } @@ -702,7 +702,7 @@ void SynapseConnectivityInitGroupMerged::genInitConnectivity(CodeStream &os, Sub std::string code = rowNotColumns ? snippet->getRowBuildCode() : snippet->getColBuildCode(); popSubs.addVarNameSubstitution(stateVars); popSubs.applyCheckUnreplaced(code, "initSparseConnectivity : merged" + std::to_string(getIndex())); - code = ensureFtype(code, ftype); + //code = ensureFtype(code, ftype); // Write out code os << code << std::endl; @@ -715,7 +715,7 @@ void SynapseConnectivityInitGroupMerged::genInitConnectivity(CodeStream &os, Sub //---------------------------------------------------------------------------- const std::string SynapseConnectivityHostInitGroupMerged::name = "SynapseConnectivityHostInit"; //------------------------------------------------------------------------ -SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(size_t index, const std::string &precision, const std::string&, const BackendBase &backend, +SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, precision, groups) { @@ -826,7 +826,7 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac } std::string code = connectInit.getSnippet()->getHostInitCode(); subs.applyCheckUnreplaced(code, "hostInitSparseConnectivity : merged" + std::to_string(getIndex())); - code = ensureFtype(code, modelMerged.getModel().getPrecision()); + //code = ensureFtype(code, modelMerged.getModel().getPrecision()); // Write out code os << code << std::endl; @@ -857,7 +857,7 @@ bool SynapseConnectivityHostInitGroupMerged::isSparseConnectivityInitParamRefere //---------------------------------------------------------------------------- const std::string CustomUpdateInitGroupMerged::name = "CustomUpdateInit"; //---------------------------------------------------------------------------- -CustomUpdateInitGroupMerged::CustomUpdateInitGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, +CustomUpdateInitGroupMerged::CustomUpdateInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, precision, backend, groups) { @@ -892,7 +892,7 @@ void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, CodeS //---------------------------------------------------------------------------- const std::string CustomWUUpdateInitGroupMerged::name = "CustomWUUpdateInit"; //---------------------------------------------------------------------------- -CustomWUUpdateInitGroupMerged::CustomWUUpdateInitGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, +CustomWUUpdateInitGroupMerged::CustomWUUpdateInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, precision, backend, groups) { @@ -1002,7 +1002,7 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Cod //---------------------------------------------------------------------------- const std::string CustomWUUpdateSparseInitGroupMerged::name = "CustomWUUpdateSparseInit"; //---------------------------------------------------------------------------- -CustomWUUpdateSparseInitGroupMerged::CustomWUUpdateSparseInitGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, +CustomWUUpdateSparseInitGroupMerged::CustomWUUpdateSparseInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, precision, backend, groups) { @@ -1074,7 +1074,7 @@ void CustomWUUpdateSparseInitGroupMerged::generateInit(const BackendBase &backen //---------------------------------------------------------------------------- const std::string CustomConnectivityUpdatePreInitGroupMerged::name = "CustomConnectivityUpdatePreInit"; //---------------------------------------------------------------------------- -CustomConnectivityUpdatePreInitGroupMerged::CustomConnectivityUpdatePreInitGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, +CustomConnectivityUpdatePreInitGroupMerged::CustomConnectivityUpdatePreInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, precision, backend, groups) { @@ -1122,7 +1122,7 @@ void CustomConnectivityUpdatePreInitGroupMerged::generateInit(const BackendBase //---------------------------------------------------------------------------- const std::string CustomConnectivityUpdatePostInitGroupMerged::name = "CustomConnectivityUpdatePostInit"; //---------------------------------------------------------------------------- -CustomConnectivityUpdatePostInitGroupMerged::CustomConnectivityUpdatePostInitGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, +CustomConnectivityUpdatePostInitGroupMerged::CustomConnectivityUpdatePostInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, precision, backend, groups) { @@ -1163,7 +1163,7 @@ void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase //---------------------------------------------------------------------------- const std::string CustomConnectivityUpdateSparseInitGroupMerged::name = "CustomConnectivityUpdateSparseInit"; //---------------------------------------------------------------------------- -CustomConnectivityUpdateSparseInitGroupMerged::CustomConnectivityUpdateSparseInitGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, +CustomConnectivityUpdateSparseInitGroupMerged::CustomConnectivityUpdateSparseInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, precision, backend, groups) { diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 81a3892889..40d2f2febc 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -11,7 +11,7 @@ using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- const std::string NeuronUpdateGroupMerged::name = "NeuronUpdate"; //---------------------------------------------------------------------------- -NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string &precision, const std::string &timePrecision, const BackendBase &backend, +NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, const std::vector> &groups) : NeuronGroupMergedBase(index, precision, timePrecision, backend, false, groups) { @@ -80,7 +80,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const std::string for(const auto &var : sgPreVars) { // If variable is referenced in event threshold code if(s.eventThresholdCode.find("$(" + var.name + ")") != std::string::npos) { - addField(parseNumeric(var.type)->getPointerType(), var.name + "EventThresh" + std::to_string(i), + addField(var.type->getPointerType(), var.name + "EventThresh" + std::to_string(i), [&backend, eventThresholdSGs, var, i](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + eventThresholdSGs.at(groupIndex).at(i)->getName(); @@ -247,7 +247,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C // Apply substitutions to value std::string value = a.value; neuronSubs.applyCheckUnreplaced(value, "neuron additional input var : merged" + std::to_string(getIndex())); - value = ensureFtype(value, modelMerged.getModel().getPrecision()); + //value = ensureFtype(value, modelMerged.getModel().getPrecision()); os << a.type << " " << a.name << " = " << value << ";" << std::endl; } @@ -304,12 +304,12 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C // Apply substitutions to current converter code std::string psCode = psm->getApplyInputCode(); inSynSubs.applyCheckUnreplaced(psCode, "postSyntoCurrent : merged " + std::to_string(i)); - psCode = ensureFtype(psCode, model.getPrecision()); + //psCode = ensureFtype(psCode, model.getPrecision()); // Apply substitutions to decay code std::string pdCode = psm->getDecayCode(); inSynSubs.applyCheckUnreplaced(pdCode, "decayCode : merged " + std::to_string(i)); - pdCode = ensureFtype(pdCode, model.getPrecision()); + //pdCode = ensureFtype(pdCode, model.getPrecision()); if (!psm->getSupportCode().empty() && backend.supportsNamespace()) { os << "using namespace " << modelMerged.getPostsynapticDynamicsSupportCodeNamespace(psm->getSupportCode()) << ";" << std::endl; @@ -383,7 +383,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C std::string iCode = csm->getInjectionCode(); currSourceSubs.applyCheckUnreplaced(iCode, "injectionCode : merged" + std::to_string(i)); - iCode = ensureFtype(iCode, model.getPrecision()); + //iCode = ensureFtype(iCode, model.getPrecision()); os << iCode << std::endl; // Write read/write variables back to global memory @@ -405,7 +405,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C os << "// test whether spike condition was fulfilled previously" << std::endl; neuronSubs.applyCheckUnreplaced(thCode, "thresholdConditionCode : merged" + std::to_string(getIndex())); - thCode= ensureFtype(thCode, model.getPrecision()); + //thCode= ensureFtype(thCode, model.getPrecision()); if (!nm->getSupportCode().empty() && !backend.supportsNamespace()) { thCode = disambiguateNamespaceFunction(nm->getSupportCode(), thCode, modelMerged.getNeuronUpdateSupportCodeNamespace(nm->getSupportCode())); @@ -425,7 +425,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C os << "// calculate membrane potential" << std::endl; std::string sCode = nm->getSimCode(); neuronSubs.applyCheckUnreplaced(sCode, "simCode : merged" + std::to_string(getIndex())); - sCode = ensureFtype(sCode, model.getPrecision()); + //sCode = ensureFtype(sCode, model.getPrecision()); if (!nm->getSupportCode().empty() && !backend.supportsNamespace()) { sCode = disambiguateNamespaceFunction(nm->getSupportCode(), sCode, modelMerged.getNeuronUpdateSupportCodeNamespace(nm->getSupportCode())); @@ -434,7 +434,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C os << sCode << std::endl; // Generate var update for outgoing synaptic populations with presynaptic update code - generateWUVarUpdate(os, popSubs, "WUPre", modelMerged.getModel().getPrecision(), "_pre", true, batchSize, + generateWUVarUpdate(os, popSubs, "WUPre", "_pre", true, batchSize, getSortedArchetypeOutSynWithPreCode(), &SynapseGroupInternal::getDelaySteps, &WeightUpdateModels::Base::getPreVars, &WeightUpdateModels::Base::getPreDynamicsCode, &NeuronUpdateGroupMerged::isOutSynWUMParamHeterogeneous, @@ -442,7 +442,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C // Generate var update for incoming synaptic populations with postsynaptic code - generateWUVarUpdate(os, popSubs, "WUPost", modelMerged.getModel().getPrecision(), "_post", true, batchSize, + generateWUVarUpdate(os, popSubs, "WUPost", "_post", true, batchSize, getSortedArchetypeInSynWithPostCode(), &SynapseGroupInternal::getBackPropDelaySteps, &WeightUpdateModels::Base::getPostVars, &WeightUpdateModels::Base::getPostDynamicsCode, &NeuronUpdateGroupMerged::isInSynWUMParamHeterogeneous, @@ -477,7 +477,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C std::string eCode = spkEventCond.eventThresholdCode; spkEventCondSubs.applyCheckUnreplaced(eCode, "neuronSpkEvntCondition : merged" + std::to_string(getIndex())); - eCode = ensureFtype(eCode, model.getPrecision()); + //eCode = ensureFtype(eCode, model.getPrecision()); // Open scope for spike-like event test os << CodeStream::OB(31); @@ -537,7 +537,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C if (!nm->getResetCode().empty()) { std::string rCode = nm->getResetCode(); neuronSubs.applyCheckUnreplaced(rCode, "resetCode : merged" + std::to_string(getIndex())); - rCode = ensureFtype(rCode, model.getPrecision()); + //rCode = ensureFtype(rCode, model.getPrecision()); os << "// spike reset code" << std::endl; os << rCode << std::endl; @@ -634,7 +634,7 @@ void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase&, CodeStream { // Generate var update for outgoing synaptic populations with presynaptic update code const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - generateWUVarUpdate(os, popSubs, "WUPre", modelMerged.getModel().getPrecision(), "_pre", false, batchSize, + generateWUVarUpdate(os, popSubs, "WUPre", "_pre", false, batchSize, getSortedArchetypeOutSynWithPreCode(), &SynapseGroupInternal::getDelaySteps, &WeightUpdateModels::Base::getPreVars, &WeightUpdateModels::Base::getPreSpikeCode, &NeuronUpdateGroupMerged::isOutSynWUMParamHeterogeneous, @@ -642,7 +642,7 @@ void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase&, CodeStream // Generate var update for incoming synaptic populations with postsynaptic code - generateWUVarUpdate(os, popSubs, "WUPost", modelMerged.getModel().getPrecision(), "_post", false, batchSize, + generateWUVarUpdate(os, popSubs, "WUPost", "_post", false, batchSize, getSortedArchetypeInSynWithPostCode(), &SynapseGroupInternal::getBackPropDelaySteps, &WeightUpdateModels::Base::getPostVars, &WeightUpdateModels::Base::getPostSpikeCode, &NeuronUpdateGroupMerged::isInSynWUMParamHeterogeneous, @@ -716,7 +716,7 @@ void NeuronUpdateGroupMerged::generateWUVar(const BackendBase &backend, const s for(size_t v = 0; v < vars.size(); v++) { // Add pointers to state variable const auto var = vars[v]; - addField(Type::parseNumeric(var.type)->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), + addField(var.type->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto &, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); @@ -767,7 +767,7 @@ void NeuronUpdateGroupMerged::addNeuronModelSubstitutions(Substitutions &substit } //-------------------------------------------------------------------------- void NeuronUpdateGroupMerged::generateWUVarUpdate(CodeStream &os, const Substitutions &popSubs, - const std::string &fieldPrefixStem, const std::string &precision, const std::string &sourceSuffix, + const std::string &fieldPrefixStem, const std::string &sourceSuffix, bool useLocalNeuronVars, unsigned int batchSize, const std::vector &archetypeSyn, unsigned int(SynapseGroupInternal::*getDelaySteps)(void) const, @@ -821,7 +821,7 @@ void NeuronUpdateGroupMerged::generateWUVarUpdate(CodeStream &os, const Substitu // Perform standard substitutions subs.applyCheckUnreplaced(code, "spikeCode : merged" + std::to_string(i)); - code = ensureFtype(code, precision); + //code = ensureFtype(code, precision); os << code; // Write back presynaptic variables into global memory diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 98c50035e1..56e990f14f 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -167,7 +167,7 @@ void applySynapseSubstitutions(CodeStream &os, std::string code, const std::stri synapseSubs.apply(code); //synapseSubs.applyCheckUnreplaced(code, errorContext + " : " + sg.getName()); - code = ensureFtype(code, model.getPrecision()); + //code = ensureFtype(code, model.getPrecision()); os << code; } } // Anonymous namespace @@ -209,7 +209,7 @@ void PresynapticUpdateGroupMerged::generateSpikeEventThreshold(const BackendBase // Get event threshold condition code std::string code = wum->getEventThresholdConditionCode(); synapseSubs.applyCheckUnreplaced(code, "eventThresholdConditionCode"); - code = ensureFtype(code, modelMerged.getModel().getPrecision()); + //code = ensureFtype(code, modelMerged.getModel().getPrecision()); if (!backend.supportsNamespace() && !wum->getSimSupportCode().empty()) { code = disambiguateNamespaceFunction(wum->getSimSupportCode(), code, modelMerged.getPresynapticUpdateSupportCodeNamespace(wum->getSimSupportCode())); @@ -250,7 +250,7 @@ void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendB // Apply substitutions to value std::string value = a.value; popSubs.applyCheckUnreplaced(value, "proceduralSparseConnectivity row build state var : merged" + std::to_string(getIndex())); - value = ensureFtype(value, modelMerged.getModel().getPrecision()); + //value = ensureFtype(value, modelMerged.getModel().getPrecision()); os << a.type << " " << a.name << " = " << value << ";" << std::endl; } @@ -263,7 +263,7 @@ void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendB std::string pCode = connectInit.getSnippet()->getRowBuildCode(); popSubs.applyCheckUnreplaced(pCode, "proceduralSparseConnectivity : merged " + std::to_string(getIndex())); - pCode = ensureFtype(pCode, modelMerged.getModel().getPrecision()); + //pCode = ensureFtype(pCode, modelMerged.getModel().getPrecision()); // Write out code os << pCode << std::endl; @@ -277,7 +277,7 @@ void PresynapticUpdateGroupMerged::generateToeplitzConnectivity(const BackendBas // Apply substitutions to diagonal building code std::string pCode = connectInit.getSnippet()->getDiagonalBuildCode(); popSubs.applyCheckUnreplaced(pCode, "toeplitzSparseConnectivity : merged " + std::to_string(getIndex())); - pCode = ensureFtype(pCode, modelMerged.getModel().getPrecision()); + //pCode = ensureFtype(pCode, modelMerged.getModel().getPrecision()); // Write out code os << pCode << std::endl; @@ -321,7 +321,7 @@ void SynapseDynamicsGroupMerged::generateSynapseUpdate(const BackendBase &backen //---------------------------------------------------------------------------- const std::string SynapseDendriticDelayUpdateGroupMerged::name = "SynapseDendriticDelayUpdate"; //---------------------------------------------------------------------------- -SynapseDendriticDelayUpdateGroupMerged::SynapseDendriticDelayUpdateGroupMerged(size_t index, const std::string &precision, const std::string &, const BackendBase &backend, +SynapseDendriticDelayUpdateGroupMerged::SynapseDendriticDelayUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, precision, groups) { diff --git a/src/genn/genn/modelSpec.cc b/src/genn/genn/modelSpec.cc index 295f653cbd..5aa032abc3 100644 --- a/src/genn/genn/modelSpec.cc +++ b/src/genn/genn/modelSpec.cc @@ -36,33 +36,17 @@ namespace GeNN { ModelSpec::ModelSpec() -: m_TimePrecision(TimePrecision::DEFAULT), m_DT(0.5), m_TimingEnabled(false), m_Seed(0), +: m_Precision(Type::Float::getInstance()), m_TimePrecision(nullptr), m_DT(0.5), m_TimingEnabled(false), m_Seed(0), m_DefaultVarLocation(VarLocation::HOST_DEVICE), m_DefaultExtraGlobalParamLocation(VarLocation::HOST_DEVICE), m_DefaultSparseConnectivityLocation(VarLocation::HOST_DEVICE), m_DefaultNarrowSparseIndEnabled(false), m_ShouldFusePostsynapticModels(false), m_ShouldFusePrePostWeightUpdateModels(false), m_BatchSize(1) { - setPrecision(ScalarPrecision::FLOAT); } // --------------------------------------------------------------------------- ModelSpec::~ModelSpec() { } // --------------------------------------------------------------------------- -std::string ModelSpec::getTimePrecision() const -{ - // If time precision is set to match model precision - if(m_TimePrecision == TimePrecision::DEFAULT) { - return getPrecision(); - } - // Otherwise return appropriate type - else if(m_TimePrecision == TimePrecision::FLOAT) { - return "float"; - } - else { - return "double"; - } -} -// --------------------------------------------------------------------------- unsigned int ModelSpec::getNumNeurons() const { // Return sum of local neuron group sizes @@ -202,21 +186,6 @@ CustomUpdateWU *ModelSpec::addCustomUpdate(const std::string &name, const std::s } } // --------------------------------------------------------------------------- -void ModelSpec::setPrecision(ScalarPrecision scalarPrecision) -{ - switch (scalarPrecision) { - case ScalarPrecision::FLOAT: - m_Precision = "float"; - break; - case ScalarPrecision::DOUBLE: - m_Precision = "double"; // not supported by compute capability < 1.3 - break; - case ScalarPrecision::LONG_DOUBLE: - m_Precision = "long double"; // not supported by CUDA at the moment. - break; - } -} -// --------------------------------------------------------------------------- void ModelSpec::finalize() { // NEURON GROUPS @@ -322,19 +291,6 @@ void ModelSpec::finalize() } } // --------------------------------------------------------------------------- -std::string ModelSpec::scalarExpr(double val) const -{ - if (m_Precision == "float") { - return Utils::writePreciseString(val) + "f"; - } - else if (m_Precision == "double") { - return Utils::writePreciseString(val); - } - else { - throw std::runtime_error("Unrecognised floating-point type."); - } -} -// --------------------------------------------------------------------------- bool ModelSpec::zeroCopyInUse() const { // If any neuron groups use zero copy return true @@ -393,8 +349,8 @@ boost::uuids::detail::sha1::digest_type ModelSpec::getHashDigest() const boost::uuids::detail::sha1 hash; Utils::updateHash(getName(), hash); - Utils::updateHash(getPrecision(), hash); - Utils::updateHash(getTimePrecision(), hash); + Utils::updateHash(getPrecision()->getName(), hash); + Utils::updateHash(getTimePrecision()->getName(), hash); Utils::updateHash(getDT(), hash); Utils::updateHash(isTimingEnabled(), hash); Utils::updateHash(getBatchSize(), hash); diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 6ad949f3c6..cba22f803f 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -6,12 +6,41 @@ #include "currentSourceInternal.h" #include "neuronGroupInternal.h" #include "synapseGroupInternal.h" +#include "type.h" //---------------------------------------------------------------------------- -// GeNN::Models::Base +// GeNN::Models::Base::Var //---------------------------------------------------------------------------- namespace GeNN::Models { +Base::Var::Var(const std::string &n, const std::string &t, VarAccess a) : name(n), type(Type::parseNumeric(t)), access(a) +{} +//---------------------------------------------------------------------------- +Base::Var::Var(const std::string &n, const std::string &t) : Var(n, t, VarAccess::READ_WRITE) +{} +//---------------------------------------------------------------------------- +bool Base::Var::operator == (const Var &other) const +{ + return (std::make_tuple(name, type->getName(), access) == std::make_tuple(other.name, other.type->getName(), other.access)); +} + +//---------------------------------------------------------------------------- +// GeNN::Models::Base::VarRef +//---------------------------------------------------------------------------- +Base::VarRef::VarRef(const std::string &n, const std::string &t, VarAccessMode a) : name(n), type(Type::parseNumeric(t)), access(a) +{} +//---------------------------------------------------------------------------- +Base::VarRef::VarRef(const std::string &n, const std::string &t) : VarRef(n, t, VarAccessMode::READ_WRITE) +{} +//---------------------------------------------------------------------------- +bool Base::VarRef::operator == (const VarRef &other) const +{ + return (std::make_tuple(name, type->getName(), access) == std::make_tuple(other.name, other.type->getName(), other.access)); +} + +//---------------------------------------------------------------------------- +// GeNN::Models::Base +//---------------------------------------------------------------------------- void Base::updateHash(boost::uuids::detail::sha1 &hash) const { // Superclass @@ -210,14 +239,14 @@ SynapseGroup *WUVarReference::getTransposeSynapseGroup() const void updateHash(const Base::Var &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.name, hash); - Utils::updateHash(v.type, hash); + Utils::updateHash(v.type->getName(), hash); Utils::updateHash(v.access, hash); } //---------------------------------------------------------------------------- void updateHash(const Base::VarRef &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.name, hash); - Utils::updateHash(v.type, hash); + Utils::updateHash(v.type->getName(), hash); Utils::updateHash(v.access, hash); } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 8b348edfda..68325e193b 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -141,7 +141,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto indexNumericType = dynamic_cast(indexType); if (!indexNumericType || !indexNumericType->isIntegral(m_Context)) { m_ErrorHandler.error(arraySubscript.getPointerName(), - "Invalid subscript index type '" + indexType->getName(m_Context) + "'"); + "Invalid subscript index type '" + indexType->getName() + "'"); throw TypeCheckError(); } @@ -178,8 +178,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto rightPointerType = dynamic_cast(rightType); if (leftPointerType && rightPointerType && opType == Token::Type::MINUS) { // Check pointers are compatible - if (leftPointerType->getName(m_Context) != rightPointerType->getName(m_Context)) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName(m_Context) + "' and '" + rightType->getName(m_Context)); + if (leftPointerType->getResolvedName(m_Context) != rightPointerType->getResolvedName(m_Context)) { + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } @@ -191,7 +191,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Check that numeric operand is integer if (!rightNumericType->isIntegral(m_Context)) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName(m_Context) + "' and '" + rightType->getName(m_Context)); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } @@ -203,7 +203,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Check that numeric operand is integer if (!leftNumericType->isIntegral(m_Context)) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName(m_Context) + "' and '" + rightType->getName(m_Context)); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } @@ -219,7 +219,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Check that operands are integers if (!leftNumericType->isIntegral(m_Context) || !rightNumericType->isIntegral(m_Context)) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName(m_Context) + "' and '" + rightType->getName(m_Context)); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } @@ -239,7 +239,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } } else { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName(m_Context) + "' and '" + rightType->getName(m_Context)); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } } @@ -288,7 +288,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If const is being removed if (rightType->hasQualifier(Type::Qualifier::CONSTANT) && !cast.getType()->hasQualifier(Type::Qualifier::CONSTANT)) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName(m_Context) + "' and '" + rightType->getName(m_Context)); + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } @@ -298,14 +298,14 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto leftNumericType = dynamic_cast(cast.getType()); auto leftPointerType = dynamic_cast(cast.getType()); if (rightPointerType && leftPointerType) { - if (rightPointerType->getName(m_Context) != leftPointerType->getName(m_Context)) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName(m_Context) + "' and '" + rightType->getName(m_Context)); + if (rightPointerType->getResolvedName(m_Context) != leftPointerType->getResolvedName(m_Context)) { + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } } // Otherwise, if either operand isn't numeric else if(!leftNumericType | !rightNumericType) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName(m_Context) + "' and '" + rightType->getName(m_Context)); + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } @@ -327,7 +327,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } else { m_ErrorHandler.error(conditional.getQuestion(), - "Invalid operand types '" + trueType->getName(m_Context) + "' and '" + falseType->getName(m_Context) + "' to conditional"); + "Invalid operand types '" + trueType->getName() + "' and '" + falseType->getName() + "' to conditional"); throw TypeCheckError(); } } @@ -395,7 +395,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto rightPointerType = dynamic_cast(rightType); if (!rightPointerType) { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getName(m_Context) + "'"); + "Invalid operand type '" + rightType->getName() + "'"); throw TypeCheckError(); } @@ -420,7 +420,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getName(m_Context) + "'"); + "Invalid operand type '" + rightType->getName() + "'"); throw TypeCheckError(); } } @@ -435,7 +435,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getName(m_Context) + "'"); + "Invalid operand type '" + rightType->getName() + "'"); throw TypeCheckError(); } } @@ -525,7 +525,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto valNumericType = dynamic_cast(valType); if (!valNumericType || !valNumericType->isIntegral(m_Context)) { m_ErrorHandler.error(labelled.getKeyword(), - "Invalid case value '" + valType->getName(m_Context) + "'"); + "Invalid case value '" + valType->getName() + "'"); throw TypeCheckError(); } } @@ -539,7 +539,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto condNumericType = dynamic_cast(condType); if (!condNumericType || !condNumericType->isIntegral(m_Context)) { m_ErrorHandler.error(switchStatement.getSwitch(), - "Invalid condition '" + condType->getName(m_Context) + "'"); + "Invalid condition '" + condType->getName() + "'"); throw TypeCheckError(); } @@ -624,19 +624,19 @@ const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, if (pointerAssignedType && pointerExistingType) { // If we're trying to assign a pointer to a const value to a pointer if (assignedType->hasQualifier(Type::Qualifier::CONSTANT) && !existingType->hasQualifier(Type::Qualifier::CONSTANT)) { - errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName(context) + "' and '" + pointerAssignedType->getName(context)); + errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName() + "' and '" + pointerAssignedType->getName()); throw TypeCheckError(); } // If pointer types aren't compatible - if (pointerExistingType->getName(context) != pointerAssignedType->getName(context)) { - errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName(context) + "' and '" + pointerAssignedType->getName(context)); + if (pointerExistingType->getResolvedName(context) != pointerAssignedType->getResolvedName(context)) { + errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName() + "' and '" + pointerAssignedType->getName()); throw TypeCheckError(); } } // Otherwise, if we're trying to initialise a pointer with a non-pointer or vice-versa else if (pointerAssignedType || pointerExistingType) { - errorHandler.error(name, "Invalid operand types '" + existingType->getName(context) + "' and '" + assignedType->getName(context)); + errorHandler.error(name, "Invalid operand types '" + existingType->getName() + "' and '" + assignedType->getName()); throw TypeCheckError(); } } @@ -645,13 +645,13 @@ const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, // If the operand being added isn't numeric or the type being added to is neither numeric or a pointer if (!numericAssignedType || (!pointerExistingType && !numericExistingType)) { - errorHandler.error(name, "Invalid operand types '" + existingType->getName(context) + "' and '" + assignedType->getName(context) + "'"); + errorHandler.error(name, "Invalid operand types '" + existingType->getName() + "' and '" + assignedType->getName() + "'"); throw TypeCheckError(); } // If we're adding a numeric type to a pointer, check it's an integer if (pointerExistingType && numericAssignedType->isIntegral(context)) { - errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName(context) + "'"); + errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName() + "'"); throw TypeCheckError(); } } @@ -659,22 +659,22 @@ const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, else { // If either type is non-numeric, give error if(!numericAssignedType) { - errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName(context) + "'"); + errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName() + "'"); throw TypeCheckError(); } if(!numericExistingType) { - errorHandler.error(name, "Invalid operand types '" + existingType->getName(context) + "'"); + errorHandler.error(name, "Invalid operand types '" + existingType->getName() + "'"); throw TypeCheckError(); } // If operand isn't one that takes any numeric type, check both operands are integral if (op != Token::Type::STAR_EQUAL && op != Token::Type::SLASH_EQUAL) { if(!numericAssignedType->isIntegral(context)) { - errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName(context) + "'"); + errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName() + "'"); throw TypeCheckError(); } if(!numericExistingType->isIntegral(context)) { - errorHandler.error(name, "Invalid operand types '" + numericExistingType->getName(context) + "'"); + errorHandler.error(name, "Invalid operand types '" + numericExistingType->getName() + "'"); throw TypeCheckError(); } } diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 2f9aa118df..c72db89cdf 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -90,9 +90,9 @@ const Base *Base::getPointerType(Qualifier qualifiers) const //---------------------------------------------------------------------------- // GeNN::Type::NumericTypedef //---------------------------------------------------------------------------- -std::string NumericTypedef::getName(const TypeContext &context) const +std::string NumericTypedef::getResolvedName(const TypeContext &context) const { - return getNumeric(context)->getName(context); + return getNumeric(context)->getResolvedName(context); } //---------------------------------------------------------------------------- size_t NumericTypedef::getSizeBytes(const TypeContext &context) const @@ -147,7 +147,7 @@ const Type::NumericBase *NumericTypedef::getNumeric(const TypeContext &context) return numericType; } else { - throw std::runtime_error("Numeric typedef '" + m_Name + "' resolved to non-numeric type '" + t->second->getName(context) + "'"); + throw std::runtime_error("Numeric typedef '" + m_Name + "' resolved to non-numeric type '" + t->second->getName() + "'"); } } } @@ -229,13 +229,13 @@ const NumericBase *getPromotedType(const NumericBase *type, const TypeContext &c const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b, const TypeContext &context) { // If either type is double, common type is double - const auto &aTypeName = a->getName(context); - const auto &bTypeName = b->getName(context); - if(aTypeName == Double::getInstance()->getName(context) || bTypeName == Double::getInstance()->getName(context)) { + const auto &aTypeName = a->getResolvedName(context); + const auto &bTypeName = b->getResolvedName(context); + if(aTypeName == Double::getInstance()->getName() || bTypeName == Double::getInstance()->getName()) { return Double::getInstance(); } // Otherwise, if either type is float, common type is float - if(aTypeName == Float::getInstance()->getName(context) || bTypeName == Float::getInstance()->getName(context)) { + if(aTypeName == Float::getInstance()->getName() || bTypeName == Float::getInstance()->getName()) { return Float::getInstance(); } // Otherwise, must be an integer type @@ -245,7 +245,7 @@ const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b, con const auto *bPromoted = getPromotedType(b, context); // If both promoted operands have the same type, then no further conversion is needed. - if(aPromoted->getName(context) == bPromoted->getName(context)) { + if(aPromoted->getResolvedName(context) == bPromoted->getResolvedName(context)) { return aPromoted; } // Otherwise, if both promoted operands have signed integer numeric types or both have unsigned integer numeric types, From 5ef6743816e950a7d24308a9c265c258fae32f37 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 18 Jan 2023 14:31:26 +0000 Subject: [PATCH 066/725] lots more hacking * ModelSpecMerged and all GroupMerged classes should own TypeContexts rather than passing around precision/time precision --- .../genn/genn/code_generator/backendBase.h | 31 +++- .../genn/genn/code_generator/backendSIMT.h | 10 +- .../genn/genn/code_generator/codeGenUtils.h | 2 +- .../customConnectivityUpdateGroupMerged.h | 14 +- .../code_generator/customUpdateGroupMerged.h | 42 ++--- .../genn/genn/code_generator/groupMerged.h | 94 +++++----- .../groupMergedTypeEnvironment.h | 34 ++-- .../genn/code_generator/initGroupMerged.h | 76 ++++---- .../genn/code_generator/modelSpecMerged.h | 72 +++---- .../code_generator/neuronUpdateGroupMerged.h | 6 +- .../code_generator/synapseUpdateGroupMerged.h | 30 +-- include/genn/genn/gennUtils.h | 10 +- include/genn/genn/modelSpecInternal.h | 2 - include/genn/genn/transpiler/prettyPrinter.h | 5 +- include/genn/genn/type.h | 57 +++--- src/genn/genn/code_generator/backendSIMT.cc | 26 +-- src/genn/genn/code_generator/codeGenUtils.cc | 24 +-- .../customConnectivityUpdateGroupMerged.cc | 14 +- .../code_generator/customUpdateGroupMerged.cc | 22 +-- .../genn/code_generator/generateRunner.cc | 175 +++++++++--------- src/genn/genn/code_generator/groupMerged.cc | 42 ++--- .../genn/code_generator/initGroupMerged.cc | 32 ++-- .../genn/code_generator/modelSpecMerged.cc | 2 +- .../code_generator/neuronUpdateGroupMerged.cc | 4 +- .../presynapticUpdateStrategySIMT.cc | 2 +- .../synapseUpdateGroupMerged.cc | 4 +- src/genn/genn/customConnectivityUpdate.cc | 4 +- src/genn/genn/customUpdate.cc | 4 +- src/genn/genn/synapseGroup.cc | 6 +- src/genn/genn/transpiler/prettyPrinter.cc | 27 ++- src/genn/genn/type.cc | 10 + 31 files changed, 471 insertions(+), 412 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 88e91a0121..61c24f1c05 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -408,6 +408,15 @@ class GENN_EXPORT BackendBase genVariablePull(pull, type, name, loc, count); } + //! Templated version of gelper function to generate matching push and pull functions for + //! a variable when type is known at compile time + template + void genVariablePushPull(CodeStream &push, CodeStream &pull, + const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const + { + genVariablePushPull(push, pull, T::getInstance(), name, loc, autoInitialized, count); + } + //! Helper function to generate matching push and pull functions for the current state of a variable void genCurrentVariablePushPull(CodeStream &push, CodeStream &pull, const NeuronGroupInternal &ng, const Type::Base *type, const std::string &name, VarLocation loc, unsigned int batchSize) const @@ -416,6 +425,15 @@ class GENN_EXPORT BackendBase genCurrentVariablePull(pull, ng, type, name, loc, batchSize); } + //! Templated version of gelper function to generate matching push and pull functions + //! for the current state of variable when type is known at compile time + template + void genCurrentVariablePushPull(CodeStream &push, CodeStream &pull, const NeuronGroupInternal &ng, + const std::string &name, VarLocation loc, unsigned int batchSize) const + { + genCurrentVariablePushPull(push, pull, ng, T::getInstance(), name, loc, batchSize); + } + //! Helper function to generate matching definition, declaration, allocation and free code for an array void genArray(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, const Type::Base *type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const @@ -426,6 +444,15 @@ class GENN_EXPORT BackendBase genVariableAllocation(allocations, type, name, loc, count, memAlloc); } + //! Templated version of helper function to generate matching definition, declaration, + //! allocation and free code for an array when type is known at compile-time + template + void genArray(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, + const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const + { + genArray(definitions, definitionsInternal, runner, allocations, free, T::getInstance(), name, loc, count, memAlloc); + + } //! Get the prefix for accessing the address of 'scalar' variables std::string getScalarAddressPrefix() const { @@ -446,13 +473,13 @@ class GENN_EXPORT BackendBase //! Simple struct to hold reduction targets struct ReductionTarget { - ReductionTarget(const std::string &n, const std::string &t, VarAccessMode a, const std::string &i) + ReductionTarget(const std::string &n, const Type::NumericBase *t, VarAccessMode a, const std::string &i) : name(n), type(t), access(a), index(i) { } const std::string name; - const std::string type; + const Type::NumericBase *type; const VarAccessMode access; const std::string index; }; diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 0525766fcf..0aabb1d6b8 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -92,7 +92,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase virtual std::string getCLZ() const = 0; //! Get name of atomic operation - virtual std::string getAtomic(const std::string &type, AtomicOperation op = AtomicOperation::ADD, + virtual std::string getAtomic(const Type::NumericBase *type, AtomicOperation op = AtomicOperation::ADD, AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const = 0; //! Generate a shared memory barrier @@ -159,6 +159,14 @@ class GENN_EXPORT BackendSIMT : public BackendBase size_t getPaddedNumCustomUpdateWUThreads(const CustomUpdateWUInternal &cg, unsigned int batchSize) const; size_t getPaddedNumCustomUpdateTransposeWUThreads(const CustomUpdateWUInternal &cg, unsigned int batchSize) const; + //! Helper to get name of atomic operation + template + std::string getAtomic(AtomicOperation op = AtomicOperation::ADD, + AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const + { + return getAtomic(T::getInstance(), op, memSpace); + } + //-------------------------------------------------------------------------- // Static API //-------------------------------------------------------------------------- diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 0a01ff429d..88e04f1cb0 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -67,7 +67,7 @@ inline size_t padSize(size_t size, size_t blockSize) return ceilDivide(size, blockSize) * blockSize; } -GENN_EXPORT void genTypeRange(CodeStream &os, const std::string &precision, const std::string &prefix); +GENN_EXPORT void genTypeRange(CodeStream &os, const Type::NumericBase *precision, const Type::TypeContext &typeContext, const std::string &prefix); //-------------------------------------------------------------------------- /*! \brief This function implements a parser that converts any floating point constant in a code snippet to a floating point constant with an explicit precision (by appending "f" or removing it). diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index 806784508f..4af7204d66 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -15,7 +15,7 @@ namespace GeNN::CodeGenerator class GENN_EXPORT CustomConnectivityUpdateGroupMergedBase : public GroupMerged { public: - CustomConnectivityUpdateGroupMergedBase(size_t index, const Type::NumericBase *precision, + CustomConnectivityUpdateGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups); protected: @@ -32,7 +32,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMergedBase : public GroupMerged> &groups); //---------------------------------------------------------------------------- @@ -40,12 +40,12 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public CustomConnectivit //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -75,17 +75,17 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public CustomConnectivit class GENN_EXPORT CustomConnectivityHostUpdateGroupMerged : public CustomConnectivityUpdateGroupMergedBase { public: - CustomConnectivityHostUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + CustomConnectivityHostUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name, true); } diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index 23a9263969..890286d051 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -12,7 +12,7 @@ namespace GeNN::CodeGenerator class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged { public: - CustomUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + CustomUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -23,12 +23,12 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged> &groups); private: @@ -95,21 +95,21 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged> &groups) - : CustomUpdateWUGroupMergedBase(index, precision, timePrecision, backend, groups) + : CustomUpdateWUGroupMergedBase(index, typeContext, backend, groups) { } //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -127,21 +127,21 @@ class GENN_EXPORT CustomUpdateWUGroupMerged : public CustomUpdateWUGroupMergedBa class GENN_EXPORT CustomUpdateTransposeWUGroupMerged : public CustomUpdateWUGroupMergedBase { public: - CustomUpdateTransposeWUGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + CustomUpdateTransposeWUGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) - : CustomUpdateWUGroupMergedBase(index, precision, timePrecision, backend, groups) + : CustomUpdateWUGroupMergedBase(index, typeContext, backend, groups) { } //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -161,9 +161,9 @@ template class CustomUpdateHostReductionGroupMergedBase : public GroupMerged { protected: - CustomUpdateHostReductionGroupMergedBase(size_t index, const Type::NumericBase *precision, const BackendBase &backend, - const std::vector> &groups) - : GroupMerged(index, precision, groups) + CustomUpdateHostReductionGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups) + : GroupMerged(index, typeContext, groups) { // Create type environment // **TEMP** parse precision to get scalar type @@ -192,18 +192,18 @@ class CustomUpdateHostReductionGroupMergedBase : public GroupMerged class GENN_EXPORT CustomUpdateHostReductionGroupMerged : public CustomUpdateHostReductionGroupMergedBase { public: - CustomUpdateHostReductionGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + CustomUpdateHostReductionGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name, true); } @@ -219,18 +219,18 @@ class GENN_EXPORT CustomUpdateHostReductionGroupMerged : public CustomUpdateHost class GENN_EXPORT CustomWUUpdateHostReductionGroupMerged : public CustomUpdateHostReductionGroupMergedBase { public: - CustomWUUpdateHostReductionGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + CustomWUUpdateHostReductionGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name, true); } diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 378c891933..6c3c112d47 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -63,10 +63,11 @@ class GroupMerged //------------------------------------------------------------------------ typedef G GroupInternal; typedef std::function GetFieldValueFunc; + typedef std::function GetFieldDoubleValueFunc; typedef std::tuple Field; - GroupMerged(size_t index, const Type::NumericBase *precision, const std::vector> groups) - : m_Index(index), m_LiteralSuffix((precision == "float") ? "f" : ""), m_ScalarType(precision), m_Groups(std::move(groups)) + GroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> groups) + : m_Index(index), m_TypeContext(typeContext), m_Groups(std::move(groups)) {} //------------------------------------------------------------------------ @@ -87,29 +88,28 @@ class GroupMerged const std::vector &getFields() const{ return m_Fields; } //! Get group fields, sorted into order they will appear in struct - std::vector getSortedFields(const BackendBase &backend, const Type::TypeContext &context) const + std::vector getSortedFields(const BackendBase &backend) const { // Make a copy of fields and sort so largest come first. This should mean that due // to structure packing rules, significant memory is saved and estimate is more precise auto sortedFields = m_Fields; std::sort(sortedFields.begin(), sortedFields.end(), - [&backend, &context](const Field &a, const Field &b) + [&backend, this](const Field &a, const Field &b) { - return (std::get<0>(a)->getSizeBytes(context) > std::get<0>(b)->getSizeBytes(context)); + return (std::get<0>(a)->getSizeBytes(m_TypeContext) > std::get<0>(b)->getSizeBytes(m_TypeContext)); }); return sortedFields; } //! Generate declaration of struct to hold this merged group - void generateStruct(CodeStream &os, const BackendBase &backend, const Type::TypeContext &context, - const std::string &name, bool host = false) const + void generateStruct(CodeStream &os, const BackendBase &backend, const std::string &name, bool host = false) const { os << "struct Merged" << name << "Group" << getIndex() << std::endl; { // Loop through fields and write to structure CodeStream::Scope b(os); - const auto sortedFields = getSortedFields(backend, context); + const auto sortedFields = getSortedFields(backend); for(const auto &f : sortedFields) { // If field is a pointer and not marked as being a host field // (in which case the backend should leave its type alone!) @@ -117,16 +117,16 @@ class GroupMerged if(dynamic_cast(type) && !(std::get<3>(f) & GroupMergedFieldType::HOST)) { // If we are generating a host structure, allow the backend to override the type if(host) { - os << backend.getMergedGroupFieldHostTypeName(type, context); + os << backend.getMergedGroupFieldHostTypeName(type, m_TypeContext); } // Otherwise, allow the backend to add a prefix else { - os << backend.getPointerPrefix() << type->getResolvedName(context); + os << backend.getPointerPrefix() << type->getResolvedName(m_TypeContext); } } // Otherwise, leave the type alone else { - os << type->getResolvedName(context); + os << type->getResolvedName(m_TypeContext); } os << " " << std::get<1>(f) << ";" << std::endl; } @@ -136,28 +136,28 @@ class GroupMerged os << ";" << std::endl; } - void generateStructFieldArgumentDefinitions(CodeStream &os, const BackendBase &backend, const Type::TypeContext &context) const + void generateStructFieldArgumentDefinitions(CodeStream &os, const BackendBase &backend) const { // Get sorted fields - const auto sortedFields = getSortedFields(backend, context); + const auto sortedFields = getSortedFields(backend); for(size_t fieldIndex = 0; fieldIndex < sortedFields.size(); fieldIndex++) { const auto &f = sortedFields[fieldIndex]; - os << backend.getMergedGroupFieldHostTypeName(std::get<0>(f), context) << " " << std::get<1>(f); + os << backend.getMergedGroupFieldHostTypeName(std::get<0>(f), m_TypeContext) << " " << std::get<1>(f); if(fieldIndex != (sortedFields.size() - 1)) { os << ", "; } } } - size_t getStructArraySize(const BackendBase &backend, const Type::TypeContext &context) const + size_t getStructArraySize(const BackendBase &backend) const { // Loop through fields again to generate any EGP pushing functions that are required and to calculate struct size size_t structSize = 0; size_t largestFieldSize = 0; - const auto sortedFields = getSortedFields(backend, context); + const auto sortedFields = getSortedFields(backend); for(const auto &f : sortedFields) { // Add size of field to total - const size_t fieldSize = std::get<0>(f)->getSizeBytes(context); + const size_t fieldSize = std::get<0>(f)->getSizeBytes(m_TypeContext); structSize += fieldSize; // Update largest field size @@ -199,7 +199,8 @@ class GroupMerged //------------------------------------------------------------------------ // Protected methods //------------------------------------------------------------------------ - const Type::NumericBase *getScalarType() const{ return m_ScalarType; } + const Type::NumericBase *getScalarType() const{ return dynamic_cast(m_TypeContext.at("scalar")); } + const Type::NumericBase *getTimeType() const{ return dynamic_cast(m_TypeContext.at("time")); } //! Helper to test whether parameter is referenced in vector of codestrings bool isParamReferenced(const std::vector &codeStrings, const std::string ¶mName) const @@ -254,12 +255,12 @@ class GroupMerged m_Fields.emplace_back(T::getInstance(), name, getFieldValue, fieldType); } - void addScalarField(const std::string &name, GetFieldValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) + void addScalarField(const std::string &name, GetFieldDoubleValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) { - addField(m_ScalarType, name, + addField(getScalarType(), name, [getFieldValue, this](const G &g, size_t i) { - return getFieldValue(g, i) + m_LiteralSuffix; + return Utils::writePreciseString(getFieldValue(g, i), getScalarType()->getMaxDigits10(m_TypeContext)) + getScalarType()->getLiteralSuffix(m_TypeContext); }, fieldType); } @@ -321,8 +322,7 @@ class GroupMerged addScalarField(p + suffix, [p, getParamValues](const G &g, size_t) { - const auto &values = getParamValues(g); - return Utils::writePreciseString(values.at(p)); + return getParamValues(g).at(p); }); } } @@ -341,8 +341,7 @@ class GroupMerged addScalarField(d.name + suffix, [d, getDerivedParamValues](const G &g, size_t) { - const auto &values = getDerivedParamValues(g); - return Utils::writePreciseString(values.at(d.name)); + return getDerivedParamValues(g).at(d.name); }); } } @@ -360,8 +359,7 @@ class GroupMerged addScalarField(p.first + v.name, [p, v](const G &g, size_t) { - const auto &values = A(g).getVarInitialisers().at(v.name).getParams(); - return Utils::writePreciseString(values.at(p.first)); + return A(g).getVarInitialisers().at(v.name).getParams().at(p.first); }); } } @@ -380,8 +378,7 @@ class GroupMerged addScalarField(p.first + v.name, [p, v](const G &g, size_t) { - const auto &values = A(g).getVarInitialisers().at(v.name).getDerivedParams(); - return Utils::writePreciseString(values.at(p.first)); + return A(g).getVarInitialisers().at(v.name).getDerivedParams().at(p.first); }); } } @@ -458,19 +455,19 @@ class GroupMerged } } - void generateRunnerBase(const BackendBase &backend, const Type::TypeContext &context, + void generateRunnerBase(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc, const std::string &name, bool host = false) const { // Make a copy of fields and sort so largest come first. This should mean that due // to structure packing rules, significant memory is saved and estimate is more precise - auto sortedFields = getSortedFields(backend, context); + auto sortedFields = getSortedFields(backend); // If this isn't a host merged structure, generate definition for function to push group if(!host) { definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << "Group" << getIndex() << "ToDevice(unsigned int idx, "; - generateStructFieldArgumentDefinitions(definitionsInternalFunc, backend, context); + generateStructFieldArgumentDefinitions(definitionsInternalFunc, backend); definitionsInternalFunc << ");" << std::endl; } @@ -479,7 +476,7 @@ class GroupMerged // If this field is a dynamic pointer if((std::get<3>(f) & GroupMergedFieldType::DYNAMIC) && dynamic_cast(std::get<0>(f))) { definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << getIndex() << std::get<1>(f) << "ToDevice(unsigned int idx, "; - definitionsInternalFunc << backend.getMergedGroupFieldHostTypeName(std::get<0>(f), context) << " value);" << std::endl; + definitionsInternalFunc << backend.getMergedGroupFieldHostTypeName(std::get<0>(f), m_TypeContext) << " value);" << std::endl; } // Raise error if this field is a host field but this isn't a host structure @@ -490,7 +487,7 @@ class GroupMerged if(host) { // Generate struct directly into internal definitions // **NOTE** we ignore any backend prefix as we're generating this struct for use on the host - generateStruct(definitionsInternal, backend, context, name, true); + generateStruct(definitionsInternal, backend, name, true); // Declare array of these structs containing individual neuron group pointers etc runnerVarDecl << "Merged" << name << "Group" << getIndex() << " merged" << name << "Group" << getIndex() << "[" << getGroups().size() << "];" << std::endl; @@ -541,8 +538,7 @@ class GroupMerged // Members //------------------------------------------------------------------------ const size_t m_Index; - const std::string m_LiteralSuffix; - const Type::NumericBase *m_ScalarType; + const Type::TypeContext &m_TypeContext; std::string m_MemorySpace; std::vector m_Fields; std::vector> m_Groups; @@ -554,18 +550,18 @@ class GroupMerged class GENN_EXPORT NeuronSpikeQueueUpdateGroupMerged : public GroupMerged { public: - NeuronSpikeQueueUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecison, const BackendBase &backend, + NeuronSpikeQueueUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -583,18 +579,18 @@ class GENN_EXPORT NeuronSpikeQueueUpdateGroupMerged : public GroupMerged { public: - NeuronPrevSpikeTimeUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecison, const BackendBase &backend, + NeuronPrevSpikeTimeUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -662,7 +658,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged> &groups); void updateBaseHash(bool init, boost::uuids::detail::sha1 &hash) const; @@ -758,7 +754,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged*getValueFn)().at(p)); + return std::invoke(getValueFn, child).at(p); }); } } @@ -778,7 +774,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged*getValueFn)().at(p.name)); + return std::invoke(getValueFn, child).at(p.name); }); } } @@ -798,8 +794,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged*getVarInitialiserFn)(); - return Utils::writePreciseString(varInit.at(varName).getParams().at(p)); + return std::invoke(getVarInitialiserFn, child).at(varName).getParams().at(p); }); } } @@ -819,8 +814,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged*getVarInitialiserFn)(); - return Utils::writePreciseString(varInit.at(varName).getDerivedParams().at(d.name)); + return std::invoke(getVarInitialiserFn, child).at(varName).getDerivedParams().at(d.name); }); } } @@ -1073,7 +1067,7 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged> &groups); //---------------------------------------------------------------------------- diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index b8890d44ea..1b05a3a2da 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -140,6 +140,16 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa defineField(qualifiedType, name, type->getPointerType(), name, [prefix](const auto &g, size_t) { return prefix + g.getName(); }); } + + void defineScalarField(const std::string &name, typename G::GetFieldDoubleValueFunc getFieldValue) + { + defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONSTANT), name, + m_ScalarType, name, + [getFieldValue, this](const auto &g, size_t i) + { + return Utils::writePreciseString(getFieldValue(g, i), m_ScalarType->getMaxDigits10(m_TypeContext)) + m_ScalarType->getLiteralSuffix(m_TypeContext); + }); + } template void defineHeterogeneousParams(const Snippet::Base::StringVec ¶mNames, const std::string &suffix, @@ -148,13 +158,11 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa // Loop through params for(const auto &p : paramNames) { if (std::invoke(isHeterogeneous, m_GroupMerged, p)) { - defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONSTANT), p + suffix, - m_ScalarType, p + suffix, - [p, getParamValues](const auto &g, size_t) - { - const auto &values = getParamValues(g); - return Utils::writePreciseString(values.at(p)); - }); + defineScalarField(p + suffix, + [p, getParamValues](const auto &g, size_t) + { + return getParamValues(g).at(p); + }); } // Otherwise, just add a const-qualified scalar to the type environment else { @@ -170,13 +178,11 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa // Loop through derived params for(const auto &d : derivedParams) { if (std::invoke(isHeterogeneous, m_GroupMerged, d.name)) { - defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONSTANT), d.name + suffix, - m_ScalarType, d.name + suffix, - [d, getDerivedParamValues](const auto &g, size_t) - { - const auto &values = getDerivedParamValues(g); - return Utils::writePreciseString(values.at(d.name)); - }); + defineScalarField(d.name + suffix, + [d, getDerivedParamValues](const auto &g, size_t) + { + return getDerivedParamValues(g).at(d.name); + }); } else { defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONSTANT), d.name + suffix); diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 88c0e48083..0f93872682 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -11,7 +11,7 @@ namespace GeNN::CodeGenerator class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase { public: - NeuronInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + NeuronInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -20,12 +20,12 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //! Get hash digest used for detecting changes boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -96,9 +96,9 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase class GENN_EXPORT SynapseInitGroupMerged : public SynapseGroupMergedBase { public: - SynapseInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + SynapseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) - : SynapseGroupMergedBase(index, precision, timePrecision, backend, SynapseGroupMergedBase::Role::Init, "", groups) + : SynapseGroupMergedBase(index, typeContext, backend, SynapseGroupMergedBase::Role::Init, "", groups) {} boost::uuids::detail::sha1::digest_type getHashDigest() const @@ -106,12 +106,12 @@ class GENN_EXPORT SynapseInitGroupMerged : public SynapseGroupMergedBase return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::Init); } - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -129,9 +129,9 @@ class GENN_EXPORT SynapseInitGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseSparseInitGroupMerged : public SynapseGroupMergedBase { public: - SynapseSparseInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + SynapseSparseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) - : SynapseGroupMergedBase(index, precision, timePrecision, backend, SynapseGroupMergedBase::Role::SparseInit, "", groups) + : SynapseGroupMergedBase(index, typeContext, backend, SynapseGroupMergedBase::Role::SparseInit, "", groups) {} boost::uuids::detail::sha1::digest_type getHashDigest() const @@ -139,12 +139,12 @@ class GENN_EXPORT SynapseSparseInitGroupMerged : public SynapseGroupMergedBase return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::SparseInit); } - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -162,9 +162,9 @@ class GENN_EXPORT SynapseSparseInitGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMergedBase { public: - SynapseConnectivityInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + SynapseConnectivityInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) - : SynapseGroupMergedBase(index, precision, timePrecision, backend, SynapseGroupMergedBase::Role::ConnectivityInit, "", groups) + : SynapseGroupMergedBase(index, typeContext, backend, SynapseGroupMergedBase::Role::ConnectivityInit, "", groups) {} boost::uuids::detail::sha1::digest_type getHashDigest() const @@ -172,12 +172,12 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMerged return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::ConnectivityInit); } - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -205,18 +205,18 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMerged class GENN_EXPORT SynapseConnectivityHostInitGroupMerged : public GroupMerged { public: - SynapseConnectivityHostInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + SynapseConnectivityHostInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name, true); } @@ -249,9 +249,9 @@ template class CustomUpdateInitGroupMergedBase : public GroupMerged { protected: - CustomUpdateInitGroupMergedBase(size_t index, const Type::NumericBase *precision, const BackendBase &backend, + CustomUpdateInitGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) - : GroupMerged(index, precision, groups) + : GroupMerged(index, typeContext, groups) { // Loop through variables A archetypeAdaptor(this->getArchetype()); @@ -340,7 +340,7 @@ class CustomUpdateInitGroupMergedBase : public GroupMerged class GENN_EXPORT CustomUpdateInitGroupMerged : public CustomUpdateInitGroupMergedBase { public: - CustomUpdateInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + CustomUpdateInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -348,12 +348,12 @@ class GENN_EXPORT CustomUpdateInitGroupMerged : public CustomUpdateInitGroupMerg //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -373,7 +373,7 @@ class GENN_EXPORT CustomWUUpdateInitGroupMerged : public CustomUpdateInitGroupMe CustomUpdateVarAdapter> { public: - CustomWUUpdateInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + CustomWUUpdateInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -381,12 +381,12 @@ class GENN_EXPORT CustomWUUpdateInitGroupMerged : public CustomUpdateInitGroupMe //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -432,7 +432,7 @@ class GENN_EXPORT CustomWUUpdateSparseInitGroupMerged : public CustomUpdateInitG CustomUpdateVarAdapter> { public: - CustomWUUpdateSparseInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + CustomWUUpdateSparseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -440,12 +440,12 @@ class GENN_EXPORT CustomWUUpdateSparseInitGroupMerged : public CustomUpdateInitG //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -464,7 +464,7 @@ class GENN_EXPORT CustomConnectivityUpdatePreInitGroupMerged : public CustomUpda CustomConnectivityUpdatePreVarAdapter> { public: - CustomConnectivityUpdatePreInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + CustomConnectivityUpdatePreInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -472,12 +472,12 @@ class GENN_EXPORT CustomConnectivityUpdatePreInitGroupMerged : public CustomUpda //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -496,7 +496,7 @@ class GENN_EXPORT CustomConnectivityUpdatePostInitGroupMerged : public CustomUpd CustomConnectivityUpdatePostVarAdapter> { public: - CustomConnectivityUpdatePostInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + CustomConnectivityUpdatePostInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -504,12 +504,12 @@ class GENN_EXPORT CustomConnectivityUpdatePostInitGroupMerged : public CustomUpd //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -528,7 +528,7 @@ class GENN_EXPORT CustomConnectivityUpdateSparseInitGroupMerged : public CustomU CustomConnectivityUpdateVarAdapter> { public: - CustomConnectivityUpdateSparseInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + CustomConnectivityUpdateSparseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -536,12 +536,12 @@ class GENN_EXPORT CustomConnectivityUpdateSparseInitGroupMerged : public CustomU //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 98e2e2f7c3..1aa21679e1 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -40,11 +40,11 @@ class GENN_EXPORT ModelSpecMerged //! Immutable structure for tracking fields of merged group structure containing EGPs struct EGPField { - EGPField(size_t m, const std::string &t, const std::string &f, bool h) - : mergedGroupIndex(m), typeName(t), fieldName(f), hostGroup(h) {} + EGPField(size_t m, const Type::Base *t, const std::string &f, bool h) + : mergedGroupIndex(m), type(t), fieldName(f), hostGroup(h) {} const size_t mergedGroupIndex; - const std::string typeName; + const Type::Base *type; const std::string fieldName; const bool hostGroup; @@ -52,8 +52,8 @@ class GENN_EXPORT ModelSpecMerged //! lexicographically compares all three struct members bool operator < (const EGPField &other) const { - return (std::tie(mergedGroupIndex, typeName, fieldName, hostGroup) - < std::tie(other.mergedGroupIndex, other.typeName, other.fieldName, other.hostGroup)); + return (std::make_tuple(mergedGroupIndex, type->getName(), fieldName, hostGroup) + < std::make_tuple(other.mergedGroupIndex, other.type->getName(), other.fieldName, other.hostGroup)); } }; @@ -63,7 +63,7 @@ class GENN_EXPORT ModelSpecMerged //! Immutable structure for tracking where an extra global variable ends up after merging struct MergedEGP : public EGPField { - MergedEGP(size_t m, size_t g, const std::string &t, const std::string &f, bool h) + MergedEGP(size_t m, size_t g, const Type::Base *t, const std::string &f, bool h) : EGPField(m, t, f, h), groupIndex(g) {} const size_t groupIndex; @@ -161,31 +161,31 @@ class GENN_EXPORT ModelSpecMerged //! Get merged custom connectivity update groups where host processing needs to be performed const std::vector &getMergedCustomConnectivityHostUpdateGroups() const { return m_MergedCustomConnectivityHostUpdateGroups; } - void genMergedNeuronUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedNeuronUpdateGroups); } - void genMergedPresynapticUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedPresynapticUpdateGroups); } - void genMergedPostsynapticUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedPostsynapticUpdateGroups); } - void genMergedSynapseDynamicsGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedSynapseDynamicsGroups); } - void genMergedNeuronInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedNeuronInitGroups); } - void genMergedCustomUpdateInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomUpdateInitGroups); } - void genMergedCustomWUUpdateInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomWUUpdateInitGroups); } - void genMergedSynapseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedSynapseInitGroups); } - void genMergedSynapseConnectivityInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedSynapseConnectivityInitGroups); } - void genMergedSynapseSparseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedSynapseSparseInitGroups); } - void genMergedCustomWUUpdateSparseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomWUUpdateSparseInitGroups); } - void genMergedCustomConnectivityUpdatePreInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomConnectivityUpdatePreInitGroups); } - void genMergedCustomConnectivityUpdatePostInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomConnectivityUpdatePostInitGroups); } - void genMergedCustomConnectivityUpdateSparseInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomConnectivityUpdateSparseInitGroups); } - void genMergedNeuronSpikeQueueUpdateStructs(CodeStream &os, const BackendBase &backend) const{ genMergedStructures(os, backend, m_TypeContext, m_MergedNeuronSpikeQueueUpdateGroups); } - void genMergedNeuronPrevSpikeTimeUpdateStructs(CodeStream &os, const BackendBase &backend) const{ genMergedStructures(os, backend, m_TypeContext, m_MergedNeuronPrevSpikeTimeUpdateGroups); } - void genMergedSynapseDendriticDelayUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedSynapseDendriticDelayUpdateGroups); } - void genMergedSynapseConnectivityHostInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedSynapseConnectivityHostInitGroups); } - void genMergedCustomUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomUpdateGroups); } - void genMergedCustomUpdateWUStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomUpdateWUGroups); } - void genMergedCustomUpdateTransposeWUStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomUpdateTransposeWUGroups); } - void genMergedCustomUpdateHostReductionStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomUpdateHostReductionGroups); } - void genMergedCustomWUUpdateHostReductionStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomWUUpdateHostReductionGroups); } - void genMergedCustomConnectivityUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomConnectivityUpdateGroups); } - void genMergedCustomConnectivityHostUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_TypeContext, m_MergedCustomConnectivityHostUpdateGroups); } + void genMergedNeuronUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedNeuronUpdateGroups); } + void genMergedPresynapticUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedPresynapticUpdateGroups); } + void genMergedPostsynapticUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedPostsynapticUpdateGroups); } + void genMergedSynapseDynamicsGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseDynamicsGroups); } + void genMergedNeuronInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedNeuronInitGroups); } + void genMergedCustomUpdateInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomUpdateInitGroups); } + void genMergedCustomWUUpdateInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomWUUpdateInitGroups); } + void genMergedSynapseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseInitGroups); } + void genMergedSynapseConnectivityInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseConnectivityInitGroups); } + void genMergedSynapseSparseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseSparseInitGroups); } + void genMergedCustomWUUpdateSparseInitGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomWUUpdateSparseInitGroups); } + void genMergedCustomConnectivityUpdatePreInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomConnectivityUpdatePreInitGroups); } + void genMergedCustomConnectivityUpdatePostInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomConnectivityUpdatePostInitGroups); } + void genMergedCustomConnectivityUpdateSparseInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomConnectivityUpdateSparseInitGroups); } + void genMergedNeuronSpikeQueueUpdateStructs(CodeStream &os, const BackendBase &backend) const{ genMergedStructures(os, backend, m_MergedNeuronSpikeQueueUpdateGroups); } + void genMergedNeuronPrevSpikeTimeUpdateStructs(CodeStream &os, const BackendBase &backend) const{ genMergedStructures(os, backend, m_MergedNeuronPrevSpikeTimeUpdateGroups); } + void genMergedSynapseDendriticDelayUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseDendriticDelayUpdateGroups); } + void genMergedSynapseConnectivityHostInitStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedSynapseConnectivityHostInitGroups); } + void genMergedCustomUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomUpdateGroups); } + void genMergedCustomUpdateWUStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomUpdateWUGroups); } + void genMergedCustomUpdateTransposeWUStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomUpdateTransposeWUGroups); } + void genMergedCustomUpdateHostReductionStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomUpdateHostReductionGroups); } + void genMergedCustomWUUpdateHostReductionStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomWUUpdateHostReductionGroups); } + void genMergedCustomConnectivityUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomConnectivityUpdateGroups); } + void genMergedCustomConnectivityHostUpdateStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedCustomConnectivityHostUpdateGroups); } void genNeuronUpdateGroupSupportCode(CodeStream &os, bool supportsNamespace = true) const{ m_NeuronUpdateSupportCode.gen(os, getModel().getPrecision(), supportsNamespace); } void genPostsynapticDynamicsSupportCode(CodeStream &os, bool supportsNamespace = true) const{ m_PostsynapticDynamicsSupportCode.gen(os, getModel().getPrecision(), supportsNamespace); } @@ -243,7 +243,7 @@ class GENN_EXPORT ModelSpecMerged std::transform(groupEGPs.first, groupEGPs.second, std::inserter(mergedGroupFields, mergedGroupFields.end()), [](const MergedEGPMap::value_type::second_type::value_type &g) { - return EGPField{g.second.mergedGroupIndex, g.second.typeName, g.second.fieldName, g.second.hostGroup}; + return EGPField{g.second.mergedGroupIndex, g.second.type, g.second.fieldName, g.second.hostGroup}; }); } @@ -267,7 +267,7 @@ class GENN_EXPORT ModelSpecMerged // If EGP is a pointer // **NOTE** this is common to all references! if(dynamic_cast(f.type)) { - os << "void pushMerged" << T::name << f.mergedGroupIndex << f.fieldName << "ToDevice(unsigned int idx, " << backend.getMergedGroupFieldHostTypeName(f.type) << " value)"; + os << "void pushMerged" << T::name << f.mergedGroupIndex << f.fieldName << "ToDevice(unsigned int idx, " << backend.getMergedGroupFieldHostTypeName(f.type, m_TypeContext) << " value)"; { CodeStream::Scope b(os); backend.genMergedExtraGlobalParamPush(os, T::name, f.mergedGroupIndex, "idx", f.fieldName, "value"); @@ -286,11 +286,11 @@ class GENN_EXPORT ModelSpecMerged // Private methods //-------------------------------------------------------------------------- template - void genMergedStructures(CodeStream &os, const BackendBase &backend, const Type::TypeContext &context, const std::vector &mergedGroups) const + void genMergedStructures(CodeStream &os, const BackendBase &backend, const std::vector &mergedGroups) const { // Loop through all merged groups and generate struct for(const auto &g : mergedGroups) { - g.generateStruct(os, backend, context, T::name); + g.generateStruct(os, backend, T::name); } } @@ -316,7 +316,7 @@ class GENN_EXPORT ModelSpecMerged size_t i = 0; for(const auto &p : protoMergedGroups) { // Add group to vector - mergedGroups.emplace_back(i, model.getPrecision(), model.getTimePrecision(), backend, p.second); + mergedGroups.emplace_back(i, m_TypeContext, backend, p.second); // Loop through fields for(const auto &f : mergedGroups.back().getFields()) { diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index ef173dad79..85c398e57b 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -11,7 +11,7 @@ namespace GeNN::CodeGenerator class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase { public: - NeuronUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); //------------------------------------------------------------------------ @@ -38,12 +38,12 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //! Get hash digest used for detecting changes boost::uuids::detail::sha1::digest_type getHashDigest() const; - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } diff --git a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h index b24dc41f28..52f179f5b2 100644 --- a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h @@ -11,9 +11,9 @@ namespace GeNN::CodeGenerator class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase { public: - PresynapticUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + PresynapticUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) - : SynapseGroupMergedBase(index, precision, timePrecision, backend, SynapseGroupMergedBase::Role::PresynapticUpdate, + : SynapseGroupMergedBase(index, typeContext, backend, SynapseGroupMergedBase::Role::PresynapticUpdate, groups.front().get().getWUModel()->getSimCode() + groups.front().get().getWUModel()->getEventCode() + groups.front().get().getWUModel()->getEventThresholdConditionCode(), groups) {} @@ -22,12 +22,12 @@ class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::PresynapticUpdate); } - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -49,9 +49,9 @@ class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase { public: - PostsynapticUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + PostsynapticUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) - : SynapseGroupMergedBase(index, precision, timePrecision, backend, SynapseGroupMergedBase::Role::PostsynapticUpdate, + : SynapseGroupMergedBase(index, typeContext, backend, SynapseGroupMergedBase::Role::PostsynapticUpdate, groups.front().get().getWUModel()->getLearnPostCode(), groups) {} @@ -60,12 +60,12 @@ class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::PostsynapticUpdate); } - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -83,9 +83,9 @@ class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase { public: - SynapseDynamicsGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + SynapseDynamicsGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) - : SynapseGroupMergedBase(index, precision, timePrecision, backend, SynapseGroupMergedBase::Role::SynapseDynamics, + : SynapseGroupMergedBase(index, typeContext, backend, SynapseGroupMergedBase::Role::SynapseDynamics, groups.front().get().getWUModel()->getSynapseDynamicsCode(), groups) {} @@ -94,12 +94,12 @@ class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::SynapseDynamics); } - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } @@ -117,18 +117,18 @@ class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseDendriticDelayUpdateGroupMerged : public GroupMerged { public: - SynapseDendriticDelayUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, + SynapseDendriticDelayUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &group); //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - void generateRunner(const BackendBase &backend, const Type::TypeContext &context, + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc) const { - generateRunnerBase(backend, context, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc, name); } diff --git a/include/genn/genn/gennUtils.h b/include/genn/genn/gennUtils.h index fd610bffe1..94c29249c6 100644 --- a/include/genn/genn/gennUtils.h +++ b/include/genn/genn/gennUtils.h @@ -111,7 +111,7 @@ void validateVecNames(const std::vector &vec, const std::string &description) //! \brief This function writes a floating point value to a stream -setting the precision so no digits are lost //-------------------------------------------------------------------------- template::value>::type * = nullptr> -void writePreciseString(std::ostream &os, T value) +void writePreciseString(std::ostream &os, T value, int maxDigits10 = std::numeric_limits::max_digits10) { // Cache previous precision const std::streamsize previousPrecision = os.precision(); @@ -119,8 +119,8 @@ void writePreciseString(std::ostream &os, T value) // Set scientific formatting os << std::scientific; - // Set precision to what is required to fully represent T - os << std::setprecision(std::numeric_limits::max_digits10); + // Set precision + os << std::setprecision(maxDigits10); // Write value to stream os << value; @@ -138,10 +138,10 @@ void writePreciseString(std::ostream &os, T value) //! \brief This function writes a floating point value to a string - setting the precision so no digits are lost //-------------------------------------------------------------------------- template::value>::type * = nullptr> -inline std::string writePreciseString(T value) +inline std::string writePreciseString(T value, int maxDigits10 = std::numeric_limits::max_digits10) { std::stringstream s; - writePreciseString(s, value); + writePreciseString(s, value, maxDigits10); return s.str(); } diff --git a/include/genn/genn/modelSpecInternal.h b/include/genn/genn/modelSpecInternal.h index 7fc79c3327..4420f3873c 100644 --- a/include/genn/genn/modelSpecInternal.h +++ b/include/genn/genn/modelSpecInternal.h @@ -24,8 +24,6 @@ class ModelSpecInternal : public ModelSpec using ModelSpec::finalize; - using ModelSpec::scalarExpr; - using ModelSpec::zeroCopyInUse; using ModelSpec::isRecordingInUse; using ModelSpec::getHashDigest; diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h index ddbb2af9e2..af1a5cada1 100644 --- a/include/genn/genn/transpiler/prettyPrinter.h +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -3,6 +3,9 @@ // Standard C++ includes #include +// GeNN includes +#include "type.h" + // Transpiler includes #include "transpiler/statement.h" @@ -11,5 +14,5 @@ //--------------------------------------------------------------------------- namespace GeNN::Transpiler::PrettyPrinter { -std::string print(const Statement::StatementList &statements); +std::string print(const Statement::StatementList &statements, const Type::TypeContext &context); } \ No newline at end of file diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 5a279bbb0f..6059ae1b85 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -23,7 +23,7 @@ //---------------------------------------------------------------------------- #define DECLARE_TYPE(TYPE) \ private: \ - GENN_EXPORT static TYPE *s_Instance; \ + GENN_EXPORT static TYPE *s_Instance; \ public: \ static const TYPE *getInstance() \ { \ @@ -34,19 +34,20 @@ return s_Instance; \ } -#define DECLARE_NUMERIC_TYPE(TYPE, UNDERLYING_TYPE, RANK) \ - class TYPE : public Numeric \ - { \ - DECLARE_TYPE(TYPE) \ - TYPE(Qualifier qualifiers = Qualifier{0}) : Numeric(qualifiers){} \ - virtual std::string getName() const final{ return #UNDERLYING_TYPE; } \ - virtual std::string getResolvedName(const TypeContext&) const final{ return #UNDERLYING_TYPE; } \ - virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new TYPE(qualifiers); } \ - }; \ - template<> \ - struct TypeTraits \ - { \ - using NumericType = TYPE; \ +#define DECLARE_NUMERIC_TYPE(TYPE, UNDERLYING_TYPE, RANK, LITERAL_SUFFIX) \ + class TYPE : public Numeric \ + { \ + DECLARE_TYPE(TYPE) \ + TYPE(Qualifier qualifiers = Qualifier{0}) : Numeric(qualifiers){} \ + virtual std::string getName() const final{ return #UNDERLYING_TYPE; } \ + virtual std::string getResolvedName(const TypeContext&) const final{ return #UNDERLYING_TYPE; } \ + virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new TYPE(qualifiers); } \ + virtual std::string getLiteralSuffix(const TypeContext &context) const final{ return LITERAL_SUFFIX; } \ + }; \ + template<> \ + struct TypeTraits \ + { \ + using NumericType = TYPE; \ } #define DECLARE_FOREIGN_FUNCTION_TYPE(TYPE, RETURN_TYPE, ...) \ @@ -179,8 +180,12 @@ class NumericBase : public Base virtual double getMin(const TypeContext&) const = 0; virtual double getMax(const TypeContext&) const = 0; virtual double getLowest(const TypeContext&) const = 0; + virtual int getMaxDigits10(const TypeContext&) const = 0; + virtual bool isSigned(const TypeContext&) const = 0; virtual bool isIntegral(const TypeContext&) const = 0; + + virtual std::string getLiteralSuffix(const TypeContext&) const = 0; }; //---------------------------------------------------------------------------- @@ -204,6 +209,8 @@ class Numeric : public NumericBase virtual double getMin(const TypeContext&) const final { return std::numeric_limits::min(); } virtual double getMax(const TypeContext&) const final { return std::numeric_limits::max(); } virtual double getLowest(const TypeContext&) const final { return std::numeric_limits::lowest(); } + virtual int getMaxDigits10(const TypeContext&) const final{ return std::numeric_limits::max_digits10; } + virtual bool isSigned(const TypeContext&) const final { return std::is_signed::value; } virtual bool isIntegral(const TypeContext&) const final { return std::is_integral::value; } }; @@ -232,9 +239,13 @@ class NumericTypedef : public NumericBase virtual double getMin(const TypeContext &context) const final; virtual double getMax(const TypeContext &context) const final; virtual double getLowest(const TypeContext &context) const final; + virtual int getMaxDigits10(const TypeContext &context) const final; + virtual bool isSigned(const TypeContext &context) const final; virtual bool isIntegral(const TypeContext &context) const final; + virtual std::string getLiteralSuffix(const TypeContext &context) const final; + private: //------------------------------------------------------------------------ // Private methods @@ -359,17 +370,17 @@ class ForeignFunction : public ForeignFunctionBase //---------------------------------------------------------------------------- // Declare numeric types //---------------------------------------------------------------------------- -DECLARE_NUMERIC_TYPE(Bool, bool, 0); -DECLARE_NUMERIC_TYPE(Int8, int8_t, 10); -DECLARE_NUMERIC_TYPE(Int16, int16_t, 20); -DECLARE_NUMERIC_TYPE(Int32, int32_t, 30); +DECLARE_NUMERIC_TYPE(Bool, bool, 0, ""); +DECLARE_NUMERIC_TYPE(Int8, int8_t, 10, ""); +DECLARE_NUMERIC_TYPE(Int16, int16_t, 20, ""); +DECLARE_NUMERIC_TYPE(Int32, int32_t, 30, ""); //DECLARE_NUMERIC_TYPE(Int64, int64_t, 40); -DECLARE_NUMERIC_TYPE(Uint8, uint8_t, 10); -DECLARE_NUMERIC_TYPE(Uint16, uint16_t, 20); -DECLARE_NUMERIC_TYPE(Uint32, uint32_t, 30); +DECLARE_NUMERIC_TYPE(Uint8, uint8_t, 10, "u"); +DECLARE_NUMERIC_TYPE(Uint16, uint16_t, 20, "u"); +DECLARE_NUMERIC_TYPE(Uint32, uint32_t, 30, "u"); //DECLARE_NUMERIC_TYPE(Uint64, uint64_t, 40); -DECLARE_NUMERIC_TYPE(Float, float, 50); -DECLARE_NUMERIC_TYPE(Double, double, 60); +DECLARE_NUMERIC_TYPE(Float, float, 50, "f"); +DECLARE_NUMERIC_TYPE(Double, double, 60, ""); //---------------------------------------------------------------------------- // Declare standard library foreign function types diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index d02d5287c0..b8b081600e 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -523,7 +523,7 @@ void BackendSIMT::genNeuronUpdateKernel(CodeStream &os, const Substitutions &ker os << "if (shSpkEvntCount > 0)"; { CodeStream::Scope b(os); - os << "shPosSpkEvnt = " << getAtomic("unsigned int") << "(&group->spkCntEvnt"; + os << "shPosSpkEvnt = " << getAtomic() << "(&group->spkCntEvnt"; if(ng.getArchetype().isDelayRequired()) { os << "[*group->spkQuePtr"; if(batchSize > 1) { @@ -546,7 +546,7 @@ void BackendSIMT::genNeuronUpdateKernel(CodeStream &os, const Substitutions &ker os << "if (shSpkCount > 0)"; { CodeStream::Scope b(os); - os << "shPosSpk = " << getAtomic("unsigned int") << "(&group->spkCnt"; + os << "shPosSpk = " << getAtomic() << "(&group->spkCnt"; if(ng.getArchetype().isDelayRequired() && ng.getArchetype().isTrueSpikeRequired()) { os << "[*group->spkQuePtr"; if(batchSize > 1) { @@ -937,7 +937,7 @@ void BackendSIMT::genCustomUpdateKernel(CodeStream &os, const Substitutions &ker // Loop through reduction targets and generate reduction for(const auto &r : reductionTargets) { - os << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; + os << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; } } @@ -980,7 +980,7 @@ void BackendSIMT::genCustomUpdateKernel(CodeStream &os, const Substitutions &ker // Loop through reduction targets and generate reduction for (const auto &r : reductionTargets) { - os << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; + os << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; } } @@ -989,7 +989,7 @@ void BackendSIMT::genCustomUpdateKernel(CodeStream &os, const Substitutions &ker for (unsigned int i = 16; i > 0; i /= 2) { for (const auto &r : reductionTargets) { os << getReductionOperation("lr" + r.name, "__shfl_down_sync(0xFFFFFFFF, lr" + r.name + ", " + std::to_string(i) + ")", - r.access, r.type) << ";" << std::endl; + r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; } } @@ -1146,7 +1146,7 @@ void BackendSIMT::genCustomUpdateWUKernel(CodeStream &os, const Substitutions &k if(cg.getArchetype().isBatchReduction()) { // Loop through reduction targets and generate reduction for(const auto &r : reductionTargets) { - os << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; + os << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; } // End for loop through batches @@ -1564,7 +1564,7 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne } // Otherwise else { - kernelInit << "group->ind[(($(0)) * group->rowStride) + " << getAtomic("unsigned int") << +"(&group->rowLength[$(0)], 1)] = " << popSubs["id_post"] << ";"; + kernelInit << "group->ind[(($(0)) * group->rowStride) + " << getAtomic() << +"(&group->rowLength[$(0)], 1)] = " << popSubs["id_post"] << ";"; } } // Otherwise, if it's bitmask @@ -1575,12 +1575,12 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne // If there is row-building code in this snippet if(!snippet->getRowBuildCode().empty()) { kernelInit << "const " << indexType << " rowStartGID = " << popSubs["id"] << " * (" << indexType << ")group->rowStride;" << std::endl; - kernelInit << getAtomic("unsigned int", AtomicOperation::OR) << "(&group->gp[(rowStartGID + ($(0))) / 32], 0x80000000 >> ((rowStartGID + ($(0))) & 31));" << std::endl; + kernelInit << getAtomic(AtomicOperation::OR) << "(&group->gp[(rowStartGID + ($(0))) / 32], 0x80000000 >> ((rowStartGID + ($(0))) & 31));" << std::endl; } // Otherwise else { kernelInit << "const " << indexType << " colStartGID = " << popSubs["id"] << ";" << std::endl; - kernelInit << getAtomic("unsigned int", AtomicOperation::OR) << "(&group->gp[(colStartGID + (($(0)) * group->rowStride)) / 32], 0x80000000 >> ((colStartGID + (($(0)) * group->rowStride)) & 31));" << std::endl; + kernelInit << getAtomic(AtomicOperation::OR) << "(&group->gp[(colStartGID + (($(0)) * group->rowStride)) / 32], 0x80000000 >> ((colStartGID + (($(0)) * group->rowStride)) & 31));" << std::endl; } } } @@ -1656,7 +1656,7 @@ void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions // Atomically increment length of column of connectivity associated with this target // **NOTE** this returns previous length i.e. where to insert new entry - os << "const unsigned int colLocation = " << getAtomic("unsigned int") << "(&group->colLength[postIndex], 1);" << std::endl; + os << "const unsigned int colLocation = " << getAtomic() << "(&group->colLength[postIndex], 1);" << std::endl; // From this calculate index into column-major matrix os << "const unsigned int colMajorIndex = (postIndex * group->colStride) + colLocation;" << std::endl; @@ -1711,16 +1711,16 @@ size_t BackendSIMT::padKernelSize(size_t size, Kernel kernel) const //-------------------------------------------------------------------------- void BackendSIMT::genEmitSpike(CodeStream &os, const Substitutions &subs, const std::string &suffix, bool recordingEnabled) const { - os << "const unsigned int spk" << suffix << "Idx = " << getAtomic("unsigned int", AtomicOperation::ADD, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Count, 1);" << std::endl; + os << "const unsigned int spk" << suffix << "Idx = " << getAtomic(AtomicOperation::ADD, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Count, 1);" << std::endl; os << "shSpk" << suffix << "[spk" << suffix << "Idx] = " << subs["id"] << ";" << std::endl; // If recording is enabled, set bit in recording word if(recordingEnabled) { if(m_KernelBlockSizes[KernelNeuronUpdate] == 32) { - os << getAtomic("unsigned int", AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record, 1 << " << getThreadID() << ");" << std::endl; + os << getAtomic(AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record, 1 << " << getThreadID() << ");" << std::endl; } else { - os << getAtomic("unsigned int", AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record[" << getThreadID() << " / 32], 1 << (" << getThreadID() << " % 32));" << std::endl; + os << getAtomic(AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record[" << getThreadID() << " / 32], 1 << (" << getThreadID() << " % 32));" << std::endl; } } } diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index 78346f3fb2..74d23452f4 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -326,28 +326,16 @@ void functionSubstitute(std::string &code, const std::string &funcName, } } //---------------------------------------------------------------------------- -void genTypeRange(CodeStream &os, const std::string &precision, const std::string &prefix) +void genTypeRange(CodeStream &os, const Type::NumericBase *precision, const Type::TypeContext &typeContext, const std::string &prefix) { + os << "#define " << prefix << "_MIN "; - if(precision == "float") { - Utils::writePreciseString(os, std::numeric_limits::min()); - os << "f" << std::endl; - } - else { - Utils::writePreciseString(os, std::numeric_limits::min()); - os << std::endl; - } + Utils::writePreciseString(os, precision->getMin(typeContext), precision->getMaxDigits10(typeContext)); + os << precision->getLiteralSuffix(typeContext) << std::endl << std::endl; os << "#define " << prefix << "_MAX "; - if(precision == "float") { - Utils::writePreciseString(os, std::numeric_limits::max()); - os << "f" << std::endl; - } - else { - Utils::writePreciseString(os, std::numeric_limits::max()); - os << std::endl; - } - os << std::endl; + Utils::writePreciseString(os, precision->getMax(typeContext), precision->getMaxDigits10(typeContext)); + os << precision->getLiteralSuffix(typeContext) << std::endl; } //---------------------------------------------------------------------------- std::string ensureFtype(const std::string &oldcode, const std::string &type) diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index f21a43f142..efbec90e60 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -12,9 +12,9 @@ using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- // CodeGenerator::CustomConnectivityUpdateGroupMergedBase //---------------------------------------------------------------------------- -CustomConnectivityUpdateGroupMergedBase::CustomConnectivityUpdateGroupMergedBase(size_t index, const Type::NumericBase *precision, +CustomConnectivityUpdateGroupMergedBase::CustomConnectivityUpdateGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups) -: GroupMerged(index, precision, groups) +: GroupMerged(index, typeContext, groups) { using namespace Type; @@ -60,9 +60,9 @@ bool CustomConnectivityUpdateGroupMergedBase::isDerivedParamHeterogeneous(const //---------------------------------------------------------------------------- const std::string CustomConnectivityUpdateGroupMerged::name = "CustomConnectivityUpdate"; //---------------------------------------------------------------------------- -CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: CustomConnectivityUpdateGroupMergedBase(index, precision, groups) +: CustomConnectivityUpdateGroupMergedBase(index, typeContext, groups) { using namespace Type; @@ -427,9 +427,9 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back //---------------------------------------------------------------------------- const std::string CustomConnectivityHostUpdateGroupMerged::name = "CustomConnectivityHostUpdate"; //---------------------------------------------------------------------------- -CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: CustomConnectivityUpdateGroupMergedBase(index, precision, groups) +: CustomConnectivityUpdateGroupMergedBase(index, typeContext, groups) { using namespace Type; @@ -544,7 +544,7 @@ void CustomConnectivityHostUpdateGroupMerged::addVarPushPullFuncSubs(const Backe // **YUCK** these EGP functions should probably just be called dynamic or something std::stringstream pullStream; CodeStream pull(pullStream); - backend.genExtraGlobalParamPull(pull, v.->getPointerType(), v.name, + backend.genExtraGlobalParamPull(pull, v.type->getPointerType(), v.name, loc, count, "group->"); // Add substitution diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 6e87acef0d..be03687422 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -81,7 +81,7 @@ void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, std::string code = cm->getUpdateCode(); updateSubs.applyCheckUnreplaced(code, "custom update : merged" + std::to_string(cg.getIndex())); - code = ensureFtype(code, modelMerged.getModel().getPrecision()); + //code = ensureFtype(code, modelMerged.getModel().getPrecision()); os << code; // Write read/write variables back to global memory @@ -111,9 +111,9 @@ void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, //---------------------------------------------------------------------------- const std::string CustomUpdateGroupMerged::name = "CustomUpdate"; //---------------------------------------------------------------------------- -CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, precision, groups) +: GroupMerged(index, typeContext, groups) { using namespace Type; @@ -159,7 +159,7 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::Numer const std::string code = upgradeCodeString(cm->getUpdateCode()); const auto tokens = Transpiler::Scanner::scanSource(code, errorHandler); const auto statements = Transpiler::Parser::parseBlockItemList(tokens, errorHandler); - Transpiler::TypeChecker::typeCheck(statements, typeEnvironment, errorHandler); + Transpiler::TypeChecker::typeCheck(statements, typeEnvironment, typeContext, errorHandler); } //---------------------------------------------------------------------------- @@ -290,15 +290,15 @@ std::string CustomUpdateWUGroupMergedBase::getVarRefIndex(VarAccessDuplication v return ((varDuplication == VarAccessDuplication::SHARED || !getArchetype().isBatched()) ? "" : "batchOffset + ") + index; } //---------------------------------------------------------------------------- -CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, precision, groups) +: GroupMerged(index, typeContext, groups) { using namespace Type; // Create type environment // **TEMP** parse precision to get scalar type - GroupMergedTypeEnvironment typeEnvironment(*this, precision); + GroupMergedTypeEnvironment typeEnvironment(*this, getScalarType()); // If underlying synapse group has kernel weights if (getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL) { @@ -426,9 +426,9 @@ void CustomUpdateTransposeWUGroupMerged::generateCustomUpdate(const BackendBase& //---------------------------------------------------------------------------- const std::string CustomUpdateHostReductionGroupMerged::name = "CustomUpdateHostReduction"; //---------------------------------------------------------------------------- -CustomUpdateHostReductionGroupMerged::CustomUpdateHostReductionGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +CustomUpdateHostReductionGroupMerged::CustomUpdateHostReductionGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: CustomUpdateHostReductionGroupMergedBase(index, precision, backend, groups) +: CustomUpdateHostReductionGroupMergedBase(index, typeContext, backend, groups) { using namespace Type; @@ -451,9 +451,9 @@ CustomUpdateHostReductionGroupMerged::CustomUpdateHostReductionGroupMerged(size_ //---------------------------------------------------------------------------- const std::string CustomWUUpdateHostReductionGroupMerged::name = "CustomWUUpdateHostReduction"; //---------------------------------------------------------------------------- -CustomWUUpdateHostReductionGroupMerged::CustomWUUpdateHostReductionGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +CustomWUUpdateHostReductionGroupMerged::CustomWUUpdateHostReductionGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: CustomUpdateHostReductionGroupMergedBase(index, precision, backend, groups) +: CustomUpdateHostReductionGroupMergedBase(index, typeContext, backend, groups) { using namespace Type; diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 747456e333..0024d47ca6 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -87,23 +87,26 @@ void genSpikeMacros(CodeStream &os, const NeuronGroupInternal &ng, bool trueSpik os << std::endl << std::endl; } //-------------------------------------------------------------------------- -void genHostScalar(CodeStream &definitionsVar, CodeStream &runnerVarDecl, const std::string &type, const std::string &name, const std::string &value) +template +void genHostScalar(CodeStream &definitionsVar, CodeStream &runnerVarDecl, + const std::string &name, const std::string &value) { - definitionsVar << "EXPORT_VAR " << type << " " << name << ";" << std::endl; - runnerVarDecl << type << " " << name << " = " << value << ";" << std::endl; + definitionsVar << "EXPORT_VAR " << T::getInstance()->getName() << " " << name << ";" << std::endl; + runnerVarDecl << T::getInstance()->getName() << " " << name << " = " << value << ";" << std::endl; } //-------------------------------------------------------------------------- +template void genHostDeviceScalar(const BackendBase &backend, CodeStream &definitionsVar, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerVarAlloc, CodeStream &runnerVarFree, - const std::string &type, const std::string &name, const std::string &hostValue, MemAlloc &mem) + const std::string &name, const std::string &hostValue, MemAlloc &mem) { // Generate a host scalar - genHostScalar(definitionsVar, runnerVarDecl, type, name, hostValue); + genHostScalar(definitionsVar, runnerVarDecl, name, hostValue); // Generate a single-element array on device if(backend.isDeviceScalarRequired()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - type, name, VarLocation::DEVICE, 1, mem); + T::getInstance(), name, VarLocation::DEVICE, 1, mem); } } //-------------------------------------------------------------------------- @@ -550,18 +553,14 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // write DT macro const ModelSpecInternal &model = modelMerged.getModel(); - if (model.getTimePrecision() == "float") { - definitions << "#define DT " << Utils::writePreciseString(model.getDT()) << "f" << std::endl; - } else { - definitions << "#define DT " << Utils::writePreciseString(model.getDT()) << std::endl; - } - + definitions << "#define DT " << Utils::writePreciseString(model.getDT()) << model.getTimePrecision()->getLiteralSuffix(modelMerged.getTypeContext()) << std::endl; + // Typedefine scalar type definitions << "typedef " << model.getPrecision() << " scalar;" << std::endl; // Write ranges of scalar and time types - genTypeRange(definitions, model.getPrecision(), "SCALAR"); - genTypeRange(definitions, model.getTimePrecision(), "TIME"); + genTypeRange(definitions, model.getPrecision(), modelMerged.getTypeContext(), "SCALAR"); + genTypeRange(definitions, model.getTimePrecision(), modelMerged.getTypeContext(), "TIME"); definitions << "// ------------------------------------------------------------------------" << std::endl; definitions << "// bit tool macros" << std::endl; @@ -654,17 +653,17 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Generate variables to store total elapsed time // **NOTE** we ALWAYS generate these so usercode doesn't require #ifdefs around timing code - genHostScalar(definitionsVar, runnerVarDecl, "double", "initTime", "0.0"); - genHostScalar(definitionsVar, runnerVarDecl, "double", "initSparseTime", "0.0"); - genHostScalar(definitionsVar, runnerVarDecl, "double", "neuronUpdateTime", "0.0"); - genHostScalar(definitionsVar, runnerVarDecl, "double", "presynapticUpdateTime", "0.0"); - genHostScalar(definitionsVar, runnerVarDecl, "double", "postsynapticUpdateTime", "0.0"); - genHostScalar(definitionsVar, runnerVarDecl, "double", "synapseDynamicsTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, "initTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, "initSparseTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, "neuronUpdateTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, "presynapticUpdateTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, "postsynapticUpdateTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, "synapseDynamicsTime", "0.0"); // Generate variables to store total elapsed time for each custom update group for(const auto &g : customUpdateGroups) { - genHostScalar(definitionsVar, runnerVarDecl, "double", "customUpdate" + g + "Time", "0.0"); - genHostScalar(definitionsVar, runnerVarDecl, "double", "customUpdate" + g + "TransposeTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, "customUpdate" + g + "Time", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, "customUpdate" + g + "TransposeTime", "0.0"); } // If timing is actually enabled @@ -730,7 +729,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Loop through merged synapse connectivity host initialisation groups for(const auto &m : modelMerged.getMergedSynapseConnectivityHostInitGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } @@ -742,145 +741,145 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Generate merged neuron initialisation groups for(const auto &m : modelMerged.getMergedNeuronInitGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through merged synapse init groups for(const auto &m : modelMerged.getMergedSynapseInitGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through merged synapse connectivity initialisation groups for(const auto &m : modelMerged.getMergedSynapseConnectivityInitGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through merged sparse synapse init groups for(const auto &m : modelMerged.getMergedSynapseSparseInitGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Generate merged custom update initialisation groups for(const auto &m : modelMerged.getMergedCustomUpdateInitGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Generate merged custom WU update initialisation groups for(const auto &m : modelMerged.getMergedCustomWUUpdateInitGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Generate merged custom sparse WU update initialisation groups for(const auto &m : modelMerged.getMergedCustomWUUpdateSparseInitGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Generate merged custom connectivity update presynaptic initialisation groups for(const auto &m : modelMerged.getMergedCustomConnectivityUpdatePreInitGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Generate merged custom connectivity update postsynaptic initialisation groups for(const auto &m : modelMerged.getMergedCustomConnectivityUpdatePostInitGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Generate merged custom connectivity update synaptic initialisation groups for(const auto &m : modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through merged neuron update groups for(const auto &m : modelMerged.getMergedNeuronUpdateGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through merged presynaptic update groups for(const auto &m : modelMerged.getMergedPresynapticUpdateGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through merged postsynaptic update groups for(const auto &m : modelMerged.getMergedPostsynapticUpdateGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through synapse dynamics groups for(const auto &m : modelMerged.getMergedSynapseDynamicsGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through neuron groups whose previous spike times need resetting for(const auto &m : modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through neuron groups whose spike queues need resetting for(const auto &m : modelMerged.getMergedNeuronSpikeQueueUpdateGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through synapse groups whose dendritic delay pointers need updating for(const auto &m : modelMerged.getMergedSynapseDendriticDelayUpdateGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom variable update groups for(const auto &m : modelMerged.getMergedCustomUpdateGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom WU variable update groups for(const auto &m : modelMerged.getMergedCustomUpdateWUGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom WU transpose variable update groups for(const auto &m : modelMerged.getMergedCustomUpdateTransposeWUGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom update host reduction groups for(const auto &m : modelMerged.getMergedCustomUpdateHostReductionGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom weight update host reduction groups for(const auto &m : modelMerged.getMergedCustomWUUpdateHostReductionGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom connectivity update groups for(const auto &m : modelMerged.getMergedCustomConnectivityUpdateGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } // Loop through custom connectivity host update groups for(const auto &m : modelMerged.getMergedCustomConnectivityHostUpdateGroups()) { - m.generateRunner(backend, modelMerged.getTypeContext(), definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc); } @@ -901,20 +900,20 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const size_t numNeuronDelaySlots = batchSize * (size_t)n.second.getNumNeurons() * (size_t)n.second.getNumDelaySlots(); const size_t numSpikeCounts = n.second.isTrueSpikeRequired() ? (batchSize * n.second.getNumDelaySlots()) : batchSize; const size_t numSpikes = n.second.isTrueSpikeRequired() ? numNeuronDelaySlots : (batchSize * n.second.getNumNeurons()); - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "unsigned int", "glbSpkCnt" + n.first, n.second.getSpikeLocation(), numSpikeCounts, mem); - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "unsigned int", "glbSpk" + n.first, n.second.getSpikeLocation(), numSpikes, mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + "glbSpkCnt" + n.first, n.second.getSpikeLocation(), numSpikeCounts, mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + "glbSpk" + n.first, n.second.getSpikeLocation(), numSpikes, mem); // True spike push and pull functions genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, n.second.getSpikeLocation(), backend.getPreferences().automaticCopy, n.first + "Spikes", [&]() { - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - "unsigned int", "glbSpkCnt" + n.first, n.second.getSpikeLocation(), true, numSpikeCounts); - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - "unsigned int", "glbSpk" + n.first, n.second.getSpikeLocation(), true, numSpikes); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + "glbSpkCnt" + n.first, n.second.getSpikeLocation(), true, numSpikeCounts); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + "glbSpk" + n.first, n.second.getSpikeLocation(), true, numSpikes); }); // Current true spike push and pull functions @@ -931,8 +930,9 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If spike recording is enabled, define and declare variables and add free if(n.second.isSpikeRecordingEnabled()) { - backend.genVariableDefinition(definitionsVar, definitionsInternalVar, "uint32_t*", "recordSpk" + n.first, VarLocation::HOST_DEVICE); - backend.genVariableImplementation(runnerVarDecl, "uint32_t*", "recordSpk" + n.first, VarLocation::HOST_DEVICE); + const auto *uint32Pointer = Type::Uint32::getInstance()->getPointerType(); + backend.genVariableDefinition(definitionsVar, definitionsInternalVar, uint32Pointer, "recordSpk" + n.first, VarLocation::HOST_DEVICE); + backend.genVariableImplementation(runnerVarDecl, uint32Pointer, "recordSpk" + n.first, VarLocation::HOST_DEVICE); backend.genVariableFree(runnerVarFree, "recordSpk" + n.first, VarLocation::HOST_DEVICE); } @@ -944,22 +944,22 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, } // Spike-like event variables - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "unsigned int", "glbSpkCntEvnt" + n.first, n.second.getSpikeEventLocation(), - batchSize * n.second.getNumDelaySlots(), mem); - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "unsigned int", "glbSpkEvnt" + n.first, n.second.getSpikeEventLocation(), - numNeuronDelaySlots, mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + "glbSpkCntEvnt" + n.first, n.second.getSpikeEventLocation(), + batchSize * n.second.getNumDelaySlots(), mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + "glbSpkEvnt" + n.first, n.second.getSpikeEventLocation(), + numNeuronDelaySlots, mem); // Spike-like event push and pull functions genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, n.second.getSpikeEventLocation(), backend.getPreferences().automaticCopy, n.first + "SpikeEvents", [&]() { - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, "unsigned int", "glbSpkCntEvnt" + n.first, - n.second.getSpikeLocation(), true, batchSize * n.second.getNumDelaySlots()); - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, "unsigned int", "glbSpkEvnt" + n.first, - n.second.getSpikeLocation(), true, numNeuronDelaySlots); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, "glbSpkCntEvnt" + n.first, + n.second.getSpikeLocation(), true, batchSize * n.second.getNumDelaySlots()); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, "glbSpkEvnt" + n.first, + n.second.getSpikeLocation(), true, numNeuronDelaySlots); }); // Current spike-like event push and pull functions @@ -976,16 +976,17 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If spike recording is enabled, define and declare variables and add free if(n.second.isSpikeEventRecordingEnabled()) { - backend.genVariableDefinition(definitionsVar, definitionsInternalVar, "uint32_t*", "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); - backend.genVariableImplementation(runnerVarDecl, "uint32_t*", "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); + const auto *uint32Pointer = Type::Uint32::getInstance()->getPointerType(); + backend.genVariableDefinition(definitionsVar, definitionsInternalVar, uint32Pointer, "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); + backend.genVariableImplementation(runnerVarDecl, uint32Pointer, "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); backend.genVariableFree(runnerVarFree, "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); } } // If neuron group has axonal delays if (n.second.isDelayRequired()) { - genHostDeviceScalar(backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "unsigned int", "spkQuePtr" + n.first, "0", mem); + genHostDeviceScalar(backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + "spkQuePtr" + n.first, "0", mem); } // If neuron group needs to record its spike times @@ -1246,8 +1247,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, model.getPrecision(), "denDelay" + sg->getFusedPSVarSuffix(), sg->getDendriticDelayLocation(), (size_t)sg->getMaxDendriticDelayTimesteps() * (size_t)sg->getTrgNeuronGroup()->getNumNeurons() * batchSize, mem); - genHostDeviceScalar(backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "unsigned int", "denDelayPtr" + sg->getFusedPSVarSuffix(), "0", mem); + genHostDeviceScalar(backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + "denDelayPtr" + sg->getFusedPSVarSuffix(), "0", mem); } genRunnerFusedVars(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, @@ -1300,8 +1301,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, if(s.second.getMatrixType() & SynapseMatrixConnectivity::BITMASK) { const size_t gpSize = ceilDivide((size_t)s.second.getSrcNeuronGroup()->getNumNeurons() * backend.getSynapticMatrixRowStride(s.second), 32); - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "uint32_t", "gp" + s.second.getName(), s.second.getSparseConnectivityLocation(), gpSize, mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + "gp" + s.second.getName(), s.second.getSparseConnectivityLocation(), gpSize, mem); // Generate push and pull functions for bitmask genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, s.second.getSparseConnectivityLocation(), @@ -1309,8 +1310,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { // Row lengths - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, "uint32_t", "gp" + s.second.getName(), - s.second.getSparseConnectivityLocation(), autoInitialized, gpSize); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, "gp" + s.second.getName(), + s.second.getSparseConnectivityLocation(), autoInitialized, gpSize); }); } else if(s.second.getMatrixType() & SynapseMatrixConnectivity::SPARSE) { @@ -1322,24 +1323,24 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerVarDecl << "const unsigned int maxRowLength" << s.second.getName() << " = " << backend.getSynapticMatrixRowStride(s.second) << ";" << std::endl; // Row lengths - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "unsigned int", "rowLength" + s.second.getName(), varLoc, s.second.getSrcNeuronGroup()->getNumNeurons(), mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + "rowLength" + s.second.getName(), varLoc, s.second.getSrcNeuronGroup()->getNumNeurons(), mem); // Target indices backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - s.second.getSparseIndType()->getResolvedName(modelMerged.getTypeContext()), "ind" + s.second.getName(), varLoc, size, mem); + s.second.getSparseIndType(), "ind" + s.second.getName(), varLoc, size, mem); // **TODO** remap is not always required if(backend.isPostsynapticRemapRequired() && !s.second.getWUModel()->getLearnPostCode().empty()) { const size_t postSize = (size_t)s.second.getTrgNeuronGroup()->getNumNeurons() * (size_t)s.second.getMaxSourceConnections(); // Allocate column lengths - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "unsigned int", "colLength" + s.second.getName(), VarLocation::DEVICE, s.second.getTrgNeuronGroup()->getNumNeurons(), mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + "colLength" + s.second.getName(), VarLocation::DEVICE, s.second.getTrgNeuronGroup()->getNumNeurons(), mem); // Allocate remap - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "unsigned int", "remap" + s.second.getName(), VarLocation::DEVICE, postSize, mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + "remap" + s.second.getName(), VarLocation::DEVICE, postSize, mem); } // Generate push and pull functions for sparse connectivity @@ -1348,11 +1349,11 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { // Row lengths - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, "unsigned int", "rowLength" + s.second.getName(), - s.second.getSparseConnectivityLocation(), autoInitialized, s.second.getSrcNeuronGroup()->getNumNeurons()); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, "rowLength" + s.second.getName(), + s.second.getSparseConnectivityLocation(), autoInitialized, s.second.getSrcNeuronGroup()->getNumNeurons()); // Target indices - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, s.second.getSparseIndType()->getName(), "ind" + s.second.getName(), + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, s.second.getSparseIndType(), "ind" + s.second.getName(), s.second.getSparseConnectivityLocation(), autoInitialized, size); }); } diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 884d741e14..d91a64cf5f 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -19,9 +19,9 @@ using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- const std::string NeuronSpikeQueueUpdateGroupMerged::name = "NeuronSpikeQueueUpdate"; //---------------------------------------------------------------------------- -NeuronSpikeQueueUpdateGroupMerged::NeuronSpikeQueueUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +NeuronSpikeQueueUpdateGroupMerged::NeuronSpikeQueueUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, precision, groups) +: GroupMerged(index, typeContext, groups) { using namespace Type; @@ -68,9 +68,9 @@ void NeuronSpikeQueueUpdateGroupMerged::genMergedGroupSpikeCountReset(CodeStream //---------------------------------------------------------------------------- const std::string NeuronPrevSpikeTimeUpdateGroupMerged::name = "NeuronPrevSpikeTimeUpdate"; //---------------------------------------------------------------------------- -NeuronPrevSpikeTimeUpdateGroupMerged::NeuronPrevSpikeTimeUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, +NeuronPrevSpikeTimeUpdateGroupMerged::NeuronPrevSpikeTimeUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, precision, groups) +: GroupMerged(index, typeContext, groups) { using namespace Type; @@ -86,11 +86,11 @@ NeuronPrevSpikeTimeUpdateGroupMerged::NeuronPrevSpikeTimeUpdateGroupMerged(size_ if(getArchetype().isPrevSpikeTimeRequired()) { addPointerField("spk", backend.getDeviceVarPrefix() + "glbSpk"); - addPointerField(timePrecision, "prevST", backend.getDeviceVarPrefix() + "prevST"); + addPointerField(getTimeType(), "prevST", backend.getDeviceVarPrefix() + "prevST"); } if(getArchetype().isPrevSpikeEventTimeRequired()) { addPointerField("spkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); - addPointerField(timePrecision, "prevSET", backend.getDeviceVarPrefix() + "prevSET"); + addPointerField(getTimeType(), "prevSET", backend.getDeviceVarPrefix() + "prevSET"); } if(getArchetype().isDelayRequired()) { @@ -183,9 +183,9 @@ bool NeuronGroupMergedBase::isPSMVarInitDerivedParamHeterogeneous(size_t childIn [varName](const SynapseGroupInternal *inSyn){ return inSyn->getPSVarInitialisers().at(varName).getDerivedParams(); })); } //---------------------------------------------------------------------------- -NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, +NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, bool init, const std::vector> &groups) -: GroupMerged(index, precision, groups) +: GroupMerged(index, typeContext, groups) { using namespace Type; @@ -217,17 +217,17 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::NumericBa } if(getArchetype().isSpikeTimeRequired()) { - addPointerField(timePrecision, "sT", backend.getDeviceVarPrefix() + "sT"); + addPointerField(getTimeType(), "sT", backend.getDeviceVarPrefix() + "sT"); } if(getArchetype().isSpikeEventTimeRequired()) { - addPointerField(timePrecision, "seT", backend.getDeviceVarPrefix() + "seT"); + addPointerField(getTimeType(), "seT", backend.getDeviceVarPrefix() + "seT"); } if(getArchetype().isPrevSpikeTimeRequired()) { - addPointerField(timePrecision, "prevST", backend.getDeviceVarPrefix() + "prevST"); + addPointerField(getTimeType(), "prevST", backend.getDeviceVarPrefix() + "prevST"); } if(getArchetype().isPrevSpikeEventTimeRequired()) { - addPointerField(timePrecision, "prevSET", backend.getDeviceVarPrefix() + "prevSET"); + addPointerField(getTimeType(), "prevSET", backend.getDeviceVarPrefix() + "prevSET"); } // If this backend initialises population RNGs on device and this group requires on for simulation @@ -661,9 +661,9 @@ std::string SynapseGroupMergedBase::getKernelVarIndex(unsigned int batchSize, Va return (singleBatch ? "" : "kernBatchOffset + ") + index; } //---------------------------------------------------------------------------- -SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, +SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, Role role, const std::string &archetypeCode, const std::vector> &groups) -: GroupMerged(index, precision, groups), m_ArchetypeCode(archetypeCode) +: GroupMerged(index, typeContext, groups), m_ArchetypeCode(archetypeCode) { using namespace Type; @@ -800,22 +800,22 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::Numeric // Add spike times if required if(wum->isPreSpikeTimeRequired()) { - addSrcPointerField(timePrecision, "sTPre", backend.getDeviceVarPrefix() + "sT"); + addSrcPointerField(getTimeType(), "sTPre", backend.getDeviceVarPrefix() + "sT"); } if(wum->isPostSpikeTimeRequired()) { - addTrgPointerField(timePrecision, "sTPost", backend.getDeviceVarPrefix() + "sT"); + addTrgPointerField(getTimeType(), "sTPost", backend.getDeviceVarPrefix() + "sT"); } if(wum->isPreSpikeEventTimeRequired()) { - addSrcPointerField(timePrecision, "seTPre", backend.getDeviceVarPrefix() + "seT"); + addSrcPointerField(getTimeType(), "seTPre", backend.getDeviceVarPrefix() + "seT"); } if(wum->isPrevPreSpikeTimeRequired()) { - addSrcPointerField(timePrecision, "prevSTPre", backend.getDeviceVarPrefix() + "prevST"); + addSrcPointerField(getTimeType(), "prevSTPre", backend.getDeviceVarPrefix() + "prevST"); } if(wum->isPrevPostSpikeTimeRequired()) { - addTrgPointerField(timePrecision, "prevSTPost", backend.getDeviceVarPrefix() + "prevST"); + addTrgPointerField(getTimeType(), "prevSTPost", backend.getDeviceVarPrefix() + "prevST"); } if(wum->isPrevPreSpikeEventTimeRequired()) { - addSrcPointerField(timePrecision, "prevSETPre", backend.getDeviceVarPrefix() + "prevSET"); + addSrcPointerField(getTimeType(), "prevSETPre", backend.getDeviceVarPrefix() + "prevSET"); } // Add heterogeneous weight update model parameters addHeterogeneousParams( @@ -914,7 +914,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::Numeric addScalarField(var.name, [var](const SynapseGroupInternal &sg, size_t) { - return Utils::writePreciseString(sg.getWUConstInitVals().at(var.name)); + return sg.getWUConstInitVals().at(var.name); }); } } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 898ba27fbc..47b6cf6b64 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -185,9 +185,9 @@ void genInitWUVarCode(CodeStream &os, const Substitutions &popSubs, //---------------------------------------------------------------------------- const std::string NeuronInitGroupMerged::name = "NeuronInit"; //---------------------------------------------------------------------------- -NeuronInitGroupMerged::NeuronInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, +NeuronInitGroupMerged::NeuronInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: NeuronGroupMergedBase(index, precision, timePrecision, backend, true, groups) +: NeuronGroupMergedBase(index, typeContext, backend, true, groups) { // Build vector of vectors containing each child group's incoming // synapse groups, ordered to match those of the archetype group @@ -715,9 +715,9 @@ void SynapseConnectivityInitGroupMerged::genInitConnectivity(CodeStream &os, Sub //---------------------------------------------------------------------------- const std::string SynapseConnectivityHostInitGroupMerged::name = "SynapseConnectivityHostInit"; //------------------------------------------------------------------------ -SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, precision, groups) +: GroupMerged(index, typeContext, groups) { using namespace Type; @@ -857,9 +857,9 @@ bool SynapseConnectivityHostInitGroupMerged::isSparseConnectivityInitParamRefere //---------------------------------------------------------------------------- const std::string CustomUpdateInitGroupMerged::name = "CustomUpdateInit"; //---------------------------------------------------------------------------- -CustomUpdateInitGroupMerged::CustomUpdateInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +CustomUpdateInitGroupMerged::CustomUpdateInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: CustomUpdateInitGroupMergedBase(index, precision, backend, groups) +: CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) { addField("size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); @@ -892,9 +892,9 @@ void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, CodeS //---------------------------------------------------------------------------- const std::string CustomWUUpdateInitGroupMerged::name = "CustomWUUpdateInit"; //---------------------------------------------------------------------------- -CustomWUUpdateInitGroupMerged::CustomWUUpdateInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +CustomWUUpdateInitGroupMerged::CustomWUUpdateInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: CustomUpdateInitGroupMergedBase(index, precision, backend, groups) +: CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) { using namespace Type; @@ -1002,9 +1002,9 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Cod //---------------------------------------------------------------------------- const std::string CustomWUUpdateSparseInitGroupMerged::name = "CustomWUUpdateSparseInit"; //---------------------------------------------------------------------------- -CustomWUUpdateSparseInitGroupMerged::CustomWUUpdateSparseInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +CustomWUUpdateSparseInitGroupMerged::CustomWUUpdateSparseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: CustomUpdateInitGroupMergedBase(index, precision, backend, groups) +: CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) { using namespace Type; @@ -1074,9 +1074,9 @@ void CustomWUUpdateSparseInitGroupMerged::generateInit(const BackendBase &backen //---------------------------------------------------------------------------- const std::string CustomConnectivityUpdatePreInitGroupMerged::name = "CustomConnectivityUpdatePreInit"; //---------------------------------------------------------------------------- -CustomConnectivityUpdatePreInitGroupMerged::CustomConnectivityUpdatePreInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +CustomConnectivityUpdatePreInitGroupMerged::CustomConnectivityUpdatePreInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: CustomUpdateInitGroupMergedBase(index, precision, backend, groups) +: CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) { using namespace Type; @@ -1122,9 +1122,9 @@ void CustomConnectivityUpdatePreInitGroupMerged::generateInit(const BackendBase //---------------------------------------------------------------------------- const std::string CustomConnectivityUpdatePostInitGroupMerged::name = "CustomConnectivityUpdatePostInit"; //---------------------------------------------------------------------------- -CustomConnectivityUpdatePostInitGroupMerged::CustomConnectivityUpdatePostInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +CustomConnectivityUpdatePostInitGroupMerged::CustomConnectivityUpdatePostInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: CustomUpdateInitGroupMergedBase(index, precision, backend, groups) +: CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) { addField("size", [](const auto &c, size_t) @@ -1163,9 +1163,9 @@ void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase //---------------------------------------------------------------------------- const std::string CustomConnectivityUpdateSparseInitGroupMerged::name = "CustomConnectivityUpdateSparseInit"; //---------------------------------------------------------------------------- -CustomConnectivityUpdateSparseInitGroupMerged::CustomConnectivityUpdateSparseInitGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +CustomConnectivityUpdateSparseInitGroupMerged::CustomConnectivityUpdateSparseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: CustomUpdateInitGroupMergedBase(index, precision, backend, groups) +: CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) { using namespace Type; diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index c57f1b0cde..e87cfed457 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -31,7 +31,7 @@ void assignGroups(const BackendBase &backend, std::vector &groups, BackendBas ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBase &backend) : m_Model(model), m_NeuronUpdateSupportCode("NeuronUpdateSupportCode"), m_PostsynapticDynamicsSupportCode("PostsynapticDynamicsSupportCode"), m_PresynapticUpdateSupportCode("PresynapticUpdateSupportCode"), m_PostsynapticUpdateSupportCode("PostsynapticUpdateSupportCode"), - m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode"), m_TypeContext{{"scalar", Type::parseNumeric(model.getPrecision())}} + m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode"), m_TypeContext{{"scalar", model.getPrecision()}, {"time", model.getTimePrecision()}} { LOGD_CODE_GEN << "Merging neuron update groups:"; createMergedGroupsHash(model, backend, model.getNeuronGroups(), m_MergedNeuronUpdateGroups, diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 40d2f2febc..86e9852f88 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -11,9 +11,9 @@ using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- const std::string NeuronUpdateGroupMerged::name = "NeuronUpdate"; //---------------------------------------------------------------------------- -NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase *timePrecision, const BackendBase &backend, +NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: NeuronGroupMergedBase(index, precision, timePrecision, backend, false, groups) +: NeuronGroupMergedBase(index, typeContext, backend, false, groups) { using namespace Type; diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index ac5cf8396a..8d9f2b4476 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -869,7 +869,7 @@ void PostSpanToeplitz::genUpdate(CodeStream &os, const ModelSpecMerged &modelMer // Apply substitutions to value std::string value = d.value; connSubs.applyCheckUnreplaced(value, "toeplitz diagonal build state var : merged" + std::to_string(sg.getIndex())); - value = ensureFtype(value, modelMerged.getModel().getPrecision()); + //value = ensureFtype(value, modelMerged.getModel().getPrecision()); os << d.type << " " << d.name << " = " << value << ";" << std::endl; } diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 56e990f14f..853e8e67d4 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -321,9 +321,9 @@ void SynapseDynamicsGroupMerged::generateSynapseUpdate(const BackendBase &backen //---------------------------------------------------------------------------- const std::string SynapseDendriticDelayUpdateGroupMerged::name = "SynapseDendriticDelayUpdate"; //---------------------------------------------------------------------------- -SynapseDendriticDelayUpdateGroupMerged::SynapseDendriticDelayUpdateGroupMerged(size_t index, const Type::NumericBase *precision, const Type::NumericBase*, const BackendBase &backend, +SynapseDendriticDelayUpdateGroupMerged::SynapseDendriticDelayUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, precision, groups) +: GroupMerged(index, typeContext, groups) { addField(Type::Uint32::getInstance()->getPointerType(), "denDelayPtr", [&backend](const SynapseGroupInternal &sg, size_t) diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index 21d8f86773..60ad5c8bc6 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -311,7 +311,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdate::getHashDigest( Utils::updateHash(getUpdateGroupName(), hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName({}), hash); + Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName(), hash); // Because it adds and removes synapses, connectivity update has to update // ALL variables associated with synapse group being modified as well as @@ -326,7 +326,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdate::getHashDigest( [](const Models::WUVarReference &v) { boost::uuids::detail::sha1 hash; - Utils::updateHash(v.getVar().type, hash); + Utils::updateHash(v.getVar().type->getName(), hash); Utils::updateHash(v.isDuplicated(), hash); return hash.get_digest(); }); diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 1a3c816e1c..5d7fc7af13 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -266,7 +266,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getHashDigest() const CustomUpdateBase::updateHash(hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName({}), hash); + Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName(), hash); // Loop through variable references for(const auto &v : getVarReferences()) { @@ -287,7 +287,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getInitHashDigest() cons CustomUpdateBase::updateInitHash(hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName({}), hash); + Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName(), hash); return hash.get_digest(); } } // namespace GeNN diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 1bad25f7be..fee9114cf2 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -739,7 +739,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getWUHashDigest() const Utils::updateHash(getDelaySteps(), hash); Utils::updateHash(getBackPropDelaySteps(), hash); Utils::updateHash(getMaxDendriticDelayTimesteps(), hash); - Utils::updateHash(getSparseIndType()->getName({}), hash); + Utils::updateHash(getSparseIndType()->getName(), hash); Utils::updateHash(getNumThreadsPerSpike(), hash); Utils::updateHash(isEventThresholdReTestRequired(), hash); Utils::updateHash(getSpanType(), hash); @@ -904,7 +904,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getWUInitHashDigest() cons { boost::uuids::detail::sha1 hash; Utils::updateHash(getMatrixType(), hash); - Utils::updateHash(getSparseIndType()->getName({}), hash); + Utils::updateHash(getSparseIndType()->getName(), hash); Utils::updateHash(getWUModel()->getVars(), hash); Utils::updateHash(getWUModel()->getSynapseDynamicsCode().empty(), hash); @@ -969,7 +969,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getConnectivityInitHashDig boost::uuids::detail::sha1 hash; Utils::updateHash(getConnectivityInitialiser().getHashDigest(), hash); Utils::updateHash(getMatrixType(), hash); - Utils::updateHash(getSparseIndType()->getName({}), hash); + Utils::updateHash(getSparseIndType()->getName(), hash); return hash.get_digest(); } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 3faa576e20..ed2d80c845 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -12,6 +12,7 @@ // Transpiler includes #include "transpiler/transpilerUtils.h" +using namespace GeNN; using namespace GeNN::Transpiler; using namespace GeNN::Transpiler::PrettyPrinter; @@ -26,6 +27,8 @@ namespace class Visitor : public Expression::Visitor, public Statement::Visitor { public: + Visitor(const Type::TypeContext &context) : m_Context(context) {} + //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- @@ -102,11 +105,20 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Literal &literal) final { - std::visit( - Utils::Overload{ - [this](auto x) { m_StringStream << x; }, - [this](std::monostate) { m_StringStream << "invalid"; }}, - literal.getValue()); + // If literal is a double, we want to remove the d suffix in generated code + std::string_view lexeme = literal.getValue().lexeme; + if (literal.getValue().type == Token::Type::DOUBLE_NUMBER){ + m_StringStream << lexeme.substr(0, literal.getValue().lexeme.size() - 1); + } + // Otherwise, if literal is a scalar, we want to add appropriate suffix for scalar type + else if (literal.getValue().type == Token::Type::SCALAR_NUMBER) { + const Type::NumericBase *scalar = dynamic_cast(m_Context.at("scalar")); + m_StringStream << lexeme << scalar->getLiteralSuffix(m_Context); + } + // Otherwise, just write out original lexeme directly + else { + m_StringStream << lexeme; + } } virtual void visit(const Expression::Logical &logical) final @@ -301,14 +313,15 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // Members //--------------------------------------------------------------------------- std::ostringstream m_StringStream; + const Type::TypeContext &m_Context; }; } // Anonymous namespace //--------------------------------------------------------------------------- // GeNN::Transpiler::PrettyPrinter //--------------------------------------------------------------------------- -std::string GeNN::Transpiler::PrettyPrinter::print(const Statement::StatementList &statements) +std::string GeNN::Transpiler::PrettyPrinter::print(const Statement::StatementList &statements, const Type::TypeContext &context) { - Visitor visitor; + Visitor visitor(context); return visitor.print(statements); } diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index c72db89cdf..5b0f537840 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -125,6 +125,11 @@ double NumericTypedef::getLowest(const TypeContext &context) const return getNumeric(context)->getLowest(context); } //---------------------------------------------------------------------------- +int NumericTypedef::getMaxDigits10(const TypeContext &context) const +{ + return getNumeric(context)->getMaxDigits10(context); +} +//---------------------------------------------------------------------------- bool NumericTypedef::isSigned(const TypeContext &context) const { return getNumeric(context)->getSizeBytes(context); @@ -135,6 +140,11 @@ bool NumericTypedef::isIntegral(const TypeContext &context) const return getNumeric(context)->isIntegral(context); } //---------------------------------------------------------------------------- +std::string NumericTypedef::getLiteralSuffix(const TypeContext &context) const +{ + return getNumeric(context)->getLiteralSuffix(context); +} +//---------------------------------------------------------------------------- const Type::NumericBase *NumericTypedef::getNumeric(const TypeContext &context) const { const auto t = context.find(m_Name); From ce172f865049abe839317a1ba8466ffc03f1f262 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 18 Jan 2023 15:57:00 +0000 Subject: [PATCH 067/725] fixed some warnings --- .../genn/genn/code_generator/backendBase.h | 6 +- .../genn/genn/code_generator/backendSIMT.h | 5 +- .../genn/genn/code_generator/groupMerged.h | 3 + .../groupMergedTypeEnvironment.h | 17 +++--- .../genn/code_generator/initGroupMerged.h | 2 +- .../genn/code_generator/modelSpecMerged.h | 9 ++- include/genn/genn/type.h | 2 +- src/genn/backends/opencl/backend.cc | 36 ++++++------ src/genn/genn/code_generator/backendBase.cc | 2 +- .../genn/code_generator/initGroupMerged.cc | 20 +++---- .../genn/code_generator/modelSpecMerged.cc | 57 +++++++++++-------- .../code_generator/neuronUpdateGroupMerged.cc | 2 +- 12 files changed, 86 insertions(+), 75 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 61c24f1c05..5603e64e63 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -189,7 +189,7 @@ class GENN_EXPORT BackendBase //! Vector of prefixes required to allocate in memory space and size of memory space typedef std::vector> MemorySpaces; - BackendBase(const std::string &scalarType, const PreferencesBase &preferences); + BackendBase(const PreferencesBase &preferences); virtual ~BackendBase(){} //-------------------------------------------------------------------------- @@ -516,7 +516,7 @@ class GENN_EXPORT BackendBase for (const auto &v : cm->getVars()) { // If variable is a reduction target, define variable initialised to correct initial value for reduction if (v.access & VarAccessModeAttribute::REDUCE) { - os << v.type << " lr" << v.name << " = " << getReductionInitialValue(*this, getVarAccessMode(v.access), v.type) << ";" << std::endl; + os << v.type << " lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), v.type, cg.getTypeContext()) << ";" << std::endl; reductionTargets.emplace_back(v.name, v.type, getVarAccessMode(v.access), cg.getVarIndex(getVarAccessDuplication(v.access), idx)); } @@ -528,7 +528,7 @@ class GENN_EXPORT BackendBase // If variable reference is a reduction target, define variable initialised to correct initial value for reduction if (modelVarRef.access & VarAccessModeAttribute::REDUCE) { - os << modelVarRef.type << " lr" << modelVarRef.name << " = " << getReductionInitialValue(*this, modelVarRef.access, modelVarRef.type) << ";" << std::endl; + os << modelVarRef.type << " lr" << modelVarRef.name << " = " << getReductionInitialValue(modelVarRef.access, modelVarRef.type, cg.getTypeContext()) << ";" << std::endl; reductionTargets.emplace_back(modelVarRef.name, modelVarRef.type, modelVarRef.access, getVarRefIndexFn(varRef, idx)); } diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 0aabb1d6b8..8449d85753 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -51,9 +51,8 @@ using KernelBlockSize = std::array; class GENN_EXPORT BackendSIMT : public BackendBase { public: - BackendSIMT(const KernelBlockSize &kernelBlockSizes, const PreferencesBase &preferences, - const std::string &scalarType) - : BackendBase(scalarType, preferences), m_KernelBlockSizes(kernelBlockSizes) + BackendSIMT(const KernelBlockSize &kernelBlockSizes, const PreferencesBase &preferences) + : BackendBase(preferences), m_KernelBlockSizes(kernelBlockSizes) {} //------------------------------------------------------------------------ diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 6c3c112d47..eb09646ade 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -78,6 +78,9 @@ class GroupMerged //! Get 'archetype' neuron group - it's properties represent those of all other merged neuron groups const GroupInternal &getArchetype() const { return m_Groups.front().get(); } + //! Get type context used to resolve any types involved in this group + const Type::TypeContext &getTypeContext() const{ return m_TypeContext; } + //! Get name of memory space assigned to group const std::string &getMemorySpace() const { return m_MemorySpace; } diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index 1b05a3a2da..417dae24bc 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -24,9 +24,8 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa using TypeCheckError = Transpiler::TypeChecker::TypeCheckError; public: - GroupMergedTypeEnvironment(G &groupMerged, const Type::NumericBase *scalarType, - EnvironmentBase *enclosing = nullptr) - : m_GroupMerged(groupMerged), m_ScalarType(scalarType), m_Enclosing(enclosing) + GroupMergedTypeEnvironment(G &groupMerged, EnvironmentBase *enclosing = nullptr) + : m_GroupMerged(groupMerged), m_Enclosing(enclosing) { } @@ -143,11 +142,12 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa void defineScalarField(const std::string &name, typename G::GetFieldDoubleValueFunc getFieldValue) { - defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONSTANT), name, - m_ScalarType, name, + defineField(m_GroupMerged.getScalarType()->getQualifiedType(Type::Qualifier::CONSTANT), name, + m_GroupMerged.getScalarType(), name, [getFieldValue, this](const auto &g, size_t i) { - return Utils::writePreciseString(getFieldValue(g, i), m_ScalarType->getMaxDigits10(m_TypeContext)) + m_ScalarType->getLiteralSuffix(m_TypeContext); + return (Utils::writePreciseString(getFieldValue(g, i), m_GroupMerged.getScalarType()->getMaxDigits10(m_GroupMerged.getTypeContext())) + + m_GroupMerged.getScalarType()->getLiteralSuffix(m_GroupMerged.getTypeContext())); }); } @@ -166,7 +166,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa } // Otherwise, just add a const-qualified scalar to the type environment else { - defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONSTANT), p + suffix); + defineField(m_GroupMerged.getScalarType()->getQualifiedType(Type::Qualifier::CONSTANT), p + suffix); } } } @@ -185,7 +185,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa }); } else { - defineField(m_ScalarType->getQualifiedType(Type::Qualifier::CONSTANT), d.name + suffix); + defineField(m_GroupMerged.getScalarType()->getQualifiedType(Type::Qualifier::CONSTANT), d.name + suffix); } } } @@ -252,7 +252,6 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa // Members //--------------------------------------------------------------------------- G &m_GroupMerged; - const Type::NumericBase *m_ScalarType; EnvironmentBase *m_Enclosing; std::unordered_map>> m_Types; diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 0f93872682..342f1d81ad 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -195,7 +195,7 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMerged // Private methods //---------------------------------------------------------------------------- //! Generate either row or column connectivity init code - void genInitConnectivity(CodeStream &os, Substitutions &popSubs, const Type::NumericBase *scalarType, bool rowNotColumns) const; + void genInitConnectivity(CodeStream &os, Substitutions &popSubs, bool rowNotColumns) const; }; diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 1aa21679e1..449045bf43 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -214,6 +214,9 @@ class GENN_EXPORT ModelSpecMerged //! Get hash digest of init module boost::uuids::detail::sha1::digest_type getInitArchetypeHashDigest() const; + //! Get the string literal that should be used to represent a value in scalar type + std::string scalarExp(double value) const; + //! Does model have any EGPs? bool anyPointerEGPs() const; @@ -295,7 +298,7 @@ class GENN_EXPORT ModelSpecMerged } template - void createMergedGroupsHash(const ModelSpecInternal &model, const BackendBase &backend, + void createMergedGroupsHash(const BackendBase &backend, const std::vector> &unmergedGroups, std::vector &mergedGroups, D getHashDigest, bool host = false) { @@ -340,7 +343,7 @@ class GENN_EXPORT ModelSpecMerged } template - void createMergedGroupsHash(const ModelSpecInternal &model, const BackendBase &backend, + void createMergedGroupsHash(const BackendBase &backend, const std::map &groups, std::vector &mergedGroups, F filter, U updateHash, bool host = false) { @@ -353,7 +356,7 @@ class GENN_EXPORT ModelSpecMerged } // Merge filtered vector - createMergedGroupsHash(model, backend, unmergedGroups, mergedGroups, updateHash, host); + createMergedGroupsHash(backend, unmergedGroups, mergedGroups, updateHash, host); } //-------------------------------------------------------------------------- diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 6059ae1b85..10893bf57c 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -42,7 +42,7 @@ virtual std::string getName() const final{ return #UNDERLYING_TYPE; } \ virtual std::string getResolvedName(const TypeContext&) const final{ return #UNDERLYING_TYPE; } \ virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new TYPE(qualifiers); } \ - virtual std::string getLiteralSuffix(const TypeContext &context) const final{ return LITERAL_SUFFIX; } \ + virtual std::string getLiteralSuffix(const TypeContext&) const final{ return LITERAL_SUFFIX; } \ }; \ template<> \ struct TypeTraits \ diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index b958131c1a..69cc932dbc 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -2625,7 +2625,7 @@ void Backend::genKernelPreamble(CodeStream &os, const ModelSpecMerged &modelMerg os << "#include " << std::endl; os << "typedef " << precision << " scalar;" << std::endl; - os << "#define DT " << model.scalarExpr(model.getDT()) << std::endl; + os << "#define DT " << modelMerged.scalarExprr(model.getDT()) << std::endl; os << "#define SUPPORT_CODE_FUNC" << std::endl; genTypeRange(os, model.getTimePrecision(), "TIME"); @@ -2641,7 +2641,7 @@ void Backend::genKernelPreamble(CodeStream &os, const ModelSpecMerged &modelMerg { CodeStream::Scope b(os); os << "const " << precision << " u = clrng" << r << "RandomU01(rng);" << std::endl; - os << "if (u != " << model.scalarExpr(0.0) << ")"; + os << "if (u != " << modelMerged.scalarExprr(0.0) << ")"; { CodeStream::Scope b(os); os << "return -log(u);" << std::endl; @@ -2657,8 +2657,8 @@ void Backend::genKernelPreamble(CodeStream &os, const ModelSpecMerged &modelMerg const std::string pi = (model.getPrecision() == "float") ? "M_PI_F" : "M_PI"; os << "const " << precision << " u1 = clrng" << r << "RandomU01(rng);" << std::endl; os << "const " << precision << " u2 = clrng" << r << "RandomU01(rng);" << std::endl; - os << "const " << precision << " r = sqrt(" << model.scalarExpr(-2.0) << " * log(u1));" << std::endl; - os << "const " << precision << " theta = " << model.scalarExpr(2.0) << " * " << pi << " * u2;" << std::endl; + os << "const " << precision << " r = sqrt(" << modelMerged.scalarExprr(-2.0) << " * log(u1));" << std::endl; + os << "const " << precision << " theta = " << modelMerged.scalarExprr(2.0) << " * " << pi << " * u2;" << std::endl; os << "return r * sin(theta);" << std::endl; } os << std::endl; @@ -2683,9 +2683,9 @@ void Backend::genKernelPreamble(CodeStream &os, const ModelSpecMerged &modelMerg { CodeStream::Scope b(os); os << "x = normalDist" << r << "(rng);" << std::endl; - os << "v = " << model.scalarExpr(1.0) << " + c*x;" << std::endl; + os << "v = " << modelMerged.scalarExprr(1.0) << " + c*x;" << std::endl; } - os << "while (v <= " << model.scalarExpr(0.0) << ");" << std::endl; + os << "while (v <= " << modelMerged.scalarExprr(0.0) << ");" << std::endl; os << std::endl; os << "v = v*v*v;" << std::endl; os << "do"; @@ -2693,10 +2693,10 @@ void Backend::genKernelPreamble(CodeStream &os, const ModelSpecMerged &modelMerg CodeStream::Scope b(os); os << "u = clrng" << r << "RandomU01(rng);" << std::endl; } - os << "while (u == " << model.scalarExpr(1.0) << ");" << std::endl; + os << "while (u == " << modelMerged.scalarExprr(1.0) << ");" << std::endl; os << std::endl; - os << "if (u < " << model.scalarExpr(1.0) << " - " << model.scalarExpr(0.0331) << "*x*x*x*x) break;" << std::endl; - os << "if (log(u) < " << model.scalarExpr(0.5) << "*x*x + d*(" << model.scalarExpr(1.0) << " - v + log(v))) break;" << std::endl; + os << "if (u < " << modelMerged.scalarExprr(1.0) << " - " << modelMerged.scalarExprr(0.0331) << "*x*x*x*x) break;" << std::endl; + os << "if (log(u) < " << modelMerged.scalarExprr(0.5) << "*x*x + d*(" << modelMerged.scalarExprr(1.0) << " - v + log(v))) break;" << std::endl; } os << std::endl; os << "return d*v;" << std::endl; @@ -2710,15 +2710,15 @@ void Backend::genKernelPreamble(CodeStream &os, const ModelSpecMerged &modelMerg { CodeStream::Scope b(os); os << "const " << precision << " u = clrng" << r << "RandomU01 (rng);" << std::endl; - os << "const " << precision << " d = (" << model.scalarExpr(1.0) << " + a) - " << model.scalarExpr(1.0) << " / " << model.scalarExpr(3.0) << ";" << std::endl; - os << "const " << precision << " c = (" << model.scalarExpr(1.0) << " / " << model.scalarExpr(3.0) << ") / sqrt(d);" << std::endl; - os << "return gammaDistInternal" << r << "(rng, c, d) * pow(u, " << model.scalarExpr(1.0) << " / a);" << std::endl; + os << "const " << precision << " d = (" << modelMerged.scalarExprr(1.0) << " + a) - " << modelMerged.scalarExprr(1.0) << " / " << modelMerged.scalarExprr(3.0) << ";" << std::endl; + os << "const " << precision << " c = (" << modelMerged.scalarExprr(1.0) << " / " << modelMerged.scalarExprr(3.0) << ") / sqrt(d);" << std::endl; + os << "return gammaDistInternal" << r << "(rng, c, d) * pow(u, " << modelMerged.scalarExprr(1.0) << " / a);" << std::endl; } os << "else" << std::endl; { CodeStream::Scope b(os); - os << "const " << precision << " d = a - " << model.scalarExpr(1.0) << " / " << model.scalarExpr(3.0) << ";" << std::endl; - os << "const " << precision << " c = (" << model.scalarExpr(1.0) << " / " << model.scalarExpr(3.0) << ") / sqrt(d);" << std::endl; + os << "const " << precision << " d = a - " << modelMerged.scalarExprr(1.0) << " / " << modelMerged.scalarExprr(3.0) << ";" << std::endl; + os << "const " << precision << " c = (" << modelMerged.scalarExprr(1.0) << " / " << modelMerged.scalarExprr(3.0) << ") / sqrt(d);" << std::endl; os << "return gammaDistInternal" << r << "(rng, c, d);" << std::endl; } } @@ -2728,10 +2728,10 @@ void Backend::genKernelPreamble(CodeStream &os, const ModelSpecMerged &modelMerg os << "inline unsigned int binomialDist" << r << "Internal(clrng" << r << "Stream *rng, unsigned int n, " << precision << " p)" << std::endl; { CodeStream::Scope b(os); - os << "const " << precision << " q = " << model.scalarExpr(1.0) << " - p;" << std::endl; + os << "const " << precision << " q = " << modelMerged.scalarExprr(1.0) << " - p;" << std::endl; os << "const " << precision << " qn = exp(n * log(q));" << std::endl; os << "const " << precision << " np = n * p;" << std::endl; - os << "const unsigned int bound = min(n, (unsigned int)(np + (" << model.scalarExpr(10.0) << " * sqrt((np * q) + " << model.scalarExpr(1.0) << "))));" << std::endl; + os << "const unsigned int bound = min(n, (unsigned int)(np + (" << modelMerged.scalarExprr(10.0) << " * sqrt((np * q) + " << modelMerged.scalarExprr(1.0) << "))));" << std::endl; os << "unsigned int x = 0;" << std::endl; os << precision << " px = qn;" << std::endl; @@ -2761,7 +2761,7 @@ void Backend::genKernelPreamble(CodeStream &os, const ModelSpecMerged &modelMerg os << "inline unsigned int binomialDist" << r << "(clrng" << r << "Stream *rng, unsigned int n, " << precision << " p)" << std::endl; { CodeStream::Scope b(os); - os << "if(p <= " << model.scalarExpr(0.5) << ")"; + os << "if(p <= " << modelMerged.scalarExprr(0.5) << ")"; { CodeStream::Scope b(os); os << "return binomialDist" << r << "Internal(rng, n, p);" << std::endl; @@ -2770,7 +2770,7 @@ void Backend::genKernelPreamble(CodeStream &os, const ModelSpecMerged &modelMerg os << "else"; { CodeStream::Scope b(os); - os << "return (n - binomialDist" << r << "Internal(rng, n, " << model.scalarExpr(1.0) << " - p));" << std::endl; + os << "return (n - binomialDist" << r << "Internal(rng, n, " << modelMerged.scalarExprr(1.0) << " - p));" << std::endl; } } os << std::endl; diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 8a110517a4..879c4cf736 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -15,7 +15,7 @@ //-------------------------------------------------------------------------- namespace GeNN::CodeGenerator { -BackendBase::BackendBase(const std::string &scalarType, const PreferencesBase &preferences) +BackendBase::BackendBase(const PreferencesBase &preferences) : m_Preferences(preferences) { } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 47b6cf6b64..81b70991f6 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -102,7 +102,7 @@ void genInitNeuronVarCode(CodeStream &os, const BackendBase &backend, const Subs else { backend.genVariableInit( os, count, "id", varSubs, - [&var, &varInit, &fieldSuffix, &ftype, batchSize, groupIndex, count, numDelaySlots, isVarQueueRequired] + [&var, &varInit, &fieldSuffix, batchSize, groupIndex, count, numDelaySlots, isVarQueueRequired] (CodeStream &os, Substitutions &varInitSubs) { // Generate initial value into temporary variable @@ -110,7 +110,7 @@ void genInitNeuronVarCode(CodeStream &os, const BackendBase &backend, const Subs varInitSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); varInitSubs.applyCheckUnreplaced(code, "initVar : " + var.name + "merged" + std::to_string(groupIndex)); - code = ensureFtype(code, ftype); + //code = ensureFtype(code, ftype); os << code << std::endl; // Fill value across all delay slots and batches @@ -316,7 +316,7 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream backend.genVariableInit(os, "group->numNeurons", "id", popSubs, [&model, i] (CodeStream &os, Substitutions &varSubs) { - genVariableFill(os, "inSynInSyn" + std::to_string(i), model.scalarExpr(0.0), + genVariableFill(os, "inSynInSyn" + std::to_string(i), modelMerged.scalarExpr(0.0), varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, model.getBatchSize()); }); @@ -327,7 +327,7 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream backend.genVariableInit(os, "group->numNeurons", "id", popSubs, [&model, sg, i](CodeStream &os, Substitutions &varSubs) { - genVariableFill(os, "denDelayInSyn" + std::to_string(i), model.scalarExpr(0.0), + genVariableFill(os, "denDelayInSyn" + std::to_string(i), modelMerged.scalarExpr(0.0), varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, model.getBatchSize(), true, sg->getMaxDendriticDelayTimesteps()); }); @@ -376,7 +376,7 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream backend.genVariableInit(os, "group->numNeurons", "id", popSubs, [&model, i] (CodeStream &os, Substitutions &varSubs) { - genVariableFill(os, "revInSynOutSyn" + std::to_string(i), model.scalarExpr(0.0), + genVariableFill(os, "revInSynOutSyn" + std::to_string(i), modelMerged.scalarExpr(0.0), varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, model.getBatchSize()); }); } @@ -618,14 +618,14 @@ void SynapseSparseInitGroupMerged::generateInit(const BackendBase &backend, Code //---------------------------------------------------------------------------- const std::string SynapseConnectivityInitGroupMerged::name = "SynapseConnectivityInit"; //---------------------------------------------------------------------------- -void SynapseConnectivityInitGroupMerged::generateSparseRowInit(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void SynapseConnectivityInitGroupMerged::generateSparseRowInit(const BackendBase&, CodeStream &os, const ModelSpecMerged &, Substitutions &popSubs) const { - genInitConnectivity(os, popSubs, modelMerged.getModel().getPrecision(), true); + genInitConnectivity(os, popSubs, true); } //---------------------------------------------------------------------------- -void SynapseConnectivityInitGroupMerged::generateSparseColumnInit(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void SynapseConnectivityInitGroupMerged::generateSparseColumnInit(const BackendBase&, CodeStream &os, const ModelSpecMerged &, Substitutions &popSubs) const { - genInitConnectivity(os, popSubs, modelMerged.getModel().getPrecision(), false); + genInitConnectivity(os, popSubs, false); } //---------------------------------------------------------------------------- void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const @@ -668,7 +668,7 @@ void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase&, } } //---------------------------------------------------------------------------- -void SynapseConnectivityInitGroupMerged::genInitConnectivity(CodeStream &os, Substitutions &popSubs, const Type::NumericBase *scalarType, bool rowNotColumns) const +void SynapseConnectivityInitGroupMerged::genInitConnectivity(CodeStream &os, Substitutions &popSubs, bool rowNotColumns) const { const auto &connectInit = getArchetype().getConnectivityInitialiser(); const auto *snippet = connectInit.getSnippet(); diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index e87cfed457..de15687e2b 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -1,6 +1,7 @@ #include "code_generator/modelSpecMerged.h" // GeNN includes +#include "gennUtils.h" #include "logging.h" #include "modelSpecInternal.h" @@ -34,32 +35,32 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode"), m_TypeContext{{"scalar", model.getPrecision()}, {"time", model.getTimePrecision()}} { LOGD_CODE_GEN << "Merging neuron update groups:"; - createMergedGroupsHash(model, backend, model.getNeuronGroups(), m_MergedNeuronUpdateGroups, + createMergedGroupsHash(backend, model.getNeuronGroups(), m_MergedNeuronUpdateGroups, [](const NeuronGroupInternal &){ return true; }, &NeuronGroupInternal::getHashDigest); LOGD_CODE_GEN << "Merging presynaptic update groups:"; - createMergedGroupsHash(model, backend, model.getSynapseGroups(), m_MergedPresynapticUpdateGroups, + createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedPresynapticUpdateGroups, [](const SynapseGroupInternal &sg) { return (sg.isSpikeEventRequired() || sg.isTrueSpikeRequired()); }, &SynapseGroupInternal::getWUHashDigest); LOGD_CODE_GEN << "Merging postsynaptic update groups:"; - createMergedGroupsHash(model, backend, model.getSynapseGroups(), m_MergedPostsynapticUpdateGroups, + createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedPostsynapticUpdateGroups, [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getLearnPostCode().empty(); }, &SynapseGroupInternal::getWUHashDigest); LOGD_CODE_GEN << "Merging synapse dynamics update groups:"; - createMergedGroupsHash(model, backend, model.getSynapseGroups(), m_MergedSynapseDynamicsGroups, + createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedSynapseDynamicsGroups, [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getSynapseDynamicsCode().empty(); }, &SynapseGroupInternal::getWUHashDigest); LOGD_CODE_GEN << "Merging neuron initialization groups:"; - createMergedGroupsHash(model, backend, model.getNeuronGroups(), m_MergedNeuronInitGroups, + createMergedGroupsHash(backend, model.getNeuronGroups(), m_MergedNeuronInitGroups, [](const NeuronGroupInternal &){ return true; }, &NeuronGroupInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging synapse initialization groups:"; - createMergedGroupsHash(model, backend, model.getSynapseGroups(), m_MergedSynapseInitGroups, + createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedSynapseInitGroups, [](const SynapseGroupInternal &sg) { return (((sg.getMatrixType() & SynapseMatrixConnectivity::DENSE) @@ -69,12 +70,12 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa &SynapseGroupInternal::getWUInitHashDigest); LOGD_CODE_GEN << "Merging custom update initialization groups:"; - createMergedGroupsHash(model, backend, model.getCustomUpdates(), m_MergedCustomUpdateInitGroups, + createMergedGroupsHash(backend, model.getCustomUpdates(), m_MergedCustomUpdateInitGroups, [](const CustomUpdateInternal &cg) { return cg.isVarInitRequired(); }, &CustomUpdateInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging custom weight update initialization groups:"; - createMergedGroupsHash(model, backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateInitGroups, + createMergedGroupsHash(backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateInitGroups, [](const CustomUpdateWUInternal &cg) { return (((cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::DENSE) @@ -84,12 +85,12 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa &CustomUpdateWUInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging synapse connectivity initialisation groups:"; - createMergedGroupsHash(model, backend, model.getSynapseGroups(), m_MergedSynapseConnectivityInitGroups, + createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedSynapseConnectivityInitGroups, [](const SynapseGroupInternal &sg){ return sg.isSparseConnectivityInitRequired(); }, &SynapseGroupInternal::getConnectivityInitHashDigest); LOGD_CODE_GEN << "Merging synapse sparse initialization groups:"; - createMergedGroupsHash(model, backend, model.getSynapseGroups(), m_MergedSynapseSparseInitGroups, + createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedSynapseSparseInitGroups, [&backend](const SynapseGroupInternal &sg) { return ((sg.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && @@ -99,7 +100,7 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa &SynapseGroupInternal::getWUInitHashDigest); LOGD_CODE_GEN << "Merging custom sparse weight update initialization groups:"; - createMergedGroupsHash(model, backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateSparseInitGroups, + createMergedGroupsHash(backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateSparseInitGroups, [](const CustomUpdateWUInternal &cg) { return (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) && cg.isVarInitRequired(); @@ -107,7 +108,7 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa &CustomUpdateWUInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging custom connectivity update presynaptic initialisation groups:"; - createMergedGroupsHash(model, backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePreInitGroups, + createMergedGroupsHash(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePreInitGroups, [&backend](const CustomConnectivityUpdateInternal &cg) { return (cg.isPreVarInitRequired() || (backend.isPopulationRNGInitialisedOnDevice() && cg.isRowSimRNGRequired())); @@ -115,22 +116,22 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa &CustomConnectivityUpdateInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging custom connectivity update postsynaptic initialisation groups:"; - createMergedGroupsHash(model, backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePostInitGroups, + createMergedGroupsHash(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePostInitGroups, [](const CustomConnectivityUpdateInternal &cg) { return cg.isPostVarInitRequired(); }, &CustomConnectivityUpdateInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging custom connectivity update sparse initialisation groups:"; - createMergedGroupsHash(model, backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateSparseInitGroups, + createMergedGroupsHash(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateSparseInitGroups, [](const CustomConnectivityUpdateInternal &cg) { return cg.isVarInitRequired(); }, &CustomConnectivityUpdateInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging neuron groups which require their spike queues updating:"; - createMergedGroupsHash(model, backend, model.getNeuronGroups(), m_MergedNeuronSpikeQueueUpdateGroups, + createMergedGroupsHash(backend, model.getNeuronGroups(), m_MergedNeuronSpikeQueueUpdateGroups, [](const NeuronGroupInternal &){ return true; }, &NeuronGroupInternal::getSpikeQueueUpdateHashDigest); LOGD_CODE_GEN << "Merging neuron groups which require their previous spike times updating:"; - createMergedGroupsHash(model, backend, model.getNeuronGroups(), m_MergedNeuronPrevSpikeTimeUpdateGroups, + createMergedGroupsHash(backend, model.getNeuronGroups(), m_MergedNeuronPrevSpikeTimeUpdateGroups, [](const NeuronGroupInternal &ng){ return (ng.isPrevSpikeTimeRequired() || ng.isPrevSpikeEventTimeRequired()); }, &NeuronGroupInternal::getPrevSpikeTimeUpdateHashDigest); @@ -144,11 +145,11 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa } } LOGD_CODE_GEN << "Merging synapse groups which require their dendritic delay updating:"; - createMergedGroupsHash(model, backend, synapseGroupsWithDendriticDelay, m_MergedSynapseDendriticDelayUpdateGroups, + createMergedGroupsHash(backend, synapseGroupsWithDendriticDelay, m_MergedSynapseDendriticDelayUpdateGroups, &SynapseGroupInternal::getDendriticDelayUpdateHashDigest); LOGD_CODE_GEN << "Merging synapse groups which require host code to initialise their synaptic connectivity:"; - createMergedGroupsHash(model, backend, model.getSynapseGroups(), m_MergedSynapseConnectivityHostInitGroups, + createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedSynapseConnectivityHostInitGroups, [](const SynapseGroupInternal &sg) { return !sg.getConnectivityInitialiser().getSnippet()->getHostInitCode().empty(); @@ -156,39 +157,39 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa &SynapseGroupInternal::getConnectivityHostInitHashDigest, true); LOGD_CODE_GEN << "Merging custom update groups:"; - createMergedGroupsHash(model, backend, model.getCustomUpdates(), m_MergedCustomUpdateGroups, + createMergedGroupsHash(backend, model.getCustomUpdates(), m_MergedCustomUpdateGroups, [](const CustomUpdateInternal &) { return true; }, &CustomUpdateInternal::getHashDigest); LOGD_CODE_GEN << "Merging custom weight update groups:"; - createMergedGroupsHash(model, backend, model.getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, + createMergedGroupsHash(backend, model.getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, [](const CustomUpdateWUInternal &cg) { return !cg.isTransposeOperation(); }, &CustomUpdateWUInternal::getHashDigest); LOGD_CODE_GEN << "Merging custom weight transpose update groups:"; - createMergedGroupsHash(model, backend, model.getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, + createMergedGroupsHash(backend, model.getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, [](const CustomUpdateWUInternal &cg) { return cg.isTransposeOperation(); }, &CustomUpdateWUInternal::getHashDigest); if(backend.isHostReductionRequired()) { LOGD_CODE_GEN << "Merging custom weight update groups:"; - createMergedGroupsHash(model, backend, model.getCustomUpdates(), m_MergedCustomUpdateHostReductionGroups, + createMergedGroupsHash(backend, model.getCustomUpdates(), m_MergedCustomUpdateHostReductionGroups, [](const CustomUpdateInternal &cg) { return cg.isBatchReduction(); }, &CustomUpdateInternal::getHashDigest, true); LOGD_CODE_GEN << "Merging custom weight transpose update groups:"; - createMergedGroupsHash(model, backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateHostReductionGroups, + createMergedGroupsHash(backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateHostReductionGroups, [](const CustomUpdateWUInternal &cg) { return cg.isBatchReduction(); }, &CustomUpdateWUInternal::getHashDigest, true); } LOGD_CODE_GEN << "Merging custom connectivity update groups:"; - createMergedGroupsHash(model, backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateGroups, + createMergedGroupsHash(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateGroups, [](const CustomConnectivityUpdateInternal &cg) { return !cg.getCustomConnectivityUpdateModel()->getRowUpdateCode().empty(); }, &CustomConnectivityUpdateInternal::getHashDigest); LOGD_CODE_GEN << "Merging custom connectivity host update groups:"; - createMergedGroupsHash(model, backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityHostUpdateGroups, + createMergedGroupsHash(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityHostUpdateGroups, [](const CustomConnectivityUpdateInternal &cg) { return !cg.getCustomConnectivityUpdateModel()->getHostUpdateCode().empty(); }, &CustomConnectivityUpdateInternal::getHashDigest, true); @@ -563,6 +564,12 @@ boost::uuids::detail::sha1::digest_type ModelSpecMerged::getInitArchetypeHashDig return hash.get_digest(); } //---------------------------------------------------------------------------- +std::string ModelSpecMerged::scalarExp(double value) const +{ + const auto *scalarType = dynamic_cast(m_TypeContext.at("scalar")); + return Utils::writePreciseString(value, scalarType->getMaxDigits10(m_TypeContext)) + scalarType->getLiteralSuffix(m_TypeContext); +} +//---------------------------------------------------------------------------- bool ModelSpecMerged::anyPointerEGPs() const { // Loop through grouped merged EGPs diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 86e9852f88..2acbc4eacf 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -274,7 +274,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C os << "linSyn += *denDelayFront;" << std::endl; // Zero delay buffer slot - os << "*denDelayFront = " << model.scalarExpr(0.0) << ";" << std::endl; + os << "*denDelayFront = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; } // Pull postsynaptic model variables in a coalesced access From ec9c06e7dffc2400fc4015ca0486de4e48f8367e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 18 Jan 2023 17:19:48 +0000 Subject: [PATCH 068/725] almost compiles aside from EGPs --- .../genn/genn/code_generator/backendBase.h | 12 +++++- .../genn/genn/code_generator/codeGenUtils.h | 11 ----- .../genn/code_generator/modelSpecMerged.h | 2 +- src/genn/genn/code_generator/backendBase.cc | 41 +++++++++++++++++++ src/genn/genn/code_generator/codeGenUtils.cc | 41 ------------------- .../customConnectivityUpdateGroupMerged.cc | 6 +-- .../code_generator/customUpdateGroupMerged.cc | 14 +++---- .../genn/code_generator/initGroupMerged.cc | 6 +-- .../genn/code_generator/modelSpecMerged.cc | 2 +- 9 files changed, 66 insertions(+), 69 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 5603e64e63..8f2f0749b4 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -15,13 +15,15 @@ #include // GeNN includes -#include "codeStream.h" #include "gennExport.h" #include "gennUtils.h" #include "type.h" #include "varAccess.h" #include "variableMode.h" +// GeNN code generator includes +#include "code_generator/codeStream.h" + // Forward declarations namespace GeNN { @@ -494,6 +496,14 @@ class GENN_EXPORT BackendBase void genCustomUpdateIndexCalculation(CodeStream &os, const CustomUpdateGroupMerged &cu) const; void genCustomConnectivityUpdateIndexCalculation(CodeStream &os, const CustomConnectivityUpdateGroupMerged &cu) const; + + //! Get the initial value to start reduction operations from + std::string getReductionInitialValue(VarAccessMode access, const Type::NumericBase *type, const Type::TypeContext &context) const; + + //! Generate a reduction operation to reduce value into reduction + std::string getReductionOperation(const std::string &reduction, const std::string &value, VarAccessMode access, + const Type::NumericBase *type, const Type::TypeContext &context) const; + //! Helper function to generate initialisation code for any reduction operations carried out be custom update group. //! Returns vector of ReductionTarget structs, providing all information to write back reduction results to memory diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 88e04f1cb0..ac07a50899 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -75,17 +75,6 @@ GENN_EXPORT void genTypeRange(CodeStream &os, const Type::NumericBase *precision //-------------------------------------------------------------------------- GENN_EXPORT std::string ensureFtype(const std::string &oldcode, const std::string &type); -//-------------------------------------------------------------------------- -//! \brief Get the initial value to start reduction operations from -//-------------------------------------------------------------------------- -GENN_EXPORT std::string getReductionInitialValue(VarAccessMode access, const Type::NumericBase *type, const Type::TypeContext &context); - -//-------------------------------------------------------------------------- -//! \brief Generate a reduction operation to reduce value into reduction -//-------------------------------------------------------------------------- -GENN_EXPORT std::string getReductionOperation(const std::string &reduction, const std::string &value, VarAccessMode access, - const Type::NumericBase *type, const Type::TypeContext &context); - //-------------------------------------------------------------------------- /*! \brief This function checks for unknown variable definitions and returns a gennError if any are found */ diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 449045bf43..3215481e10 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -215,7 +215,7 @@ class GENN_EXPORT ModelSpecMerged boost::uuids::detail::sha1::digest_type getInitArchetypeHashDigest() const; //! Get the string literal that should be used to represent a value in scalar type - std::string scalarExp(double value) const; + std::string scalarExpr(double value) const; //! Does model have any EGPs? bool anyPointerEGPs() const; diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 879c4cf736..99e7f8802e 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -189,6 +189,47 @@ void BackendBase::genCustomConnectivityUpdateIndexCalculation(CodeStream &os, co os << "const unsigned int postDelayOffset = (*group->postSpkQuePtr * group->numTrgNeurons);" << std::endl; } } +//---------------------------------------------------------------------------- +std::string BackendBase::getReductionInitialValue(VarAccessMode access, const Type::NumericBase *type, const Type::TypeContext &context) const +{ + // If reduction is a sum, initialise to zero + if(access & VarAccessModeAttribute::SUM) { + return "0"; + } + // Otherwise, reduction is a maximum operation, return lowest value for type + else if(access & VarAccessModeAttribute::MAX) { + return Utils::writePreciseString(type->getLowest(context)); + } + else { + assert(false); + return ""; + } +} +//---------------------------------------------------------------------------- +std::string BackendBase::getReductionOperation(const std::string &reduction, const std::string &value, VarAccessMode access, + const Type::NumericBase *type, const Type::TypeContext &context) const +{ + // If operation is sum, add output of custom update to sum + if(access & VarAccessModeAttribute::SUM) { + return reduction + " += " + value; + } + // Otherwise, if it's max + else if(access & VarAccessModeAttribute::MAX) { + // If type is integral, generate max call + if(type->isIntegral(context)) { + return reduction + " = " + "max(" + reduction + ", " + value + ")"; + + } + // Otherwise, generate gmax call + else { + return reduction + " = " + "fmax(" + reduction + ", " + value + ")"; + } + } + else { + assert(false); + return ""; + } +} //----------------------------------------------------------------------- std::vector BackendBase::genInitReductionTargets(CodeStream &os, const CustomUpdateGroupMerged &cg, const std::string &idx) const { diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index 74d23452f4..4040ad149e 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -435,47 +435,6 @@ std::string ensureFtype(const std::string &oldcode, const std::string &type) return code; } //---------------------------------------------------------------------------- -std::string getReductionInitialValue(VarAccessMode access, const Type::NumericBase *type, const Type::TypeContext &context) -{ - // If reduction is a sum, initialise to zero - if(access & VarAccessModeAttribute::SUM) { - return "0"; - } - // Otherwise, reduction is a maximum operation, return lowest value for type - else if(access & VarAccessModeAttribute::MAX) { - return Utils::writePreciseString(type->getLowest(context)); - } - else { - assert(false); - return ""; - } -} -//---------------------------------------------------------------------------- -std::string getReductionOperation(const std::string &reduction, const std::string &value, VarAccessMode access, - const Type::NumericBase *type, const Type::TypeContext &context) -{ - // If operation is sum, add output of custom update to sum - if(access & VarAccessModeAttribute::SUM) { - return reduction + " += " + value; - } - // Otherwise, if it's max - else if(access & VarAccessModeAttribute::MAX) { - // If type is integral, generate max call - if(type->isIntegral(context)) { - return reduction + " = " + "max(" + reduction + ", " + value + ")"; - - } - // Otherwise, generate gmax call - else { - return reduction + " = " + "fmax(" + reduction + ", " + value + ")"; - } - } - else { - assert(false); - return ""; - } -} -//---------------------------------------------------------------------------- void checkUnreplacedVariables(const std::string &code, const std::string &codeName) { std::regex rgx("\\$\\([\\w]+\\)"); diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index efbec90e60..f2a5733f7b 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -80,11 +80,11 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t dependentVarsList.sort([](const auto &a, const auto &b) { boost::uuids::detail::sha1 hashA; - Utils::updateHash(a.getVar().type, hashA); + Utils::updateHash(a.getVar().type->getName(), hashA); Utils::updateHash(getVarAccessDuplication(a.getVar().access), hashA); boost::uuids::detail::sha1 hashB; - Utils::updateHash(b.getVar().type, hashB); + Utils::updateHash(b.getVar().type->getName(), hashB); Utils::updateHash(getVarAccessDuplication(b.getVar().access), hashB); return (hashA.get_digest() < hashB.get_digest()); @@ -528,7 +528,7 @@ void CustomConnectivityHostUpdateGroupMerged::addVarPushPullFuncSubs(const Backe // Loop through variables for(const auto &v : vars) { // If var is located on the host - const auto loc = (getArchetype().*getVarLocationFn)(v.name); + const auto loc = std::invoke(getVarLocationFn, getArchetype(), v.name); if (loc & VarLocation::HOST) { // Generate code to push this variable // **YUCK** these EGP functions should probably just be called dynamic or something diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index be03687422..033328459f 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -21,8 +21,7 @@ using namespace GeNN::Transpiler; namespace { template -void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, - const ModelSpecMerged &modelMerged, const std::string &index, +void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, const std::string &index, R getVarRefIndex) { Substitutions updateSubs(&baseSubs); @@ -118,7 +117,7 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::TypeC using namespace Type; // Create type environment - GroupMergedTypeEnvironment typeEnvironment(*this, getScalarType()); + GroupMergedTypeEnvironment typeEnvironment(*this); addField("size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); @@ -193,7 +192,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() //---------------------------------------------------------------------------- void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { - genCustomUpdate(os, popSubs, *this, modelMerged, "id", + genCustomUpdate(os, popSubs, *this, "id", [this](const auto &varRef, const std::string &index) { return getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, @@ -297,8 +296,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const using namespace Type; // Create type environment - // **TEMP** parse precision to get scalar type - GroupMergedTypeEnvironment typeEnvironment(*this, getScalarType()); + GroupMergedTypeEnvironment typeEnvironment(*this); // If underlying synapse group has kernel weights if (getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL) { @@ -398,7 +396,7 @@ const std::string CustomUpdateWUGroupMerged::name = "CustomUpdateWU"; //---------------------------------------------------------------------------- void CustomUpdateWUGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { - genCustomUpdate(os, popSubs, *this, modelMerged, "id_syn", + genCustomUpdate(os, popSubs, *this, "id_syn", [this, &modelMerged](const auto &varRef, const std::string &index) { return getVarRefIndex(getVarAccessDuplication(varRef.getVar().access), @@ -413,7 +411,7 @@ const std::string CustomUpdateTransposeWUGroupMerged::name = "CustomUpdateTransp //---------------------------------------------------------------------------- void CustomUpdateTransposeWUGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { - genCustomUpdate(os, popSubs, *this, modelMerged, "id_syn", + genCustomUpdate(os, popSubs, *this, "id_syn", [this, &modelMerged](const auto &varRef, const std::string &index) { return getVarRefIndex(getVarAccessDuplication(varRef.getVar().access), diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 81b70991f6..4e0bbfb8b2 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -314,7 +314,7 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream // Zero InSyn backend.genVariableInit(os, "group->numNeurons", "id", popSubs, - [&model, i] (CodeStream &os, Substitutions &varSubs) + [&model, &modelMerged, i] (CodeStream &os, Substitutions &varSubs) { genVariableFill(os, "inSynInSyn" + std::to_string(i), modelMerged.scalarExpr(0.0), varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, model.getBatchSize()); @@ -325,7 +325,7 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream if(sg->isDendriticDelayRequired()) { // Zero dendritic delay buffer backend.genVariableInit(os, "group->numNeurons", "id", popSubs, - [&model, sg, i](CodeStream &os, Substitutions &varSubs) + [&model, &modelMerged, sg, i](CodeStream &os, Substitutions &varSubs) { genVariableFill(os, "denDelayInSyn" + std::to_string(i), modelMerged.scalarExpr(0.0), varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, model.getBatchSize(), @@ -374,7 +374,7 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream for(size_t i = 0; i < getSortedArchetypeMergedPreOutputOutSyns().size(); i++) { // Zero revInSynOutSyn backend.genVariableInit(os, "group->numNeurons", "id", popSubs, - [&model, i] (CodeStream &os, Substitutions &varSubs) + [&model, &modelMerged, i] (CodeStream &os, Substitutions &varSubs) { genVariableFill(os, "revInSynOutSyn" + std::to_string(i), modelMerged.scalarExpr(0.0), varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, model.getBatchSize()); diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index de15687e2b..67180589af 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -564,7 +564,7 @@ boost::uuids::detail::sha1::digest_type ModelSpecMerged::getInitArchetypeHashDig return hash.get_digest(); } //---------------------------------------------------------------------------- -std::string ModelSpecMerged::scalarExp(double value) const +std::string ModelSpecMerged::scalarExpr(double value) const { const auto *scalarType = dynamic_cast(m_TypeContext.at("scalar")); return Utils::writePreciseString(value, scalarType->getMaxDigits10(m_TypeContext)) + scalarType->getLiteralSuffix(m_TypeContext); From 3305c7b095395988cb0ecf4380c4d5213c17c230 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 18 Jan 2023 17:52:55 +0000 Subject: [PATCH 069/725] extra global parameters also use type system --- .../genn/genn/code_generator/backendBase.h | 14 +++--- .../genn/genn/code_generator/groupMerged.h | 7 +-- .../groupMergedTypeEnvironment.h | 5 +- include/genn/genn/models.h | 4 -- include/genn/genn/snippet.h | 40 +++++++--------- include/genn/genn/type.h | 2 +- .../customConnectivityUpdateGroupMerged.cc | 43 ++++++++--------- .../genn/code_generator/generateRunner.cc | 27 +++++------ src/genn/genn/code_generator/groupMerged.cc | 6 +-- .../genn/code_generator/initGroupMerged.cc | 19 ++++---- .../code_generator/neuronUpdateGroupMerged.cc | 2 +- src/genn/genn/currentSource.cc | 8 +--- src/genn/genn/customConnectivityUpdate.cc | 1 + src/genn/genn/customUpdate.cc | 1 + src/genn/genn/neuronGroup.cc | 6 +-- src/genn/genn/snippet.cc | 46 ++++++++++++++++++- src/genn/genn/synapseGroup.cc | 19 ++------ src/genn/genn/type.cc | 2 +- 18 files changed, 133 insertions(+), 119 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 8f2f0749b4..2a9daf2d80 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -250,13 +250,13 @@ class GENN_EXPORT BackendBase virtual void genVariableAllocation(CodeStream &os, const Type::Base *type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const = 0; virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const = 0; - virtual void genExtraGlobalParamDefinition(CodeStream &definitions, CodeStream &definitionsInternal, const std::string &type, const std::string &name, VarLocation loc) const = 0; - virtual void genExtraGlobalParamImplementation(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc) const = 0; - virtual void genExtraGlobalParamAllocation(CodeStream &os, const std::string &type, const std::string &name, + virtual void genExtraGlobalParamDefinition(CodeStream &definitions, CodeStream &definitionsInternal, const Type::Base *type, const std::string &name, VarLocation loc) const = 0; + virtual void genExtraGlobalParamImplementation(CodeStream &os, const Type::Pointer *type, const std::string &name, VarLocation loc) const = 0; + virtual void genExtraGlobalParamAllocation(CodeStream &os, const Type::Pointer *type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; - virtual void genExtraGlobalParamPush(CodeStream &os, const std::string &type, const std::string &name, + virtual void genExtraGlobalParamPush(CodeStream &os, const Type::Pointer *type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; - virtual void genExtraGlobalParamPull(CodeStream &os, const std::string &type, const std::string &name, + virtual void genExtraGlobalParamPull(CodeStream &os, const Type::Pointer *type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; //! Generate code for pushing an updated EGP value into the merged group structure on 'device' @@ -410,7 +410,7 @@ class GENN_EXPORT BackendBase genVariablePull(pull, type, name, loc, count); } - //! Templated version of gelper function to generate matching push and pull functions for + //! Templated version of helper function to generate matching push and pull functions for //! a variable when type is known at compile time template void genVariablePushPull(CodeStream &push, CodeStream &pull, @@ -438,7 +438,7 @@ class GENN_EXPORT BackendBase //! Helper function to generate matching definition, declaration, allocation and free code for an array void genArray(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const Type::Base *type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const + const Type::NumericBase *type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const { genVariableDefinition(definitions, definitionsInternal, type->getPointerType(), name, loc); genVariableImplementation(runner, type->getPointerType(), name, loc); diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index eb09646ade..fdfe3bd620 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -305,8 +305,7 @@ class GroupMerged void addEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "") { for(const auto &e : egps) { - assert(Utils::isTypePointer(e.type)); - addField(Type::parseNumericPtr(e.type), e.name + varName, + addField(e.type->getPointerType(), e.name + varName, [e, arrayPrefix, varName](const G &g, size_t) { return arrayPrefix + e.name + varName + g.getName(); }, GroupMergedFieldType::DYNAMIC); } @@ -812,6 +811,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged(this)->*isChildDerivedParamHeterogeneousFn)(childIndex, varName, d.name)) { addScalarField(d.name + varName + prefix + std::to_string(childIndex), [&sortedGroupChildren, childIndex, varName, d, getVarInitialiserFn](const NeuronGroupInternal &, size_t groupIndex) @@ -829,7 +829,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMergedgetPointerType(), e.name + prefix + std::to_string(childIndex), [getEGPSuffixFn, childIndex, e, arrayPrefix](const NeuronGroupInternal&, size_t groupIndex) { return arrayPrefix + e.name + getEGPSuffixFn(groupIndex, childIndex); @@ -847,6 +847,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged*getValueFn)(); for(const auto &p : archetypeParams) { // If any of the code strings reference the parameter + // **TODO** std::invoke if((static_cast(this)->*isChildParamReferencedFn)(childIndex, p.first)) { // Loop through groups for(size_t g = 0; g < getGroups().size(); g++) { diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index 417dae24bc..5af65bc19e 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -218,9 +218,8 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa void defineEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "") { for(const auto &e : egps) { - const auto *type = Type::parseNumericPtr(e.type); - defineField(type, e.name, - type, e.name + varName, + defineField(e.type->getPointerType(), e.name, + e.type->getPointerType(), e.name + varName, [arrayPrefix, e, varName](const auto &g, size_t) { return arrayPrefix + e.name + varName + g.getName(); diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 659a248a30..06db894612 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -24,10 +24,6 @@ class CurrentSource; class NeuronGroupInternal; class SynapseGroupInternal; class CurrentSourceInternal; -namespace CodeGenerator -{ -class BackendBase; -} namespace Type { class NumericBase; diff --git a/include/genn/genn/snippet.h b/include/genn/genn/snippet.h index f6ff4ab67f..b25fe99050 100644 --- a/include/genn/genn/snippet.h +++ b/include/genn/genn/snippet.h @@ -13,7 +13,15 @@ // GeNN includes #include "gennExport.h" #include "gennUtils.h" -#include "type.h" + +// Forward declarations +namespace GeNN +{ +namespace Type +{ +class NumericBase; +} +} //---------------------------------------------------------------------------- // Macros @@ -57,13 +65,13 @@ class GENN_EXPORT Base //! An extra global parameter has a name and a type struct EGP { - bool operator == (const EGP &other) const - { - return ((name == other.name) && (type == other.type)); - } + EGP(const std::string &n, const Type::NumericBase *t); + EGP(const std::string &n, const std::string &t); + + bool operator == (const EGP &other) const; const std::string name; - const std::string type; + const Type::NumericBase *type; }; //! Additional input variables, row state variables and other things have a name, a type and an initial value @@ -208,21 +216,7 @@ class Init //---------------------------------------------------------------------------- // updateHash overrides //---------------------------------------------------------------------------- -inline void updateHash(const Base::EGP &e, boost::uuids::detail::sha1 &hash) -{ - Utils::updateHash(e.name, hash); - Utils::updateHash(e.type, hash); -} - -inline void updateHash(const Base::ParamVal &p, boost::uuids::detail::sha1 &hash) -{ - Utils::updateHash(p.name, hash); - Utils::updateHash(p.type, hash); - Utils::updateHash(p.value, hash); -} - -inline void updateHash(const Base::DerivedParam &d, boost::uuids::detail::sha1 &hash) -{ - Utils::updateHash(d.name, hash); -} +GENN_EXPORT void updateHash(const Base::EGP &e, boost::uuids::detail::sha1 &hash); +GENN_EXPORT void updateHash(const Base::ParamVal &p, boost::uuids::detail::sha1 &hash); +GENN_EXPORT void updateHash(const Base::DerivedParam &d, boost::uuids::detail::sha1 &hash); } // namespace GeNN::Snippet diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 10893bf57c..6d3d19cf5b 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -120,7 +120,7 @@ class Base // Public API //------------------------------------------------------------------------ //! Return a pointer to this type, optionally, with specified qualifiers - const Base *getPointerType(Qualifier qualifiers = Qualifier{0}) const; + const class Pointer *getPointerType(Qualifier qualifiers = Qualifier{0}) const; //! Does this type have qualifier? bool hasQualifier(Qualifier qualifier) const{ return (m_Qualifiers & qualifier); }; diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index f2a5733f7b..cf1f1896c8 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -440,13 +440,12 @@ CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged // Add host extra global parameters for(const auto &e : cm->getExtraGlobalParams()) { - const auto *pointerType = parseNumericPtr(e.type); - addField(pointerType, e.name, + addField(e.type->getPointerType(), e.name, [e](const auto &g, size_t) { return e.name + g.getName(); }, GroupMergedFieldType::HOST_DYNAMIC); if(!backend.getDeviceVarPrefix().empty()) { - addField(pointerType, backend.getDeviceVarPrefix() + e.name, + addField(e.type->getPointerType(), backend.getDeviceVarPrefix() + e.name, [e, &backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + e.name + g.getName(); @@ -486,26 +485,24 @@ void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase & // Loop through EGPs for(const auto &egp : cm->getExtraGlobalParams()) { - // If EGP is a pointer - if(Utils::isTypePointer(egp.type)) { - // Generate code to push this EGP with count specified by $(0) - std::stringstream pushStream; - CodeStream push(pushStream); - backend.genExtraGlobalParamPush(push, egp.type, egp.name, - VarLocation::HOST_DEVICE, "$(0)", "group->"); - - // Add substitution - subs.addFuncSubstitution("push" + egp.name + "ToDevice", 1, pushStream.str()); - - // Generate code to pull this EGP with count specified by $(0) - std::stringstream pullStream; - CodeStream pull(pullStream); - backend.genExtraGlobalParamPull(pull, egp.type, egp.name, - VarLocation::HOST_DEVICE, "$(0)", "group->"); - - // Add substitution - subs.addFuncSubstitution("pull" + egp.name + "FromDevice", 1, pullStream.str()); - } + // Generate code to push this EGP with count specified by $(0) + std::stringstream pushStream; + const auto *pointerType = egp.type->getPointerType(); + CodeStream push(pushStream); + backend.genExtraGlobalParamPush(push, pointerType, egp.name, + VarLocation::HOST_DEVICE, "$(0)", "group->"); + + // Add substitution + subs.addFuncSubstitution("push" + egp.name + "ToDevice", 1, pushStream.str()); + + // Generate code to pull this EGP with count specified by $(0) + std::stringstream pullStream; + CodeStream pull(pullStream); + backend.genExtraGlobalParamPull(pull, pointerType, egp.name, + VarLocation::HOST_DEVICE, "$(0)", "group->"); + + // Add substitution + subs.addFuncSubstitution("pull" + egp.name + "FromDevice", 1, pullStream.str()); } addVarPushPullFuncSubs(backend, subs, cm->getPreVars(), "group->numSrcNeurons", diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 0024d47ca6..37da7b7ff9 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -255,7 +255,7 @@ void genStatePushPull(CodeStream &definitionsFunc, CodeStream &runnerPushFunc, C //------------------------------------------------------------------------- void genVariable(const BackendBase &backend, CodeStream &definitionsVar, CodeStream &definitionsFunc, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - CodeStream &push, CodeStream &pull, const Type::Base *type, const std::string &name, + CodeStream &push, CodeStream &pull, const Type::NumericBase *type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count, MemAlloc &mem, std::vector &statePushPullFunction) { @@ -273,14 +273,15 @@ void genVariable(const BackendBase &backend, CodeStream &definitionsVar, CodeStr //------------------------------------------------------------------------- void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase &backend, CodeStream &definitionsVar, CodeStream &definitionsFunc, CodeStream &definitionsInternalVar, CodeStream &runner, - CodeStream &extraGlobalParam, const std::string &type, const std::string &name, bool apiRequired, VarLocation loc) + CodeStream &extraGlobalParam, const Type::NumericBase *type, const std::string &name, bool apiRequired, VarLocation loc) { // Generate variables - backend.genExtraGlobalParamDefinition(definitionsVar, definitionsInternalVar, type, name, loc); - backend.genExtraGlobalParamImplementation(runner, type, name, loc); + const auto *pointerType = type->getPointerType(); + backend.genExtraGlobalParamDefinition(definitionsVar, definitionsInternalVar, pointerType, name, loc); + backend.genExtraGlobalParamImplementation(runner, pointerType, name, loc); - // If type is a pointer and API is required - if(Utils::isTypePointer(type) && apiRequired) { + // If API is required + if(apiRequired) { // Write definitions for functions to allocate and free extra global param definitionsFunc << "EXPORT_FUNC void allocate" << name << "(unsigned int count);" << std::endl; definitionsFunc << "EXPORT_FUNC void free" << name << "();" << std::endl; @@ -289,7 +290,7 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & extraGlobalParam << "void allocate" << name << "(unsigned int count)"; { CodeStream::Scope a(extraGlobalParam); - backend.genExtraGlobalParamAllocation(extraGlobalParam, type, name, loc); + backend.genExtraGlobalParamAllocation(extraGlobalParam, pointerType, name, loc); // Loop through destinations in merged structures, the device EGP needs to be copied to if(modelMerged.anyMergedEGPDestinations(backend.getDeviceVarPrefix() + name)) { @@ -351,7 +352,7 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & extraGlobalParam << "void push" << name << "ToDevice(unsigned int count)"; { CodeStream::Scope a(extraGlobalParam); - backend.genExtraGlobalParamPush(extraGlobalParam, type, name, loc); + backend.genExtraGlobalParamPush(extraGlobalParam, pointerType, name, loc); } if(backend.getPreferences().generateExtraGlobalParamPull) { @@ -362,7 +363,7 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & extraGlobalParam << "void pull" << name << "FromDevice(unsigned int count)"; { CodeGenerator::CodeStream::Scope a(extraGlobalParam); - backend.genExtraGlobalParamPull(extraGlobalParam, type, name, loc); + backend.genExtraGlobalParamPull(extraGlobalParam, pointerType, name, loc); } } } @@ -1592,7 +1593,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // **YUCK** maybe this should be renamed genDynamicArray if(n.second.isSpikeRecordingEnabled()) { CodeStream::Scope b(runner); - backend.genExtraGlobalParamAllocation(runner, "uint32_t*", "recordSpk" + n.first, VarLocation::HOST_DEVICE, "numWords"); + backend.genExtraGlobalParamAllocation(runner, Type::Uint32::getInstance()->getPointerType(), "recordSpk" + n.first, VarLocation::HOST_DEVICE, "numWords"); // Get destinations in merged structures, this EGP // needs to be copied to and call push function @@ -1607,7 +1608,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // **YUCK** maybe this should be renamed genDynamicArray if(n.second.isSpikeEventRecordingEnabled()) { CodeStream::Scope b(runner); - backend.genExtraGlobalParamAllocation(runner, "uint32_t*", "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE, "numWords"); + backend.genExtraGlobalParamAllocation(runner, Type::Uint32::getInstance()->getPointerType(), "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE, "numWords"); // Get destinations in merged structures, this EGP // needs to be copied to and call push function @@ -1646,13 +1647,13 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // **YUCK** maybe this should be renamed pullDynamicArray if(n.second.isSpikeRecordingEnabled()) { CodeStream::Scope b(runner); - backend.genExtraGlobalParamPull(runner, "uint32_t*", "recordSpk" + n.first, VarLocation::HOST_DEVICE, "numWords"); + backend.genExtraGlobalParamPull(runner, Type::Uint32::getInstance()->getPointerType(), "recordSpk" + n.first, VarLocation::HOST_DEVICE, "numWords"); } // AllocaPullte spike event array if required // **YUCK** maybe this should be renamed pullDynamicArray if(n.second.isSpikeEventRecordingEnabled()) { CodeStream::Scope b(runner); - backend.genExtraGlobalParamPull(runner, "uint32_t*", "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE, "numWords"); + backend.genExtraGlobalParamPull(runner, Type::Uint32::getInstance()->getPointerType(), "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE, "numWords"); } } } diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index d91a64cf5f..e7d3529183 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -781,7 +781,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon for(const auto &e : preEGPs) { if(code.find("$(" + e.name + "_pre)") != std::string::npos) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(parseNumericPtr(e.type), e.name + "Pre", + addField(e.type->getPointerType(), e.name + "Pre", [e, prefix](const auto &sg, size_t) { return prefix + e.name + sg.getSrcNeuronGroup()->getName(); }, GroupMergedFieldType::DYNAMIC); } @@ -792,7 +792,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon for(const auto &e : postEGPs) { if(code.find("$(" + e.name + "_post)") != std::string::npos) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(parseNumericPtr(e.type), e.name + "Post", + addField(e.type->getPointerType(), e.name + "Post", [e, prefix](const auto &sg, size_t) { return prefix + e.name + sg.getTrgNeuronGroup()->getName(); }, GroupMergedFieldType::DYNAMIC); } @@ -972,7 +972,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon const auto egps = snippet->getExtraGlobalParams(); for(const auto &e : egps) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(parseNumericPtr(e.type), e.name + var.name, + addField(e.type->getPointerType(), e.name + var.name, [e, prefix, var](const SynapseGroupInternal &sg, size_t) { return prefix + e.name + var.name + sg.getName(); diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 4e0bbfb8b2..6b175d5b31 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -744,13 +744,13 @@ SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(s // Add EGP pointers to struct for both host and device EGPs if they are seperate const auto egps = getArchetype().getConnectivityInitialiser().getSnippet()->getExtraGlobalParams(); for(const auto &e : egps) { - assert(false); - /*addField(e.type + "*", e.name, + const auto *pointerToPointerToEGP = e.type->getPointerType()->getPointerType(); + addField(pointerToPointerToEGP, e.name, [e](const SynapseGroupInternal &g, size_t) { return "&" + e.name + g.getName(); }, GroupMergedFieldType::HOST_DYNAMIC); if(!backend.getDeviceVarPrefix().empty()) { - addField(e.type + "*", backend.getDeviceVarPrefix() + e.name, + addField(pointerToPointerToEGP, backend.getDeviceVarPrefix() + e.name, [e, &backend](const SynapseGroupInternal &g, size_t) { return "&" + backend.getDeviceVarPrefix() + e.name + g.getName(); @@ -758,13 +758,13 @@ SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(s GroupMergedFieldType::DYNAMIC); } if(!backend.getHostVarPrefix().empty()) { - addField(e.type + "*", backend.getHostVarPrefix() + e.name, + addField(pointerToPointerToEGP, backend.getHostVarPrefix() + e.name, [e, &backend](const SynapseGroupInternal &g, size_t) { return "&" + backend.getHostVarPrefix() + e.name + g.getName(); }, GroupMergedFieldType::DYNAMIC); - }*/ + } } } //------------------------------------------------------------------------- @@ -801,13 +801,14 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac // Loop through EGPs for(const auto &egp : connectInit.getSnippet()->getExtraGlobalParams()) { - // If EGP is a pointer and located on the host + // If EGP is located on the host const auto loc = getArchetype().getSparseConnectivityExtraGlobalParamLocation(egp.name); - if(Utils::isTypePointer(egp.type) && (loc & VarLocation::HOST)) { + if(loc & VarLocation::HOST) { // Generate code to allocate this EGP with count specified by $(0) std::stringstream allocStream; + const auto *pointerToPointerToEGP = egp.type->getPointerType()->getPointerType(); CodeGenerator::CodeStream alloc(allocStream); - backend.genExtraGlobalParamAllocation(alloc, egp.type + "*", egp.name, + backend.genExtraGlobalParamAllocation(alloc, pointerToPointerToEGP, egp.name, loc, "$(0)", "group->"); // Add substitution @@ -816,7 +817,7 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac // Generate code to push this EGP with count specified by $(0) std::stringstream pushStream; CodeStream push(pushStream); - backend.genExtraGlobalParamPush(push, egp.type + "*", egp.name, + backend.genExtraGlobalParamPush(push, pointerToPointerToEGP, egp.name, loc, "$(0)", "group->"); diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 2acbc4eacf..53c86480c4 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -66,7 +66,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC // If EGP is referenced in event threshold code if(s.eventThresholdCode.find("$(" + egp.name + ")") != std::string::npos) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(parseNumericPtr(egp.type), egp.name + "EventThresh" + std::to_string(i), + addField(egp.type->getPointerType(), egp.name + "EventThresh" + std::to_string(i), [eventThresholdSGs, prefix, egp, i](const auto &, size_t groupIndex) { return prefix + egp.name + eventThresholdSGs.at(groupIndex).at(i)->getName(); diff --git a/src/genn/genn/currentSource.cc b/src/genn/genn/currentSource.cc index 0be714adb0..bea478983f 100644 --- a/src/genn/genn/currentSource.cc +++ b/src/genn/genn/currentSource.cc @@ -19,11 +19,7 @@ void CurrentSource::setVarLocation(const std::string &varName, VarLocation loc) //---------------------------------------------------------------------------- void CurrentSource::setExtraGlobalParamLocation(const std::string ¶mName, VarLocation loc) { - const size_t extraGlobalParamIndex = getCurrentSourceModel()->getExtraGlobalParamIndex(paramName); - if(!Utils::isTypePointer(getCurrentSourceModel()->getExtraGlobalParams()[extraGlobalParamIndex].type)) { - throw std::runtime_error("Only extra global parameters with a pointer type have a location"); - } - m_ExtraGlobalParamLocation[extraGlobalParamIndex] = loc; + m_ExtraGlobalParamLocation[getCurrentSourceModel()->getExtraGlobalParamIndex(paramName)] = loc; } //---------------------------------------------------------------------------- VarLocation CurrentSource::getVarLocation(const std::string &varName) const @@ -116,4 +112,4 @@ boost::uuids::detail::sha1::digest_type CurrentSource::getVarLocationHashDigest( Utils::updateHash(m_ExtraGlobalParamLocation, hash); return hash.get_digest(); } -} // namespace GeNN \ No newline at end of file +} // namespace GeNN diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index 60ad5c8bc6..e7f2642a29 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -11,6 +11,7 @@ #include "customUpdateInternal.h" #include "neuronGroupInternal.h" #include "synapseGroupInternal.h" +#include "type.h" using namespace GeNN; diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 5d7fc7af13..7d6d02b2c4 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -9,6 +9,7 @@ #include "currentSource.h" #include "neuronGroupInternal.h" #include "synapseGroupInternal.h" +#include "type.h" //------------------------------------------------------------------------ // GeNN::CustomUpdateBase diff --git a/src/genn/genn/neuronGroup.cc b/src/genn/genn/neuronGroup.cc index 98c2b65e23..6b9cf60cee 100644 --- a/src/genn/genn/neuronGroup.cc +++ b/src/genn/genn/neuronGroup.cc @@ -113,11 +113,7 @@ void NeuronGroup::setVarLocation(const std::string &varName, VarLocation loc) //---------------------------------------------------------------------------- void NeuronGroup::setExtraGlobalParamLocation(const std::string ¶mName, VarLocation loc) { - const size_t extraGlobalParamIndex = getNeuronModel()->getExtraGlobalParamIndex(paramName); - if(!Utils::isTypePointer(getNeuronModel()->getExtraGlobalParams()[extraGlobalParamIndex].type)) { - throw std::runtime_error("Only extra global parameters with a pointer type have a location"); - } - m_ExtraGlobalParamLocation.at(extraGlobalParamIndex) = loc; + m_ExtraGlobalParamLocation.at(getNeuronModel()->getExtraGlobalParamIndex(paramName)) = loc; } //---------------------------------------------------------------------------- VarLocation NeuronGroup::getVarLocation(const std::string &varName) const diff --git a/src/genn/genn/snippet.cc b/src/genn/genn/snippet.cc index 931ca3678e..c19dd0bab0 100644 --- a/src/genn/genn/snippet.cc +++ b/src/genn/genn/snippet.cc @@ -1,10 +1,31 @@ #include "snippet.h" +// GeNN includes +#include "logging.h" +#include "type.h" + //---------------------------------------------------------------------------- -// GeNN::Snippet::Base +// GeNN::Snippet::Base::EGP //---------------------------------------------------------------------------- namespace GeNN::Snippet { +Base::EGP::EGP(const std::string &n, const std::string &t) +: name(n), type(Type::parseNumeric((t.back() == '*') ? t.substr(0, t.length() - 1) : t)) +{ + // If type ends in a *, give warning as this is legacy syntax + if(t.back() == '*') { + LOGW_GENN << "Extra global parameters are now always arrays so * at end of type is no longer necessary"; + } +} +//---------------------------------------------------------------------------- +bool Base::EGP::operator == (const EGP &other) const +{ + return ((name == other.name) && (type->getName() == other.type->getName())); +} + +//---------------------------------------------------------------------------- +// GeNN::Snippet::Base +//---------------------------------------------------------------------------- void Base::updateHash(boost::uuids::detail::sha1 &hash) const { Utils::updateHash(getParamNames(), hash); @@ -32,4 +53,25 @@ void Base::validate(const std::unordered_map ¶mValues, } } } -} // namespace GeNN::Snippet \ No newline at end of file + +//---------------------------------------------------------------------------- +// Free functions +//---------------------------------------------------------------------------- +void updateHash(const Base::EGP &e, boost::uuids::detail::sha1 &hash) +{ + Utils::updateHash(e.name, hash); + Utils::updateHash(e.type->getName(), hash); +} +//---------------------------------------------------------------------------- +void updateHash(const Base::ParamVal &p, boost::uuids::detail::sha1 &hash) +{ + Utils::updateHash(p.name, hash); + Utils::updateHash(p.type, hash); + Utils::updateHash(p.value, hash); +} +//---------------------------------------------------------------------------- +void updateHash(const Base::DerivedParam &d, boost::uuids::detail::sha1 &hash) +{ + Utils::updateHash(d.name, hash); +} +} // namespace GeNN::Snippet diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index fee9114cf2..35ed646961 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -10,6 +10,7 @@ #include "gennUtils.h" #include "neuronGroupInternal.h" #include "synapseGroupInternal.h" +#include "type.h" //---------------------------------------------------------------------------- // Anonymous namespace @@ -60,11 +61,7 @@ void SynapseGroup::setWUPostVarLocation(const std::string &varName, VarLocation //---------------------------------------------------------------------------- void SynapseGroup::setWUExtraGlobalParamLocation(const std::string ¶mName, VarLocation loc) { - const size_t extraGlobalParamIndex = getWUModel()->getExtraGlobalParamIndex(paramName); - if(!Utils::isTypePointer(getWUModel()->getExtraGlobalParams()[extraGlobalParamIndex].type)) { - throw std::runtime_error("Only extra global parameters with a pointer type have a location"); - } - m_WUExtraGlobalParamLocation[extraGlobalParamIndex] = loc; + m_WUExtraGlobalParamLocation[getWUModel()->getExtraGlobalParamIndex(paramName)] = loc; } //---------------------------------------------------------------------------- void SynapseGroup::setPSVarLocation(const std::string &varName, VarLocation loc) @@ -104,20 +101,12 @@ void SynapseGroup::setPreTargetVar(const std::string &varName) //---------------------------------------------------------------------------- void SynapseGroup::setPSExtraGlobalParamLocation(const std::string ¶mName, VarLocation loc) { - const size_t extraGlobalParamIndex = getPSModel()->getExtraGlobalParamIndex(paramName); - if(!Utils::isTypePointer(getPSModel()->getExtraGlobalParams()[extraGlobalParamIndex].type)) { - throw std::runtime_error("Only extra global parameters with a pointer type have a location"); - } - m_PSExtraGlobalParamLocation[extraGlobalParamIndex] = loc; + m_PSExtraGlobalParamLocation[getPSModel()->getExtraGlobalParamIndex(paramName)] = loc; } //---------------------------------------------------------------------------- void SynapseGroup::setSparseConnectivityExtraGlobalParamLocation(const std::string ¶mName, VarLocation loc) { - const size_t extraGlobalParamIndex = m_SparseConnectivityInitialiser.getSnippet()->getExtraGlobalParamIndex(paramName); - if(!Utils::isTypePointer(m_SparseConnectivityInitialiser.getSnippet()->getExtraGlobalParams()[extraGlobalParamIndex].type)) { - throw std::runtime_error("Only extra global parameters with a pointer type have a location"); - } - m_ConnectivityExtraGlobalParamLocation[extraGlobalParamIndex] = loc; + m_ConnectivityExtraGlobalParamLocation[m_SparseConnectivityInitialiser.getSnippet()->getExtraGlobalParamIndex(paramName)] = loc; } //---------------------------------------------------------------------------- void SynapseGroup::setSparseConnectivityLocation(VarLocation loc) diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 5b0f537840..e60e3a9abc 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -80,7 +80,7 @@ IMPLEMENT_TYPE(Sqrt); //---------------------------------------------------------------------------- // GeNN::Type::Base //---------------------------------------------------------------------------- -const Base *Base::getPointerType(Qualifier qualifiers) const +const Pointer *Base::getPointerType(Qualifier qualifiers) const { // **TODO** befriend constructor // **TODO** don't just leak these! From f96b97565d9fe4c10d5980874d3ba2bbfea2fc73 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 18 Jan 2023 18:03:40 +0000 Subject: [PATCH 070/725] and additional input variables --- include/genn/genn/models.h | 3 ++- include/genn/genn/snippet.h | 12 ++++++------ src/genn/backends/single_threaded_cpu/backend.cc | 2 +- src/genn/genn/code_generator/initGroupMerged.cc | 2 +- .../genn/code_generator/neuronUpdateGroupMerged.cc | 2 +- .../code_generator/synapseUpdateGroupMerged.cc | 2 +- src/genn/genn/models.cc | 5 +---- src/genn/genn/snippet.cc | 14 +++++++++++++- 8 files changed, 26 insertions(+), 16 deletions(-) diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 06db894612..cca64046c0 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -59,7 +59,8 @@ class GENN_EXPORT Base : public Snippet::Base Var(const std::string &n, const Type::NumericBase *t) : Var(n, t, VarAccess::READ_WRITE) {} Var(const std::string &n, const std::string &t, VarAccess a); - Var(const std::string &n, const std::string &t); + Var(const std::string &n, const std::string &t) : Var(n, t, VarAccess::READ_WRITE) + {} bool operator == (const Var &other) const; diff --git a/include/genn/genn/snippet.h b/include/genn/genn/snippet.h index b25fe99050..072f11c572 100644 --- a/include/genn/genn/snippet.h +++ b/include/genn/genn/snippet.h @@ -77,18 +77,18 @@ class GENN_EXPORT Base //! Additional input variables, row state variables and other things have a name, a type and an initial value struct ParamVal { - ParamVal(const std::string &n, const std::string &t, const std::string &v) : name(n), type(t), value(v) + ParamVal(const std::string &n, const Type::NumericBase *t, const std::string &v) : name(n), type(t), value(v) {} + ParamVal(const std::string &n, const Type::NumericBase *t, double v) : ParamVal(n, t, Utils::writePreciseString(v)) + {} + ParamVal(const std::string &n, const std::string &t, const std::string &v); ParamVal(const std::string &n, const std::string &t, double v) : ParamVal(n, t, Utils::writePreciseString(v)) {} - bool operator == (const ParamVal &other) const - { - return ((name == other.name) && (type == other.type) && (value == other.value)); - } + bool operator == (const ParamVal &other) const; const std::string name; - const std::string type; + const Type::NumericBase *type; const std::string value; }; diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 43ab05ae3a..3db8f7647f 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1675,7 +1675,7 @@ void Backend::genPresynapticUpdate(CodeStream &os, const ModelSpecMerged &modelM connSubs.applyCheckUnreplaced(value, "toeplitz diagonal build state var : merged" + std::to_string(sg.getIndex())); value = ensureFtype(value, modelMerged.getModel().getPrecision()); - os << d.type << " " << d.name << " = " << value << ";" << std::endl; + os << d.type->getResolvedName(sg.getTypeContext()) << " " << d.name << " = " << value << ";" << std::endl; } // Detect spike events or spikes and do the update diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 6b175d5b31..dfa2f29822 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -692,7 +692,7 @@ void SynapseConnectivityInitGroupMerged::genInitConnectivity(CodeStream &os, Sub popSubs.applyCheckUnreplaced(value, "initSparseConnectivity state var : merged" + std::to_string(getIndex())); //value = ensureFtype(value, ftype); - os << a.type << " " << a.name << " = " << value << ";" << std::endl; + os << a.type->getResolvedName(getTypeContext()) << " " << a.name << " = " << value << ";" << std::endl; } os << "while(true)"; { diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 53c86480c4..4d37101d57 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -249,7 +249,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C neuronSubs.applyCheckUnreplaced(value, "neuron additional input var : merged" + std::to_string(getIndex())); //value = ensureFtype(value, modelMerged.getModel().getPrecision()); - os << a.type << " " << a.name << " = " << value << ";" << std::endl; + os << a.type->getResolvedName(getTypeContext()) << " " << a.name << " = " << value << ";" << std::endl; } // Loop through incoming synapse groups diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 853e8e67d4..8b7766455a 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -251,7 +251,7 @@ void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendB std::string value = a.value; popSubs.applyCheckUnreplaced(value, "proceduralSparseConnectivity row build state var : merged" + std::to_string(getIndex())); //value = ensureFtype(value, modelMerged.getModel().getPrecision()); - os << a.type << " " << a.name << " = " << value << ";" << std::endl; + os << a.type->getResolvedName(getTypeContext()) << " " << a.name << " = " << value << ";" << std::endl; } // Loop through synapses in row diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index cba22f803f..9afc6a7ea7 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -16,9 +16,6 @@ namespace GeNN::Models Base::Var::Var(const std::string &n, const std::string &t, VarAccess a) : name(n), type(Type::parseNumeric(t)), access(a) {} //---------------------------------------------------------------------------- -Base::Var::Var(const std::string &n, const std::string &t) : Var(n, t, VarAccess::READ_WRITE) -{} -//---------------------------------------------------------------------------- bool Base::Var::operator == (const Var &other) const { return (std::make_tuple(name, type->getName(), access) == std::make_tuple(other.name, other.type->getName(), other.access)); @@ -266,4 +263,4 @@ void updateHash(const WUVarReference &v, boost::uuids::detail::sha1 &hash) Utils::updateHash(v.getTransposeVarIndex(), hash); } } -} // namespace GeNN::Models \ No newline at end of file +} // namespace GeNN::Models diff --git a/src/genn/genn/snippet.cc b/src/genn/genn/snippet.cc index c19dd0bab0..95d4e5e6f7 100644 --- a/src/genn/genn/snippet.cc +++ b/src/genn/genn/snippet.cc @@ -23,6 +23,18 @@ bool Base::EGP::operator == (const EGP &other) const return ((name == other.name) && (type->getName() == other.type->getName())); } +//---------------------------------------------------------------------------- +// GeNN::Snippet::Base::ParamVal +//---------------------------------------------------------------------------- +Base::ParamVal::ParamVal(const std::string &n, const std::string &t, const std::string &v) : name(n), type(Type::parseNumeric(t)), value(v) +{ +} +//---------------------------------------------------------------------------- +bool Base::ParamVal::operator == (const ParamVal &other) const +{ + return ((name == other.name) && (type->getName() == other.type->getName()) && (value == other.value)); +} + //---------------------------------------------------------------------------- // GeNN::Snippet::Base //---------------------------------------------------------------------------- @@ -66,7 +78,7 @@ void updateHash(const Base::EGP &e, boost::uuids::detail::sha1 &hash) void updateHash(const Base::ParamVal &p, boost::uuids::detail::sha1 &hash) { Utils::updateHash(p.name, hash); - Utils::updateHash(p.type, hash); + Utils::updateHash(p.type->getName(), hash); Utils::updateHash(p.value, hash); } //---------------------------------------------------------------------------- From 3c1aa768746a130fe7b433ba6fb007192cf5d4c9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 18 Jan 2023 18:19:35 +0000 Subject: [PATCH 071/725] forgot to save... --- src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 8d9f2b4476..5dc11aabe2 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -871,7 +871,7 @@ void PostSpanToeplitz::genUpdate(CodeStream &os, const ModelSpecMerged &modelMer connSubs.applyCheckUnreplaced(value, "toeplitz diagonal build state var : merged" + std::to_string(sg.getIndex())); //value = ensureFtype(value, modelMerged.getModel().getPrecision()); - os << d.type << " " << d.name << " = " << value << ";" << std::endl; + os << d.type->getResolvedName(sg.getTypeContext()) << " " << d.name << " = " << value << ";" << std::endl; } os << "const unsigned int numSpikes = group->srcSpkCnt" << eventSuffix << "[" << sg.getPreSlot(batchSize) << "];" << std::endl; From 9bebf710c8fe019b7cfac469afe7b0e5fc3c02a4 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 11:16:41 +0000 Subject: [PATCH 072/725] tidied up worst excesses of backend interface and implemented for CPU --- .../backends/single_threaded_cpu/backend.h | 175 ++++++++------ .../backends/single_threaded_cpu/optimiser.h | 8 +- .../genn/genn/code_generator/backendBase.h | 139 ++++++----- .../genn/code_generator/modelSpecMerged.h | 29 +-- include/genn/genn/type.h | 13 +- src/genn/backends/cuda/backend.cc | 12 +- src/genn/backends/opencl/backend.cc | 13 +- .../backends/single_threaded_cpu/backend.cc | 143 +++++------ .../backends/single_threaded_cpu/optimiser.cc | 7 +- .../customConnectivityUpdateGroupMerged.cc | 18 +- .../genn/code_generator/generateRunner.cc | 225 +++++++++--------- .../genn/code_generator/initGroupMerged.cc | 13 +- 12 files changed, 416 insertions(+), 379 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 066dfbe3e2..ccc0811bb7 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -33,129 +33,156 @@ struct Preferences : public PreferencesBase class BACKEND_EXPORT Backend : public BackendBase { public: - Backend(const std::string &scalarType, const Preferences &preferences) - : BackendBase(scalarType, preferences) + Backend(const Preferences &preferences) + : BackendBase(preferences) { } //-------------------------------------------------------------------------- // CodeGenerator::BackendBase virtuals //-------------------------------------------------------------------------- - virtual void genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; + virtual void genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; - virtual void genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; + virtual void genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; - virtual void genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; + virtual void genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; - virtual void genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; + virtual void genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; - virtual size_t getSynapticMatrixRowStride(const SynapseGroupInternal &sg) const override; + virtual size_t getSynapticMatrixRowStride(const SynapseGroupInternal &sg) const final; - virtual void genDefinitionsPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; - virtual void genDefinitionsInternalPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; - virtual void genRunnerPreamble(CodeStream &os, const ModelSpecMerged &modelMerged, const MemAlloc &memAlloc) const override; - virtual void genAllocateMemPreamble(CodeStream &os, const ModelSpecMerged &modelMerged, const MemAlloc &memAlloc) const override; - virtual void genFreeMemPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; - virtual void genStepTimeFinalisePreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; + virtual void genDefinitionsPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final; + virtual void genDefinitionsInternalPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final; + virtual void genRunnerPreamble(CodeStream &os, const ModelSpecMerged &modelMerged, const MemAlloc &memAlloc) const final; + virtual void genAllocateMemPreamble(CodeStream &os, const ModelSpecMerged &modelMerged, const MemAlloc &memAlloc) const final; + virtual void genFreeMemPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final; + virtual void genStepTimeFinalisePreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final; - virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, const std::string &type, const std::string &name, VarLocation loc) const override; - virtual void genVariableImplementation(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc) const override; - virtual void genVariableAllocation(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const override; - virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const override; - - virtual void genExtraGlobalParamDefinition(CodeStream &definitions, CodeStream &definitionsInternal, const std::string &type, const std::string &name, VarLocation loc) const override; - virtual void genExtraGlobalParamImplementation(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc) const override; - virtual void genExtraGlobalParamAllocation(CodeStream &os, const std::string &type, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const override; - virtual void genExtraGlobalParamPush(CodeStream &os, const std::string &type, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const override; - virtual void genExtraGlobalParamPull(CodeStream &os, const std::string &type, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const override; - - ///! Generate code for pushing an updated EGP value into the merged group structure on 'device' - virtual void genMergedExtraGlobalParamPush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, - const std::string &groupIdx, const std::string &fieldName, - const std::string &egpName) const override; + //! Generate code to define a variable in the appropriate header file + virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc) const final; + + //! Generate code to instantiate a variable in the provided stream + virtual void genVariableInstantiation(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc) const final; + + //! Generate code to allocate variable with a size known at compile-time + virtual void genVariableAllocation(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, size_t count, MemAlloc &memAlloc) const final; + + //! Generate code to allocate variable with a size known at runtime + virtual void genVariableDynamicAllocation(CodeStream &os, + const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + + //! Generate code to free a variable + virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const final; + + //! Generate code for pushing a variable with a size known at compile-time to the 'device' + virtual void genVariablePush(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, bool autoInitialized, size_t count) const final; + + //! Generate code for pulling a variable with a size known at compile-time from the 'device' + virtual void genVariablePull(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, size_t count) const final; + + //! Generate code for pushing a variable's value in the current timestep to the 'device' + virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, unsigned int batchSize) const final; + + //! Generate code for pulling a variable's value in the current timestep from the 'device' + virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, unsigned int batchSize) const final; + + //! Generate code for pushing a variable with a size known at tuntime to the 'device' + virtual void genVariableDynamicPush(CodeStream &os, + const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + + //! Generate code for pulling a variable with a size known at runtime from the 'device' + virtual void genVariableDynamicPull(CodeStream &os, + const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + + //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' + virtual void genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, + const std::string &groupIdx, const std::string &fieldName, + const std::string &egpName) const final; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const override; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const final; //! When generating merged structures what type to use for simulation RNGs - virtual const Type::Base *getMergedGroupSimRNGType() const override; + virtual const Type::ValueBase *getMergedGroupSimRNGType() const final; - virtual void genPopVariableInit(CodeStream &os,const Substitutions &kernelSubs, Handler handler) const override; + virtual void genPopVariableInit(CodeStream &os,const Substitutions &kernelSubs, Handler handler) const final; virtual void genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, - const Substitutions &kernelSubs, Handler handler) const override; - virtual void genSparseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const override; - virtual void genDenseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const override; + const Substitutions &kernelSubs, Handler handler) const final; + virtual void genSparseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const final; + virtual void genDenseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const final; virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const final; virtual void genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged &cu, const Substitutions &kernelSubs, Handler handler) const final; - virtual void genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const override; - virtual void genVariablePull(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, size_t count) const override; - virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, const std::string &type, - const std::string &name, VarLocation loc, unsigned int batchSize) const override; - virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, const std::string &type, - const std::string &name, VarLocation loc, unsigned int batchSize) const override; - - virtual void genCurrentTrueSpikePush(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize) const override; - virtual void genCurrentTrueSpikePull(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize) const override; - virtual void genCurrentSpikeLikeEventPush(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize) const override; - virtual void genCurrentSpikeLikeEventPull(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize) const override; - virtual void genGlobalDeviceRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, - CodeStream &allocations, CodeStream &free, MemAlloc &memAlloc) const override; + CodeStream &allocations, CodeStream &free, MemAlloc &memAlloc) const final; virtual void genPopulationRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const std::string &name, size_t count, MemAlloc &memAlloc) const override; + const std::string &name, size_t count, MemAlloc &memAlloc) const final; virtual void genTimer(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - CodeStream &stepTimeFinalise, const std::string &name, bool updateInStepTime) const override; + CodeStream &stepTimeFinalise, const std::string &name, bool updateInStepTime) const final; //! Generate code to return amount of free 'device' memory in bytes - virtual void genReturnFreeDeviceMemoryBytes(CodeStream &os) const override; + virtual void genReturnFreeDeviceMemoryBytes(CodeStream &os) const final; //! On backends which support it, generate a runtime assert - virtual void genAssert(CodeStream &os, const std::string &condition) const override; + virtual void genAssert(CodeStream &os, const std::string &condition) const final; - virtual void genMakefilePreamble(std::ostream &os) const override; - virtual void genMakefileLinkRule(std::ostream &os) const override; - virtual void genMakefileCompileRule(std::ostream &os) const override; + virtual void genMakefilePreamble(std::ostream &os) const final; + virtual void genMakefileLinkRule(std::ostream &os) const final; + virtual void genMakefileCompileRule(std::ostream &os) const final; - virtual void genMSBuildConfigProperties(std::ostream &os) const override; - virtual void genMSBuildImportProps(std::ostream &os) const override; - virtual void genMSBuildItemDefinitions(std::ostream &os) const override; - virtual void genMSBuildCompileModule(const std::string &moduleName, std::ostream &os) const override; - virtual void genMSBuildImportTarget(std::ostream &os) const override; + virtual void genMSBuildConfigProperties(std::ostream &os) const final; + virtual void genMSBuildImportProps(std::ostream &os) const final; + virtual void genMSBuildItemDefinitions(std::ostream &os) const final; + virtual void genMSBuildCompileModule(const std::string &moduleName, std::ostream &os) const final; + virtual void genMSBuildImportTarget(std::ostream &os) const final; - virtual std::string getDeviceVarPrefix() const override{ return ""; } + virtual std::string getDeviceVarPrefix() const final{ return ""; } //! Should 'scalar' variables be implemented on device or can host variables be used directly? - virtual bool isDeviceScalarRequired() const override { return false; } + virtual bool isDeviceScalarRequired() const final { return false; } - virtual bool isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) const override; - virtual bool isGlobalDeviceRNGRequired(const ModelSpecMerged &modelMerged) const override; - virtual bool isPopulationRNGRequired() const override { return false; } + virtual bool isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) const final; + virtual bool isGlobalDeviceRNGRequired(const ModelSpecMerged &modelMerged) const final; + virtual bool isPopulationRNGRequired() const final { return false; } //! Different backends seed RNGs in different ways. Does this one initialise population RNGS on device? - virtual bool isPopulationRNGInitialisedOnDevice() const override { return false; } + virtual bool isPopulationRNGInitialisedOnDevice() const final { return false; } - virtual bool isPostsynapticRemapRequired() const override{ return true; } + virtual bool isPostsynapticRemapRequired() const final{ return true; } //! Backends which support batch-parallelism might require an additional host reduction phase after reduction kernels - virtual bool isHostReductionRequired() const override { return false; } + virtual bool isHostReductionRequired() const final { return false; } //! How many bytes of memory does 'device' have - virtual size_t getDeviceMemoryBytes() const override{ return 0; } + virtual size_t getDeviceMemoryBytes() const final{ return 0; } //! Some backends will have additional small, fast, memory spaces for read-only data which might //! Be well-suited to storing merged group structs. This method returns the prefix required to //! Place arrays in these and their size in preferential order - virtual MemorySpaces getMergedGroupMemorySpaces(const ModelSpecMerged &modelMerged) const override; + virtual MemorySpaces getMergedGroupMemorySpaces(const ModelSpecMerged &modelMerged) const final; - virtual bool supportsNamespace() const override { return true; }; + virtual bool supportsNamespace() const final { return true; }; //! Get hash digest of this backends identification and the preferences it has been configured with - virtual boost::uuids::detail::sha1::digest_type getHashDigest() const override; + virtual boost::uuids::detail::sha1::digest_type getHashDigest() const final; private: //-------------------------------------------------------------------------- diff --git a/include/genn/backends/single_threaded_cpu/optimiser.h b/include/genn/backends/single_threaded_cpu/optimiser.h index 570642a3f5..0cf261a68b 100644 --- a/include/genn/backends/single_threaded_cpu/optimiser.h +++ b/include/genn/backends/single_threaded_cpu/optimiser.h @@ -9,12 +9,6 @@ // Single-threaded CPU backend includes #include "backend.h" -// Forward declarations -namespace GeNN -{ -class ModelSpecInternal; -} - namespace plog { class IAppender; @@ -25,7 +19,7 @@ class IAppender; //-------------------------------------------------------------------------- namespace GeNN::CodeGenerator::SingleThreadedCPU::Optimiser { -BACKEND_EXPORT Backend createBackend(const ModelSpecInternal &model, const filesystem::path &outputPath, +BACKEND_EXPORT Backend createBackend(const filesystem::path &outputPath, plog::Severity backendLevel, plog::IAppender *backendAppender, const Preferences &preferences); } // namespace GeNN::CodeGenerator::SingleThreadedCPU::Optimiser diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 2a9daf2d80..9a41d8bd14 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -245,30 +245,69 @@ class GENN_EXPORT BackendBase //! After all timestep logic is complete virtual void genStepTimeFinalisePreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const = 0; - virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, const Type::Base *type, const std::string &name, VarLocation loc) const = 0; - virtual void genVariableImplementation(CodeStream &os, const Type::Base *type, const std::string &name, VarLocation loc) const = 0; - virtual void genVariableAllocation(CodeStream &os, const Type::Base *type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const = 0; + //! Generate code to define a variable in the appropriate header file + virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc) const = 0; + + //! Generate code to instantiate a variable in the provided stream + virtual void genVariableInstantiation(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc) const = 0; + + //! Generate code to allocate variable with a size known at compile-time + virtual void genVariableAllocation(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, size_t count, MemAlloc &memAlloc) const = 0; + + //! Generate code to allocate variable with a size known at runtime + virtual void genVariableDynamicAllocation(CodeStream &os, + const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; + + //! Generate code to free a variable virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const = 0; - virtual void genExtraGlobalParamDefinition(CodeStream &definitions, CodeStream &definitionsInternal, const Type::Base *type, const std::string &name, VarLocation loc) const = 0; - virtual void genExtraGlobalParamImplementation(CodeStream &os, const Type::Pointer *type, const std::string &name, VarLocation loc) const = 0; - virtual void genExtraGlobalParamAllocation(CodeStream &os, const Type::Pointer *type, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; - virtual void genExtraGlobalParamPush(CodeStream &os, const Type::Pointer *type, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; - virtual void genExtraGlobalParamPull(CodeStream &os, const Type::Pointer *type, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; + //! Generate code for pushing a variable with a size known at compile-time to the 'device' + virtual void genVariablePush(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, bool autoInitialized, size_t count) const = 0; + + //! Generate code for pulling a variable with a size known at compile-time from the 'device' + virtual void genVariablePull(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, size_t count) const = 0; - //! Generate code for pushing an updated EGP value into the merged group structure on 'device' - virtual void genMergedExtraGlobalParamPush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, - const std::string &groupIdx, const std::string &fieldName, - const std::string &egpName) const = 0; + //! Generate code for pushing a variable's value in the current timestep to the 'device' + virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, unsigned int batchSize) const = 0; + + //! Generate code for pulling a variable's value in the current timestep from the 'device' + virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, unsigned int batchSize) const = 0; + + //! Generate code for pushing a variable with a size known at tuntime to the 'device' + virtual void genVariableDynamicPush(CodeStream &os, + const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; + + //! Generate code for pulling a variable with a size known at runtime from the 'device' + virtual void genVariableDynamicPull(CodeStream &os, + const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; + + //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' + virtual void genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, + const std::string &groupIdx, const std::string &fieldName, + const std::string &egpName) const = 0; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const = 0; //! When generating merged structures what type to use for simulation RNGs - virtual const Type::Base *getMergedGroupSimRNGType() const = 0; + virtual const Type::ValueBase *getMergedGroupSimRNGType() const = 0; virtual void genPopVariableInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0; virtual void genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, @@ -278,32 +317,6 @@ class GENN_EXPORT BackendBase virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const = 0; virtual void genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged &cu, const Substitutions &kernelSubs, Handler handler) const = 0; - //! Generate code for pushing a variable to the 'device' - virtual void genVariablePush(CodeStream &os, const Type::Base *type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const = 0; - - //! Generate code for pulling a variable from the 'device' - virtual void genVariablePull(CodeStream &os, const Type::Base *type, const std::string &name, VarLocation loc, size_t count) const = 0; - - //! Generate code for pushing a variable's value in the current timestep to the 'device' - virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, const Type::Base *type, - const std::string &name, VarLocation loc, unsigned int batchSize) const = 0; - - //! Generate code for pulling a variable's value in the current timestep from the 'device' - virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, const Type::Base *type, - const std::string &name, VarLocation loc, unsigned int batchSize) const = 0; - - //! Generate code for pushing true spikes emitted by a neuron group in the current timestep to the 'device' - virtual void genCurrentTrueSpikePush(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize) const = 0; - - //! Generate code for pulling true spikes emitted by a neuron group in the current timestep from the 'device' - virtual void genCurrentTrueSpikePull(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize) const = 0; - - //! Generate code for pushing spike-like events emitted by a neuron group in the current timestep to the 'device' - virtual void genCurrentSpikeLikeEventPush(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize) const = 0; - - //! Generate code for pulling spike-like events emitted by a neuron group in the current timestep from the 'device' - virtual void genCurrentSpikeLikeEventPull(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize) const = 0; - //! Generate a single RNG instance /*! On single-threaded platforms this can be a standard RNG like M.T. but, on parallel platforms, it is likely to be a counter-based RNG */ virtual void genGlobalDeviceRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, @@ -404,55 +417,59 @@ class GENN_EXPORT BackendBase //-------------------------------------------------------------------------- //! Helper function to generate matching push and pull functions for a variable void genVariablePushPull(CodeStream &push, CodeStream &pull, - const Type::Base *type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, bool autoInitialized, size_t count) const { - genVariablePush(push, type, name, loc, autoInitialized, count); - genVariablePull(pull, type, name, loc, count); + genVariablePush(push, type, typeContext, name, loc, autoInitialized, count); + genVariablePull(pull, type, typeContext, name, loc, count); } //! Templated version of helper function to generate matching push and pull functions for //! a variable when type is known at compile time template void genVariablePushPull(CodeStream &push, CodeStream &pull, - const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const + const Type::TypeContext &typeContext, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const { - genVariablePushPull(push, pull, T::getInstance(), name, loc, autoInitialized, count); + genVariablePushPull(push, pull, T::getInstance(), typeContext, name, loc, autoInitialized, count); } //! Helper function to generate matching push and pull functions for the current state of a variable - void genCurrentVariablePushPull(CodeStream &push, CodeStream &pull, const NeuronGroupInternal &ng, const Type::Base *type, - const std::string &name, VarLocation loc, unsigned int batchSize) const + void genCurrentVariablePushPull(CodeStream &push, CodeStream &pull, const NeuronGroupInternal &ng, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, unsigned int batchSize) const { - genCurrentVariablePush(push, ng, type, name, loc, batchSize); - genCurrentVariablePull(pull, ng, type, name, loc, batchSize); + genCurrentVariablePush(push, ng, type, typeContext, name, loc, batchSize); + genCurrentVariablePull(pull, ng, type, typeContext, name, loc, batchSize); } //! Templated version of gelper function to generate matching push and pull functions //! for the current state of variable when type is known at compile time template void genCurrentVariablePushPull(CodeStream &push, CodeStream &pull, const NeuronGroupInternal &ng, - const std::string &name, VarLocation loc, unsigned int batchSize) const + const std::string &name, const Type::TypeContext &typeContext, + VarLocation loc, unsigned int batchSize) const { - genCurrentVariablePushPull(push, pull, ng, T::getInstance(), name, loc, batchSize); + genCurrentVariablePushPull(push, pull, ng, T::getInstance(), typeContext, name, loc, batchSize); } - //! Helper function to generate matching definition, declaration, allocation and free code for an array + //! Helper function to generate matching definition, declaration, allocation and free code for a statically-sized array void genArray(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const Type::NumericBase *type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, size_t count, MemAlloc &memAlloc) const { - genVariableDefinition(definitions, definitionsInternal, type->getPointerType(), name, loc); - genVariableImplementation(runner, type->getPointerType(), name, loc); + genVariableDefinition(definitions, definitionsInternal, type, typeContext, name, loc); + genVariableInstantiation(runner, type, typeContext, name, loc); genVariableFree(free, name, loc); - genVariableAllocation(allocations, type, name, loc, count, memAlloc); + genVariableAllocation(allocations, type, typeContext, name, loc, count, memAlloc); } //! Templated version of helper function to generate matching definition, declaration, - //! allocation and free code for an array when type is known at compile-time + //! allocation and free code for a statically-sized array when type is known at compile-time template void genArray(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const + const Type::TypeContext &typeContext, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const { - genArray(definitions, definitionsInternal, runner, allocations, free, T::getInstance(), name, loc, count, memAlloc); + genArray(definitions, definitionsInternal, runner, allocations, free, T::getInstance(), typeContext, name, loc, count, memAlloc); } //! Get the prefix for accessing the address of 'scalar' variables diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 3215481e10..82cd1d9201 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -40,11 +40,11 @@ class GENN_EXPORT ModelSpecMerged //! Immutable structure for tracking fields of merged group structure containing EGPs struct EGPField { - EGPField(size_t m, const Type::Base *t, const std::string &f, bool h) + EGPField(size_t m, const Type::Pointer *t, const std::string &f, bool h) : mergedGroupIndex(m), type(t), fieldName(f), hostGroup(h) {} const size_t mergedGroupIndex; - const Type::Base *type; + const Type::Pointer *type; const std::string fieldName; const bool hostGroup; @@ -63,7 +63,7 @@ class GENN_EXPORT ModelSpecMerged //! Immutable structure for tracking where an extra global variable ends up after merging struct MergedEGP : public EGPField { - MergedEGP(size_t m, size_t g, const Type::Base *t, const std::string &f, bool h) + MergedEGP(size_t m, size_t g, const Type::Pointer *t, const std::string &f, bool h) : EGPField(m, t, f, h), groupIndex(g) {} const size_t groupIndex; @@ -265,21 +265,14 @@ class GENN_EXPORT ModelSpecMerged os << "// ------------------------------------------------------------------------" << std::endl; os << "// merged extra global parameter functions" << std::endl; os << "// ------------------------------------------------------------------------" << std::endl; - // Loop through resultant fields and generate push function for pointer extra global parameters + // Loop through resultant fields and generate function to push updated pointers into group merged for(auto f : mergedGroupFields) { - // If EGP is a pointer - // **NOTE** this is common to all references! - if(dynamic_cast(f.type)) { - os << "void pushMerged" << T::name << f.mergedGroupIndex << f.fieldName << "ToDevice(unsigned int idx, " << backend.getMergedGroupFieldHostTypeName(f.type, m_TypeContext) << " value)"; - { - CodeStream::Scope b(os); - backend.genMergedExtraGlobalParamPush(os, T::name, f.mergedGroupIndex, "idx", f.fieldName, "value"); - } - os << std::endl; - } - else { - assert(false); + os << "void pushMerged" << T::name << f.mergedGroupIndex << f.fieldName << "ToDevice(unsigned int idx, " << backend.getMergedGroupFieldHostTypeName(f.type, m_TypeContext) << " value)"; + { + CodeStream::Scope b(os); + backend.genMergedDynamicVariablePush(os, T::name, f.mergedGroupIndex, "idx", f.fieldName, "value"); } + os << std::endl; } } } @@ -330,10 +323,12 @@ class GENN_EXPORT ModelSpecMerged const auto &g = mergedGroups.back().getGroups()[groupIndex]; // Add reference to this group's variable to data structure + const auto *pointerType = dynamic_cast(std::get<0>(f)); + assert(pointerType); m_MergedEGPs[std::get<2>(f)(g, groupIndex)].emplace( std::piecewise_construct, std::forward_as_tuple(MergedGroup::name), - std::forward_as_tuple(i, groupIndex, std::get<0>(f), std::get<1>(f), host)); + std::forward_as_tuple(i, groupIndex, pointerType, std::get<1>(f), host)); } } } diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 6d3d19cf5b..0ac324fe9d 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -165,13 +165,22 @@ class Pointer : public Base const Base *m_ValueType; }; +//---------------------------------------------------------------------------- +// GeNN::Type::ValueBase +//---------------------------------------------------------------------------- +class ValueBase : public Base +{ +public: + ValueBase(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} +}; + //---------------------------------------------------------------------------- // GeNN::Type::NumericBase //---------------------------------------------------------------------------- -class NumericBase : public Base +class NumericBase : public ValueBase { public: - NumericBase(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} + NumericBase(Qualifier qualifiers = Qualifier{0}) : ValueBase(qualifiers){} //------------------------------------------------------------------------ // Declared virtuals diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 716222589d..d51d339fd9 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -47,18 +47,18 @@ const std::vector cudaDoublePrecisionFunctions //-------------------------------------------------------------------------- // CURandState //-------------------------------------------------------------------------- -class CURandState : public Type::Base +class CURandState : public Type::ValueBase { public: DECLARE_TYPE(CURandState); - CURandState(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} + CURandState(Type::Qualifier qualifiers = Type::Qualifier{0}) : ValueBase(qualifiers){} //------------------------------------------------------------------------ // Base overloads //------------------------------------------------------------------------ virtual std::string getName() const final{ return "curandState"; } - virtual Base *getQualifiedType(Qualifier qualifiers) const { return new CURandState(qualifiers); } + virtual Base *getQualifiedType(Type::Qualifier qualifiers) const { return new CURandState(qualifiers); } virtual size_t getSizeBytes() const final{ return 44; } }; IMPLEMENT_TYPE(CURandState); @@ -66,18 +66,18 @@ IMPLEMENT_TYPE(CURandState); //-------------------------------------------------------------------------- // CURandStatePhilox43210 //-------------------------------------------------------------------------- -class CURandStatePhilox43210 : public Type::Base +class CURandStatePhilox43210 : public Type::ValueBase { public: DECLARE_TYPE(CURandStatePhilox43210); - CURandStatePhilox43210(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} + CURandStatePhilox43210(Type::Qualifier qualifiers = Type::Qualifier{0}) : ValueBasese(qualifiers){} //------------------------------------------------------------------------ // Base overloads //------------------------------------------------------------------------ virtual std::string getName() const final{ return "curandStatePhilox4_32_10_t"; } - virtual Base *getQualifiedType(Qualifier qualifiers) const { return new CURandStatePhilox43210(qualifiers); } + virtual Base *getQualifiedType(Type::Qualifier qualifiers) const { return new CURandStatePhilox43210(qualifiers); } virtual size_t getSizeBytes() const final{ return 64; } }; IMPLEMENT_TYPE(CURandStatePhilox43210); diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index 69cc932dbc..fe2ea6e8c9 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -17,6 +17,7 @@ // OpenCL backend includes #include "utils.h" +using namespace GeNN; using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- @@ -45,18 +46,18 @@ const std::vector openclPhilloxFunctions = { //-------------------------------------------------------------------------- // CLRRNGLFSR113Stream //-------------------------------------------------------------------------- -class CLRRNGLFSR113Stream : public Type::Base +class CLRRNGLFSR113Stream : public Type::ValueBase { public: DECLARE_TYPE(CLRRNGLFSR113Stream); - CLRRNGLFSR113Stream(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} + CLRRNGLFSR113Stream(Type::Qualifier qualifiers = Type::Qualifier{0}) : ValueBase(qualifiers){} //------------------------------------------------------------------------ // Base overloads //------------------------------------------------------------------------ virtual std::string getName() const final{ return "clrngLfsr113Stream"; } - virtual Base *getQualifiedType(Qualifier qualifiers) const { return new CLRRNGLFSR113Stream(qualifiers); } + virtual Base *getQualifiedType(Type::Qualifier qualifiers) const { return new CLRRNGLFSR113Stream(qualifiers); } virtual size_t getSizeBytes() const final{ return 48; } }; IMPLEMENT_TYPE(CLRRNGLFSR113Stream); @@ -64,18 +65,18 @@ IMPLEMENT_TYPE(CLRRNGLFSR113Stream); //-------------------------------------------------------------------------- // CLRRNGPhilox432Stream //-------------------------------------------------------------------------- -class CLRRNGPhilox432Stream : public Type::Base +class CLRRNGPhilox432Stream : public Type::ValueBase { public: DECLARE_TYPE(CLRRNGPhilox432Stream); - CLRRNGPhilox432Stream(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} + CLRRNGPhilox432Stream(Type::Qualifier qualifiers = Type::Qualifier{0}) : ValueBase(qualifiers){} //------------------------------------------------------------------------ // Base overloads //------------------------------------------------------------------------ virtual std::string getName() const final{ return "clrngPhilox432Stream"; } - virtual Base *getQualifiedType(Qualifier qualifiers) const { return new CLRRNGLFSR113Stream(qualifiers); } + virtual Base *getQualifiedType(Type::Qualifier qualifiers) const { return new CLRRNGLFSR113Stream(qualifiers); } virtual size_t getSizeBytes() const final{ return 132; } }; IMPLEMENT_TYPE(CLRRNGLFSR113Stream); diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 3db8f7647f..ef24c18ac4 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -149,7 +149,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision())); + Substitutions funcSubs(getFunctionTemplates(model.getPrecision()->getName())); funcSubs.addVarSubstitution("t", "t"); funcSubs.addVarSubstitution("batch", "0"); @@ -316,7 +316,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge os << "void updateSynapses(" << model.getTimePrecision() << " t)"; { CodeStream::Scope b(os); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision())); + Substitutions funcSubs(getFunctionTemplates(model.getPrecision()->getName())); funcSubs.addVarSubstitution("t", "t"); funcSubs.addVarSubstitution("batch", "0"); @@ -525,7 +525,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision())); + Substitutions funcSubs(getFunctionTemplates(model.getPrecision()->getName())); funcSubs.addVarSubstitution("t", "t"); funcSubs.addVarSubstitution("batch", "0"); @@ -574,7 +574,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // Loop through reduction targets and generate reduction for (const auto &r : reductionTargets) { - os << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; + os << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; } } @@ -807,7 +807,7 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHa os << "void initialize()"; { CodeStream::Scope b(os); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision())); + Substitutions funcSubs(getFunctionTemplates(model.getPrecision()->getName())); Timer t(os, "init", model.isTimingEnabled()); @@ -1062,7 +1062,7 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHa os << "void initializeSparse()"; { CodeStream::Scope b(os); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision())); + Substitutions funcSubs(getFunctionTemplates(model.getPrecision()->getName())); Timer t(os, "initSparse", model.isTimingEnabled()); @@ -1264,9 +1264,9 @@ void Backend::genRunnerPreamble(CodeStream &os, const ModelSpecMerged &modelMerg // If a global RNG is required, implement standard host distributions as recreating them each call is slow if(isGlobalHostRNGRequired(modelMerged)) { - os << "std::uniform_real_distribution<" << model.getPrecision() << "> standardUniformDistribution(" << model.scalarExpr(0.0) << ", " << model.scalarExpr(1.0) << ");" << std::endl; - os << "std::normal_distribution<" << model.getPrecision() << "> standardNormalDistribution(" << model.scalarExpr(0.0) << ", " << model.scalarExpr(1.0) << ");" << std::endl; - os << "std::exponential_distribution<" << model.getPrecision() << "> standardExponentialDistribution(" << model.scalarExpr(1.0) << ");" << std::endl; + os << "std::uniform_real_distribution<" << model.getPrecision()->getName() << "> standardUniformDistribution(" << modelMerged.scalarExpr(0.0) << ", " << modelMerged.scalarExpr(1.0) << ");" << std::endl; + os << "std::normal_distribution<" << model.getPrecision()->getName() << "> standardNormalDistribution(" << modelMerged.scalarExpr(0.0) << ", " << modelMerged.scalarExpr(1.0) << ");" << std::endl; + os << "std::exponential_distribution<" << model.getPrecision()->getName() << "> standardExponentialDistribution(" << modelMerged.scalarExpr(1.0) << ");" << std::endl; os << std::endl; } os << std::endl; @@ -1284,76 +1284,99 @@ void Backend::genStepTimeFinalisePreamble(CodeStream &, const ModelSpecMerged &) { } //-------------------------------------------------------------------------- -void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &, const std::string &type, const std::string &name, VarLocation) const +void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc) const { - definitions << "EXPORT_VAR " << type << " " << name << ";" << std::endl; + definitions << "EXPORT_VAR " << type->getPointerType()->getResolvedName(typeContext) << " " << name << ";" << std::endl; } //-------------------------------------------------------------------------- -void Backend::genVariableImplementation(CodeStream &os, const std::string &type, const std::string &name, VarLocation) const +void Backend::genVariableInstantiation(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc) const { - os << type << " " << name << ";" << std::endl; + os << type->getPointerType()->getResolvedName(typeContext) << " " << name << ";" << std::endl; } //-------------------------------------------------------------------------- -void Backend::genVariableAllocation(CodeStream &os, const std::string &type, const std::string &name, VarLocation, size_t count, MemAlloc &memAlloc) const +void Backend::genVariableAllocation(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, size_t count, MemAlloc &memAlloc) const { - os << name << " = new " << type << "[" << count << "];" << std::endl; + os << name << " = new " << type->getResolvedName(typeContext) << "[" << count << "];" << std::endl; - memAlloc += MemAlloc::host(count * getSize(type)); + memAlloc += MemAlloc::host(count * type->getSizeBytes(typeContext)); } //-------------------------------------------------------------------------- -void Backend::genVariableFree(CodeStream &os, const std::string &name, VarLocation) const +void Backend::genVariableDynamicAllocation(CodeStream &os, + const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, const std::string &countVarName, const std::string &prefix) const +{ + const auto *pointerType = dynamic_cast(type); + if (pointerType) { + os << "*" << prefix << name << " = new " << pointerType->getValueType()->getResolvedName(typeContext) << "[" << countVarName << "];" << std::endl; + } + else { + os << prefix << name << " = new " << type->getResolvedName(typeContext) << "[" << countVarName << "];" << std::endl; + } +} +//-------------------------------------------------------------------------- +void Backend::genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const { os << "delete[] " << name << ";" << std::endl; } //-------------------------------------------------------------------------- -void Backend::genExtraGlobalParamDefinition(CodeStream &definitions, CodeStream &, - const std::string &type, const std::string &name, VarLocation) const +void Backend::genVariablePush(CodeStream&, const Type::ValueBase*, const Type::TypeContext&, const std::string&, VarLocation, bool, size_t) const { - definitions << "EXPORT_VAR " << type << " " << name << ";" << std::endl; + assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- -void Backend::genExtraGlobalParamImplementation(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc) const +void Backend::genVariablePull(CodeStream&, const Type::ValueBase*, const Type::TypeContext&, const std::string&, VarLocation, size_t) const { - genVariableImplementation(os, type, name, loc); + assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- -void Backend::genExtraGlobalParamAllocation(CodeStream &os, const std::string &type, const std::string &name, - VarLocation, const std::string &countVarName, const std::string &prefix) const +void Backend::genCurrentVariablePush(CodeStream&, const NeuronGroupInternal&, + const Type::ValueBase*, const Type::TypeContext&, const std::string&, + VarLocation, unsigned int) const { - // Get underlying type - const std::string underlyingType = Utils::getUnderlyingType(type); - const bool pointerToPointer = Utils::isTypePointerToPointer(type); - - const std::string pointer = pointerToPointer ? ("*" + prefix + name) : (prefix + name); - - os << pointer << " = new " << underlyingType << "[" << countVarName << "];" << std::endl; + assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- -void Backend::genExtraGlobalParamPush(CodeStream &, const std::string &, const std::string &, - VarLocation, const std::string &, const std::string &) const +void Backend::genCurrentVariablePull(CodeStream&, const NeuronGroupInternal&, + const Type::ValueBase*, const Type::TypeContext&, const std::string&, + VarLocation, unsigned int) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- -void Backend::genExtraGlobalParamPull(CodeStream &, const std::string &, const std::string &, - VarLocation, const std::string &, const std::string &) const +void Backend::genVariableDynamicPush(CodeStream&, + const Type::Base*, const Type::TypeContext&, const std::string&, + VarLocation, const std::string&, const std::string&) const +{ + assert(!getPreferences().automaticCopy); +} +//-------------------------------------------------------------------------- +void Backend::genVariableDynamicPull(CodeStream&, + const Type::Base*, const Type::TypeContext&, const std::string&, + VarLocation, const std::string&, const std::string&) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- -void Backend::genMergedExtraGlobalParamPush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, - const std::string &groupIdx, const std::string &fieldName, - const std::string &egpName) const +void Backend::genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, + const std::string &groupIdx, const std::string &fieldName, + const std::string &egpName) const { os << "merged" << suffix << "Group" << mergedGroupIdx << "[" << groupIdx << "]." << fieldName << " = " << egpName << ";" << std::endl; } + //-------------------------------------------------------------------------- std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const { return type->getResolvedName(context); } //-------------------------------------------------------------------------- -const Type::Base *Backend::getMergedGroupSimRNGType() const +const Type::ValueBase *Backend::getMergedGroupSimRNGType() const { assert(false); return nullptr; @@ -1415,46 +1438,6 @@ void Backend::genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUp genKernelIteration(os, cu, cu.getArchetype().getSynapseGroup()->getKernelSize().size(), kernelSubs, handler); } //-------------------------------------------------------------------------- -void Backend::genVariablePush(CodeStream&, const std::string&, const std::string&, VarLocation, bool, size_t) const -{ - assert(!getPreferences().automaticCopy); -} -//-------------------------------------------------------------------------- -void Backend::genVariablePull(CodeStream&, const std::string&, const std::string&, VarLocation, size_t) const -{ - assert(!getPreferences().automaticCopy); -} -//-------------------------------------------------------------------------- -void Backend::genCurrentVariablePush(CodeStream &, const NeuronGroupInternal &, const std::string &, const std::string &, VarLocation, unsigned int) const -{ - assert(!getPreferences().automaticCopy); -} -//-------------------------------------------------------------------------- -void Backend::genCurrentVariablePull(CodeStream &, const NeuronGroupInternal &, const std::string &, const std::string &, VarLocation, unsigned int) const -{ - assert(!getPreferences().automaticCopy); -} -//-------------------------------------------------------------------------- -void Backend::genCurrentTrueSpikePush(CodeStream&, const NeuronGroupInternal&, unsigned int) const -{ - assert(!getPreferences().automaticCopy); -} -//-------------------------------------------------------------------------- -void Backend::genCurrentTrueSpikePull(CodeStream&, const NeuronGroupInternal&, unsigned int) const -{ - assert(!getPreferences().automaticCopy); -} -//-------------------------------------------------------------------------- -void Backend::genCurrentSpikeLikeEventPush(CodeStream&, const NeuronGroupInternal&, unsigned int) const -{ - assert(!getPreferences().automaticCopy); -} -//-------------------------------------------------------------------------- -void Backend::genCurrentSpikeLikeEventPull(CodeStream&, const NeuronGroupInternal&, unsigned int) const -{ - assert(!getPreferences().automaticCopy); -} -//-------------------------------------------------------------------------- void Backend::genGlobalDeviceRNG(CodeStream&, CodeStream&, CodeStream&, CodeStream&, CodeStream&, MemAlloc&) const { assert(false); @@ -1673,7 +1656,7 @@ void Backend::genPresynapticUpdate(CodeStream &os, const ModelSpecMerged &modelM // Apply substitutions to value std::string value = d.value; connSubs.applyCheckUnreplaced(value, "toeplitz diagonal build state var : merged" + std::to_string(sg.getIndex())); - value = ensureFtype(value, modelMerged.getModel().getPrecision()); + //value = ensureFtype(value, modelMerged.getModel().getPrecision()); os << d.type->getResolvedName(sg.getTypeContext()) << " " << d.name << " = " << value << ";" << std::endl; } diff --git a/src/genn/backends/single_threaded_cpu/optimiser.cc b/src/genn/backends/single_threaded_cpu/optimiser.cc index 3f904d1049..03c8ca69db 100644 --- a/src/genn/backends/single_threaded_cpu/optimiser.cc +++ b/src/genn/backends/single_threaded_cpu/optimiser.cc @@ -1,14 +1,11 @@ #include "optimiser.h" -// GeNN includes -#include "modelSpecInternal.h" - //-------------------------------------------------------------------------- // GeNN::CodeGenerator::SingleThreadedCPU::Optimiser //-------------------------------------------------------------------------- namespace GeNN::CodeGenerator::SingleThreadedCPU::Optimiser { -Backend createBackend(const ModelSpecInternal &model, const filesystem::path&, +Backend createBackend(const filesystem::path&, plog::Severity backendLevel, plog::IAppender *backendAppender, const Preferences &preferences) { @@ -21,6 +18,6 @@ Backend createBackend(const ModelSpecInternal &model, const filesystem::path&, plog::get()->setMaxSeverity(backendLevel); } - return Backend(model.getPrecision(), preferences); + return Backend(preferences); } } // namespace GeNN::CodeGenerator::SingleThreadedCPU::Optimiser diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index cf1f1896c8..c2c5738204 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -487,10 +487,9 @@ void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase & for(const auto &egp : cm->getExtraGlobalParams()) { // Generate code to push this EGP with count specified by $(0) std::stringstream pushStream; - const auto *pointerType = egp.type->getPointerType(); CodeStream push(pushStream); - backend.genExtraGlobalParamPush(push, pointerType, egp.name, - VarLocation::HOST_DEVICE, "$(0)", "group->"); + backend.genVariableDynamicPush(push, egp.type, getTypeContext(), egp.name, + VarLocation::HOST_DEVICE, "$(0)", "group->"); // Add substitution subs.addFuncSubstitution("push" + egp.name + "ToDevice", 1, pushStream.str()); @@ -498,8 +497,8 @@ void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase & // Generate code to pull this EGP with count specified by $(0) std::stringstream pullStream; CodeStream pull(pullStream); - backend.genExtraGlobalParamPull(pull, pointerType, egp.name, - VarLocation::HOST_DEVICE, "$(0)", "group->"); + backend.genVariableDynamicPull(pull, egp.type, getTypeContext(), egp.name, + VarLocation::HOST_DEVICE, "$(0)", "group->"); // Add substitution subs.addFuncSubstitution("pull" + egp.name + "FromDevice", 1, pullStream.str()); @@ -528,11 +527,10 @@ void CustomConnectivityHostUpdateGroupMerged::addVarPushPullFuncSubs(const Backe const auto loc = std::invoke(getVarLocationFn, getArchetype(), v.name); if (loc & VarLocation::HOST) { // Generate code to push this variable - // **YUCK** these EGP functions should probably just be called dynamic or something std::stringstream pushStream; CodeStream push(pushStream); - backend.genExtraGlobalParamPush(push, v.type->getPointerType(), v.name, - loc, count, "group->"); + backend.genVariableDynamicPush(push, v.type, getTypeContext(), v.name, + loc, count, "group->"); // Add substitution subs.addFuncSubstitution("push" + v.name + "ToDevice", 0, pushStream.str()); @@ -541,8 +539,8 @@ void CustomConnectivityHostUpdateGroupMerged::addVarPushPullFuncSubs(const Backe // **YUCK** these EGP functions should probably just be called dynamic or something std::stringstream pullStream; CodeStream pull(pullStream); - backend.genExtraGlobalParamPull(pull, v.type->getPointerType(), v.name, - loc, count, "group->"); + backend.genVariableDynamicPull(pull, v.type, getTypeContext(), v.name, + loc, count, "group->"); // Add substitution subs.addFuncSubstitution("pull" + v.name + "FromDevice", 0, pullStream.str()); diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 37da7b7ff9..d05479e616 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -96,8 +96,8 @@ void genHostScalar(CodeStream &definitionsVar, CodeStream &runnerVarDecl, } //-------------------------------------------------------------------------- template -void genHostDeviceScalar(const BackendBase &backend, CodeStream &definitionsVar, CodeStream &definitionsInternalVar, - CodeStream &runnerVarDecl, CodeStream &runnerVarAlloc, CodeStream &runnerVarFree, +void genHostDeviceScalar(const ModelSpecMerged &modelMerged, const BackendBase &backend, CodeStream &definitionsVar, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerVarAlloc, CodeStream &runnerVarFree, const std::string &name, const std::string &hostValue, MemAlloc &mem) { // Generate a host scalar @@ -106,7 +106,7 @@ void genHostDeviceScalar(const BackendBase &backend, CodeStream &definitionsVar, // Generate a single-element array on device if(backend.isDeviceScalarRequired()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - T::getInstance(), name, VarLocation::DEVICE, 1, mem); + T::getInstance(), modelMerged.getTypeContext(), name, VarLocation::DEVICE, 1, mem); } } //-------------------------------------------------------------------------- @@ -253,9 +253,10 @@ void genStatePushPull(CodeStream &definitionsFunc, CodeStream &runnerPushFunc, C } } //------------------------------------------------------------------------- -void genVariable(const BackendBase &backend, CodeStream &definitionsVar, CodeStream &definitionsFunc, - CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - CodeStream &push, CodeStream &pull, const Type::NumericBase *type, const std::string &name, +void genVariable(const ModelSpecMerged &modelMerged, const BackendBase &backend, + CodeStream &definitionsVar, CodeStream &definitionsFunc, CodeStream &definitionsInternal, + CodeStream &runner, CodeStream &allocations, CodeStream &free, CodeStream &push, CodeStream &pull, + const Type::ValueBase *type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count, MemAlloc &mem, std::vector &statePushPullFunction) { @@ -263,12 +264,12 @@ void genVariable(const BackendBase &backend, CodeStream &definitionsVar, CodeStr genVarPushPullScope(definitionsFunc, push, pull, loc, backend.getPreferences().automaticCopy, name, statePushPullFunction, [&]() { - backend.genVariablePushPull(push, pull, type, name, loc, autoInitialized, count); + backend.genVariablePushPull(push, pull, type, modelMerged.getTypeContext(), name, loc, autoInitialized, count); }); // Generate variables backend.genArray(definitionsVar, definitionsInternal, runner, allocations, free, - type, name, loc, count, mem); + type, modelMerged.getTypeContext(), name, loc, count, mem); } //------------------------------------------------------------------------- void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase &backend, CodeStream &definitionsVar, @@ -276,9 +277,8 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & CodeStream &extraGlobalParam, const Type::NumericBase *type, const std::string &name, bool apiRequired, VarLocation loc) { // Generate variables - const auto *pointerType = type->getPointerType(); - backend.genExtraGlobalParamDefinition(definitionsVar, definitionsInternalVar, pointerType, name, loc); - backend.genExtraGlobalParamImplementation(runner, pointerType, name, loc); + backend.genVariableDefinition(definitionsVar, definitionsInternalVar, type, modelMerged.getTypeContext(), name, loc); + backend.genVariableInstantiation(runner, type, modelMerged.getTypeContext(), name, loc); // If API is required if(apiRequired) { @@ -290,9 +290,10 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & extraGlobalParam << "void allocate" << name << "(unsigned int count)"; { CodeStream::Scope a(extraGlobalParam); - backend.genExtraGlobalParamAllocation(extraGlobalParam, pointerType, name, loc); + backend.genVariableDynamicAllocation(extraGlobalParam, type, modelMerged.getTypeContext(), name, loc); // Loop through destinations in merged structures, the device EGP needs to be copied to + // **TODO** rename to dynamic if(modelMerged.anyMergedEGPDestinations(backend.getDeviceVarPrefix() + name)) { const auto &mergedDestinations = modelMerged.getMergedEGPDestinations(backend.getDeviceVarPrefix() + name); for (const auto &v : mergedDestinations) { @@ -352,7 +353,7 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & extraGlobalParam << "void push" << name << "ToDevice(unsigned int count)"; { CodeStream::Scope a(extraGlobalParam); - backend.genExtraGlobalParamPush(extraGlobalParam, pointerType, name, loc); + backend.genVariableDynamicPush(extraGlobalParam, type, modelMerged.getTypeContext(), name, loc); } if(backend.getPreferences().generateExtraGlobalParamPull) { @@ -363,7 +364,7 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & extraGlobalParam << "void pull" << name << "FromDevice(unsigned int count)"; { CodeGenerator::CodeStream::Scope a(extraGlobalParam); - backend.genExtraGlobalParamPull(extraGlobalParam, pointerType, name, loc); + backend.genVariableDynamicPull(extraGlobalParam, type, modelMerged.getTypeContext(), name, loc); } } } @@ -414,7 +415,7 @@ void genRunnerVars(const ModelSpecMerged &modelMerged, const BackendBase &backen for(const auto &var : varAdaptor.getVars()) { const auto *varInitSnippet = varAdaptor.getVarInitialisers().at(var.name).getSnippet(); const bool autoInitialized = !varInitSnippet->getCode().empty(); - genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + genVariable(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, var.type, var.name + group.getName(), varAdaptor.getVarLocation(var.name), autoInitialized, getSizeFn(group, var), mem, statePushPullFunctions); @@ -438,7 +439,7 @@ void genRunnerFusedVars(const ModelSpecMerged &modelMerged, const BackendBase &b const V varAdaptor(group); for(const auto &var : varAdaptor.getVars()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - var.type, var.name + varAdaptor.getFusedVarSuffix(), varAdaptor.getVarLocation(var.name), + var.type, modelMerged.getTypeContext(), var.name + varAdaptor.getFusedVarSuffix(), varAdaptor.getVarLocation(var.name), getSizeFn(group, var), mem); // Loop through EGPs required to initialize variable @@ -452,8 +453,9 @@ void genRunnerFusedVars(const ModelSpecMerged &modelMerged, const BackendBase &b } //------------------------------------------------------------------------- template -void genRunnerFusedVarPushPull(const BackendBase &backend, CodeStream &definitionsFunc, CodeStream &runnerPushFunc, CodeStream &runnerPullFunc, - const G &group, std::vector &groupStatePushPullFunctions, S getSizeFn) +void genRunnerFusedVarPushPull(const ModelSpecMerged &modelMerged, const BackendBase &backend, CodeStream &definitionsFunc, + CodeStream &runnerPushFunc, CodeStream &runnerPullFunc, const G &group, + std::vector &groupStatePushPullFunctions, S getSizeFn) { // Loop through variables const V varAdaptor(group); @@ -463,7 +465,8 @@ void genRunnerFusedVarPushPull(const BackendBase &backend, CodeStream &definitio backend.getPreferences().automaticCopy, var.name + group.getName(), groupStatePushPullFunctions, [&]() { - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, var.type, var.name + group.getName(), + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + var.type, modelMerged.getTypeContext(), var.name + group.getName(), varAdaptor.getVarLocation(var.name), autoInitialized, getSizeFn(group, var)); }); } @@ -902,9 +905,11 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const size_t numSpikeCounts = n.second.isTrueSpikeRequired() ? (batchSize * n.second.getNumDelaySlots()) : batchSize; const size_t numSpikes = n.second.isTrueSpikeRequired() ? numNeuronDelaySlots : (batchSize * n.second.getNumNeurons()); backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "glbSpkCnt" + n.first, n.second.getSpikeLocation(), numSpikeCounts, mem); + modelMerged.getTypeContext(), "glbSpkCnt" + n.first, + n.second.getSpikeLocation(), numSpikeCounts, mem); backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "glbSpk" + n.first, n.second.getSpikeLocation(), numSpikes, mem); + modelMerged.getTypeContext(), "glbSpk" + n.first, + n.second.getSpikeLocation(), numSpikes, mem); // True spike push and pull functions genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, n.second.getSpikeLocation(), @@ -912,18 +917,11 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - "glbSpkCnt" + n.first, n.second.getSpikeLocation(), true, numSpikeCounts); + modelMerged.getTypeContext(), "glbSpkCnt" + n.first, + n.second.getSpikeLocation(), true, numSpikeCounts); backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - "glbSpk" + n.first, n.second.getSpikeLocation(), true, numSpikes); - }); - - // Current true spike push and pull functions - genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, n.second.getSpikeLocation(), - backend.getPreferences().automaticCopy, n.first + "CurrentSpikes", currentSpikePullFunctions, - [&]() - { - backend.genCurrentTrueSpikePush(runnerPushFunc, n.second, batchSize); - backend.genCurrentTrueSpikePull(runnerPullFunc, n.second, batchSize); + modelMerged.getTypeContext(), "glbSpk" + n.first, + n.second.getSpikeLocation(), true, numSpikes); }); // Current true spike getter functions @@ -931,10 +929,14 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If spike recording is enabled, define and declare variables and add free if(n.second.isSpikeRecordingEnabled()) { - const auto *uint32Pointer = Type::Uint32::getInstance()->getPointerType(); - backend.genVariableDefinition(definitionsVar, definitionsInternalVar, uint32Pointer, "recordSpk" + n.first, VarLocation::HOST_DEVICE); - backend.genVariableImplementation(runnerVarDecl, uint32Pointer, "recordSpk" + n.first, VarLocation::HOST_DEVICE); - backend.genVariableFree(runnerVarFree, "recordSpk" + n.first, VarLocation::HOST_DEVICE); + backend.genVariableDefinition(definitionsVar, definitionsInternalVar, + Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpk" + n.first, + VarLocation::HOST_DEVICE); + backend.genVariableInstantiation(runnerVarDecl, + Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpk" + n.first, + VarLocation::HOST_DEVICE); + backend.genVariableFree(runnerVarFree, + "recordSpk" + n.first, VarLocation::HOST_DEVICE); } // If neuron group needs to emit spike-like events @@ -946,10 +948,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Spike-like event variables backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "glbSpkCntEvnt" + n.first, n.second.getSpikeEventLocation(), + modelMerged.getTypeContext(), "glbSpkCntEvnt" + n.first, n.second.getSpikeEventLocation(), batchSize * n.second.getNumDelaySlots(), mem); backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "glbSpkEvnt" + n.first, n.second.getSpikeEventLocation(), + modelMerged.getTypeContext(), "glbSpkEvnt" + n.first, n.second.getSpikeEventLocation(), numNeuronDelaySlots, mem); // Spike-like event push and pull functions @@ -957,104 +959,100 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, backend.getPreferences().automaticCopy, n.first + "SpikeEvents", [&]() { - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, "glbSpkCntEvnt" + n.first, + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + modelMerged.getTypeContext(), "glbSpkCntEvnt" + n.first, n.second.getSpikeLocation(), true, batchSize * n.second.getNumDelaySlots()); - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, "glbSpkEvnt" + n.first, + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + modelMerged.getTypeContext(), "glbSpkEvnt" + n.first, n.second.getSpikeLocation(), true, numNeuronDelaySlots); }); - // Current spike-like event push and pull functions - genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, n.second.getSpikeEventLocation(), - backend.getPreferences().automaticCopy, n.first + "CurrentSpikeEvents", currentSpikeEventPullFunctions, - [&]() - { - backend.genCurrentSpikeLikeEventPush(runnerPushFunc, n.second, batchSize); - backend.genCurrentSpikeLikeEventPull(runnerPullFunc, n.second, batchSize); - }); - // Current true spike getter functions genSpikeGetters(definitionsFunc, runnerGetterFunc, n.second, false, batchSize); // If spike recording is enabled, define and declare variables and add free if(n.second.isSpikeEventRecordingEnabled()) { - const auto *uint32Pointer = Type::Uint32::getInstance()->getPointerType(); - backend.genVariableDefinition(definitionsVar, definitionsInternalVar, uint32Pointer, "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); - backend.genVariableImplementation(runnerVarDecl, uint32Pointer, "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); + backend.genVariableDefinition(definitionsVar, definitionsInternalVar, + Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpkEvent" + n.first, + VarLocation::HOST_DEVICE); + backend.genVariableInstantiation(runnerVarDecl, + Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpkEvent" + n.first, + VarLocation::HOST_DEVICE); backend.genVariableFree(runnerVarFree, "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); } } // If neuron group has axonal delays if (n.second.isDelayRequired()) { - genHostDeviceScalar(backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + genHostDeviceScalar(modelMerged, backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, "spkQuePtr" + n.first, "0", mem); } // If neuron group needs to record its spike times if (n.second.isSpikeTimeRequired()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getTimePrecision(), "sT" + n.first, n.second.getSpikeTimeLocation(), - numNeuronDelaySlots, mem); + model.getTimePrecision(), modelMerged.getTypeContext(), "sT" + n.first, + n.second.getSpikeTimeLocation(), numNeuronDelaySlots, mem); // Generate push and pull functions genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, n.second.getSpikeTimeLocation(), backend.getPreferences().automaticCopy, n.first + "SpikeTimes", [&]() { - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, model.getTimePrecision(), - "sT" + n.first, n.second.getSpikeTimeLocation(), true, - numNeuronDelaySlots); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + model.getTimePrecision(), modelMerged.getTypeContext(), "sT" + n.first, + n.second.getSpikeTimeLocation(), true, numNeuronDelaySlots); }); } // If neuron group needs to record its previous spike times if (n.second.isPrevSpikeTimeRequired()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getTimePrecision(), "prevST" + n.first, n.second.getPrevSpikeTimeLocation(), - numNeuronDelaySlots, mem); + model.getTimePrecision(), modelMerged.getTypeContext(), "prevST" + n.first, + n.second.getPrevSpikeTimeLocation(), numNeuronDelaySlots, mem); // Generate push and pull functions genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, n.second.getPrevSpikeTimeLocation(), backend.getPreferences().automaticCopy, n.first + "PreviousSpikeTimes", [&]() { - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, model.getTimePrecision(), - "prevST" + n.first, n.second.getPrevSpikeTimeLocation(), true, - numNeuronDelaySlots); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + model.getTimePrecision(), modelMerged.getTypeContext(), "prevST" + n.first, + n.second.getPrevSpikeTimeLocation(), true, numNeuronDelaySlots); }); } // If neuron group needs to record its spike-like-event times if (n.second.isSpikeEventTimeRequired()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getTimePrecision(), "seT" + n.first, n.second.getSpikeEventTimeLocation(), - numNeuronDelaySlots, mem); + model.getTimePrecision(), modelMerged.getTypeContext(), "seT" + n.first, + n.second.getSpikeEventTimeLocation(), numNeuronDelaySlots, mem); // Generate push and pull functions genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, n.second.getSpikeTimeLocation(), backend.getPreferences().automaticCopy, n.first + "SpikeEventTimes", [&]() { - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, model.getTimePrecision(), - "seT" + n.first, n.second.getSpikeEventTimeLocation(), true, - numNeuronDelaySlots); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + model.getTimePrecision(), modelMerged.getTypeContext(), "seT" + n.first, + n.second.getSpikeEventTimeLocation(), true, numNeuronDelaySlots); }); } // If neuron group needs to record its previous spike-like-event times if (n.second.isPrevSpikeEventTimeRequired()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getTimePrecision(), "prevSET" + n.first, n.second.getPrevSpikeEventTimeLocation(), - numNeuronDelaySlots, mem); + model.getTimePrecision(), modelMerged.getTypeContext(), "prevSET" + n.first, + n.second.getPrevSpikeEventTimeLocation(), numNeuronDelaySlots, mem); // Generate push and pull functions genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, n.second.getPrevSpikeEventTimeLocation(), backend.getPreferences().automaticCopy, n.first + "PreviousSpikeEventTimes", [&]() { - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, model.getTimePrecision(), - "prevSET" + n.first, n.second.getPrevSpikeEventTimeLocation(), true, - numNeuronDelaySlots); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + model.getTimePrecision(), modelMerged.getTypeContext(), "prevSET" + n.first, + n.second.getPrevSpikeEventTimeLocation(), true, numNeuronDelaySlots); }); } @@ -1073,8 +1071,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const unsigned int numElements = getNumVarElements(var.access, n.second.getNumNeurons()); const size_t count = n.second.isVarQueueRequired(var.name) ? numCopies * numElements * n.second.getNumDelaySlots() : numCopies * n.second.getNumNeurons(); const bool autoInitialized = !varInitSnippet->getCode().empty(); - genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - runnerPushFunc, runnerPullFunc, var.type, var.name + n.first, + genVariable(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, + runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, var.type, var.name + n.first, n.second.getVarLocation(var.name), autoInitialized, count, mem, neuronStatePushPullFunctions); // Current variable push and pull functions @@ -1082,8 +1080,9 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, backend.getPreferences().automaticCopy, "Current" + var.name + n.first, [&]() { - backend.genCurrentVariablePushPull(runnerPushFunc, runnerPullFunc, n.second, var.type, - var.name, n.second.getVarLocation(var.name), numCopies); + backend.genCurrentVariablePushPull(runnerPushFunc, runnerPullFunc, n.second, + var.type, modelMerged.getTypeContext(), var.name, + n.second.getVarLocation(var.name), numCopies); }); // Write getter to get access to correct pointer @@ -1241,14 +1240,14 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Loop through merged postsynaptic models of incoming synaptic populations for(const auto *sg : n.second.getFusedPSMInSyn()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getPrecision(), "inSyn" + sg->getFusedPSVarSuffix(), sg->getInSynLocation(), - sg->getTrgNeuronGroup()->getNumNeurons() * batchSize, mem); + model.getPrecision(), modelMerged.getTypeContext(), "inSyn" + sg->getFusedPSVarSuffix(), + sg->getInSynLocation(), sg->getTrgNeuronGroup()->getNumNeurons() * batchSize, mem); if (sg->isDendriticDelayRequired()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getPrecision(), "denDelay" + sg->getFusedPSVarSuffix(), sg->getDendriticDelayLocation(), - (size_t)sg->getMaxDendriticDelayTimesteps() * (size_t)sg->getTrgNeuronGroup()->getNumNeurons() * batchSize, mem); - genHostDeviceScalar(backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + model.getPrecision(), modelMerged.getTypeContext(), "denDelay" + sg->getFusedPSVarSuffix(), + sg->getDendriticDelayLocation(), (size_t)sg->getMaxDendriticDelayTimesteps() * (size_t)sg->getTrgNeuronGroup()->getNumNeurons() * batchSize, mem); + genHostDeviceScalar(modelMerged, backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, "denDelayPtr" + sg->getFusedPSVarSuffix(), "0", mem); } @@ -1262,8 +1261,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Loop through fused outgoing synapse populations with weightupdate models that have presynaptic output for(const auto *sg : n.second.getFusedPreOutputOutSyn()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getPrecision(), "revInSyn" + sg->getFusedPreOutputSuffix(), sg->getInSynLocation(), - sg->getSrcNeuronGroup()->getNumNeurons() * batchSize, mem); + model.getPrecision(), modelMerged.getTypeContext(), "revInSyn" + sg->getFusedPreOutputSuffix(), + sg->getInSynLocation(), sg->getSrcNeuronGroup()->getNumNeurons() * batchSize, mem); } // Loop through merged postsynaptic weight updates of incoming synaptic populations @@ -1303,7 +1302,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, if(s.second.getMatrixType() & SynapseMatrixConnectivity::BITMASK) { const size_t gpSize = ceilDivide((size_t)s.second.getSrcNeuronGroup()->getNumNeurons() * backend.getSynapticMatrixRowStride(s.second), 32); backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "gp" + s.second.getName(), s.second.getSparseConnectivityLocation(), gpSize, mem); + modelMerged.getTypeContext(), "gp" + s.second.getName(), + s.second.getSparseConnectivityLocation(), gpSize, mem); // Generate push and pull functions for bitmask genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, s.second.getSparseConnectivityLocation(), @@ -1311,7 +1311,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { // Row lengths - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, "gp" + s.second.getName(), + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + modelMerged.getTypeContext(), "gp" + s.second.getName(), s.second.getSparseConnectivityLocation(), autoInitialized, gpSize); }); } @@ -1325,11 +1326,13 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Row lengths backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "rowLength" + s.second.getName(), varLoc, s.second.getSrcNeuronGroup()->getNumNeurons(), mem); + modelMerged.getTypeContext(), "rowLength" + s.second.getName(), + varLoc, s.second.getSrcNeuronGroup()->getNumNeurons(), mem); // Target indices backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - s.second.getSparseIndType(), "ind" + s.second.getName(), varLoc, size, mem); + s.second.getSparseIndType(), modelMerged.getTypeContext(), "ind" + s.second.getName(), + varLoc, size, mem); // **TODO** remap is not always required if(backend.isPostsynapticRemapRequired() && !s.second.getWUModel()->getLearnPostCode().empty()) { @@ -1337,11 +1340,13 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Allocate column lengths backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "colLength" + s.second.getName(), VarLocation::DEVICE, s.second.getTrgNeuronGroup()->getNumNeurons(), mem); + modelMerged.getTypeContext(), "colLength" + s.second.getName(), + VarLocation::DEVICE, s.second.getTrgNeuronGroup()->getNumNeurons(), mem); // Allocate remap backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "remap" + s.second.getName(), VarLocation::DEVICE, postSize, mem); + modelMerged.getTypeContext(), "remap" + s.second.getName(), + VarLocation::DEVICE, postSize, mem); } // Generate push and pull functions for sparse connectivity @@ -1350,11 +1355,13 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { // Row lengths - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, "rowLength" + s.second.getName(), + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + modelMerged.getTypeContext(), "rowLength" + s.second.getName(), s.second.getSparseConnectivityLocation(), autoInitialized, s.second.getSrcNeuronGroup()->getNumNeurons()); // Target indices - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, s.second.getSparseIndType(), "ind" + s.second.getName(), + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + s.second.getSparseIndType(), modelMerged.getTypeContext(), "ind" + s.second.getName(), s.second.getSparseConnectivityLocation(), autoInitialized, size); }); } @@ -1379,7 +1386,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const bool autoInitialized = !varInitSnippet->getCode().empty(); if(individualWeights) { const size_t size = (size_t)s.second.getSrcNeuronGroup()->getNumNeurons() * (size_t)backend.getSynapticMatrixRowStride(s.second); - genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + genVariable(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, wuVar.type, wuVar.name + s.second.getName(), s.second.getWUVarLocation(wuVar.name), autoInitialized, size * getNumVarCopies(wuVar.access, batchSize), mem, synapseGroupStatePushPullFunctions); } @@ -1388,7 +1395,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const size_t size = s.second.getKernelSizeFlattened() * getNumVarCopies(wuVar.access, batchSize); // Generate variable - genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + genVariable(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, wuVar.type, wuVar.name + s.second.getName(), s.second.getWUVarLocation(wuVar.name), autoInitialized, size, mem, synapseGroupStatePushPullFunctions); } @@ -1411,11 +1418,12 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, backend.getPreferences().automaticCopy, "inSyn" + s.second.getName(), synapseGroupStatePushPullFunctions, [&]() { - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, model.getPrecision(), "inSyn" + s.second.getName(), s.second.getInSynLocation(), - true, s.second.getTrgNeuronGroup()->getNumNeurons() * batchSize); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + model.getPrecision(), modelMerged.getTypeContext(), "inSyn" + s.second.getName(), + s.second.getInSynLocation(), true, s.second.getTrgNeuronGroup()->getNumNeurons() * batchSize); }); - genRunnerFusedVarPushPull(backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, + genRunnerFusedVarPushPull(modelMerged, backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, [batchSize](const SynapseGroupInternal &sg, const Models::Base::Var &var) { return getVarSize(var.access, sg.getTrgNeuronGroup()->getNumNeurons(), batchSize); @@ -1426,7 +1434,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // **NOTE** we generated initialisation and declaration code earlier - here we just generate push and pull as we want this per-synapse group if(!s.second.isWUPreModelFused()) { const unsigned int preDelaySlots = (s.second.getDelaySteps() == NO_DELAY) ? 1 : s.second.getSrcNeuronGroup()->getNumDelaySlots(); - genRunnerFusedVarPushPull(backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, + genRunnerFusedVarPushPull(modelMerged, backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, [batchSize, preDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) { return getVarSize(var.access, sg.getSrcNeuronGroup()->getNumNeurons(), batchSize, preDelaySlots); @@ -1438,7 +1446,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // **NOTE** we generated initialisation and declaration code earlier - here we just generate push and pull as we want this per-synapse group if(!s.second.isWUPostModelFused()) { const unsigned int postDelaySlots = (s.second.getBackPropDelaySteps() == NO_DELAY) ? 1 : s.second.getTrgNeuronGroup()->getNumDelaySlots(); - genRunnerFusedVarPushPull(backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, + genRunnerFusedVarPushPull(modelMerged, backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, [batchSize, postDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) { return getVarSize(var.access, sg.getTrgNeuronGroup()->getNumNeurons(), batchSize, postDelaySlots); @@ -1590,10 +1598,11 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, } // Allocate spike array if required - // **YUCK** maybe this should be renamed genDynamicArray if(n.second.isSpikeRecordingEnabled()) { CodeStream::Scope b(runner); - backend.genExtraGlobalParamAllocation(runner, Type::Uint32::getInstance()->getPointerType(), "recordSpk" + n.first, VarLocation::HOST_DEVICE, "numWords"); + backend.genVariableDynamicAllocation(runner, + Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpk" + n.first, + VarLocation::HOST_DEVICE, "numWords"); // Get destinations in merged structures, this EGP // needs to be copied to and call push function @@ -1605,10 +1614,11 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, } // Allocate spike event array if required - // **YUCK** maybe this should be renamed genDynamicArray if(n.second.isSpikeEventRecordingEnabled()) { CodeStream::Scope b(runner); - backend.genExtraGlobalParamAllocation(runner, Type::Uint32::getInstance()->getPointerType(), "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE, "numWords"); + backend.genVariableDynamicAllocation(runner, + Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpkEvent" + n.first, + VarLocation::HOST_DEVICE, "numWords"); // Get destinations in merged structures, this EGP // needs to be copied to and call push function @@ -1644,16 +1654,19 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, } // Pull spike array if required - // **YUCK** maybe this should be renamed pullDynamicArray if(n.second.isSpikeRecordingEnabled()) { CodeStream::Scope b(runner); - backend.genExtraGlobalParamPull(runner, Type::Uint32::getInstance()->getPointerType(), "recordSpk" + n.first, VarLocation::HOST_DEVICE, "numWords"); + backend.genVariableDynamicPull(runner, + Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpk" + n.first, + VarLocation::HOST_DEVICE, "numWords"); } // AllocaPullte spike event array if required // **YUCK** maybe this should be renamed pullDynamicArray if(n.second.isSpikeEventRecordingEnabled()) { CodeStream::Scope b(runner); - backend.genExtraGlobalParamPull(runner, Type::Uint32::getInstance()->getPointerType(), "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE, "numWords"); + backend.genVariableDynamicPull(runner, + Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpkEvent" + n.first, + VarLocation::HOST_DEVICE, "numWords"); } } } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index dfa2f29822..99c7d5b605 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -805,11 +805,13 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac const auto loc = getArchetype().getSparseConnectivityExtraGlobalParamLocation(egp.name); if(loc & VarLocation::HOST) { // Generate code to allocate this EGP with count specified by $(0) + // **NOTE** we generate these with a pointer type as the fields are pointer to pointer std::stringstream allocStream; - const auto *pointerToPointerToEGP = egp.type->getPointerType()->getPointerType(); + const auto *pointerToEGP = egp.type->getPointerType(); CodeGenerator::CodeStream alloc(allocStream); - backend.genExtraGlobalParamAllocation(alloc, pointerToPointerToEGP, egp.name, - loc, "$(0)", "group->"); + backend.genVariableDynamicAllocation(alloc, + pointerToEGP, getTypeContext(), egp.name, + loc, "$(0)", "group->"); // Add substitution subs.addFuncSubstitution("allocate" + egp.name, 1, allocStream.str()); @@ -817,8 +819,9 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac // Generate code to push this EGP with count specified by $(0) std::stringstream pushStream; CodeStream push(pushStream); - backend.genExtraGlobalParamPush(push, pointerToPointerToEGP, egp.name, - loc, "$(0)", "group->"); + backend.genVariableDynamicPush(push, + pointerToEGP, getTypeContext(), egp.name, + loc, "$(0)", "group->"); // Add substitution From 72f5202095871fc05698cf55061bff762d009d8e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 11:32:27 +0000 Subject: [PATCH 073/725] parsing of numeric pointer types no longer necessary --- include/genn/genn/transpiler/parser.h | 2 +- include/genn/genn/type.h | 3 --- src/genn/genn/transpiler/parser.cc | 35 +++++---------------------- src/genn/genn/type.cc | 27 ++------------------- 4 files changed, 9 insertions(+), 58 deletions(-) diff --git a/include/genn/genn/transpiler/parser.h b/include/genn/genn/transpiler/parser.h index d612189c67..ae9010d871 100644 --- a/include/genn/genn/transpiler/parser.h +++ b/include/genn/genn/transpiler/parser.h @@ -29,6 +29,6 @@ Expression::ExpressionPtr parseExpression(const std::vector &tokens, Erro Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandlerBase &errorHandler); //! Parse type from tokens -const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, ErrorHandlerBase &errorHandler); +const GeNN::Type::NumericBase *parseNumericType(const std::vector &tokens, ErrorHandlerBase &errorHandler); } // MiniParse::MiniParse diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 0ac324fe9d..439fc40d3f 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -400,9 +400,6 @@ DECLARE_FOREIGN_FUNCTION_TYPE(Sqrt, Double, Double); //! Parse a numeric type const NumericBase *parseNumeric(std::string_view typeString); -//! Parse a numeric pointer type -const Pointer *parseNumericPtr(std::string_view typeString); - //! Look up numeric type based on set of type specifiers const NumericBase *getNumericType(const std::set &typeSpecifiers); diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index f4bb83a544..f3dc60a635 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -871,40 +871,17 @@ Statement::StatementList parseBlockItemList(const std::vector &tokens, Er return statements; } //--------------------------------------------------------------------------- -const GeNN::Type::Base *parseType(const std::vector &tokens, bool allowPointers, ErrorHandlerBase &errorHandler) +const GeNN::Type::NumericBase *parseNumericType(const std::vector &tokens, ErrorHandlerBase &errorHandler) { ParserState parserState(tokens, errorHandler); - bool pointerFound = false; std::set typeSpecifiers; - while(parserState.match({Token::Type::TYPE_SPECIFIER, Token::Type::STAR})) { - // If token is a star, set pointer found flag - if(parserState.previous().type == Token::Type::STAR) { - if (!allowPointers) { - parserState.error(parserState.previous(), "pointer type not valid in this context"); - } - pointerFound = true; - } - // Otherwise, if token is type specifier - else if(parserState.previous().type == Token::Type::TYPE_SPECIFIER) { - if(pointerFound) { - parserState.error(parserState.previous(), "invalid type specifier"); - } - else if(!typeSpecifiers.insert(parserState.previous().lexeme).second) { - parserState.error(parserState.previous(), "duplicate type specifier"); - } + while(parserState.match(Token::Type::TYPE_SPECIFIER)) { + if(!typeSpecifiers.insert(parserState.previous().lexeme).second) { + parserState.error(parserState.previous(), "duplicate type specifier"); } }; - // Lookup numeric type - const auto *numericType = GeNN::Type::getNumericType(typeSpecifiers); - - // If pointer, return pointer to numeric type - if (pointerFound) { - return numericType->getPointerType(); - } - // Otherwise, return numeric type directly - else { - return numericType; - } + // Return numeric type + return GeNN::Type::getNumericType(typeSpecifiers); } } diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index e60e3a9abc..adc806a0e9 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -172,31 +172,8 @@ const NumericBase *parseNumeric(std::string_view typeString) SingleLineErrorHandler errorHandler; const auto tokens = Scanner::scanSource(typeString, errorHandler); - // Parse type and cast to numeric - const auto *type = dynamic_cast(Parser::parseType(tokens, false, errorHandler)); - - // If an error was encountered while scanning or parsing, throw exception - if (errorHandler.hasError()) { - throw std::runtime_error("Error parsing type '" + std::string{typeString} + "'"); - } - - // If tokens did not contain a valid numeric type, throw exception - if (!type) { - throw std::runtime_error("Unable to parse type '" + std::string{typeString} + "'"); - } - return type; -} -//---------------------------------------------------------------------------- -const Pointer *parseNumericPtr(std::string_view typeString) -{ - using namespace Transpiler; - - // Scan type - SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, errorHandler); - - // Parse type and cast to numeric pointer - const auto *type = dynamic_cast(Parser::parseType(tokens, true, errorHandler)); + // Parse type numeric type + const auto *type = Parser::parseNumericType(tokens, errorHandler); // If an error was encountered while scanning or parsing, throw exception if (errorHandler.hasError()) { From f667d5f4ed8b5aae7ee3c7224066fc0c262bab8d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 11:32:35 +0000 Subject: [PATCH 074/725] started updating unit tests --- tests/unit/customConnectivityUpdate.cc | 2 +- tests/unit/customUpdate.cc | 20 +++--- tests/unit/modelSpecMerged.cc | 10 +-- tests/unit/neuronGroup.cc | 22 +++---- tests/unit/scanner.cc | 91 ++++++++++---------------- tests/unit/synapseGroup.cc | 18 ++--- tests/unit/typeChecker.cc | 54 +++++++-------- 7 files changed, 96 insertions(+), 121 deletions(-) diff --git a/tests/unit/customConnectivityUpdate.cc b/tests/unit/customConnectivityUpdate.cc index fad4251cd3..8dc4b64d8b 100644 --- a/tests/unit/customConnectivityUpdate.cc +++ b/tests/unit/customConnectivityUpdate.cc @@ -358,7 +358,7 @@ TEST(CustomConnectivityUpdate, CompareDifferentDependentVars) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); diff --git a/tests/unit/customUpdate.cc b/tests/unit/customUpdate.cc index 9c5d9c17a3..da32220f7a 100644 --- a/tests/unit/customUpdate.cc +++ b/tests/unit/customUpdate.cc @@ -179,7 +179,7 @@ TEST(CustomUpdates, ConstantVarSum) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -208,7 +208,7 @@ TEST(CustomUpdates, UninitialisedVarSum) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -238,7 +238,7 @@ TEST(CustomUpdates, RandVarSum) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -679,7 +679,7 @@ TEST(CustomUpdates, CompareDifferentModel) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -719,7 +719,7 @@ TEST(CustomUpdates, CompareDifferentUpdateGroup) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -794,7 +794,7 @@ TEST(CustomUpdates, CompareDifferentDelay) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -845,7 +845,7 @@ TEST(CustomUpdates, CompareDifferentBatched) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -894,7 +894,7 @@ TEST(CustomUpdates, CompareDifferentWUTranspose) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -946,7 +946,7 @@ TEST(CustomUpdates, CompareDifferentWUConnectivity) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -1007,7 +1007,7 @@ TEST(CustomUpdates, CompareDifferentWUBatched) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); diff --git a/tests/unit/modelSpecMerged.cc b/tests/unit/modelSpecMerged.cc index ed02938353..5c684c7ec3 100644 --- a/tests/unit/modelSpecMerged.cc +++ b/tests/unit/modelSpecMerged.cc @@ -182,8 +182,8 @@ void test(const std::pair (&modelModifiers)[N], M applyModifierFn) model.setName("test"); model.setDT(0.1); model.setTiming(false); - model.setPrecision(ScalarPrecision::FLOAT); - model.setTimePrecision(TimePrecision::DEFAULT); + model.setPrecision(Type::Float::getInstance()); + model.setTimePrecision(nullptr); model.setBatchSize(1); model.setSeed(0); @@ -194,7 +194,7 @@ void test(const std::pair (&modelModifiers)[N], M applyModifierFn) model.finalize(); // Create suitable backend to build model - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -319,8 +319,8 @@ TEST(ModelSpecMerged, CompareModelChanges) {[](ModelSpecInternal &model) { model.setName("interesting_name"); }, false}, {[](ModelSpecInternal &model) { model.setDT(1.0); }, false}, {[](ModelSpecInternal &model) { model.setTiming(true); }, false}, - {[](ModelSpecInternal &model) { model.setPrecision(ScalarPrecision::DOUBLE); }, false}, - {[](ModelSpecInternal &model) { model.setTimePrecision(TimePrecision::DOUBLE); }, false}, + {[](ModelSpecInternal &model) { model.setPrecision(Type::Double::getInstance()); }, false}, + {[](ModelSpecInternal &model) { model.setTimePrecision(Type::Double::getInstance()); }, false}, {[](ModelSpecInternal &model) { model.setBatchSize(10); }, false}, {[](ModelSpecInternal &model) { model.setSeed(1234); }, false}}; diff --git a/tests/unit/neuronGroup.cc b/tests/unit/neuronGroup.cc index f42b834ce2..5a43a53fb6 100644 --- a/tests/unit/neuronGroup.cc +++ b/tests/unit/neuronGroup.cc @@ -227,7 +227,7 @@ TEST(NeuronGroup, ConstantVarIzhikevich) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -251,7 +251,7 @@ TEST(NeuronGroup, UninitialisedVarIzhikevich) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -276,7 +276,7 @@ TEST(NeuronGroup, RandVarIzhikevich) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -300,7 +300,7 @@ TEST(NeuronGroup, Poisson) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -582,7 +582,7 @@ TEST(NeuronGroup, CompareNeuronModels) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -632,7 +632,7 @@ TEST(NeuronGroup, CompareHeterogeneousParamVarState) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -731,7 +731,7 @@ TEST(NeuronGroup, CompareCurrentSources) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -849,7 +849,7 @@ TEST(NeuronGroup, ComparePostsynapticModels) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -951,7 +951,7 @@ TEST(NeuronGroup, ComparePreOutput) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -1040,7 +1040,7 @@ TEST(NeuronGroup, CompareWUPreUpdate) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -1141,7 +1141,7 @@ TEST(NeuronGroup, CompareWUPostUpdate) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); diff --git a/tests/unit/scanner.cc b/tests/unit/scanner.cc index 52ca2f0b42..1f366c05f1 100644 --- a/tests/unit/scanner.cc +++ b/tests/unit/scanner.cc @@ -56,91 +56,66 @@ class TestErrorHandler : public ErrorHandlerBase TEST(Scanner, DecimalInt) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("1234 4294967295U -2345 -2147483647", GeNN::Type::Float::getInstance(), errorHandler); + const auto tokens = Scanner::scanSource("1234 4294967295U -2345 -2147483647", errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 7); - ASSERT_EQ(tokens[0].type, Token::Type::NUMBER); - ASSERT_EQ(tokens[1].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[0].type, Token::Type::INT32_NUMBER); + ASSERT_EQ(tokens[1].type, Token::Type::UINT32_NUMBER); ASSERT_EQ(tokens[2].type, Token::Type::MINUS); - ASSERT_EQ(tokens[3].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[3].type, Token::Type::INT32_NUMBER); ASSERT_EQ(tokens[4].type, Token::Type::MINUS); - ASSERT_EQ(tokens[5].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[5].type, Token::Type::INT32_NUMBER); ASSERT_EQ(tokens[6].type, Token::Type::END_OF_FILE); - ASSERT_EQ(std::get(tokens[0].literalValue), 1234); - ASSERT_EQ(std::get(tokens[1].literalValue), 4294967295U); - ASSERT_EQ(std::get(tokens[3].literalValue), 2345); - ASSERT_EQ(std::get(tokens[5].literalValue), 2147483647); + ASSERT_EQ(tokens[0].lexeme, "1234"); + ASSERT_EQ(tokens[1].lexeme, "4294967295U"); + ASSERT_EQ(tokens[3].lexeme, "2345"); + ASSERT_EQ(tokens[5].lexeme, "2147483647"); } //-------------------------------------------------------------------------- TEST(Scanner, HexInt) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("0x1234 0xFFFFFFFFU -0x1234 -0x7FFFFFFF", GeNN::Type::Float::getInstance(), errorHandler); + const auto tokens = Scanner::scanSource("0x1234 0xFFFFFFFFU -0x1234 -0x7FFFFFFF", errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 7); - ASSERT_EQ(tokens[0].type, Token::Type::NUMBER); - ASSERT_EQ(tokens[1].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[0].type, Token::Type::INT32_NUMBER); + ASSERT_EQ(tokens[1].type, Token::Type::UINT32_NUMBER); ASSERT_EQ(tokens[2].type, Token::Type::MINUS); - ASSERT_EQ(tokens[3].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[3].type, Token::Type::INT32_NUMBER); ASSERT_EQ(tokens[4].type, Token::Type::MINUS); - ASSERT_EQ(tokens[5].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[5].type, Token::Type::INT32_NUMBER); ASSERT_EQ(tokens[6].type, Token::Type::END_OF_FILE); - ASSERT_EQ(std::get(tokens[0].literalValue), 0x1234); - ASSERT_EQ(std::get(tokens[1].literalValue), 0xFFFFFFFFU); - ASSERT_EQ(std::get(tokens[3].literalValue), 0x1234); - ASSERT_EQ(std::get(tokens[5].literalValue), 0x7FFFFFFF); + ASSERT_EQ(tokens[0].lexeme, "0x1234"); + ASSERT_EQ(tokens[1].lexeme, "0xFFFFFFFFU"); + ASSERT_EQ(tokens[3].lexeme, "0x1234"); + ASSERT_EQ(tokens[5].lexeme, "0x7FFFFFFF"); } //-------------------------------------------------------------------------- -TEST(Scanner, DecimalFloatFloatScalar) +TEST(Scanner, DecimalFloat) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("1.0 0.2 100.0f 0.2f -12.0d -0.0004f", GeNN::Type::Float::getInstance(), errorHandler); + const auto tokens = Scanner::scanSource("1.0 0.2 100.0f 0.2f -12.0d -0.0004f", errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 9); - ASSERT_EQ(tokens[0].type, Token::Type::NUMBER); - ASSERT_EQ(tokens[1].type, Token::Type::NUMBER); - ASSERT_EQ(tokens[2].type, Token::Type::NUMBER); - ASSERT_EQ(tokens[3].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[0].type, Token::Type::SCALAR_NUMBER); + ASSERT_EQ(tokens[1].type, Token::Type::SCALAR_NUMBER); + ASSERT_EQ(tokens[2].type, Token::Type::FLOAT_NUMBER); + ASSERT_EQ(tokens[3].type, Token::Type::FLOAT_NUMBER); ASSERT_EQ(tokens[4].type, Token::Type::MINUS); - ASSERT_EQ(tokens[5].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[5].type, Token::Type::DOUBLE_NUMBER); ASSERT_EQ(tokens[6].type, Token::Type::MINUS); - ASSERT_EQ(tokens[7].type, Token::Type::NUMBER); + ASSERT_EQ(tokens[7].type, Token::Type::FLOAT_NUMBER); ASSERT_EQ(tokens[8].type, Token::Type::END_OF_FILE); - ASSERT_EQ(std::get(tokens[0].literalValue), 1.0f); - ASSERT_EQ(std::get(tokens[1].literalValue), 0.2f); - ASSERT_EQ(std::get(tokens[2].literalValue), 100.0f); - ASSERT_EQ(std::get(tokens[3].literalValue), 0.2f); - ASSERT_EQ(std::get(tokens[5].literalValue), 12.0); - ASSERT_EQ(std::get(tokens[7].literalValue), 0.0004f); -} -//-------------------------------------------------------------------------- -TEST(Scanner, DecimalFloatDoubleScalar) -{ - TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("1.0 0.2 100.0f 0.2f -12.0d -0.0004f", GeNN::Type::Double::getInstance(), errorHandler); - ASSERT_FALSE(errorHandler.hasError()); - - ASSERT_EQ(tokens.size(), 9); - ASSERT_EQ(tokens[0].type, Token::Type::NUMBER); - ASSERT_EQ(tokens[1].type, Token::Type::NUMBER); - ASSERT_EQ(tokens[2].type, Token::Type::NUMBER); - ASSERT_EQ(tokens[3].type, Token::Type::NUMBER); - ASSERT_EQ(tokens[4].type, Token::Type::MINUS); - ASSERT_EQ(tokens[5].type, Token::Type::NUMBER); - ASSERT_EQ(tokens[6].type, Token::Type::MINUS); - ASSERT_EQ(tokens[7].type, Token::Type::NUMBER); - ASSERT_EQ(tokens[8].type, Token::Type::END_OF_FILE); - - ASSERT_EQ(std::get(tokens[0].literalValue), 1.0); - ASSERT_EQ(std::get(tokens[1].literalValue), 0.2); - ASSERT_EQ(std::get(tokens[2].literalValue), 100.0f); - ASSERT_EQ(std::get(tokens[3].literalValue), 0.2f); - ASSERT_EQ(std::get(tokens[5].literalValue), 12.0); - ASSERT_EQ(std::get(tokens[7].literalValue), 0.0004f); -} + ASSERT_EQ(tokens[0].lexeme, "1.0f"); + ASSERT_EQ(tokens[1].lexeme, "0.2f"); + ASSERT_EQ(tokens[2].lexeme, "100.0f"); + ASSERT_EQ(tokens[3].lexeme, "0.2f"); + ASSERT_EQ(tokens[5].lexeme, "12.0"); + ASSERT_EQ(tokens[7].lexeme, "0.0004f"); +} \ No newline at end of file diff --git a/tests/unit/synapseGroup.cc b/tests/unit/synapseGroup.cc index 373f102dba..ee8b9fc815 100644 --- a/tests/unit/synapseGroup.cc +++ b/tests/unit/synapseGroup.cc @@ -373,7 +373,7 @@ TEST(SynapseGroup, CompareWUDifferentModel) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -423,7 +423,7 @@ TEST(SynapseGroup, CompareWUDifferentGlobalG) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -478,7 +478,7 @@ TEST(SynapseGroup, CompareWUDifferentProceduralConnectivity) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -544,7 +544,7 @@ TEST(SynapseGroup, CompareWUDifferentToeplitzConnectivity) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -610,7 +610,7 @@ TEST(SynapseGroup, CompareWUDifferentProceduralVars) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -664,7 +664,7 @@ TEST(SynapseGroup, CompareWUDifferentProceduralSnippet) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -723,7 +723,7 @@ TEST(SynapseGroup, InitCompareWUDifferentVars) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -864,7 +864,7 @@ TEST(SynapseGroup, InitCompareWUDifferentHeterogeneousParamVarState) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); @@ -924,7 +924,7 @@ TEST(SynapseGroup, InitCompareWUSynapseDynamicsPostLearn) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(model.getPrecision(), preferences); + CodeGenerator::SingleThreadedCPU::Backend backend(preferences); // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index dbd8f7561c..945df8fb2f 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -138,35 +138,35 @@ std::string getPointerTypeName() return T::getInstance()->getPointerType()->getName(); } -void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment, const Type::NumericBase *scalarType = Type::Float::getInstance()) +void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment, const Type::TypeContext &typeContext) { // Scan TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(code, scalarType, errorHandler); + const auto tokens = Scanner::scanSource(code, errorHandler); ASSERT_FALSE(errorHandler.hasError()); // Parse - const auto statements = Parser::parseBlockItemList(tokens, scalarType, errorHandler); + const auto statements = Parser::parseBlockItemList(tokens, errorHandler); ASSERT_FALSE(errorHandler.hasError()); // Typecheck - TypeChecker::typeCheck(statements, typeEnvironment, errorHandler); + TypeChecker::typeCheck(statements, typeEnvironment, typeContext, errorHandler); ASSERT_FALSE(errorHandler.hasError()); } -const Type::Base *typeCheckExpression(std::string_view code, TestEnvironment &typeEnvironment, const Type::NumericBase *scalarType = Type::Float::getInstance()) +const Type::Base *typeCheckExpression(std::string_view code, TestEnvironment &typeEnvironment, const Type::TypeContext &typeContext) { // Scan TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(code, scalarType, errorHandler); + const auto tokens = Scanner::scanSource(code, errorHandler); EXPECT_FALSE(errorHandler.hasError()); // Parse - const auto expression = Parser::parseExpression(tokens, scalarType, errorHandler); + const auto expression = Parser::parseExpression(tokens, errorHandler); EXPECT_FALSE(errorHandler.hasError()); // Typecheck - const auto *type = TypeChecker::typeCheck(expression.get(), typeEnvironment, errorHandler); + const auto *type = TypeChecker::typeCheck(expression.get(), typeEnvironment, typeContext, errorHandler); EXPECT_FALSE(errorHandler.hasError()); return type; } @@ -183,7 +183,7 @@ TEST(TypeChecker, ArraySubscript) typeEnvironment.definePointer("intArray"); const auto *type = typeCheckExpression("intArray[4]", typeEnvironment); EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); } // Float array indexing @@ -225,7 +225,7 @@ TEST(TypeChecker, Assignment) { TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); - typeEnvironment.definePointer("intArrayConst", Type::Qualifier::CONST); + typeEnvironment.definePointer("intArrayConst", Type::Qualifier::CONSTANT); typeCheckStatements( "int *x = intArray;\n" "const int *y = intArray;\n" @@ -236,7 +236,7 @@ TEST(TypeChecker, Assignment) // Pointer assignement, attempt to remove const EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier::CONST); + typeEnvironment.definePointer("intArray", Type::Qualifier::CONSTANT); typeCheckStatements("int *x = intArray;", typeEnvironment);}, TypeChecker::TypeCheckError); @@ -266,7 +266,7 @@ TEST(TypeChecker, Cast) typeEnvironment.define("intVal"); const auto *type = typeCheckExpression("(float)intVal", typeEnvironment); EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); } // Numeric cast to const @@ -275,7 +275,7 @@ TEST(TypeChecker, Cast) typeEnvironment.define("intVal"); const auto *type = typeCheckExpression("(const int)intVal", typeEnvironment); EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONST)); + EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONSTANT)); } // Pointer cast to value const @@ -283,12 +283,12 @@ TEST(TypeChecker, Cast) TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); const auto *type = typeCheckExpression("(const int*)intArray", typeEnvironment); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); const auto *pointerType = dynamic_cast(type); EXPECT_TRUE(pointerType); EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONST)); + EXPECT_TRUE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONSTANT)); } // Pointer cast to pointer const @@ -296,32 +296,32 @@ TEST(TypeChecker, Cast) TestEnvironment typeEnvironment; typeEnvironment.definePointer("intArray"); const auto *type = typeCheckExpression("(int * const)intArray", typeEnvironment); - EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONST)); + EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONSTANT)); const auto *pointerType = dynamic_cast(type); EXPECT_TRUE(pointerType); EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONST)); + EXPECT_FALSE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONSTANT)); } // Can't remove value const from numeric EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.define("intVal", Type::Qualifier::CONST); + typeEnvironment.define("intVal", Type::Qualifier::CONSTANT); typeCheckExpression("(int)intVal", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't remove value const from pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier::CONST); + typeEnvironment.definePointer("intArray", Type::Qualifier::CONSTANT); typeCheckExpression("(int*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't remove pointer const from pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier{0}, Type::Qualifier::CONST); + typeEnvironment.definePointer("intArray", Type::Qualifier{0}, Type::Qualifier::CONSTANT); typeCheckExpression("(int*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); @@ -359,7 +359,7 @@ TEST(TypeChecker, IncDec) typeEnvironment.define("intVal"); const auto *type = typeCheckExpression("intVal++", typeEnvironment); EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); } // Can increment pointer @@ -368,33 +368,33 @@ TEST(TypeChecker, IncDec) typeEnvironment.definePointer("intArray"); const auto *type = typeCheckExpression("intArray++", typeEnvironment); EXPECT_EQ(type->getName(), getPointerTypeName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); } // Can increment pointer to const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier::CONST); + typeEnvironment.definePointer("intArray", Type::Qualifier::CONSTANT); const auto *type = typeCheckExpression("intArray++", typeEnvironment); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); const auto *pointerType = dynamic_cast(type); EXPECT_TRUE(pointerType); EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONST)); + EXPECT_TRUE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONSTANT)); } // Can't increment const number EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.define("intVal", Type::Qualifier::CONST); + typeEnvironment.define("intVal", Type::Qualifier::CONSTANT); typeCheckExpression("intVal++", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't increment const pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier{0}, Type::Qualifier::CONST); + typeEnvironment.definePointer("intArray", Type::Qualifier{0}, Type::Qualifier::CONSTANT); typeCheckExpression("intArray++", typeEnvironment);}, TypeChecker::TypeCheckError); } From 2672c4fe0004d19daeabe33ccf2fbd63b6a56ce7 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 12:26:53 +0000 Subject: [PATCH 075/725] unit tests now compile - many failures :) --- tests/unit/typeChecker.cc | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 945df8fb2f..c08644cfbc 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -88,7 +88,8 @@ class TestEnvironment : public TypeChecker::EnvironmentBase } virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, - ErrorHandlerBase &errorHandler, bool initializer = false) final + const Type::TypeContext &context, ErrorHandlerBase &errorHandler, + bool initializer = false) final { // If type isn't found auto existingType = m_Types.find(name.lexeme); @@ -98,10 +99,11 @@ class TestEnvironment : public TypeChecker::EnvironmentBase } // Perform standard type-checking logic - return EnvironmentBase::assign(name, op, existingType->second, assignedType, errorHandler, initializer); + return EnvironmentBase::assign(name, op, existingType->second, assignedType, context, errorHandler, initializer); } - virtual const Type::Base *incDec(const Token &name, Token::Type op, ErrorHandlerBase &errorHandler) final + virtual const Type::Base *incDec(const Token &name, Token::Type op, + const Type::TypeContext&, ErrorHandlerBase &errorHandler) final { auto existingType = m_Types.find(name.lexeme); if(existingType == m_Types.end()) { @@ -138,7 +140,7 @@ std::string getPointerTypeName() return T::getInstance()->getPointerType()->getName(); } -void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment, const Type::TypeContext &typeContext) +void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment, const Type::NumericBase *scalarType = Type::Float::getInstance()) { // Scan TestErrorHandler errorHandler; @@ -150,11 +152,12 @@ void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment ASSERT_FALSE(errorHandler.hasError()); // Typecheck + const Type::TypeContext typeContext{{"scalar", scalarType}}; TypeChecker::typeCheck(statements, typeEnvironment, typeContext, errorHandler); ASSERT_FALSE(errorHandler.hasError()); } -const Type::Base *typeCheckExpression(std::string_view code, TestEnvironment &typeEnvironment, const Type::TypeContext &typeContext) +const Type::Base *typeCheckExpression(std::string_view code, TestEnvironment &typeEnvironment, const Type::NumericBase *scalarType = Type::Float::getInstance()) { // Scan TestErrorHandler errorHandler; @@ -166,6 +169,7 @@ const Type::Base *typeCheckExpression(std::string_view code, TestEnvironment &ty EXPECT_FALSE(errorHandler.hasError()); // Typecheck + const Type::TypeContext typeContext{{"scalar", scalarType}}; const auto *type = TypeChecker::typeCheck(expression.get(), typeEnvironment, typeContext, errorHandler); EXPECT_FALSE(errorHandler.hasError()); return type; @@ -209,7 +213,7 @@ TEST(TypeChecker, Assignment) TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); typeEnvironment.define("floatVal"); - typeEnvironment.define("intValConst", Type::Qualifier::CONST); + typeEnvironment.define("intValConst", Type::Qualifier::CONSTANT); typeCheckStatements( "int w = intVal;\n" "float x = floatVal;\n" @@ -464,34 +468,34 @@ TEST(TypeChecker, Unary) typeEnvironment.definePointer("intArray"); const auto *type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); } // Dereference pointer to const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier::CONST); + typeEnvironment.definePointer("intArray", Type::Qualifier::CONSTANT); const auto *type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONST)); + EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONSTANT)); } // Dereference const pointer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier{0}, Type::Qualifier::CONST); + typeEnvironment.definePointer("intArray", Type::Qualifier{0}, Type::Qualifier::CONSTANT); const auto *type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); } // Dereference const pointer to const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier::CONST, Type::Qualifier::CONST); + typeEnvironment.definePointer("intArray", Type::Qualifier::CONSTANT, Type::Qualifier::CONSTANT); const auto *type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONST)); + EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONSTANT)); } // Dereference numeric @@ -506,12 +510,12 @@ TEST(TypeChecker, Unary) TestEnvironment typeEnvironment; typeEnvironment.define("intVal"); const auto *type = typeCheckExpression("&intVal", typeEnvironment); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONST)); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); const auto *pointerType = dynamic_cast(type); EXPECT_TRUE(pointerType); EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONST)); + EXPECT_FALSE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONSTANT)); } // Address of pointer From 0fba83adb7301b43f9065e709f4dc756ddf808e3 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 17:39:43 +0000 Subject: [PATCH 076/725] * Fixed small bugs found by unit tests * Improved pointer value type and constness type checking --- include/genn/genn/models.h | 4 +- src/genn/genn/models.cc | 2 +- src/genn/genn/transpiler/scanner.cc | 4 +- src/genn/genn/transpiler/typeChecker.cc | 64 +++++++++++++++++++++---- tests/unit/scanner.cc | 6 +-- tests/unit/typeChecker.cc | 14 +++--- 6 files changed, 69 insertions(+), 25 deletions(-) diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index cca64046c0..25f2d3b0c8 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -311,8 +311,8 @@ void checkVarReferences(const std::unordered_map &varRefs, const const auto varRef = varRefs.at(modelVarRef.name); // Check types of variable references against those specified in model - // **THINK** due to GeNN's current string-based type system this is rather conservative - if(varRef.getVar().type != modelVarRef.type) { + // **THINK** this is rather conservative but I think not allowing scalar and whatever happens to be scalar type is ok + if(varRef.getVar().type->getName() != modelVarRef.type->getName()) { throw std::runtime_error("Incompatible type for variable reference '" + modelVarRef.name + "'"); } diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 9afc6a7ea7..6b3b3b7115 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -195,7 +195,7 @@ WUVarReference::WUVarReference(SynapseGroup *sg, const std::string &varName, } // Check types - if(getVar().type != getTransposeVar().type) { + if(getVar().type->getName() != getTransposeVar().type->getName()) { throw std::runtime_error("Transpose updates can only be performed on variables with the same type"); } diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index 6f13db090c..ca25a294fd 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -226,14 +226,14 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) // If number has an f suffix, emplace FLOAT_NUMBER token if (std::tolower(scanState.peek()) == 'f') { - emplaceToken(tokens, Token::Type::FLOAT_NUMBER, scanState); scanState.advance(); + emplaceToken(tokens, Token::Type::FLOAT_NUMBER, scanState); } // Otherwise, if it has a d suffix, emplace DOUBLE_NUMBER token // **NOTE** 'd' is a GeNN extension not standard C else if (std::tolower(scanState.peek()) == 'd') { - emplaceToken(tokens, Token::Type::DOUBLE_NUMBER, scanState); scanState.advance(); + emplaceToken(tokens, Token::Type::DOUBLE_NUMBER, scanState); } // Otherwise, emplace SCALAR_NUMBER token else { diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 68325e193b..dcc1c7ef0b 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -24,6 +24,49 @@ namespace Type = GeNN::Type; //--------------------------------------------------------------------------- namespace { + bool checkPointerTypeAssignement(const Type::Base *rightType, const Type::Base *leftType, const Type::TypeContext &typeContext) +{ + // If both are pointers, recurse through value type + auto rightPointerType = dynamic_cast(rightType); + auto leftPointerType = dynamic_cast(leftType); + if (rightPointerType && leftPointerType) { + return checkPointerTypeAssignement(rightPointerType->getValueType(), leftPointerType->getValueType(), typeContext); + } + // Otherwise, if we've hit the value type at the end of the chain, check resolved names match + else if (!rightPointerType && !leftPointerType) { + return (rightType->getResolvedName(typeContext) == leftType->getResolvedName(typeContext)); + } + // Otherwise, pointers with different levels of indirection e.g. int* and int** are being compared + else { + return false; + } +} + +bool checkForConstRemoval(const Type::Base *rightType, const Type::Base *leftType) +{ + // If const is being removed + if (rightType->hasQualifier(Type::Qualifier::CONSTANT) && !leftType->hasQualifier(Type::Qualifier::CONSTANT)) { + return false; + } + + // If both are pointers, recurse through value type + auto rightPointerType = dynamic_cast(rightType); + auto leftPointerType = dynamic_cast(leftType); + if (rightPointerType && leftPointerType) { + return checkForConstRemoval(rightPointerType->getValueType(), leftPointerType->getValueType()); + } + // Otherwise, if both are non-pointers, return true as const removal has been succesfully checked + else if (!rightPointerType && !leftPointerType) { + return true; + } + // Otherwise, pointers with different levels of indirection e.g. int* and int** are being compared + else { + return false; + } + +} + + //--------------------------------------------------------------------------- // EnvironmentInternal //--------------------------------------------------------------------------- @@ -285,9 +328,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Evaluate type of expression we're casting const auto rightType = evaluateType(cast.getExpression()); - + // If const is being removed - if (rightType->hasQualifier(Type::Qualifier::CONSTANT) && !cast.getType()->hasQualifier(Type::Qualifier::CONSTANT)) { + if (!checkForConstRemoval(rightType, cast.getType())) { m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } @@ -298,7 +341,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto leftNumericType = dynamic_cast(cast.getType()); auto leftPointerType = dynamic_cast(cast.getType()); if (rightPointerType && leftPointerType) { - if (rightPointerType->getResolvedName(m_Context) != leftPointerType->getResolvedName(m_Context)) { + // Check that value type at the end matches + if (!checkPointerTypeAssignement(rightPointerType->getValueType(), leftPointerType->getValueType(), m_Context)) { m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); throw TypeCheckError(); } @@ -345,7 +389,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor m_Type = Type::Double::getInstance(); } else if (literal.getValue().type == Token::Type::FLOAT_NUMBER) { - m_Type = Type::Double::getInstance(); + m_Type = Type::Float::getInstance(); } else if (literal.getValue().type == Token::Type::SCALAR_NUMBER) { // **TODO** cache @@ -587,7 +631,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor expression->accept(*this); return m_Type; } - + //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- @@ -598,7 +642,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor bool m_InLoop; bool m_InSwitch; }; -} +} // Anonymous namespace //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker::EnvironmentBase @@ -622,14 +666,14 @@ const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, if(op == Token::Type::EQUAL) { // If we're initialising a pointer with another pointer if (pointerAssignedType && pointerExistingType) { - // If we're trying to assign a pointer to a const value to a pointer - if (assignedType->hasQualifier(Type::Qualifier::CONSTANT) && !existingType->hasQualifier(Type::Qualifier::CONSTANT)) { + // Check that value type at the end matches + if (!checkPointerTypeAssignement(pointerAssignedType->getValueType(), pointerExistingType->getValueType(), context)) { errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName() + "' and '" + pointerAssignedType->getName()); throw TypeCheckError(); } - // If pointer types aren't compatible - if (pointerExistingType->getResolvedName(context) != pointerAssignedType->getResolvedName(context)) { + // If we're trying to make type less const + if (!checkForConstRemoval(pointerAssignedType, pointerExistingType)) { errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName() + "' and '" + pointerAssignedType->getName()); throw TypeCheckError(); } diff --git a/tests/unit/scanner.cc b/tests/unit/scanner.cc index 1f366c05f1..3d5c102bf3 100644 --- a/tests/unit/scanner.cc +++ b/tests/unit/scanner.cc @@ -112,10 +112,10 @@ TEST(Scanner, DecimalFloat) ASSERT_EQ(tokens[7].type, Token::Type::FLOAT_NUMBER); ASSERT_EQ(tokens[8].type, Token::Type::END_OF_FILE); - ASSERT_EQ(tokens[0].lexeme, "1.0f"); - ASSERT_EQ(tokens[1].lexeme, "0.2f"); + ASSERT_EQ(tokens[0].lexeme, "1.0"); + ASSERT_EQ(tokens[1].lexeme, "0.2"); ASSERT_EQ(tokens[2].lexeme, "100.0f"); ASSERT_EQ(tokens[3].lexeme, "0.2f"); - ASSERT_EQ(tokens[5].lexeme, "12.0"); + ASSERT_EQ(tokens[5].lexeme, "12.0d"); ASSERT_EQ(tokens[7].lexeme, "0.0004f"); } \ No newline at end of file diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index c08644cfbc..81cdc19cf1 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -140,7 +140,7 @@ std::string getPointerTypeName() return T::getInstance()->getPointerType()->getName(); } -void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment, const Type::NumericBase *scalarType = Type::Float::getInstance()) +void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment, const Type::TypeContext &typeContext = {}) { // Scan TestErrorHandler errorHandler; @@ -152,12 +152,11 @@ void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment ASSERT_FALSE(errorHandler.hasError()); // Typecheck - const Type::TypeContext typeContext{{"scalar", scalarType}}; TypeChecker::typeCheck(statements, typeEnvironment, typeContext, errorHandler); ASSERT_FALSE(errorHandler.hasError()); } -const Type::Base *typeCheckExpression(std::string_view code, TestEnvironment &typeEnvironment, const Type::NumericBase *scalarType = Type::Float::getInstance()) +const Type::Base *typeCheckExpression(std::string_view code, TestEnvironment &typeEnvironment, const Type::TypeContext &typeContext = {}) { // Scan TestErrorHandler errorHandler; @@ -169,7 +168,6 @@ const Type::Base *typeCheckExpression(std::string_view code, TestEnvironment &ty EXPECT_FALSE(errorHandler.hasError()); // Typecheck - const Type::TypeContext typeContext{{"scalar", scalarType}}; const auto *type = TypeChecker::typeCheck(expression.get(), typeEnvironment, typeContext, errorHandler); EXPECT_FALSE(errorHandler.hasError()); return type; @@ -417,8 +415,9 @@ TEST(TypeChecker, Literal) // Scalar with single-precision { TestEnvironment typeEnvironment; + const Type::TypeContext typeContext{{"scalar", Type::Float::getInstance()}}; const auto *type = typeCheckExpression("1.0", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); + EXPECT_EQ(type->getResolvedName(typeContext), Type::Float::getInstance()->getName()); //EXPECT_TRUE(type.constValue); //EXPECT_FALSE(type.constPointer); } @@ -426,8 +425,9 @@ TEST(TypeChecker, Literal) // Scalar with double-precision { TestEnvironment typeEnvironment; - const auto *type = typeCheckExpression("1.0", typeEnvironment, Type::Double::getInstance()); - EXPECT_EQ(type->getName(), Type::Double::getInstance()->getName()); + const Type::TypeContext typeContext{{"scalar", Type::Double::getInstance()}}; + const auto *type = typeCheckExpression("1.0", typeEnvironment, typeContext); + EXPECT_EQ(type->getResolvedName(typeContext), Type::Double::getInstance()->getName()); //EXPECT_TRUE(type.constValue); //EXPECT_FALSE(type.constPointer); } From cdbc7af99a970db70725735ff19654333fd1a3f1 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 17:44:49 +0000 Subject: [PATCH 077/725] Fixed some warnings --- include/genn/genn/gennUtils.h | 20 ------------ include/genn/genn/models.h | 7 ++-- .../backends/single_threaded_cpu/backend.cc | 10 +++--- src/genn/genn/gennUtils.cc | 32 ------------------- tests/unit/modelSpec.cc | 2 +- 5 files changed, 8 insertions(+), 63 deletions(-) diff --git a/include/genn/genn/gennUtils.h b/include/genn/genn/gennUtils.h index 94c29249c6..3aabbdb2e6 100644 --- a/include/genn/genn/gennUtils.h +++ b/include/genn/genn/gennUtils.h @@ -40,26 +40,6 @@ GENN_EXPORT bool isRNGRequired(const std::string &code); //-------------------------------------------------------------------------- GENN_EXPORT bool isRNGRequired(const std::unordered_map &varInitialisers); -//-------------------------------------------------------------------------- -//! \brief Function to determine whether a string containing a type is a pointer -//-------------------------------------------------------------------------- -GENN_EXPORT bool isTypePointer(const std::string &type); - -//-------------------------------------------------------------------------- -//! \brief Function to determine whether a string containing a type is a pointer to a pointer -//-------------------------------------------------------------------------- -GENN_EXPORT bool isTypePointerToPointer(const std::string &type); - -//-------------------------------------------------------------------------- -//! \brief Function to determine whether a string containing a type is floating point -//-------------------------------------------------------------------------- -GENN_EXPORT bool isTypeFloatingPoint(const std::string &type); - -//-------------------------------------------------------------------------- -//! \brief Assuming type is a string containing a pointer type, function to return the underlying type -//-------------------------------------------------------------------------- -GENN_EXPORT std::string getUnderlyingType(const std::string &type); - //-------------------------------------------------------------------------- //! \brief Is the variable name valid? GeNN variable names must obey C variable naming rules //-------------------------------------------------------------------------- diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 25f2d3b0c8..3c4bc42398 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -8,8 +8,9 @@ #include // GeNN includes -#include "snippet.h" #include "initVarSnippet.h" +#include "snippet.h" +#include "type.h" #include "varAccess.h" // Forward declarations @@ -24,10 +25,6 @@ class CurrentSource; class NeuronGroupInternal; class SynapseGroupInternal; class CurrentSourceInternal; -namespace Type -{ -class NumericBase; -} } //---------------------------------------------------------------------------- diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index ef24c18ac4..995811a87a 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1286,21 +1286,21 @@ void Backend::genStepTimeFinalisePreamble(CodeStream &, const ModelSpecMerged &) //-------------------------------------------------------------------------- void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &, const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc) const + VarLocation) const { definitions << "EXPORT_VAR " << type->getPointerType()->getResolvedName(typeContext) << " " << name << ";" << std::endl; } //-------------------------------------------------------------------------- void Backend::genVariableInstantiation(CodeStream &os, const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc) const + VarLocation) const { os << type->getPointerType()->getResolvedName(typeContext) << " " << name << ";" << std::endl; } //-------------------------------------------------------------------------- void Backend::genVariableAllocation(CodeStream &os, const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, size_t count, MemAlloc &memAlloc) const + VarLocation, size_t count, MemAlloc &memAlloc) const { os << name << " = new " << type->getResolvedName(typeContext) << "[" << count << "];" << std::endl; @@ -1309,7 +1309,7 @@ void Backend::genVariableAllocation(CodeStream &os, //-------------------------------------------------------------------------- void Backend::genVariableDynamicAllocation(CodeStream &os, const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, const std::string &countVarName, const std::string &prefix) const + VarLocation, const std::string &countVarName, const std::string &prefix) const { const auto *pointerType = dynamic_cast(type); if (pointerType) { @@ -1320,7 +1320,7 @@ void Backend::genVariableDynamicAllocation(CodeStream &os, } } //-------------------------------------------------------------------------- -void Backend::genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const +void Backend::genVariableFree(CodeStream &os, const std::string &name, VarLocation) const { os << "delete[] " << name << ";" << std::endl; } diff --git a/src/genn/genn/gennUtils.cc b/src/genn/genn/gennUtils.cc index 830d78969a..e944c97369 100644 --- a/src/genn/genn/gennUtils.cc +++ b/src/genn/genn/gennUtils.cc @@ -72,38 +72,6 @@ bool isRNGRequired(const std::unordered_map &varIn }); } //-------------------------------------------------------------------------- -bool isTypePointer(const std::string &type) -{ - return (type.back() == '*'); -} -//-------------------------------------------------------------------------- -bool isTypePointerToPointer(const std::string &type) -{ - const size_t len = type.length(); - return (type[len - 1] == '*' && type[len - 2] == '*'); -} -//-------------------------------------------------------------------------- -bool isTypeFloatingPoint(const std::string &type) -{ - assert(!isTypePointer(type)); - return ((type == "float") || (type == "double") || (type == "half") || (type == "scalar")); -} -//-------------------------------------------------------------------------- -std::string getUnderlyingType(const std::string &type) -{ - // Check that type is a pointer type - assert(isTypePointer(type)); - - // if type is actually a pointer to a pointer, return string without last 2 characters - if(isTypePointerToPointer(type)) { - return type.substr(0, type.length() - 2); - } - // Otherwise, return string without last character - else { - return type.substr(0, type.length() - 1); - } -} -//-------------------------------------------------------------------------- void validateVarName(const std::string &name, const std::string &description) { // Empty names aren't valid diff --git a/tests/unit/modelSpec.cc b/tests/unit/modelSpec.cc index 0f378f7e7e..e8e60e8f5d 100644 --- a/tests/unit/modelSpec.cc +++ b/tests/unit/modelSpec.cc @@ -156,7 +156,7 @@ TEST(ModelSpec, CustomConnectivityUpdateZeroCopy) model.addNeuronPopulation("Neurons0", 10, paramVals, varVals); model.addNeuronPopulation("Neurons1", 10, paramVals, varVals); - SynapseGroup *sg = model.addSynapsePopulation( + model.addSynapsePopulation( "Synapse", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, "Neurons0", "Neurons1", {}, {{"g", 1.0}, {"d", 1}}, From cadf97bc5ce7a13a5fa11fed5a8ba21651c17954 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 18:27:11 +0000 Subject: [PATCH 078/725] started fixing up CUDA backend --- include/genn/backends/cuda/backend.h | 111 ++-- .../backends/single_threaded_cpu/optimiser.h | 2 +- .../genn/genn/code_generator/backendSIMT.h | 3 +- src/genn/backends/cuda/backend.cc | 531 ++++++------------ src/genn/backends/cuda/optimiser.cc | 6 +- .../backends/single_threaded_cpu/optimiser.cc | 2 +- 6 files changed, 249 insertions(+), 406 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index 964ee8a77b..7fca5702f5 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -127,8 +127,7 @@ struct Preferences : public PreferencesBase class BACKEND_EXPORT Backend : public BackendSIMT { public: - Backend(const KernelBlockSize &kernelBlockSizes, const Preferences &preferences, - const std::string &scalarType, int device); + Backend(const KernelBlockSize &kernelBlockSizes, const Preferences &preferences, int device); //-------------------------------------------------------------------------- // CodeGenerator::BackendSIMT virtuals @@ -149,7 +148,8 @@ class BACKEND_EXPORT Backend : public BackendSIMT virtual std::string getCLZ() const override { return "__clz"; } //! Get name of atomic operation - virtual std::string getAtomic(const std::string &type, AtomicOperation op = AtomicOperation::ADD, + virtual std::string getAtomic(const Type::NumericBase *type, const Type::TypeContext &typeContext, + AtomicOperation op = AtomicOperation::ADD, AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const override; //! Generate a shared memory barrier @@ -186,55 +186,69 @@ class BACKEND_EXPORT Backend : public BackendSIMT virtual void genFreeMemPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; virtual void genStepTimeFinalisePreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; - virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, const std::string &type, const std::string &name, VarLocation loc) const override; - virtual void genVariableImplementation(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc) const override; - virtual void genVariableAllocation(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const override; - virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const override; - - virtual void genExtraGlobalParamDefinition(CodeStream &definitions, CodeStream &definitionsInternal, const std::string &type, const std::string &name, VarLocation loc) const override; - virtual void genExtraGlobalParamImplementation(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc) const override; - virtual void genExtraGlobalParamAllocation(CodeStream &os, const std::string &type, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const override; - virtual void genExtraGlobalParamPush(CodeStream &os, const std::string &type, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const override; - virtual void genExtraGlobalParamPull(CodeStream &os, const std::string &type, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const override; - - //! Generate code for pushing an updated EGP value into the merged group structure on 'device' - virtual void genMergedExtraGlobalParamPush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, - const std::string &groupIdx, const std::string &fieldName, - const std::string &egpName) const override; + //! Generate code to define a variable in the appropriate header file + virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc) const final; + + //! Generate code to instantiate a variable in the provided stream + virtual void genVariableInstantiation(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc) const final; + + //! Generate code to allocate variable with a size known at compile-time + virtual void genVariableAllocation(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, size_t count, MemAlloc &memAlloc) const final; + + //! Generate code to allocate variable with a size known at runtime + virtual void genVariableDynamicAllocation(CodeStream &os, + const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + + //! Generate code to free a variable + virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const final; + + //! Generate code for pushing a variable with a size known at compile-time to the 'device' + virtual void genVariablePush(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, bool autoInitialized, size_t count) const final; + + //! Generate code for pulling a variable with a size known at compile-time from the 'device' + virtual void genVariablePull(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, size_t count) const final; + + //! Generate code for pushing a variable's value in the current timestep to the 'device' + virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, unsigned int batchSize) const final; + + //! Generate code for pulling a variable's value in the current timestep from the 'device' + virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, unsigned int batchSize) const final; + + //! Generate code for pushing a variable with a size known at tuntime to the 'device' + virtual void genVariableDynamicPush(CodeStream &os, + const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + + //! Generate code for pulling a variable with a size known at runtime from the 'device' + virtual void genVariableDynamicPull(CodeStream &os, + const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + + //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' + virtual void genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, + const std::string &groupIdx, const std::string &fieldName, + const std::string &egpName) const final; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const override; //! When generating merged structures what type to use for simulation RNGs - virtual const Type::Base *getMergedGroupSimRNGType() const override; - - virtual void genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const override; - virtual void genVariablePull(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, size_t count) const override; - - virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, const std::string &type, - const std::string &name, VarLocation loc, unsigned int batchSize) const override; - virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, const std::string &type, - const std::string &name, VarLocation loc, unsigned int batchSize) const override; - - virtual void genCurrentTrueSpikePush(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize) const override - { - genCurrentSpikePush(os, ng, batchSize, false); - } - virtual void genCurrentTrueSpikePull(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize) const override - { - genCurrentSpikePull(os, ng, batchSize, false); - } - virtual void genCurrentSpikeLikeEventPush(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize) const override - { - genCurrentSpikePush(os, ng, batchSize, true); - } - virtual void genCurrentSpikeLikeEventPull(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize) const override - { - genCurrentSpikePull(os, ng, batchSize, true); - } + virtual const Type::ValueBase *getMergedGroupSimRNGType() const override; virtual void genGlobalDeviceRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, @@ -336,9 +350,6 @@ class BACKEND_EXPORT Backend : public BackendSIMT return m_ChosenDevice.totalConstMem - getPreferences().constantCacheOverhead; } - void genCurrentSpikePush(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize, bool spikeEvent) const; - void genCurrentSpikePull(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize, bool spikeEvent) const; - void genKernelDimensions(CodeStream &os, Kernel kernel, size_t numThreadsX, size_t batchSize, size_t numBlockThreadsY = 1) const; //-------------------------------------------------------------------------- diff --git a/include/genn/backends/single_threaded_cpu/optimiser.h b/include/genn/backends/single_threaded_cpu/optimiser.h index 0cf261a68b..112602553e 100644 --- a/include/genn/backends/single_threaded_cpu/optimiser.h +++ b/include/genn/backends/single_threaded_cpu/optimiser.h @@ -19,7 +19,7 @@ class IAppender; //-------------------------------------------------------------------------- namespace GeNN::CodeGenerator::SingleThreadedCPU::Optimiser { -BACKEND_EXPORT Backend createBackend(const filesystem::path &outputPath, +BACKEND_EXPORT Backend createBackend(const ModelSpecInternal &model, const filesystem::path &outputPath, plog::Severity backendLevel, plog::IAppender *backendAppender, const Preferences &preferences); } // namespace GeNN::CodeGenerator::SingleThreadedCPU::Optimiser diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 8449d85753..8d7fbeb640 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -91,7 +91,8 @@ class GENN_EXPORT BackendSIMT : public BackendBase virtual std::string getCLZ() const = 0; //! Get name of atomic operation - virtual std::string getAtomic(const Type::NumericBase *type, AtomicOperation op = AtomicOperation::ADD, + virtual std::string getAtomic(const Type::NumericBase *type, const Type::TypeContext &typeContext, + AtomicOperation op = AtomicOperation::ADD, AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const = 0; //! Generate a shared memory barrier diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index d51d339fd9..40d4d76aea 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -44,10 +44,18 @@ const std::vector cudaDoublePrecisionFunctions {"gennrand_binomial", 2, "binomialDistDouble($(rng), $(0), $(1))"} }; +//-------------------------------------------------------------------------- +// CUDADeviceType +//-------------------------------------------------------------------------- +//! Tag class used to mark types which are only usable on device +struct CUDADeviceType +{ +}; + //-------------------------------------------------------------------------- // CURandState //-------------------------------------------------------------------------- -class CURandState : public Type::ValueBase +class CURandState : public Type::ValueBase, public CUDADeviceType { public: DECLARE_TYPE(CURandState); @@ -58,27 +66,29 @@ class CURandState : public Type::ValueBase // Base overloads //------------------------------------------------------------------------ virtual std::string getName() const final{ return "curandState"; } + virtual std::string getResolvedName(const Type::TypeContext&) const final{ return "curandState"; } virtual Base *getQualifiedType(Type::Qualifier qualifiers) const { return new CURandState(qualifiers); } - virtual size_t getSizeBytes() const final{ return 44; } + virtual size_t getSizeBytes(const Type::TypeContext&) const final{ return 44; } }; IMPLEMENT_TYPE(CURandState); //-------------------------------------------------------------------------- // CURandStatePhilox43210 //-------------------------------------------------------------------------- -class CURandStatePhilox43210 : public Type::ValueBase +class CURandStatePhilox43210 : public Type::ValueBase, public CUDADeviceType { public: DECLARE_TYPE(CURandStatePhilox43210); - CURandStatePhilox43210(Type::Qualifier qualifiers = Type::Qualifier{0}) : ValueBasese(qualifiers){} + CURandStatePhilox43210(Type::Qualifier qualifiers = Type::Qualifier{0}) : ValueBase(qualifiers){} //------------------------------------------------------------------------ // Base overloads //------------------------------------------------------------------------ virtual std::string getName() const final{ return "curandStatePhilox4_32_10_t"; } + virtual std::string getResolvedName(const Type::TypeContext&) const final{ return "curandStatePhilox4_32_10_t"; } virtual Base *getQualifiedType(Type::Qualifier qualifiers) const { return new CURandStatePhilox43210(qualifiers); } - virtual size_t getSizeBytes() const final{ return 64; } + virtual size_t getSizeBytes(const Type::TypeContext&) const final{ return 64; } }; IMPLEMENT_TYPE(CURandStatePhilox43210); @@ -233,9 +243,9 @@ size_t getGroupStartIDSize(const std::vector &mergedGroups) }); } //----------------------------------------------------------------------- -const std::vector &getFunctionTemplates(const std::string &precision) +const std::vector &getFunctionTemplates(const Type::NumericBase *precision) { - return (precision == "double") ? cudaDoublePrecisionFunctions : cudaSinglePrecisionFunctions; + return (precision->getName() == Type::Double::getInstance()->getName()) ? cudaDoublePrecisionFunctions : cudaSinglePrecisionFunctions; } //----------------------------------------------------------------------- std::string getNCCLReductionType(VarAccessMode mode) @@ -322,9 +332,8 @@ void genNCCLReduction(CodeStream &os, const G &cg, const std::string &precision) //-------------------------------------------------------------------------- namespace GeNN::CodeGenerator::CUDA { -Backend::Backend(const KernelBlockSize &kernelBlockSizes, const Preferences &preferences, - const std::string &scalarType, int device) -: BackendSIMT(kernelBlockSizes, preferences, scalarType), m_ChosenDeviceID(device) +Backend::Backend(const KernelBlockSize &kernelBlockSizes, const Preferences &preferences, int device) +: BackendSIMT(kernelBlockSizes, preferences), m_ChosenDeviceID(device) { // Set device CHECK_CUDA_ERRORS(cudaSetDevice(device)); @@ -388,12 +397,14 @@ std::string Backend::getBlockID(unsigned int axis) const } } //-------------------------------------------------------------------------- -std::string Backend::getAtomic(const std::string &type, AtomicOperation op, AtomicMemSpace) const +std::string Backend::getAtomic(const Type::NumericBase *type, const Type::TypeContext &typeContext, + AtomicOperation op, AtomicMemSpace) const { // If operation is an atomic add + const std::string typeName = type->getResolvedName(typeContext); if(op == AtomicOperation::ADD) { - if(((getChosenCUDADevice().major < 2) && (type == "float")) - || (((getChosenCUDADevice().major < 6) || (getRuntimeVersion() < 8000)) && (type == "double"))) + if(((getChosenCUDADevice().major < 2) && (typeName == Type::Float::getInstance()->getName())) + || (((getChosenCUDADevice().major < 6) || (getRuntimeVersion() < 8000)) && (typeName == Type::Double::getInstance()->getName()))) { return "atomicAddSW"; } @@ -403,7 +414,7 @@ std::string Backend::getAtomic(const std::string &type, AtomicOperation op, Atom // Otherwise, it's an atomic or else { assert(op == AtomicOperation::OR); - assert(type == "unsigned int" || type == "int"); + assert(typeName == Type::Uint32::getInstance()->getName() || typeName == Type::Int32::getInstance()->getName()); return "atomicOr"; } } @@ -604,11 +615,11 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge // If there are any presynaptic update groups size_t idPresynapticStart = 0; if(!modelMerged.getMergedPresynapticUpdateGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelPresynapticUpdate] << "(" << model.getTimePrecision() << " t)" << std::endl; // end of synapse kernel header + os << "extern \"C\" __global__ void " << KernelNames[KernelPresynapticUpdate] << "(" << model.getTimePrecision()->getName() << " t)" << std::endl; // end of synapse kernel header { CodeStream::Scope b(os); - Substitutions kernelSubs((model.getPrecision() == "double") ? cudaDoublePrecisionFunctions : cudaSinglePrecisionFunctions); + Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); kernelSubs.addVarSubstitution("t", "t"); os << "const unsigned int id = " << getKernelBlockSize(KernelPresynapticUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; @@ -626,11 +637,11 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge // If any synapse groups require postsynaptic learning size_t idPostsynapticStart = 0; if(!modelMerged.getMergedPostsynapticUpdateGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelPostsynapticUpdate] << "(" << model.getTimePrecision() << " t)" << std::endl; + os << "extern \"C\" __global__ void " << KernelNames[KernelPostsynapticUpdate] << "(" << model.getTimePrecision()->getName() << " t)" << std::endl; { CodeStream::Scope b(os); - Substitutions kernelSubs((model.getPrecision() == "double") ? cudaDoublePrecisionFunctions : cudaSinglePrecisionFunctions); + Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); kernelSubs.addVarSubstitution("t", "t"); os << "const unsigned int id = " << getKernelBlockSize(KernelPostsynapticUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; @@ -647,7 +658,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge size_t idSynapseDynamicsStart = 0; if(!modelMerged.getMergedSynapseDynamicsGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDynamicsUpdate] << "(" << model.getTimePrecision() << " t)" << std::endl; // end of synapse kernel header + os << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDynamicsUpdate] << "(" << model.getTimePrecision()->getName() << " t)" << std::endl; // end of synapse kernel header { CodeStream::Scope b(os); @@ -666,7 +677,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge } } - os << "void updateSynapses(" << model.getTimePrecision() << " t)"; + os << "void updateSynapses(" << model.getTimePrecision()->getName() << " t)"; { CodeStream::Scope b(os); @@ -778,7 +789,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - Substitutions kernelSubs((model.getPrecision() == "double") ? cudaDoublePrecisionFunctions : cudaSinglePrecisionFunctions); + Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); kernelSubs.addVarSubstitution("t", "t"); os << "const unsigned int id = " << getKernelBlockSize(KernelCustomUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; @@ -809,7 +820,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - Substitutions kernelSubs((model.getPrecision() == "double") ? cudaDoublePrecisionFunctions : cudaSinglePrecisionFunctions); + Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); kernelSubs.addVarSubstitution("t", "t"); os << "const unsigned int id = " << getKernelBlockSize(KernelCustomTransposeUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; @@ -1555,74 +1566,66 @@ void Backend::genStepTimeFinalisePreamble(CodeStream &os, const ModelSpecMerged } } //-------------------------------------------------------------------------- -void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, const std::string &type, const std::string &name, VarLocation loc) const +void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc) const { - const bool deviceType = isDeviceType(type); - - if(getPreferences().automaticCopy && ::Utils::isTypePointer(type)) { + const bool deviceType = dynamic_cast(type); + CodeStream &d = deviceType ? definitionsInternal : definitions; + const std::string pointerTypeName = type->getPointerType()->getResolvedName(typeContext); + if(getPreferences().automaticCopy) { // Export pointer, either in definitionsInternal if variable has a device type // or to definitions if it should be accessable on host - CodeStream &d = deviceType ? definitionsInternal : definitions; - d << "EXPORT_VAR " << type << " " << name << ";" << std::endl; + d << "EXPORT_VAR " << pointerTypeName << " " << name << ";" << std::endl; } else { if(loc & VarLocation::HOST) { if(deviceType) { - throw std::runtime_error("Variable '" + name + "' is of device-only type '" + type + "' but is located on the host"); + throw std::runtime_error("Variable '" + name + "' is of device-only type '" + pointerTypeName + "' but is located on the host"); } - definitions << "EXPORT_VAR " << type << " " << name << ";" << std::endl; + definitions << "EXPORT_VAR " << pointerTypeName << " " << name << ";" << std::endl; } if(loc & VarLocation::DEVICE) { - // If the type is a pointer type we need a device pointer - if(::Utils::isTypePointer(type)) { - // Write host definition to internal definitions stream if type is device only - CodeStream &d = deviceType ? definitionsInternal : definitions; - d << "EXPORT_VAR " << type << " d_" << name << ";" << std::endl; - } - // Otherwise we just need a device variable, made volatile for safety - else { - definitionsInternal << "EXPORT_VAR __device__ volatile " << type << " d_" << name << ";" << std::endl; - } + // Write host definition to internal definitions stream if type is device only + d << "EXPORT_VAR " << pointerTypeName << " d_" << name << ";" << std::endl; + } } - - } //-------------------------------------------------------------------------- -void Backend::genVariableImplementation(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc) const +void Backend::genVariableInstantiation(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc) const { - if(getPreferences().automaticCopy && ::Utils::isTypePointer(type)) { - os << type << " " << name << ";" << std::endl; + const std::string pointerTypeName = type->getPointerType()->getResolvedName(typeContext); + if(getPreferences().automaticCopy) { + os << pointerTypeName << " " << name << ";" << std::endl; } else { if(loc & VarLocation::HOST) { - os << type << " " << name << ";" << std::endl; + os << pointerTypeName << " " << name << ";" << std::endl; } if(loc & VarLocation::DEVICE) { - // If the type is a pointer type we need a host and a device pointer - if(::Utils::isTypePointer(type)) { - os << type << " d_" << name << ";" << std::endl; - } - // Otherwise we just need a device variable, made volatile for safety - else { - os << "__device__ volatile " << type << " d_" << name << ";" << std::endl; - } + os << pointerTypeName << " d_" << name < V< ";" << std::endl; } } } //-------------------------------------------------------------------------- -void Backend::genVariableAllocation(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const +void Backend::genVariableAllocation(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, size_t count, MemAlloc &memAlloc) const { + const std::string typeName = type->getResolvedName(typeContext); if(getPreferences().automaticCopy) { - os << "CHECK_CUDA_ERRORS(cudaMallocManaged(&" << name << ", " << count << " * sizeof(" << type << ")));" << std::endl; - memAlloc += MemAlloc::device(count * getSize(type)); + os << "CHECK_CUDA_ERRORS(cudaMallocManaged(&" << name << ", " << count << " * sizeof(" << typeName << ")));" << std::endl; + memAlloc += MemAlloc::device(count * type->getSizeBytes(typeContext)); } else { if(loc & VarLocation::HOST) { const char *flags = (loc & VarLocation::ZERO_COPY) ? "cudaHostAllocMapped" : "cudaHostAllocPortable"; - os << "CHECK_CUDA_ERRORS(cudaHostAlloc(&" << name << ", " << count << " * sizeof(" << type << "), " << flags << "));" << std::endl; - memAlloc += MemAlloc::host(count * getSize(type)); + os << "CHECK_CUDA_ERRORS(cudaHostAlloc(&" << name << ", " << count << " * sizeof(" << typeName << "), " << flags << "));" << std::endl; + memAlloc += MemAlloc::host(count * type->getSizeBytes(typeContext)); } // If variable is present on device at all @@ -1630,82 +1633,33 @@ void Backend::genVariableAllocation(CodeStream &os, const std::string &type, con // Insert call to correct helper depending on whether variable should be allocated in zero-copy mode or not if(loc & VarLocation::ZERO_COPY) { os << "CHECK_CUDA_ERRORS(cudaHostGetDevicePointer((void **)&d_" << name << ", (void *)" << name << ", 0));" << std::endl; - memAlloc += MemAlloc::zeroCopy(count * getSize(type)); + memAlloc += MemAlloc::zeroCopy(count * type->getSizeBytes(typeContext)); } else { - os << "CHECK_CUDA_ERRORS(cudaMalloc(&d_" << name << ", " << count << " * sizeof(" << type << ")));" << std::endl; - memAlloc += MemAlloc::device(count * getSize(type)); + os << "CHECK_CUDA_ERRORS(cudaMalloc(&d_" << name << ", " << count << " * sizeof(" << typeName << ")));" << std::endl; + memAlloc += MemAlloc::device(count * type->getSizeBytes(typeContext)); } } } } //-------------------------------------------------------------------------- -void Backend::genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const -{ - if(getPreferences().automaticCopy) { - os << "CHECK_CUDA_ERRORS(cudaFree(" << name << "));" << std::endl; - } - else { - // **NOTE** because we pinned the variable we need to free it with cudaFreeHost rather than use the host code generator - if(loc & VarLocation::HOST) { - os << "CHECK_CUDA_ERRORS(cudaFreeHost(" << name << "));" << std::endl; - } - - // If this variable wasn't allocated in zero-copy mode, free it - if((loc & VarLocation::DEVICE) && !(loc & VarLocation::ZERO_COPY)) { - os << "CHECK_CUDA_ERRORS(cudaFree(d_" << name << "));" << std::endl; - } - } -} -//-------------------------------------------------------------------------- -void Backend::genExtraGlobalParamDefinition(CodeStream &definitions, CodeStream &, - const std::string &type, const std::string &name, VarLocation loc) const -{ - if(getPreferences().automaticCopy) { - definitions << "EXPORT_VAR " << type << " " << name << ";" << std::endl; - } - else { - if(loc & VarLocation::HOST) { - definitions << "EXPORT_VAR " << type << " " << name << ";" << std::endl; - } - if(loc & VarLocation::DEVICE && ::Utils::isTypePointer(type)) { - definitions << "EXPORT_VAR " << type << " d_" << name << ";" << std::endl; - } - } -} -//-------------------------------------------------------------------------- -void Backend::genExtraGlobalParamImplementation(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc) const -{ - if(getPreferences().automaticCopy) { - os << type << " " << name << ";" << std::endl; - } - else { - if(loc & VarLocation::HOST) { - os << type << " " << name << ";" << std::endl; - } - if(loc & VarLocation::DEVICE && ::Utils::isTypePointer(type)) { - os << type << " d_" << name << ";" << std::endl; - } - } -} -//-------------------------------------------------------------------------- -void Backend::genExtraGlobalParamAllocation(CodeStream &os, const std::string &type, const std::string &name, - VarLocation loc, const std::string &countVarName, const std::string &prefix) const +void Backend::genVariableDynamicAllocation(CodeStream &os, + const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, const std::string &countVarName, const std::string &prefix) const { - // Get underlying type - const std::string underlyingType = ::Utils::getUnderlyingType(type); - const bool pointerToPointer = ::Utils::isTypePointerToPointer(type); - - const std::string hostPointer = pointerToPointer ? ("*" + prefix + name) : (prefix + name); - const std::string hostPointerToPointer = pointerToPointer ? (prefix + name) : ("&" + prefix + name); - const std::string devicePointerToPointer = pointerToPointer ? (prefix + "d_" + name) : ("&" + prefix + "d_" + name); + const auto *pointerType = dynamic_cast(type); + const auto *underlyingType = pointerType ? pointerType->getValueType() : type; + const std::string underlyingTypeName = underlyingType->getResolvedName(typeContext); + const std::string hostPointer = pointerType ? ("*" + prefix + name) : (prefix + name); + const std::string hostPointerToPointer = pointerType ? (prefix + name) : ("&" + prefix + name); + const std::string devicePointerToPointer = pointerType ? (prefix + "d_" + name) : ("&" + prefix + "d_" + name); if(getPreferences().automaticCopy) { - os << "CHECK_CUDA_ERRORS(cudaMallocManaged(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType << ")));" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaMallocManaged(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingTypeName << ")));" << std::endl; } else { if(loc & VarLocation::HOST) { const char *flags = (loc & VarLocation::ZERO_COPY) ? "cudaHostAllocMapped" : "cudaHostAllocPortable"; - os << "CHECK_CUDA_ERRORS(cudaHostAlloc(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType << "), " << flags << "));" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaHostAlloc(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingTypeName << "), " << flags << "));" << std::endl; } // If variable is present on device at all @@ -1714,71 +1668,33 @@ void Backend::genExtraGlobalParamAllocation(CodeStream &os, const std::string &t os << "CHECK_CUDA_ERRORS(cudaHostGetDevicePointer((void**)" << devicePointerToPointer << ", (void*)" << hostPointer << ", 0));" << std::endl; } else { - os << "CHECK_CUDA_ERRORS(cudaMalloc(" << devicePointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType << ")));" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaMalloc(" << devicePointerToPointer << ", " << countVarName << " * sizeof(" << underlyingTypeName << ")));" << std::endl; } } } } //-------------------------------------------------------------------------- -void Backend::genExtraGlobalParamPush(CodeStream &os, const std::string &type, const std::string &name, - VarLocation loc, const std::string &countVarName, const std::string &prefix) const +void Backend::genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const { - assert(!getPreferences().automaticCopy); - - if(!(loc & VarLocation::ZERO_COPY)) { - // Get underlying type - const std::string underlyingType = ::Utils::getUnderlyingType(type); - const bool pointerToPointer = ::Utils::isTypePointerToPointer(type); - - const std::string hostPointer = pointerToPointer ? ("*" + prefix + name) : (prefix + name); - const std::string devicePointer = pointerToPointer ? ("*" + prefix + "d_" + name) : (prefix + "d_" + name); - - os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << devicePointer; - os << ", " << hostPointer; - os << ", " << countVarName << " * sizeof(" << underlyingType << "), cudaMemcpyHostToDevice));" << std::endl; + if(getPreferences().automaticCopy) { + os << "CHECK_CUDA_ERRORS(cudaFree(" << name << "));" << std::endl; } -} -//-------------------------------------------------------------------------- -void Backend::genExtraGlobalParamPull(CodeStream &os, const std::string &type, const std::string &name, - VarLocation loc, const std::string &countVarName, const std::string &prefix) const -{ - assert(!getPreferences().automaticCopy); - - if(!(loc & VarLocation::ZERO_COPY)) { - // Get underlying type - const std::string underlyingType = ::Utils::getUnderlyingType(type); - const bool pointerToPointer = ::Utils::isTypePointerToPointer(type); - - const std::string hostPointer = pointerToPointer ? ("*" + prefix + name) : (prefix + name); - const std::string devicePointer = pointerToPointer ? ("*" + prefix + "d_" + name) : (prefix + "d_" + name); + else { + // **NOTE** because we pinned the variable we need to free it with cudaFreeHost rather than use the host code generator + if(loc & VarLocation::HOST) { + os << "CHECK_CUDA_ERRORS(cudaFreeHost(" << name << "));" << std::endl; + } - os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << hostPointer; - os << ", " << devicePointer; - os << ", " << countVarName << " * sizeof(" << underlyingType << "), cudaMemcpyDeviceToHost));" << std::endl; + // If this variable wasn't allocated in zero-copy mode, free it + if((loc & VarLocation::DEVICE) && !(loc & VarLocation::ZERO_COPY)) { + os << "CHECK_CUDA_ERRORS(cudaFree(d_" << name << "));" << std::endl; + } } } //-------------------------------------------------------------------------- -void Backend::genMergedExtraGlobalParamPush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, - const std::string &groupIdx, const std::string &fieldName, - const std::string &egpName) const -{ - const std::string structName = "Merged" + suffix + "Group" + std::to_string(mergedGroupIdx); - os << "CHECK_CUDA_ERRORS(cudaMemcpyToSymbolAsync(d_merged" << suffix << "Group" << mergedGroupIdx; - os << ", &" << egpName << ", sizeof(" << egpName << ")"; - os << ", (sizeof(" << structName << ") * (" << groupIdx << ")) + offsetof(" << structName << ", " << fieldName << ")));" << std::endl; -} -//-------------------------------------------------------------------------- -std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const -{ - return type->getResolvedName(context); -} -//-------------------------------------------------------------------------- -const Type::Base *Backend::getMergedGroupSimRNGType() const -{ - return CLRRNGLFSR113Stream::getInstance(); -} -//-------------------------------------------------------------------------- -void Backend::genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const +void Backend::genVariablePush(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, bool autoInitialized, size_t count) const { assert(!getPreferences().automaticCopy); @@ -1790,7 +1706,7 @@ void Backend::genVariablePush(CodeStream &os, const std::string &type, const std os << "CHECK_CUDA_ERRORS(cudaMemcpy(d_" << name; os << ", " << name; - os << ", " << count << " * sizeof(" << type << "), cudaMemcpyHostToDevice));" << std::endl; + os << ", " << count << " * sizeof(" << type->getResolvedName(typeContext) << "), cudaMemcpyHostToDevice));" << std::endl; if(autoInitialized) { os << CodeStream::CB(1101); @@ -1798,74 +1714,144 @@ void Backend::genVariablePush(CodeStream &os, const std::string &type, const std } } //-------------------------------------------------------------------------- -void Backend::genVariablePull(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, size_t count) const +void Backend::genVariablePull(CodeStream &os, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, size_t count) const { assert(!getPreferences().automaticCopy); if(!(loc & VarLocation::ZERO_COPY)) { os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << name; os << ", d_" << name; - os << ", " << count << " * sizeof(" << type << "), cudaMemcpyDeviceToHost));" << std::endl; + os << ", " << count << " * sizeof(" << type->getResolvedName(typeContext) << "), cudaMemcpyDeviceToHost));" << std::endl; } } //-------------------------------------------------------------------------- -void Backend::genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, const std::string &type, - const std::string &name, VarLocation loc, unsigned int batchSize) const +void Backend::genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, unsigned int batchSize) const { assert(!getPreferences().automaticCopy); // If this variable requires queuing and isn't zero-copy if(ng.isVarQueueRequired(name) && ng.isDelayRequired() && !(loc & VarLocation::ZERO_COPY)) { // If batch size is one, generate 1D memcpy to copy current timestep's data + const std::string typeName = type->getResolvedName(typeContext); if(batchSize == 1) { os << "CHECK_CUDA_ERRORS(cudaMemcpy(d_" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; os << ", " << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() << " * sizeof(" << type << "), cudaMemcpyHostToDevice));" << std::endl; + os << ", " << ng.getNumNeurons() << " * sizeof(" << typeName << "), cudaMemcpyHostToDevice));" << std::endl; } // Otherwise, perform a 2D memcpy to copy current timestep's data from each batch else { os << "CHECK_CUDA_ERRORS(cudaMemcpy2D(d_" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type << ")"; + os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << typeName << ")"; os << ", " << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type << ")"; - os << ", " << ng.getNumNeurons() << " * sizeof(" << type << ")"; + os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << typeName << ")"; + os << ", " << ng.getNumNeurons() << " * sizeof(" << typeName << ")"; os << ", " << batchSize << ", cudaMemcpyHostToDevice));" << std::endl; } } // Otherwise, generate standard push else { - genVariablePush(os, type, name + ng.getName(), loc, false, ng.getNumNeurons() * batchSize); + genVariablePush(os, type, typeContext, name + ng.getName(), loc, false, ng.getNumNeurons() * batchSize); } } //-------------------------------------------------------------------------- -void Backend::genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, const std::string &type, - const std::string &name, VarLocation loc, unsigned int batchSize) const +void Backend::genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, + const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, unsigned int batchSize) const { assert(!getPreferences().automaticCopy); // If this variable requires queuing and isn't zero-copy if(ng.isVarQueueRequired(name) && ng.isDelayRequired() && !(loc & VarLocation::ZERO_COPY)) { // If batch size is one, generate 1D memcpy to copy current timestep's data + const std::string typeName = type->getResolvedName(typeContext); if(batchSize == 1) { os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; os << ", d_" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() << " * sizeof(" << type << "), cudaMemcpyDeviceToHost));" << std::endl; + os << ", " << ng.getNumNeurons() << " * sizeof(" << typeName << "), cudaMemcpyDeviceToHost));" << std::endl; } else { os << "CHECK_CUDA_ERRORS(cudaMemcpy2D(" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type << ")"; + os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << typeName << ")"; os << ", d_" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type << ")"; - os << ", " << ng.getNumNeurons() << " * sizeof(" << type << ")"; + os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << typeName << ")"; + os << ", " << ng.getNumNeurons() << " * sizeof(" << typeName << ")"; os << ", " << batchSize << ", cudaMemcpyDeviceToHost));" << std::endl; } } // Otherwise, generate standard pull else { - genVariablePull(os, type, name + ng.getName(), loc, ng.getNumNeurons() * batchSize); + genVariablePull(os, type, typeContext, name + ng.getName(), loc, ng.getNumNeurons() * batchSize); } } //-------------------------------------------------------------------------- +void Backend::genVariableDynamicPush(CodeStream &os, + const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, const std::string &countVarName, const std::string &prefix) const +{ + assert(!getPreferences().automaticCopy); + + if(!(loc & VarLocation::ZERO_COPY)) { + const auto *pointerType = dynamic_cast(type); + if (pointerType) { + os << "CHECK_CUDA_ERRORS(cudaMemcpy(*" << prefix << "d_" << name; + os << ", *" << prefix << name; + os << ", " << countVarName << " * sizeof(" << pointerType->getValueType()->getResolvedName(typeContext) << "), cudaMemcpyHostToDevice));" << std::endl; + } + else { + os << prefix << name << " = new " << type->getResolvedName(typeContext) << "[" << countVarName << "];" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << prefix << "d_" << name; + os << ", " << prefix << name; + os << ", " << countVarName << " * sizeof(" << type->getResolvedName(typeContext) << "), cudaMemcpyHostToDevice));" << std::endl; + } + } +} +//-------------------------------------------------------------------------- +void Backend::genVariableDynamicPull(CodeStream &os, + const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, + VarLocation loc, const std::string &countVarName, const std::string &prefix) const +{ + assert(!getPreferences().automaticCopy); + + if(!(loc & VarLocation::ZERO_COPY)) { + const auto *pointerType = dynamic_cast(type); + if (pointerType) { + os << "CHECK_CUDA_ERRORS(cudaMemcpy(*" << prefix << name; + os << ", *" << prefix << "d_" << name; + os << ", " << countVarName << " * sizeof(" << pointerType->getValueType()->getResolvedName(typeContext) << "), cudaMemcpyDeviceToHost));" << std::endl; + } + else { + os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << prefix << name; + os << ", " << prefix << "d_" << name; + os << ", " << countVarName << " * sizeof(" << type->getResolvedName(typeContext) << "), cudaMemcpyDeviceToHost));" << std::endl; + } + + } +} +//-------------------------------------------------------------------------- +void Backend::genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, + const std::string &groupIdx, const std::string &fieldName, + const std::string &egpName) const +{ + const std::string structName = "Merged" + suffix + "Group" + std::to_string(mergedGroupIdx); + os << "CHECK_CUDA_ERRORS(cudaMemcpyToSymbolAsync(d_merged" << suffix << "Group" << mergedGroupIdx; + os << ", &" << egpName << ", sizeof(" << egpName << ")"; + os << ", (sizeof(" << structName << ") * (" << groupIdx << ")) + offsetof(" << structName << ", " << fieldName << ")));" << std::endl; +} +//-------------------------------------------------------------------------- +std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const +{ + return type->getResolvedName(context); +} +//-------------------------------------------------------------------------- +const Type::ValueBase *Backend::getMergedGroupSimRNGType() const +{ + return CURandState::getInstance(); +} +//-------------------------------------------------------------------------- void Backend::genGlobalDeviceRNG(CodeStream &, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &, CodeStream &, MemAlloc &memAlloc) const { // Define global Phillox RNG @@ -1875,11 +1861,11 @@ void Backend::genGlobalDeviceRNG(CodeStream &, CodeStream &definitionsInternal, // Implement global Phillox RNG runner << "__device__ curandStatePhilox4_32_10_t d_rng;" << std::endl; - memAlloc += MemAlloc::device(getSize("curandStatePhilox4_32_10_t")); + memAlloc += MemAlloc::device(CURandStatePhilox43210->getSizeBytes()); } //-------------------------------------------------------------------------- void Backend::genPopulationRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const std::string &name, size_t count, MemAlloc &memAlloc) const + const std::string &name, size_t count, MemAlloc &memAlloc) const { // Create an array or XORWOW RNGs genArray(definitions, definitionsInternal, runner, allocations, free, "curandState", name, VarLocation::DEVICE, count, memAlloc); @@ -2115,161 +2101,6 @@ std::string Backend::getNVCCFlags() const return nvccFlags; } //-------------------------------------------------------------------------- -void Backend::genCurrentSpikePush(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize, bool spikeEvent) const -{ - assert(!getPreferences().automaticCopy); - - if(!(ng.getSpikeLocation() & VarLocation::ZERO_COPY)) { - // Is delay required - const bool delayRequired = spikeEvent ? - ng.isDelayRequired() : - (ng.isTrueSpikeRequired() && ng.isDelayRequired()); - - const char *spikeCntPrefix = spikeEvent ? "glbSpkCntEvnt" : "glbSpkCnt"; - const char *spikePrefix = spikeEvent ? "glbSpkEvnt" : "glbSpk"; - - if (delayRequired) { - // If there's only a single batch - if(batchSize == 1) { - // Copy spike count for current timestep - os << "CHECK_CUDA_ERRORS(cudaMemcpy(d_" << spikeCntPrefix << ng.getName() << " + spkQuePtr" << ng.getName(); - os << ", " << spikeCntPrefix << ng.getName() << " + spkQuePtr" << ng.getName(); - os << ", sizeof(unsigned int), cudaMemcpyHostToDevice));" << std::endl; - - // Copy this many spikes from current timestep - os << "CHECK_CUDA_ERRORS(cudaMemcpy(d_" << spikePrefix << ng.getName() << " + (spkQuePtr" << ng.getName() << "*" << ng.getNumNeurons() << ")"; - os << ", " << spikePrefix << ng.getName(); - os << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << spikeCntPrefix << ng.getName() << "[spkQuePtr" << ng.getName() << "] * sizeof(unsigned int), cudaMemcpyHostToDevice));" << std::endl; - } - else { - // Copy spike count for current timestep from each batch using 2D memcpy - os << "CHECK_CUDA_ERRORS(cudaMemcpy2D(d_" << spikeCntPrefix << ng.getName() << " + spkQuePtr" << ng.getName(); - os << ", " << ng.getNumDelaySlots() << " * sizeof(unsigned int)"; - os << ", " << spikeCntPrefix << ng.getName() << " + spkQuePtr" << ng.getName(); - os << ", " << ng.getNumDelaySlots() << " * sizeof(unsigned int)"; - os << ", sizeof(unsigned int), " << batchSize << ", cudaMemcpyHostToDevice));" << std::endl; - - // Loop through batches and launch asynchronous memcpys to copy spikes from each one - os << "for(unsigned int b = 0; b < " << batchSize << "; b++)"; - { - CodeStream::Scope b(os); - os << "const unsigned int spikeOffset = (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ") + (b * " << (ng.getNumNeurons() * ng.getNumDelaySlots()) << ");" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaMemcpyAsync(d_" << spikePrefix << ng.getName() << " + spikeOffset"; - os << ", " << spikePrefix << ng.getName() << " + spikeOffset"; - os << ", " << spikeCntPrefix << ng.getName() << "[spkQuePtr" << ng.getName() << " + (b * " << ng.getNumDelaySlots() << ")] * sizeof(unsigned int)"; - os << ", cudaMemcpyHostToDevice)); " << std::endl; - } - - // Wait until queued copies have completed - os << "CHECK_CUDA_ERRORS(cudaStreamSynchronize(0));" << std::endl; - } - } - else { - // Copy the spike count for each batch - os << "CHECK_CUDA_ERRORS(cudaMemcpy(d_" << spikeCntPrefix << ng.getName(); - os << ", " << spikeCntPrefix << ng.getName(); - os << ", " << batchSize << " * sizeof(unsigned int), cudaMemcpyHostToDevice));" << std::endl; - - // If there's only a single batch, copy spikes - if(batchSize == 1) { - os << "CHECK_CUDA_ERRORS(cudaMemcpy(d_" << spikePrefix << ng.getName(); - os << ", " << spikePrefix << ng.getName(); - os << ", " << spikeCntPrefix << ng.getName() << "[0] * sizeof(unsigned int), cudaMemcpyHostToDevice));" << std::endl; - } - // Otherwise, loop through batches and launch asynchronous memcpys to copy spikes from each one - else { - os << "for(unsigned int b = 0; b < " << batchSize << "; b++)"; - { - CodeStream::Scope b(os); - os << "CHECK_CUDA_ERRORS(cudaMemcpyAsync(d_" << spikePrefix << ng.getName() << " + (b * " << ng.getNumNeurons() << ")"; - os << ", " << spikePrefix << ng.getName() << " + (b * " << ng.getNumNeurons() << ")"; - os << ", " << spikeCntPrefix << ng.getName() << "[b] * sizeof(unsigned int), cudaMemcpyHostToDevice));" << std::endl; - } - - // Wait until queued copies have completed - os << "CHECK_CUDA_ERRORS(cudaStreamSynchronize(0));" << std::endl; - } - } - } -} -//-------------------------------------------------------------------------- -void Backend::genCurrentSpikePull(CodeStream &os, const NeuronGroupInternal &ng, unsigned int batchSize, bool spikeEvent) const -{ - if(!(ng.getSpikeLocation() & VarLocation::ZERO_COPY)) { - // Is delay required - const bool delayRequired = spikeEvent ? - ng.isDelayRequired() : - (ng.isTrueSpikeRequired() && ng.isDelayRequired()); - - const char *spikeCntPrefix = spikeEvent ? "glbSpkCntEvnt" : "glbSpkCnt"; - const char *spikePrefix = spikeEvent ? "glbSpkEvnt" : "glbSpk"; - - if (delayRequired) { - // If there's only a single batch - if(batchSize == 1) { - // Copy spike count for current timestep - os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << spikeCntPrefix << ng.getName() << " + spkQuePtr" << ng.getName(); - os << ", d_" << spikeCntPrefix << ng.getName() << " + spkQuePtr" << ng.getName(); - os << ", sizeof(unsigned int), cudaMemcpyDeviceToHost));" << std::endl; - - // Copy this many spikes from current timestep - os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << spikePrefix << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", d_" << spikePrefix << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << spikeCntPrefix << ng.getName() << "[spkQuePtr" << ng.getName() << "] * sizeof(unsigned int), cudaMemcpyDeviceToHost));" << std::endl; - } - else { - // Copy spike count for current timestep from each batch using 2D memcpy - os << "CHECK_CUDA_ERRORS(cudaMemcpy2D(" << spikeCntPrefix << ng.getName() << " + spkQuePtr" << ng.getName(); - os << ", " << ng.getNumDelaySlots() << " * sizeof(unsigned int)"; - os << ", d_" << spikeCntPrefix << ng.getName() << " + spkQuePtr" << ng.getName(); - os << ", " << ng.getNumDelaySlots() << " * sizeof(unsigned int)"; - os << ", sizeof(unsigned int), " << batchSize << ", cudaMemcpyDeviceToHost));" << std::endl; - - // Loop through batches and launch asynchronous memcpys to copy spikes from each one - os << "for(unsigned int b = 0; b < " << batchSize << "; b++)"; - { - CodeStream::Scope b(os); - os << "const unsigned int spikeOffset = (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ") + (b * " << (ng.getNumNeurons() * ng.getNumDelaySlots()) << ");" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaMemcpyAsync(" << spikePrefix << ng.getName() << " + spikeOffset"; - os << ", d_" << spikePrefix << ng.getName() << " + spikeOffset"; - os << ", " << spikeCntPrefix << ng.getName() << "[spkQuePtr" << ng.getName() << " + (b * " << ng.getNumDelaySlots() << ")] * sizeof(unsigned int)"; - os << ", cudaMemcpyDeviceToHost)); " << std::endl; - } - - // Wait until queued copies have completed - os << "CHECK_CUDA_ERRORS(cudaStreamSynchronize(0));" << std::endl; - } - } - else { - // Copy the spike count for each batch - os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << spikeCntPrefix << ng.getName(); - os << ", d_" << spikeCntPrefix << ng.getName(); - os << ", " << batchSize << " * sizeof(unsigned int), cudaMemcpyDeviceToHost));" << std::endl; - - // If there's only a single batch, copy spikes - if(batchSize == 1) { - os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << spikePrefix << ng.getName(); - os << ", d_" << spikePrefix << ng.getName(); - os << ", " << spikeCntPrefix << ng.getName() << "[0] * sizeof(unsigned int), cudaMemcpyDeviceToHost));" << std::endl; - } - // Otherwise, loop through batches and launch asynchronous memcpys to copy spikes from each one - else { - os << "for(unsigned int b = 0; b < " << batchSize << "; b++)"; - { - CodeStream::Scope b(os); - os << "CHECK_CUDA_ERRORS(cudaMemcpyAsync(" << spikePrefix << ng.getName() << " + (b * " << ng.getNumNeurons() << ")"; - os << ", d_" << spikePrefix << ng.getName() << " + (b * " << ng.getNumNeurons() << ")"; - os << ", " << spikeCntPrefix << ng.getName() << "[b] * sizeof(unsigned int), cudaMemcpyDeviceToHost));" << std::endl; - } - - // Wait until queued copies have completed - os << "CHECK_CUDA_ERRORS(cudaStreamSynchronize(0));" << std::endl; - } - } - } -} -//-------------------------------------------------------------------------- void Backend::genKernelDimensions(CodeStream &os, Kernel kernel, size_t numThreadsX, size_t batchSize, size_t numBlockThreadsY) const { // Calculate grid size diff --git a/src/genn/backends/cuda/optimiser.cc b/src/genn/backends/cuda/optimiser.cc index 0917d5743c..f13d8bef99 100644 --- a/src/genn/backends/cuda/optimiser.cc +++ b/src/genn/backends/cuda/optimiser.cc @@ -763,7 +763,7 @@ Backend createBackend(const ModelSpecInternal &model, const filesystem::path &ou const int deviceID = chooseOptimalDevice(model, cudaBlockSize, preferences, outputPath); // Create backend - return Backend(cudaBlockSize, preferences, model.getPrecision(), deviceID); + return Backend(cudaBlockSize, preferences, deviceID); } // Otherwise else { @@ -782,11 +782,11 @@ Backend createBackend(const ModelSpecInternal &model, const filesystem::path &ou optimizeBlockSize(deviceID, deviceProps, model, cudaBlockSize, preferences, outputPath); // Create backend - return Backend(cudaBlockSize, preferences, model.getPrecision(), deviceID); + return Backend(cudaBlockSize, preferences, deviceID); } // Otherwise, create backend using manual block sizes specified in preferences else { - return Backend(preferences.manualBlockSizes, preferences, model.getPrecision(), deviceID); + return Backend(preferences.manualBlockSizes, preferences, deviceID); } } diff --git a/src/genn/backends/single_threaded_cpu/optimiser.cc b/src/genn/backends/single_threaded_cpu/optimiser.cc index 03c8ca69db..288e0ba45a 100644 --- a/src/genn/backends/single_threaded_cpu/optimiser.cc +++ b/src/genn/backends/single_threaded_cpu/optimiser.cc @@ -5,7 +5,7 @@ //-------------------------------------------------------------------------- namespace GeNN::CodeGenerator::SingleThreadedCPU::Optimiser { -Backend createBackend(const filesystem::path&, +Backend createBackend(const ModelSpecInternal&,const filesystem::path&, plog::Severity backendLevel, plog::IAppender *backendAppender, const Preferences &preferences) { From 69adaa1c21b64eba382bfeee454e8a0bcc0a3b81 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 22:13:25 +0000 Subject: [PATCH 079/725] fixed atomic implementation --- .../genn/genn/code_generator/backendSIMT.h | 7 ++-- src/genn/genn/code_generator/backendSIMT.cc | 38 +++++++++---------- .../presynapticUpdateStrategySIMT.cc | 34 ++++++++--------- 3 files changed, 40 insertions(+), 39 deletions(-) diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 8d7fbeb640..824921bd27 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -161,10 +161,11 @@ class GENN_EXPORT BackendSIMT : public BackendBase //! Helper to get name of atomic operation template - std::string getAtomic(AtomicOperation op = AtomicOperation::ADD, + std::string getAtomic(const Type::TypeContext &typeContext, + AtomicOperation op = AtomicOperation::ADD, AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const { - return getAtomic(T::getInstance(), op, memSpace); + return getAtomic(T::getInstance(), typeContext, op, memSpace); } //-------------------------------------------------------------------------- @@ -453,7 +454,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase } } - void genEmitSpike(CodeStream &os, const Substitutions &subs, const std::string &suffix, bool recordingEnabled) const; + void genEmitSpike(const ModelSpecMerged &modelMerged, CodeStream &os, const Substitutions &subs, const std::string &suffix, bool recordingEnabled) const; void genRecordingSharedMemInit(CodeStream &os, const std::string &suffix) const; diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index b8b081600e..863334e187 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -498,14 +498,14 @@ void BackendSIMT::genNeuronUpdateKernel(CodeStream &os, const Substitutions &ker ng.generateNeuronUpdate(*this, os, modelMerged, popSubs, // Emit true spikes - [this](CodeStream &neuronUpdateKernelsBody, const NeuronUpdateGroupMerged &ng, Substitutions &subs) + [&modelMerged, this](CodeStream &neuronUpdateKernelsBody, const NeuronUpdateGroupMerged &ng, Substitutions &subs) { - genEmitSpike(neuronUpdateKernelsBody, subs, "", ng.getArchetype().isSpikeRecordingEnabled()); + genEmitSpike(modelMerged, neuronUpdateKernelsBody, subs, "", ng.getArchetype().isSpikeRecordingEnabled()); }, // Emit spike-like events - [this](CodeStream &neuronUpdateKernelsBody, const NeuronUpdateGroupMerged &ng, Substitutions &subs) + [&modelMerged, this](CodeStream &neuronUpdateKernelsBody, const NeuronUpdateGroupMerged &ng, Substitutions &subs) { - genEmitSpike(neuronUpdateKernelsBody, subs, "Evnt", ng.getArchetype().isSpikeEventRecordingEnabled()); + genEmitSpike(modelMerged, neuronUpdateKernelsBody, subs, "Evnt", ng.getArchetype().isSpikeEventRecordingEnabled()); }); // Copy local stream back to local @@ -523,7 +523,7 @@ void BackendSIMT::genNeuronUpdateKernel(CodeStream &os, const Substitutions &ker os << "if (shSpkEvntCount > 0)"; { CodeStream::Scope b(os); - os << "shPosSpkEvnt = " << getAtomic() << "(&group->spkCntEvnt"; + os << "shPosSpkEvnt = " << getAtomic(modelMerged.getTypeContext()) << "(&group->spkCntEvnt"; if(ng.getArchetype().isDelayRequired()) { os << "[*group->spkQuePtr"; if(batchSize > 1) { @@ -546,7 +546,7 @@ void BackendSIMT::genNeuronUpdateKernel(CodeStream &os, const Substitutions &ker os << "if (shSpkCount > 0)"; { CodeStream::Scope b(os); - os << "shPosSpk = " << getAtomic() << "(&group->spkCnt"; + os << "shPosSpk = " << getAtomic(modelMerged.getTypeContext()) << "(&group->spkCnt"; if(ng.getArchetype().isDelayRequired() && ng.getArchetype().isTrueSpikeRequired()) { os << "[*group->spkQuePtr"; if(batchSize > 1) { @@ -814,7 +814,7 @@ void BackendSIMT::genPostsynapticUpdateKernel(CodeStream &os, const Substitution if(sg.getArchetype().isPresynapticOutputRequired()) { synSubs.addFuncSubstitution("addToPre", 1, - getAtomic(modelMerged.getModel().getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); + getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); } sg.generateSynapseUpdate(*this, os, modelMerged, synSubs); @@ -875,16 +875,16 @@ void BackendSIMT::genSynapseDynamicsKernel(CodeStream &os, const Substitutions & // If dendritic delay is required, always use atomic operation to update dendritic delay buffer // **TODO** once synapse dynamics gets refactored into update strategy classes, move the index building code elsewhere if(sg.getArchetype().isDendriticDelayRequired()) { - synSubs.addFuncSubstitution("addToInSynDelay", 2, getAtomic(modelMerged.getModel().getPrecision()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, synSubs["id_post"], "$(1)") + "], $(0))"); + synSubs.addFuncSubstitution("addToInSynDelay", 2, getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, synSubs["id_post"], "$(1)") + "], $(0))"); } // Otherwise else { - synSubs.addFuncSubstitution("addToInSyn", 1, getAtomic(modelMerged.getModel().getPrecision()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, synSubs["id_post"]) + "], $(0))"); + synSubs.addFuncSubstitution("addToInSyn", 1, getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, synSubs["id_post"]) + "], $(0))"); } if(sg.getArchetype().isPresynapticOutputRequired()) { synSubs.addFuncSubstitution("addToPre", 1, - getAtomic(modelMerged.getModel().getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); + getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); } sg.generateSynapseUpdate(*this, os, modelMerged, synSubs); @@ -1564,7 +1564,7 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne } // Otherwise else { - kernelInit << "group->ind[(($(0)) * group->rowStride) + " << getAtomic() << +"(&group->rowLength[$(0)], 1)] = " << popSubs["id_post"] << ";"; + kernelInit << "group->ind[(($(0)) * group->rowStride) + " << getAtomic(modelMerged.getTypeContext()) << +"(&group->rowLength[$(0)], 1)] = " << popSubs["id_post"] << ";"; } } // Otherwise, if it's bitmask @@ -1575,12 +1575,12 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne // If there is row-building code in this snippet if(!snippet->getRowBuildCode().empty()) { kernelInit << "const " << indexType << " rowStartGID = " << popSubs["id"] << " * (" << indexType << ")group->rowStride;" << std::endl; - kernelInit << getAtomic(AtomicOperation::OR) << "(&group->gp[(rowStartGID + ($(0))) / 32], 0x80000000 >> ((rowStartGID + ($(0))) & 31));" << std::endl; + kernelInit << getAtomic(modelMerged.getTypeContext(), AtomicOperation::OR) << "(&group->gp[(rowStartGID + ($(0))) / 32], 0x80000000 >> ((rowStartGID + ($(0))) & 31));" << std::endl; } // Otherwise else { kernelInit << "const " << indexType << " colStartGID = " << popSubs["id"] << ";" << std::endl; - kernelInit << getAtomic(AtomicOperation::OR) << "(&group->gp[(colStartGID + (($(0)) * group->rowStride)) / 32], 0x80000000 >> ((colStartGID + (($(0)) * group->rowStride)) & 31));" << std::endl; + kernelInit << getAtomic(modelMerged.getTypeContext(), AtomicOperation::OR) << "(&group->gp[(colStartGID + (($(0)) * group->rowStride)) / 32], 0x80000000 >> ((colStartGID + (($(0)) * group->rowStride)) & 31));" << std::endl; } } } @@ -1645,7 +1645,7 @@ void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions // Generate sparse synapse variable initialisation code genSparseSynapseVarInit( os, modelMerged, sg, popSubs, sg.getArchetype().isWUVarInitRequired(), - [this](CodeStream &os, const SynapseSparseInitGroupMerged &sg, Substitutions&) + [&modelMerged, this](CodeStream &os, const SynapseSparseInitGroupMerged &sg, Substitutions&) { // If postsynaptic learning is required if(!sg.getArchetype().getWUModel()->getLearnPostCode().empty()) { @@ -1656,7 +1656,7 @@ void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions // Atomically increment length of column of connectivity associated with this target // **NOTE** this returns previous length i.e. where to insert new entry - os << "const unsigned int colLocation = " << getAtomic() << "(&group->colLength[postIndex], 1);" << std::endl; + os << "const unsigned int colLocation = " << getAtomic(modelMerged.getTypeContext()) << "(&group->colLength[postIndex], 1);" << std::endl; // From this calculate index into column-major matrix os << "const unsigned int colMajorIndex = (postIndex * group->colStride) + colLocation;" << std::endl; @@ -1709,18 +1709,18 @@ size_t BackendSIMT::padKernelSize(size_t size, Kernel kernel) const return padSize(size, getKernelBlockSize(kernel)); } //-------------------------------------------------------------------------- -void BackendSIMT::genEmitSpike(CodeStream &os, const Substitutions &subs, const std::string &suffix, bool recordingEnabled) const +void BackendSIMT::genEmitSpike(const ModelSpecMerged &modelMerged, CodeStream &os, const Substitutions &subs, const std::string &suffix, bool recordingEnabled) const { - os << "const unsigned int spk" << suffix << "Idx = " << getAtomic(AtomicOperation::ADD, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Count, 1);" << std::endl; + os << "const unsigned int spk" << suffix << "Idx = " << getAtomic(modelMerged.getTypeContext(), AtomicOperation::ADD, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Count, 1);" << std::endl; os << "shSpk" << suffix << "[spk" << suffix << "Idx] = " << subs["id"] << ";" << std::endl; // If recording is enabled, set bit in recording word if(recordingEnabled) { if(m_KernelBlockSizes[KernelNeuronUpdate] == 32) { - os << getAtomic(AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record, 1 << " << getThreadID() << ");" << std::endl; + os << getAtomic(modelMerged.getTypeContext(), AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record, 1 << " << getThreadID() << ");" << std::endl; } else { - os << getAtomic(AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record[" << getThreadID() << " / 32], 1 << (" << getThreadID() << " % 32));" << std::endl; + os << getAtomic(modelMerged.getTypeContext(), AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record[" << getThreadID() << " / 32], 1 << (" << getThreadID() << " % 32));" << std::endl; } } } diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 5dc11aabe2..56b56fc816 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -156,12 +156,12 @@ void PreSpan::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, cons // If dendritic delay is required, use atomic operation to update dendritic delay buffer if(sg.getArchetype().isDendriticDelayRequired()) { synSubs.addFuncSubstitution("addToInSynDelay", 2, - backend.getAtomic(model.getPrecision()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, "ipost", "$(1)") + "], $(0))"); + backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, "ipost", "$(1)") + "], $(0))"); } // Otherwise, substitute global memory array for $(inSyn) else { synSubs.addFuncSubstitution("addToInSyn", 1, - backend.getAtomic(model.getPrecision()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, "ipost") + "], $(0))"); + backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, "ipost") + "], $(0))"); } if(sg.getArchetype().isPresynapticOutputRequired()) { @@ -184,7 +184,7 @@ void PreSpan::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, cons // Should this be in the Postamble? if(sg.getArchetype().isPresynapticOutputRequired()) { // write lrevInSyn to global memory if not 0 - os << "if(lrevInSyn != 0.0) " << backend.getAtomic(model.getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, "preInd") + "], lrevInSyn);" << std::endl; + os << "if(lrevInSyn != 0.0) " << backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, "preInd") + "], lrevInSyn);" << std::endl; } } @@ -343,7 +343,7 @@ void PostSpan::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, con // If dendritic delay is required, always use atomic operation to update dendritic delay buffer if(sg.getArchetype().isDendriticDelayRequired()) { synSubs.addFuncSubstitution("addToInSynDelay", 2, - backend.getAtomic(model.getPrecision()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, synSubs["id_post"], "$(1)") + "], $(0))"); + backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, synSubs["id_post"], "$(1)") + "], $(0))"); } // Otherwise else { @@ -359,13 +359,13 @@ void PostSpan::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, con // Otherwise, use global memory atomic else { synSubs.addFuncSubstitution("addToInSyn", 1, - backend.getAtomic(model.getPrecision()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, synSubs["id_post"]) + "], $(0))"); + backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, synSubs["id_post"]) + "], $(0))"); } } if(sg.getArchetype().isPresynapticOutputRequired()) { synSubs.addFuncSubstitution("addToPre", 1, - backend.getAtomic(model.getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); + backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); } if(trueSpike) { @@ -403,7 +403,7 @@ void PostSpan::genPostamble(CodeStream &os, const ModelSpecMerged &modelMerged, CodeStream::Scope b(os); const std::string inSyn = "group->inSyn[" + sg.getPostISynIndex(batchSize, popSubs["id"]) + "]"; if(sg.getArchetype().isPSModelFused()) { - os << backend.getAtomic(model.getPrecision()) << "(&" << inSyn << ", linSyn);" << std::endl; + os << backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) << "(&" << inSyn << ", linSyn);" << std::endl; } else { os << inSyn << " += linSyn;" << std::endl; @@ -416,7 +416,7 @@ void PostSpan::genPostamble(CodeStream &os, const ModelSpecMerged &modelMerged, os << "if(" << backend.getThreadID() << " < group->numTrgNeurons)"; { CodeGenerator::CodeStream::Scope b(os); - os << backend.getAtomic(model.getPrecision()) << "(&group->inSyn[" << sg.getPostISynIndex(batchSize, backend.getThreadID()) << "], "; + os << backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) << "(&group->inSyn[" << sg.getPostISynIndex(batchSize, backend.getThreadID()) << "], "; os << "shLg[" << backend.getThreadID() << "]); " << std::endl; } } @@ -577,12 +577,12 @@ void PreSpanProcedural::genUpdate(CodeStream &os, const ModelSpecMerged &modelMe // If dendritic delay is required, use atomic operation to update dendritic delay buffer if(sg.getArchetype().isDendriticDelayRequired()) { presynapticUpdateSubs.addFuncSubstitution("addToInSynDelay", 2, - backend.getAtomic(model.getPrecision()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); + backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); } // Otherwise, substitute global memory array for $(inSyn) else { presynapticUpdateSubs.addFuncSubstitution("addToInSyn", 1, - backend.getAtomic(model.getPrecision()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, "$(id_post)") + "], $(0))"); + backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, "$(id_post)") + "], $(0))"); } if(sg.getArchetype().isPresynapticOutputRequired()) { @@ -612,7 +612,7 @@ void PreSpanProcedural::genUpdate(CodeStream &os, const ModelSpecMerged &modelMe // Should this be in the Postamble? if(sg.getArchetype().isPresynapticOutputRequired()) { // write lrevInSyn to global memory if not 0 - os << "if(lrevInSyn != 0.0) " << backend.getAtomic(model.getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, "preInd") + "], lrevInSyn);" << std::endl; + os << "if(lrevInSyn != 0.0) " << backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, "preInd") + "], lrevInSyn);" << std::endl; } } @@ -753,7 +753,7 @@ void PostSpanBitmask::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerg if(sg.getArchetype().isPresynapticOutputRequired()) { synSubs.addFuncSubstitution("addToPre", 1, - backend.getAtomic(modelMerged.getModel().getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); + backend.getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); } if(trueSpike) { @@ -793,7 +793,7 @@ void PostSpanBitmask::genPostamble(CodeStream &os, const ModelSpecMerged &modelM CodeStream::Scope b(os); const std::string inSyn = "group->inSyn[" + sg.getPostISynIndex(modelMerged.getModel().getBatchSize(), "glbIdx") +"]"; if(sg.getArchetype().isPSModelFused()) { - os << backend.getAtomic(modelMerged.getModel().getPrecision()) << "(&" << inSyn << ", shLg[shIdx]);" << std::endl; + os << backend.getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()) << "(&" << inSyn << ", shLg[shIdx]);" << std::endl; } else { os << inSyn << " += shLg[shIdx];" << std::endl; @@ -936,7 +936,7 @@ void PostSpanToeplitz::genUpdate(CodeStream &os, const ModelSpecMerged &modelMer // If dendritic delay is required, always use atomic operation to update dendritic delay buffer if(sg.getArchetype().isDendriticDelayRequired()) { presynapticUpdateSubs.addFuncSubstitution("addToInSynDelay", 2, - backend.getAtomic(model.getPrecision()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); + backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); } // Otherwise else { @@ -947,13 +947,13 @@ void PostSpanToeplitz::genUpdate(CodeStream &os, const ModelSpecMerged &modelMer // Otherwise, use global memory atomic else { presynapticUpdateSubs.addFuncSubstitution("addToInSyn", 1, - backend.getAtomic(model.getPrecision()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, "$(id_post)") + "], $(0))"); + backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, "$(id_post)") + "], $(0))"); } } if(sg.getArchetype().isPresynapticOutputRequired()) { presynapticUpdateSubs.addFuncSubstitution("addToPre", 1, - backend.getAtomic(model.getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, presynapticUpdateSubs["id_pre"]) + "], $(0))"); + backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, presynapticUpdateSubs["id_pre"]) + "], $(0))"); } // Generate presynaptic simulation code into new stringstream-backed code stream @@ -989,7 +989,7 @@ void PostSpanToeplitz::genPostamble(CodeStream &os, const ModelSpecMerged &model os << "if(" << backend.getThreadID() << " < group->numTrgNeurons)"; { CodeGenerator::CodeStream::Scope b(os); - os << backend.getAtomic(modelMerged.getModel().getPrecision()); + os << backend.getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()); os << "(&group->inSyn[" << sg.getPostISynIndex(modelMerged.getModel().getBatchSize(), backend.getThreadID()) << "], "; os << "shLg[" << backend.getThreadID() << "]); " << std::endl; } From ad32d643a88a3f5855de3f2899c6a1b84f696972 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 22:24:04 +0000 Subject: [PATCH 080/725] way nicer expression visitor implementation --- include/genn/genn/transpiler/expression.h | 65 +++++++++-------------- src/genn/genn/genn.vcxproj | 1 - src/genn/genn/transpiler/expression.cc | 22 -------- 3 files changed, 26 insertions(+), 62 deletions(-) delete mode 100644 src/genn/genn/transpiler/expression.cc diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index 53b77fa472..d37c49d6ee 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -27,21 +27,32 @@ class Base virtual void accept(Visitor &visitor) const = 0; }; +//--------------------------------------------------------------------------- +// GeNN::Transpiler::Expression::Acceptable +//--------------------------------------------------------------------------- +template +class Acceptable : public Base +{ +public: + void accept(Visitor &visitor) const final + { + visitor.visit(static_cast(*this)); + } +}; + typedef std::unique_ptr ExpressionPtr; typedef std::vector ExpressionList; //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::ArraySubscript //--------------------------------------------------------------------------- -class ArraySubscript : public Base +class ArraySubscript : public Acceptable { public: ArraySubscript(Token pointerName, ExpressionPtr index) : m_PointerName(pointerName), m_Index(std::move(index)) {} - virtual void accept(Visitor &visitor) const final; - const Token &getPointerName() const { return m_PointerName; } const ExpressionPtr &getIndex() const { return m_Index; } @@ -53,15 +64,13 @@ class ArraySubscript : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::Assignment //--------------------------------------------------------------------------- -class Assignment : public Base +class Assignment : public Acceptable { public: Assignment(Token varName, Token op, ExpressionPtr value) : m_VarName(varName), m_Operator(op), m_Value(std::move(value)) {} - virtual void accept(Visitor &visitor) const final; - const Token &getVarName() const { return m_VarName; } const Token &getOperator() const { return m_Operator; } const Base *getValue() const { return m_Value.get(); } @@ -75,15 +84,13 @@ class Assignment : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::Binary //--------------------------------------------------------------------------- -class Binary : public Base +class Binary : public Acceptable { public: Binary(ExpressionPtr left, Token op, ExpressionPtr right) : m_Left(std::move(left)), m_Operator(op), m_Right(std::move(right)) {} - virtual void accept(Visitor &visitor) const final; - const Base *getLeft() const { return m_Left.get(); } const Token &getOperator() const { return m_Operator; } const Base *getRight() const { return m_Right.get(); } @@ -97,15 +104,13 @@ class Binary : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::Call //--------------------------------------------------------------------------- -class Call : public Base +class Call : public Acceptable { public: Call(ExpressionPtr callee, Token closingParen, ExpressionList arguments) : m_Callee(std::move(callee)), m_ClosingParen(closingParen), m_Arguments(std::move(arguments)) {} - virtual void accept(Visitor &visitor) const final; - const Base *getCallee() const { return m_Callee.get(); } const Token &getClosingParen() const { return m_ClosingParen; } const ExpressionList &getArguments() const { return m_Arguments; } @@ -119,15 +124,13 @@ class Call : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::Cast //--------------------------------------------------------------------------- -class Cast : public Base +class Cast : public Acceptable { public: Cast(const Type::Base *type, ExpressionPtr expression, Token closingParen) : m_Type(type), m_Expression(std::move(expression)), m_ClosingParen(closingParen) {} - virtual void accept(Visitor &visitor) const final; - const Type::Base *getType() const{ return m_Type; } const Base *getExpression() const { return m_Expression.get(); } const Token &getClosingParen() const { return m_ClosingParen; } @@ -141,15 +144,13 @@ class Cast : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::Conditional //--------------------------------------------------------------------------- -class Conditional : public Base +class Conditional : public Acceptable { public: Conditional(ExpressionPtr condition, Token question, ExpressionPtr trueExpression, ExpressionPtr falseExpression) : m_Condition(std::move(condition)), m_Question(question), m_True(std::move(trueExpression)), m_False(std::move(falseExpression)) {} - virtual void accept(Visitor &visitor) const final; - const Base *getCondition() const { return m_Condition.get(); } const Token &getQuestion() const { return m_Question; } const Base *getTrue() const { return m_True.get(); } @@ -165,15 +166,13 @@ class Conditional : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::Grouping //--------------------------------------------------------------------------- -class Grouping : public Base +class Grouping : public Acceptable { public: Grouping(ExpressionPtr expression) : m_Expression(std::move(expression)) {} - virtual void accept(Visitor &visitor) const final; - const Base *getExpression() const { return m_Expression.get(); } private: @@ -183,15 +182,13 @@ class Grouping : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::Literal //--------------------------------------------------------------------------- -class Literal : public Base +class Literal : public Acceptable { public: Literal(Token value) : m_Value(value) {} - virtual void accept(Visitor &visitor) const final; - Token getValue() const { return m_Value; } private: @@ -201,15 +198,13 @@ class Literal : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::Logical //--------------------------------------------------------------------------- -class Logical : public Base +class Logical : public Acceptable { public: Logical(ExpressionPtr left, Token op, ExpressionPtr right) : m_Left(std::move(left)), m_Operator(op), m_Right(std::move(right)) {} - virtual void accept(Visitor &visitor) const final; - const Base *getLeft() const { return m_Left.get(); } const Token &getOperator() const { return m_Operator; } const Base *getRight() const { return m_Right.get(); } @@ -223,15 +218,13 @@ class Logical : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::PostfixIncDec //--------------------------------------------------------------------------- -class PostfixIncDec : public Base +class PostfixIncDec : public Acceptable { public: PostfixIncDec(Token varName, Token op) : m_VarName(varName), m_Operator(op) {} - virtual void accept(Visitor &visitor) const final; - const Token &getVarName() const { return m_VarName; } const Token &getOperator() const { return m_Operator; } @@ -243,15 +236,13 @@ class PostfixIncDec : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::PrefixIncDec //--------------------------------------------------------------------------- -class PrefixIncDec : public Base +class PrefixIncDec : public Acceptable { public: PrefixIncDec(Token varName, Token op) : m_VarName(varName), m_Operator(op) {} - virtual void accept(Visitor &visitor) const final; - const Token &getVarName() const { return m_VarName; } const Token &getOperator() const { return m_Operator; } @@ -263,15 +254,13 @@ class PrefixIncDec : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::Variable //--------------------------------------------------------------------------- -class Variable : public Base +class Variable : public Acceptable { public: Variable(Token name) : m_Name(name) {} - virtual void accept(Visitor &visitor) const final; - const Token &getName() const { return m_Name; } private: @@ -281,15 +270,13 @@ class Variable : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::Unary //--------------------------------------------------------------------------- -class Unary : public Base +class Unary : public Acceptable { public: Unary(Token op, ExpressionPtr right) : m_Operator(op), m_Right(std::move(right)) {} - virtual void accept(Visitor &visitor) const final; - const Token &getOperator() const { return m_Operator; } const Base *getRight() const { return m_Right.get(); } diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index aa133b54ce..3dbb61b56a 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -56,7 +56,6 @@ - diff --git a/src/genn/genn/transpiler/expression.cc b/src/genn/genn/transpiler/expression.cc deleted file mode 100644 index e3ea06e688..0000000000 --- a/src/genn/genn/transpiler/expression.cc +++ /dev/null @@ -1,22 +0,0 @@ -#include "transpiler/expression.h" - -#define IMPLEMENT_ACCEPT(CLASS_NAME) \ - void GeNN::Transpiler::Expression::CLASS_NAME::accept(Visitor &visitor) const \ - { \ - visitor.visit(*this); \ - } - - -IMPLEMENT_ACCEPT(ArraySubscript) -IMPLEMENT_ACCEPT(Assignment) -IMPLEMENT_ACCEPT(Binary) -IMPLEMENT_ACCEPT(Call) -IMPLEMENT_ACCEPT(Cast) -IMPLEMENT_ACCEPT(Conditional) -IMPLEMENT_ACCEPT(Grouping) -IMPLEMENT_ACCEPT(Literal) -IMPLEMENT_ACCEPT(Logical) -IMPLEMENT_ACCEPT(PrefixIncDec) -IMPLEMENT_ACCEPT(PostfixIncDec) -IMPLEMENT_ACCEPT(Variable) -IMPLEMENT_ACCEPT(Unary) \ No newline at end of file From 537feae92ac04f5fb660b541d7a1ac98acc5dcfb Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 22:28:21 +0000 Subject: [PATCH 081/725] applied the same approach to statements --- include/genn/genn/transpiler/statement.h | 62 ++++++++++-------------- src/genn/genn/genn.vcxproj | 1 - src/genn/genn/transpiler/statement.cc | 21 -------- 3 files changed, 26 insertions(+), 58 deletions(-) delete mode 100644 src/genn/genn/transpiler/statement.cc diff --git a/include/genn/genn/transpiler/statement.h b/include/genn/genn/transpiler/statement.h index d0b3bdbe87..c0ba7c4c1a 100644 --- a/include/genn/genn/transpiler/statement.h +++ b/include/genn/genn/transpiler/statement.h @@ -27,21 +27,33 @@ class Base virtual void accept(Visitor &visitor) const = 0; }; +//--------------------------------------------------------------------------- +// GeNN::Transpiler::Statement::Acceptable +//--------------------------------------------------------------------------- +template +class Acceptable : public Base +{ +public: + void accept(Visitor &visitor) const final + { + visitor.visit(static_cast(*this)); + } +}; + typedef std::unique_ptr StatementPtr; typedef std::vector StatementList; + //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::Break //--------------------------------------------------------------------------- -class Break : public Base +class Break : public Acceptable { public: Break(Token token) : m_Token(token) {} - virtual void accept(Visitor &visitor) const override; - const Token &getToken() const { return m_Token; } private: @@ -51,15 +63,13 @@ class Break : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::Compound //--------------------------------------------------------------------------- -class Compound : public Base +class Compound : public Acceptable { public: Compound(StatementList statements) : m_Statements(std::move(statements)) {} - virtual void accept(Visitor &visitor) const override; - const StatementList &getStatements() const { return m_Statements; } private: @@ -69,15 +79,13 @@ class Compound : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::Continue //--------------------------------------------------------------------------- -class Continue : public Base +class Continue : public Acceptable { public: Continue(Token token) : m_Token(token) {} - virtual void accept(Visitor &visitor) const override; - const Token &getToken() const { return m_Token; } private: @@ -87,7 +95,7 @@ class Continue : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::Do //--------------------------------------------------------------------------- -class Do : public Base +class Do : public Acceptable { using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: @@ -95,8 +103,6 @@ class Do : public Base : m_Condition(std::move(condition)), m_Body(std::move(body)) {} - virtual void accept(Visitor &visitor) const override; - const ExpressionPtr::element_type *getCondition() const { return m_Condition.get(); } const Base *getBody() const { return m_Body.get(); } @@ -108,7 +114,7 @@ class Do : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::Expression //--------------------------------------------------------------------------- -class Expression : public Base +class Expression : public Acceptable { using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: @@ -116,8 +122,6 @@ class Expression : public Base : m_Expression(std::move(expression)) {} - virtual void accept(Visitor &visitor) const override; - const ExpressionPtr::element_type *getExpression() const { return m_Expression.get(); } private: @@ -127,7 +131,7 @@ class Expression : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::For //--------------------------------------------------------------------------- -class For : public Base +class For : public Acceptable { using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: @@ -135,8 +139,6 @@ class For : public Base : m_Initialiser(std::move(initialiser)), m_Condition(std::move(condition)), m_Increment(std::move(increment)), m_Body(std::move(body)) {} - virtual void accept(Visitor &visitor) const override; - const Base *getInitialiser() const { return m_Initialiser.get(); } const ExpressionPtr::element_type *getCondition() const { return m_Condition.get(); } const ExpressionPtr::element_type *getIncrement() const { return m_Increment.get(); } @@ -152,7 +154,7 @@ class For : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::If //--------------------------------------------------------------------------- -class If : public Base +class If : public Acceptable { using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: @@ -160,8 +162,6 @@ class If : public Base : m_Condition(std::move(condition)), m_ThenBranch(std::move(thenBranch)), m_ElseBranch(std::move(elseBranch)) {} - virtual void accept(Visitor &visitor) const override; - const ExpressionPtr::element_type *getCondition() const { return m_Condition.get(); } const Base *getThenBranch() const { return m_ThenBranch.get(); } const Base *getElseBranch() const { return m_ElseBranch.get(); } @@ -175,7 +175,7 @@ class If : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::Labelled //--------------------------------------------------------------------------- -class Labelled : public Base +class Labelled : public Acceptable { using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: @@ -183,8 +183,6 @@ class Labelled : public Base : m_Keyword(keyword), m_Value(std::move(value)), m_Body(std::move(body)) {} - virtual void accept(Visitor &visitor) const override; - const Token &getKeyword() const { return m_Keyword; } const ExpressionPtr::element_type *getValue() const { return m_Value.get(); } const Base *getBody() const { return m_Body.get(); } @@ -199,7 +197,7 @@ class Labelled : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::Switch //--------------------------------------------------------------------------- -class Switch : public Base +class Switch : public Acceptable { using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: @@ -207,8 +205,6 @@ class Switch : public Base : m_Switch(switchToken), m_Condition(std::move(condition)), m_Body(std::move(body)) {} - virtual void accept(Visitor &visitor) const override; - const Token &getSwitch() const { return m_Switch; } const ExpressionPtr::element_type *getCondition() const { return m_Condition.get(); } const Base *getBody() const { return m_Body.get(); } @@ -223,7 +219,7 @@ class Switch : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::VarDeclaration //--------------------------------------------------------------------------- -class VarDeclaration : public Base +class VarDeclaration : public Acceptable { public: typedef std::vector> InitDeclaratorList; @@ -232,8 +228,6 @@ class VarDeclaration : public Base : m_Type(type), m_InitDeclaratorList(std::move(initDeclaratorList)) {} - virtual void accept(Visitor &visitor) const override; - const Type::Base *getType() const{ return m_Type; } const InitDeclaratorList &getInitDeclaratorList() const { return m_InitDeclaratorList; } @@ -246,7 +240,7 @@ class VarDeclaration : public Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::If //--------------------------------------------------------------------------- -class While : public Base +class While : public Acceptable { using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: @@ -254,8 +248,6 @@ class While : public Base : m_Condition(std::move(condition)), m_Body(std::move(body)) {} - virtual void accept(Visitor &visitor) const override; - const ExpressionPtr::element_type *getCondition() const { return m_Condition.get(); } const Base *getBody() const { return m_Body.get(); } @@ -268,7 +260,7 @@ class While : public Base // GeNN::Transpiler::Statement::Print //--------------------------------------------------------------------------- // **HACK** temporary until function calling is working -class Print : public Base +class Print : public Acceptable { using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; public: @@ -276,8 +268,6 @@ class Print : public Base : m_Expression(std::move(expression)) {} - virtual void accept(Visitor &visitor) const override; - const ExpressionPtr::element_type *getExpression() const { return m_Expression.get(); } private: diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index 3dbb61b56a..10b53f21ae 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -59,7 +59,6 @@ - diff --git a/src/genn/genn/transpiler/statement.cc b/src/genn/genn/transpiler/statement.cc deleted file mode 100644 index 19ca9459c2..0000000000 --- a/src/genn/genn/transpiler/statement.cc +++ /dev/null @@ -1,21 +0,0 @@ -#include "transpiler/statement.h" - -#define IMPLEMENT_ACCEPT(CLASS_NAME) \ - void GeNN::Transpiler::Statement::CLASS_NAME::accept(Visitor &visitor) const \ - { \ - visitor.visit(*this); \ - } - -// Implement accept methods -IMPLEMENT_ACCEPT(Break) -IMPLEMENT_ACCEPT(Compound) -IMPLEMENT_ACCEPT(Continue) -IMPLEMENT_ACCEPT(Do) -IMPLEMENT_ACCEPT(Expression) -IMPLEMENT_ACCEPT(For) -IMPLEMENT_ACCEPT(If) -IMPLEMENT_ACCEPT(Labelled) -IMPLEMENT_ACCEPT(Switch) -IMPLEMENT_ACCEPT(VarDeclaration) -IMPLEMENT_ACCEPT(While) -IMPLEMENT_ACCEPT(Print) \ No newline at end of file From 307a3a4730b8f7fc811799e16cbb8c408f319f4f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 22:30:57 +0000 Subject: [PATCH 082/725] we're a typename family... --- include/genn/genn/transpiler/expression.h | 2 +- include/genn/genn/transpiler/statement.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index d37c49d6ee..2633385d21 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -30,7 +30,7 @@ class Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Expression::Acceptable //--------------------------------------------------------------------------- -template +template class Acceptable : public Base { public: diff --git a/include/genn/genn/transpiler/statement.h b/include/genn/genn/transpiler/statement.h index c0ba7c4c1a..c70ceaf52a 100644 --- a/include/genn/genn/transpiler/statement.h +++ b/include/genn/genn/transpiler/statement.h @@ -30,7 +30,7 @@ class Base //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::Acceptable //--------------------------------------------------------------------------- -template +template class Acceptable : public Base { public: From d26d27f7472ea83be54149c539b3c061cac154f5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 22:39:24 +0000 Subject: [PATCH 083/725] fixed another typo --- include/genn/backends/single_threaded_cpu/optimiser.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/include/genn/backends/single_threaded_cpu/optimiser.h b/include/genn/backends/single_threaded_cpu/optimiser.h index 112602553e..570642a3f5 100644 --- a/include/genn/backends/single_threaded_cpu/optimiser.h +++ b/include/genn/backends/single_threaded_cpu/optimiser.h @@ -9,6 +9,12 @@ // Single-threaded CPU backend includes #include "backend.h" +// Forward declarations +namespace GeNN +{ +class ModelSpecInternal; +} + namespace plog { class IAppender; From 8446e6926798afb6f804e80b1105b43ef24d9d5c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 22:39:32 +0000 Subject: [PATCH 084/725] explicit virtual --- include/genn/genn/transpiler/expression.h | 2 +- include/genn/genn/transpiler/statement.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index 2633385d21..1d83bbd27d 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -34,7 +34,7 @@ template class Acceptable : public Base { public: - void accept(Visitor &visitor) const final + virtual void accept(Visitor &visitor) const final { visitor.visit(static_cast(*this)); } diff --git a/include/genn/genn/transpiler/statement.h b/include/genn/genn/transpiler/statement.h index c70ceaf52a..0af0888416 100644 --- a/include/genn/genn/transpiler/statement.h +++ b/include/genn/genn/transpiler/statement.h @@ -34,7 +34,7 @@ template class Acceptable : public Base { public: - void accept(Visitor &visitor) const final + virtual void accept(Visitor &visitor) const final { visitor.visit(static_cast(*this)); } From 369a1fcb98833308a378adf339e91b2fa71a89b3 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 19 Jan 2023 23:23:15 +0000 Subject: [PATCH 085/725] type visitor system - unsure how good an idea this is --- include/genn/genn/type.h | 78 +++++++++++---- src/genn/genn/transpiler/typeChecker.cc | 126 ++++++++++++++++-------- 2 files changed, 145 insertions(+), 59 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 439fc40d3f..953a469b3d 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -18,6 +18,13 @@ // GeNN includes #include "gennExport.h" +// Forward declarations +namespace GeNN::Type +{ +struct UnaryVisitor; +struct BinaryVisitor; +} + //---------------------------------------------------------------------------- // Macros //---------------------------------------------------------------------------- @@ -43,12 +50,7 @@ virtual std::string getResolvedName(const TypeContext&) const final{ return #UNDERLYING_TYPE; } \ virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new TYPE(qualifiers); } \ virtual std::string getLiteralSuffix(const TypeContext&) const final{ return LITERAL_SUFFIX; } \ - }; \ - template<> \ - struct TypeTraits \ - { \ - using NumericType = TYPE; \ - } + } #define DECLARE_FOREIGN_FUNCTION_TYPE(TYPE, RETURN_TYPE, ...) \ class TYPE : public ForeignFunction \ @@ -62,16 +64,11 @@ #define IMPLEMENT_NUMERIC_TYPE(TYPE) IMPLEMENT_TYPE(TYPE) //---------------------------------------------------------------------------- -// GeNN::Type::TypeTraits +// GeNN::Type::TypeContext //---------------------------------------------------------------------------- +//! Map of 'typedef' names to concrete classes namespace GeNN::Type { -//! Empty type trait structure -template -struct TypeTraits -{ -}; - typedef std::unordered_map TypeContext; //---------------------------------------------------------------------------- @@ -104,6 +101,8 @@ class Base //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ + virtual const Type::Base *accept(UnaryVisitor &visitor) const = 0; + //! Get the (unqualified) name of this type virtual std::string getName() const = 0; @@ -133,15 +132,32 @@ class Base const Qualifier m_Qualifiers; }; +//--------------------------------------------------------------------------- +// GeNN::Type::UnaryAcceptable +//--------------------------------------------------------------------------- +template +class UnaryAcceptable : public B +{ +public: + UnaryAcceptable(Qualifier qualifiers = Qualifier{0}) : B(qualifiers) + { + } + + virtual const Type::Base *accept(UnaryVisitor &visitor) const final + { + return visitor.visit(static_cast(this)); + } +}; + //---------------------------------------------------------------------------- // GeNN::Type::Pointer //---------------------------------------------------------------------------- //! Type representing a pointer -class Pointer : public Base +class Pointer : public UnaryAcceptable { public: Pointer(const Base *valueType, Qualifier qualifiers = Qualifier{0}) - : Base(qualifiers), m_ValueType(valueType) + : UnaryAcceptable(qualifiers), m_ValueType(valueType) { } @@ -177,10 +193,10 @@ class ValueBase : public Base //---------------------------------------------------------------------------- // GeNN::Type::NumericBase //---------------------------------------------------------------------------- -class NumericBase : public ValueBase +class NumericBase : public UnaryAcceptable { public: - NumericBase(Qualifier qualifiers = Qualifier{0}) : ValueBase(qualifiers){} + NumericBase(Qualifier qualifiers = Qualifier{0}) : UnaryAcceptable(qualifiers){} //------------------------------------------------------------------------ // Declared virtuals @@ -270,10 +286,10 @@ class NumericTypedef : public NumericBase //---------------------------------------------------------------------------- // GeNN::Type::ForeignFunctionBase //---------------------------------------------------------------------------- -class ForeignFunctionBase : public Base +class ForeignFunctionBase : public UnaryAcceptable { public: - ForeignFunctionBase(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} + ForeignFunctionBase(Qualifier qualifiers = Qualifier{0}) : UnaryAcceptable(qualifiers){} //------------------------------------------------------------------------ // Declared virtuals @@ -373,7 +389,31 @@ class ForeignFunction : public ForeignFunctionBase updateArgumentTypes(args); } } +}; + +//---------------------------------------------------------------------------- +// GeNN::Type::UnaryVisitor +//---------------------------------------------------------------------------- +//! Visitor class used for implementing logic on a single type +struct UnaryVisitor +{ + virtual const Type::Base *visit(const ForeignFunctionBase *function) = 0; + virtual const Type::Base *visit(const Pointer *pointer) = 0; + virtual const Type::Base *visit(const NumericBase *numeric) = 0; +}; + + +//---------------------------------------------------------------------------- +// GeNN::Type::BinaryTypeVisitor +//---------------------------------------------------------------------------- +//! Visitor class used for implementing logic on pairs of types +struct BinaryTypeVisitor +{ + virtual void visit(const Pointer *pointerLeft, const Pointer &pointerRight) = 0; + virtual void visit(const Pointer *pointerLeft, const NumericBase &numericRight) = 0; + virtual void visit(const NumericBase *numericLeft, const NumericBase &numericRight) = 0; + virtual void visit(const NumericBase *numericLeft, const Pointer &pointerRight) = 0; }; //---------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index dcc1c7ef0b..3fcd9a1c75 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -24,7 +24,7 @@ namespace Type = GeNN::Type; //--------------------------------------------------------------------------- namespace { - bool checkPointerTypeAssignement(const Type::Base *rightType, const Type::Base *leftType, const Type::TypeContext &typeContext) +bool checkPointerTypeAssignement(const Type::Base *rightType, const Type::Base *leftType, const Type::TypeContext &typeContext) { // If both are pointers, recurse through value type auto rightPointerType = dynamic_cast(rightType); @@ -66,6 +66,45 @@ bool checkForConstRemoval(const Type::Base *rightType, const Type::Base *leftTyp } +//--------------------------------------------------------------------------- +// UnaryVisitor +//--------------------------------------------------------------------------- +struct UnaryVisitor : public Type::UnaryVisitor +{ + //------------------------------------------------------------------------ + // Type::UnaryVisitor virtuals + //------------------------------------------------------------------------ + virtual const Type::Base *visit(const Type::ForeignFunctionBase*) override + { + throw TypeCheckError(); + } + + virtual const Type::Base *visit(const Type::Pointer*) override + { + throw TypeCheckError(); + } + + virtual const Type::Base *visit(const Type::NumericBase *) override + { + throw TypeCheckError(); + } + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + const Type::Base *getType(const Type::Base *type, ErrorHandlerBase &errorHandler, + const Token &errorToken, std::string_view errorMessage) + { + try { + return type->accept(*this); + } + catch (const TypeCheckError&) { + errorHandler.error(errorToken, errorMessage); + throw TypeCheckError(); + } + } +}; + //--------------------------------------------------------------------------- // EnvironmentInternal @@ -432,56 +471,63 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Unary &unary) final { - const auto rightType = evaluateType(unary.getRight()); + const auto type = evaluateType(unary.getRight()); // If operator is pointer de-reference if (unary.getOperator().type == Token::Type::STAR) { - auto rightPointerType = dynamic_cast(rightType); - if (!rightPointerType) { - m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getName() + "'"); - throw TypeCheckError(); - } + struct Visitor : public UnaryVisitor + { + virtual const Type::Base *visit(const Type::Pointer *pointerType) final + { + return pointerType->getValueType(); + } + }; - // Return value type - m_Type = rightPointerType->getValueType(); + m_Type = Visitor().getType(type, m_ErrorHandler, unary.getOperator(), + "Invalid operand type '" + type->getName() + "'"); } // Otherwise else { - auto rightNumericType = dynamic_cast(rightType); - if (rightNumericType) { - // If operator is arithmetic, return promoted type - if (unary.getOperator().type == Token::Type::PLUS || unary.getOperator().type == Token::Type::MINUS) { - // **THINK** const through these? - m_Type = Type::getPromotedType(rightNumericType, m_Context); - } - // Otherwise, if operator is bitwise - else if (unary.getOperator().type == Token::Type::TILDA) { - // If type is integer, return promoted type - if (rightNumericType->isIntegral(m_Context)) { + struct Visitor : public UnaryVisitor + { + Visitor(const Type::TypeContext &c, Token::Type t) : context(c), opType(t) {} + + virtual const Type::Base *visit(const Type::NumericBase *numericType) final + { + // If operator is arithmetic, return promoted type + if (opType == Token::Type::PLUS || opType == Token::Type::MINUS) { // **THINK** const through these? - m_Type = Type::getPromotedType(rightNumericType, m_Context); + return Type::getPromotedType(numericType, context); } - else { - m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getName() + "'"); - throw TypeCheckError(); + // Otherwise, if operator is bitwise + else if (opType == Token::Type::TILDA) { + // If type is integer, return promoted type + if (numericType->isIntegral(context)) { + // **THINK** const through these? + return Type::getPromotedType(numericType, context); + } + else { + throw TypeCheckError(); + } + } + // Otherwise, if operator is logical + else if (opType == Token::Type::NOT) { + return Type::Int32::getInstance();; + } + // Otherwise, if operator is address of, return pointer type + else if (opType == Token::Type::AMPERSAND) { + return numericType->getPointerType(); } } - // Otherwise, if operator is logical - else if (unary.getOperator().type == Token::Type::NOT) { - m_Type = Type::Int32::getInstance();; - } - // Otherwise, if operator is address of, return pointer type - else if (unary.getOperator().type == Token::Type::AMPERSAND) { - m_Type = rightType->getPointerType(); - } - } - else { - m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getName() + "'"); - throw TypeCheckError(); - } + + const Token::Type opType; + const Type::TypeContext &context; + }; + + m_Type = Visitor(m_Context, unary.getOperator().type).getType( + type, m_ErrorHandler, unary.getOperator(), + "Invalid operand type '" + type->getName() + "'"); + } } From a488188bc1f75f0e7258ef5b664be52d6ce106a3 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 20 Jan 2023 10:16:11 +0000 Subject: [PATCH 086/725] Revert "type visitor system - unsure how good an idea this is" This reverts commit 369a1fcb98833308a378adf339e91b2fa71a89b3. --- include/genn/genn/type.h | 78 ++++----------- src/genn/genn/transpiler/typeChecker.cc | 126 ++++++++---------------- 2 files changed, 59 insertions(+), 145 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 953a469b3d..439fc40d3f 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -18,13 +18,6 @@ // GeNN includes #include "gennExport.h" -// Forward declarations -namespace GeNN::Type -{ -struct UnaryVisitor; -struct BinaryVisitor; -} - //---------------------------------------------------------------------------- // Macros //---------------------------------------------------------------------------- @@ -50,7 +43,12 @@ struct BinaryVisitor; virtual std::string getResolvedName(const TypeContext&) const final{ return #UNDERLYING_TYPE; } \ virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new TYPE(qualifiers); } \ virtual std::string getLiteralSuffix(const TypeContext&) const final{ return LITERAL_SUFFIX; } \ - } + }; \ + template<> \ + struct TypeTraits \ + { \ + using NumericType = TYPE; \ + } #define DECLARE_FOREIGN_FUNCTION_TYPE(TYPE, RETURN_TYPE, ...) \ class TYPE : public ForeignFunction \ @@ -64,11 +62,16 @@ struct BinaryVisitor; #define IMPLEMENT_NUMERIC_TYPE(TYPE) IMPLEMENT_TYPE(TYPE) //---------------------------------------------------------------------------- -// GeNN::Type::TypeContext +// GeNN::Type::TypeTraits //---------------------------------------------------------------------------- -//! Map of 'typedef' names to concrete classes namespace GeNN::Type { +//! Empty type trait structure +template +struct TypeTraits +{ +}; + typedef std::unordered_map TypeContext; //---------------------------------------------------------------------------- @@ -101,8 +104,6 @@ class Base //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ - virtual const Type::Base *accept(UnaryVisitor &visitor) const = 0; - //! Get the (unqualified) name of this type virtual std::string getName() const = 0; @@ -132,32 +133,15 @@ class Base const Qualifier m_Qualifiers; }; -//--------------------------------------------------------------------------- -// GeNN::Type::UnaryAcceptable -//--------------------------------------------------------------------------- -template -class UnaryAcceptable : public B -{ -public: - UnaryAcceptable(Qualifier qualifiers = Qualifier{0}) : B(qualifiers) - { - } - - virtual const Type::Base *accept(UnaryVisitor &visitor) const final - { - return visitor.visit(static_cast(this)); - } -}; - //---------------------------------------------------------------------------- // GeNN::Type::Pointer //---------------------------------------------------------------------------- //! Type representing a pointer -class Pointer : public UnaryAcceptable +class Pointer : public Base { public: Pointer(const Base *valueType, Qualifier qualifiers = Qualifier{0}) - : UnaryAcceptable(qualifiers), m_ValueType(valueType) + : Base(qualifiers), m_ValueType(valueType) { } @@ -193,10 +177,10 @@ class ValueBase : public Base //---------------------------------------------------------------------------- // GeNN::Type::NumericBase //---------------------------------------------------------------------------- -class NumericBase : public UnaryAcceptable +class NumericBase : public ValueBase { public: - NumericBase(Qualifier qualifiers = Qualifier{0}) : UnaryAcceptable(qualifiers){} + NumericBase(Qualifier qualifiers = Qualifier{0}) : ValueBase(qualifiers){} //------------------------------------------------------------------------ // Declared virtuals @@ -286,10 +270,10 @@ class NumericTypedef : public NumericBase //---------------------------------------------------------------------------- // GeNN::Type::ForeignFunctionBase //---------------------------------------------------------------------------- -class ForeignFunctionBase : public UnaryAcceptable +class ForeignFunctionBase : public Base { public: - ForeignFunctionBase(Qualifier qualifiers = Qualifier{0}) : UnaryAcceptable(qualifiers){} + ForeignFunctionBase(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} //------------------------------------------------------------------------ // Declared virtuals @@ -389,31 +373,7 @@ class ForeignFunction : public ForeignFunctionBase updateArgumentTypes(args); } } -}; - -//---------------------------------------------------------------------------- -// GeNN::Type::UnaryVisitor -//---------------------------------------------------------------------------- -//! Visitor class used for implementing logic on a single type -struct UnaryVisitor -{ - virtual const Type::Base *visit(const ForeignFunctionBase *function) = 0; - virtual const Type::Base *visit(const Pointer *pointer) = 0; - virtual const Type::Base *visit(const NumericBase *numeric) = 0; -}; - - -//---------------------------------------------------------------------------- -// GeNN::Type::BinaryTypeVisitor -//---------------------------------------------------------------------------- -//! Visitor class used for implementing logic on pairs of types -struct BinaryTypeVisitor -{ - virtual void visit(const Pointer *pointerLeft, const Pointer &pointerRight) = 0; - virtual void visit(const Pointer *pointerLeft, const NumericBase &numericRight) = 0; - virtual void visit(const NumericBase *numericLeft, const NumericBase &numericRight) = 0; - virtual void visit(const NumericBase *numericLeft, const Pointer &pointerRight) = 0; }; //---------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 3fcd9a1c75..dcc1c7ef0b 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -24,7 +24,7 @@ namespace Type = GeNN::Type; //--------------------------------------------------------------------------- namespace { -bool checkPointerTypeAssignement(const Type::Base *rightType, const Type::Base *leftType, const Type::TypeContext &typeContext) + bool checkPointerTypeAssignement(const Type::Base *rightType, const Type::Base *leftType, const Type::TypeContext &typeContext) { // If both are pointers, recurse through value type auto rightPointerType = dynamic_cast(rightType); @@ -66,45 +66,6 @@ bool checkForConstRemoval(const Type::Base *rightType, const Type::Base *leftTyp } -//--------------------------------------------------------------------------- -// UnaryVisitor -//--------------------------------------------------------------------------- -struct UnaryVisitor : public Type::UnaryVisitor -{ - //------------------------------------------------------------------------ - // Type::UnaryVisitor virtuals - //------------------------------------------------------------------------ - virtual const Type::Base *visit(const Type::ForeignFunctionBase*) override - { - throw TypeCheckError(); - } - - virtual const Type::Base *visit(const Type::Pointer*) override - { - throw TypeCheckError(); - } - - virtual const Type::Base *visit(const Type::NumericBase *) override - { - throw TypeCheckError(); - } - - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - const Type::Base *getType(const Type::Base *type, ErrorHandlerBase &errorHandler, - const Token &errorToken, std::string_view errorMessage) - { - try { - return type->accept(*this); - } - catch (const TypeCheckError&) { - errorHandler.error(errorToken, errorMessage); - throw TypeCheckError(); - } - } -}; - //--------------------------------------------------------------------------- // EnvironmentInternal @@ -471,63 +432,56 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Unary &unary) final { - const auto type = evaluateType(unary.getRight()); + const auto rightType = evaluateType(unary.getRight()); // If operator is pointer de-reference if (unary.getOperator().type == Token::Type::STAR) { - struct Visitor : public UnaryVisitor - { - virtual const Type::Base *visit(const Type::Pointer *pointerType) final - { - return pointerType->getValueType(); - } - }; + auto rightPointerType = dynamic_cast(rightType); + if (!rightPointerType) { + m_ErrorHandler.error(unary.getOperator(), + "Invalid operand type '" + rightType->getName() + "'"); + throw TypeCheckError(); + } - m_Type = Visitor().getType(type, m_ErrorHandler, unary.getOperator(), - "Invalid operand type '" + type->getName() + "'"); + // Return value type + m_Type = rightPointerType->getValueType(); } // Otherwise else { - struct Visitor : public UnaryVisitor - { - Visitor(const Type::TypeContext &c, Token::Type t) : context(c), opType(t) {} - - virtual const Type::Base *visit(const Type::NumericBase *numericType) final - { - // If operator is arithmetic, return promoted type - if (opType == Token::Type::PLUS || opType == Token::Type::MINUS) { + auto rightNumericType = dynamic_cast(rightType); + if (rightNumericType) { + // If operator is arithmetic, return promoted type + if (unary.getOperator().type == Token::Type::PLUS || unary.getOperator().type == Token::Type::MINUS) { + // **THINK** const through these? + m_Type = Type::getPromotedType(rightNumericType, m_Context); + } + // Otherwise, if operator is bitwise + else if (unary.getOperator().type == Token::Type::TILDA) { + // If type is integer, return promoted type + if (rightNumericType->isIntegral(m_Context)) { // **THINK** const through these? - return Type::getPromotedType(numericType, context); - } - // Otherwise, if operator is bitwise - else if (opType == Token::Type::TILDA) { - // If type is integer, return promoted type - if (numericType->isIntegral(context)) { - // **THINK** const through these? - return Type::getPromotedType(numericType, context); - } - else { - throw TypeCheckError(); - } - } - // Otherwise, if operator is logical - else if (opType == Token::Type::NOT) { - return Type::Int32::getInstance();; + m_Type = Type::getPromotedType(rightNumericType, m_Context); } - // Otherwise, if operator is address of, return pointer type - else if (opType == Token::Type::AMPERSAND) { - return numericType->getPointerType(); + else { + m_ErrorHandler.error(unary.getOperator(), + "Invalid operand type '" + rightType->getName() + "'"); + throw TypeCheckError(); } } - - const Token::Type opType; - const Type::TypeContext &context; - }; - - m_Type = Visitor(m_Context, unary.getOperator().type).getType( - type, m_ErrorHandler, unary.getOperator(), - "Invalid operand type '" + type->getName() + "'"); - + // Otherwise, if operator is logical + else if (unary.getOperator().type == Token::Type::NOT) { + m_Type = Type::Int32::getInstance();; + } + // Otherwise, if operator is address of, return pointer type + else if (unary.getOperator().type == Token::Type::AMPERSAND) { + m_Type = rightType->getPointerType(); + } + } + else { + m_ErrorHandler.error(unary.getOperator(), + "Invalid operand type '" + rightType->getName() + "'"); + throw TypeCheckError(); + } } } From 90f7d6594daf41021480460c4786b1f9e34e9f64 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 20 Jan 2023 10:24:02 +0000 Subject: [PATCH 087/725] * Disabled stuff so everything compiles on GCC 7.5.0 * Fixed forward declaration issues with Acceptable classes --- include/genn/genn/transpiler/expression.h | 59 +++++++++++-------- include/genn/genn/transpiler/statement.h | 55 ++++++++++------- .../genn/genn/transpiler/transpilerUtils.h | 6 +- src/genn/genn/transpiler/scanner.cc | 1 - 4 files changed, 71 insertions(+), 50 deletions(-) diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index 1d83bbd27d..27f27513a7 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -13,14 +13,47 @@ // Forward declarations namespace GeNN::Transpiler::Expression { -class Visitor; +class ArraySubscript; +class Assignment; +class Binary; +class Call; +class Cast; +class Conditional; +class Grouping; +class Literal; +class Logical; +class PostfixIncDec; +class PrefixIncDec; +class Variable; +class Unary; } //--------------------------------------------------------------------------- -// GeNN::Transpiler::Expression::Base +// GeNN::Transpiler::Expression::Visitor //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Expression { +class Visitor +{ +public: + virtual void visit(const ArraySubscript &arraySubscript) = 0; + virtual void visit(const Assignment &assignement) = 0; + virtual void visit(const Binary &binary) = 0; + virtual void visit(const Call &call) = 0; + virtual void visit(const Cast &cast) = 0; + virtual void visit(const Conditional &conditional) = 0; + virtual void visit(const Grouping &grouping) = 0; + virtual void visit(const Literal &literal) = 0; + virtual void visit(const Logical &logical) = 0; + virtual void visit(const PostfixIncDec &postfixIncDec) = 0; + virtual void visit(const PrefixIncDec &postfixIncDec) = 0; + virtual void visit(const Variable &variable) = 0; + virtual void visit(const Unary &unary) = 0; +}; + +//--------------------------------------------------------------------------- +// GeNN::Transpiler::Expression::Base +//--------------------------------------------------------------------------- class Base { public: @@ -284,26 +317,4 @@ class Unary : public Acceptable const Token m_Operator; const ExpressionPtr m_Right; }; - - -//--------------------------------------------------------------------------- -// GeNN::Transpiler::Expression::Visitor -//--------------------------------------------------------------------------- -class Visitor -{ -public: - virtual void visit(const ArraySubscript &arraySubscript) = 0; - virtual void visit(const Assignment &assignement) = 0; - virtual void visit(const Binary &binary) = 0; - virtual void visit(const Call &call) = 0; - virtual void visit(const Cast &cast) = 0; - virtual void visit(const Conditional &conditional) = 0; - virtual void visit(const Grouping &grouping) = 0; - virtual void visit(const Literal &literal) = 0; - virtual void visit(const Logical &logical) = 0; - virtual void visit(const PostfixIncDec &postfixIncDec) = 0; - virtual void visit(const PrefixIncDec &postfixIncDec) = 0; - virtual void visit(const Variable &variable) = 0; - virtual void visit(const Unary &unary) = 0; -}; } // namespace GeNN::Transpiler::Expression diff --git a/include/genn/genn/transpiler/statement.h b/include/genn/genn/transpiler/statement.h index 0af0888416..1dc454edd4 100644 --- a/include/genn/genn/transpiler/statement.h +++ b/include/genn/genn/transpiler/statement.h @@ -13,14 +13,45 @@ // Forward declarations namespace GeNN::Transpiler::Statement { -class Visitor; +class Break; +class Compound; +class Continue; +class Do; +class Expression; +class For; +class If; +class Labelled; +class Switch; +class VarDeclaration; +class While; +class Print; } //--------------------------------------------------------------------------- -// GeNN::Transpiler::Statement::Base +// GeNN::Transpiler::Statement::Visitor //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Statement { +class Visitor +{ +public: + virtual void visit(const Break &breakStatement) = 0; + virtual void visit(const Compound &compound) = 0; + virtual void visit(const Continue &continueStatement) = 0; + virtual void visit(const Do &doStatement) = 0; + virtual void visit(const Expression &expression) = 0; + virtual void visit(const For &forStatement) = 0; + virtual void visit(const If &ifStatement) = 0; + virtual void visit(const Labelled &labelled) = 0; + virtual void visit(const Switch &switchStatement) = 0; + virtual void visit(const VarDeclaration &varDeclaration) = 0; + virtual void visit(const While &whileStatement) = 0; + virtual void visit(const Print &print) = 0; +}; + +//--------------------------------------------------------------------------- +// GeNN::Transpiler::Statement::Base +//--------------------------------------------------------------------------- class Base { public: @@ -273,24 +304,4 @@ class Print : public Acceptable private: const ExpressionPtr m_Expression; }; - -//--------------------------------------------------------------------------- -// GeNN::Transpiler::Statement::Visitor -//--------------------------------------------------------------------------- -class Visitor -{ -public: - virtual void visit(const Break &breakStatement) = 0; - virtual void visit(const Compound &compound) = 0; - virtual void visit(const Continue &continueStatement) = 0; - virtual void visit(const Do &doStatement) = 0; - virtual void visit(const Expression &expression) = 0; - virtual void visit(const For &forStatement) = 0; - virtual void visit(const If &ifStatement) = 0; - virtual void visit(const Labelled &labelled) = 0; - virtual void visit(const Switch &switchStatement) = 0; - virtual void visit(const VarDeclaration &varDeclaration) = 0; - virtual void visit(const While &whileStatement) = 0; - virtual void visit(const Print &print) = 0; -}; } // namespace GeNN::Transpiler::Statement diff --git a/include/genn/genn/transpiler/transpilerUtils.h b/include/genn/genn/transpiler/transpilerUtils.h index 5c6340dcd6..92f79059b3 100644 --- a/include/genn/genn/transpiler/transpilerUtils.h +++ b/include/genn/genn/transpiler/transpilerUtils.h @@ -1,7 +1,7 @@ #pragma once // Standard C++ includes -#include +//#include #include #include @@ -10,7 +10,7 @@ namespace GeNN::Transpiler::Utils template struct Overload : Ts... { using Ts::operator()...; }; template Overload(Ts...) -> Overload; // line not needed in -template +/*template T toCharsThrow(std::string_view input, int base = 10) { T out; @@ -30,5 +30,5 @@ T toCharsThrow(std::string_view input, int base = 10) throw std::out_of_range("Unable to convert chars '" + std::string{input} + "'"); } return out; -} +}*/ } // namespace GeNN::Transpiler::Utils diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index ca25a294fd..f534456ade 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -1,7 +1,6 @@ #include "transpiler/scanner.h" // Standard C++ includes -#include #include #include #include From f84b420b4d2af380484fb20918bf30dd9a7f48a9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 20 Jan 2023 10:51:21 +0000 Subject: [PATCH 088/725] added cunning hack to detect writing of type pointers to stream via....linker errors --- include/genn/genn/type.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 439fc40d3f..b995a311fe 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -408,4 +408,7 @@ const NumericBase *getPromotedType(const NumericBase *type, const TypeContext &c //! Apply C rules to get common type between numeric types a and b const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b, const TypeContext &context); + +// **YUCK** unimplemented stream operator so we get linker errors if you try and write types directly to an IO stream +std::ostream& operator<<(std::ostream &stream, const Base* value); } // namespace GeNN::Type From 387950982980e7ddb5bafc70c621680fd4838118 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 20 Jan 2023 10:51:34 +0000 Subject: [PATCH 089/725] started fixing type writing issues --- .../backends/single_threaded_cpu/backend.cc | 10 +-- src/genn/genn/code_generator/backendSIMT.cc | 2 +- .../genn/code_generator/generateRunner.cc | 10 +-- .../genn/code_generator/initGroupMerged.cc | 61 ++++++++++--------- .../code_generator/neuronUpdateGroupMerged.cc | 22 +++---- 5 files changed, 54 insertions(+), 51 deletions(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 995811a87a..3f4e0982f8 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -141,7 +141,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // Generate preamble preambleHandler(os); - os << "void updateNeurons(" << model.getTimePrecision() << " t"; + os << "void updateNeurons(" << model.getTimePrecision()->getName() << " t"; if(model.isRecordingInUse()) { os << ", unsigned int recordingTimestep"; } @@ -313,7 +313,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge // Generate preamble preambleHandler(os); - os << "void updateSynapses(" << model.getTimePrecision() << " t)"; + os << "void updateSynapses(" << model.getTimePrecision()->getName() << " t)"; { CodeStream::Scope b(os); Substitutions funcSubs(getFunctionTemplates(model.getPrecision()->getName())); @@ -1216,9 +1216,9 @@ void Backend::genDefinitionsPreamble(CodeStream &os, const ModelSpecMerged &mode // If a global RNG is required, define standard host distributions as recreating them each call is slow if(isGlobalHostRNGRequired(modelMerged)) { - os << "EXPORT_VAR " << "std::uniform_real_distribution<" << model.getPrecision() << "> standardUniformDistribution;" << std::endl; - os << "EXPORT_VAR " << "std::normal_distribution<" << model.getPrecision() << "> standardNormalDistribution;" << std::endl; - os << "EXPORT_VAR " << "std::exponential_distribution<" << model.getPrecision() << "> standardExponentialDistribution;" << std::endl; + os << "EXPORT_VAR " << "std::uniform_real_distribution<" << model.getPrecision()->getName() << "> standardUniformDistribution;" << std::endl; + os << "EXPORT_VAR " << "std::normal_distribution<" << model.getPrecision()->getName() << "> standardNormalDistribution;" << std::endl; + os << "EXPORT_VAR " << "std::exponential_distribution<" << model.getPrecision()->getName() << "> standardExponentialDistribution;" << std::endl; os << std::endl; } } diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 863334e187..e493e6e948 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -677,7 +677,7 @@ void BackendSIMT::genPresynapticUpdateKernel(CodeStream &os, const Substitutions // If any shared memory is required, declare array if(maxSharedMemPerThread > 0) { - os << getSharedPrefix() << modelMerged.getModel().getPrecision() << " shLg[" << maxSharedMemPerThread * getKernelBlockSize(KernelPresynapticUpdate) << "];" << std::endl; + os << getSharedPrefix() << modelMerged.getModel().getPrecision()->getName() << " shLg[" << maxSharedMemPerThread * getKernelBlockSize(KernelPresynapticUpdate) << "];" << std::endl; } // If any of these synapse groups also have sparse connectivity, allocate shared memory for row length diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index d05479e616..878503afa0 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -560,7 +560,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, definitions << "#define DT " << Utils::writePreciseString(model.getDT()) << model.getTimePrecision()->getLiteralSuffix(modelMerged.getTypeContext()) << std::endl; // Typedefine scalar type - definitions << "typedef " << model.getPrecision() << " scalar;" << std::endl; + definitions << "typedef " << model.getPrecision()->getName() << " scalar;" << std::endl; // Write ranges of scalar and time types genTypeRange(definitions, model.getPrecision(), modelMerged.getTypeContext(), "SCALAR"); @@ -618,9 +618,9 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Define and declare time variables definitionsVar << "EXPORT_VAR unsigned long long iT;" << std::endl; - definitionsVar << "EXPORT_VAR " << model.getTimePrecision() << " t;" << std::endl; + definitionsVar << "EXPORT_VAR " << model.getTimePrecision()->getName() << " t;" << std::endl; runnerVarDecl << "unsigned long long iT;" << std::endl; - runnerVarDecl << model.getTimePrecision() << " t;" << std::endl; + runnerVarDecl << model.getTimePrecision()->getName() << " t;" << std::endl; if(model.isRecordingInUse()) { runnerVarDecl << "unsigned long long numRecordingTimesteps = 0;" << std::endl; @@ -1786,12 +1786,12 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, definitions << "EXPORT_FUNC void stepTime();" << std::endl; definitions << std::endl; definitions << "// Functions generated by backend" << std::endl; - definitions << "EXPORT_FUNC void updateNeurons(" << model.getTimePrecision() << " t"; + definitions << "EXPORT_FUNC void updateNeurons(" << model.getTimePrecision()->getName() << " t"; if(model.isRecordingInUse()) { definitions << ", unsigned int recordingTimestep"; } definitions << "); " << std::endl; - definitions << "EXPORT_FUNC void updateSynapses(" << model.getTimePrecision() << " t);" << std::endl; + definitions << "EXPORT_FUNC void updateSynapses(" << model.getTimePrecision()->getName() << " t);" << std::endl; definitions << "EXPORT_FUNC void initialize();" << std::endl; definitions << "EXPORT_FUNC void initializeSparse();" << std::endl; diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 99c7d5b605..c245f9f141 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -52,10 +52,10 @@ void genScalarFill(CodeStream &os, const std::string &fieldName, const std::stri } //------------------------------------------------------------------------ template -void genInitNeuronVarCode(CodeStream &os, const BackendBase &backend, const Substitutions &popSubs, +void genInitNeuronVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const BackendBase &backend, const Substitutions &popSubs, const Models::Base::VarVec &vars, const std::unordered_map &varInitialisers, const std::string &fieldSuffix, const std::string &countMember, - size_t numDelaySlots, const size_t groupIndex, const Type::NumericBase *scalarType, unsigned int batchSize, + size_t numDelaySlots, const size_t groupIndex, unsigned int batchSize, Q isVarQueueRequired, P isParamHeterogeneousFn, D isDerivedParamHeterogeneousFn) { const std::string count = "group->" + countMember; @@ -82,11 +82,11 @@ void genInitNeuronVarCode(CodeStream &os, const BackendBase &backend, const Subs if (getVarAccessDuplication(var.access) == VarAccessDuplication::SHARED_NEURON) { backend.genPopVariableInit( os, varSubs, - [&var, &varInit, &fieldSuffix, scalarType, batchSize, groupIndex, numDelaySlots, isVarQueueRequired] + [&var, &varInit, &fieldSuffix, &modelMerged, batchSize, groupIndex, numDelaySlots, isVarQueueRequired] (CodeStream &os, Substitutions &varInitSubs) { // Generate initial value into temporary variable - os << var.type << " initVal;" << std::endl; + os << var.type->getResolvedName(modelMerged.getTypeContext()) << " initVal;" << std::endl; varInitSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); varInitSubs.applyCheckUnreplaced(code, "initVar : " + var.name + "merged" + std::to_string(groupIndex)); @@ -124,12 +124,12 @@ void genInitNeuronVarCode(CodeStream &os, const BackendBase &backend, const Subs } //------------------------------------------------------------------------ template -void genInitNeuronVarCode(CodeStream &os, const BackendBase &backend, const Substitutions &popSubs, +void genInitNeuronVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const BackendBase &backend, const Substitutions &popSubs, const Models::Base::VarVec &vars, const std::unordered_map &varInitialisers, const std::string &fieldSuffix, const std::string &countMember, const size_t groupIndex, - const Type::NumericBase *scalarType, unsigned int batchSize, P isParamHeterogeneousFn, D isDerivedParamHeterogeneousFn) + unsigned int batchSize, P isParamHeterogeneousFn, D isDerivedParamHeterogeneousFn) { - genInitNeuronVarCode(os, backend, popSubs, vars, varInitialisers, fieldSuffix, countMember, 0, groupIndex, scalarType, batchSize, + genInitNeuronVarCode(os, modelMerged, backend, popSubs, vars, varInitialisers, fieldSuffix, countMember, 0, groupIndex, batchSize, [](const std::string&){ return false; }, isParamHeterogeneousFn, isDerivedParamHeterogeneousFn); @@ -300,8 +300,8 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream } // Initialise neuron variables - genInitNeuronVarCode(os, backend, popSubs, getArchetype().getNeuronModel()->getVars(), getArchetype().getVarInitialisers(), - "", "numNeurons", getArchetype().getNumDelaySlots(), getIndex(), model.getPrecision(), model.getBatchSize(), + genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getNeuronModel()->getVars(), getArchetype().getVarInitialisers(), + "", "numNeurons", getArchetype().getNumDelaySlots(), getIndex(), model.getBatchSize(), [this](const std::string &v){ return getArchetype().isVarQueueRequired(v); }, [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); @@ -340,19 +340,20 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream }); } - genInitNeuronVarCode(os, backend, popSubs, sg->getPSModel()->getVars(), sg->getPSVarInitialisers(), - "InSyn" + std::to_string(i), "numNeurons", i, model.getPrecision(), model.getBatchSize(), - [i, this](const std::string &v, const std::string &p) { return isPSMVarInitParamHeterogeneous(i, v, p); }, - [i, this](const std::string &v, const std::string &p) { return isPSMVarInitDerivedParamHeterogeneous(i, v, p); }); + // **TODO** adaptor + genInitNeuronVarCode(os, modelMerged, backend, popSubs, sg->getPSModel()->getVars(), sg->getPSVarInitialisers(), + "InSyn" + std::to_string(i), "numNeurons", i, model.getBatchSize(), + [i, this](const std::string &v, const std::string &p) { return isPSMVarInitParamHeterogeneous(i, v, p); }, + [i, this](const std::string &v, const std::string &p) { return isPSMVarInitDerivedParamHeterogeneous(i, v, p); }); } // Loop through incoming synaptic populations with postsynaptic variables // **NOTE** number of delay slots is based on the target neuron (for simplicity) but whether delay is required is based on the synapse group for(size_t i = 0; i < getSortedArchetypeInSynWithPostVars().size(); i++) { const auto *sg = getSortedArchetypeInSynWithPostVars().at(i); - genInitNeuronVarCode(os, backend, popSubs, sg->getWUModel()->getPostVars(), sg->getWUPostVarInitialisers(), - "WUPost" + std::to_string(i), "numNeurons", sg->getTrgNeuronGroup()->getNumDelaySlots(), - i, model.getPrecision(), model.getBatchSize(), + // **TODO** adaptor + genInitNeuronVarCode(os, modelMerged, backend, popSubs, sg->getWUModel()->getPostVars(), sg->getWUPostVarInitialisers(), + "WUPost" + std::to_string(i), "numNeurons", sg->getTrgNeuronGroup()->getNumDelaySlots(), i, model.getBatchSize(), [&sg](const std::string&){ return (sg->getBackPropDelaySteps() != NO_DELAY); }, [i, this](const std::string &v, const std::string &p) { return isInSynWUMVarInitParamHeterogeneous(i, v, p); }, [i, this](const std::string &v, const std::string &p) { return isInSynWUMVarInitDerivedParamHeterogeneous(i, v, p); }); @@ -362,9 +363,9 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream // **NOTE** number of delay slots is based on the source neuron (for simplicity) but whether delay is required is based on the synapse group for(size_t i = 0; i < getSortedArchetypeOutSynWithPreVars().size(); i++) { const auto *sg = getSortedArchetypeOutSynWithPreVars().at(i); - genInitNeuronVarCode(os, backend, popSubs, sg->getWUModel()->getPreVars(), sg->getWUPreVarInitialisers(), - "WUPre" + std::to_string(i), "numNeurons", sg->getSrcNeuronGroup()->getNumDelaySlots(), - i, model.getPrecision(), model.getBatchSize(), + // **TODO** adaptor + genInitNeuronVarCode(os, modelMerged, backend, popSubs, sg->getWUModel()->getPreVars(), sg->getWUPreVarInitialisers(), + "WUPre" + std::to_string(i), "numNeurons", sg->getSrcNeuronGroup()->getNumDelaySlots(), i, model.getBatchSize(), [&sg](const std::string&){ return (sg->getDelaySteps() != NO_DELAY); }, [i, this](const std::string &v, const std::string &p) { return isOutSynWUMVarInitParamHeterogeneous(i, v, p); }, [i, this](const std::string &v, const std::string &p) { return isOutSynWUMVarInitDerivedParamHeterogeneous(i, v, p); }); @@ -385,9 +386,9 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream os << "// current source variables" << std::endl; for(size_t i = 0; i < getSortedArchetypeCurrentSources().size(); i++) { const auto *cs = getSortedArchetypeCurrentSources().at(i); - - genInitNeuronVarCode(os, backend, popSubs, cs->getCurrentSourceModel()->getVars(), cs->getVarInitialisers(), - "CS" + std::to_string(i), "numNeurons", i, model.getPrecision(), model.getBatchSize(), + // **TODO** adaptor + genInitNeuronVarCode(os, modelMerged, backend, popSubs, cs->getCurrentSourceModel()->getVars(), cs->getVarInitialisers(), + "CS" + std::to_string(i), "numNeurons", i, model.getBatchSize(), [i, this](const std::string &v, const std::string &p) { return isCurrentSourceVarInitParamHeterogeneous(i, v, p); }, [i, this](const std::string &v, const std::string &p) { return isCurrentSourceVarInitDerivedParamHeterogeneous(i, v, p); }); } @@ -653,7 +654,7 @@ void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase&, "", "group->", var.name); // Generate initial value into temporary variable - os << var.type << " initVal;" << std::endl; + os << var.type->getResolvedName(getTypeContext()) << " initVal;" << std::endl; popSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); //popSubs.applyCheckUnreplaced(code, "initVar : merged" + vars[k].name + std::to_string(sg.getIndex())); @@ -885,8 +886,8 @@ boost::uuids::detail::sha1::digest_type CustomUpdateInitGroupMerged::getHashDige void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { // Initialise custom update variables - genInitNeuronVarCode(os, backend, popSubs, getArchetype().getCustomUpdateModel()->getVars(), getArchetype().getVarInitialisers(), - "", "size", getIndex(), modelMerged.getModel().getPrecision(), getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, + genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getCustomUpdateModel()->getVars(), getArchetype().getVarInitialisers(), + "", "size", getIndex(), getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); } @@ -1115,8 +1116,9 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePreInitGroupMerg void CustomConnectivityUpdatePreInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { // Initialise presynaptic custom connectivity update variables - genInitNeuronVarCode(os, backend, popSubs, getArchetype().getCustomConnectivityUpdateModel()->getPreVars(), getArchetype().getPreVarInitialisers(), - "", "size", getIndex(), modelMerged.getModel().getPrecision(), 1, + // **TODO** adaptor + genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getCustomConnectivityUpdateModel()->getPreVars(), getArchetype().getPreVarInitialisers(), + "", "size", getIndex(), 1, [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); } @@ -1156,8 +1158,9 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePostInitGroupMer void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { // Initialise presynaptic custom connectivity update variables - genInitNeuronVarCode(os, backend, popSubs, getArchetype().getCustomConnectivityUpdateModel()->getPostVars(), getArchetype().getPostVarInitialisers(), - "", "size", getIndex(), modelMerged.getModel().getPrecision(), 1, + // **TODO** adapter + genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getCustomConnectivityUpdateModel()->getPostVars(), getArchetype().getPostVarInitialisers(), + "", "size", getIndex(), 1, [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 4d37101d57..8f755e9986 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -187,26 +187,26 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type << " l" << v.name << " = group->" << v.name << "["; + os << v.type->getResolvedName(getTypeContext()) << " l" << v.name << " = group->" << v.name << "["; const bool delayed = (getArchetype().isVarQueueRequired(v.name) && getArchetype().isDelayRequired()); os << getReadVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), popSubs["id"]) << "];" << std::endl; } // Also read spike and spike-like-event times into local variables if required if(getArchetype().isSpikeTimeRequired()) { - os << "const " << model.getTimePrecision() << " lsT = group->sT["; + os << "const " << model.getTimePrecision()->getName() << " lsT = group->sT["; os << getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; } if(getArchetype().isPrevSpikeTimeRequired()) { - os << "const " << model.getTimePrecision() << " lprevST = group->prevST["; + os << "const " << model.getTimePrecision()->getName() << " lprevST = group->prevST["; os << getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; } if(getArchetype().isSpikeEventTimeRequired()) { - os << "const " << model.getTimePrecision() << " lseT = group->seT["; + os << "const " << model.getTimePrecision()->getName() << " lseT = group->seT["; os << getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; } if(getArchetype().isPrevSpikeEventTimeRequired()) { - os << "const " << model.getTimePrecision() << " lprevSET = group->prevSET["; + os << "const " << model.getTimePrecision()->getName() << " lprevSET = group->prevSET["; os << getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; } os << std::endl; @@ -221,7 +221,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C || sg->getPSModel()->getDecayCode().find("$(Isyn)") != std::string::npos); })) { - os << model.getPrecision() << " Isyn = 0;" << std::endl; + os << model.getPrecision()->getName() << " Isyn = 0;" << std::endl; } Substitutions neuronSubs(&popSubs); @@ -260,13 +260,13 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C const auto *psm = sg->getPSModel(); os << "// pull inSyn values in a coalesced access" << std::endl; - os << model.getPrecision() << " linSyn = group->inSynInSyn" << i << "["; + os << model.getPrecision()->getName() << " linSyn = group->inSynInSyn" << i << "["; os << getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; // If dendritic delay is required if (sg->isDendriticDelayRequired()) { // Get reference to dendritic delay buffer input for this timestep - os << backend.getPointerPrefix() << model.getPrecision() << " *denDelayFront = "; + os << backend.getPointerPrefix() << model.getPrecision()->getName() << " *denDelayFront = "; os << "&group->denDelayInSyn" << i << "[(*group->denDelayPtrInSyn" << i << " * group->numNeurons) + "; os << getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; @@ -282,7 +282,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type << " lps" << v.name << " = group->" << v.name << "InSyn" << i << "["; + os << v.type->getResolvedName(getTypeContext()) << " lps" << v.name << " = group->" << v.name << "InSyn" << i << "["; os << getVarIndex(batchSize, getVarAccessDuplication(v.access), neuronSubs["id"]) << "];" << std::endl; } @@ -366,7 +366,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type << " lcs" << v.name << " = " << "group->" << v.name << "CS" << i << "["; + os << v.type->getResolvedName(getTypeContext()) << " lcs" << v.name << " = " << "group->" << v.name << "CS" << i << "["; os << getVarIndex(batchSize, getVarAccessDuplication(v.access), popSubs["id"]) << "];" << std::endl; } @@ -794,7 +794,7 @@ void NeuronUpdateGroupMerged::generateWUVarUpdate(CodeStream &os, const Substitu if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type << " l" << v.name << " = group->" << v.name << fieldPrefixStem << i << "["; + os << v.type->getResolvedName(getTypeContext()) << " l" << v.name << " = group->" << v.name << fieldPrefixStem << i << "["; os << getReadVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), subs["id"]) << "];" << std::endl; } From be1076a8af012885278112f9d2dde6d785a38a1a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 20 Jan 2023 13:41:00 +0000 Subject: [PATCH 090/725] fixed some more type-writing issues --- .../customConnectivityUpdateGroupMerged.cc | 2 +- .../code_generator/customUpdateGroupMerged.cc | 4 +-- .../genn/code_generator/initGroupMerged.cc | 35 +++++++++---------- .../synapseUpdateGroupMerged.cc | 2 +- 4 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index c2c5738204..5f507b4df3 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -266,7 +266,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back if ((modelMerged.getModel().getBatchSize() > 1) && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) { // Copy parameter into a register (just incase it's e.g. a RNG call) and copy into all batches - addSynapse << "const " << ccuVarRefs[i].type << " _" << ccuVarRefs[i].name << "Val = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; + addSynapse << "const " << ccuVarRefs[i].type->getResolvedName(getTypeContext()) << " _" << ccuVarRefs[i].name << "Val = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 033328459f..238a3a5b8f 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -34,7 +34,7 @@ void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, const if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type << " l" << v.name; + os << v.type->getResolvedName(cg.getTypeContext()) << " l" << v.name; // If this isn't a reduction, read value from memory // **NOTE** by not initialising these variables for reductions, @@ -54,7 +54,7 @@ void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, const os << "const "; } - os << v.type << " l" << v.name; + os << v.type->getResolvedName(cg.getTypeContext()) << " l" << v.name; // If this isn't a reduction, read value from memory // **NOTE** by not initialising these variables for reductions, diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index c245f9f141..45f3fe75e7 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -102,11 +102,11 @@ void genInitNeuronVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, co else { backend.genVariableInit( os, count, "id", varSubs, - [&var, &varInit, &fieldSuffix, batchSize, groupIndex, count, numDelaySlots, isVarQueueRequired] + [&var, &varInit, &modelMerged, &fieldSuffix, batchSize, groupIndex, count, numDelaySlots, isVarQueueRequired] (CodeStream &os, Substitutions &varInitSubs) { // Generate initial value into temporary variable - os << var.type << " initVal;" << std::endl; + os << var.type->getResolvedName(modelMerged.getTypeContext()) << " initVal;" << std::endl; varInitSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); varInitSubs.applyCheckUnreplaced(code, "initVar : " + var.name + "merged" + std::to_string(groupIndex)); @@ -137,9 +137,9 @@ void genInitNeuronVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, co //------------------------------------------------------------------------ // Initialise one row of weight update model variables template -void genInitWUVarCode(CodeStream &os, const Substitutions &popSubs, +void genInitWUVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const Substitutions &popSubs, const Models::Base::VarVec &vars, const std::unordered_map &varInitialisers, - const std::string &stride, const size_t groupIndex, const Type::NumericBase *scalarType, unsigned int batchSize, + const std::string &stride, const size_t groupIndex, unsigned int batchSize, P isParamHeterogeneousFn, D isDerivedParamHeterogeneousFn, G genSynapseVariableRowInitFn) { for (const auto &var : vars) { @@ -151,7 +151,7 @@ void genInitWUVarCode(CodeStream &os, const Substitutions &popSubs, // Generate target-specific code to initialise variable genSynapseVariableRowInitFn(os, popSubs, - [&var, &varInit, &stride, batchSize, groupIndex, isParamHeterogeneousFn, isDerivedParamHeterogeneousFn, scalarType] + [&var, &varInit, &stride, &modelMerged, batchSize, groupIndex, isParamHeterogeneousFn, isDerivedParamHeterogeneousFn] (CodeStream &os, Substitutions &varSubs) { varSubs.addParamValueSubstitution(varInit.getSnippet()->getParamNames(), varInit.getParams(), @@ -164,7 +164,7 @@ void genInitWUVarCode(CodeStream &os, const Substitutions &popSubs, "", "group->", var.name); // Generate initial value into temporary variable - os << var.type << " initVal;" << std::endl; + os << var.type->getResolvedName(modelMerged.getTypeContext()) << " initVal;" << std::endl; varSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); varSubs.applyCheckUnreplaced(code, "initVar : merged" + var.name + std::to_string(groupIndex)); @@ -575,9 +575,8 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, CodeStream // Generate initialisation code const std::string stride = kernel ? "batchStride" : "group->numSrcNeurons * group->rowStride"; - genInitWUVarCode(os, popSubs, getArchetype().getWUModel()->getVars(), - getArchetype().getWUVarInitialisers(), stride, getIndex(), - modelMerged.getModel().getPrecision(), modelMerged.getModel().getBatchSize(), + genInitWUVarCode(os, modelMerged, popSubs, getArchetype().getWUModel()->getVars(), + getArchetype().getWUVarInitialisers(), stride, getIndex(), modelMerged.getModel().getBatchSize(), [this](const std::string &v, const std::string &p) { return isWUVarInitParamHeterogeneous(v, p); }, [this](const std::string &v, const std::string &p) { return isWUVarInitDerivedParamHeterogeneous(v, p); }, [&backend, kernel, this](CodeStream &os, const Substitutions &kernelSubs, BackendBase::Handler handler) @@ -603,9 +602,8 @@ const std::string SynapseSparseInitGroupMerged::name = "SynapseSparseInit"; //---------------------------------------------------------------------------- void SynapseSparseInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { - genInitWUVarCode(os, popSubs, getArchetype().getWUModel()->getVars(), - getArchetype().getWUVarInitialisers(), "group->numSrcNeurons * group->rowStride", getIndex(), - modelMerged.getModel().getPrecision(), modelMerged.getModel().getBatchSize(), + genInitWUVarCode(os, modelMerged, popSubs, getArchetype().getWUModel()->getVars(), + getArchetype().getWUVarInitialisers(), "group->numSrcNeurons * group->rowStride", getIndex(), modelMerged.getModel().getBatchSize(), [this](const std::string &v, const std::string &p) { return isWUVarInitParamHeterogeneous(v, p); }, [this](const std::string &v, const std::string &p) { return isWUVarInitDerivedParamHeterogeneous(v, p); }, [&backend](CodeStream &os, const Substitutions &kernelSubs, BackendBase::Handler handler) @@ -981,9 +979,9 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Cod // Loop through rows const std::string stride = kernel ? "batchStride" : "group->numSrcNeurons * group->rowStride"; - genInitWUVarCode(os, popSubs, getArchetype().getCustomUpdateModel()->getVars(), + genInitWUVarCode(os, modelMerged, popSubs, getArchetype().getCustomUpdateModel()->getVars(), getArchetype().getVarInitialisers(), stride, getIndex(), - modelMerged.getModel().getPrecision(), getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, + getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }, [&backend, kernel, this](CodeStream &os, const Substitutions &kernelSubs, BackendBase::Handler handler) @@ -1063,9 +1061,9 @@ boost::uuids::detail::sha1::digest_type CustomWUUpdateSparseInitGroupMerged::get // ---------------------------------------------------------------------------- void CustomWUUpdateSparseInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { - genInitWUVarCode(os, popSubs, getArchetype().getCustomUpdateModel()->getVars(), + genInitWUVarCode(os, modelMerged, popSubs, getArchetype().getCustomUpdateModel()->getVars(), getArchetype().getVarInitialisers(), "group->numSrcNeurons * group->rowStride", getIndex(), - modelMerged.getModel().getPrecision(), getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, + getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }, [&backend](CodeStream &os, const Substitutions &kernelSubs, BackendBase::Handler handler) @@ -1227,9 +1225,8 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateSparseInitGroupM void CustomConnectivityUpdateSparseInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { // Initialise custom connectivity update variables - genInitWUVarCode(os, popSubs, getArchetype().getCustomConnectivityUpdateModel()->getVars(), - getArchetype().getVarInitialisers(), "group->numSrcNeurons * group->rowStride", getIndex(), - modelMerged.getModel().getPrecision(), 1, + genInitWUVarCode(os, modelMerged, popSubs, getArchetype().getCustomConnectivityUpdateModel()->getVars(), + getArchetype().getVarInitialisers(), "group->numSrcNeurons * group->rowStride", getIndex(), 1, [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }, [&backend](CodeStream &os, const Substitutions &kernelSubs, BackendBase::Handler handler) diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 8b7766455a..5709ba6c29 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -90,7 +90,7 @@ void applySynapseSubstitutions(CodeStream &os, std::string code, const std::stri varSubs.applyCheckUnreplaced(code, "initVar : merged" + var.name + std::to_string(sg.getIndex())); // Declare local variable - os << var.type << " " << "l" << var.name << ";" << std::endl; + os << var.type->getResolvedName(sg.getTypeContext()) << " " << "l" << var.name << ";" << std::endl; // Insert code to initialize variable into scope { From 7dc843cd875874c8bff13e77ec88de40bc67ddf3 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 23 Jan 2023 09:19:22 +0000 Subject: [PATCH 091/725] fixed one more type-writing issue --- include/genn/genn/code_generator/backendBase.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 9a41d8bd14..7135dc6e35 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -543,7 +543,7 @@ class GENN_EXPORT BackendBase for (const auto &v : cm->getVars()) { // If variable is a reduction target, define variable initialised to correct initial value for reduction if (v.access & VarAccessModeAttribute::REDUCE) { - os << v.type << " lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), v.type, cg.getTypeContext()) << ";" << std::endl; + os << v.type->getResolvedName(cg.getTypeContext()) << " lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), v.type, cg.getTypeContext()) << ";" << std::endl; reductionTargets.emplace_back(v.name, v.type, getVarAccessMode(v.access), cg.getVarIndex(getVarAccessDuplication(v.access), idx)); } @@ -555,7 +555,7 @@ class GENN_EXPORT BackendBase // If variable reference is a reduction target, define variable initialised to correct initial value for reduction if (modelVarRef.access & VarAccessModeAttribute::REDUCE) { - os << modelVarRef.type << " lr" << modelVarRef.name << " = " << getReductionInitialValue(modelVarRef.access, modelVarRef.type, cg.getTypeContext()) << ";" << std::endl; + os << modelVarRef.type->getResolvedName(cg.getTypeContext()) << " lr" << modelVarRef.name << " = " << getReductionInitialValue(modelVarRef.access, modelVarRef.type, cg.getTypeContext()) << ";" << std::endl; reductionTargets.emplace_back(modelVarRef.name, modelVarRef.type, modelVarRef.access, getVarRefIndexFn(varRef, idx)); } From 2ca54a2abb7178a4c6a22c24b515cc218241fb3e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 23 Jan 2023 10:17:54 +0000 Subject: [PATCH 092/725] fixed up CUDA and CPU backends --- include/genn/backends/cuda/backend.h | 4 +- .../backends/single_threaded_cpu/backend.h | 4 +- .../genn/genn/code_generator/backendBase.h | 7 +- include/genn/genn/type.h | 6 +- src/genn/backends/cuda/backend.cc | 64 ++++++++++--------- src/genn/backends/cuda/optimiser.cc | 2 +- .../backends/single_threaded_cpu/backend.cc | 4 +- .../genn/code_generator/generateRunner.cc | 7 +- .../presynapticUpdateStrategySIMT.cc | 2 +- src/genn/genn/type.cc | 22 +++---- 10 files changed, 65 insertions(+), 57 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index 7fca5702f5..2309197e58 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -252,10 +252,10 @@ class BACKEND_EXPORT Backend : public BackendSIMT virtual void genGlobalDeviceRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - MemAlloc &memAlloc) const override; + const Type::TypeContext &typeContext, MemAlloc &memAlloc) const override; virtual void genPopulationRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const std::string &name, size_t count, MemAlloc &memAlloc) const override; + const Type::TypeContext &typeContext, const std::string &name, size_t count, MemAlloc &memAlloc) const override; virtual void genTimer(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, CodeStream &stepTimeFinalise, const std::string &name, bool updateInStepTime) const override; diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index ccc0811bb7..91ab53471f 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -131,10 +131,10 @@ class BACKEND_EXPORT Backend : public BackendBase virtual void genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged &cu, const Substitutions &kernelSubs, Handler handler) const final; virtual void genGlobalDeviceRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, - CodeStream &allocations, CodeStream &free, MemAlloc &memAlloc) const final; + CodeStream &allocations, CodeStream &free, const Type::TypeContext &typeContext, MemAlloc &memAlloc) const final; virtual void genPopulationRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const std::string &name, size_t count, MemAlloc &memAlloc) const final; + const Type::TypeContext &typeContext, const std::string &name, size_t count, MemAlloc &memAlloc) const final; virtual void genTimer(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, CodeStream &stepTimeFinalise, const std::string &name, bool updateInStepTime) const final; diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 7135dc6e35..03192bbd28 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -320,11 +320,12 @@ class GENN_EXPORT BackendBase //! Generate a single RNG instance /*! On single-threaded platforms this can be a standard RNG like M.T. but, on parallel platforms, it is likely to be a counter-based RNG */ virtual void genGlobalDeviceRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, - CodeStream &allocations, CodeStream &free, MemAlloc &memAlloc) const = 0; + CodeStream &allocations, CodeStream &free, const Type::TypeContext &typeContext, MemAlloc &memAlloc) const = 0; //! Generate an RNG with a state per population member - virtual void genPopulationRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, - CodeStream &free, const std::string &name, size_t count, MemAlloc &memAlloc) const = 0; + virtual void genPopulationRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, + CodeStream &allocations, CodeStream &free, + const Type::TypeContext &typeContext, const std::string &name, size_t count, MemAlloc &memAlloc) const = 0; virtual void genTimer(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, CodeStream &stepTimeFinalise, const std::string &name, bool updateInStepTime) const = 0; diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index b995a311fe..213f517d63 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -255,12 +255,12 @@ class NumericTypedef : public NumericBase virtual std::string getLiteralSuffix(const TypeContext &context) const final; -private: //------------------------------------------------------------------------ - // Private methods + // Public API //------------------------------------------------------------------------ - const Type::NumericBase *getNumeric(const TypeContext &context) const; + const Type::NumericBase *getResolvedType(const TypeContext &context) const; +private: //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 40d4d76aea..82ce8aada5 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -262,41 +262,46 @@ std::string getNCCLReductionType(VarAccessMode mode) } } //----------------------------------------------------------------------- -std::string getNCCLType(const std::string &type, const std::string &precision) +std::string getNCCLType(const Type::NumericBase *type, const Type::TypeContext &context) { - // Convert GeNN types to NCCL types - // **YUCK** GeNN really needs a better type system - if(type == "scalar") { - return (precision == "float") ? "ncclFloat32" : "ncclFloat64"; + // If type is a numeric typedef, resolve it + const auto numericTypedef = dynamic_cast(type); + if (numericTypedef) { + type = numericTypedef->getResolvedType(context); } - else if(type == "char" || type == "signed char" || type == "int8_t") { + + // Convert GeNN types to NCCL types + // **YUCK** Visitor pattern would really help here + if(dynamic_cast(type)) { return "ncclInt8"; } - else if(type == "unsigned char" || type == "uint8_t") { + else if(dynamic_cast(type)) { return "ncclUint8"; } - else if(type == "int" || type == "signed int" || type == "signed" || type == "int32_t") { + else if(dynamic_cast(type)) { return "ncclInt32"; } - else if(type == "unsigned" || type == "unsigned int" || type == "uint32_t") { + else if(dynamic_cast(type)){ return "ncclUint32"; } - else if(type == "half") { + /*else if(type == "half") { return "ncclFloat16"; - } - else if(type == "float") { + }*/ + else if(dynamic_cast(type)){ return "ncclFloat32"; } - else if(type == "double") { + else if(dynamic_cast(type)) { return "ncclFloat64"; } + else if (dynamic_cast(type)) { + } else { - throw std::runtime_error("Data type '" + type + "' unsupported by NCCL"); + throw std::runtime_error("Data type '" + type->getResolvedName(context) + "' unsupported by NCCL"); } } //----------------------------------------------------------------------- template -void genNCCLReduction(CodeStream &os, const G &cg, const std::string &precision) +void genNCCLReduction(CodeStream &os, const G &cg) { CodeStream::Scope b(os); os << "// merged custom update host reduction group " << cg.getIndex() << std::endl; @@ -312,7 +317,7 @@ void genNCCLReduction(CodeStream &os, const G &cg, const std::string &precision) for(const auto &v : cm->getVars()) { if(v.access & VarAccessModeAttribute::REDUCE) { os << "CHECK_NCCL_ERRORS(ncclAllReduce(group->" << v.name << ", group->" << v.name << ", group->size"; - os << ", " << getNCCLType(v.type, precision) << ", " << getNCCLReductionType(getVarAccessMode(v.access)) << ", ncclCommunicator, 0)); " << std::endl; + os << ", " << getNCCLType(v.type, cg.getTypeContext()) << ", " << getNCCLReductionType(getVarAccessMode(v.access)) << ", ncclCommunicator, 0)); " << std::endl; } } @@ -320,7 +325,7 @@ void genNCCLReduction(CodeStream &os, const G &cg, const std::string &precision) for(const auto &v : cm->getVarRefs()) { if(v.access & VarAccessModeAttribute::REDUCE) { os << "CHECK_NCCL_ERRORS(ncclAllReduce(group->" << v.name << ", group->" << v.name << ", group->size"; - os << ", " << getNCCLType(v.type, precision) << ", " << getNCCLReductionType(v.access) << ", ncclCommunicator, 0));" << std::endl; + os << ", " << getNCCLType(v.type, cg.getTypeContext()) << ", " << getNCCLReductionType(v.access) << ", ncclCommunicator, 0));" << std::endl; } } } @@ -481,7 +486,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // If any neuron groups require their previous spike times updating size_t idNeuronPrevSpikeTimeUpdate = 0; if(!modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronPrevSpikeTimeUpdate] << "(" << model.getTimePrecision() << " t)"; + os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronPrevSpikeTimeUpdate] << "(" << model.getTimePrecision()->getName() << " t)"; { CodeStream::Scope b(os); @@ -510,7 +515,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged os << std::endl; size_t idStart = 0; - os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronUpdate] << "(" << model.getTimePrecision() << " t"; + os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronUpdate] << "(" << model.getTimePrecision()->getName() << " t"; if(model.isRecordingInUse()) { os << ", unsigned int recordingTimestep"; } @@ -532,7 +537,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged genNeuronUpdateKernel(os, kernelSubs, modelMerged, idStart); } - os << "void updateNeurons(" << model.getTimePrecision() << " t"; + os << "void updateNeurons(" << model.getTimePrecision()->getName() << " t"; if(model.isRecordingInUse()) { os << ", unsigned int recordingTimestep"; } @@ -785,7 +790,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged [this](const CustomConnectivityUpdateInternal &cg){ return padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelCustomUpdate); }, [g](const CustomConnectivityUpdateGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; }); - os << "extern \"C\" __global__ void " << KernelNames[KernelCustomUpdate] << g << "(" << model.getTimePrecision() << " t)" << std::endl; + os << "extern \"C\" __global__ void " << KernelNames[KernelCustomUpdate] << g << "(" << model.getTimePrecision()->getName() << " t)" << std::endl; { CodeStream::Scope b(os); @@ -816,7 +821,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged [&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); }, [g](const CustomUpdateTransposeWUGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; }); - os << "extern \"C\" __global__ void " << KernelNames[KernelCustomTransposeUpdate] << g << "(" << model.getTimePrecision() << " t)" << std::endl; + os << "extern \"C\" __global__ void " << KernelNames[KernelCustomTransposeUpdate] << g << "(" << model.getTimePrecision()->getName() << " t)" << std::endl; { CodeStream::Scope b(os); @@ -866,7 +871,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // generate reductions for those in this custom update group for(const auto &cg : modelMerged.getMergedCustomUpdateHostReductionGroups()) { if(cg.getArchetype().getUpdateGroupName() == g) { - genNCCLReduction(os, cg, model.getPrecision()); + genNCCLReduction(os, cg); } } @@ -874,7 +879,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // generate reductions for those in this custom update group for(const auto &cg : modelMerged.getMergedCustomWUUpdateHostReductionGroups()) { if(cg.getArchetype().getUpdateGroupName() == g) { - genNCCLReduction(os, cg, model.getPrecision()); + genNCCLReduction(os, cg); } } } @@ -1607,7 +1612,7 @@ void Backend::genVariableInstantiation(CodeStream &os, os << pointerTypeName << " " << name << ";" << std::endl; } if(loc & VarLocation::DEVICE) { - os << pointerTypeName << " d_" << name < V< ";" << std::endl; + os << pointerTypeName << " d_" << name << ";" << std::endl; } } } @@ -1852,7 +1857,8 @@ const Type::ValueBase *Backend::getMergedGroupSimRNGType() const return CURandState::getInstance(); } //-------------------------------------------------------------------------- -void Backend::genGlobalDeviceRNG(CodeStream &, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &, CodeStream &, MemAlloc &memAlloc) const +void Backend::genGlobalDeviceRNG(CodeStream &, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &, CodeStream &, + const Type::TypeContext &typeContext, MemAlloc &memAlloc) const { // Define global Phillox RNG // **NOTE** this is actually accessed as a global so, unlike other variables, needs device global @@ -1861,14 +1867,14 @@ void Backend::genGlobalDeviceRNG(CodeStream &, CodeStream &definitionsInternal, // Implement global Phillox RNG runner << "__device__ curandStatePhilox4_32_10_t d_rng;" << std::endl; - memAlloc += MemAlloc::device(CURandStatePhilox43210->getSizeBytes()); + memAlloc += MemAlloc::device(CURandStatePhilox43210::getInstance()->getSizeBytes(typeContext)); } //-------------------------------------------------------------------------- void Backend::genPopulationRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const std::string &name, size_t count, MemAlloc &memAlloc) const + const Type::TypeContext &typeContext, const std::string &name, size_t count, MemAlloc &memAlloc) const { // Create an array or XORWOW RNGs - genArray(definitions, definitionsInternal, runner, allocations, free, "curandState", name, VarLocation::DEVICE, count, memAlloc); + genArray(definitions, definitionsInternal, runner, allocations, free, typeContext, name, VarLocation::DEVICE, count, memAlloc); } //-------------------------------------------------------------------------- void Backend::genTimer(CodeStream &, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, diff --git a/src/genn/backends/cuda/optimiser.cc b/src/genn/backends/cuda/optimiser.cc index f13d8bef99..9c7a3403ea 100644 --- a/src/genn/backends/cuda/optimiser.cc +++ b/src/genn/backends/cuda/optimiser.cc @@ -452,7 +452,7 @@ KernelOptimisationOutput optimizeBlockSize(int deviceID, const cudaDeviceProp &d std::fill(blockSize.begin(), blockSize.end(), repBlockSizes[r]); // Create backend - Backend backend(blockSize, preferences, model.getPrecision(), deviceID); + Backend backend(blockSize, preferences, deviceID); // Create merged model ModelSpecMerged modelMerged(model, backend); diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 3f4e0982f8..1899f7079c 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1438,13 +1438,13 @@ void Backend::genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUp genKernelIteration(os, cu, cu.getArchetype().getSynapseGroup()->getKernelSize().size(), kernelSubs, handler); } //-------------------------------------------------------------------------- -void Backend::genGlobalDeviceRNG(CodeStream&, CodeStream&, CodeStream&, CodeStream&, CodeStream&, MemAlloc&) const +void Backend::genGlobalDeviceRNG(CodeStream&, CodeStream&, CodeStream&, CodeStream&, CodeStream&, const Type::TypeContext&, MemAlloc&) const { assert(false); } //-------------------------------------------------------------------------- void Backend::genPopulationRNG(CodeStream&, CodeStream&, CodeStream&, CodeStream&, CodeStream&, - const std::string&, size_t, MemAlloc&) const + const Type::TypeContext&, const std::string&, size_t, MemAlloc&) const { } //-------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 878503afa0..933702aaed 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -627,7 +627,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, } // If backend requires a global device RNG to simulate (or initialize) this model if(backend.isGlobalDeviceRNGRequired(modelMerged)) { - backend.genGlobalDeviceRNG(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, mem); + backend.genGlobalDeviceRNG(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + modelMerged.getTypeContext(), mem); } // If backend required a global host RNG to simulate (or initialize) this model, generate a standard Mersenne Twister if(backend.isGlobalHostRNGRequired(modelMerged)) { @@ -1059,7 +1060,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If neuron group needs per-neuron RNGs if(n.second.isSimRNGRequired()) { backend.genPopulationRNG(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "rng" + n.first, batchSize * n.second.getNumNeurons(), mem); + modelMerged.getTypeContext(), "rng" + n.first, batchSize * n.second.getNumNeurons(), mem); } // Neuron state variables @@ -1220,7 +1221,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If custom connectivity update group needs per-row RNGs if(c.second.isRowSimRNGRequired()) { backend.genPopulationRNG(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "rowRNG" + c.first, c.second.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), mem); + modelMerged.getTypeContext(), "rowRNG" + c.first, c.second.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), mem); } diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 56b56fc816..f21b89661a 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -232,7 +232,7 @@ void PostSpan::genPreamble(CodeStream &os, const ModelSpecMerged &modelMerged, c { // If data structure is dense, we can accumulate output directly into register if(shouldAccumulateInRegister(sg)) { - os << modelMerged.getModel().getPrecision() << " linSyn = 0;" << std::endl; + os << modelMerged.getModel().getPrecision()->getName() << " linSyn = 0;" << std::endl; } else if(isSmallSharedMemoryPop(sg, backend)) { os << "if(" << backend.getThreadID() << " < group->numTrgNeurons)"; diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index adc806a0e9..6c00a1fdc5 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -92,12 +92,12 @@ const Pointer *Base::getPointerType(Qualifier qualifiers) const //---------------------------------------------------------------------------- std::string NumericTypedef::getResolvedName(const TypeContext &context) const { - return getNumeric(context)->getResolvedName(context); + return getResolvedType(context)->getResolvedName(context); } //---------------------------------------------------------------------------- size_t NumericTypedef::getSizeBytes(const TypeContext &context) const { - return getNumeric(context)->getSizeBytes(context); + return getResolvedType(context)->getSizeBytes(context); } //---------------------------------------------------------------------------- Base *NumericTypedef::getQualifiedType(Qualifier qualifiers) const @@ -107,45 +107,45 @@ Base *NumericTypedef::getQualifiedType(Qualifier qualifiers) const //---------------------------------------------------------------------------- int NumericTypedef::getRank(const TypeContext &context) const { - return getNumeric(context)->getRank(context); + return getResolvedType(context)->getRank(context); } //---------------------------------------------------------------------------- double NumericTypedef::getMin(const TypeContext &context) const { - return getNumeric(context)->getMin(context); + return getResolvedType(context)->getMin(context); } //---------------------------------------------------------------------------- double NumericTypedef::getMax(const TypeContext &context) const { - return getNumeric(context)->getMax(context); + return getResolvedType(context)->getMax(context); } //---------------------------------------------------------------------------- double NumericTypedef::getLowest(const TypeContext &context) const { - return getNumeric(context)->getLowest(context); + return getResolvedType(context)->getLowest(context); } //---------------------------------------------------------------------------- int NumericTypedef::getMaxDigits10(const TypeContext &context) const { - return getNumeric(context)->getMaxDigits10(context); + return getResolvedType(context)->getMaxDigits10(context); } //---------------------------------------------------------------------------- bool NumericTypedef::isSigned(const TypeContext &context) const { - return getNumeric(context)->getSizeBytes(context); + return getResolvedType(context)->getSizeBytes(context); } //---------------------------------------------------------------------------- bool NumericTypedef::isIntegral(const TypeContext &context) const { - return getNumeric(context)->isIntegral(context); + return getResolvedType(context)->isIntegral(context); } //---------------------------------------------------------------------------- std::string NumericTypedef::getLiteralSuffix(const TypeContext &context) const { - return getNumeric(context)->getLiteralSuffix(context); + return getResolvedType(context)->getLiteralSuffix(context); } //---------------------------------------------------------------------------- -const Type::NumericBase *NumericTypedef::getNumeric(const TypeContext &context) const +const Type::NumericBase *NumericTypedef::getResolvedType(const TypeContext &context) const { const auto t = context.find(m_Name); if (t == context.cend()) { From bc358a3c3e75c51dee3b9acb8ace4bb87c75fedd Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 23 Jan 2023 11:16:30 +0000 Subject: [PATCH 093/725] no need to resolve types so aggressively - just use the typedef (need to think of a less clashy name than "time" though) --- .../genn/genn/code_generator/backendBase.h | 4 +- .../genn/genn/code_generator/groupMerged.h | 4 +- src/genn/backends/cuda/backend.cc | 66 +++++++++---------- src/genn/backends/opencl/backend.cc | 4 +- .../backends/single_threaded_cpu/backend.cc | 14 ++-- .../customConnectivityUpdateGroupMerged.cc | 2 +- .../code_generator/customUpdateGroupMerged.cc | 4 +- .../genn/code_generator/generateRunner.cc | 8 ++- .../genn/code_generator/initGroupMerged.cc | 10 +-- .../code_generator/neuronUpdateGroupMerged.cc | 10 +-- .../presynapticUpdateStrategySIMT.cc | 2 +- .../synapseUpdateGroupMerged.cc | 4 +- 12 files changed, 64 insertions(+), 68 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 03192bbd28..a9cfc62a90 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -544,7 +544,7 @@ class GENN_EXPORT BackendBase for (const auto &v : cm->getVars()) { // If variable is a reduction target, define variable initialised to correct initial value for reduction if (v.access & VarAccessModeAttribute::REDUCE) { - os << v.type->getResolvedName(cg.getTypeContext()) << " lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), v.type, cg.getTypeContext()) << ";" << std::endl; + os << v.type->getName() << " lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), v.type, cg.getTypeContext()) << ";" << std::endl; reductionTargets.emplace_back(v.name, v.type, getVarAccessMode(v.access), cg.getVarIndex(getVarAccessDuplication(v.access), idx)); } @@ -556,7 +556,7 @@ class GENN_EXPORT BackendBase // If variable reference is a reduction target, define variable initialised to correct initial value for reduction if (modelVarRef.access & VarAccessModeAttribute::REDUCE) { - os << modelVarRef.type->getResolvedName(cg.getTypeContext()) << " lr" << modelVarRef.name << " = " << getReductionInitialValue(modelVarRef.access, modelVarRef.type, cg.getTypeContext()) << ";" << std::endl; + os << modelVarRef.type->getName() << " lr" << modelVarRef.name << " = " << getReductionInitialValue(modelVarRef.access, modelVarRef.type, cg.getTypeContext()) << ";" << std::endl; reductionTargets.emplace_back(modelVarRef.name, modelVarRef.type, modelVarRef.access, getVarRefIndexFn(varRef, idx)); } diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index fdfe3bd620..24128b5666 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -124,12 +124,12 @@ class GroupMerged } // Otherwise, allow the backend to add a prefix else { - os << backend.getPointerPrefix() << type->getResolvedName(m_TypeContext); + os << backend.getPointerPrefix() << type->getName(); } } // Otherwise, leave the type alone else { - os << type->getResolvedName(m_TypeContext); + os << type->getName(); } os << " " << std::get<1>(f) << ";" << std::endl; } diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 82ce8aada5..1e5df16c80 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -296,7 +296,7 @@ std::string getNCCLType(const Type::NumericBase *type, const Type::TypeContext & else if (dynamic_cast(type)) { } else { - throw std::runtime_error("Data type '" + type->getResolvedName(context) + "' unsupported by NCCL"); + throw std::runtime_error("Data type '" + type->getName() + "' unsupported by NCCL"); } } //----------------------------------------------------------------------- @@ -1577,23 +1577,22 @@ void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &definit { const bool deviceType = dynamic_cast(type); CodeStream &d = deviceType ? definitionsInternal : definitions; - const std::string pointerTypeName = type->getPointerType()->getResolvedName(typeContext); if(getPreferences().automaticCopy) { // Export pointer, either in definitionsInternal if variable has a device type // or to definitions if it should be accessable on host - d << "EXPORT_VAR " << pointerTypeName << " " << name << ";" << std::endl; + d << "EXPORT_VAR " << type->getPointerType()->getName() << " " << name << ";" << std::endl; } else { if(loc & VarLocation::HOST) { if(deviceType) { - throw std::runtime_error("Variable '" + name + "' is of device-only type '" + pointerTypeName + "' but is located on the host"); + throw std::runtime_error("Variable '" + name + "' is of device-only type '" + type->getPointerType()->getName() + "' but is located on the host"); } - definitions << "EXPORT_VAR " << pointerTypeName << " " << name << ";" << std::endl; + definitions << "EXPORT_VAR " << type->getPointerType()->getName() << " " << name << ";" << std::endl; } if(loc & VarLocation::DEVICE) { // Write host definition to internal definitions stream if type is device only - d << "EXPORT_VAR " << pointerTypeName << " d_" << name << ";" << std::endl; + d << "EXPORT_VAR " << type->getPointerType()->getName() << " d_" << name << ";" << std::endl; } } @@ -1603,16 +1602,15 @@ void Backend::genVariableInstantiation(CodeStream &os, const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, VarLocation loc) const { - const std::string pointerTypeName = type->getPointerType()->getResolvedName(typeContext); if(getPreferences().automaticCopy) { - os << pointerTypeName << " " << name << ";" << std::endl; + os << type->getPointerType()->getName() << " " << name << ";" << std::endl; } else { if(loc & VarLocation::HOST) { - os << pointerTypeName << " " << name << ";" << std::endl; + os << type->getPointerType()->getName() << " " << name << ";" << std::endl; } if(loc & VarLocation::DEVICE) { - os << pointerTypeName << " d_" << name << ";" << std::endl; + os << type->getPointerType()->getName() << " d_" << name << ";" << std::endl; } } } @@ -1621,15 +1619,14 @@ void Backend::genVariableAllocation(CodeStream &os, const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const { - const std::string typeName = type->getResolvedName(typeContext); if(getPreferences().automaticCopy) { - os << "CHECK_CUDA_ERRORS(cudaMallocManaged(&" << name << ", " << count << " * sizeof(" << typeName << ")));" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaMallocManaged(&" << name << ", " << count << " * sizeof(" << type->getName() << ")));" << std::endl; memAlloc += MemAlloc::device(count * type->getSizeBytes(typeContext)); } else { if(loc & VarLocation::HOST) { const char *flags = (loc & VarLocation::ZERO_COPY) ? "cudaHostAllocMapped" : "cudaHostAllocPortable"; - os << "CHECK_CUDA_ERRORS(cudaHostAlloc(&" << name << ", " << count << " * sizeof(" << typeName << "), " << flags << "));" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaHostAlloc(&" << name << ", " << count << " * sizeof(" << type->getName() << "), " << flags << "));" << std::endl; memAlloc += MemAlloc::host(count * type->getSizeBytes(typeContext)); } @@ -1641,7 +1638,7 @@ void Backend::genVariableAllocation(CodeStream &os, memAlloc += MemAlloc::zeroCopy(count * type->getSizeBytes(typeContext)); } else { - os << "CHECK_CUDA_ERRORS(cudaMalloc(&d_" << name << ", " << count << " * sizeof(" << typeName << ")));" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaMalloc(&d_" << name << ", " << count << " * sizeof(" << type->getName() << ")));" << std::endl; memAlloc += MemAlloc::device(count * type->getSizeBytes(typeContext)); } } @@ -1654,17 +1651,16 @@ void Backend::genVariableDynamicAllocation(CodeStream &os, { const auto *pointerType = dynamic_cast(type); const auto *underlyingType = pointerType ? pointerType->getValueType() : type; - const std::string underlyingTypeName = underlyingType->getResolvedName(typeContext); const std::string hostPointer = pointerType ? ("*" + prefix + name) : (prefix + name); const std::string hostPointerToPointer = pointerType ? (prefix + name) : ("&" + prefix + name); const std::string devicePointerToPointer = pointerType ? (prefix + "d_" + name) : ("&" + prefix + "d_" + name); if(getPreferences().automaticCopy) { - os << "CHECK_CUDA_ERRORS(cudaMallocManaged(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingTypeName << ")));" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaMallocManaged(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType->getName() << ")));" << std::endl; } else { if(loc & VarLocation::HOST) { const char *flags = (loc & VarLocation::ZERO_COPY) ? "cudaHostAllocMapped" : "cudaHostAllocPortable"; - os << "CHECK_CUDA_ERRORS(cudaHostAlloc(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingTypeName << "), " << flags << "));" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaHostAlloc(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType->getName() << "), " << flags << "));" << std::endl; } // If variable is present on device at all @@ -1673,7 +1669,7 @@ void Backend::genVariableDynamicAllocation(CodeStream &os, os << "CHECK_CUDA_ERRORS(cudaHostGetDevicePointer((void**)" << devicePointerToPointer << ", (void*)" << hostPointer << ", 0));" << std::endl; } else { - os << "CHECK_CUDA_ERRORS(cudaMalloc(" << devicePointerToPointer << ", " << countVarName << " * sizeof(" << underlyingTypeName << ")));" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaMalloc(" << devicePointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType->getName() << ")));" << std::endl; } } } @@ -1711,7 +1707,7 @@ void Backend::genVariablePush(CodeStream &os, os << "CHECK_CUDA_ERRORS(cudaMemcpy(d_" << name; os << ", " << name; - os << ", " << count << " * sizeof(" << type->getResolvedName(typeContext) << "), cudaMemcpyHostToDevice));" << std::endl; + os << ", " << count << " * sizeof(" << type->getName() << "), cudaMemcpyHostToDevice));" << std::endl; if(autoInitialized) { os << CodeStream::CB(1101); @@ -1728,7 +1724,7 @@ void Backend::genVariablePull(CodeStream &os, if(!(loc & VarLocation::ZERO_COPY)) { os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << name; os << ", d_" << name; - os << ", " << count << " * sizeof(" << type->getResolvedName(typeContext) << "), cudaMemcpyDeviceToHost));" << std::endl; + os << ", " << count << " * sizeof(" << type->getName() << "), cudaMemcpyDeviceToHost));" << std::endl; } } //-------------------------------------------------------------------------- @@ -1741,19 +1737,18 @@ void Backend::genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal & // If this variable requires queuing and isn't zero-copy if(ng.isVarQueueRequired(name) && ng.isDelayRequired() && !(loc & VarLocation::ZERO_COPY)) { // If batch size is one, generate 1D memcpy to copy current timestep's data - const std::string typeName = type->getResolvedName(typeContext); if(batchSize == 1) { os << "CHECK_CUDA_ERRORS(cudaMemcpy(d_" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; os << ", " << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() << " * sizeof(" << typeName << "), cudaMemcpyHostToDevice));" << std::endl; + os << ", " << ng.getNumNeurons() << " * sizeof(" << type->getName() << "), cudaMemcpyHostToDevice));" << std::endl; } // Otherwise, perform a 2D memcpy to copy current timestep's data from each batch else { os << "CHECK_CUDA_ERRORS(cudaMemcpy2D(d_" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << typeName << ")"; + os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type->getName() << ")"; os << ", " << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << typeName << ")"; - os << ", " << ng.getNumNeurons() << " * sizeof(" << typeName << ")"; + os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type->getName() << ")"; + os << ", " << ng.getNumNeurons() << " * sizeof(" << type->getName() << ")"; os << ", " << batchSize << ", cudaMemcpyHostToDevice));" << std::endl; } } @@ -1772,18 +1767,17 @@ void Backend::genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal & // If this variable requires queuing and isn't zero-copy if(ng.isVarQueueRequired(name) && ng.isDelayRequired() && !(loc & VarLocation::ZERO_COPY)) { // If batch size is one, generate 1D memcpy to copy current timestep's data - const std::string typeName = type->getResolvedName(typeContext); if(batchSize == 1) { os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; os << ", d_" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() << " * sizeof(" << typeName << "), cudaMemcpyDeviceToHost));" << std::endl; + os << ", " << ng.getNumNeurons() << " * sizeof(" << type->getName() << "), cudaMemcpyDeviceToHost));" << std::endl; } else { os << "CHECK_CUDA_ERRORS(cudaMemcpy2D(" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << typeName << ")"; + os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type->getName() << ")"; os << ", d_" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << typeName << ")"; - os << ", " << ng.getNumNeurons() << " * sizeof(" << typeName << ")"; + os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type->getName() << ")"; + os << ", " << ng.getNumNeurons() << " * sizeof(" << type->getName() << ")"; os << ", " << batchSize << ", cudaMemcpyDeviceToHost));" << std::endl; } } @@ -1804,13 +1798,13 @@ void Backend::genVariableDynamicPush(CodeStream &os, if (pointerType) { os << "CHECK_CUDA_ERRORS(cudaMemcpy(*" << prefix << "d_" << name; os << ", *" << prefix << name; - os << ", " << countVarName << " * sizeof(" << pointerType->getValueType()->getResolvedName(typeContext) << "), cudaMemcpyHostToDevice));" << std::endl; + os << ", " << countVarName << " * sizeof(" << pointerType->getValueType()->getName() << "), cudaMemcpyHostToDevice));" << std::endl; } else { - os << prefix << name << " = new " << type->getResolvedName(typeContext) << "[" << countVarName << "];" << std::endl; + os << prefix << name << " = new " << type->getName() << "[" << countVarName << "];" << std::endl; os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << prefix << "d_" << name; os << ", " << prefix << name; - os << ", " << countVarName << " * sizeof(" << type->getResolvedName(typeContext) << "), cudaMemcpyHostToDevice));" << std::endl; + os << ", " << countVarName << " * sizeof(" << type->getName() << "), cudaMemcpyHostToDevice));" << std::endl; } } } @@ -1826,12 +1820,12 @@ void Backend::genVariableDynamicPull(CodeStream &os, if (pointerType) { os << "CHECK_CUDA_ERRORS(cudaMemcpy(*" << prefix << name; os << ", *" << prefix << "d_" << name; - os << ", " << countVarName << " * sizeof(" << pointerType->getValueType()->getResolvedName(typeContext) << "), cudaMemcpyDeviceToHost));" << std::endl; + os << ", " << countVarName << " * sizeof(" << pointerType->getValueType()->getName() << "), cudaMemcpyDeviceToHost));" << std::endl; } else { os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << prefix << name; os << ", " << prefix << "d_" << name; - os << ", " << countVarName << " * sizeof(" << type->getResolvedName(typeContext) << "), cudaMemcpyDeviceToHost));" << std::endl; + os << ", " << countVarName << " * sizeof(" << type->getName() << "), cudaMemcpyDeviceToHost));" << std::endl; } } @@ -1849,7 +1843,7 @@ void Backend::genMergedDynamicVariablePush(CodeStream &os, const std::string &su //-------------------------------------------------------------------------- std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const { - return type->getResolvedName(context); + return type->getName(); } //-------------------------------------------------------------------------- const Type::ValueBase *Backend::getMergedGroupSimRNGType() const diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index fe2ea6e8c9..5b6f18c0d8 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -2080,12 +2080,12 @@ std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, con /*if(GeNN::Utils::isTypePointerToPointer(type)) { return "cl::Buffer*"; } - else */if(dynamic_cast(type)) { + else */if(dynamic_cast(type)) { return "cl::Buffer"; } // Otherwise, type remains the same else { - return type->getResolvedName(context); + return type->getName(); } } //-------------------------------------------------------------------------- diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 1899f7079c..f4cf846354 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1288,21 +1288,21 @@ void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &, const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, VarLocation) const { - definitions << "EXPORT_VAR " << type->getPointerType()->getResolvedName(typeContext) << " " << name << ";" << std::endl; + definitions << "EXPORT_VAR " << type->getPointerType()->getName() << " " << name << ";" << std::endl; } //-------------------------------------------------------------------------- void Backend::genVariableInstantiation(CodeStream &os, const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, VarLocation) const { - os << type->getPointerType()->getResolvedName(typeContext) << " " << name << ";" << std::endl; + os << type->getPointerType()->getName() << " " << name << ";" << std::endl; } //-------------------------------------------------------------------------- void Backend::genVariableAllocation(CodeStream &os, const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, VarLocation, size_t count, MemAlloc &memAlloc) const { - os << name << " = new " << type->getResolvedName(typeContext) << "[" << count << "];" << std::endl; + os << name << " = new " << type->getName() << "[" << count << "];" << std::endl; memAlloc += MemAlloc::host(count * type->getSizeBytes(typeContext)); } @@ -1313,10 +1313,10 @@ void Backend::genVariableDynamicAllocation(CodeStream &os, { const auto *pointerType = dynamic_cast(type); if (pointerType) { - os << "*" << prefix << name << " = new " << pointerType->getValueType()->getResolvedName(typeContext) << "[" << countVarName << "];" << std::endl; + os << "*" << prefix << name << " = new " << pointerType->getValueType()->getName() << "[" << countVarName << "];" << std::endl; } else { - os << prefix << name << " = new " << type->getResolvedName(typeContext) << "[" << countVarName << "];" << std::endl; + os << prefix << name << " = new " << type->getName() << "[" << countVarName << "];" << std::endl; } } //-------------------------------------------------------------------------- @@ -1373,7 +1373,7 @@ void Backend::genMergedDynamicVariablePush(CodeStream &os, const std::string &su //-------------------------------------------------------------------------- std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const { - return type->getResolvedName(context); + return type->getName(); } //-------------------------------------------------------------------------- const Type::ValueBase *Backend::getMergedGroupSimRNGType() const @@ -1658,7 +1658,7 @@ void Backend::genPresynapticUpdate(CodeStream &os, const ModelSpecMerged &modelM connSubs.applyCheckUnreplaced(value, "toeplitz diagonal build state var : merged" + std::to_string(sg.getIndex())); //value = ensureFtype(value, modelMerged.getModel().getPrecision()); - os << d.type->getResolvedName(sg.getTypeContext()) << " " << d.name << " = " << value << ";" << std::endl; + os << d.type->getName() << " " << d.name << " = " << value << ";" << std::endl; } // Detect spike events or spikes and do the update diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 5f507b4df3..54a90ce4e0 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -266,7 +266,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back if ((modelMerged.getModel().getBatchSize() > 1) && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) { // Copy parameter into a register (just incase it's e.g. a RNG call) and copy into all batches - addSynapse << "const " << ccuVarRefs[i].type->getResolvedName(getTypeContext()) << " _" << ccuVarRefs[i].name << "Val = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; + addSynapse << "const " << ccuVarRefs[i].type->getName() << " _" << ccuVarRefs[i].name << "Val = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 238a3a5b8f..63593088c1 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -34,7 +34,7 @@ void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, const if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type->getResolvedName(cg.getTypeContext()) << " l" << v.name; + os << v.type->getName() << " l" << v.name; // If this isn't a reduction, read value from memory // **NOTE** by not initialising these variables for reductions, @@ -54,7 +54,7 @@ void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, const os << "const "; } - os << v.type->getResolvedName(cg.getTypeContext()) << " l" << v.name; + os << v.type->getName() << " l" << v.name; // If this isn't a reduction, read value from memory // **NOTE** by not initialising these variables for reductions, diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 933702aaed..d7f1808fe1 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -559,8 +559,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const ModelSpecInternal &model = modelMerged.getModel(); definitions << "#define DT " << Utils::writePreciseString(model.getDT()) << model.getTimePrecision()->getLiteralSuffix(modelMerged.getTypeContext()) << std::endl; - // Typedefine scalar type - definitions << "typedef " << model.getPrecision()->getName() << " scalar;" << std::endl; + // Typedefine types in type context + for (const auto &t : modelMerged.getTypeContext()) { + definitions << "typedef " << t.second->getName() << " " << t.first << ";" << std::endl; + } // Write ranges of scalar and time types genTypeRange(definitions, model.getPrecision(), modelMerged.getTypeContext(), "SCALAR"); @@ -1089,7 +1091,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Write getter to get access to correct pointer const bool delayRequired = (n.second.isVarQueueRequired(var.name) && n.second.isDelayRequired()); genVarGetterScope(definitionsFunc, runnerGetterFunc, n.second.getVarLocation(var.name), - "Current" + var.name + n.first, var.type->getPointerType()->getResolvedName(modelMerged.getTypeContext()), + "Current" + var.name + n.first, var.type->getPointerType()->getName(), [&]() { runnerGetterFunc << "return " << var.name << n.first; diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 45f3fe75e7..84abb7fdaa 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -86,7 +86,7 @@ void genInitNeuronVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, co (CodeStream &os, Substitutions &varInitSubs) { // Generate initial value into temporary variable - os << var.type->getResolvedName(modelMerged.getTypeContext()) << " initVal;" << std::endl; + os << var.type->getName() << " initVal;" << std::endl; varInitSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); varInitSubs.applyCheckUnreplaced(code, "initVar : " + var.name + "merged" + std::to_string(groupIndex)); @@ -106,7 +106,7 @@ void genInitNeuronVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, co (CodeStream &os, Substitutions &varInitSubs) { // Generate initial value into temporary variable - os << var.type->getResolvedName(modelMerged.getTypeContext()) << " initVal;" << std::endl; + os << var.type->getName() << " initVal;" << std::endl; varInitSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); varInitSubs.applyCheckUnreplaced(code, "initVar : " + var.name + "merged" + std::to_string(groupIndex)); @@ -164,7 +164,7 @@ void genInitWUVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const "", "group->", var.name); // Generate initial value into temporary variable - os << var.type->getResolvedName(modelMerged.getTypeContext()) << " initVal;" << std::endl; + os << var.type->getName() << " initVal;" << std::endl; varSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); varSubs.applyCheckUnreplaced(code, "initVar : merged" + var.name + std::to_string(groupIndex)); @@ -652,7 +652,7 @@ void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase&, "", "group->", var.name); // Generate initial value into temporary variable - os << var.type->getResolvedName(getTypeContext()) << " initVal;" << std::endl; + os << var.type->getName() << " initVal;" << std::endl; popSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); //popSubs.applyCheckUnreplaced(code, "initVar : merged" + vars[k].name + std::to_string(sg.getIndex())); @@ -691,7 +691,7 @@ void SynapseConnectivityInitGroupMerged::genInitConnectivity(CodeStream &os, Sub popSubs.applyCheckUnreplaced(value, "initSparseConnectivity state var : merged" + std::to_string(getIndex())); //value = ensureFtype(value, ftype); - os << a.type->getResolvedName(getTypeContext()) << " " << a.name << " = " << value << ";" << std::endl; + os << a.type->getName() << " " << a.name << " = " << value << ";" << std::endl; } os << "while(true)"; { diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 8f755e9986..22faf11de3 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -187,7 +187,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type->getResolvedName(getTypeContext()) << " l" << v.name << " = group->" << v.name << "["; + os << v.type->getName() << " l" << v.name << " = group->" << v.name << "["; const bool delayed = (getArchetype().isVarQueueRequired(v.name) && getArchetype().isDelayRequired()); os << getReadVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), popSubs["id"]) << "];" << std::endl; } @@ -249,7 +249,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C neuronSubs.applyCheckUnreplaced(value, "neuron additional input var : merged" + std::to_string(getIndex())); //value = ensureFtype(value, modelMerged.getModel().getPrecision()); - os << a.type->getResolvedName(getTypeContext()) << " " << a.name << " = " << value << ";" << std::endl; + os << a.type->getName() << " " << a.name << " = " << value << ";" << std::endl; } // Loop through incoming synapse groups @@ -282,7 +282,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type->getResolvedName(getTypeContext()) << " lps" << v.name << " = group->" << v.name << "InSyn" << i << "["; + os << v.type->getName() << " lps" << v.name << " = group->" << v.name << "InSyn" << i << "["; os << getVarIndex(batchSize, getVarAccessDuplication(v.access), neuronSubs["id"]) << "];" << std::endl; } @@ -366,7 +366,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type->getResolvedName(getTypeContext()) << " lcs" << v.name << " = " << "group->" << v.name << "CS" << i << "["; + os << v.type->getName() << " lcs" << v.name << " = " << "group->" << v.name << "CS" << i << "["; os << getVarIndex(batchSize, getVarAccessDuplication(v.access), popSubs["id"]) << "];" << std::endl; } @@ -794,7 +794,7 @@ void NeuronUpdateGroupMerged::generateWUVarUpdate(CodeStream &os, const Substitu if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type->getResolvedName(getTypeContext()) << " l" << v.name << " = group->" << v.name << fieldPrefixStem << i << "["; + os << v.type->getName() << " l" << v.name << " = group->" << v.name << fieldPrefixStem << i << "["; os << getReadVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), subs["id"]) << "];" << std::endl; } diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index f21b89661a..98ae3b58d6 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -871,7 +871,7 @@ void PostSpanToeplitz::genUpdate(CodeStream &os, const ModelSpecMerged &modelMer connSubs.applyCheckUnreplaced(value, "toeplitz diagonal build state var : merged" + std::to_string(sg.getIndex())); //value = ensureFtype(value, modelMerged.getModel().getPrecision()); - os << d.type->getResolvedName(sg.getTypeContext()) << " " << d.name << " = " << value << ";" << std::endl; + os << d.type->getName() << " " << d.name << " = " << value << ";" << std::endl; } os << "const unsigned int numSpikes = group->srcSpkCnt" << eventSuffix << "[" << sg.getPreSlot(batchSize) << "];" << std::endl; diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 5709ba6c29..75864fdc60 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -90,7 +90,7 @@ void applySynapseSubstitutions(CodeStream &os, std::string code, const std::stri varSubs.applyCheckUnreplaced(code, "initVar : merged" + var.name + std::to_string(sg.getIndex())); // Declare local variable - os << var.type->getResolvedName(sg.getTypeContext()) << " " << "l" << var.name << ";" << std::endl; + os << var.type->getName() << " " << "l" << var.name << ";" << std::endl; // Insert code to initialize variable into scope { @@ -251,7 +251,7 @@ void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendB std::string value = a.value; popSubs.applyCheckUnreplaced(value, "proceduralSparseConnectivity row build state var : merged" + std::to_string(getIndex())); //value = ensureFtype(value, modelMerged.getModel().getPrecision()); - os << a.type->getResolvedName(getTypeContext()) << " " << a.name << " = " << value << ";" << std::endl; + os << a.type->getName() << " " << a.name << " = " << value << ";" << std::endl; } // Loop through synapses in row From 2c17fdfa604a07b298f2ea5c77b5cc46488c232d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 23 Jan 2023 12:05:50 +0000 Subject: [PATCH 094/725] ``timepoint`` seems a sensible name for the time type --- include/genn/genn/code_generator/groupMerged.h | 2 +- src/genn/genn/code_generator/modelSpecMerged.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 24128b5666..cf4fdc74f0 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -203,7 +203,7 @@ class GroupMerged // Protected methods //------------------------------------------------------------------------ const Type::NumericBase *getScalarType() const{ return dynamic_cast(m_TypeContext.at("scalar")); } - const Type::NumericBase *getTimeType() const{ return dynamic_cast(m_TypeContext.at("time")); } + const Type::NumericBase *getTimeType() const{ return dynamic_cast(m_TypeContext.at("timepoint")); } //! Helper to test whether parameter is referenced in vector of codestrings bool isParamReferenced(const std::vector &codeStrings, const std::string ¶mName) const diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index 67180589af..4b36e0691f 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -32,7 +32,7 @@ void assignGroups(const BackendBase &backend, std::vector &groups, BackendBas ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBase &backend) : m_Model(model), m_NeuronUpdateSupportCode("NeuronUpdateSupportCode"), m_PostsynapticDynamicsSupportCode("PostsynapticDynamicsSupportCode"), m_PresynapticUpdateSupportCode("PresynapticUpdateSupportCode"), m_PostsynapticUpdateSupportCode("PostsynapticUpdateSupportCode"), - m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode"), m_TypeContext{{"scalar", model.getPrecision()}, {"time", model.getTimePrecision()}} + m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode"), m_TypeContext{{"scalar", model.getPrecision()}, {"timepoint", model.getTimePrecision()}} { LOGD_CODE_GEN << "Merging neuron update groups:"; createMergedGroupsHash(backend, model.getNeuronGroups(), m_MergedNeuronUpdateGroups, From a034f984d1c2f05e5f34192122821ab2cf734e17 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 23 Jan 2023 12:28:11 +0000 Subject: [PATCH 095/725] don't need to pass type context through to all backend calls --- include/genn/backends/cuda/backend.h | 28 +++++---- .../backends/single_threaded_cpu/backend.h | 30 +++++----- .../genn/genn/code_generator/backendBase.h | 53 ++++++++--------- .../genn/genn/code_generator/groupMerged.h | 6 +- .../genn/code_generator/modelSpecMerged.h | 2 +- src/genn/backends/cuda/backend.cc | 34 ++++++----- .../backends/single_threaded_cpu/backend.cc | 24 ++++---- .../customConnectivityUpdateGroupMerged.cc | 8 +-- .../genn/code_generator/generateRunner.cc | 57 +++++++++---------- .../genn/code_generator/initGroupMerged.cc | 4 +- 10 files changed, 117 insertions(+), 129 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index 2309197e58..5161d60a4e 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -188,13 +188,11 @@ class BACKEND_EXPORT Backend : public BackendSIMT //! Generate code to define a variable in the appropriate header file virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc) const final; + const Type::ValueBase *type, const std::string &name, VarLocation loc) const final; //! Generate code to instantiate a variable in the provided stream virtual void genVariableInstantiation(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc) const final; + const Type::ValueBase *type, const std::string &name, VarLocation loc) const final; //! Generate code to allocate variable with a size known at compile-time virtual void genVariableAllocation(CodeStream &os, @@ -203,41 +201,41 @@ class BACKEND_EXPORT Backend : public BackendSIMT //! Generate code to allocate variable with a size known at runtime virtual void genVariableDynamicAllocation(CodeStream &os, - const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + const Type::Base *type, const std::string &name, VarLocation loc, + const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code to free a variable virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const final; //! Generate code for pushing a variable with a size known at compile-time to the 'device' virtual void genVariablePush(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const final; //! Generate code for pulling a variable with a size known at compile-time from the 'device' virtual void genVariablePull(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, size_t count) const final; //! Generate code for pushing a variable's value in the current timestep to the 'device' virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, unsigned int batchSize) const final; //! Generate code for pulling a variable's value in the current timestep from the 'device' virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, unsigned int batchSize) const final; //! Generate code for pushing a variable with a size known at tuntime to the 'device' virtual void genVariableDynamicPush(CodeStream &os, - const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + const Type::Base *type, const std::string &name, VarLocation loc, + const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code for pulling a variable with a size known at runtime from the 'device' virtual void genVariableDynamicPull(CodeStream &os, - const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + const Type::Base *type, const std::string &name, VarLocation loc, + const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' virtual void genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, @@ -245,7 +243,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT const std::string &egpName) const final; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const override; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const override; //! When generating merged structures what type to use for simulation RNGs virtual const Type::ValueBase *getMergedGroupSimRNGType() const override; diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 91ab53471f..62a590c1bc 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -60,13 +60,11 @@ class BACKEND_EXPORT Backend : public BackendBase //! Generate code to define a variable in the appropriate header file virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc) const final; + const Type::ValueBase *type, const std::string &name, VarLocation loc) const final; //! Generate code to instantiate a variable in the provided stream virtual void genVariableInstantiation(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc) const final; + const Type::ValueBase *type, const std::string &name, VarLocation loc) const final; //! Generate code to allocate variable with a size known at compile-time virtual void genVariableAllocation(CodeStream &os, @@ -75,41 +73,41 @@ class BACKEND_EXPORT Backend : public BackendBase //! Generate code to allocate variable with a size known at runtime virtual void genVariableDynamicAllocation(CodeStream &os, - const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + const Type::Base *type, const std::string &name, VarLocation loc, + const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code to free a variable virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const final; //! Generate code for pushing a variable with a size known at compile-time to the 'device' virtual void genVariablePush(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, bool autoInitialized, size_t count) const final; + const Type::ValueBase *type, const std::string &name, VarLocation loc, + bool autoInitialized, size_t count) const final; //! Generate code for pulling a variable with a size known at compile-time from the 'device' virtual void genVariablePull(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, size_t count) const final; //! Generate code for pushing a variable's value in the current timestep to the 'device' virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, unsigned int batchSize) const final; //! Generate code for pulling a variable's value in the current timestep from the 'device' virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, unsigned int batchSize) const final; //! Generate code for pushing a variable with a size known at tuntime to the 'device' virtual void genVariableDynamicPush(CodeStream &os, - const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + const Type::Base *type, const std::string &name, VarLocation loc, + const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code for pulling a variable with a size known at runtime from the 'device' virtual void genVariableDynamicPull(CodeStream &os, - const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + const Type::Base *type, const std::string &name, VarLocation loc, + const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' virtual void genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, @@ -117,7 +115,7 @@ class BACKEND_EXPORT Backend : public BackendBase const std::string &egpName) const final; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const final; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const final; //! When generating merged structures what type to use for simulation RNGs virtual const Type::ValueBase *getMergedGroupSimRNGType() const final; diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index a9cfc62a90..cec6681eb0 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -247,13 +247,11 @@ class GENN_EXPORT BackendBase //! Generate code to define a variable in the appropriate header file virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc) const = 0; + const Type::ValueBase *type, const std::string &name, VarLocation loc) const = 0; //! Generate code to instantiate a variable in the provided stream virtual void genVariableInstantiation(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc) const = 0; + const Type::ValueBase *type, const std::string &name, VarLocation loc) const = 0; //! Generate code to allocate variable with a size known at compile-time virtual void genVariableAllocation(CodeStream &os, @@ -262,41 +260,41 @@ class GENN_EXPORT BackendBase //! Generate code to allocate variable with a size known at runtime virtual void genVariableDynamicAllocation(CodeStream &os, - const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; + const Type::Base *type, const std::string &name, VarLocation loc, + const std::string &countVarName = "count", const std::string &prefix = "") const = 0; //! Generate code to free a variable virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const = 0; //! Generate code for pushing a variable with a size known at compile-time to the 'device' virtual void genVariablePush(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const = 0; //! Generate code for pulling a variable with a size known at compile-time from the 'device' virtual void genVariablePull(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, size_t count) const = 0; //! Generate code for pushing a variable's value in the current timestep to the 'device' virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, unsigned int batchSize) const = 0; //! Generate code for pulling a variable's value in the current timestep from the 'device' virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, unsigned int batchSize) const = 0; //! Generate code for pushing a variable with a size known at tuntime to the 'device' virtual void genVariableDynamicPush(CodeStream &os, - const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; + const Type::Base *type, const std::string &name, VarLocation loc, + const std::string &countVarName = "count", const std::string &prefix = "") const = 0; //! Generate code for pulling a variable with a size known at runtime from the 'device' virtual void genVariableDynamicPull(CodeStream &os, - const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; + const Type::Base *type, const std::string &name, VarLocation loc, + const std::string &countVarName = "count", const std::string &prefix = "") const = 0; //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' virtual void genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, @@ -304,7 +302,7 @@ class GENN_EXPORT BackendBase const std::string &egpName) const = 0; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const = 0; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const = 0; //! When generating merged structures what type to use for simulation RNGs virtual const Type::ValueBase *getMergedGroupSimRNGType() const = 0; @@ -418,39 +416,38 @@ class GENN_EXPORT BackendBase //-------------------------------------------------------------------------- //! Helper function to generate matching push and pull functions for a variable void genVariablePushPull(CodeStream &push, CodeStream &pull, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const { - genVariablePush(push, type, typeContext, name, loc, autoInitialized, count); - genVariablePull(pull, type, typeContext, name, loc, count); + genVariablePush(push, type, name, loc, autoInitialized, count); + genVariablePull(pull, type, name, loc, count); } //! Templated version of helper function to generate matching push and pull functions for //! a variable when type is known at compile time template void genVariablePushPull(CodeStream &push, CodeStream &pull, - const Type::TypeContext &typeContext, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const + const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const { - genVariablePushPull(push, pull, T::getInstance(), typeContext, name, loc, autoInitialized, count); + genVariablePushPull(push, pull, T::getInstance(), name, loc, autoInitialized, count); } //! Helper function to generate matching push and pull functions for the current state of a variable void genCurrentVariablePushPull(CodeStream &push, CodeStream &pull, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, unsigned int batchSize) const { - genCurrentVariablePush(push, ng, type, typeContext, name, loc, batchSize); - genCurrentVariablePull(pull, ng, type, typeContext, name, loc, batchSize); + genCurrentVariablePush(push, ng, type, name, loc, batchSize); + genCurrentVariablePull(pull, ng, type, name, loc, batchSize); } //! Templated version of gelper function to generate matching push and pull functions //! for the current state of variable when type is known at compile time template void genCurrentVariablePushPull(CodeStream &push, CodeStream &pull, const NeuronGroupInternal &ng, - const std::string &name, const Type::TypeContext &typeContext, - VarLocation loc, unsigned int batchSize) const + const std::string &name, VarLocation loc, unsigned int batchSize) const { - genCurrentVariablePushPull(push, pull, ng, T::getInstance(), typeContext, name, loc, batchSize); + genCurrentVariablePushPull(push, pull, ng, T::getInstance(), name, loc, batchSize); } //! Helper function to generate matching definition, declaration, allocation and free code for a statically-sized array @@ -458,8 +455,8 @@ class GENN_EXPORT BackendBase const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const { - genVariableDefinition(definitions, definitionsInternal, type, typeContext, name, loc); - genVariableInstantiation(runner, type, typeContext, name, loc); + genVariableDefinition(definitions, definitionsInternal, type, name, loc); + genVariableInstantiation(runner, type, name, loc); genVariableFree(free, name, loc); genVariableAllocation(allocations, type, typeContext, name, loc, count, memAlloc); } diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index cf4fdc74f0..3f42f3db23 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -120,7 +120,7 @@ class GroupMerged if(dynamic_cast(type) && !(std::get<3>(f) & GroupMergedFieldType::HOST)) { // If we are generating a host structure, allow the backend to override the type if(host) { - os << backend.getMergedGroupFieldHostTypeName(type, m_TypeContext); + os << backend.getMergedGroupFieldHostTypeName(type); } // Otherwise, allow the backend to add a prefix else { @@ -145,7 +145,7 @@ class GroupMerged const auto sortedFields = getSortedFields(backend); for(size_t fieldIndex = 0; fieldIndex < sortedFields.size(); fieldIndex++) { const auto &f = sortedFields[fieldIndex]; - os << backend.getMergedGroupFieldHostTypeName(std::get<0>(f), m_TypeContext) << " " << std::get<1>(f); + os << backend.getMergedGroupFieldHostTypeName(std::get<0>(f)) << " " << std::get<1>(f); if(fieldIndex != (sortedFields.size() - 1)) { os << ", "; } @@ -478,7 +478,7 @@ class GroupMerged // If this field is a dynamic pointer if((std::get<3>(f) & GroupMergedFieldType::DYNAMIC) && dynamic_cast(std::get<0>(f))) { definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << getIndex() << std::get<1>(f) << "ToDevice(unsigned int idx, "; - definitionsInternalFunc << backend.getMergedGroupFieldHostTypeName(std::get<0>(f), m_TypeContext) << " value);" << std::endl; + definitionsInternalFunc << backend.getMergedGroupFieldHostTypeName(std::get<0>(f)) << " value);" << std::endl; } // Raise error if this field is a host field but this isn't a host structure diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 82cd1d9201..75d419b57e 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -267,7 +267,7 @@ class GENN_EXPORT ModelSpecMerged os << "// ------------------------------------------------------------------------" << std::endl; // Loop through resultant fields and generate function to push updated pointers into group merged for(auto f : mergedGroupFields) { - os << "void pushMerged" << T::name << f.mergedGroupIndex << f.fieldName << "ToDevice(unsigned int idx, " << backend.getMergedGroupFieldHostTypeName(f.type, m_TypeContext) << " value)"; + os << "void pushMerged" << T::name << f.mergedGroupIndex << f.fieldName << "ToDevice(unsigned int idx, " << backend.getMergedGroupFieldHostTypeName(f.type) << " value)"; { CodeStream::Scope b(os); backend.genMergedDynamicVariablePush(os, T::name, f.mergedGroupIndex, "idx", f.fieldName, "value"); diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 1e5df16c80..c0ba3c5269 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -1572,8 +1572,7 @@ void Backend::genStepTimeFinalisePreamble(CodeStream &os, const ModelSpecMerged } //-------------------------------------------------------------------------- void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc) const + const Type::ValueBase *type, const std::string &name, VarLocation loc) const { const bool deviceType = dynamic_cast(type); CodeStream &d = deviceType ? definitionsInternal : definitions; @@ -1599,8 +1598,7 @@ void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &definit } //-------------------------------------------------------------------------- void Backend::genVariableInstantiation(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc) const + const Type::ValueBase *type, const std::string &name, VarLocation loc) const { if(getPreferences().automaticCopy) { os << type->getPointerType()->getName() << " " << name << ";" << std::endl; @@ -1646,8 +1644,8 @@ void Backend::genVariableAllocation(CodeStream &os, } //-------------------------------------------------------------------------- void Backend::genVariableDynamicAllocation(CodeStream &os, - const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, const std::string &countVarName, const std::string &prefix) const + const Type::Base *type, const std::string &name, VarLocation loc, + const std::string &countVarName, const std::string &prefix) const { const auto *pointerType = dynamic_cast(type); const auto *underlyingType = pointerType ? pointerType->getValueType() : type; @@ -1694,8 +1692,8 @@ void Backend::genVariableFree(CodeStream &os, const std::string &name, VarLocati } //-------------------------------------------------------------------------- void Backend::genVariablePush(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, bool autoInitialized, size_t count) const + const Type::ValueBase *type, const std::string &name, VarLocation loc, + bool autoInitialized, size_t count) const { assert(!getPreferences().automaticCopy); @@ -1716,7 +1714,7 @@ void Backend::genVariablePush(CodeStream &os, } //-------------------------------------------------------------------------- void Backend::genVariablePull(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, size_t count) const { assert(!getPreferences().automaticCopy); @@ -1729,7 +1727,7 @@ void Backend::genVariablePull(CodeStream &os, } //-------------------------------------------------------------------------- void Backend::genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, unsigned int batchSize) const { assert(!getPreferences().automaticCopy); @@ -1754,12 +1752,12 @@ void Backend::genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal & } // Otherwise, generate standard push else { - genVariablePush(os, type, typeContext, name + ng.getName(), loc, false, ng.getNumNeurons() * batchSize); + genVariablePush(os, type, name + ng.getName(), loc, false, ng.getNumNeurons() * batchSize); } } //-------------------------------------------------------------------------- void Backend::genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ValueBase *type, const std::string &name, VarLocation loc, unsigned int batchSize) const { assert(!getPreferences().automaticCopy); @@ -1783,13 +1781,13 @@ void Backend::genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal & } // Otherwise, generate standard pull else { - genVariablePull(os, type, typeContext, name + ng.getName(), loc, ng.getNumNeurons() * batchSize); + genVariablePull(os, type, name + ng.getName(), loc, ng.getNumNeurons() * batchSize); } } //-------------------------------------------------------------------------- void Backend::genVariableDynamicPush(CodeStream &os, - const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, const std::string &countVarName, const std::string &prefix) const + const Type::Base *type, const std::string &name, VarLocation loc, + const std::string &countVarName, const std::string &prefix) const { assert(!getPreferences().automaticCopy); @@ -1810,8 +1808,8 @@ void Backend::genVariableDynamicPush(CodeStream &os, } //-------------------------------------------------------------------------- void Backend::genVariableDynamicPull(CodeStream &os, - const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation loc, const std::string &countVarName, const std::string &prefix) const + const Type::Base *type, const std::string &name, VarLocation loc, + const std::string &countVarName, const std::string &prefix) const { assert(!getPreferences().automaticCopy); @@ -1841,7 +1839,7 @@ void Backend::genMergedDynamicVariablePush(CodeStream &os, const std::string &su os << ", (sizeof(" << structName << ") * (" << groupIdx << ")) + offsetof(" << structName << ", " << fieldName << ")));" << std::endl; } //-------------------------------------------------------------------------- -std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const +std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) const { return type->getName(); } diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index f4cf846354..28c22347d5 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1285,15 +1285,13 @@ void Backend::genStepTimeFinalisePreamble(CodeStream &, const ModelSpecMerged &) } //-------------------------------------------------------------------------- void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation) const + const Type::ValueBase *type, const std::string &name, VarLocation) const { definitions << "EXPORT_VAR " << type->getPointerType()->getName() << " " << name << ";" << std::endl; } //-------------------------------------------------------------------------- void Backend::genVariableInstantiation(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation) const + const Type::ValueBase *type, const std::string &name, VarLocation) const { os << type->getPointerType()->getName() << " " << name << ";" << std::endl; } @@ -1308,8 +1306,8 @@ void Backend::genVariableAllocation(CodeStream &os, } //-------------------------------------------------------------------------- void Backend::genVariableDynamicAllocation(CodeStream &os, - const Type::Base *type, const Type::TypeContext &typeContext, const std::string &name, - VarLocation, const std::string &countVarName, const std::string &prefix) const + const Type::Base *type, const std::string &name, VarLocation, + const std::string &countVarName, const std::string &prefix) const { const auto *pointerType = dynamic_cast(type); if (pointerType) { @@ -1325,39 +1323,39 @@ void Backend::genVariableFree(CodeStream &os, const std::string &name, VarLocati os << "delete[] " << name << ";" << std::endl; } //-------------------------------------------------------------------------- -void Backend::genVariablePush(CodeStream&, const Type::ValueBase*, const Type::TypeContext&, const std::string&, VarLocation, bool, size_t) const +void Backend::genVariablePush(CodeStream&, const Type::ValueBase*, const std::string&, VarLocation, bool, size_t) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- -void Backend::genVariablePull(CodeStream&, const Type::ValueBase*, const Type::TypeContext&, const std::string&, VarLocation, size_t) const +void Backend::genVariablePull(CodeStream&, const Type::ValueBase*, const std::string&, VarLocation, size_t) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- void Backend::genCurrentVariablePush(CodeStream&, const NeuronGroupInternal&, - const Type::ValueBase*, const Type::TypeContext&, const std::string&, + const Type::ValueBase*, const std::string&, VarLocation, unsigned int) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- void Backend::genCurrentVariablePull(CodeStream&, const NeuronGroupInternal&, - const Type::ValueBase*, const Type::TypeContext&, const std::string&, + const Type::ValueBase*, const std::string&, VarLocation, unsigned int) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- void Backend::genVariableDynamicPush(CodeStream&, - const Type::Base*, const Type::TypeContext&, const std::string&, + const Type::Base*, const std::string&, VarLocation, const std::string&, const std::string&) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- void Backend::genVariableDynamicPull(CodeStream&, - const Type::Base*, const Type::TypeContext&, const std::string&, + const Type::Base*, const std::string&, VarLocation, const std::string&, const std::string&) const { assert(!getPreferences().automaticCopy); @@ -1371,7 +1369,7 @@ void Backend::genMergedDynamicVariablePush(CodeStream &os, const std::string &su } //-------------------------------------------------------------------------- -std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const +std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) const { return type->getName(); } diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 54a90ce4e0..db6b3c1a9e 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -488,7 +488,7 @@ void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase & // Generate code to push this EGP with count specified by $(0) std::stringstream pushStream; CodeStream push(pushStream); - backend.genVariableDynamicPush(push, egp.type, getTypeContext(), egp.name, + backend.genVariableDynamicPush(push, egp.type, egp.name, VarLocation::HOST_DEVICE, "$(0)", "group->"); // Add substitution @@ -497,7 +497,7 @@ void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase & // Generate code to pull this EGP with count specified by $(0) std::stringstream pullStream; CodeStream pull(pullStream); - backend.genVariableDynamicPull(pull, egp.type, getTypeContext(), egp.name, + backend.genVariableDynamicPull(pull, egp.type, egp.name, VarLocation::HOST_DEVICE, "$(0)", "group->"); // Add substitution @@ -529,7 +529,7 @@ void CustomConnectivityHostUpdateGroupMerged::addVarPushPullFuncSubs(const Backe // Generate code to push this variable std::stringstream pushStream; CodeStream push(pushStream); - backend.genVariableDynamicPush(push, v.type, getTypeContext(), v.name, + backend.genVariableDynamicPush(push, v.type, v.name, loc, count, "group->"); // Add substitution @@ -539,7 +539,7 @@ void CustomConnectivityHostUpdateGroupMerged::addVarPushPullFuncSubs(const Backe // **YUCK** these EGP functions should probably just be called dynamic or something std::stringstream pullStream; CodeStream pull(pullStream); - backend.genVariableDynamicPull(pull, v.type, getTypeContext(), v.name, + backend.genVariableDynamicPull(pull, v.type, v.name, loc, count, "group->"); // Add substitution diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index d7f1808fe1..c3e9d73023 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -264,7 +264,7 @@ void genVariable(const ModelSpecMerged &modelMerged, const BackendBase &backend, genVarPushPullScope(definitionsFunc, push, pull, loc, backend.getPreferences().automaticCopy, name, statePushPullFunction, [&]() { - backend.genVariablePushPull(push, pull, type, modelMerged.getTypeContext(), name, loc, autoInitialized, count); + backend.genVariablePushPull(push, pull, type, name, loc, autoInitialized, count); }); // Generate variables @@ -277,8 +277,8 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & CodeStream &extraGlobalParam, const Type::NumericBase *type, const std::string &name, bool apiRequired, VarLocation loc) { // Generate variables - backend.genVariableDefinition(definitionsVar, definitionsInternalVar, type, modelMerged.getTypeContext(), name, loc); - backend.genVariableInstantiation(runner, type, modelMerged.getTypeContext(), name, loc); + backend.genVariableDefinition(definitionsVar, definitionsInternalVar, type, name, loc); + backend.genVariableInstantiation(runner, type, name, loc); // If API is required if(apiRequired) { @@ -290,7 +290,7 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & extraGlobalParam << "void allocate" << name << "(unsigned int count)"; { CodeStream::Scope a(extraGlobalParam); - backend.genVariableDynamicAllocation(extraGlobalParam, type, modelMerged.getTypeContext(), name, loc); + backend.genVariableDynamicAllocation(extraGlobalParam, type, name, loc); // Loop through destinations in merged structures, the device EGP needs to be copied to // **TODO** rename to dynamic @@ -353,7 +353,7 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & extraGlobalParam << "void push" << name << "ToDevice(unsigned int count)"; { CodeStream::Scope a(extraGlobalParam); - backend.genVariableDynamicPush(extraGlobalParam, type, modelMerged.getTypeContext(), name, loc); + backend.genVariableDynamicPush(extraGlobalParam, type, name, loc); } if(backend.getPreferences().generateExtraGlobalParamPull) { @@ -364,7 +364,7 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & extraGlobalParam << "void pull" << name << "FromDevice(unsigned int count)"; { CodeGenerator::CodeStream::Scope a(extraGlobalParam); - backend.genVariableDynamicPull(extraGlobalParam, type, modelMerged.getTypeContext(), name, loc); + backend.genVariableDynamicPull(extraGlobalParam, type, name, loc); } } } @@ -466,7 +466,7 @@ void genRunnerFusedVarPushPull(const ModelSpecMerged &modelMerged, const Backend [&]() { backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - var.type, modelMerged.getTypeContext(), var.name + group.getName(), + var.type, var.name + group.getName(), varAdaptor.getVarLocation(var.name), autoInitialized, getSizeFn(group, var)); }); } @@ -920,10 +920,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - modelMerged.getTypeContext(), "glbSpkCnt" + n.first, + "glbSpkCnt" + n.first, n.second.getSpikeLocation(), true, numSpikeCounts); backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - modelMerged.getTypeContext(), "glbSpk" + n.first, + "glbSpk" + n.first, n.second.getSpikeLocation(), true, numSpikes); }); @@ -933,10 +933,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If spike recording is enabled, define and declare variables and add free if(n.second.isSpikeRecordingEnabled()) { backend.genVariableDefinition(definitionsVar, definitionsInternalVar, - Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpk" + n.first, + Type::Uint32::getInstance(), "recordSpk" + n.first, VarLocation::HOST_DEVICE); backend.genVariableInstantiation(runnerVarDecl, - Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpk" + n.first, + Type::Uint32::getInstance(), "recordSpk" + n.first, VarLocation::HOST_DEVICE); backend.genVariableFree(runnerVarFree, "recordSpk" + n.first, VarLocation::HOST_DEVICE); @@ -963,10 +963,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - modelMerged.getTypeContext(), "glbSpkCntEvnt" + n.first, + "glbSpkCntEvnt" + n.first, n.second.getSpikeLocation(), true, batchSize * n.second.getNumDelaySlots()); backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - modelMerged.getTypeContext(), "glbSpkEvnt" + n.first, + "glbSpkEvnt" + n.first, n.second.getSpikeLocation(), true, numNeuronDelaySlots); }); @@ -976,10 +976,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If spike recording is enabled, define and declare variables and add free if(n.second.isSpikeEventRecordingEnabled()) { backend.genVariableDefinition(definitionsVar, definitionsInternalVar, - Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpkEvent" + n.first, + Type::Uint32::getInstance(), "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); backend.genVariableInstantiation(runnerVarDecl, - Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpkEvent" + n.first, + Type::Uint32::getInstance(), "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); backend.genVariableFree(runnerVarFree, "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); } @@ -1003,7 +1003,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - model.getTimePrecision(), modelMerged.getTypeContext(), "sT" + n.first, + model.getTimePrecision(), "sT" + n.first, n.second.getSpikeTimeLocation(), true, numNeuronDelaySlots); }); } @@ -1020,7 +1020,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - model.getTimePrecision(), modelMerged.getTypeContext(), "prevST" + n.first, + model.getTimePrecision(), "prevST" + n.first, n.second.getPrevSpikeTimeLocation(), true, numNeuronDelaySlots); }); } @@ -1037,7 +1037,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - model.getTimePrecision(), modelMerged.getTypeContext(), "seT" + n.first, + model.getTimePrecision(), "seT" + n.first, n.second.getSpikeEventTimeLocation(), true, numNeuronDelaySlots); }); } @@ -1054,7 +1054,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - model.getTimePrecision(), modelMerged.getTypeContext(), "prevSET" + n.first, + model.getTimePrecision(), "prevSET" + n.first, n.second.getPrevSpikeEventTimeLocation(), true, numNeuronDelaySlots); }); } @@ -1084,7 +1084,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { backend.genCurrentVariablePushPull(runnerPushFunc, runnerPullFunc, n.second, - var.type, modelMerged.getTypeContext(), var.name, + var.type, var.name, n.second.getVarLocation(var.name), numCopies); }); @@ -1313,9 +1313,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, backend.getPreferences().automaticCopy, s.second.getName() + "Connectivity", connectivityPushPullFunctions, [&]() { - // Row lengths backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - modelMerged.getTypeContext(), "gp" + s.second.getName(), + "gp" + s.second.getName(), s.second.getSparseConnectivityLocation(), autoInitialized, gpSize); }); } @@ -1359,12 +1358,12 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, { // Row lengths backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - modelMerged.getTypeContext(), "rowLength" + s.second.getName(), + "rowLength" + s.second.getName(), s.second.getSparseConnectivityLocation(), autoInitialized, s.second.getSrcNeuronGroup()->getNumNeurons()); // Target indices backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - s.second.getSparseIndType(), modelMerged.getTypeContext(), "ind" + s.second.getName(), + s.second.getSparseIndType(), "ind" + s.second.getName(), s.second.getSparseConnectivityLocation(), autoInitialized, size); }); } @@ -1422,7 +1421,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - model.getPrecision(), modelMerged.getTypeContext(), "inSyn" + s.second.getName(), + model.getPrecision(), "inSyn" + s.second.getName(), s.second.getInSynLocation(), true, s.second.getTrgNeuronGroup()->getNumNeurons() * batchSize); }); @@ -1604,7 +1603,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, if(n.second.isSpikeRecordingEnabled()) { CodeStream::Scope b(runner); backend.genVariableDynamicAllocation(runner, - Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpk" + n.first, + Type::Uint32::getInstance(), "recordSpk" + n.first, VarLocation::HOST_DEVICE, "numWords"); // Get destinations in merged structures, this EGP @@ -1620,7 +1619,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, if(n.second.isSpikeEventRecordingEnabled()) { CodeStream::Scope b(runner); backend.genVariableDynamicAllocation(runner, - Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpkEvent" + n.first, + Type::Uint32::getInstance(), "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE, "numWords"); // Get destinations in merged structures, this EGP @@ -1660,7 +1659,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, if(n.second.isSpikeRecordingEnabled()) { CodeStream::Scope b(runner); backend.genVariableDynamicPull(runner, - Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpk" + n.first, + Type::Uint32::getInstance(), "recordSpk" + n.first, VarLocation::HOST_DEVICE, "numWords"); } // AllocaPullte spike event array if required @@ -1668,7 +1667,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, if(n.second.isSpikeEventRecordingEnabled()) { CodeStream::Scope b(runner); backend.genVariableDynamicPull(runner, - Type::Uint32::getInstance(), modelMerged.getTypeContext(), "recordSpkEvent" + n.first, + Type::Uint32::getInstance(), "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE, "numWords"); } } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 84abb7fdaa..50230d8b5d 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -809,7 +809,7 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac const auto *pointerToEGP = egp.type->getPointerType(); CodeGenerator::CodeStream alloc(allocStream); backend.genVariableDynamicAllocation(alloc, - pointerToEGP, getTypeContext(), egp.name, + pointerToEGP, egp.name, loc, "$(0)", "group->"); // Add substitution @@ -819,7 +819,7 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac std::stringstream pushStream; CodeStream push(pushStream); backend.genVariableDynamicPush(push, - pointerToEGP, getTypeContext(), egp.name, + pointerToEGP, egp.name, loc, "$(0)", "group->"); From adbaf8c3ceca3890659ff15d0338a82ca20d05ac Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 23 Jan 2023 12:46:57 +0000 Subject: [PATCH 096/725] removed unused modelSpec parameters --- .../customConnectivityUpdateGroupMerged.h | 4 ++-- .../code_generator/customUpdateGroupMerged.h | 6 ++--- .../code_generator/synapseUpdateGroupMerged.h | 4 ++-- src/genn/backends/cuda/backend.cc | 20 ++++++++--------- .../backends/single_threaded_cpu/backend.cc | 20 ++++++++--------- src/genn/genn/code_generator/backendSIMT.cc | 14 ++++++------ .../customConnectivityUpdateGroupMerged.cc | 22 +++++++++---------- .../code_generator/customUpdateGroupMerged.cc | 10 ++++----- .../genn/code_generator/generateRunner.cc | 16 +++++++------- .../code_generator/neuronUpdateGroupMerged.cc | 14 ++++++------ .../presynapticUpdateStrategySIMT.cc | 6 ++--- .../synapseUpdateGroupMerged.cc | 4 ++-- src/genn/genn/transpiler/errorHandler.cc | 4 ++-- 13 files changed, 72 insertions(+), 72 deletions(-) diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index 4af7204d66..6ed6483cac 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -49,7 +49,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public CustomConnectivit runnerVarDecl, runnerMergedStructAlloc, name); } - void generateUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateUpdate(const BackendBase &backend, CodeStream &os, unsigned int batchSize, Substitutions &popSubs) const; //! Get sorted vector of variable names, types and duplication modes which //! need updating when synapses are added and removed, belonging to archetype group @@ -89,7 +89,7 @@ class GENN_EXPORT CustomConnectivityHostUpdateGroupMerged : public CustomConnect runnerVarDecl, runnerMergedStructAlloc, name, true); } - void generateUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged) const; + void generateUpdate(const BackendBase &backend, CodeStream &os) const; //---------------------------------------------------------------------------- // Static constants diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index 890286d051..d2b8c5566d 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -32,7 +32,7 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMergedgetName() << " t)"; + os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronPrevSpikeTimeUpdate] << "(timepoint t)"; { CodeStream::Scope b(os); @@ -515,7 +515,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged os << std::endl; size_t idStart = 0; - os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronUpdate] << "(" << model.getTimePrecision()->getName() << " t"; + os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronUpdate] << "(timepoint t"; if(model.isRecordingInUse()) { os << ", unsigned int recordingTimestep"; } @@ -537,7 +537,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged genNeuronUpdateKernel(os, kernelSubs, modelMerged, idStart); } - os << "void updateNeurons(" << model.getTimePrecision()->getName() << " t"; + os << "void updateNeurons(timepoint t"; if(model.isRecordingInUse()) { os << ", unsigned int recordingTimestep"; } @@ -620,7 +620,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge // If there are any presynaptic update groups size_t idPresynapticStart = 0; if(!modelMerged.getMergedPresynapticUpdateGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelPresynapticUpdate] << "(" << model.getTimePrecision()->getName() << " t)" << std::endl; // end of synapse kernel header + os << "extern \"C\" __global__ void " << KernelNames[KernelPresynapticUpdate] << "(timepoint t)" << std::endl; // end of synapse kernel header { CodeStream::Scope b(os); @@ -642,7 +642,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge // If any synapse groups require postsynaptic learning size_t idPostsynapticStart = 0; if(!modelMerged.getMergedPostsynapticUpdateGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelPostsynapticUpdate] << "(" << model.getTimePrecision()->getName() << " t)" << std::endl; + os << "extern \"C\" __global__ void " << KernelNames[KernelPostsynapticUpdate] << "(timepoint t)" << std::endl; { CodeStream::Scope b(os); @@ -663,7 +663,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge size_t idSynapseDynamicsStart = 0; if(!modelMerged.getMergedSynapseDynamicsGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDynamicsUpdate] << "(" << model.getTimePrecision()->getName() << " t)" << std::endl; // end of synapse kernel header + os << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDynamicsUpdate] << "(timepoint t)" << std::endl; // end of synapse kernel header { CodeStream::Scope b(os); @@ -682,7 +682,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge } } - os << "void updateSynapses(" << model.getTimePrecision()->getName() << " t)"; + os << "void updateSynapses(timepoint t)"; { CodeStream::Scope b(os); @@ -790,7 +790,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged [this](const CustomConnectivityUpdateInternal &cg){ return padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelCustomUpdate); }, [g](const CustomConnectivityUpdateGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; }); - os << "extern \"C\" __global__ void " << KernelNames[KernelCustomUpdate] << g << "(" << model.getTimePrecision()->getName() << " t)" << std::endl; + os << "extern \"C\" __global__ void " << KernelNames[KernelCustomUpdate] << g << "(timepoint t)" << std::endl; { CodeStream::Scope b(os); @@ -821,7 +821,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged [&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); }, [g](const CustomUpdateTransposeWUGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; }); - os << "extern \"C\" __global__ void " << KernelNames[KernelCustomTransposeUpdate] << g << "(" << model.getTimePrecision()->getName() << " t)" << std::endl; + os << "extern \"C\" __global__ void " << KernelNames[KernelCustomTransposeUpdate] << g << "(timepoint t)" << std::endl; { CodeStream::Scope b(os); @@ -842,7 +842,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // Loop through host update groups and generate code for those in this custom update group for (const auto &cg : modelMerged.getMergedCustomConnectivityHostUpdateGroups()) { if (cg.getArchetype().getUpdateGroupName() == g) { - cg.generateUpdate(*this, os, modelMerged); + cg.generateUpdate(*this, os); } } diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 28c22347d5..6df47e6ffe 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -141,7 +141,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // Generate preamble preambleHandler(os); - os << "void updateNeurons(" << model.getTimePrecision()->getName() << " t"; + os << "void updateNeurons(timepoint t"; if(model.isRecordingInUse()) { os << ", unsigned int recordingTimestep"; } @@ -313,7 +313,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge // Generate preamble preambleHandler(os); - os << "void updateSynapses(" << model.getTimePrecision()->getName() << " t)"; + os << "void updateSynapses(timepoint t)"; { CodeStream::Scope b(os); Substitutions funcSubs(getFunctionTemplates(model.getPrecision()->getName())); @@ -532,7 +532,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // Loop through host update groups and generate code for those in this custom update group for (const auto &cg : modelMerged.getMergedCustomConnectivityHostUpdateGroups()) { if (cg.getArchetype().getUpdateGroupName() == g) { - cg.generateUpdate(*this, os, modelMerged); + cg.generateUpdate(*this, os); } } @@ -570,7 +570,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged popSubs.addVarSubstitution("id", "i"); // Generate custom update - c.generateCustomUpdate(*this, os, modelMerged, popSubs); + c.generateCustomUpdate(*this, os, popSubs); // Loop through reduction targets and generate reduction for (const auto &r : reductionTargets) { @@ -593,7 +593,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged popSubs.addVarSubstitution("id", "i"); // Generate custom update - c.generateCustomUpdate(*this, os, modelMerged, popSubs); + c.generateCustomUpdate(*this, os, popSubs); // Write back reductions genWriteBackReductions(os, c, popSubs["id"]); @@ -625,7 +625,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged (CodeStream &os, Substitutions &subs) { // Call custom update handler - c.generateCustomUpdate(*this, os, modelMerged, subs); + c.generateCustomUpdate(*this, os, subs); // Write back reductions genWriteBackReductions(os, c, subs["id_syn"]); @@ -667,7 +667,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged synSubs.addVarSubstitution("id_post", "j"); // Call custom update handler - c.generateCustomUpdate(*this, os, modelMerged, synSubs); + c.generateCustomUpdate(*this, os, synSubs); // Write back reductions genWriteBackReductions(os, c, synSubs["id_syn"]); @@ -708,7 +708,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged if(c.getArchetype().isRowSimRNGRequired()) { popSubs.addVarSubstitution("rng", "hostRNG"); } - c.generateUpdate(*this, os, modelMerged, popSubs); + c.generateUpdate(*this, os, model.getBatchSize(), popSubs); } } } @@ -756,7 +756,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged synSubs.addVarSubstitution("id_post", "j"); // Call custom update handler - c.generateCustomUpdate(*this, os, modelMerged, synSubs); + c.generateCustomUpdate(*this, os, synSubs); // Update transpose variable os << "group->" << transposeVarName << "Transpose[(j * group->numSrcNeurons) + i] = l" << transposeVarName << ";" << std::endl; @@ -1728,7 +1728,7 @@ void Backend::genPresynapticUpdate(CodeStream &os, const ModelSpecMerged &modelM connSubs.addFuncSubstitution("addSynapse", 1 + (unsigned int)sg.getArchetype().getKernelSize().size(), presynapticUpdateStream.str()); // Generate toeplitz connectivity code - sg.generateToeplitzConnectivity(*this, os, modelMerged, connSubs); + sg.generateToeplitzConnectivity(*this, os, connSubs); if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { os << CodeStream::CB(130); // end if (eCode) diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index e493e6e948..92d55cbaf2 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -677,7 +677,7 @@ void BackendSIMT::genPresynapticUpdateKernel(CodeStream &os, const Substitutions // If any shared memory is required, declare array if(maxSharedMemPerThread > 0) { - os << getSharedPrefix() << modelMerged.getModel().getPrecision()->getName() << " shLg[" << maxSharedMemPerThread * getKernelBlockSize(KernelPresynapticUpdate) << "];" << std::endl; + os << getSharedPrefix() <<" scalar shLg[" << maxSharedMemPerThread * getKernelBlockSize(KernelPresynapticUpdate) << "];" << std::endl; } // If any of these synapse groups also have sparse connectivity, allocate shared memory for row length @@ -933,7 +933,7 @@ void BackendSIMT::genCustomUpdateKernel(CodeStream &os, const Substitutions &ker genCustomUpdateIndexCalculation(os, cg); // **THINK** it would be great to 'lift' reads of SHARED variables out of this loop - cg.generateCustomUpdate(*this, os, modelMerged, cuSubs); + cg.generateCustomUpdate(*this, os, cuSubs); // Loop through reduction targets and generate reduction for(const auto &r : reductionTargets) { @@ -976,7 +976,7 @@ void BackendSIMT::genCustomUpdateKernel(CodeStream &os, const Substitutions &ker reductionSubs.addVarSubstitution("id", "idx", true); // **THINK** it would be great to 'lift' reads of NEURON_SHARED variables out of this loop - cg.generateCustomUpdate(*this, os, modelMerged, reductionSubs); + cg.generateCustomUpdate(*this, os, reductionSubs); // Loop through reduction targets and generate reduction for (const auto &r : reductionTargets) { @@ -1027,7 +1027,7 @@ void BackendSIMT::genCustomUpdateKernel(CodeStream &os, const Substitutions &ker CodeStream::Scope b(os); genCustomUpdateIndexCalculation(os, cg); - cg.generateCustomUpdate(*this, os, modelMerged, cuSubs); + cg.generateCustomUpdate(*this, os, cuSubs); } } @@ -1140,7 +1140,7 @@ void BackendSIMT::genCustomUpdateWUKernel(CodeStream &os, const Substitutions &k os << "const unsigned int batchOffset = size * batch;" << std::endl; } - cg.generateCustomUpdate(*this, os, modelMerged, cuSubs); + cg.generateCustomUpdate(*this, os, cuSubs); // If this is a reduction if(cg.getArchetype().isBatchReduction()) { @@ -1245,7 +1245,7 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(CodeStream &os, const Substit synSubs.addVarSubstitution("id_pre", "y"); synSubs.addVarSubstitution("id_post", "x"); synSubs.addVarSubstitution("id_syn", "idx"); - cg.generateCustomUpdate(*this, os, modelMerged, synSubs); + cg.generateCustomUpdate(*this, os, synSubs); // Write forward weight to shared memory os << "shTile[" << getThreadID(1) << " + j][" << getThreadID(0) << "] = l" << transposeVarName << ";" << std::endl; @@ -1312,7 +1312,7 @@ void BackendSIMT::genCustomConnectivityUpdateKernel(CodeStream &os, const Substi genPopulationRNGPreamble(os, popSubs, "group->rng[" + popSubs["id"] + "]"); } - cg.generateUpdate(*this, os, modelMerged, popSubs); + cg.generateUpdate(*this, os, modelMerged.getModel().getBatchSize(), popSubs); // Copy local stream back to local if(cg.getArchetype().isRowSimRNGRequired()) { diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index db6b3c1a9e..04257d158f 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -203,7 +203,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateGroupMerged::get return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &backend, CodeStream &os, unsigned int batchSize, Substitutions &popSubs) const { Substitutions updateSubs(&popSubs); @@ -226,7 +226,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back const auto &dependentVars = getSortedArchetypeDependentVars(); // Determine if any - const bool modelBatched = (modelMerged.getModel().getBatchSize() > 1); + const bool modelBatched = (batchSize > 1); const bool anyBatched = (modelBatched && (std::any_of(getArchetype().getVarReferences().cbegin(), getArchetype().getVarReferences().cend(), [](const auto &v){ return v.second.isDuplicated(); }) || std::any_of(dependentVars.cbegin(), dependentVars.cend(), @@ -263,11 +263,11 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Use subsequent parameters to initialise new synapse's variables referenced via the custom connectivity update for (size_t i = 0; i < ccuVarRefs.size(); i++) { // If model is batched and this variable is duplicated - if ((modelMerged.getModel().getBatchSize() > 1) && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) + if ((batchSize > 1) && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) { // Copy parameter into a register (just incase it's e.g. a RNG call) and copy into all batches addSynapse << "const " << ccuVarRefs[i].type->getName() << " _" << ccuVarRefs[i].name << "Val = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; - addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; + addSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { CodeStream::Scope b(addSynapse); addSynapse << "group->" << ccuVarRefs[i].name << "[(b * synStride) + newIdx] = _" << ccuVarRefs[i].name << "Val;" << std::endl; @@ -282,10 +282,10 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Loop through any other dependent variables for (size_t i = 0; i < dependentVars.size(); i++) { // If model is batched and this dependent variable is duplicated - if ((modelMerged.getModel().getBatchSize() > 1) && dependentVars.at(i).isDuplicated()) + if ((batchSize > 1) && dependentVars.at(i).isDuplicated()) { // Loop through all batches and zero - addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; + addSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { CodeStream::Scope b(addSynapse); addSynapse << "group->_dependentVar" << i << "[(b * synStride) + newIdx] = 0;" << std::endl; @@ -325,11 +325,11 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Loop through variable references for (size_t i = 0; i < ccuVarRefs.size(); i++) { // If model is batched and this variable is duplicated - if ((modelMerged.getModel().getBatchSize() > 1) + if ((batchSize > 1) && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) { // Loop through all batches and copy custom connectivity update variable references from end of row over synapse to be deleted - removeSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; + removeSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { CodeStream::Scope b(addSynapse); removeSynapse << "group->" << ccuVarRefs[i].name << "[(b * synStride) + idx] = group->" << ccuVarRefs[i].name << "[(b * synStride) + lastIdx];" << std::endl; @@ -344,9 +344,9 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Loop through any other dependent variables for (size_t i = 0; i < dependentVars.size(); i++) { // If model is batched and this dependent variable is duplicated - if ((modelMerged.getModel().getBatchSize() > 1) && dependentVars.at(i).isDuplicated()) { + if ((batchSize > 1) && dependentVars.at(i).isDuplicated()) { // Loop through all batches and copy dependent variable from end of row over synapse to be deleted - removeSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; + removeSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { CodeStream::Scope b(removeSynapse); removeSynapse << "group->_dependentVar" << i << "[(b * synStride) + idx] = group->_dependentVar" << i << "[(b * synStride) + lastIdx];" << std::endl; @@ -456,7 +456,7 @@ CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged } //---------------------------------------------------------------------------- -void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged) const +void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase &backend, CodeStream &os) const { CodeStream::Scope b(os); os << "// merged custom connectivity host update group " << getIndex() << std::endl; diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 63593088c1..eca5d6a161 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -190,7 +190,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, Substitutions &popSubs) const { genCustomUpdate(os, popSubs, *this, "id", [this](const auto &varRef, const std::string &index) @@ -394,10 +394,10 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const //---------------------------------------------------------------------------- const std::string CustomUpdateWUGroupMerged::name = "CustomUpdateWU"; //---------------------------------------------------------------------------- -void CustomUpdateWUGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void CustomUpdateWUGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, Substitutions &popSubs) const { genCustomUpdate(os, popSubs, *this, "id_syn", - [this, &modelMerged](const auto &varRef, const std::string &index) + [this](const auto &varRef, const std::string &index) { return getVarRefIndex(getVarAccessDuplication(varRef.getVar().access), index); @@ -409,10 +409,10 @@ void CustomUpdateWUGroupMerged::generateCustomUpdate(const BackendBase&, CodeStr //---------------------------------------------------------------------------- const std::string CustomUpdateTransposeWUGroupMerged::name = "CustomUpdateTransposeWU"; //---------------------------------------------------------------------------- -void CustomUpdateTransposeWUGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void CustomUpdateTransposeWUGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, Substitutions &popSubs) const { genCustomUpdate(os, popSubs, *this, "id_syn", - [this, &modelMerged](const auto &varRef, const std::string &index) + [this](const auto &varRef, const std::string &index) { return getVarRefIndex(getVarAccessDuplication(varRef.getVar().access), index); diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index c3e9d73023..c75eaae3d7 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -453,7 +453,7 @@ void genRunnerFusedVars(const ModelSpecMerged &modelMerged, const BackendBase &b } //------------------------------------------------------------------------- template -void genRunnerFusedVarPushPull(const ModelSpecMerged &modelMerged, const BackendBase &backend, CodeStream &definitionsFunc, +void genRunnerFusedVarPushPull(const BackendBase &backend, CodeStream &definitionsFunc, CodeStream &runnerPushFunc, CodeStream &runnerPullFunc, const G &group, std::vector &groupStatePushPullFunctions, S getSizeFn) { @@ -620,9 +620,9 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Define and declare time variables definitionsVar << "EXPORT_VAR unsigned long long iT;" << std::endl; - definitionsVar << "EXPORT_VAR " << model.getTimePrecision()->getName() << " t;" << std::endl; + definitionsVar << "EXPORT_VAR timepoint t;" << std::endl; runnerVarDecl << "unsigned long long iT;" << std::endl; - runnerVarDecl << model.getTimePrecision()->getName() << " t;" << std::endl; + runnerVarDecl << "timepoint t;" << std::endl; if(model.isRecordingInUse()) { runnerVarDecl << "unsigned long long numRecordingTimesteps = 0;" << std::endl; @@ -1425,7 +1425,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, s.second.getInSynLocation(), true, s.second.getTrgNeuronGroup()->getNumNeurons() * batchSize); }); - genRunnerFusedVarPushPull(modelMerged, backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, + genRunnerFusedVarPushPull(backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, [batchSize](const SynapseGroupInternal &sg, const Models::Base::Var &var) { return getVarSize(var.access, sg.getTrgNeuronGroup()->getNumNeurons(), batchSize); @@ -1436,7 +1436,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // **NOTE** we generated initialisation and declaration code earlier - here we just generate push and pull as we want this per-synapse group if(!s.second.isWUPreModelFused()) { const unsigned int preDelaySlots = (s.second.getDelaySteps() == NO_DELAY) ? 1 : s.second.getSrcNeuronGroup()->getNumDelaySlots(); - genRunnerFusedVarPushPull(modelMerged, backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, + genRunnerFusedVarPushPull(backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, [batchSize, preDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) { return getVarSize(var.access, sg.getSrcNeuronGroup()->getNumNeurons(), batchSize, preDelaySlots); @@ -1448,7 +1448,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // **NOTE** we generated initialisation and declaration code earlier - here we just generate push and pull as we want this per-synapse group if(!s.second.isWUPostModelFused()) { const unsigned int postDelaySlots = (s.second.getBackPropDelaySteps() == NO_DELAY) ? 1 : s.second.getTrgNeuronGroup()->getNumDelaySlots(); - genRunnerFusedVarPushPull(modelMerged, backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, + genRunnerFusedVarPushPull(backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, [batchSize, postDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) { return getVarSize(var.access, sg.getTrgNeuronGroup()->getNumNeurons(), batchSize, postDelaySlots); @@ -1788,12 +1788,12 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, definitions << "EXPORT_FUNC void stepTime();" << std::endl; definitions << std::endl; definitions << "// Functions generated by backend" << std::endl; - definitions << "EXPORT_FUNC void updateNeurons(" << model.getTimePrecision()->getName() << " t"; + definitions << "EXPORT_FUNC void updateNeurons(timepoint t"; if(model.isRecordingInUse()) { definitions << ", unsigned int recordingTimestep"; } definitions << "); " << std::endl; - definitions << "EXPORT_FUNC void updateSynapses(" << model.getTimePrecision()->getName() << " t);" << std::endl; + definitions << "EXPORT_FUNC void updateSynapses(timepoint t);" << std::endl; definitions << "EXPORT_FUNC void initialize();" << std::endl; definitions << "EXPORT_FUNC void initializeSparse();" << std::endl; diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 22faf11de3..40ead7cb6b 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -194,19 +194,19 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C // Also read spike and spike-like-event times into local variables if required if(getArchetype().isSpikeTimeRequired()) { - os << "const " << model.getTimePrecision()->getName() << " lsT = group->sT["; + os << "const timepoint lsT = group->sT["; os << getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; } if(getArchetype().isPrevSpikeTimeRequired()) { - os << "const " << model.getTimePrecision()->getName() << " lprevST = group->prevST["; + os << "const timepoint lprevST = group->prevST["; os << getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; } if(getArchetype().isSpikeEventTimeRequired()) { - os << "const " << model.getTimePrecision()->getName() << " lseT = group->seT["; + os << "const timepoint lseT = group->seT["; os << getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; } if(getArchetype().isPrevSpikeEventTimeRequired()) { - os << "const " << model.getTimePrecision()->getName() << " lprevSET = group->prevSET["; + os << "const timepoint lprevSET = group->prevSET["; os << getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; } os << std::endl; @@ -221,7 +221,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C || sg->getPSModel()->getDecayCode().find("$(Isyn)") != std::string::npos); })) { - os << model.getPrecision()->getName() << " Isyn = 0;" << std::endl; + os << "scalar Isyn = 0;" << std::endl; } Substitutions neuronSubs(&popSubs); @@ -260,13 +260,13 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C const auto *psm = sg->getPSModel(); os << "// pull inSyn values in a coalesced access" << std::endl; - os << model.getPrecision()->getName() << " linSyn = group->inSynInSyn" << i << "["; + os << "scalar linSyn = group->inSynInSyn" << i << "["; os << getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; // If dendritic delay is required if (sg->isDendriticDelayRequired()) { // Get reference to dendritic delay buffer input for this timestep - os << backend.getPointerPrefix() << model.getPrecision()->getName() << " *denDelayFront = "; + os << backend.getPointerPrefix() << "scalar *denDelayFront = "; os << "&group->denDelayInSyn" << i << "[(*group->denDelayPtrInSyn" << i << " * group->numNeurons) + "; os << getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 98ae3b58d6..f8d08635bf 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -232,7 +232,7 @@ void PostSpan::genPreamble(CodeStream &os, const ModelSpecMerged &modelMerged, c { // If data structure is dense, we can accumulate output directly into register if(shouldAccumulateInRegister(sg)) { - os << modelMerged.getModel().getPrecision()->getName() << " linSyn = 0;" << std::endl; + os << "scalar linSyn = 0;" << std::endl; } else if(isSmallSharedMemoryPop(sg, backend)) { os << "if(" << backend.getThreadID() << " < group->numTrgNeurons)"; @@ -603,7 +603,7 @@ void PreSpanProcedural::genUpdate(CodeStream &os, const ModelSpecMerged &modelMe connSubs.addFuncSubstitution("addSynapse", 1 + (unsigned int)sg.getArchetype().getKernelSize().size(), presynapticUpdateStream.str()); // Generate procedural connectivity code - sg.generateProceduralConnectivity(backend, os, modelMerged, connSubs); + sg.generateProceduralConnectivity(backend, os, connSubs); if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { os << CodeStream::CB(130); @@ -970,7 +970,7 @@ void PostSpanToeplitz::genUpdate(CodeStream &os, const ModelSpecMerged &modelMer connSubs.addFuncSubstitution("addSynapse", 1 + (unsigned int)sg.getArchetype().getKernelSize().size(), presynapticUpdateStream.str()); // Generate toeplitz connectivity code - sg.generateToeplitzConnectivity(backend, os, modelMerged, connSubs); + sg.generateToeplitzConnectivity(backend, os, connSubs); if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { os << CodeStream::CB(130); // end if (eCode) diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 75864fdc60..004e7a3459 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -230,7 +230,7 @@ void PresynapticUpdateGroupMerged::generateSpikeUpdate(const BackendBase &backen *this, popSubs, modelMerged, backend.supportsNamespace()); } //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendBase&, CodeStream &os, Substitutions &popSubs) const { const auto &connectInit = getArchetype().getConnectivityInitialiser(); @@ -270,7 +270,7 @@ void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendB } } //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateToeplitzConnectivity(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void PresynapticUpdateGroupMerged::generateToeplitzConnectivity(const BackendBase&, CodeStream &os, Substitutions &popSubs) const { const auto &connectInit = getArchetype().getToeplitzConnectivityInitialiser(); diff --git a/src/genn/genn/transpiler/errorHandler.cc b/src/genn/genn/transpiler/errorHandler.cc index 92c8339cf0..0c90a298bd 100644 --- a/src/genn/genn/transpiler/errorHandler.cc +++ b/src/genn/genn/transpiler/errorHandler.cc @@ -32,7 +32,7 @@ void ErrorHandler::report(size_t line, std::string_view where, std::string_view //---------------------------------------------------------------------------- // GeNN::Transpiler::SingleLineErrorHandler //---------------------------------------------------------------------------- -void SingleLineErrorHandler::error(size_t line, std::string_view message) +void SingleLineErrorHandler::error(size_t, std::string_view message) { report("", message); } @@ -52,4 +52,4 @@ void SingleLineErrorHandler::report(std::string_view where, std::string_view mes LOGE_TRANSPILER << "Error" << where << ": " << message; m_Error = true; } -} // namespace GeNN::Transpiler \ No newline at end of file +} // namespace GeNN::Transpiler From 1b4e8e5933e659f6c44572c419d0de50c0f4d926 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 24 Jan 2023 09:14:09 +0000 Subject: [PATCH 097/725] started integrating code stream in pretty printer --- .../code_generator/customUpdateGroupMerged.h | 10 ++ include/genn/genn/transpiler/prettyPrinter.h | 10 +- .../code_generator/customUpdateGroupMerged.cc | 5 +- src/genn/genn/transpiler/prettyPrinter.cc | 129 ++++++++---------- 4 files changed, 79 insertions(+), 75 deletions(-) diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index d2b8c5566d..515cabd5a1 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -4,6 +4,9 @@ #include "code_generator/codeGenUtils.h" #include "code_generator/groupMerged.h" +// GeNN transpiler includes +#include "statement.h" + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomUpdateGroupMerged //---------------------------------------------------------------------------- @@ -41,6 +44,13 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMergedgetUpdateCode()); const auto tokens = Transpiler::Scanner::scanSource(code, errorHandler); - const auto statements = Transpiler::Parser::parseBlockItemList(tokens, errorHandler); - Transpiler::TypeChecker::typeCheck(statements, typeEnvironment, typeContext, errorHandler); - + m_UpdateStatements = Transpiler::Parser::parseBlockItemList(tokens, errorHandler); + Transpiler::TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, typeContext, errorHandler); } //---------------------------------------------------------------------------- bool CustomUpdateGroupMerged::isParamHeterogeneous(const std::string ¶mName) const diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index ed2d80c845..0bd7f4058b 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -6,8 +6,8 @@ #include #include -// GeNN includes -#include "type.h" +// GeNN code generator includes +#include "code_generator/codeStream.h" // Transpiler includes #include "transpiler/transpilerUtils.h" @@ -27,23 +27,13 @@ namespace class Visitor : public Expression::Visitor, public Statement::Visitor { public: - Visitor(const Type::TypeContext &context) : m_Context(context) {} - - //--------------------------------------------------------------------------- - // Public API - //--------------------------------------------------------------------------- - std::string print(const Statement::StatementList &statements) + Visitor(CodeGenerator::CodeStream &codeStream, const Statement::StatementList &statements, const Type::TypeContext &context) + : m_CodeStream(codeStream), m_Context(context) { - // Clear string stream - m_StringStream.str(""); - - for(auto &s : statements) { + for(auto &s : statements) { s.get()->accept(*this); - m_StringStream << std::endl; + m_CodeStream << std::endl; } - - // Return string stream contents - return m_StringStream.str(); } //--------------------------------------------------------------------------- @@ -51,56 +41,56 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- virtual void visit(const Expression::ArraySubscript &arraySubscript) final { - m_StringStream << arraySubscript.getPointerName().lexeme << "["; + m_CodeStream << arraySubscript.getPointerName().lexeme << "["; arraySubscript.getIndex()->accept(*this); - m_StringStream << "]"; + m_CodeStream << "]"; } virtual void visit(const Expression::Assignment &assignement) final { - m_StringStream << assignement.getVarName().lexeme << " " << assignement.getOperator().lexeme << " "; + m_CodeStream << assignement.getVarName().lexeme << " " << assignement.getOperator().lexeme << " "; assignement.getValue()->accept(*this); } virtual void visit(const Expression::Binary &binary) final { binary.getLeft()->accept(*this); - m_StringStream << " " << binary.getOperator().lexeme << " "; + m_CodeStream << " " << binary.getOperator().lexeme << " "; binary.getRight()->accept(*this); } virtual void visit(const Expression::Call &call) final { call.getCallee()->accept(*this); - m_StringStream << "("; + m_CodeStream << "("; for(const auto &a : call.getArguments()) { a->accept(*this); } - m_StringStream << ")"; + m_CodeStream << ")"; } virtual void visit(const Expression::Cast &cast) final { - m_StringStream << "("; + m_CodeStream << "("; printType(cast.getType()); - m_StringStream << ")"; + m_CodeStream << ")"; cast.getExpression()->accept(*this); } virtual void visit(const Expression::Conditional &conditional) final { conditional.getCondition()->accept(*this); - m_StringStream << " ? "; + m_CodeStream << " ? "; conditional.getTrue()->accept(*this); - m_StringStream << " : "; + m_CodeStream << " : "; conditional.getFalse()->accept(*this); } virtual void visit(const Expression::Grouping &grouping) final { - m_StringStream << "("; + m_CodeStream << "("; grouping.getExpression()->accept(*this); - m_StringStream << ")"; + m_CodeStream << ")"; } virtual void visit(const Expression::Literal &literal) final @@ -108,44 +98,44 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If literal is a double, we want to remove the d suffix in generated code std::string_view lexeme = literal.getValue().lexeme; if (literal.getValue().type == Token::Type::DOUBLE_NUMBER){ - m_StringStream << lexeme.substr(0, literal.getValue().lexeme.size() - 1); + m_CodeStream << lexeme.substr(0, literal.getValue().lexeme.size() - 1); } // Otherwise, if literal is a scalar, we want to add appropriate suffix for scalar type else if (literal.getValue().type == Token::Type::SCALAR_NUMBER) { const Type::NumericBase *scalar = dynamic_cast(m_Context.at("scalar")); - m_StringStream << lexeme << scalar->getLiteralSuffix(m_Context); + m_CodeStream << lexeme << scalar->getLiteralSuffix(m_Context); } // Otherwise, just write out original lexeme directly else { - m_StringStream << lexeme; + m_CodeStream << lexeme; } } virtual void visit(const Expression::Logical &logical) final { logical.getLeft()->accept(*this); - m_StringStream << " " << logical.getOperator().lexeme << " "; + m_CodeStream << " " << logical.getOperator().lexeme << " "; logical.getRight()->accept(*this); } virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final { - m_StringStream << postfixIncDec.getVarName().lexeme << postfixIncDec.getOperator().lexeme; + m_CodeStream << postfixIncDec.getVarName().lexeme << postfixIncDec.getOperator().lexeme; } virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final { - m_StringStream << prefixIncDec.getOperator().lexeme << prefixIncDec.getVarName().lexeme; + m_CodeStream << prefixIncDec.getOperator().lexeme << prefixIncDec.getVarName().lexeme; } virtual void visit(const Expression::Variable &variable) final { - m_StringStream << variable.getName().lexeme; + m_CodeStream << variable.getName().lexeme; } virtual void visit(const Expression::Unary &unary) final { - m_StringStream << unary.getOperator().lexeme; + m_CodeStream << unary.getOperator().lexeme; unary.getRight()->accept(*this); } @@ -154,89 +144,88 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- virtual void visit(const Statement::Break&) final { - m_StringStream << "break;"; + m_CodeStream << "break;"; } virtual void visit(const Statement::Compound &compound) final { - m_StringStream << "{" << std::endl; + CodeGenerator::CodeStream::Scope b(m_CodeStream); for(auto &s : compound.getStatements()) { s->accept(*this); - m_StringStream << std::endl; + m_CodeStream << std::endl; } - m_StringStream << "}" << std::endl; } virtual void visit(const Statement::Continue&) final { - m_StringStream << "continue;"; + m_CodeStream << "continue;"; } virtual void visit(const Statement::Do &doStatement) final { - m_StringStream << "do"; + m_CodeStream << "do"; doStatement.getBody()->accept(*this); - m_StringStream << "while("; + m_CodeStream << "while("; doStatement.getCondition()->accept(*this); - m_StringStream << ");" << std::endl; + m_CodeStream << ");" << std::endl; } virtual void visit(const Statement::Expression &expression) final { expression.getExpression()->accept(*this); - m_StringStream << ";"; + m_CodeStream << ";"; } virtual void visit(const Statement::For &forStatement) final { - m_StringStream << "for("; + m_CodeStream << "for("; if(forStatement.getInitialiser()) { forStatement.getInitialiser()->accept(*this); } else { - m_StringStream << ";"; + m_CodeStream << ";"; } - m_StringStream << " "; + m_CodeStream << " "; if(forStatement.getCondition()) { forStatement.getCondition()->accept(*this); } - m_StringStream << "; "; + m_CodeStream << "; "; if(forStatement.getIncrement()) { forStatement.getIncrement()->accept(*this); } - m_StringStream << ")"; + m_CodeStream << ")"; forStatement.getBody()->accept(*this); } virtual void visit(const Statement::If &ifStatement) final { - m_StringStream << "if("; + m_CodeStream << "if("; ifStatement.getCondition()->accept(*this); - m_StringStream << ")" << std::endl; + m_CodeStream << ")" << std::endl; ifStatement.getThenBranch()->accept(*this); if(ifStatement.getElseBranch()) { - m_StringStream << "else" << std::endl; + m_CodeStream << "else" << std::endl; ifStatement.getElseBranch()->accept(*this); } } virtual void visit(const Statement::Labelled &labelled) final { - m_StringStream << labelled.getKeyword().lexeme << " "; + m_CodeStream << labelled.getKeyword().lexeme << " "; if(labelled.getValue()) { labelled.getValue()->accept(*this); } - m_StringStream << " : "; + m_CodeStream << " : "; labelled.getBody()->accept(*this); } virtual void visit(const Statement::Switch &switchStatement) final { - m_StringStream << "switch("; + m_CodeStream << "switch("; switchStatement.getCondition()->accept(*this); - m_StringStream << ")" << std::endl; + m_CodeStream << ")" << std::endl; switchStatement.getBody()->accept(*this); } @@ -245,29 +234,29 @@ class Visitor : public Expression::Visitor, public Statement::Visitor printType(varDeclaration.getType()); for(const auto &var : varDeclaration.getInitDeclaratorList()) { - m_StringStream << std::get<0>(var).lexeme; + m_CodeStream << std::get<0>(var).lexeme; if(std::get<1>(var)) { - m_StringStream << " = "; + m_CodeStream << " = "; std::get<1>(var)->accept(*this); } - m_StringStream << ", "; + m_CodeStream << ", "; } - m_StringStream << ";"; + m_CodeStream << ";"; } virtual void visit(const Statement::While &whileStatement) final { - m_StringStream << "while("; + m_CodeStream << "while("; whileStatement.getCondition()->accept(*this); - m_StringStream << ")" << std::endl; + m_CodeStream << ")" << std::endl; whileStatement.getBody()->accept(*this); } virtual void visit(const Statement::Print &print) final { - m_StringStream << "print "; + m_CodeStream << "print "; print.getExpression()->accept(*this); - m_StringStream << ";"; + m_CodeStream << ";"; } private: @@ -306,13 +295,13 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Copy tokens backwards into string stream, seperating with spaces - std::copy(tokens.rbegin(), tokens.rend(), std::ostream_iterator(m_StringStream, " ")); + std::copy(tokens.rbegin(), tokens.rend(), std::ostream_iterator(m_CodeStream, " ")); } //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- - std::ostringstream m_StringStream; + CodeGenerator::CodeStream &m_CodeStream; const Type::TypeContext &m_Context; }; } // Anonymous namespace @@ -320,8 +309,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // GeNN::Transpiler::PrettyPrinter //--------------------------------------------------------------------------- -std::string GeNN::Transpiler::PrettyPrinter::print(const Statement::StatementList &statements, const Type::TypeContext &context) +void GeNN::Transpiler::PrettyPrinter::print(CodeGenerator::CodeStream &os, const Statement::StatementList &statements, + const Type::TypeContext &context) { - Visitor visitor(context); - return visitor.print(statements); + Visitor(os, statements, context); } From b1f935e96a1af8daa9d0995b340ff42c175a98a9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 24 Jan 2023 10:20:44 +0000 Subject: [PATCH 098/725] fixed dumb typo - left with awful unique_ptr shit --- include/genn/genn/code_generator/customUpdateGroupMerged.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index 515cabd5a1..b0d9a6f6fe 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -5,7 +5,7 @@ #include "code_generator/groupMerged.h" // GeNN transpiler includes -#include "statement.h" +#include "transpiler/statement.h" //---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomUpdateGroupMerged @@ -50,7 +50,7 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged Date: Thu, 26 Jan 2023 13:51:48 +0000 Subject: [PATCH 099/725] found and removed unused copy of group merged --- include/genn/genn/code_generator/backendSIMT.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 824921bd27..13a87f949e 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -244,7 +244,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase // Sum padded sizes of each group within merged group const size_t paddedSize = std::accumulate( gMerge.getGroups().cbegin(), gMerge.getGroups().cend(), size_t{0}, - [gMerge, getPaddedSizeFunc](size_t acc, std::reference_wrapper g) + [getPaddedSizeFunc](size_t acc, std::reference_wrapper g) { return (acc + getPaddedSizeFunc(g.get())); }); From 5561ef85a9086be4c0edc5a1ff451dbdc36b3225 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 26 Jan 2023 17:22:34 +0000 Subject: [PATCH 100/725] * tidied up interfaces to PrettyPrinter and TypeChecker - constructors are the only entry points * fixed bug with unclosed environments * pretty printer now has environments * outline pretty printer environments for basic substitution and caching of referenced variables to registers --- include/genn/genn/transpiler/prettyPrinter.h | 24 +- .../code_generator/customUpdateGroupMerged.cc | 231 +++++++++++++++++- src/genn/genn/code_generator/substitutions.cc | 2 +- src/genn/genn/transpiler/prettyPrinter.cc | 181 +++++++++----- src/genn/genn/transpiler/typeChecker.cc | 102 ++++---- 5 files changed, 431 insertions(+), 109 deletions(-) diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h index f3c21830e0..5427dc00ca 100644 --- a/include/genn/genn/transpiler/prettyPrinter.h +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -16,9 +16,29 @@ class CodeStream; } //--------------------------------------------------------------------------- -// GeNN::Transpiler::PrettyPrinter +// GeNN::Transpiler::PrettyPrinter::EnvironmentBase //--------------------------------------------------------------------------- namespace GeNN::Transpiler::PrettyPrinter { -void print(CodeGenerator::CodeStream &os, const Statement::StatementList &statements, const Type::TypeContext &context); +class EnvironmentBase +{ +public: + //------------------------------------------------------------------------ + // Declared virtuals + //------------------------------------------------------------------------ + //! Define variable named by token and return the name as it should be used in code + virtual std::string define(const Token &name) = 0; + + //! Get the name to use in code for the variable named by token + virtual std::string getName(const Token &name) = 0; + + //! Get stream to write code within this environment to + virtual CodeGenerator::CodeStream &getStream() = 0; +}; + +//--------------------------------------------------------------------------- +// Free functions +//--------------------------------------------------------------------------- +void print(const Statement::StatementList &statements, EnvironmentBase &environment, + const Type::TypeContext &context); } diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index a7ef05104c..1e0214f469 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -1,5 +1,8 @@ #include "code_generator/customUpdateGroupMerged.h" +// Standard C++ includes +#include + // GeNN code generator includes #include "code_generator/groupMergedTypeEnvironment.h" #include "code_generator/modelSpecMerged.h" @@ -7,19 +10,214 @@ // GeNN transpiler includes #include "transpiler/errorHandler.h" #include "transpiler/parser.h" +#include "transpiler/prettyPrinter.h" #include "transpiler/scanner.h" #include "transpiler/typeChecker.h" +#include "transpiler/transpilerUtils.h" using namespace GeNN; using namespace GeNN::CodeGenerator; using namespace GeNN::Transpiler; + //-------------------------------------------------------------------------- // Anonymous namespace //-------------------------------------------------------------------------- namespace { +class EnvironmentExternal : public PrettyPrinter::EnvironmentBase +{ +public: + EnvironmentExternal(PrettyPrinter::EnvironmentBase &enclosing) + : m_Context(enclosing) + { + } + + EnvironmentExternal(CodeStream &os) + : m_Context(os) + { + } + + //------------------------------------------------------------------------ + // PrettyPrinter::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::string define(const Token&) + { + throw std::runtime_error("Cannot declare variable in external environment"); + } + +protected: + //------------------------------------------------------------------------ + // Protected API + //------------------------------------------------------------------------ + auto &getContext() const{ return m_Context; } + + CodeStream &getContextStream() const + { + return std::visit( + Transpiler::Utils::Overload{ + [](std::reference_wrapper enclosing)->CodeStream& { return enclosing.get().getStream(); }, + [](std::reference_wrapper os)->CodeStream& { return os.get(); }}, + getContext()); + } + + std::string getContextName(const Token &name) const + { + return std::visit( + Transpiler::Utils::Overload{ + [&name](std::reference_wrapper enclosing)->std::string { return enclosing.get().getName(name); }, + [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Variable '" + std::string{name.lexeme} + "' undefined"); }}, + getContext()); + } + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + std::variant, std::reference_wrapper> m_Context; +}; + +//! Standard pretty printing environment simply allowing substitutions to be implemented +class EnvironmentSubstitute : public EnvironmentExternal +{ +public: + EnvironmentSubstitute(PrettyPrinter::EnvironmentBase &enclosing) : EnvironmentExternal(enclosing){} + EnvironmentSubstitute(CodeStream &os) : EnvironmentExternal(os){} + + //------------------------------------------------------------------------ + // PrettyPrinter::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::string getName(const Token &name) final + { + // If there isn't a substitution for this name, try and get name from context + auto sub = m_VarSubstitutions.find(std::string{name.lexeme}); + if(sub == m_VarSubstitutions.end()) { + return getContextName(name); + } + // Otherwise, return substitution + else { + return sub->second; + } + } + + virtual CodeStream &getStream() final + { + return getContextStream(); + } + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + void addSubstitution(const std::string &source, const std::string &destination) + { + if(!m_VarSubstitutions.emplace(source, destination).second) { + throw std::runtime_error("Redeclaration of substitution '" + source + "'"); + } + } + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + std::unordered_map m_VarSubstitutions; +}; + +//! Pretty printing environment which caches used variables in local variables +template +class EnvironmentLocalVarCache : public EnvironmentExternal +{ + typedef std::function GetIndexFn; +public: + EnvironmentLocalVarCache(const std::vector &vars, PrettyPrinter::EnvironmentBase &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") + : EnvironmentExternal(enclosing), m_Vars(vars), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) + { + // Add variables to map, initially with value set to value + std::transform(m_Vars.cbegin(), m_Vars.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), + [](const auto &v){ return std::make_pair(v.name, false); }); + } + + EnvironmentLocalVarCache(const std::vector &vars, CodeStream &os, GetIndexFn getIndex, const std::string &localPrefix = "l") + : EnvironmentExternal(os), m_Vars(vars), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) + { + // Add variables to map, initially with value set to value + std::transform(m_Vars.cbegin(), m_Vars.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), + [](const auto &v){ return std::make_pair(v.name, false); }); + } + + ~EnvironmentLocalVarCache() + { + // Copy variables which have been referenced into new vector + std::vector referencedVars; + std::copy_if(m_Vars.cbegin(), m_Vars.cend(), std::back_inserter(referencedVars), + [this](const auto &v){ return m_VariablesReferenced.at(v.name); }); + + // Loop through referenced variables + for(const auto &v : referencedVars) { + if(v.access & VarAccessMode::READ_ONLY) { + getContextStream() << "const "; + } + getContextStream() << v.type->getName() << " " << m_LocalPrefix << v.name; + + // If this isn't a reduction, read value from memory + // **NOTE** by not initialising these variables for reductions, + // compilers SHOULD emit a warning if user code doesn't set it to something + if(!(v.access & VarAccessModeAttribute::REDUCE)) { + getContextStream() << " = group->" << v.name << "[" << m_GetIndex(v.access) << "]"; + } + getContextStream() << ";" << std::endl; + } + + // Write contents to context stream + getContextStream() << m_ContentsStream.str(); + + // Loop through referenced variables again + for(const auto &v : referencedVars) { + // If variables are read-write + if(v.access & VarAccessMode::READ_WRITE) { + getContextStream() << "group->" << v.name << "[" << m_GetIndex(v.access) << "]"; + getContextStream() << " = " << m_LocalPrefix << v.name << ";" << std::endl; + } + } + } + + //------------------------------------------------------------------------ + // PrettyPrinter::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::string getName(const Token &name) final + { + // If variable with this name isn't found, try and get name from context + auto var = m_VariablesReferenced.find(std::string{name.lexeme}); + if(var == m_VariablesReferenced.end()) { + return getContextName(name); + } + // Otherwise + else { + // Set flag to indicate that variable has been referenced + var->second = true; + + // Add local prefix to variable name + return m_LocalPrefix + std::string{name.lexeme}; + } + } + + virtual CodeStream &getStream() final + { + return m_Contents; + } + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + const std::vector &m_Vars; + std::ostringstream m_ContentsStream; + CodeStream m_Contents; + const std::string m_LocalPrefix; + const GetIndexFn m_GetIndex; + std::unordered_map m_VariablesReferenced; +}; + template void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, const std::string &index, R getVarRefIndex) @@ -55,7 +253,7 @@ void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, const } os << v.type->getName() << " l" << v.name; - + // If this isn't a reduction, read value from memory // **NOTE** by not initialising these variables for reductions, // compilers SHOULD emit a warning if user code doesn't set it to something @@ -154,11 +352,11 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::TypeC typeEnvironment.defineEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); // Scan, parse and type-check update code - Transpiler::ErrorHandler errorHandler; + ErrorHandler errorHandler; const std::string code = upgradeCodeString(cm->getUpdateCode()); - const auto tokens = Transpiler::Scanner::scanSource(code, errorHandler); - m_UpdateStatements = Transpiler::Parser::parseBlockItemList(tokens, errorHandler); - Transpiler::TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, typeContext, errorHandler); + const auto tokens = Scanner::scanSource(code, errorHandler); + m_UpdateStatements = Parser::parseBlockItemList(tokens, errorHandler); + TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, typeContext, errorHandler); } //---------------------------------------------------------------------------- bool CustomUpdateGroupMerged::isParamHeterogeneous(const std::string ¶mName) const @@ -191,13 +389,32 @@ boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() //---------------------------------------------------------------------------- void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, Substitutions &popSubs) const { - genCustomUpdate(os, popSubs, *this, "id", + const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); + + EnvironmentSubstitute subs(os); + subs.addSubstitution("id", popSubs["id"]); + + EnvironmentLocalVarCache varSubs(cm->getVars(), subs, + [this](VarAccess a) + { + return getVarIndex(getVarAccessDuplication(a), "id"); + }); + + /*EnvironmentLocalVarCache varRefSubs(cm->getVarRefs(), subs, + [this](VarAccessMode a) + { + return getVarRefIndex(a, "id"); + });*/ + + /*genCustomUpdate(os, popSubs, *this, "id", [this](const auto &varRef, const std::string &index) { return getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, getVarAccessDuplication(varRef.getVar().access), index); - }); + });*/ + // Pretty print code + PrettyPrinter::print(m_UpdateStatements, varSubs, getTypeContext()); } //---------------------------------------------------------------------------- std::string CustomUpdateGroupMerged::getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const diff --git a/src/genn/genn/code_generator/substitutions.cc b/src/genn/genn/code_generator/substitutions.cc index d31792466a..b51904401c 100644 --- a/src/genn/genn/code_generator/substitutions.cc +++ b/src/genn/genn/code_generator/substitutions.cc @@ -107,4 +107,4 @@ void Substitutions::applyVars(std::string &code) const m_Parent->applyVars(code); } } -} // namespace GeNN::CodeGenerator \ No newline at end of file +} // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 0bd7f4058b..afd1d09b7e 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -5,6 +5,7 @@ #include #include #include +#include // GeNN code generator includes #include "code_generator/codeStream.h" @@ -13,6 +14,7 @@ #include "transpiler/transpilerUtils.h" using namespace GeNN; +using namespace GeNN::CodeGenerator; using namespace GeNN::Transpiler; using namespace GeNN::Transpiler::PrettyPrinter; @@ -21,76 +23,124 @@ using namespace GeNN::Transpiler::PrettyPrinter; //--------------------------------------------------------------------------- namespace { +//--------------------------------------------------------------------------- +// EnvironmentInternal +//--------------------------------------------------------------------------- +class EnvironmentInternal : public EnvironmentBase +{ +public: + EnvironmentInternal(EnvironmentBase &enclosing) + : m_Enclosing(enclosing) + { + } + + //--------------------------------------------------------------------------- + // EnvironmentBase virtuals + //--------------------------------------------------------------------------- + virtual std::string define(const Token &name) final + { + if(!m_LocalVariables.emplace(name.lexeme).second) { + throw std::runtime_error("Redeclaration of variable"); + } + + return "_" + std::string{name.lexeme}; + } + + virtual std::string getName(const Token &name) final + { + if(m_LocalVariables.find(name.lexeme) == m_LocalVariables.end()) { + return m_Enclosing.getName(name); + } + else { + return "_" + std::string{name.lexeme}; + } + } + + virtual CodeStream &getStream() + { + return m_Enclosing.getStream(); + } + +private: + //--------------------------------------------------------------------------- + // Members + //--------------------------------------------------------------------------- + EnvironmentBase &m_Enclosing; + std::unordered_set m_LocalVariables; +}; + //--------------------------------------------------------------------------- // Visitor //--------------------------------------------------------------------------- class Visitor : public Expression::Visitor, public Statement::Visitor { public: - Visitor(CodeGenerator::CodeStream &codeStream, const Statement::StatementList &statements, const Type::TypeContext &context) - : m_CodeStream(codeStream), m_Context(context) + Visitor(const Statement::StatementList &statements, + EnvironmentInternal &environment, const Type::TypeContext &context) + : m_Environment(environment), m_Context(context) { for(auto &s : statements) { s.get()->accept(*this); - m_CodeStream << std::endl; + m_Environment.get().getStream() << std::endl; } } +private: //--------------------------------------------------------------------------- // Expression::Visitor virtuals //--------------------------------------------------------------------------- virtual void visit(const Expression::ArraySubscript &arraySubscript) final { - m_CodeStream << arraySubscript.getPointerName().lexeme << "["; + m_Environment.get().getStream() << m_Environment.get().getName(arraySubscript.getPointerName()) << "["; arraySubscript.getIndex()->accept(*this); - m_CodeStream << "]"; + m_Environment.get().getStream() << "]"; } virtual void visit(const Expression::Assignment &assignement) final { - m_CodeStream << assignement.getVarName().lexeme << " " << assignement.getOperator().lexeme << " "; + m_Environment.get().getStream() << m_Environment.get().getName(assignement.getVarName()) << " " << assignement.getOperator().lexeme << " "; assignement.getValue()->accept(*this); } virtual void visit(const Expression::Binary &binary) final { binary.getLeft()->accept(*this); - m_CodeStream << " " << binary.getOperator().lexeme << " "; + m_Environment.get().getStream() << " " << binary.getOperator().lexeme << " "; binary.getRight()->accept(*this); } virtual void visit(const Expression::Call &call) final { call.getCallee()->accept(*this); - m_CodeStream << "("; + m_Environment.get().getStream() << "("; for(const auto &a : call.getArguments()) { a->accept(*this); } - m_CodeStream << ")"; + m_Environment.get().getStream() << ")"; } virtual void visit(const Expression::Cast &cast) final { - m_CodeStream << "("; + m_Environment.get().getStream() << "("; printType(cast.getType()); - m_CodeStream << ")"; + m_Environment.get().getStream() << ")"; cast.getExpression()->accept(*this); } virtual void visit(const Expression::Conditional &conditional) final { conditional.getCondition()->accept(*this); - m_CodeStream << " ? "; + m_Environment.get().getStream() << " ? "; conditional.getTrue()->accept(*this); - m_CodeStream << " : "; + m_Environment.get().getStream() << " : "; conditional.getFalse()->accept(*this); } virtual void visit(const Expression::Grouping &grouping) final { - m_CodeStream << "("; + m_Environment.get().getStream() << "("; grouping.getExpression()->accept(*this); - m_CodeStream << ")"; + m_Environment.get().getStream() << ")"; } virtual void visit(const Expression::Literal &literal) final @@ -98,44 +148,44 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If literal is a double, we want to remove the d suffix in generated code std::string_view lexeme = literal.getValue().lexeme; if (literal.getValue().type == Token::Type::DOUBLE_NUMBER){ - m_CodeStream << lexeme.substr(0, literal.getValue().lexeme.size() - 1); + m_Environment.get().getStream() << lexeme.substr(0, literal.getValue().lexeme.size() - 1); } // Otherwise, if literal is a scalar, we want to add appropriate suffix for scalar type else if (literal.getValue().type == Token::Type::SCALAR_NUMBER) { const Type::NumericBase *scalar = dynamic_cast(m_Context.at("scalar")); - m_CodeStream << lexeme << scalar->getLiteralSuffix(m_Context); + m_Environment.get().getStream() << lexeme << scalar->getLiteralSuffix(m_Context); } // Otherwise, just write out original lexeme directly else { - m_CodeStream << lexeme; + m_Environment.get().getStream() << lexeme; } } virtual void visit(const Expression::Logical &logical) final { logical.getLeft()->accept(*this); - m_CodeStream << " " << logical.getOperator().lexeme << " "; + m_Environment.get().getStream() << " " << logical.getOperator().lexeme << " "; logical.getRight()->accept(*this); } virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final { - m_CodeStream << postfixIncDec.getVarName().lexeme << postfixIncDec.getOperator().lexeme; + m_Environment.get().getStream() << m_Environment.get().getName(postfixIncDec.getVarName()) << postfixIncDec.getOperator().lexeme; } virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final { - m_CodeStream << prefixIncDec.getOperator().lexeme << prefixIncDec.getVarName().lexeme; + m_Environment.get().getStream() << m_Environment.get().getName(prefixIncDec.getOperator()) << prefixIncDec.getVarName().lexeme; } virtual void visit(const Expression::Variable &variable) final { - m_CodeStream << variable.getName().lexeme; + m_Environment.get().getStream() << m_Environment.get().getName(variable.getName()); } virtual void visit(const Expression::Unary &unary) final { - m_CodeStream << unary.getOperator().lexeme; + m_Environment.get().getStream() << unary.getOperator().lexeme; unary.getRight()->accept(*this); } @@ -144,88 +194,108 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- virtual void visit(const Statement::Break&) final { - m_CodeStream << "break;"; + m_Environment.get().getStream() << "break;"; } virtual void visit(const Statement::Compound &compound) final { - CodeGenerator::CodeStream::Scope b(m_CodeStream); + // Cache reference to current reference + std::reference_wrapper oldEnvironment = m_Environment; + + // Create new environment and set to current + EnvironmentInternal environment(m_Environment); + m_Environment = environment; + + CodeGenerator::CodeStream::Scope b(m_Environment.get().getStream()); for(auto &s : compound.getStatements()) { s->accept(*this); - m_CodeStream << std::endl; + m_Environment.get().getStream() << std::endl; } + + // Restore old environment + m_Environment = oldEnvironment; } virtual void visit(const Statement::Continue&) final { - m_CodeStream << "continue;"; + m_Environment.get().getStream() << "continue;"; } virtual void visit(const Statement::Do &doStatement) final { - m_CodeStream << "do"; + m_Environment.get().getStream() << "do"; doStatement.getBody()->accept(*this); - m_CodeStream << "while("; + m_Environment.get().getStream() << "while("; doStatement.getCondition()->accept(*this); - m_CodeStream << ");" << std::endl; + m_Environment.get().getStream() << ");" << std::endl; } virtual void visit(const Statement::Expression &expression) final { expression.getExpression()->accept(*this); - m_CodeStream << ";"; + m_Environment.get().getStream() << ";"; } virtual void visit(const Statement::For &forStatement) final { - m_CodeStream << "for("; + // Cache reference to current reference + std::reference_wrapper oldEnvironment = m_Environment; + + // Create new environment and set to current + EnvironmentInternal environment(m_Environment); + m_Environment = environment; + + m_Environment.get().getStream() << "for("; if(forStatement.getInitialiser()) { forStatement.getInitialiser()->accept(*this); } else { - m_CodeStream << ";"; + m_Environment.get().getStream() << ";"; } - m_CodeStream << " "; + m_Environment.get().getStream() << " "; if(forStatement.getCondition()) { forStatement.getCondition()->accept(*this); } - m_CodeStream << "; "; + m_Environment.get().getStream() << "; "; if(forStatement.getIncrement()) { forStatement.getIncrement()->accept(*this); } - m_CodeStream << ")"; + m_Environment.get().getStream() << ")"; forStatement.getBody()->accept(*this); + + // Restore old environment + m_Environment = oldEnvironment; } virtual void visit(const Statement::If &ifStatement) final { - m_CodeStream << "if("; + m_Environment.get().getStream() << "if("; ifStatement.getCondition()->accept(*this); - m_CodeStream << ")" << std::endl; + m_Environment.get().getStream() << ")" << std::endl; ifStatement.getThenBranch()->accept(*this); if(ifStatement.getElseBranch()) { - m_CodeStream << "else" << std::endl; + m_Environment.get().getStream() << "else" << std::endl; ifStatement.getElseBranch()->accept(*this); } } virtual void visit(const Statement::Labelled &labelled) final { - m_CodeStream << labelled.getKeyword().lexeme << " "; + m_Environment.get().getStream() << labelled.getKeyword().lexeme << " "; if(labelled.getValue()) { labelled.getValue()->accept(*this); } - m_CodeStream << " : "; + m_Environment.get().getStream() << " : "; labelled.getBody()->accept(*this); } virtual void visit(const Statement::Switch &switchStatement) final { - m_CodeStream << "switch("; + m_Environment.get().getStream() << "switch("; switchStatement.getCondition()->accept(*this); - m_CodeStream << ")" << std::endl; + m_Environment.get().getStream() << ")" << std::endl; switchStatement.getBody()->accept(*this); } @@ -234,29 +304,29 @@ class Visitor : public Expression::Visitor, public Statement::Visitor printType(varDeclaration.getType()); for(const auto &var : varDeclaration.getInitDeclaratorList()) { - m_CodeStream << std::get<0>(var).lexeme; + m_Environment.get().getStream() << m_Environment.get().define(std::get<0>(var)); if(std::get<1>(var)) { - m_CodeStream << " = "; + m_Environment.get().getStream() << " = "; std::get<1>(var)->accept(*this); } - m_CodeStream << ", "; + m_Environment.get().getStream() << ", "; } - m_CodeStream << ";"; + m_Environment.get().getStream() << ";"; } virtual void visit(const Statement::While &whileStatement) final { - m_CodeStream << "while("; + m_Environment.get().getStream() << "while("; whileStatement.getCondition()->accept(*this); - m_CodeStream << ")" << std::endl; + m_Environment.get().getStream() << ")" << std::endl; whileStatement.getBody()->accept(*this); } virtual void visit(const Statement::Print &print) final { - m_CodeStream << "print "; + m_Environment.get().getStream() << "print "; print.getExpression()->accept(*this); - m_CodeStream << ";"; + m_Environment.get().getStream() << ";"; } private: @@ -295,13 +365,13 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Copy tokens backwards into string stream, seperating with spaces - std::copy(tokens.rbegin(), tokens.rend(), std::ostream_iterator(m_CodeStream, " ")); + std::copy(tokens.rbegin(), tokens.rend(), std::ostream_iterator(m_Environment.get().getStream(), " ")); } //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- - CodeGenerator::CodeStream &m_CodeStream; + std::reference_wrapper m_Environment; const Type::TypeContext &m_Context; }; } // Anonymous namespace @@ -309,8 +379,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // GeNN::Transpiler::PrettyPrinter //--------------------------------------------------------------------------- -void GeNN::Transpiler::PrettyPrinter::print(CodeGenerator::CodeStream &os, const Statement::StatementList &statements, +void GeNN::Transpiler::PrettyPrinter::print(const Statement::StatementList &statements, EnvironmentBase &environment, const Type::TypeContext &context) { - Visitor(os, statements, context); + EnvironmentInternal internalEnvironment(environment); + Visitor(statements, internalEnvironment, context); } diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index dcc1c7ef0b..17d0c983ea 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -24,7 +24,7 @@ namespace Type = GeNN::Type; //--------------------------------------------------------------------------- namespace { - bool checkPointerTypeAssignement(const Type::Base *rightType, const Type::Base *leftType, const Type::TypeContext &typeContext) +bool checkPointerTypeAssignement(const Type::Base *rightType, const Type::Base *leftType, const Type::TypeContext &typeContext) { // If both are pointers, recurse through value type auto rightPointerType = dynamic_cast(rightType); @@ -41,7 +41,7 @@ namespace return false; } } - +//--------------------------------------------------------------------------- bool checkForConstRemoval(const Type::Base *rightType, const Type::Base *leftType) { // If const is being removed @@ -66,7 +66,6 @@ bool checkForConstRemoval(const Type::Base *rightType, const Type::Base *leftTyp } - //--------------------------------------------------------------------------- // EnvironmentInternal //--------------------------------------------------------------------------- @@ -143,29 +142,32 @@ class EnvironmentInternal : public EnvironmentBase class Visitor : public Expression::Visitor, public Statement::Visitor { public: - Visitor(const Type::TypeContext &context, ErrorHandlerBase &errorHandler) - : m_Environment(nullptr), m_Type(nullptr), m_Context(context), m_ErrorHandler(errorHandler), - m_InLoop(false), m_InSwitch(false) + Visitor(const Statement::StatementList &statements, const Type::TypeContext &context, + EnvironmentInternal &environment, ErrorHandlerBase &errorHandler) + : Visitor(context, environment, errorHandler) { - } - - //--------------------------------------------------------------------------- - // Public API - //--------------------------------------------------------------------------- - // **THINK** make constructors? - void typeCheck(const Statement::StatementList &statements, EnvironmentInternal &environment) - { - m_Environment = &environment; for (auto &s : statements) { s.get()->accept(*this); } } - - const Type::Base *typeCheck(const Expression::Base *expression, EnvironmentInternal &environment) + + Visitor(const Expression::Base *expression, const Type::TypeContext &context, + EnvironmentInternal &environment, ErrorHandlerBase &errorHandler) + : Visitor(context, environment, errorHandler) + { + expression->accept(*this); + } + + //--------------------------------------------------------------------------- + // Public API + //--------------------------------------------------------------------------- + const Type::Base *getType() const{ return m_Type; } + +private: + Visitor(const Type::TypeContext &context, EnvironmentInternal &environment, ErrorHandlerBase &errorHandler) + : m_Environment(environment), m_Type(nullptr), m_Context(context), m_ErrorHandler(errorHandler), + m_InLoop(false), m_InSwitch(false) { - - m_Environment = &environment; - return evaluateType(expression); } //--------------------------------------------------------------------------- @@ -174,7 +176,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::ArraySubscript &arraySubscript) final { // Get pointer type - auto arrayType = m_Environment->getType(arraySubscript.getPointerName(), m_ErrorHandler); + auto arrayType = m_Environment.get().getType(arraySubscript.getPointerName(), m_ErrorHandler); auto pointerType = dynamic_cast(arrayType); // If pointer is indeed a pointer @@ -201,8 +203,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Assignment &assignment) final { const auto rhsType = evaluateType(assignment.getValue()); - m_Type = m_Environment->assign(assignment.getVarName(), assignment.getOperator().type, rhsType, - m_Context, m_ErrorHandler); + m_Type = m_Environment.get().assign(assignment.getVarName(), assignment.getOperator().type, rhsType, + m_Context, m_ErrorHandler); } virtual void visit(const Expression::Binary &binary) final @@ -415,19 +417,19 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final { - m_Type = m_Environment->incDec(postfixIncDec.getVarName(), postfixIncDec.getOperator().type, - m_Context, m_ErrorHandler); + m_Type = m_Environment.get().incDec(postfixIncDec.getVarName(), postfixIncDec.getOperator().type, + m_Context, m_ErrorHandler); } virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final { - m_Type = m_Environment->incDec(prefixIncDec.getVarName(), prefixIncDec.getOperator().type, - m_Context, m_ErrorHandler); + m_Type = m_Environment.get().incDec(prefixIncDec.getVarName(), prefixIncDec.getOperator().type, + m_Context, m_ErrorHandler); } virtual void visit(const Expression::Variable &variable) { - m_Type = m_Environment->getType(variable.getName(), m_ErrorHandler); + m_Type = m_Environment.get().getType(variable.getName(), m_ErrorHandler); } virtual void visit(const Expression::Unary &unary) final @@ -497,8 +499,19 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::Compound &compound) final { - EnvironmentInternal environment(*m_Environment); - typeCheck(compound.getStatements(), environment); + // Cache reference to current reference + std::reference_wrapper oldEnvironment = m_Environment; + + // Create new environment and set to current + EnvironmentInternal environment(m_Environment); + m_Environment = environment; + + for (auto &s : compound.getStatements()) { + s.get()->accept(*this); + } + + // Restore old environment + m_Environment = oldEnvironment; } virtual void visit(const Statement::Continue &continueStatement) final @@ -523,10 +536,12 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::For &forStatement) final { - // Create new environment for loop initialisation - EnvironmentInternal *previous = m_Environment; - EnvironmentInternal environment(*m_Environment); - m_Environment = &environment; + // Cache reference to current reference + std::reference_wrapper oldEnvironment = m_Environment; + + // Create new environment and set to current + EnvironmentInternal environment(m_Environment); + m_Environment = environment; // Interpret initialiser if statement present if (forStatement.getInitialiser()) { @@ -545,8 +560,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor forStatement.getBody()->accept(*this); m_InLoop = false; - // Restore environment - m_Environment = previous; + // Restore old environment + m_Environment = oldEnvironment; } virtual void visit(const Statement::If &ifStatement) final @@ -595,7 +610,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::VarDeclaration &varDeclaration) final { for (const auto &var : varDeclaration.getInitDeclaratorList()) { - m_Environment->define(std::get<0>(var), varDeclaration.getType(), m_ErrorHandler); + m_Environment.get().define(std::get<0>(var), varDeclaration.getType(), m_ErrorHandler); // If variable has an initialiser expression if (std::get<1>(var)) { @@ -603,8 +618,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto initialiserType = evaluateType(std::get<1>(var).get()); // Assign initialiser expression to variable - m_Environment->assign(std::get<0>(var), Token::Type::EQUAL, initialiserType, - m_Context, m_ErrorHandler, true); + m_Environment.get().assign(std::get<0>(var), Token::Type::EQUAL, initialiserType, + m_Context, m_ErrorHandler, true); } } } @@ -635,7 +650,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- - EnvironmentInternal *m_Environment; + std::reference_wrapper m_Environment; const Type::Base *m_Type; const Type::TypeContext &m_Context; ErrorHandlerBase &m_ErrorHandler; @@ -749,15 +764,14 @@ const Type::Base *EnvironmentBase::incDec(const Token &name, Token::Type, void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { - Visitor visitor(context, errorHandler); EnvironmentInternal internalEnvironment(environment); - visitor.typeCheck(statements, internalEnvironment); + Visitor(statements, context, internalEnvironment, errorHandler); } //--------------------------------------------------------------------------- const Type::Base *GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, EnvironmentBase &environment, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { - Visitor visitor(context, errorHandler); EnvironmentInternal internalEnvironment(environment); - return visitor.typeCheck(expression, internalEnvironment); + Visitor visitor(expression, context, internalEnvironment, errorHandler); + return visitor.getType(); } From 56a62d73a3c600cb2c27cb4ed3be52eb66e14e9c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 26 Jan 2023 17:38:13 +0000 Subject: [PATCH 101/725] made names of adapter methods more generic --- .../genn/genn/code_generator/groupMerged.h | 28 +++++++-------- .../genn/code_generator/initGroupMerged.h | 12 +++---- include/genn/genn/currentSourceInternal.h | 10 +++--- .../genn/customConnectivityUpdateInternal.h | 24 ++++++------- include/genn/genn/customUpdate.h | 12 +++---- include/genn/genn/neuronGroupInternal.h | 12 +++---- include/genn/genn/synapseGroupInternal.h | 34 +++++++++---------- .../genn/code_generator/generateRunner.cc | 26 +++++++------- 8 files changed, 78 insertions(+), 80 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 3f42f3db23..0ee25dee84 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -354,14 +354,14 @@ class GroupMerged { // Loop through weight update model variables const A archetypeAdaptor(getArchetype()); - for(const auto &v : archetypeAdaptor.getVars()) { + for(const auto &v : archetypeAdaptor.getDefs()) { // Loop through parameters - for(const auto &p : archetypeAdaptor.getVarInitialisers().at(v.name).getParams()) { + for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getParams()) { if((static_cast(this)->*isHeterogeneous)(v.name, p.first)) { addScalarField(p.first + v.name, [p, v](const G &g, size_t) { - return A(g).getVarInitialisers().at(v.name).getParams().at(p.first); + return A(g).getInitialisers().at(v.name).getParams().at(p.first); }); } } @@ -373,14 +373,14 @@ class GroupMerged { // Loop through weight update model variables const A archetypeAdaptor(getArchetype()); - for(const auto &v : archetypeAdaptor.getVars()) { + for(const auto &v : archetypeAdaptor.getDefs()) { // Loop through parameters - for(const auto &p : archetypeAdaptor.getVarInitialisers().at(v.name).getDerivedParams()) { + for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getDerivedParams()) { if((static_cast(this)->*isHeterogeneous)(v.name, p.first)) { addScalarField(p.first + v.name, [p, v](const G &g, size_t) { - return A(g).getVarInitialisers().at(v.name).getDerivedParams().at(p.first); + return A(g).getInitialisers().at(v.name).getDerivedParams().at(p.first); }); } } @@ -417,7 +417,7 @@ class GroupMerged void updateVarInitParamHash(R isParamReferencedFn, boost::uuids::detail::sha1 &hash) const { // Loop through variables - const auto &archetypeVarInitialisers = A(getArchetype()).getVarInitialisers(); + const auto &archetypeVarInitialisers = A(getArchetype()).getInitialisers(); for(const auto &varInit : archetypeVarInitialisers) { // Loop through parameters for(const auto &p : varInit.second.getParams()) { @@ -425,7 +425,7 @@ class GroupMerged if((static_cast(this)->*isParamReferencedFn)(varInit.first, p.first)) { // Loop through groups for(const auto &g : getGroups()) { - const auto &values = A(g.get()).getVarInitialisers().at(varInit.first).getParams(); + const auto &values = A(g.get()).getInitialisers().at(varInit.first).getParams(); // Update hash with parameter value Utils::updateHash(values.at(p.first), hash); @@ -439,7 +439,7 @@ class GroupMerged void updateVarInitDerivedParamHash(R isDerivedParamReferencedFn, boost::uuids::detail::sha1 &hash) const { // Loop through variables - const auto &archetypeVarInitialisers = A(getArchetype()).getVarInitialisers(); + const auto &archetypeVarInitialisers = A(getArchetype()).getInitialisers(); for(const auto &varInit : archetypeVarInitialisers) { // Loop through parameters for(const auto &d : varInit.second.getDerivedParams()) { @@ -447,7 +447,7 @@ class GroupMerged if((static_cast(this)->*isDerivedParamReferencedFn)(varInit.first, d.first)) { // Loop through groups for(const auto &g : getGroups()) { - const auto &values = A(g.get()).getVarInitialisers().at(varInit.first).getDerivedParams(); + const auto &values = A(g.get()).getInitialisers().at(varInit.first).getDerivedParams(); // Update hash with parameter value Utils::updateHash(values.at(d.first), hash); @@ -889,7 +889,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged { // Loop through variables A archetypeAdaptor(this->getArchetype()); - for (const auto &var : archetypeAdaptor.getVars()) { + for (const auto &var : archetypeAdaptor.getDefs()) { // If we're not initialising or if there is initialization code for this variable - const auto &varInit = archetypeAdaptor.getVarInitialisers().at(var.name); + const auto &varInit = archetypeAdaptor.getInitialisers().at(var.name); if (!varInit.getSnippet()->getCode().empty()) { this->addPointerField(var.type, var.name, backend.getDeviceVarPrefix() + var.name); } @@ -284,7 +284,7 @@ class CustomUpdateInitGroupMergedBase : public GroupMerged [&varName](const G &cg) { A archetypeAdaptor(cg); - return archetypeAdaptor.getVarInitialisers().at(varName).getParams(); + return archetypeAdaptor.getInitialisers().at(varName).getParams(); })); } @@ -296,7 +296,7 @@ class CustomUpdateInitGroupMergedBase : public GroupMerged [&varName](const G &cg) { A archetypeAdaptor(cg); - return archetypeAdaptor.getVarInitialisers().at(varName).getDerivedParams(); + return archetypeAdaptor.getInitialisers().at(varName).getDerivedParams(); })); } @@ -321,7 +321,7 @@ class CustomUpdateInitGroupMergedBase : public GroupMerged bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const { A archetypeAdaptor(this->getArchetype()); - const auto *varInitSnippet = archetypeAdaptor.getVarInitialisers().at(varName).getSnippet(); + const auto *varInitSnippet = archetypeAdaptor.getInitialisers().at(varName).getSnippet(); return this->isParamReferenced({varInitSnippet->getCode()}, paramName); } @@ -329,7 +329,7 @@ class CustomUpdateInitGroupMergedBase : public GroupMerged bool isVarInitDerivedParamReferenced(const std::string &varName, const std::string ¶mName) const { A archetypeAdaptor(this->getArchetype()); - const auto *varInitSnippet = archetypeAdaptor.getVarInitialisers().at(varName).getSnippet(); + const auto *varInitSnippet = archetypeAdaptor.getInitialisers().at(varName).getSnippet(); return this->isParamReferenced({varInitSnippet->getCode()}, paramName); } }; diff --git a/include/genn/genn/currentSourceInternal.h b/include/genn/genn/currentSourceInternal.h index ca6e7abe48..06973292b2 100644 --- a/include/genn/genn/currentSourceInternal.h +++ b/include/genn/genn/currentSourceInternal.h @@ -43,11 +43,11 @@ class CurrentSourceVarAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getVarLocation(const std::string &varName) const{ return m_CS.getVarLocation(varName); } + VarLocation getLoc(const std::string &varName) const{ return m_CS.getVarLocation(varName); } - Models::Base::VarVec getVars() const{ return m_CS.getCurrentSourceModel()->getVars(); } + Models::Base::VarVec getDefs() const{ return m_CS.getCurrentSourceModel()->getVars(); } - const std::unordered_map &getVarInitialisers() const{ return m_CS.getVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_CS.getVarInitialisers(); } private: //---------------------------------------------------------------------------- @@ -68,9 +68,9 @@ class CurrentSourceEGPAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getEGPLocation(const std::string &varName) const{ return m_CS.getExtraGlobalParamLocation(varName); } + VarLocation getLoc(const std::string &varName) const{ return m_CS.getExtraGlobalParamLocation(varName); } - Snippet::Base::EGPVec getEGPs() const{ return m_CS.getCurrentSourceModel()->getExtraGlobalParams(); } + Snippet::Base::EGPVec getDefs() const{ return m_CS.getCurrentSourceModel()->getExtraGlobalParams(); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/customConnectivityUpdateInternal.h b/include/genn/genn/customConnectivityUpdateInternal.h index 54edd43732..04a218cc9e 100644 --- a/include/genn/genn/customConnectivityUpdateInternal.h +++ b/include/genn/genn/customConnectivityUpdateInternal.h @@ -53,11 +53,11 @@ class CustomConnectivityUpdateVarAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getVarLocation(const std::string &varName) const{ return m_CU.getVarLocation(varName); } + VarLocation getLoc(const std::string &varName) const{ return m_CU.getVarLocation(varName); } - Models::Base::VarVec getVars() const{ return m_CU.getCustomConnectivityUpdateModel()->getVars(); } + Models::Base::VarVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getVars(); } - const std::unordered_map &getVarInitialisers() const{ return m_CU.getVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_CU.getVarInitialisers(); } private: //---------------------------------------------------------------------------- @@ -78,11 +78,11 @@ class CustomConnectivityUpdatePreVarAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getVarLocation(const std::string &varName) const{ return m_CU.getPreVarLocation(varName); } + VarLocation getLoc(const std::string &varName) const{ return m_CU.getPreVarLocation(varName); } - Models::Base::VarVec getVars() const{ return m_CU.getCustomConnectivityUpdateModel()->getPreVars(); } + Models::Base::VarVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPreVars(); } - const std::unordered_map &getVarInitialisers() const{ return m_CU.getPreVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_CU.getPreVarInitialisers(); } private: //---------------------------------------------------------------------------- @@ -103,11 +103,11 @@ class CustomConnectivityUpdatePostVarAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getVarLocation(const std::string &varName) const{ return m_CU.getPostVarLocation(varName); } + VarLocation getLoc(const std::string &varName) const{ return m_CU.getPostVarLocation(varName); } - Models::Base::VarVec getVars() const{ return m_CU.getCustomConnectivityUpdateModel()->getPostVars(); } + Models::Base::VarVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPostVars(); } - const std::unordered_map &getVarInitialisers() const{ return m_CU.getPostVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_CU.getPostVarInitialisers(); } private: //---------------------------------------------------------------------------- @@ -129,9 +129,9 @@ class CustomConnectivityUpdateEGPAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getEGPLocation(const std::string&) const{ return VarLocation::HOST_DEVICE; } + VarLocation getLoc(const std::string&) const{ return VarLocation::HOST_DEVICE; } - Snippet::Base::EGPVec getEGPs() const{ return m_CU.getCustomConnectivityUpdateModel()->getExtraGlobalParams(); } + Snippet::Base::EGPVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getExtraGlobalParams(); } private: //---------------------------------------------------------------------------- @@ -139,4 +139,4 @@ class CustomConnectivityUpdateEGPAdapter //---------------------------------------------------------------------------- const CustomConnectivityUpdateInternal &m_CU; }; -} // namespace GeNN \ No newline at end of file +} // namespace GeNN diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 40e4464b09..21c258387e 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -179,11 +179,11 @@ class CustomUpdateVarAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getVarLocation(const std::string &varName) const{ return m_CU.getVarLocation(varName); } + VarLocation getLoc(const std::string &varName) const{ return m_CU.getVarLocation(varName); } - Models::Base::VarVec getVars() const{ return m_CU.getCustomUpdateModel()->getVars(); } + Models::Base::VarVec getDefs() const{ return m_CU.getCustomUpdateModel()->getVars(); } - const std::unordered_map &getVarInitialisers() const{ return m_CU.getVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_CU.getVarInitialisers(); } private: //---------------------------------------------------------------------------- @@ -204,11 +204,9 @@ class CustomUpdateEGPAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getEGPLocation(const std::string&) const{ return VarLocation::HOST_DEVICE; } + VarLocation getLoc(const std::string&) const{ return VarLocation::HOST_DEVICE; } - VarLocation getEGPLocation(size_t) const{ return VarLocation::HOST_DEVICE; } - - Snippet::Base::EGPVec getEGPs() const{ return m_CU.getCustomUpdateModel()->getExtraGlobalParams(); } + Snippet::Base::EGPVec getDefs() const{ return m_CU.getCustomUpdateModel()->getExtraGlobalParams(); } private: //---------------------------------------------------------------------------- diff --git a/include/genn/genn/neuronGroupInternal.h b/include/genn/genn/neuronGroupInternal.h index 8e1b23a30f..86ec247516 100644 --- a/include/genn/genn/neuronGroupInternal.h +++ b/include/genn/genn/neuronGroupInternal.h @@ -60,11 +60,11 @@ class NeuronVarAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getVarLocation(const std::string &varName) const{ return m_NG.getVarLocation(varName); } + VarLocation getLoc(const std::string &varName) const{ return m_NG.getVarLocation(varName); } - Models::Base::VarVec getVars() const{ return m_NG.getNeuronModel()->getVars(); } + Models::Base::VarVec getDefs() const{ return m_NG.getNeuronModel()->getVars(); } - const std::unordered_map &getVarInitialisers() const{ return m_NG.getVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_NG.getVarInitialisers(); } private: //---------------------------------------------------------------------------- @@ -85,9 +85,9 @@ class NeuronEGPAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getEGPLocation(const std::string &varName) const{ return m_NG.getExtraGlobalParamLocation(varName); } + VarLocation getLoc(const std::string &varName) const{ return m_NG.getExtraGlobalParamLocation(varName); } - Snippet::Base::EGPVec getEGPs() const{ return m_NG.getNeuronModel()->getExtraGlobalParams(); } + Snippet::Base::EGPVec getDefs() const{ return m_NG.getNeuronModel()->getExtraGlobalParams(); } private: //---------------------------------------------------------------------------- @@ -95,4 +95,4 @@ class NeuronEGPAdapter //---------------------------------------------------------------------------- const NeuronGroupInternal &m_NG; }; -} // namespace GeNN \ No newline at end of file +} // namespace GeNN diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h index 2307e889dd..ab5529346c 100644 --- a/include/genn/genn/synapseGroupInternal.h +++ b/include/genn/genn/synapseGroupInternal.h @@ -89,13 +89,13 @@ class SynapsePSMVarAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getVarLocation(const std::string &varName) const{ return m_SG.getPSVarLocation(varName); } + VarLocation getLoc(const std::string &varName) const{ return m_SG.getPSVarLocation(varName); } - Models::Base::VarVec getVars() const{ return m_SG.getPSModel()->getVars(); } + Models::Base::VarVec getDefs() const{ return m_SG.getPSModel()->getVars(); } - const std::unordered_map &getVarInitialisers() const{ return m_SG.getPSVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_SG.getPSVarInitialisers(); } - const std::string &getFusedVarSuffix() const{ return m_SG.getFusedPSVarSuffix(); } + const std::string &getFusedSuffix() const{ return m_SG.getFusedPSVarSuffix(); } private: //---------------------------------------------------------------------------- @@ -116,11 +116,11 @@ class SynapseWUVarAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getVarLocation(const std::string &varName) const{ return m_SG.getWUVarLocation(varName); } + VarLocation getLoc(const std::string &varName) const{ return m_SG.getWUVarLocation(varName); } - Models::Base::VarVec getVars() const{ return m_SG.getWUModel()->getVars(); } + Models::Base::VarVec getDefs() const{ return m_SG.getWUModel()->getVars(); } - const std::unordered_map &getVarInitialisers() const{ return m_SG.getWUVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_SG.getWUVarInitialisers(); } private: //---------------------------------------------------------------------------- @@ -141,13 +141,13 @@ class SynapseWUPreVarAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getVarLocation(const std::string &varName) const{ return m_SG.getWUPreVarLocation(varName); } + VarLocation getLoc(const std::string &varName) const{ return m_SG.getWUPreVarLocation(varName); } - Models::Base::VarVec getVars() const{ return m_SG.getWUModel()->getPreVars(); } + Models::Base::VarVec getDefs() const{ return m_SG.getWUModel()->getPreVars(); } - const std::unordered_map &getVarInitialisers() const{ return m_SG.getWUPreVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_SG.getWUPreVarInitialisers(); } - const std::string &getFusedVarSuffix() const{ return m_SG.getFusedWUPreVarSuffix(); } + const std::string &getFusedSuffix() const{ return m_SG.getFusedWUPreVarSuffix(); } private: //---------------------------------------------------------------------------- @@ -168,13 +168,13 @@ class SynapseWUPostVarAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getVarLocation(const std::string &varName) const{ return m_SG.getWUPostVarLocation(varName); } + VarLocation getLoc(const std::string &varName) const{ return m_SG.getWUPostVarLocation(varName); } - Models::Base::VarVec getVars() const{ return m_SG.getWUModel()->getPostVars(); } + Models::Base::VarVec getDefs() const{ return m_SG.getWUModel()->getPostVars(); } - const std::unordered_map &getVarInitialisers() const{ return m_SG.getWUPostVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_SG.getWUPostVarInitialisers(); } - const std::string &getFusedVarSuffix() const{ return m_SG.getFusedWUPostVarSuffix(); } + const std::string &getFusedSuffix() const{ return m_SG.getFusedWUPostVarSuffix(); } private: //---------------------------------------------------------------------------- @@ -196,9 +196,9 @@ class SynapseWUEGPAdapter //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- - VarLocation getEGPLocation(const std::string &varName) const{ return m_SG.getWUExtraGlobalParamLocation(varName); } + VarLocation getLoc(const std::string &varName) const{ return m_SG.getWUExtraGlobalParamLocation(varName); } - Snippet::Base::EGPVec getEGPs() const{ return m_SG.getWUModel()->getExtraGlobalParams(); } + Snippet::Base::EGPVec getDefs() const{ return m_SG.getWUModel()->getExtraGlobalParams(); } private: //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index c75eaae3d7..c8c7f73d39 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -412,11 +412,11 @@ void genRunnerVars(const ModelSpecMerged &modelMerged, const BackendBase &backen { // Loop through variables const V varAdaptor(group); - for(const auto &var : varAdaptor.getVars()) { - const auto *varInitSnippet = varAdaptor.getVarInitialisers().at(var.name).getSnippet(); + for(const auto &var : varAdaptor.getDefs()) { + const auto *varInitSnippet = varAdaptor.getInitialisers().at(var.name).getSnippet(); const bool autoInitialized = !varInitSnippet->getCode().empty(); genVariable(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - runnerPushFunc, runnerPullFunc, var.type, var.name + group.getName(), varAdaptor.getVarLocation(var.name), + runnerPushFunc, runnerPullFunc, var.type, var.name + group.getName(), varAdaptor.getLoc(var.name), autoInitialized, getSizeFn(group, var), mem, statePushPullFunctions); // Loop through EGPs required to initialize variable @@ -437,16 +437,16 @@ void genRunnerFusedVars(const ModelSpecMerged &modelMerged, const BackendBase &b { // Loop through variables const V varAdaptor(group); - for(const auto &var : varAdaptor.getVars()) { + for(const auto &var : varAdaptor.getDefs()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - var.type, modelMerged.getTypeContext(), var.name + varAdaptor.getFusedVarSuffix(), varAdaptor.getVarLocation(var.name), + var.type, modelMerged.getTypeContext(), var.name + varAdaptor.getFusedSuffix(), varAdaptor.getLoc(var.name), getSizeFn(group, var), mem); // Loop through EGPs required to initialize variable - for(const auto &egp : varAdaptor.getVarInitialisers().at(var.name).getSnippet()->getExtraGlobalParams()) { + for(const auto &egp : varAdaptor.getInitialisers().at(var.name).getSnippet()->getExtraGlobalParams()) { genExtraGlobalParam(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerExtraGlobalParamFunc, - egp.type, egp.name + var.name + varAdaptor.getFusedVarSuffix(), + egp.type, egp.name + var.name + varAdaptor.getFusedSuffix(), true, VarLocation::HOST_DEVICE); } } @@ -459,15 +459,15 @@ void genRunnerFusedVarPushPull(const BackendBase &backend, CodeStream &definitio { // Loop through variables const V varAdaptor(group); - for(const auto &var : varAdaptor.getVars()) { - const bool autoInitialized = !varAdaptor.getVarInitialisers().at(var.name).getSnippet()->getCode().empty(); - genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, varAdaptor.getVarLocation(var.name), + for(const auto &var : varAdaptor.getDefs()) { + const bool autoInitialized = !varAdaptor.getInitialisers().at(var.name).getSnippet()->getCode().empty(); + genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, varAdaptor.getLoc(var.name), backend.getPreferences().automaticCopy, var.name + group.getName(), groupStatePushPullFunctions, [&]() { backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, var.type, var.name + group.getName(), - varAdaptor.getVarLocation(var.name), autoInitialized, getSizeFn(group, var)); + varAdaptor.getLoc(var.name), autoInitialized, getSizeFn(group, var)); }); } } @@ -479,11 +479,11 @@ void genRunnerEGPs(const ModelSpecMerged &modelMerged, const BackendBase &backen { // Loop through EGPs const E egpAdaptor(group); - for(const auto &egp: egpAdaptor.getEGPs()) { + for(const auto &egp: egpAdaptor.getDefs()) { genExtraGlobalParam(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerExtraGlobalParamFunc, egp.type, egp.name + group.getName(), - true, egpAdaptor.getEGPLocation(egp.name)); + true, egpAdaptor.getLoc(egp.name)); } } //------------------------------------------------------------------------- From 2c4561af6326e473fada86980f3d3beac09cc611 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 26 Jan 2023 17:40:48 +0000 Subject: [PATCH 102/725] added adapters for custom update var references --- include/genn/genn/customUpdateInternal.h | 46 ++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/include/genn/genn/customUpdateInternal.h b/include/genn/genn/customUpdateInternal.h index 5093e7a99d..983c6e245a 100644 --- a/include/genn/genn/customUpdateInternal.h +++ b/include/genn/genn/customUpdateInternal.h @@ -36,6 +36,29 @@ class CustomUpdateInternal : public CustomUpdate using CustomUpdate::isNeuronReduction; }; +//---------------------------------------------------------------------------- +// CustomUpdateVarRefAdapter +//---------------------------------------------------------------------------- +class CustomUpdateVarRefAdapter +{ +public: + CustomUpdateVarRefAdapter(const CustomUpdateInternal &cu) : m_CU(cu) + {} + + //---------------------------------------------------------------------------- + // Public methods + //---------------------------------------------------------------------------- + Models::Base::VarRefVec getDefs() const{ return m_CU.getCustomUpdateModel()->getVarRefs(); } + + const std::unordered_map &getInitialisers() const{ return m_CU.getVarReferences(); } + +private: + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + const CustomUpdateInternal &m_CU; +}; + //------------------------------------------------------------------------ // CustomUpdateInternal //------------------------------------------------------------------------ @@ -67,4 +90,27 @@ class CustomUpdateWUInternal : public CustomUpdateWU using CustomUpdateWU::isBatchReduction; using CustomUpdateWU::isTransposeOperation; }; + +//---------------------------------------------------------------------------- +// CustomUpdateWUVarRefAdapter +//---------------------------------------------------------------------------- +class CustomUpdateWUVarRefAdapter +{ +public: + CustomUpdateWUVarRefAdapter(const CustomUpdateWUInternal &cu) : m_CU(cu) + {} + + //---------------------------------------------------------------------------- + // Public methods + //---------------------------------------------------------------------------- + Models::Base::VarRefVec getDefs() const{ return m_CU.getCustomUpdateModel()->getVarRefs(); } + + const std::unordered_map &getInitialisers() const{ return m_CU.getVarReferences(); } + +private: + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + const CustomUpdateWUInternal &m_CU; +}; } // namespace GeNN From e35aa0f1665210ed1a2f64907b89c9d011bdbd04 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 26 Jan 2023 18:01:57 +0000 Subject: [PATCH 103/725] some MAYBE unnecessary gnarliness to handle variables and variable references with caching environment --- include/genn/genn/varAccess.h | 6 ++ .../code_generator/customUpdateGroupMerged.cc | 71 +++++++++++-------- 2 files changed, 49 insertions(+), 28 deletions(-) diff --git a/include/genn/genn/varAccess.h b/include/genn/genn/varAccess.h index 5fabb846e3..43b14378cc 100644 --- a/include/genn/genn/varAccess.h +++ b/include/genn/genn/varAccess.h @@ -69,6 +69,12 @@ inline bool operator & (VarAccessMode mode, VarAccessModeAttribute modeAttribute return (static_cast(mode) & static_cast(modeAttribute)) != 0; } +inline bool operator & (VarAccessMode a, VarAccessMode b) +{ + return (static_cast(a) & static_cast(b)) != 0; +} + + //---------------------------------------------------------------------------- // Helpers //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 1e0214f469..4fa6113f44 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -124,35 +124,49 @@ class EnvironmentSubstitute : public EnvironmentExternal }; //! Pretty printing environment which caches used variables in local variables -template +template class EnvironmentLocalVarCache : public EnvironmentExternal { - typedef std::function GetIndexFn; + //! Type of a single definition + typedef typename std::invoke_result_t::value_type DefType; + + //! Type of a single initialiser + typedef typename std::remove_reference_t>::mapped_type InitialiserType; + + //! Function used to provide index strings based on initialiser and access type + typedef std::function GetIndexFn; + public: - EnvironmentLocalVarCache(const std::vector &vars, PrettyPrinter::EnvironmentBase &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") - : EnvironmentExternal(enclosing), m_Vars(vars), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) + EnvironmentLocalVarCache(const G &group, PrettyPrinter::EnvironmentBase &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") + : EnvironmentExternal(enclosing), m_Group(group), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) { - // Add variables to map, initially with value set to value - std::transform(m_Vars.cbegin(), m_Vars.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), + // Add name of each definition to map, initially with value set to value + const auto defs = A(m_Group).getDefs(); + std::transform(defs.cbegin(), defs.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), [](const auto &v){ return std::make_pair(v.name, false); }); } - EnvironmentLocalVarCache(const std::vector &vars, CodeStream &os, GetIndexFn getIndex, const std::string &localPrefix = "l") - : EnvironmentExternal(os), m_Vars(vars), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) + EnvironmentLocalVarCache(const G &group, CodeStream &os, GetIndexFn getIndex, const std::string &localPrefix = "l") + : EnvironmentExternal(os), m_Group(group), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) { - // Add variables to map, initially with value set to value - std::transform(m_Vars.cbegin(), m_Vars.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), + // Add name of each definition to map, initially with value set to value + const auto defs = A(m_Group).getDefs(); + std::transform(defs.cbegin(), defs.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), [](const auto &v){ return std::make_pair(v.name, false); }); } ~EnvironmentLocalVarCache() { - // Copy variables which have been referenced into new vector - std::vector referencedVars; - std::copy_if(m_Vars.cbegin(), m_Vars.cend(), std::back_inserter(referencedVars), + A adapter(m_Group); + + // Copy definitions which have been referenced into new vector + const auto defs = adapter.getDefs(); + std::remove_const_t referencedVars; + std::copy_if(defs.cbegin(), defs.cend(), std::back_inserter(referencedVars), [this](const auto &v){ return m_VariablesReferenced.at(v.name); }); // Loop through referenced variables + const auto &initialisers = adapter.getInitialisers(); for(const auto &v : referencedVars) { if(v.access & VarAccessMode::READ_ONLY) { getContextStream() << "const "; @@ -163,7 +177,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal // **NOTE** by not initialising these variables for reductions, // compilers SHOULD emit a warning if user code doesn't set it to something if(!(v.access & VarAccessModeAttribute::REDUCE)) { - getContextStream() << " = group->" << v.name << "[" << m_GetIndex(v.access) << "]"; + getContextStream() << " = group->" << v.name << "[" << m_GetIndex(initialisers.at(v.name), v.access) << "]"; } getContextStream() << ";" << std::endl; } @@ -175,7 +189,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal for(const auto &v : referencedVars) { // If variables are read-write if(v.access & VarAccessMode::READ_WRITE) { - getContextStream() << "group->" << v.name << "[" << m_GetIndex(v.access) << "]"; + getContextStream() << "group->" << v.name << "[" << m_GetIndex(initialisers.at(v.name), v.access) << "]"; getContextStream() << " = " << m_LocalPrefix << v.name << ";" << std::endl; } } @@ -210,7 +224,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - const std::vector &m_Vars; + const G &m_Group; std::ostringstream m_ContentsStream; CodeStream m_Contents; const std::string m_LocalPrefix; @@ -389,22 +403,23 @@ boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() //---------------------------------------------------------------------------- void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, Substitutions &popSubs) const { - const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); - EnvironmentSubstitute subs(os); subs.addSubstitution("id", popSubs["id"]); - EnvironmentLocalVarCache varSubs(cm->getVars(), subs, - [this](VarAccess a) - { - return getVarIndex(getVarAccessDuplication(a), "id"); - }); + EnvironmentLocalVarCache varSubs( + getArchetype(), subs, + [this](const Models::VarInit&, VarAccess a) + { + return getVarIndex(getVarAccessDuplication(a), "id"); + }); - /*EnvironmentLocalVarCache varRefSubs(cm->getVarRefs(), subs, - [this](VarAccessMode a) + EnvironmentLocalVarCache varRefSubs(getArchetype(), subs, + [this](const Models::VarReference &v, VarAccessMode) { - return getVarRefIndex(a, "id"); - });*/ + return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, + getVarAccessDuplication(v.getVar().access), + "id"); + }); /*genCustomUpdate(os, popSubs, *this, "id", [this](const auto &varRef, const std::string &index) @@ -414,7 +429,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase&, CodeStrea index); });*/ // Pretty print code - PrettyPrinter::print(m_UpdateStatements, varSubs, getTypeContext()); + PrettyPrinter::print(m_UpdateStatements, varRefSubs, getTypeContext()); } //---------------------------------------------------------------------------- std::string CustomUpdateGroupMerged::getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const From 9016dae3c9517b1070cc1a67aae86228bcb86ecd Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 26 Jan 2023 18:19:59 +0000 Subject: [PATCH 104/725] finally given up on std::string_view for Tokens * We don't want to keep original source knocking around for ever * Most tokens are thrown away anyway * Code strings are small --- .../genn/genn/code_generator/groupMergedTypeEnvironment.h | 6 +++--- include/genn/genn/transpiler/token.h | 7 ++----- src/genn/genn/code_generator/customUpdateGroupMerged.cc | 8 ++++---- src/genn/genn/transpiler/errorHandler.cc | 4 ++-- src/genn/genn/transpiler/parser.cc | 4 ++-- src/genn/genn/transpiler/prettyPrinter.cc | 6 +++--- src/genn/genn/transpiler/typeChecker.cc | 2 +- tests/unit/typeChecker.cc | 4 ++-- 8 files changed, 19 insertions(+), 22 deletions(-) diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index 5af65bc19e..a35155decb 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -43,7 +43,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa bool initializer) final { // If type isn't found - auto existingType = m_Types.find(std::string{name.lexeme}); + auto existingType = m_Types.find(name.lexeme); if(existingType == m_Types.end()) { if(m_Enclosing) { return m_Enclosing->assign(name, op, assignedType, @@ -66,7 +66,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa virtual const Type::Base *incDec(const Token &name, Token::Type op, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) final { - auto existingType = m_Types.find(std::string{name.lexeme}); + auto existingType = m_Types.find(name.lexeme); if(existingType == m_Types.end()) { if(m_Enclosing) { return m_Enclosing->incDec(name, op, context, errorHandler); @@ -86,7 +86,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa virtual const Type::Base *getType(const Token &name, ErrorHandlerBase &errorHandler) final { - auto type = m_Types.find(std::string{name.lexeme}); + auto type = m_Types.find(name.lexeme); if(type == m_Types.end()) { if(m_Enclosing) { return m_Enclosing->getType(name, errorHandler); diff --git a/include/genn/genn/transpiler/token.h b/include/genn/genn/transpiler/token.h index ed8022d05b..e3878e2382 100644 --- a/include/genn/genn/transpiler/token.h +++ b/include/genn/genn/transpiler/token.h @@ -1,11 +1,8 @@ #pragma once // Standard C++ includes +#include #include -#include - -// Standard C includes -#include // **YUCK** on Windows undefine TRUE and FALSE macros #ifdef _WIN32 @@ -58,7 +55,7 @@ struct Token } const Type type; - const std::string_view lexeme; + const std::string lexeme; const size_t line; }; } // namespace GeNN::Transpiler diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 4fa6113f44..5f7d9fc09d 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -67,7 +67,7 @@ class EnvironmentExternal : public PrettyPrinter::EnvironmentBase return std::visit( Transpiler::Utils::Overload{ [&name](std::reference_wrapper enclosing)->std::string { return enclosing.get().getName(name); }, - [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Variable '" + std::string{name.lexeme} + "' undefined"); }}, + [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Variable '" + name.lexeme + "' undefined"); }}, getContext()); } @@ -91,7 +91,7 @@ class EnvironmentSubstitute : public EnvironmentExternal virtual std::string getName(const Token &name) final { // If there isn't a substitution for this name, try and get name from context - auto sub = m_VarSubstitutions.find(std::string{name.lexeme}); + auto sub = m_VarSubstitutions.find(name.lexeme); if(sub == m_VarSubstitutions.end()) { return getContextName(name); } @@ -201,7 +201,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal virtual std::string getName(const Token &name) final { // If variable with this name isn't found, try and get name from context - auto var = m_VariablesReferenced.find(std::string{name.lexeme}); + auto var = m_VariablesReferenced.find(name.lexeme); if(var == m_VariablesReferenced.end()) { return getContextName(name); } @@ -211,7 +211,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal var->second = true; // Add local prefix to variable name - return m_LocalPrefix + std::string{name.lexeme}; + return m_LocalPrefix + name.lexeme; } } diff --git a/src/genn/genn/transpiler/errorHandler.cc b/src/genn/genn/transpiler/errorHandler.cc index 0c90a298bd..11f38cc315 100644 --- a/src/genn/genn/transpiler/errorHandler.cc +++ b/src/genn/genn/transpiler/errorHandler.cc @@ -19,7 +19,7 @@ void ErrorHandler::error(const Token &token, std::string_view message) report(token.line, " at end", message); } else { - report(token.line, " at '" + std::string{token.lexeme} + "'", message); + report(token.line, " at '" + token.lexeme + "'", message); } } //---------------------------------------------------------------------------- @@ -43,7 +43,7 @@ void SingleLineErrorHandler::error(const Token &token, std::string_view message) report(" at end", message); } else { - report(" at '" + std::string{token.lexeme} + "'", message); + report(" at '" + token.lexeme + "'", message); } } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index f3dc60a635..81863cc0e4 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -616,9 +616,9 @@ Statement::StatementPtr parseSelectionStatement(ParserState &parserState) // "if" "(" expression ")" statement "else" statement // "switch" "(" expression ")" compound-statement const auto keyword = parserState.previous(); - parserState.consume(Token::Type::LEFT_PAREN, "Expect '(' after '" + std::string{keyword.lexeme} + "'"); + parserState.consume(Token::Type::LEFT_PAREN, "Expect '(' after '" + keyword.lexeme + "'"); auto condition = parseExpression(parserState); - parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after '" + std::string{keyword.lexeme} + "'"); + parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after '" + keyword.lexeme + "'"); // If this is an if statement if(keyword.type == Token::Type::IF) { diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index afd1d09b7e..372657ea7b 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -43,7 +43,7 @@ class EnvironmentInternal : public EnvironmentBase throw std::runtime_error("Redeclaration of variable"); } - return "_" + std::string{name.lexeme}; + return "_" + name.lexeme; } virtual std::string getName(const Token &name) final @@ -52,7 +52,7 @@ class EnvironmentInternal : public EnvironmentBase return m_Enclosing.getName(name); } else { - return "_" + std::string{name.lexeme}; + return "_" + name.lexeme; } } @@ -66,7 +66,7 @@ class EnvironmentInternal : public EnvironmentBase // Members //--------------------------------------------------------------------------- EnvironmentBase &m_Enclosing; - std::unordered_set m_LocalVariables; + std::unordered_set m_LocalVariables; }; //--------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 17d0c983ea..8558def5d0 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -133,7 +133,7 @@ class EnvironmentInternal : public EnvironmentBase // Members //--------------------------------------------------------------------------- EnvironmentBase &m_Enclosing; - std::unordered_map m_Types; + std::unordered_map m_Types; }; //--------------------------------------------------------------------------- diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 81cdc19cf1..6406b50acc 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -37,7 +37,7 @@ class TestErrorHandler : public ErrorHandlerBase report(token.line, " at end", message); } else { - report(token.line, " at '" + std::string{token.lexeme} + "'", message); + report(token.line, " at '" + token.lexeme + "'", message); } } @@ -131,7 +131,7 @@ class TestEnvironment : public TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- - std::unordered_map m_Types; + std::unordered_map m_Types; }; template From 0df78db5e57fbfb6b296fa0f962f3cd68b771006 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 26 Jan 2023 18:26:44 +0000 Subject: [PATCH 105/725] tidy --- .../code_generator/customUpdateGroupMerged.cc | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 5f7d9fc09d..9131f68562 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -403,9 +403,12 @@ boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() //---------------------------------------------------------------------------- void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, Substitutions &popSubs) const { + // Build initial environment with ID etc + // **TODO** this should happen in backend EnvironmentSubstitute subs(os); subs.addSubstitution("id", popSubs["id"]); + // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varSubs( getArchetype(), subs, [this](const Models::VarInit&, VarAccess a) @@ -413,22 +416,17 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase&, CodeStrea return getVarIndex(getVarAccessDuplication(a), "id"); }); - EnvironmentLocalVarCache varRefSubs(getArchetype(), subs, - [this](const Models::VarReference &v, VarAccessMode) - { - return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, - getVarAccessDuplication(v.getVar().access), - "id"); - }); - - /*genCustomUpdate(os, popSubs, *this, "id", - [this](const auto &varRef, const std::string &index) - { - return getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, - getVarAccessDuplication(varRef.getVar().access), - index); - });*/ - // Pretty print code + // Create an environment which caches variable references in local variables if they are accessed + EnvironmentLocalVarCache varRefSubs( + getArchetype(), subs, + [this](const Models::VarReference &v, VarAccessMode) + { + return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, + getVarAccessDuplication(v.getVar().access), + "id"); + }); + + // Pretty print previously parsed update statements PrettyPrinter::print(m_UpdateStatements, varRefSubs, getTypeContext()); } //---------------------------------------------------------------------------- From 4fff06c97a957d59a6a6b466345ad0e6c86f8219 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 27 Jan 2023 10:52:15 +0000 Subject: [PATCH 106/725] * Removed "foreign" from function types (foreignness is irrelevant to TYPE) * Added types for all supported transcendentals --- include/genn/genn/type.h | 141 +++++++++++++++++++----- src/genn/genn/transpiler/typeChecker.cc | 2 +- src/genn/genn/type.cc | 87 +++++++++++++-- 3 files changed, 192 insertions(+), 38 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 213f517d63..29a6d9fb44 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -50,16 +50,35 @@ using NumericType = TYPE; \ } -#define DECLARE_FOREIGN_FUNCTION_TYPE(TYPE, RETURN_TYPE, ...) \ - class TYPE : public ForeignFunction \ +#define DECLARE_FUNCTION_TYPE(TYPE, RETURN_TYPE, ...) \ + class TYPE : public Function \ { \ DECLARE_TYPE(TYPE) \ - TYPE(Qualifier qualifiers = Qualifier{0}) : ForeignFunction(qualifiers){} \ + TYPE(Qualifier qualifiers = Qualifier{0}) : Function(qualifiers){} \ virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new TYPE(qualifiers); } \ } #define IMPLEMENT_TYPE(TYPE) TYPE *TYPE::s_Instance = NULL -#define IMPLEMENT_NUMERIC_TYPE(TYPE) IMPLEMENT_TYPE(TYPE) + +//! Helper macro to declare single and double precision one argument function types +#define DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(TYPE) \ + DECLARE_FUNCTION_TYPE(TYPE##F, Float, Float); \ + DECLARE_FUNCTION_TYPE(TYPE##D, Double, Double) + +//! Helper macro to declare single and double precision two argument function types +#define DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(TYPE) \ + DECLARE_FUNCTION_TYPE(TYPE##F, Float, Float, Float); \ + DECLARE_FUNCTION_TYPE(TYPE##D, Double, Double, Double) + +//! Helper macro to declare single and double precision three argument function types +#define DECLARE_THREE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(TYPE) \ + DECLARE_FUNCTION_TYPE(TYPE##F, Float, Float, Float, Float); \ + DECLARE_FUNCTION_TYPE(TYPE##D, Double, Double, Double, Double) + +//! Helper macro to implement single and double precision function types +#define IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(TYPE) \ + IMPLEMENT_TYPE(TYPE##F); \ + IMPLEMENT_TYPE(TYPE##D) //---------------------------------------------------------------------------- // GeNN::Type::TypeTraits @@ -268,28 +287,28 @@ class NumericTypedef : public NumericBase }; //---------------------------------------------------------------------------- -// GeNN::Type::ForeignFunctionBase +// GeNN::Type::FunctionBase //---------------------------------------------------------------------------- -class ForeignFunctionBase : public Base +class FunctionBase : public Base { public: - ForeignFunctionBase(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} + FunctionBase(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ - virtual const NumericBase *getReturnType() const = 0; - virtual std::vector getArgumentTypes() const = 0; + virtual const Base *getReturnType() const = 0; + virtual std::vector getArgumentTypes() const = 0; }; //---------------------------------------------------------------------------- -// GeNN::Type::ForeignFunction +// GeNN::Type::Function //---------------------------------------------------------------------------- template -class ForeignFunction : public ForeignFunctionBase +class Function : public FunctionBase { public: - ForeignFunction(Qualifier qualifiers = Qualifier{0}) : ForeignFunctionBase(qualifiers){} + Function(Qualifier qualifiers = Qualifier{0}) : FunctionBase(qualifiers){} //------------------------------------------------------------------------ // Base virtuals @@ -317,16 +336,16 @@ class ForeignFunction : public ForeignFunctionBase } //------------------------------------------------------------------------ - // ForeignFunctionBase virtuals + // FunctionBase virtuals //------------------------------------------------------------------------ - virtual const NumericBase *getReturnType() const final + virtual const Base *getReturnType() const final { return ReturnType::getInstance(); } - virtual std::vector getArgumentTypes() const final + virtual std::vector getArgumentTypes() const final { - std::vector args; + std::vector args; args.reserve(sizeof...(ArgTypes)); updateArgumentTypes(args); return args; @@ -343,7 +362,7 @@ class ForeignFunction : public ForeignFunctionBase typeName += T::getInstance()->getName(); // If there are more arguments left in pack, add comma and recurse - if constexpr (sizeof...(Args)) { + if constexpr (sizeof...(Args) > 0) { typeName += ", "; updateTypeName(typeName); } @@ -356,24 +375,23 @@ class ForeignFunction : public ForeignFunctionBase typeName += T::getInstance()->getResolvedName(context); // If there are more arguments left in pack, add comma and recurse - if constexpr (sizeof...(Args)) { + if constexpr (sizeof...(Args) > 0) { typeName += ", "; updateResolvedTypeName(context, typeName); } } template - static void updateArgumentTypes(std::vector &args) + static void updateArgumentTypes(std::vector &args) { // Add argument typename to string args.push_back(T::getInstance()); // If there are more arguments left in pack, recurse - if constexpr (sizeof...(Args)) { + if constexpr (sizeof...(Args) > 0) { updateArgumentTypes(args); } } - }; //---------------------------------------------------------------------------- @@ -392,11 +410,84 @@ DECLARE_NUMERIC_TYPE(Float, float, 50, "f"); DECLARE_NUMERIC_TYPE(Double, double, 60, ""); //---------------------------------------------------------------------------- -// Declare standard library foreign function types +// Declare standard library function types //---------------------------------------------------------------------------- -DECLARE_FOREIGN_FUNCTION_TYPE(Exp, Double, Double); -DECLARE_FOREIGN_FUNCTION_TYPE(Sqrt, Double, Double); - +// Trigonometric functions +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Cos); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Sin); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Tan); + +// Inverse trigonometric functions +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Acos); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Asin); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Atan); +DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Atan2); + +// Hyperbolic functions +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Cosh); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Sinh); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Tanh); + +// Inverse Hyperbolic functions +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Acosh); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Asinh); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Atanh); + +// Exponential functions +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Exp); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(ExpM1); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Exp2); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Pow); +DECLARE_FUNCTION_TYPE(ScalBNF, Float, Float, Int32); +DECLARE_FUNCTION_TYPE(ScalBND, Double, Double, Int32); + +// Logarithm functions +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Log); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Log1P); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Log2); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Log10); +DECLARE_FUNCTION_TYPE(LdExpF, Float, Float, Int32); +DECLARE_FUNCTION_TYPE(LdExpD, Double, Double, Int32); +DECLARE_FUNCTION_TYPE(ILogBF, Int32, Float); +DECLARE_FUNCTION_TYPE(ILogBD, Int32, Double); + +// Root functions +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Sqrt); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Cbrt); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Hypot); + +// Rounding functions +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Ceil); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Floor); +DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Fmod); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Round); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Rint); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Trunc); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(NearbyInt); +DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(NextAfter); +DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Remainder); + +// Range functions +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(FAbs); +DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(FDim); +DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(FMax); +DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(FMin); + +// Other functions +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Erf); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(ErfC); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(TGamma); +DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(LGamma); +DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(CopySign); +DECLARE_THREE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(FMA); +/*{, +{"frexp", "frexpf"}, // pointer arguments +{"modf", "modff"}, // pointer arguments +{"scalbln", "scalblnf"}, // long type +{"lround", "lroundf"}, // long return type +{"lrint", "lrintf"}, // long return type +{"remquo", "remquof"}, // pointer arguments +*/ //! Parse a numeric type const NumericBase *parseNumeric(std::string_view typeString); diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 8558def5d0..06b521fd92 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -294,7 +294,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Evaluate callee type auto calleeType = evaluateType(call.getCallee()); - auto calleeFunctionType = dynamic_cast(calleeType); + auto calleeFunctionType = dynamic_cast(calleeType); // If callee's a function if (calleeFunctionType) { diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 6c00a1fdc5..fe1ec91e1d 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -63,19 +63,82 @@ const std::unordered_map uns namespace GeNN::Type { // Implement numeric types -IMPLEMENT_NUMERIC_TYPE(Bool); -IMPLEMENT_NUMERIC_TYPE(Int8); -IMPLEMENT_NUMERIC_TYPE(Int16); -IMPLEMENT_NUMERIC_TYPE(Int32); -IMPLEMENT_NUMERIC_TYPE(Uint8); -IMPLEMENT_NUMERIC_TYPE(Uint16); -IMPLEMENT_NUMERIC_TYPE(Uint32); -IMPLEMENT_NUMERIC_TYPE(Float); -IMPLEMENT_NUMERIC_TYPE(Double); +IMPLEMENT_TYPE(Bool); +IMPLEMENT_TYPE(Int8); +IMPLEMENT_TYPE(Int16); +IMPLEMENT_TYPE(Int32); +IMPLEMENT_TYPE(Uint8); +IMPLEMENT_TYPE(Uint16); +IMPLEMENT_TYPE(Uint32); +IMPLEMENT_TYPE(Float); +IMPLEMENT_TYPE(Double); + +// Implement trigonometric functions +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Cos); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Sin); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Tan); + +// Implement inverse trigonometric functions +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Acos); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Asin); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Atan); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Atan2); + +// Implement hyperbolic functions +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Cosh); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Sinh); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Tanh); + +// Implement inverse hyperbolic functions +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Acosh); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Asinh); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Atanh); + +// Implement exponential functions +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Exp); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(ExpM1); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Exp2); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Pow); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(ScalBN); + +// Implement logarithm functions +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Log); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Log1P); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Log2); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Log10); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(LdExp); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(ILogB); + +// Implement root functions +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Sqrt); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Cbrt); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Hypot); + +// Implement rounding functions +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Ceil); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Floor); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Fmod); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Round); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Rint); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Trunc); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(NearbyInt); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(NextAfter); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Remainder); + +// Implement range functions +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(FAbs); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(FDim); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(FMax); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(FMin); + +// Implement other functions +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Erf); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(ErfC); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(TGamma); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(LGamma); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(CopySign); +IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(FMA); -// Implement foreign function types -IMPLEMENT_TYPE(Exp); -IMPLEMENT_TYPE(Sqrt); //---------------------------------------------------------------------------- // GeNN::Type::Base From 8bc03a46502e9f54846a9c439fc72be832db13d3 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 27 Jan 2023 11:46:13 +0000 Subject: [PATCH 107/725] removed some more std::string_view --- include/genn/genn/type.h | 4 ++-- src/genn/genn/transpiler/parser.cc | 10 +++++----- src/genn/genn/type.cc | 8 ++++---- tests/unit/typeChecker.cc | 6 +++--- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 29a6d9fb44..a94a48a143 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -489,10 +489,10 @@ DECLARE_THREE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(FMA); {"remquo", "remquof"}, // pointer arguments */ //! Parse a numeric type -const NumericBase *parseNumeric(std::string_view typeString); +const NumericBase *parseNumeric(const std::string &typeString); //! Look up numeric type based on set of type specifiers -const NumericBase *getNumericType(const std::set &typeSpecifiers); +const NumericBase *getNumericType(const std::set &typeSpecifiers); //! Apply C type promotion rules to numeric type const NumericBase *getPromotedType(const NumericBase *type, const TypeContext &context); diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 81863cc0e4..6002a13311 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -187,9 +187,9 @@ const GeNN::Type::Base *parseDeclarationSpecifiers(ParserState &parserState) { using namespace GeNN::Type; - std::set typeSpecifiers; - std::set typeQualifiers; - std::vector> pointerTypeQualifiers; + std::set typeSpecifiers; + std::set typeQualifiers; + std::vector> pointerTypeQualifiers; do { // If token is a star, add new set of pointer type qualifiers @@ -199,7 +199,7 @@ const GeNN::Type::Base *parseDeclarationSpecifiers(ParserState &parserState) // Otherwise, if type is a qualifier else if(parserState.previous().type == Token::Type::TYPE_QUALIFIER) { // Add qualifier lexeme to correct list - std::set &qualifiers = pointerTypeQualifiers.empty() ? typeQualifiers : pointerTypeQualifiers.back(); + std::set &qualifiers = pointerTypeQualifiers.empty() ? typeQualifiers : pointerTypeQualifiers.back(); if(!qualifiers.insert(parserState.previous().lexeme).second) { parserState.error(parserState.previous(), "duplicate type qualifier"); } @@ -874,7 +874,7 @@ Statement::StatementList parseBlockItemList(const std::vector &tokens, Er const GeNN::Type::NumericBase *parseNumericType(const std::vector &tokens, ErrorHandlerBase &errorHandler) { ParserState parserState(tokens, errorHandler); - std::set typeSpecifiers; + std::set typeSpecifiers; while(parserState.match(Token::Type::TYPE_SPECIFIER)) { if(!typeSpecifiers.insert(parserState.previous().lexeme).second) { parserState.error(parserState.previous(), "duplicate type specifier"); diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index fe1ec91e1d..e28046c049 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -17,7 +17,7 @@ using namespace GeNN; // Anonymous namespace namespace { -const std::map, const Type::NumericBase*> numericTypeSpecifiers{ +const std::map, const Type::NumericBase*> numericTypeSpecifiers{ {{"char"}, Type::Int8::getInstance()}, {{"int8_t"}, Type::Int8::getInstance()}, @@ -47,7 +47,7 @@ const std::map, const Type::NumericBase*> numericType {{"double"}, Type::Double::getInstance()}, }; //---------------------------------------------------------------------------- -const std::set scalarTypeSpecifier{{"scalar"}}; +const std::set scalarTypeSpecifier{{"scalar"}}; //---------------------------------------------------------------------------- // Mapping of signed integer numericTypeSpecifiers to their unsigned equivalents const std::unordered_map unsignedType{ @@ -227,7 +227,7 @@ const Type::NumericBase *NumericTypedef::getResolvedType(const TypeContext &cont //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- -const NumericBase *parseNumeric(std::string_view typeString) +const NumericBase *parseNumeric(const std::string &typeString) { using namespace Transpiler; @@ -250,7 +250,7 @@ const NumericBase *parseNumeric(std::string_view typeString) return type; } //---------------------------------------------------------------------------- -const NumericBase *getNumericType(const std::set &typeSpecifiers) +const NumericBase *getNumericType(const std::set &typeSpecifiers) { // If type matches scalar type specifiers if(typeSpecifiers == scalarTypeSpecifier) { diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 6406b50acc..9956894ebf 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -57,7 +57,7 @@ class TestEnvironment : public TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- - void define(std::string_view name, const Type::Base *type) + void define(const std::string &name, const Type::Base *type) { if(!m_Types.try_emplace(name, type).second) { throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); @@ -65,13 +65,13 @@ class TestEnvironment : public TypeChecker::EnvironmentBase } template - void define(std::string_view name, Type::Qualifier qualifiers = Type::Qualifier{0}) + void define(const std::string &name, Type::Qualifier qualifiers = Type::Qualifier{0}) { define(name, T::getInstance()->getQualifiedType(qualifiers)); } template - void definePointer(std::string_view name, Type::Qualifier valueQualifiers = Type::Qualifier{0}, + void definePointer(const std::string &name, Type::Qualifier valueQualifiers = Type::Qualifier{0}, Type::Qualifier pointerQualifiers = Type::Qualifier{0}) { define(name, T::getInstance()->getQualifiedType(valueQualifiers)->getPointerType(pointerQualifiers)); From 12da2dd74752f1cce6bc2d17ca2cbbaf599d9db4 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 27 Jan 2023 11:46:45 +0000 Subject: [PATCH 108/725] type checker now build map of expressions to their types rather than only keeping types temporarily --- include/genn/genn/transpiler/typeChecker.h | 2 + src/genn/genn/transpiler/typeChecker.cc | 112 ++++++++++++--------- 2 files changed, 66 insertions(+), 48 deletions(-) diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index e5df2098a0..0951d7cf06 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -31,6 +31,8 @@ class TypeCheckError : public std::runtime_error } }; +typedef std::unordered_map ResolvedTypeMap; + //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 06b521fd92..da38b9d3c1 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -143,8 +143,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { public: Visitor(const Statement::StatementList &statements, const Type::TypeContext &context, - EnvironmentInternal &environment, ErrorHandlerBase &errorHandler) - : Visitor(context, environment, errorHandler) + EnvironmentInternal &environment, ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler) + : Visitor(context, environment, resolvedTypes, errorHandler) { for (auto &s : statements) { s.get()->accept(*this); @@ -152,21 +152,17 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } Visitor(const Expression::Base *expression, const Type::TypeContext &context, - EnvironmentInternal &environment, ErrorHandlerBase &errorHandler) - : Visitor(context, environment, errorHandler) + EnvironmentInternal &environment, ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler) + : Visitor(context, environment, resolvedTypes, errorHandler) { expression->accept(*this); } - //--------------------------------------------------------------------------- - // Public API - //--------------------------------------------------------------------------- - const Type::Base *getType() const{ return m_Type; } - private: - Visitor(const Type::TypeContext &context, EnvironmentInternal &environment, ErrorHandlerBase &errorHandler) - : m_Environment(environment), m_Type(nullptr), m_Context(context), m_ErrorHandler(errorHandler), - m_InLoop(false), m_InSwitch(false) + Visitor(const Type::TypeContext &context, EnvironmentInternal &environment, + ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler) + : m_Environment(environment), m_Context(context), m_ErrorHandler(errorHandler), + m_ResolvedTypes(resolvedTypes), m_InLoop(false), m_InSwitch(false) { } @@ -191,7 +187,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Use value type of array - m_Type = pointerType->getValueType(); + setExpressionType(&arraySubscript, pointerType->getValueType()); } // Otherwise else { @@ -203,8 +199,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Assignment &assignment) final { const auto rhsType = evaluateType(assignment.getValue()); - m_Type = m_Environment.get().assign(assignment.getVarName(), assignment.getOperator().type, rhsType, - m_Context, m_ErrorHandler); + setExpressionType(&assignment, + m_Environment.get().assign(assignment.getVarName(), assignment.getOperator().type, rhsType, + m_Context, m_ErrorHandler)); } virtual void visit(const Expression::Binary &binary) final @@ -212,7 +209,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto opType = binary.getOperator().type; const auto rightType = evaluateType(binary.getRight()); if (opType == Token::Type::COMMA) { - m_Type = rightType; + setExpressionType(&binary, rightType); } else { // If we're subtracting two pointers @@ -229,7 +226,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // **TODO** should be std::ptrdiff/Int64 - m_Type = Type::Int32::getInstance(); + setExpressionType(&binary); } // Otherwise, if we're adding to or subtracting from pointers else if (leftPointerType && rightNumericType && (opType == Token::Type::PLUS || opType == Token::Type::MINUS)) // P + n or P - n @@ -241,7 +238,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Use left type - m_Type = leftType; + setExpressionType(&binary, leftType); } // Otherwise, if we're adding a number to a pointer else if (leftNumericType && rightPointerType && opType == Token::Type::PLUS) // n + P @@ -253,7 +250,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Use right type - m_Type = leftType; + setExpressionType(&binary, rightType); } // Otherwise, if both operands are numeric else if (leftNumericType && rightNumericType) { @@ -271,16 +268,16 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If operator is a shift, promote left type if (opType == Token::Type::SHIFT_LEFT || opType == Token::Type::SHIFT_RIGHT) { - m_Type = Type::getPromotedType(leftNumericType, m_Context); + setExpressionType(&binary, Type::getPromotedType(leftNumericType, m_Context)); } // Otherwise, take common type else { - m_Type = Type::getCommonType(leftNumericType, rightNumericType, m_Context); + setExpressionType(&binary, Type::getCommonType(leftNumericType, rightNumericType, m_Context)); } } // Otherwise, any numeric type will do, take common type else { - m_Type = Type::getCommonType(leftNumericType, rightNumericType, m_Context); + setExpressionType(&binary, Type::getCommonType(leftNumericType, rightNumericType, m_Context)); } } else { @@ -316,7 +313,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto callArgType = evaluateType(call.getArguments().at(i).get()); }*/ // Type is return type of function - m_Type = calleeFunctionType->getReturnType(); + setExpressionType(&call, calleeFunctionType->getReturnType()); } } // Otherwise @@ -355,7 +352,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor throw TypeCheckError(); } - m_Type = cast.getType(); + setExpressionType(&cast, cast.getType()); } virtual void visit(const Expression::Conditional &conditional) final @@ -366,10 +363,11 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto falseNumericType = dynamic_cast(falseType); if (trueNumericType && falseNumericType) { // **TODO** check behaviour - m_Type = Type::getCommonType(trueNumericType, falseNumericType, m_Context); + const Type::Base *type = Type::getCommonType(trueNumericType, falseNumericType, m_Context); if(trueType->hasQualifier(Type::Qualifier::CONSTANT) || falseType->hasQualifier(Type::Qualifier::CONSTANT)) { - m_Type = m_Type->getQualifiedType(Type::Qualifier::CONSTANT); + type = type->getQualifiedType(Type::Qualifier::CONSTANT); } + setExpressionType(&conditional, type); } else { m_ErrorHandler.error(conditional.getQuestion(), @@ -380,7 +378,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Grouping &grouping) final { - m_Type = evaluateType(grouping.getExpression()); + setExpressionType(&grouping, evaluateType(grouping.getExpression())); } virtual void visit(const Expression::Literal &literal) final @@ -388,20 +386,20 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // Convert number token type to type // **THINK** is it better to use typedef for scalar or resolve from m_Context if (literal.getValue().type == Token::Type::DOUBLE_NUMBER) { - m_Type = Type::Double::getInstance(); + setExpressionType(&literal); } else if (literal.getValue().type == Token::Type::FLOAT_NUMBER) { - m_Type = Type::Float::getInstance(); + setExpressionType(&literal); } else if (literal.getValue().type == Token::Type::SCALAR_NUMBER) { // **TODO** cache - m_Type = new Type::NumericTypedef("scalar"); + setExpressionType(&literal, new Type::NumericTypedef("scalar")); } else if (literal.getValue().type == Token::Type::INT32_NUMBER) { - m_Type = Type::Int32::getInstance(); + setExpressionType(&literal); } else if (literal.getValue().type == Token::Type::UINT32_NUMBER) { - m_Type = Type::Uint32::getInstance(); + setExpressionType(&literal); } else { assert(false); @@ -412,24 +410,26 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { logical.getLeft()->accept(*this); logical.getRight()->accept(*this); - m_Type = Type::Int32::getInstance(); + setExpressionType(&logical); } virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final { - m_Type = m_Environment.get().incDec(postfixIncDec.getVarName(), postfixIncDec.getOperator().type, - m_Context, m_ErrorHandler); + setExpressionType(&postfixIncDec, + m_Environment.get().incDec(postfixIncDec.getVarName(), postfixIncDec.getOperator().type, + m_Context, m_ErrorHandler)); } virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final { - m_Type = m_Environment.get().incDec(prefixIncDec.getVarName(), prefixIncDec.getOperator().type, - m_Context, m_ErrorHandler); + setExpressionType(&prefixIncDec, + m_Environment.get().incDec(prefixIncDec.getVarName(), prefixIncDec.getOperator().type, + m_Context, m_ErrorHandler)); } virtual void visit(const Expression::Variable &variable) { - m_Type = m_Environment.get().getType(variable.getName(), m_ErrorHandler); + setExpressionType(&variable, m_Environment.get().getType(variable.getName(), m_ErrorHandler)); } virtual void visit(const Expression::Unary &unary) final @@ -446,7 +446,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Return value type - m_Type = rightPointerType->getValueType(); + setExpressionType(&unary, rightPointerType->getValueType()); } // Otherwise else { @@ -455,14 +455,14 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If operator is arithmetic, return promoted type if (unary.getOperator().type == Token::Type::PLUS || unary.getOperator().type == Token::Type::MINUS) { // **THINK** const through these? - m_Type = Type::getPromotedType(rightNumericType, m_Context); + setExpressionType(&unary, Type::getPromotedType(rightNumericType, m_Context)); } // Otherwise, if operator is bitwise else if (unary.getOperator().type == Token::Type::TILDA) { // If type is integer, return promoted type if (rightNumericType->isIntegral(m_Context)) { // **THINK** const through these? - m_Type = Type::getPromotedType(rightNumericType, m_Context); + setExpressionType(&unary, Type::getPromotedType(rightNumericType, m_Context)); } else { m_ErrorHandler.error(unary.getOperator(), @@ -472,11 +472,11 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Otherwise, if operator is logical else if (unary.getOperator().type == Token::Type::NOT) { - m_Type = Type::Int32::getInstance();; + setExpressionType(&unary); } // Otherwise, if operator is address of, return pointer type else if (unary.getOperator().type == Token::Type::AMPERSAND) { - m_Type = rightType->getPointerType(); + setExpressionType(&unary, rightType->getPointerType()); } } else { @@ -644,16 +644,30 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const Type::Base *evaluateType(const Expression::Base *expression) { expression->accept(*this); - return m_Type; + return m_ResolvedTypes.at(expression); } + void setExpressionType(const Expression::Base *expression, const Type::Base *type) + { + if (!m_ResolvedTypes.emplace(expression, type).second) { + throw std::runtime_error("Expression type resolved multiple times"); + } + } + + template + void setExpressionType(const Expression::Base *expression) + { + if (!m_ResolvedTypes.emplace(expression, T::getInstance()).second) { + throw std::runtime_error("Expression type resolved multiple times"); + } + } //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- std::reference_wrapper m_Environment; - const Type::Base *m_Type; const Type::TypeContext &m_Context; ErrorHandlerBase &m_ErrorHandler; + ResolvedTypeMap &m_ResolvedTypes; bool m_InLoop; bool m_InSwitch; }; @@ -764,14 +778,16 @@ const Type::Base *EnvironmentBase::incDec(const Token &name, Token::Type, void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { + ResolvedTypeMap expressionTypes; EnvironmentInternal internalEnvironment(environment); - Visitor(statements, context, internalEnvironment, errorHandler); + Visitor(statements, context, internalEnvironment, expressionTypes, errorHandler); } //--------------------------------------------------------------------------- const Type::Base *GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, EnvironmentBase &environment, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { + ResolvedTypeMap expressionTypes; EnvironmentInternal internalEnvironment(environment); - Visitor visitor(expression, context, internalEnvironment, errorHandler); - return visitor.getType(); + Visitor visitor(expression, context, internalEnvironment, expressionTypes, errorHandler); + return expressionTypes.at(expression); } From 697a6ec8c7bf617f675b8f60be7cdec57a602d8f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 27 Jan 2023 12:06:12 +0000 Subject: [PATCH 109/725] added some binary operator type-checking tests --- tests/unit/typeChecker.cc | 80 +++++++++++++++++++++++++++++++++------ 1 file changed, 68 insertions(+), 12 deletions(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 9956894ebf..925bdb387e 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -254,6 +254,74 @@ TEST(TypeChecker, Assignment) //-------------------------------------------------------------------------- TEST(TypeChecker, Binary) { + // Pointer difference + { + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray1"); + typeEnvironment.definePointer("intArray2"); + const auto *type = typeCheckExpression("intArray1 - intArray2", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + } + + // **TODO** different pointer types + + + // Pointer + integer + { + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); + typeEnvironment.define("offset"); + const auto *type = typeCheckExpression("intArray + offset", typeEnvironment); + const auto *pointerType = dynamic_cast(type); + EXPECT_TRUE(pointerType); + EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); + } + + // **TODO** constness and + + // Pointer + non-integer + EXPECT_THROW({ + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); + typeEnvironment.define("offset"); + typeCheckExpression("intArray + offset", typeEnvironment);}, + TypeChecker::TypeCheckError); + + // Pointer + pointer + EXPECT_THROW({ + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray1"); + typeEnvironment.definePointer("intArray2"); + typeCheckExpression("intArray1 + intArray2", typeEnvironment);}, + TypeChecker::TypeCheckError); + + + // Pointer - integer + { + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); + typeEnvironment.define("offset"); + const auto *type = typeCheckExpression("intArray - offset", typeEnvironment); + const auto *pointerType = dynamic_cast(type); + EXPECT_TRUE(pointerType); + EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); + } + + // Integer + pointer + { + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); + typeEnvironment.define("offset"); + const auto *type = typeCheckExpression("offset + intArray", typeEnvironment); + const auto *pointerType = dynamic_cast(type); + EXPECT_TRUE(pointerType); + EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); + } + + /*integer only (opType == Token::Type::PERCENT || opType == Token::Type::SHIFT_LEFT + || opType == Token::Type::SHIFT_RIGHT || opType == Token::Type::CARET + || opType == Token::Type::AMPERSAND || opType == Token::Type::PIPE)*/ + } //-------------------------------------------------------------------------- TEST(TypeChecker, Call) @@ -408,8 +476,6 @@ TEST(TypeChecker, Literal) TestEnvironment typeEnvironment; const auto *type = typeCheckExpression("1.0f", typeEnvironment); EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); - //EXPECT_TRUE(type.constValue); - //EXPECT_FALSE(type.constPointer); } // Scalar with single-precision @@ -418,8 +484,6 @@ TEST(TypeChecker, Literal) const Type::TypeContext typeContext{{"scalar", Type::Float::getInstance()}}; const auto *type = typeCheckExpression("1.0", typeEnvironment); EXPECT_EQ(type->getResolvedName(typeContext), Type::Float::getInstance()->getName()); - //EXPECT_TRUE(type.constValue); - //EXPECT_FALSE(type.constPointer); } // Scalar with double-precision @@ -428,8 +492,6 @@ TEST(TypeChecker, Literal) const Type::TypeContext typeContext{{"scalar", Type::Double::getInstance()}}; const auto *type = typeCheckExpression("1.0", typeEnvironment, typeContext); EXPECT_EQ(type->getResolvedName(typeContext), Type::Double::getInstance()->getName()); - //EXPECT_TRUE(type.constValue); - //EXPECT_FALSE(type.constPointer); } // Double @@ -437,8 +499,6 @@ TEST(TypeChecker, Literal) TestEnvironment typeEnvironment; const auto *type = typeCheckExpression("1.0d", typeEnvironment); EXPECT_EQ(type->getName(), Type::Double::getInstance()->getName()); - //EXPECT_TRUE(type.constValue); - //EXPECT_FALSE(type.constPointer); } // Integer @@ -446,8 +506,6 @@ TEST(TypeChecker, Literal) TestEnvironment typeEnvironment; const auto *type = typeCheckExpression("100", typeEnvironment); EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - //EXPECT_TRUE(type.constValue); - //EXPECT_FALSE(type.constPointer); } // Unsigned integer @@ -455,8 +513,6 @@ TEST(TypeChecker, Literal) TestEnvironment typeEnvironment; const auto *type = typeCheckExpression("100U", typeEnvironment); EXPECT_EQ(type->getName(), Type::Uint32::getInstance()->getName()); - //EXPECT_TRUE(type.constValue); - //EXPECT_FALSE(type.constPointer); } } //-------------------------------------------------------------------------- From 933c8603f4ea164e6e0f063d15b5bb1227567af8 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 27 Jan 2023 12:06:46 +0000 Subject: [PATCH 110/725] tidy --- src/genn/genn/transpiler/typeChecker.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index da38b9d3c1..7e6606a833 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -229,8 +229,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor setExpressionType(&binary); } // Otherwise, if we're adding to or subtracting from pointers - else if (leftPointerType && rightNumericType && (opType == Token::Type::PLUS || opType == Token::Type::MINUS)) // P + n or P - n - { + else if (leftPointerType && rightNumericType && (opType == Token::Type::PLUS || opType == Token::Type::MINUS)) { // P + n or P - n // Check that numeric operand is integer if (!rightNumericType->isIntegral(m_Context)) { m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); @@ -241,8 +240,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor setExpressionType(&binary, leftType); } // Otherwise, if we're adding a number to a pointer - else if (leftNumericType && rightPointerType && opType == Token::Type::PLUS) // n + P - { + else if (leftNumericType && rightPointerType && opType == Token::Type::PLUS) { // n + P // Check that numeric operand is integer if (!leftNumericType->isIntegral(m_Context)) { m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); From 630a457f4bec01fce3be517985b420832b658eee Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 30 Jan 2023 10:38:12 +0000 Subject: [PATCH 111/725] moved new environment classes into seperate compilation unit --- .../genn/genn/code_generator/environment.h | 201 +++++++++++++++++ .../code_generator/customUpdateGroupMerged.cc | 207 +----------------- src/genn/genn/code_generator/environment.cc | 52 +++++ src/genn/genn/genn.vcxproj | 2 + 4 files changed, 256 insertions(+), 206 deletions(-) create mode 100644 include/genn/genn/code_generator/environment.h create mode 100644 src/genn/genn/code_generator/environment.cc diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h new file mode 100644 index 0000000000..2477c01edf --- /dev/null +++ b/include/genn/genn/code_generator/environment.h @@ -0,0 +1,201 @@ +#pragma once + +// Standard C++ includes +#include +#include +#include + +// GeNN code generator includes +#include "code_generator/codeStream.h" + +// GeNN transpiler includes +#include "transpiler/prettyPrinter.h" +#include "transpiler/token.h" +#include "transpiler/transpilerUtils.h" + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentExternal +//---------------------------------------------------------------------------- +namespace GeNN::CodeGenerator +{ +class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase +{ + using EnvironmentBase = Transpiler::PrettyPrinter::EnvironmentBase; +public: + EnvironmentExternal(EnvironmentBase &enclosing) + : m_Context(enclosing) + { + } + + EnvironmentExternal(CodeStream &os) + : m_Context(os) + { + } + + //------------------------------------------------------------------------ + // PrettyPrinter::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::string define(const Transpiler::Token&); + +protected: + //------------------------------------------------------------------------ + // Protected API + //------------------------------------------------------------------------ + auto &getContext() const{ return m_Context; } + + CodeStream &getContextStream() const; + + std::string getContextName(const Transpiler::Token &name) const; + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + std::variant, std::reference_wrapper> m_Context; +}; + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentSubstitute +//---------------------------------------------------------------------------- +//! Standard pretty printing environment simply allowing substitutions to be implemented +class EnvironmentSubstitute : public EnvironmentExternal +{ + using EnvironmentBase = Transpiler::PrettyPrinter::EnvironmentBase; +public: + EnvironmentSubstitute(EnvironmentBase &enclosing) : EnvironmentExternal(enclosing){} + EnvironmentSubstitute(CodeStream &os) : EnvironmentExternal(os){} + + //------------------------------------------------------------------------ + // PrettyPrinter::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::string getName(const Transpiler::Token &name) final; + + virtual CodeStream &getStream() final + { + return getContextStream(); + } + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + void addSubstitution(const std::string &source, const std::string &destination); + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + std::unordered_map m_VarSubstitutions; +}; + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentLocalVarCache +//---------------------------------------------------------------------------- +//! Pretty printing environment which caches used variables in local variables +template +class EnvironmentLocalVarCache : public EnvironmentExternal +{ + //! Type of a single definition + typedef typename std::invoke_result_t::value_type DefType; + + //! Type of a single initialiser + typedef typename std::remove_reference_t>::mapped_type InitialiserType; + + //! Function used to provide index strings based on initialiser and access type + typedef std::function GetIndexFn; + +public: + EnvironmentLocalVarCache(const G &group, Transpiler::PrettyPrinter::EnvironmentBase &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") + : EnvironmentExternal(enclosing), m_Group(group), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) + { + // Add name of each definition to map, initially with value set to value + const auto defs = A(m_Group).getDefs(); + std::transform(defs.cbegin(), defs.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), + [](const auto &v){ return std::make_pair(v.name, false); }); + } + + EnvironmentLocalVarCache(const G &group, CodeStream &os, GetIndexFn getIndex, const std::string &localPrefix = "l") + : EnvironmentExternal(os), m_Group(group), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) + { + // Add name of each definition to map, initially with value set to value + const auto defs = A(m_Group).getDefs(); + std::transform(defs.cbegin(), defs.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), + [](const auto &v){ return std::make_pair(v.name, false); }); + } + + ~EnvironmentLocalVarCache() + { + A adapter(m_Group); + + // Copy definitions which have been referenced into new vector + const auto defs = adapter.getDefs(); + std::remove_const_t referencedVars; + std::copy_if(defs.cbegin(), defs.cend(), std::back_inserter(referencedVars), + [this](const auto &v){ return m_VariablesReferenced.at(v.name); }); + + // Loop through referenced variables + const auto &initialisers = adapter.getInitialisers(); + for(const auto &v : referencedVars) { + if(v.access & VarAccessMode::READ_ONLY) { + getContextStream() << "const "; + } + getContextStream() << v.type->getName() << " " << m_LocalPrefix << v.name; + + // If this isn't a reduction, read value from memory + // **NOTE** by not initialising these variables for reductions, + // compilers SHOULD emit a warning if user code doesn't set it to something + if(!(v.access & VarAccessModeAttribute::REDUCE)) { + getContextStream() << " = group->" << v.name << "[" << m_GetIndex(initialisers.at(v.name), v.access) << "]"; + } + getContextStream() << ";" << std::endl; + } + + // Write contents to context stream + getContextStream() << m_ContentsStream.str(); + + // Loop through referenced variables again + for(const auto &v : referencedVars) { + // If variables are read-write + if(v.access & VarAccessMode::READ_WRITE) { + getContextStream() << "group->" << v.name << "[" << m_GetIndex(initialisers.at(v.name), v.access) << "]"; + getContextStream() << " = " << m_LocalPrefix << v.name << ";" << std::endl; + } + } + } + + //------------------------------------------------------------------------ + // PrettyPrinter::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::string getName(const Transpiler::Token &name) final + { + // If variable with this name isn't found, try and get name from context + auto var = m_VariablesReferenced.find(name.lexeme); + if(var == m_VariablesReferenced.end()) { + return getContextName(name); + } + // Otherwise + else { + // Set flag to indicate that variable has been referenced + var->second = true; + + // Add local prefix to variable name + return m_LocalPrefix + name.lexeme; + } + } + + virtual CodeStream &getStream() final + { + return m_Contents; + } + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + const G &m_Group; + std::ostringstream m_ContentsStream; + CodeStream m_Contents; + const std::string m_LocalPrefix; + const GetIndexFn m_GetIndex; + std::unordered_map m_VariablesReferenced; +}; +} // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 9131f68562..17d4acd839 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -4,6 +4,7 @@ #include // GeNN code generator includes +#include "code_generator/environment.h" #include "code_generator/groupMergedTypeEnvironment.h" #include "code_generator/modelSpecMerged.h" @@ -26,212 +27,6 @@ using namespace GeNN::Transpiler; //-------------------------------------------------------------------------- namespace { -class EnvironmentExternal : public PrettyPrinter::EnvironmentBase -{ -public: - EnvironmentExternal(PrettyPrinter::EnvironmentBase &enclosing) - : m_Context(enclosing) - { - } - - EnvironmentExternal(CodeStream &os) - : m_Context(os) - { - } - - //------------------------------------------------------------------------ - // PrettyPrinter::EnvironmentBase virtuals - //------------------------------------------------------------------------ - virtual std::string define(const Token&) - { - throw std::runtime_error("Cannot declare variable in external environment"); - } - -protected: - //------------------------------------------------------------------------ - // Protected API - //------------------------------------------------------------------------ - auto &getContext() const{ return m_Context; } - - CodeStream &getContextStream() const - { - return std::visit( - Transpiler::Utils::Overload{ - [](std::reference_wrapper enclosing)->CodeStream& { return enclosing.get().getStream(); }, - [](std::reference_wrapper os)->CodeStream& { return os.get(); }}, - getContext()); - } - - std::string getContextName(const Token &name) const - { - return std::visit( - Transpiler::Utils::Overload{ - [&name](std::reference_wrapper enclosing)->std::string { return enclosing.get().getName(name); }, - [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Variable '" + name.lexeme + "' undefined"); }}, - getContext()); - } - -private: - //------------------------------------------------------------------------ - // Members - //------------------------------------------------------------------------ - std::variant, std::reference_wrapper> m_Context; -}; - -//! Standard pretty printing environment simply allowing substitutions to be implemented -class EnvironmentSubstitute : public EnvironmentExternal -{ -public: - EnvironmentSubstitute(PrettyPrinter::EnvironmentBase &enclosing) : EnvironmentExternal(enclosing){} - EnvironmentSubstitute(CodeStream &os) : EnvironmentExternal(os){} - - //------------------------------------------------------------------------ - // PrettyPrinter::EnvironmentBase virtuals - //------------------------------------------------------------------------ - virtual std::string getName(const Token &name) final - { - // If there isn't a substitution for this name, try and get name from context - auto sub = m_VarSubstitutions.find(name.lexeme); - if(sub == m_VarSubstitutions.end()) { - return getContextName(name); - } - // Otherwise, return substitution - else { - return sub->second; - } - } - - virtual CodeStream &getStream() final - { - return getContextStream(); - } - - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - void addSubstitution(const std::string &source, const std::string &destination) - { - if(!m_VarSubstitutions.emplace(source, destination).second) { - throw std::runtime_error("Redeclaration of substitution '" + source + "'"); - } - } - -private: - //------------------------------------------------------------------------ - // Members - //------------------------------------------------------------------------ - std::unordered_map m_VarSubstitutions; -}; - -//! Pretty printing environment which caches used variables in local variables -template -class EnvironmentLocalVarCache : public EnvironmentExternal -{ - //! Type of a single definition - typedef typename std::invoke_result_t::value_type DefType; - - //! Type of a single initialiser - typedef typename std::remove_reference_t>::mapped_type InitialiserType; - - //! Function used to provide index strings based on initialiser and access type - typedef std::function GetIndexFn; - -public: - EnvironmentLocalVarCache(const G &group, PrettyPrinter::EnvironmentBase &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") - : EnvironmentExternal(enclosing), m_Group(group), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) - { - // Add name of each definition to map, initially with value set to value - const auto defs = A(m_Group).getDefs(); - std::transform(defs.cbegin(), defs.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), - [](const auto &v){ return std::make_pair(v.name, false); }); - } - - EnvironmentLocalVarCache(const G &group, CodeStream &os, GetIndexFn getIndex, const std::string &localPrefix = "l") - : EnvironmentExternal(os), m_Group(group), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) - { - // Add name of each definition to map, initially with value set to value - const auto defs = A(m_Group).getDefs(); - std::transform(defs.cbegin(), defs.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), - [](const auto &v){ return std::make_pair(v.name, false); }); - } - - ~EnvironmentLocalVarCache() - { - A adapter(m_Group); - - // Copy definitions which have been referenced into new vector - const auto defs = adapter.getDefs(); - std::remove_const_t referencedVars; - std::copy_if(defs.cbegin(), defs.cend(), std::back_inserter(referencedVars), - [this](const auto &v){ return m_VariablesReferenced.at(v.name); }); - - // Loop through referenced variables - const auto &initialisers = adapter.getInitialisers(); - for(const auto &v : referencedVars) { - if(v.access & VarAccessMode::READ_ONLY) { - getContextStream() << "const "; - } - getContextStream() << v.type->getName() << " " << m_LocalPrefix << v.name; - - // If this isn't a reduction, read value from memory - // **NOTE** by not initialising these variables for reductions, - // compilers SHOULD emit a warning if user code doesn't set it to something - if(!(v.access & VarAccessModeAttribute::REDUCE)) { - getContextStream() << " = group->" << v.name << "[" << m_GetIndex(initialisers.at(v.name), v.access) << "]"; - } - getContextStream() << ";" << std::endl; - } - - // Write contents to context stream - getContextStream() << m_ContentsStream.str(); - - // Loop through referenced variables again - for(const auto &v : referencedVars) { - // If variables are read-write - if(v.access & VarAccessMode::READ_WRITE) { - getContextStream() << "group->" << v.name << "[" << m_GetIndex(initialisers.at(v.name), v.access) << "]"; - getContextStream() << " = " << m_LocalPrefix << v.name << ";" << std::endl; - } - } - } - - //------------------------------------------------------------------------ - // PrettyPrinter::EnvironmentBase virtuals - //------------------------------------------------------------------------ - virtual std::string getName(const Token &name) final - { - // If variable with this name isn't found, try and get name from context - auto var = m_VariablesReferenced.find(name.lexeme); - if(var == m_VariablesReferenced.end()) { - return getContextName(name); - } - // Otherwise - else { - // Set flag to indicate that variable has been referenced - var->second = true; - - // Add local prefix to variable name - return m_LocalPrefix + name.lexeme; - } - } - - virtual CodeStream &getStream() final - { - return m_Contents; - } - -private: - //------------------------------------------------------------------------ - // Members - //------------------------------------------------------------------------ - const G &m_Group; - std::ostringstream m_ContentsStream; - CodeStream m_Contents; - const std::string m_LocalPrefix; - const GetIndexFn m_GetIndex; - std::unordered_map m_VariablesReferenced; -}; - template void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, const std::string &index, R getVarRefIndex) diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc new file mode 100644 index 0000000000..fc621462c3 --- /dev/null +++ b/src/genn/genn/code_generator/environment.cc @@ -0,0 +1,52 @@ +#include "code_generator/environment.h" + +using namespace GeNN::CodeGenerator; + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentExternal +//---------------------------------------------------------------------------- +std::string EnvironmentExternal::define(const Transpiler::Token&) +{ + throw std::runtime_error("Cannot declare variable in external environment"); +} +//---------------------------------------------------------------------------- +CodeStream &EnvironmentExternal::getContextStream() const +{ + return std::visit( + Transpiler::Utils::Overload{ + [](std::reference_wrapper enclosing)->CodeStream& { return enclosing.get().getStream(); }, + [](std::reference_wrapper os)->CodeStream& { return os.get(); }}, + getContext()); +} +//---------------------------------------------------------------------------- +std::string EnvironmentExternal::getContextName(const Transpiler::Token &name) const +{ + return std::visit( + Transpiler::Utils::Overload{ + [&name](std::reference_wrapper enclosing)->std::string { return enclosing.get().getName(name); }, + [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Variable '" + name.lexeme + "' undefined"); }}, + getContext()); +} + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentSubstitute +//---------------------------------------------------------------------------- +std::string EnvironmentSubstitute::getName(const Transpiler::Token &name) +{ + // If there isn't a substitution for this name, try and get name from context + auto sub = m_VarSubstitutions.find(name.lexeme); + if(sub == m_VarSubstitutions.end()) { + return getContextName(name); + } + // Otherwise, return substitution + else { + return sub->second; + } +} +//------------------------------------------------------------------------ +void EnvironmentSubstitute::addSubstitution(const std::string &source, const std::string &destination) +{ + if(!m_VarSubstitutions.emplace(source, destination).second) { + throw std::runtime_error("Redeclaration of substitution '" + source + "'"); + } +} \ No newline at end of file diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index 10b53f21ae..4f62e42564 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -25,6 +25,7 @@ + @@ -77,6 +78,7 @@ + From 139cb09b1604fc30e571ded1b059f3406d150e88 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 30 Jan 2023 11:28:55 +0000 Subject: [PATCH 112/725] started hooking up environments into CPU backend --- .../code_generator/customUpdateGroupMerged.h | 8 +- .../genn/genn/code_generator/environment.h | 92 ++++++++++++- .../backends/single_threaded_cpu/backend.cc | 60 +++++---- .../code_generator/customUpdateGroupMerged.cc | 127 +++++------------- src/genn/genn/code_generator/environment.cc | 43 ++++++ tests/unit/typeChecker.cc | 2 + 6 files changed, 201 insertions(+), 131 deletions(-) diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index b0d9a6f6fe..ca4139b8e7 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -35,11 +35,13 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged + void addVarNameSubstitution(const std::vector &variables) + { + for(const auto &v : variables) { + addSubstitution(v.name, "group->" + v.name); + } + } + + template + void addParamValueSubstitution(const std::vector ¶mNames, const std::unordered_map &values, G isHeterogeneousFn) + { + if(paramNames.size() != values.size()) { + throw std::runtime_error("Number of parameters does not match number of values"); + } + + for(const auto &p : paramNames) { + if(isHeterogeneousFn(p)) { + addSubstitution(p, "group->" + p); + } + else { + // **TODO** scalar suffix + addSubstitution(p, Utils::writePreciseString(values.at(p))); + } + } + } + + template + void addVarValueSubstitution(const std::vector &variables, const std::unordered_map &values, G isHeterogeneousFn) + { + if(variables.size() != values.size()) { + throw std::runtime_error("Number of variables does not match number of values"); + } + + for(const auto &v : variables) { + if(isHeterogeneousFn(v.name)) { + addVarSubstitution(v.name, "group->" + v.name); + } + else { + addVarSubstitution(v.name, Utils::writePreciseString(values.at(v.name))); + } + } + } + private: //------------------------------------------------------------------------ // Members @@ -87,6 +128,51 @@ class EnvironmentSubstitute : public EnvironmentExternal std::unordered_map m_VarSubstitutions; }; +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentSubstituteCondInit +//---------------------------------------------------------------------------- +//! Pretty printing environment simply allowing substitutions to be implemented +class EnvironmentSubstituteCondInit : public EnvironmentExternal +{ + using EnvironmentBase = Transpiler::PrettyPrinter::EnvironmentBase; +public: + EnvironmentSubstituteCondInit(EnvironmentBase &enclosing) + : EnvironmentExternal(enclosing), m_Contents(m_ContentsStream) + { + } + + EnvironmentSubstituteCondInit(CodeStream &os) + : EnvironmentExternal(os), m_Contents(m_ContentsStream) + { + } + ~EnvironmentSubstituteCondInit(); + + //------------------------------------------------------------------------ + // PrettyPrinter::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::string getName(const Transpiler::Token &name) final; + + virtual CodeStream &getStream() final + { + return m_Contents; + } + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + void addSubstitution(const std::string &source, const std::string &destination, + const std::string &initialiser); + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + std::ostringstream m_ContentsStream; + CodeStream m_Contents; + std::unordered_map> m_VarSubstitutions; + +}; + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentLocalVarCache //---------------------------------------------------------------------------- diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 6df47e6ffe..78237afe4b 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -6,6 +6,7 @@ // GeNN code generator includes #include "code_generator/codeGenUtils.h" #include "code_generator/codeStream.h" +#include "code_generator/environment.h" #include "code_generator/modelSpecMerged.h" #include "code_generator/substitutions.h" @@ -566,11 +567,10 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - Substitutions popSubs(&funcSubs); - popSubs.addVarSubstitution("id", "i"); - // Generate custom update - c.generateCustomUpdate(*this, os, popSubs); + EnvironmentSubstitute env(os); + env.addSubstitution("id", "i"); + c.generateCustomUpdate(*this, env); // Loop through reduction targets and generate reduction for (const auto &r : reductionTargets) { @@ -589,14 +589,13 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - Substitutions popSubs(&funcSubs); - popSubs.addVarSubstitution("id", "i"); - // Generate custom update - c.generateCustomUpdate(*this, os, popSubs); + EnvironmentSubstitute env(os); + env.addSubstitution("id", "i"); + c.generateCustomUpdate(*this, env); // Write back reductions - genWriteBackReductions(os, c, popSubs["id"]); + genWriteBackReductions(os, c, "i"); } } } @@ -650,27 +649,29 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - Substitutions synSubs(&funcSubs); + // Add pre and postsynaptic indices to substitutions + EnvironmentSubstitute synEnv(os); + synEnv.addSubstitution("id_pre", "i"); + synEnv.addSubstitution("id_post", "j"); + + // **TODO** DEPENDENCIES! + EnvironmentSubstituteCondInit synEnvCond(synEnv); if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Calculate index of synapse and use it to look up postsynaptic index os << "const unsigned int n = (i * group->rowStride) + s;" << std::endl; os << "const unsigned int j = group->ind[n];" << std::endl; - synSubs.addVarSubstitution("id_syn", "n"); + synEnv.addSubstitution("id_syn", "n"); } else { - synSubs.addVarSubstitution("id_syn", "(i * group->numTrgNeurons) + j"); + synEnv.addSubstitution("id_syn", "(i * group->numTrgNeurons) + j"); } - // Add pre and postsynaptic indices to substitutions - synSubs.addVarSubstitution("id_pre", "i"); - synSubs.addVarSubstitution("id_post", "j"); - - // Call custom update handler - c.generateCustomUpdate(*this, os, synSubs); + // Generate custom update + c.generateCustomUpdate(*this, synEnvCond); // Write back reductions - genWriteBackReductions(os, c, synSubs["id_syn"]); + genWriteBackReductions(os, c, synEnv["id_syn"]); } } } @@ -748,15 +749,18 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - Substitutions synSubs(&funcSubs); - synSubs.addVarSubstitution("id_syn", "(i * group->numTrgNeurons) + j"); - - // Add pre and postsynaptic indices to substitutions - synSubs.addVarSubstitution("id_pre", "i"); - synSubs.addVarSubstitution("id_post", "j"); - - // Call custom update handler - c.generateCustomUpdate(*this, os, synSubs); + // Add pre and postsynaptic indices to environment + EnvironmentSubstitute synEnv(os); + synEnv.addSubstitution("id_pre", "i"); + synEnv.addSubstitution("id_post", "j"); + + // Add conditional initialisation code to calculate synapse index + EnvironmentSubstituteCondInit synCachedEnv(synEnv); + synCachedEnv.addSubstitution("id_syn", "idSyn", + "const unsigned int idSyn = (i * group->numTrgNeurons) + j;"); + + // Generate custom update + c.generateCustomUpdate(*this, synCachedEnv); // Update transpose variable os << "group->" << transposeVarName << "Transpose[(j * group->numSrcNeurons) + i] = l" << transposeVarName << ";" << std::endl; diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 17d4acd839..714695288c 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -28,87 +28,33 @@ using namespace GeNN::Transpiler; namespace { template -void genCustomUpdate(CodeStream &os, Substitutions &baseSubs, const C &cg, const std::string &index, - R getVarRefIndex) +void genCustomUpdate(Transpiler::PrettyPrinter::EnvironmentBase &envBase, const C &cg, + const std::string &index, R getVarRefIndex) { - Substitutions updateSubs(&baseSubs); - + EnvironmentSubstitute envSubs(envBase); const CustomUpdateModels::Base *cm = cg.getArchetype().getCustomUpdateModel(); - const auto varRefs = cm->getVarRefs(); + + subs.addParamValueSubstitution(cm->getParamNames(), cg.getArchetype().getParams(), + [&cg](const std::string &p) { return cg.isParamHeterogeneous(p); }); + subs.addVarValueSubstitution(cm->getDerivedParams(), cg.getArchetype().getDerivedParams(), + [&cg](const std::string &p) { return cg.isDerivedParamHeterogeneous(p); }); + subs.addVarNameSubstitution(cm->getExtraGlobalParams()); - // Loop through variables - for(const auto &v : cm->getVars()) { - if(v.access & VarAccessMode::READ_ONLY) { - os << "const "; - } - os << v.type->getName() << " l" << v.name; - - // If this isn't a reduction, read value from memory - // **NOTE** by not initialising these variables for reductions, - // compilers SHOULD emit a warning if user code doesn't set it to something - if(!(v.access & VarAccessModeAttribute::REDUCE)) { - os << " = group->" << v.name << "["; - os << cg.getVarIndex(getVarAccessDuplication(v.access), - updateSubs[index]); - os << "]"; - } - os << ";" << std::endl; - } - // Loop through variable references - for(const auto &v : varRefs) { - if(v.access == VarAccessMode::READ_ONLY) { - os << "const "; - } - - os << v.type->getName() << " l" << v.name; - - // If this isn't a reduction, read value from memory - // **NOTE** by not initialising these variables for reductions, - // compilers SHOULD emit a warning if user code doesn't set it to something - if(!(v.access & VarAccessModeAttribute::REDUCE)) { - os << " = " << "group->" << v.name << "["; - os << getVarRefIndex(cg.getArchetype().getVarReferences().at(v.name), - updateSubs[index]); - os << "]"; - } - os << ";" << std::endl; - } + // Create an environment which caches variables in local variables if they are accessed + EnvironmentLocalVarCache varSubs( + cg, envSubs, + [index, &cg](const Models::VarInit&, VarAccess a) + { + return cg.getVarIndex(getVarAccessDuplication(a), index); + }); - updateSubs.addVarNameSubstitution(cm->getVars(), "", "l"); - updateSubs.addVarNameSubstitution(cm->getVarRefs(), "", "l"); - updateSubs.addParamValueSubstitution(cm->getParamNames(), cg.getArchetype().getParams(), - [&cg](const std::string &p) { return cg.isParamHeterogeneous(p); }, - "", "group->"); - updateSubs.addVarValueSubstitution(cm->getDerivedParams(), cg.getArchetype().getDerivedParams(), - [&cg](const std::string &p) { return cg.isDerivedParamHeterogeneous(p); }, - "", "group->"); - updateSubs.addVarNameSubstitution(cm->getExtraGlobalParams(), "", "group->"); - - std::string code = cm->getUpdateCode(); - updateSubs.applyCheckUnreplaced(code, "custom update : merged" + std::to_string(cg.getIndex())); - //code = ensureFtype(code, modelMerged.getModel().getPrecision()); - os << code; - - // Write read/write variables back to global memory - for(const auto &v : cm->getVars()) { - if(v.access & VarAccessMode::READ_WRITE) { - os << "group->" << v.name << "["; - os << cg.getVarIndex(getVarAccessDuplication(v.access), - updateSubs[index]); - os << "] = l" << v.name << ";" << std::endl; - } - } + // Create an environment which caches variable references in local variables if they are accessed + EnvironmentLocalVarCache varRefSubs( + cg, envSubs, [](const Models::VarReference &v, VarAccessMode){ return getVarRefIndex(v); }); - // Write read/write variable references back to global memory - for(const auto &v : varRefs) { - if(v.access == VarAccessMode::READ_WRITE) { - os << "group->" << v.name << "["; - os << getVarRefIndex(cg.getArchetype().getVarReferences().at(v.name), - updateSubs[index]); - os << "] = l" << v.name << ";" << std::endl; - } - } + // Pretty print previously parsed update statements + PrettyPrinter::print(cg.getUpdateStatements(), varRefSubs, cg.getTypeContext()); } } // Anonymous namespace @@ -200,29 +146,16 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase&, CodeStrea { // Build initial environment with ID etc // **TODO** this should happen in backend - EnvironmentSubstitute subs(os); - subs.addSubstitution("id", popSubs["id"]); - - // Create an environment which caches variables in local variables if they are accessed - EnvironmentLocalVarCache varSubs( - getArchetype(), subs, - [this](const Models::VarInit&, VarAccess a) - { - return getVarIndex(getVarAccessDuplication(a), "id"); - }); + EnvironmentSubstitute envBase(os); + envBase.addSubstitution("id", popSubs["id"]); - // Create an environment which caches variable references in local variables if they are accessed - EnvironmentLocalVarCache varRefSubs( - getArchetype(), subs, - [this](const Models::VarReference &v, VarAccessMode) - { - return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, - getVarAccessDuplication(v.getVar().access), - "id"); - }); - - // Pretty print previously parsed update statements - PrettyPrinter::print(m_UpdateStatements, varRefSubs, getTypeContext()); + genCustomUpdate(envBase, *this, "id", + [this](const Models::VarReference &v) + { + return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, + getVarAccessDuplication(v.getVar().access), + "id"); + }); } //---------------------------------------------------------------------------- std::string CustomUpdateGroupMerged::getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index fc621462c3..9a5fc3acd1 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -49,4 +49,47 @@ void EnvironmentSubstitute::addSubstitution(const std::string &source, const std if(!m_VarSubstitutions.emplace(source, destination).second) { throw std::runtime_error("Redeclaration of substitution '" + source + "'"); } +} + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentSubstituteCondInit +//---------------------------------------------------------------------------- +EnvironmentSubstituteCondInit::~EnvironmentSubstituteCondInit() +{ + // Loop through substitututions + for(const auto &v : m_VarSubstitutions) { + // If variable has been referenced, write out initialiser + if (std::get<0>(v.second)) { + getContextStream() << std::get<2>(v.second) << std::endl; + } + } + + // Write contents to context stream + getContextStream() << m_ContentsStream.str(); +} +//------------------------------------------------------------------------ +std::string EnvironmentSubstituteCondInit::getName(const Transpiler::Token &name) +{ + // If variable with this name isn't found, try and get name from context + auto var = m_VarSubstitutions.find(name.lexeme); + if(var == m_VarSubstitutions.end()) { + return getContextName(name); + } + // Otherwise + else { + // Set flag to indicate that variable has been referenced + std::get<0>(var->second) = true; + + // Add local prefix to variable name + return std::get<1>(var->second); + } +} + +//------------------------------------------------------------------------ +void EnvironmentSubstituteCondInit::addSubstitution(const std::string &source, const std::string &destination, + const std::string &initialiser) +{ + if(!m_VarSubstitutions.try_emplace(source, false, destination, initialiser).second) { + throw std::runtime_error("Redeclaration of substitution '" + source + "'"); + } } \ No newline at end of file diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 925bdb387e..4f8901d541 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -187,6 +187,8 @@ TEST(TypeChecker, ArraySubscript) EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); } + + // Pointer to pointer, double indexing // Float array indexing EXPECT_THROW({ From f29cfcc0eafc21abb80d89945d76f14bb426f7fd Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 6 Feb 2023 09:21:42 +0000 Subject: [PATCH 113/725] WIP --- .../genn/genn/code_generator/environment.h | 72 +++++---------- .../backends/single_threaded_cpu/backend.cc | 87 +++++++++---------- src/genn/genn/code_generator/environment.cc | 75 +++++++--------- 3 files changed, 99 insertions(+), 135 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index aabdc8bec0..eae797aa27 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -60,8 +60,18 @@ class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase //! Standard pretty printing environment simply allowing substitutions to be implemented class EnvironmentSubstitute : public EnvironmentExternal { + using EnvironmentBase = Transpiler::PrettyPrinter::EnvironmentBase; public: - using EnvironmentExternal::EnvironmentExternal; + EnvironmentSubstitute(EnvironmentBase &enclosing) + : EnvironmentExternal(enclosing), m_Contents(m_ContentsStream) + { + } + + EnvironmentSubstitute(CodeStream &os) + : EnvironmentExternal(os), m_Contents(m_ContentsStream) + { + } + ~EnvironmentSubstitute(); //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals @@ -70,14 +80,17 @@ class EnvironmentSubstitute : public EnvironmentExternal virtual CodeStream &getStream() final { - return getContextStream(); + return m_Contents; } //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - void addSubstitution(const std::string &source, const std::string &destination); + void addSubstitution(const std::string &source, const std::string &destination, + std::vector initialisers = {}); + size_t addInitialiser(const std::string &initialiser); + template void addVarNameSubstitution(const std::vector &variables) { @@ -87,7 +100,8 @@ class EnvironmentSubstitute : public EnvironmentExternal } template - void addParamValueSubstitution(const std::vector ¶mNames, const std::unordered_map &values, G isHeterogeneousFn) + void addParamValueSubstitution(const std::vector ¶mNames, const std::unordered_map &values, + G isHeterogeneousFn) { if(paramNames.size() != values.size()) { throw std::runtime_error("Number of parameters does not match number of values"); @@ -105,7 +119,8 @@ class EnvironmentSubstitute : public EnvironmentExternal } template - void addVarValueSubstitution(const std::vector &variables, const std::unordered_map &values, G isHeterogeneousFn) + void addVarValueSubstitution(const std::vector &variables, const std::unordered_map &values, + G isHeterogeneousFn) { if(variables.size() != values.size()) { throw std::runtime_error("Number of variables does not match number of values"); @@ -121,56 +136,15 @@ class EnvironmentSubstitute : public EnvironmentExternal } } -private: - //------------------------------------------------------------------------ - // Members - //------------------------------------------------------------------------ - std::unordered_map m_VarSubstitutions; -}; - -//---------------------------------------------------------------------------- -// GeNN::CodeGenerator::EnvironmentSubstituteCondInit -//---------------------------------------------------------------------------- -//! Pretty printing environment simply allowing substitutions to be implemented -class EnvironmentSubstituteCondInit : public EnvironmentExternal -{ - using EnvironmentBase = Transpiler::PrettyPrinter::EnvironmentBase; -public: - EnvironmentSubstituteCondInit(EnvironmentBase &enclosing) - : EnvironmentExternal(enclosing), m_Contents(m_ContentsStream) - { - } - - EnvironmentSubstituteCondInit(CodeStream &os) - : EnvironmentExternal(os), m_Contents(m_ContentsStream) - { - } - ~EnvironmentSubstituteCondInit(); - - //------------------------------------------------------------------------ - // PrettyPrinter::EnvironmentBase virtuals - //------------------------------------------------------------------------ - virtual std::string getName(const Transpiler::Token &name) final; - - virtual CodeStream &getStream() final - { - return m_Contents; - } - - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - void addSubstitution(const std::string &source, const std::string &destination, - const std::string &initialiser); - private: //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ std::ostringstream m_ContentsStream; CodeStream m_Contents; - std::unordered_map> m_VarSubstitutions; - + std::unordered_map>> m_VarSubstitutions; + std::vector> m_Initialisers; + }; //---------------------------------------------------------------------------- diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 78237afe4b..c38a27dd72 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -11,6 +11,7 @@ #include "code_generator/substitutions.h" using namespace GeNN::CodeGenerator; +using namespace GeNN::Transpiler; //-------------------------------------------------------------------------- // Anonymous namespace @@ -74,36 +75,35 @@ const std::vector &getFunctionTemplates(const s } //----------------------------------------------------------------------- template -void genKernelIteration(CodeStream &os, const G &g, size_t numKernelDims, const Substitutions &kernelSubs, BackendBase::Handler handler) +void genKernelIteration(PrettyPrinter::EnvironmentBase &env, const G &g, size_t numKernelDims, BackendBase::Handler handler) { - Substitutions varSubs(&kernelSubs); + EnvironmentSubstitute varEnv(env); // Define recursive function to generate nested kernel initialisation loops // **NOTE** this is a std::function as type of auto lambda couldn't be determined inside for recursive call std::function generateRecursive = - [&handler, &os, &g, &varSubs, &generateRecursive, numKernelDims] + [&handler, &varEnv, &g, &generateRecursive, numKernelDims] (size_t depth) { // Loop through this kernel dimensions const std::string idxVar = "k" + std::to_string(depth); - os << "for(unsigned int " << idxVar << " = 0; " << idxVar << " < " << g.getKernelSize(depth) << "; " << idxVar << "++)"; + varEnv.getStream() << "for(unsigned int " << idxVar << " = 0; " << idxVar << " < " << g.getKernelSize(depth) << "; " << idxVar << "++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(varEnv.getStream()); + EnvironmentSubstitute loopEnv(varEnv); // Add substitution for this kernel index - varSubs.addVarSubstitution("id_kernel_" + std::to_string(depth), idxVar); + loopEnv.addSubstitution("id_kernel_" + std::to_string(depth), idxVar); // If we've recursed through all dimensions if (depth == (numKernelDims - 1)) { // Generate kernel index and use as "synapse" index // **TODO** rename - os << "const unsigned int kernelInd = "; - g.genKernelIndex(os, varSubs); - os << ";" << std::endl; - varSubs.addVarSubstitution("id_syn", "kernelInd"); + const size_t kernelInit = loopEnv.addInitialiser("const unsigned int kernelInd = " + g.genKernelIndex(loopEnv) + ";"); + loopEnv.addVarSubstitution("id_syn", "kernelInd", kernelInit); // Call handler - handler(os, varSubs); + handler(loopEnv); } // Otherwise, recurse else { @@ -526,9 +526,9 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision()->getName())); - funcSubs.addVarSubstitution("t", "t"); - funcSubs.addVarSubstitution("batch", "0"); + EnvironmentSubstitute funcEnv(os); + funcEnv.addSubstitution("t", "t"); + funcEnv.addSubstitution("batch", "0"); // Loop through host update groups and generate code for those in this custom update group for (const auto &cg : modelMerged.getMergedCustomConnectivityHostUpdateGroups()) { @@ -568,7 +568,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged CodeStream::Scope b(os); // Generate custom update - EnvironmentSubstitute env(os); + EnvironmentSubstitute env(funcEnv); env.addSubstitution("id", "i"); c.generateCustomUpdate(*this, env); @@ -590,7 +590,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged CodeStream::Scope b(os); // Generate custom update - EnvironmentSubstitute env(os); + EnvironmentSubstitute env(funcEnv); env.addSubstitution("id", "i"); c.generateCustomUpdate(*this, env); @@ -618,16 +618,17 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged os << "const auto *group = &mergedCustomUpdateWUGroup" << c.getIndex() << "[g]; " << std::endl; const SynapseGroupInternal *sg = c.getArchetype().getSynapseGroup(); + EnvironmentSubstitute synEnv(funcEnv); if (sg->getMatrixType() & SynapseMatrixWeight::KERNEL) { - genKernelIteration(os, c, c.getArchetype().getSynapseGroup()->getKernelSize().size(), funcSubs, + genKernelIteration(c, c.getArchetype().getSynapseGroup()->getKernelSize().size(), synEnv, [&c, &modelMerged, this] - (CodeStream &os, Substitutions &subs) + (PrettyPrinter::EnvironmentBase &env) { // Call custom update handler - c.generateCustomUpdate(*this, os, subs); + c.generateCustomUpdate(*this, env); // Write back reductions - genWriteBackReductions(os, c, subs["id_syn"]); + genWriteBackReductions(env, c); }); } else { @@ -649,26 +650,28 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - // Add pre and postsynaptic indices to substitutions - EnvironmentSubstitute synEnv(os); + // Add presynaptic index to substitutions synEnv.addSubstitution("id_pre", "i"); - synEnv.addSubstitution("id_post", "j"); - - // **TODO** DEPENDENCIES! - EnvironmentSubstituteCondInit synEnvCond(synEnv); + + // If connectivity is sparse if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - // Calculate index of synapse and use it to look up postsynaptic index - os << "const unsigned int n = (i * group->rowStride) + s;" << std::endl; - os << "const unsigned int j = group->ind[n];" << std::endl; + // Add initialisers to calculate synaptic index and thus lookup postsynaptic index + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * group->rowStride) + s;"); + const size_t jInit = synEnv.addInitialiser("const unsigned int j = group->ind[idSyn];"); - synEnv.addSubstitution("id_syn", "n"); + // Add substitutions + synEnv.addSubstitution("id_syn", "idSyn", {idSynInit}); + synEnv.addSubstitution("id_post", "j", {jInit, idSynInit}); } else { - synEnv.addSubstitution("id_syn", "(i * group->numTrgNeurons) + j"); + synEnv.addSubstitution("id_post", "j"); + + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * group->numTrgNeurons) + j;"); + synEnv.addSubstitution("id_syn", "idSyn", {idSynInit}); } // Generate custom update - c.generateCustomUpdate(*this, synEnvCond); + c.generateCustomUpdate(*this, synEnv); // Write back reductions genWriteBackReductions(os, c, synEnv["id_syn"]); @@ -702,13 +705,10 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged CodeStream::Scope b(os); // Configure substitutions - Substitutions popSubs(&funcSubs); - popSubs.addVarSubstitution("id_pre", "i"); - - // If this neuron group requires a simulation RNG, substitute in global RNG - if(c.getArchetype().isRowSimRNGRequired()) { - popSubs.addVarSubstitution("rng", "hostRNG"); - } + EnvironmentSubstitute cuEnv(funcEnv); + cuEnv.addSubstitution("id_pre", "i"); + cuEnv.addSubstitution("rng", "hostRNG"); + c.generateUpdate(*this, os, model.getBatchSize(), popSubs); } } @@ -750,17 +750,16 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged CodeStream::Scope b(os); // Add pre and postsynaptic indices to environment - EnvironmentSubstitute synEnv(os); + EnvironmentSubstitute synEnv(funcEnv); synEnv.addSubstitution("id_pre", "i"); synEnv.addSubstitution("id_post", "j"); // Add conditional initialisation code to calculate synapse index - EnvironmentSubstituteCondInit synCachedEnv(synEnv); - synCachedEnv.addSubstitution("id_syn", "idSyn", - "const unsigned int idSyn = (i * group->numTrgNeurons) + j;"); + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * group->numTrgNeurons) + j;"); + synEnv.addSubstitution("id_syn", "idSyn", {idSynInit}); // Generate custom update - c.generateCustomUpdate(*this, synCachedEnv); + c.generateCustomUpdate(*this, synEnv); // Update transpose variable os << "group->" << transposeVarName << "Transpose[(j * group->numSrcNeurons) + i] = l" << transposeVarName << ";" << std::endl; diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index 9a5fc3acd1..dd4f3eed9c 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -1,5 +1,11 @@ #include "code_generator/environment.h" +// Standard C++ includes +#include + +// Standard C includes +#include + using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- @@ -31,65 +37,50 @@ std::string EnvironmentExternal::getContextName(const Transpiler::Token &name) c //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentSubstitute //---------------------------------------------------------------------------- -std::string EnvironmentSubstitute::getName(const Transpiler::Token &name) -{ - // If there isn't a substitution for this name, try and get name from context - auto sub = m_VarSubstitutions.find(name.lexeme); - if(sub == m_VarSubstitutions.end()) { - return getContextName(name); - } - // Otherwise, return substitution - else { - return sub->second; - } -} -//------------------------------------------------------------------------ -void EnvironmentSubstitute::addSubstitution(const std::string &source, const std::string &destination) +EnvironmentSubstitute::~EnvironmentSubstitute() { - if(!m_VarSubstitutions.emplace(source, destination).second) { - throw std::runtime_error("Redeclaration of substitution '" + source + "'"); - } -} - -//---------------------------------------------------------------------------- -// GeNN::CodeGenerator::EnvironmentSubstituteCondInit -//---------------------------------------------------------------------------- -EnvironmentSubstituteCondInit::~EnvironmentSubstituteCondInit() -{ - // Loop through substitututions - for(const auto &v : m_VarSubstitutions) { - // If variable has been referenced, write out initialiser - if (std::get<0>(v.second)) { - getContextStream() << std::get<2>(v.second) << std::endl; + // Loop through initialiser + for(const auto &i : m_Initialisers) { + // If variable requiring initialiser has been referenced, write out initialiser + if (i.first) { + getContextStream() << i.second << std::endl; } } // Write contents to context stream getContextStream() << m_ContentsStream.str(); } -//------------------------------------------------------------------------ -std::string EnvironmentSubstituteCondInit::getName(const Transpiler::Token &name) +//---------------------------------------------------------------------------- +std::string EnvironmentSubstitute::getName(const Transpiler::Token &name) { - // If variable with this name isn't found, try and get name from context + // If there isn't a substitution for this name, try and get name from context auto var = m_VarSubstitutions.find(name.lexeme); if(var == m_VarSubstitutions.end()) { return getContextName(name); } - // Otherwise + // Otherwise, return substitution else { - // Set flag to indicate that variable has been referenced - std::get<0>(var->second) = true; - - // Add local prefix to variable name - return std::get<1>(var->second); + // If this variable relies on any initialiser statements, mark these initialisers as required + for(const auto i : var->second.second) { + m_Initialisers.at(i).first = true; + } + + return var->second.first; } } - //------------------------------------------------------------------------ -void EnvironmentSubstituteCondInit::addSubstitution(const std::string &source, const std::string &destination, - const std::string &initialiser) +void EnvironmentSubstitute::addSubstitution(const std::string &source, const std::string &destination, + std::vector initialisers) { - if(!m_VarSubstitutions.try_emplace(source, false, destination, initialiser).second) { + assert(std::all_of(initialisers.cbegin(), initialisers.cend(), + [this](size_t i) { return i < m_Initialisers.size(); })); + + if(!m_VarSubstitutions.try_emplace(source, destination, initialisers).second) { throw std::runtime_error("Redeclaration of substitution '" + source + "'"); } +} +//------------------------------------------------------------------------ +size_t EnvironmentSubstitute::addInitialiser(const std::string &initialiser) +{ + m_Initialisers.emplace_back(false, initialiser); } \ No newline at end of file From 4b60b2debdf2bc375f72b698a40809c70dea3dc5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 21 Feb 2023 11:36:19 +0000 Subject: [PATCH 114/725] custom update group merged now entirely using new environments and pretty printing --- .../code_generator/customUpdateGroupMerged.h | 15 +- .../genn/genn/code_generator/environment.h | 4 +- .../code_generator/customUpdateGroupMerged.cc | 136 +++++++++--------- 3 files changed, 81 insertions(+), 74 deletions(-) diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index ca4139b8e7..286a466686 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -2,6 +2,7 @@ // GeNN code generator includes #include "code_generator/codeGenUtils.h" +#include "code_generator/environment.h" #include "code_generator/groupMerged.h" // GeNN transpiler includes @@ -69,6 +70,8 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged> &groups); @@ -99,6 +104,12 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMergedgetKernelSize(); } + + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + //! List of statements parsed and type-checked in constructor; and used to generate code + Transpiler::Statement::StatementList m_UpdateStatements; }; // ---------------------------------------------------------------------------- @@ -125,8 +136,6 @@ class GENN_EXPORT CustomUpdateWUGroupMerged : public CustomUpdateWUGroupMergedBa runnerVarDecl, runnerMergedStructAlloc, name); } - void generateCustomUpdate(const BackendBase &backend, EnvironmentExternal &env) const; - //---------------------------------------------------------------------------- // Static constants //---------------------------------------------------------------------------- @@ -157,8 +166,6 @@ class GENN_EXPORT CustomUpdateTransposeWUGroupMerged : public CustomUpdateWUGrou runnerVarDecl, runnerMergedStructAlloc, name); } - void generateCustomUpdate(const BackendBase &backend, EnvironmentExternal &env) const; - //---------------------------------------------------------------------------- // Static constants //---------------------------------------------------------------------------- diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index eae797aa27..a3636ad1a8 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -128,10 +128,10 @@ class EnvironmentSubstitute : public EnvironmentExternal for(const auto &v : variables) { if(isHeterogeneousFn(v.name)) { - addVarSubstitution(v.name, "group->" + v.name); + addSubstitution(v.name, "group->" + v.name); } else { - addVarSubstitution(v.name, Utils::writePreciseString(values.at(v.name))); + addSubstitution(v.name, Utils::writePreciseString(values.at(v.name))); } } } diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 714695288c..936a1cad52 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -22,42 +22,6 @@ using namespace GeNN::CodeGenerator; using namespace GeNN::Transpiler; -//-------------------------------------------------------------------------- -// Anonymous namespace -//-------------------------------------------------------------------------- -namespace -{ -template -void genCustomUpdate(Transpiler::PrettyPrinter::EnvironmentBase &envBase, const C &cg, - const std::string &index, R getVarRefIndex) -{ - EnvironmentSubstitute envSubs(envBase); - const CustomUpdateModels::Base *cm = cg.getArchetype().getCustomUpdateModel(); - - subs.addParamValueSubstitution(cm->getParamNames(), cg.getArchetype().getParams(), - [&cg](const std::string &p) { return cg.isParamHeterogeneous(p); }); - subs.addVarValueSubstitution(cm->getDerivedParams(), cg.getArchetype().getDerivedParams(), - [&cg](const std::string &p) { return cg.isDerivedParamHeterogeneous(p); }); - subs.addVarNameSubstitution(cm->getExtraGlobalParams()); - - - // Create an environment which caches variables in local variables if they are accessed - EnvironmentLocalVarCache varSubs( - cg, envSubs, - [index, &cg](const Models::VarInit&, VarAccess a) - { - return cg.getVarIndex(getVarAccessDuplication(a), index); - }); - - // Create an environment which caches variable references in local variables if they are accessed - EnvironmentLocalVarCache varRefSubs( - cg, envSubs, [](const Models::VarReference &v, VarAccessMode){ return getVarRefIndex(v); }); - - // Pretty print previously parsed update statements - PrettyPrinter::print(cg.getUpdateStatements(), varRefSubs, cg.getTypeContext()); -} -} // Anonymous namespace - //---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomUpdateGroupMerged //---------------------------------------------------------------------------- @@ -142,20 +106,37 @@ boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, Substitutions &popSubs) const +void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, EnvironmentExternal &env) const { - // Build initial environment with ID etc - // **TODO** this should happen in backend - EnvironmentSubstitute envBase(os); - envBase.addSubstitution("id", popSubs["id"]); + // Add parameters, derived parameters and EGPs to environment + EnvironmentSubstitute envSubs(env); + const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); + envSubs.addParamValueSubstitution(cm->getParamNames(), getArchetype().getParams(), + [this](const std::string &p) { return isParamHeterogeneous(p); }); + envSubs.addVarValueSubstitution(cm->getDerivedParams(), getArchetype().getDerivedParams(), + [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); + envSubs.addVarNameSubstitution(cm->getExtraGlobalParams()); + + // Create an environment which caches variables in local variables if they are accessed + EnvironmentLocalVarCache varSubs( + getArchetype(), envSubs, + [this](const Models::VarInit&, VarAccess a) + { + return getVarIndex(getVarAccessDuplication(a), "id"); + }); - genCustomUpdate(envBase, *this, "id", - [this](const Models::VarReference &v) - { - return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, - getVarAccessDuplication(v.getVar().access), - "id"); - }); + // Create an environment which caches variable references in local variables if they are accessed + EnvironmentLocalVarCache varRefSubs( + getArchetype(), varSubs, + [this](const Models::VarReference &v, VarAccessMode) + { + return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, + getVarAccessDuplication(v.getVar().access), + "id"); + }); + + // Pretty print previously parsed update statements + PrettyPrinter::print(getUpdateStatements(), varRefSubs, getTypeContext()); } //---------------------------------------------------------------------------- std::string CustomUpdateGroupMerged::getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const @@ -234,6 +215,38 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWUGroupMergedBase::getHashDi return hash.get_digest(); } //---------------------------------------------------------------------------- +void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &backend, EnvironmentExternal &env) const +{ + // Add parameters, derived parameters and EGPs to environment + EnvironmentSubstitute envSubs(env); + const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); + envSubs.addParamValueSubstitution(cm->getParamNames(), getArchetype().getParams(), + [this](const std::string &p) { return isParamHeterogeneous(p); }); + envSubs.addVarValueSubstitution(cm->getDerivedParams(), getArchetype().getDerivedParams(), + [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); + envSubs.addVarNameSubstitution(cm->getExtraGlobalParams()); + + // Create an environment which caches variables in local variables if they are accessed + EnvironmentLocalVarCache varSubs( + getArchetype(), envSubs, + [this](const Models::VarInit&, VarAccess a) + { + return getVarIndex(getVarAccessDuplication(a), "id_syn"); + }); + + // Create an environment which caches variable references in local variables if they are accessed + EnvironmentLocalVarCache varRefSubs( + getArchetype(), varSubs, + [this](const Models::WUVarReference &v, VarAccessMode) + { + return getVarRefIndex(getVarAccessDuplication(v.getVar().access), + "id_syn"); + }); + + // Pretty print previously parsed update statements + PrettyPrinter::print(getUpdateStatements(), varRefSubs, getTypeContext()); +} +//---------------------------------------------------------------------------- std::string CustomUpdateWUGroupMergedBase::getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? @@ -344,37 +357,24 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const } // Add EGPs to struct typeEnvironment.defineEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); + + // Scan, parse and type-check update code + ErrorHandler errorHandler; + const std::string code = upgradeCodeString(cm->getUpdateCode()); + const auto tokens = Scanner::scanSource(code, errorHandler); + m_UpdateStatements = Parser::parseBlockItemList(tokens, errorHandler); + TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, typeContext, errorHandler); } // ---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomUpdateWUGroupMerged //---------------------------------------------------------------------------- const std::string CustomUpdateWUGroupMerged::name = "CustomUpdateWU"; -//---------------------------------------------------------------------------- -void CustomUpdateWUGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, Substitutions &popSubs) const -{ - genCustomUpdate(os, popSubs, *this, "id_syn", - [this](const auto &varRef, const std::string &index) - { - return getVarRefIndex(getVarAccessDuplication(varRef.getVar().access), - index); - }); -} //---------------------------------------------------------------------------- // CustomUpdateTransposeWUGroupMerged //---------------------------------------------------------------------------- const std::string CustomUpdateTransposeWUGroupMerged::name = "CustomUpdateTransposeWU"; -//---------------------------------------------------------------------------- -void CustomUpdateTransposeWUGroupMerged::generateCustomUpdate(const BackendBase&, CodeStream &os, Substitutions &popSubs) const -{ - genCustomUpdate(os, popSubs, *this, "id_syn", - [this](const auto &varRef, const std::string &index) - { - return getVarRefIndex(getVarAccessDuplication(varRef.getVar().access), - index); - }); -} // ---------------------------------------------------------------------------- // CustomUpdateHostReductionGroupMerged From b9998e70983e9b4e0ff4f5ce284611e14ceee1db Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 21 Feb 2023 13:54:42 +0000 Subject: [PATCH 115/725] pretty print environment takes strings rather than tokens for easier integration --- .../genn/genn/code_generator/environment.h | 31 ++++++++----------- include/genn/genn/transpiler/prettyPrinter.h | 6 ++-- src/genn/genn/code_generator/environment.cc | 10 +++--- src/genn/genn/transpiler/prettyPrinter.cc | 24 +++++++------- 4 files changed, 33 insertions(+), 38 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index a3636ad1a8..c515398590 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -10,7 +10,6 @@ // GeNN transpiler includes #include "transpiler/prettyPrinter.h" -#include "transpiler/token.h" #include "transpiler/transpilerUtils.h" //---------------------------------------------------------------------------- @@ -35,7 +34,7 @@ class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual std::string define(const Transpiler::Token&); + virtual std::string define(const std::string &name); protected: //------------------------------------------------------------------------ @@ -45,7 +44,7 @@ class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase CodeStream &getContextStream() const; - std::string getContextName(const Transpiler::Token &name) const; + std::string getContextName(const std::string &name) const; private: //------------------------------------------------------------------------ @@ -62,7 +61,12 @@ class EnvironmentSubstitute : public EnvironmentExternal { using EnvironmentBase = Transpiler::PrettyPrinter::EnvironmentBase; public: - EnvironmentSubstitute(EnvironmentBase &enclosing) + EnvironmentSubstitute(EnvironmentSubstitute &enclosing) + : EnvironmentExternal(enclosing), m_Contents(m_ContentsStream) + { + } + + EnvironmentSubstitute(EnvironmentExternal &enclosing) : EnvironmentExternal(enclosing), m_Contents(m_ContentsStream) { } @@ -76,7 +80,7 @@ class EnvironmentSubstitute : public EnvironmentExternal //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual std::string getName(const Transpiler::Token &name) final; + virtual std::string getName(const std::string &name) final; virtual CodeStream &getStream() final { @@ -164,7 +168,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal typedef std::function GetIndexFn; public: - EnvironmentLocalVarCache(const G &group, Transpiler::PrettyPrinter::EnvironmentBase &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") + EnvironmentLocalVarCache(const G &group, EnvironmentSubstitute &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") : EnvironmentExternal(enclosing), m_Group(group), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) { // Add name of each definition to map, initially with value set to value @@ -173,15 +177,6 @@ class EnvironmentLocalVarCache : public EnvironmentExternal [](const auto &v){ return std::make_pair(v.name, false); }); } - EnvironmentLocalVarCache(const G &group, CodeStream &os, GetIndexFn getIndex, const std::string &localPrefix = "l") - : EnvironmentExternal(os), m_Group(group), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) - { - // Add name of each definition to map, initially with value set to value - const auto defs = A(m_Group).getDefs(); - std::transform(defs.cbegin(), defs.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), - [](const auto &v){ return std::make_pair(v.name, false); }); - } - ~EnvironmentLocalVarCache() { A adapter(m_Group); @@ -225,10 +220,10 @@ class EnvironmentLocalVarCache : public EnvironmentExternal //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual std::string getName(const Transpiler::Token &name) final + virtual std::string getName(const std::string &name) final { // If variable with this name isn't found, try and get name from context - auto var = m_VariablesReferenced.find(name.lexeme); + auto var = m_VariablesReferenced.find(name); if(var == m_VariablesReferenced.end()) { return getContextName(name); } @@ -238,7 +233,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal var->second = true; // Add local prefix to variable name - return m_LocalPrefix + name.lexeme; + return m_LocalPrefix + name; } } diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h index 5427dc00ca..b9d8d8a74a 100644 --- a/include/genn/genn/transpiler/prettyPrinter.h +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -26,11 +26,11 @@ class EnvironmentBase //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ - //! Define variable named by token and return the name as it should be used in code - virtual std::string define(const Token &name) = 0; + //! Define named variable and return the name as it should be used in code + virtual std::string define(const std::string &name) = 0; //! Get the name to use in code for the variable named by token - virtual std::string getName(const Token &name) = 0; + virtual std::string getName(const std::string &name) = 0; //! Get stream to write code within this environment to virtual CodeGenerator::CodeStream &getStream() = 0; diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index dd4f3eed9c..adc32d3677 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -11,7 +11,7 @@ using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentExternal //---------------------------------------------------------------------------- -std::string EnvironmentExternal::define(const Transpiler::Token&) +std::string EnvironmentExternal::define(const std::string&) { throw std::runtime_error("Cannot declare variable in external environment"); } @@ -25,12 +25,12 @@ CodeStream &EnvironmentExternal::getContextStream() const getContext()); } //---------------------------------------------------------------------------- -std::string EnvironmentExternal::getContextName(const Transpiler::Token &name) const +std::string EnvironmentExternal::getContextName(const std::string &name) const { return std::visit( Transpiler::Utils::Overload{ [&name](std::reference_wrapper enclosing)->std::string { return enclosing.get().getName(name); }, - [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Variable '" + name.lexeme + "' undefined"); }}, + [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Variable '" + name + "' undefined"); }}, getContext()); } @@ -51,10 +51,10 @@ EnvironmentSubstitute::~EnvironmentSubstitute() getContextStream() << m_ContentsStream.str(); } //---------------------------------------------------------------------------- -std::string EnvironmentSubstitute::getName(const Transpiler::Token &name) +std::string EnvironmentSubstitute::getName(const std::string &name) { // If there isn't a substitution for this name, try and get name from context - auto var = m_VarSubstitutions.find(name.lexeme); + auto var = m_VarSubstitutions.find(name); if(var == m_VarSubstitutions.end()) { return getContextName(name); } diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 372657ea7b..98720e440c 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -37,22 +37,22 @@ class EnvironmentInternal : public EnvironmentBase //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- - virtual std::string define(const Token &name) final + virtual std::string define(const std::string &name) final { - if(!m_LocalVariables.emplace(name.lexeme).second) { + if(!m_LocalVariables.emplace(name).second) { throw std::runtime_error("Redeclaration of variable"); } - return "_" + name.lexeme; + return "_" + name; } - virtual std::string getName(const Token &name) final + virtual std::string getName(const std::string &name) final { - if(m_LocalVariables.find(name.lexeme) == m_LocalVariables.end()) { + if(m_LocalVariables.find(name) == m_LocalVariables.end()) { return m_Enclosing.getName(name); } else { - return "_" + name.lexeme; + return "_" + name; } } @@ -91,14 +91,14 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- virtual void visit(const Expression::ArraySubscript &arraySubscript) final { - m_Environment.get().getStream() << m_Environment.get().getName(arraySubscript.getPointerName()) << "["; + m_Environment.get().getStream() << m_Environment.get().getName(arraySubscript.getPointerName().lexeme) << "["; arraySubscript.getIndex()->accept(*this); m_Environment.get().getStream() << "]"; } virtual void visit(const Expression::Assignment &assignement) final { - m_Environment.get().getStream() << m_Environment.get().getName(assignement.getVarName()) << " " << assignement.getOperator().lexeme << " "; + m_Environment.get().getStream() << m_Environment.get().getName(assignement.getVarName().lexeme) << " " << assignement.getOperator().lexeme << " "; assignement.getValue()->accept(*this); } @@ -170,17 +170,17 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final { - m_Environment.get().getStream() << m_Environment.get().getName(postfixIncDec.getVarName()) << postfixIncDec.getOperator().lexeme; + m_Environment.get().getStream() << m_Environment.get().getName(postfixIncDec.getVarName().lexeme) << postfixIncDec.getOperator().lexeme; } virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final { - m_Environment.get().getStream() << m_Environment.get().getName(prefixIncDec.getOperator()) << prefixIncDec.getVarName().lexeme; + m_Environment.get().getStream() << prefixIncDec.getOperator().lexeme << m_Environment.get().getName(prefixIncDec.getVarName().lexeme); } virtual void visit(const Expression::Variable &variable) final { - m_Environment.get().getStream() << m_Environment.get().getName(variable.getName()); + m_Environment.get().getStream() << m_Environment.get().getName(variable.getName().lexeme); } virtual void visit(const Expression::Unary &unary) final @@ -304,7 +304,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor printType(varDeclaration.getType()); for(const auto &var : varDeclaration.getInitDeclaratorList()) { - m_Environment.get().getStream() << m_Environment.get().define(std::get<0>(var)); + m_Environment.get().getStream() << m_Environment.get().define(std::get<0>(var).lexeme); if(std::get<1>(var)) { m_Environment.get().getStream() << " = "; std::get<1>(var)->accept(*this); From 268b58d3a294102dea5555209f9d3eb891fc4a7f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 21 Feb 2023 13:55:09 +0000 Subject: [PATCH 116/725] single-threaded backend compiling again --- .../backends/single_threaded_cpu/backend.h | 13 +- .../backends/single_threaded_cpu/backend.cc | 150 +++++++++--------- 2 files changed, 85 insertions(+), 78 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 62a590c1bc..7b3c215269 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -11,6 +11,7 @@ // GeNN code generator includes #include "code_generator/backendBase.h" +#include "code_generator/environment.h" // Forward declarations namespace filesystem @@ -219,20 +220,21 @@ class BACKEND_EXPORT Backend : public BackendBase //! Helper to generate code to copy reduced custom update group variables back to memory /*! Because reduction operations are unnecessary in unbatched single-threaded CPU models so there's no need to actually reduce */ - void genWriteBackReductions(CodeStream &os, const CustomUpdateGroupMerged &cg, const std::string &idx) const; + void genWriteBackReductions(EnvironmentExternal &env, const CustomUpdateGroupMerged &cg, const std::string &idxName) const; //! Helper to generate code to copy reduced custom weight update group variables back to memory /*! Because reduction operations are unnecessary in unbatched single-threaded CPU models so there's no need to actually reduce */ - void genWriteBackReductions(CodeStream &os, const CustomUpdateWUGroupMerged &cg, const std::string &idx) const; + void genWriteBackReductions(EnvironmentExternal &env, const CustomUpdateWUGroupMerged &cg, const std::string &idxName) const; template - void genWriteBackReductions(CodeStream &os, const G &cg, const std::string &idx, R getVarRefIndexFn) const + void genWriteBackReductions(EnvironmentExternal &env, const G &cg, const std::string &idxName, R getVarRefIndexFn) const { const auto *cm = cg.getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { // If variable is a reduction target, copy value from register straight back into global memory if(v.access & VarAccessModeAttribute::REDUCE) { - os << "group->" << v.name << "[" << cg.getVarIndex(getVarAccessDuplication(v.access), idx) << "] = l" << v.name << ";" << std::endl; + const std::string idx = env.getName(idxName); + env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(getVarAccessDuplication(v.access), idx) << "] = l" << v.name << ";" << std::endl; } } @@ -242,7 +244,8 @@ class BACKEND_EXPORT Backend : public BackendBase // If variable reference is a reduction target, copy value from register straight back into global memory if(modelVarRef.access & VarAccessModeAttribute::REDUCE) { - os << "group->" << modelVarRef.name << "[" << getVarRefIndexFn(varRef, idx) << "] = l" << modelVarRef.name << ";" << std::endl; + const std::string idx = env.getName(idxName); + env.getStream() << "group->" << modelVarRef.name << "[" << getVarRefIndexFn(varRef, idx) << "] = l" << modelVarRef.name << ";" << std::endl; } } } diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index c38a27dd72..506ecb0368 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -75,7 +75,7 @@ const std::vector &getFunctionTemplates(const s } //----------------------------------------------------------------------- template -void genKernelIteration(PrettyPrinter::EnvironmentBase &env, const G &g, size_t numKernelDims, BackendBase::Handler handler) +void genKernelIteration(EnvironmentExternal &env, const G &g, size_t numKernelDims, std::function/*BackendBase::Handler*/ handler) { EnvironmentSubstitute varEnv(env); @@ -99,8 +99,9 @@ void genKernelIteration(PrettyPrinter::EnvironmentBase &env, const G &g, size_t if (depth == (numKernelDims - 1)) { // Generate kernel index and use as "synapse" index // **TODO** rename - const size_t kernelInit = loopEnv.addInitialiser("const unsigned int kernelInd = " + g.genKernelIndex(loopEnv) + ";"); - loopEnv.addVarSubstitution("id_syn", "kernelInd", kernelInit); + assert(false); + //const size_t kernelInit = loopEnv.addInitialiser("const unsigned int kernelInd = " + g.genKernelIndex(loopEnv) + ";"); + //loopEnv.addVarSubstitution("id_syn", "kernelInd", kernelInit); // Call handler handler(loopEnv); @@ -489,24 +490,24 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge } } //-------------------------------------------------------------------------- -void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genCustomUpdate(CodeStream &os_, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); // Generate struct definitions - modelMerged.genMergedCustomUpdateStructs(os, *this); - modelMerged.genMergedCustomUpdateWUStructs(os, *this); - modelMerged.genMergedCustomUpdateTransposeWUStructs(os, *this); - modelMerged.genMergedCustomConnectivityUpdateStructs(os, *this); + modelMerged.genMergedCustomUpdateStructs(os_, *this); + modelMerged.genMergedCustomUpdateWUStructs(os_, *this); + modelMerged.genMergedCustomUpdateTransposeWUStructs(os_, *this); + modelMerged.genMergedCustomConnectivityUpdateStructs(os_, *this); // Generate arrays of merged structs and functions to set them - genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateWUGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateTransposeWUGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdateGroups()); + genMergedStructArrayPush(os_, modelMerged.getMergedCustomUpdateGroups()); + genMergedStructArrayPush(os_, modelMerged.getMergedCustomUpdateWUGroups()); + genMergedStructArrayPush(os_, modelMerged.getMergedCustomUpdateTransposeWUGroups()); + genMergedStructArrayPush(os_, modelMerged.getMergedCustomConnectivityUpdateGroups()); // Generate preamble - preambleHandler(os); + preambleHandler(os_); // Build set containing union of all custom update groupsnames std::set customUpdateGroups; @@ -522,23 +523,24 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // Loop through custom update groups for(const auto &g : customUpdateGroups) { - os << "void update" << g << "()"; + os_ << "void update" << g << "()"; { - CodeStream::Scope b(os); + CodeStream::Scope b(os_); - EnvironmentSubstitute funcEnv(os); + EnvironmentSubstitute funcEnv(os_); funcEnv.addSubstitution("t", "t"); funcEnv.addSubstitution("batch", "0"); // Loop through host update groups and generate code for those in this custom update group for (const auto &cg : modelMerged.getMergedCustomConnectivityHostUpdateGroups()) { if (cg.getArchetype().getUpdateGroupName() == g) { - cg.generateUpdate(*this, os); + assert(false); + //cg.generateUpdate(*this, os); } } { - Timer t(os, "customUpdate" + g, model.isTimingEnabled()); + Timer t(funcEnv.getStream(), "customUpdate" + g, model.isTimingEnabled()); // Loop through merged custom update groups for(const auto &c : modelMerged.getMergedCustomUpdateGroups()) { @@ -547,25 +549,25 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged continue; } - CodeStream::Scope b(os); - os << "// merged custom update group " << c.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged custom update group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Get reference to group - os << "const auto *group = &mergedCustomUpdateGroup" << c.getIndex() << "[g]; " << std::endl; + funcEnv.getStream() << "const auto *group = &mergedCustomUpdateGroup" << c.getIndex() << "[g]; " << std::endl; - genCustomUpdateIndexCalculation(os, c); + genCustomUpdateIndexCalculation(funcEnv.getStream(), c); if (c.getArchetype().isNeuronReduction()) { // Initialise reduction targets - const auto reductionTargets = genInitReductionTargets(os, c); + const auto reductionTargets = genInitReductionTargets(funcEnv.getStream(), c); // Loop through group members - os << "for(unsigned int i = 0; i < group->size; i++)"; + funcEnv.getStream() << "for(unsigned int i = 0; i < group->size; i++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Generate custom update EnvironmentSubstitute env(funcEnv); @@ -574,20 +576,20 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // Loop through reduction targets and generate reduction for (const auto &r : reductionTargets) { - os << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; + env.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; } } // Write back reductions for (const auto &r : reductionTargets) { - os << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; + funcEnv.getStream() << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; } } else { // Loop through group members - os << "for(unsigned int i = 0; i < group->size; i++)"; + funcEnv.getStream() << "for(unsigned int i = 0; i < group->size; i++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Generate custom update EnvironmentSubstitute env(funcEnv); @@ -595,7 +597,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged c.generateCustomUpdate(*this, env); // Write back reductions - genWriteBackReductions(os, c, "i"); + genWriteBackReductions(env, c, "id"); } } } @@ -608,47 +610,46 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged continue; } - CodeStream::Scope b(os); - os << "// merged custom WU update group " << c.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged custom WU update group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Get reference to group - os << "const auto *group = &mergedCustomUpdateWUGroup" << c.getIndex() << "[g]; " << std::endl; + funcEnv.getStream() << "const auto *group = &mergedCustomUpdateWUGroup" << c.getIndex() << "[g]; " << std::endl; const SynapseGroupInternal *sg = c.getArchetype().getSynapseGroup(); EnvironmentSubstitute synEnv(funcEnv); if (sg->getMatrixType() & SynapseMatrixWeight::KERNEL) { - genKernelIteration(c, c.getArchetype().getSynapseGroup()->getKernelSize().size(), synEnv, - [&c, &modelMerged, this] - (PrettyPrinter::EnvironmentBase &env) + genKernelIteration(synEnv, c, c.getArchetype().getSynapseGroup()->getKernelSize().size(), + [&c, this](EnvironmentExternal &env) { // Call custom update handler c.generateCustomUpdate(*this, env); // Write back reductions - genWriteBackReductions(env, c); + genWriteBackReductions(env, c, "id_syn"); }); } else { // Loop through presynaptic neurons - os << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; + synEnv.getStream() << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; { // If this synapse group has sparse connectivity, loop through length of this row - CodeStream::Scope b(os); + CodeStream::Scope b(synEnv.getStream()); if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << "for(unsigned int s = 0; s < group->rowLength[i]; s++)"; + synEnv.getStream() << "for(unsigned int s = 0; s < group->rowLength[i]; s++)"; } // Otherwise, if it's dense, loop through each postsynaptic neuron else if (sg->getMatrixType() & SynapseMatrixConnectivity::DENSE) { - os << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; + synEnv.getStream() << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; } else { throw std::runtime_error("Only DENSE and SPARSE format connectivity can be used for custom updates"); } { - CodeStream::Scope b(os); + CodeStream::Scope b(synEnv.getStream()); // Add presynaptic index to substitutions synEnv.addSubstitution("id_pre", "i"); @@ -674,7 +675,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged c.generateCustomUpdate(*this, synEnv); // Write back reductions - genWriteBackReductions(os, c, synEnv["id_syn"]); + genWriteBackReductions(synEnv, c, "id_syn"); } } } @@ -688,28 +689,29 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged continue; } - CodeStream::Scope b(os); - os << "// merged custom connectivity update group " << c.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged custom connectivity update group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Get reference to group - os << "const auto *group = &mergedCustomConnectivityUpdateGroup" << c.getIndex() << "[g]; " << std::endl; + funcEnv.getStream() << "const auto *group = &mergedCustomConnectivityUpdateGroup" << c.getIndex() << "[g]; " << std::endl; - genCustomConnectivityUpdateIndexCalculation(os, c); + genCustomConnectivityUpdateIndexCalculation(funcEnv.getStream(), c); // Loop through presynaptic neurons - os << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; + funcEnv.getStream() << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Configure substitutions EnvironmentSubstitute cuEnv(funcEnv); cuEnv.addSubstitution("id_pre", "i"); cuEnv.addSubstitution("rng", "hostRNG"); - c.generateUpdate(*this, os, model.getBatchSize(), popSubs); + assert(false); + //c.generateUpdate(*this, cuEnv, model.getBatchSize()); } } } @@ -717,21 +719,21 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // Loop through merged custom WU transpose update groups { - Timer t(os, "customUpdate" + g + "Transpose", model.isTimingEnabled()); + Timer t(funcEnv.getStream(), "customUpdate" + g + "Transpose", model.isTimingEnabled()); for(const auto &c : modelMerged.getMergedCustomUpdateTransposeWUGroups()) { // If this update group isn't for current group, skip if(c.getArchetype().getUpdateGroupName() != g) { continue; } - CodeStream::Scope b(os); - os << "// merged custom WU transpose update group " << c.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged custom WU transpose update group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Get reference to group - os << "const auto *group = &mergedCustomUpdateTransposeWUGroup" << c.getIndex() << "[g]; " << std::endl; + funcEnv.getStream() << "const auto *group = &mergedCustomUpdateTransposeWUGroup" << c.getIndex() << "[g]; " << std::endl; // Get index of variable being transposed const size_t transposeVarIdx = std::distance(c.getArchetype().getVarReferences().cbegin(), @@ -740,14 +742,14 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged const std::string transposeVarName = c.getArchetype().getCustomUpdateModel()->getVarRefs().at(transposeVarIdx).name; // Loop through presynaptic neurons - os << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; + funcEnv.getStream() << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Loop through each postsynaptic neuron - os << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; + funcEnv.getStream() << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Add pre and postsynaptic indices to environment EnvironmentSubstitute synEnv(funcEnv); @@ -762,7 +764,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged c.generateCustomUpdate(*this, synEnv); // Update transpose variable - os << "group->" << transposeVarName << "Transpose[(j * group->numSrcNeurons) + i] = l" << transposeVarName << ";" << std::endl; + synEnv.getStream() << "group->" << transposeVarName << "Transpose[(j * group->numSrcNeurons) + i] = l" << transposeVarName << ";" << std::endl; } } @@ -1431,12 +1433,14 @@ void Backend::genDenseSynapseVariableRowInit(CodeStream &os, const Substitutions //-------------------------------------------------------------------------- void Backend::genKernelSynapseVariableInit(CodeStream &os, const SynapseInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const { - genKernelIteration(os, sg, sg.getArchetype().getKernelSize().size(), kernelSubs, handler); + assert(false); + //genKernelIteration(os, sg, sg.getArchetype().getKernelSize().size(), kernelSubs, handler); } //-------------------------------------------------------------------------- void Backend::genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged &cu, const Substitutions &kernelSubs, Handler handler) const { - genKernelIteration(os, cu, cu.getArchetype().getSynapseGroup()->getKernelSize().size(), kernelSubs, handler); + assert(false); + //genKernelIteration(os, cu, cu.getArchetype().getSynapseGroup()->getKernelSize().size(), kernelSubs, handler); } //-------------------------------------------------------------------------- void Backend::genGlobalDeviceRNG(CodeStream&, CodeStream&, CodeStream&, CodeStream&, CodeStream&, const Type::TypeContext&, MemAlloc&) const @@ -1920,9 +1924,9 @@ void Backend::genEmitSpike(CodeStream &os, const NeuronUpdateGroupMerged &ng, co } } //-------------------------------------------------------------------------- -void Backend::genWriteBackReductions(CodeStream &os, const CustomUpdateGroupMerged &cg, const std::string &idx) const +void Backend::genWriteBackReductions(EnvironmentExternal &env, const CustomUpdateGroupMerged &cg, const std::string &idxName) const { - genWriteBackReductions(os, cg, idx, + genWriteBackReductions(env, cg, idxName, [&cg](const Models::VarReference &varRef, const std::string &index) { return cg.getVarRefIndex(varRef.getDelayNeuronGroup() != nullptr, @@ -1931,9 +1935,9 @@ void Backend::genWriteBackReductions(CodeStream &os, const CustomUpdateGroupMerg }); } //-------------------------------------------------------------------------- -void Backend::genWriteBackReductions(CodeStream &os, const CustomUpdateWUGroupMerged &cg, const std::string &idx) const +void Backend::genWriteBackReductions(EnvironmentExternal &env, const CustomUpdateWUGroupMerged &cg, const std::string &idxName) const { - genWriteBackReductions(os, cg, idx, + genWriteBackReductions(env, cg, idxName, [&cg](const Models::WUVarReference &varRef, const std::string &index) { return cg.getVarRefIndex(getVarAccessDuplication(varRef.getVar().access), From 2f7af8308216a522b0369453633434564ac5d136 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 21 Feb 2023 17:49:14 +0000 Subject: [PATCH 117/725] CPU only GeNN now completely working with environments for custom updates --- .../genn/genn/code_generator/backendBase.h | 4 + .../genn/genn/code_generator/backendSIMT.h | 93 +++++- .../genn/genn/code_generator/environment.h | 2 +- src/genn/genn/code_generator/backendSIMT.cc | 280 +++++++++--------- src/genn/genn/code_generator/environment.cc | 1 + 5 files changed, 236 insertions(+), 144 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index cec6681eb0..c36505c0f7 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -34,6 +34,7 @@ class SynapseGroupInternal; namespace CodeGenerator { +class EnvironmentExternal; class ModelSpecMerged; class NeuronUpdateGroupMerged; class Substitutions; @@ -187,6 +188,9 @@ class GENN_EXPORT BackendBase template using GroupHandler = std::function ; + + template + using GroupHandlerEnv = std::function ; //! Vector of prefixes required to allocate in memory space and size of memory space typedef std::vector> MemorySpaces; diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 13a87f949e..68669dd8ec 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -201,13 +201,13 @@ class GENN_EXPORT BackendSIMT : public BackendBase void genPostsynapticUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const; void genSynapseDynamicsKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const; - void genCustomUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, + void genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const; - void genCustomUpdateWUKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, + void genCustomUpdateWUKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const; - void genCustomTransposeUpdateWUKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, + void genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const; void genCustomConnectivityUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, @@ -320,6 +320,93 @@ class GENN_EXPORT BackendSIMT : public BackendBase genParallelGroup(os, kernelSubs, groups, idStart, getPaddedSizeFunc, [](const T &) { return true; }, handler); } + + template + void genParallelGroup(EnvironmentExternal &env, const std::vector &groups, size_t &idStart, + S getPaddedSizeFunc, F filter, GroupHandlerEnv handler) const + { + // Loop through groups + for(const auto &gMerge : groups) { + if(filter(gMerge)) { + // Sum padded sizes of each group within merged group + const size_t paddedSize = std::accumulate( + gMerge.getGroups().cbegin(), gMerge.getGroups().cend(), size_t{0}, + [getPaddedSizeFunc](size_t acc, std::reference_wrapper g) + { + return (acc + getPaddedSizeFunc(g.get())); + }); + + env.getStream() << "// merged" << gMerge.getIndex() << std::endl; + + // If this is the first group + if(idStart == 0) { + env.getStream() << "if(id < " << paddedSize << ")"; + } + else { + env.getStream() << "if(id >= " << idStart << " && id < " << idStart + paddedSize << ")"; + } + { + CodeStream::Scope b(env.getStream()); + EnvironmentSubstitute popEnv(env); + + if(gMerge.getGroups().size() == 1) { + popEnv.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; + popEnv.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[0]; " << std::endl; + popEnv.getStream() << "const unsigned int lid = id - " << idStart << ";" << std::endl; + + // Use the starting thread ID of the whole merged group as group_start_id + popEnv.addSubstitution("group_start_id", std::to_string(idStart)); + } + else { + // Perform bisect operation to get index of merged struct + popEnv.getStream() << "unsigned int lo = 0;" << std::endl; + popEnv.getStream() << "unsigned int hi = " << gMerge.getGroups().size() << ";" << std::endl; + popEnv.getStream() << "while(lo < hi)" << std::endl; + { + CodeStream::Scope b(popEnv.getStream()); + popEnv.getStream() << "const unsigned int mid = (lo + hi) / 2;" << std::endl; + + popEnv.getStream() << "if(id < d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[mid])"; + { + CodeStream::Scope b(popEnv.getStream()); + popEnv.getStream() << "hi = mid;" << std::endl; + } + popEnv.getStream() << "else"; + { + CodeStream::Scope b(popEnv.getStream()); + popEnv.getStream() << "lo = mid + 1;" << std::endl; + } + } + + // Use this to get reference to merged group structure + popEnv.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; + popEnv.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[lo - 1]; " << std::endl; + + // Get group start thread ID and use as group_start_id + popEnv.getStream() << "const unsigned int groupStartID = d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[lo - 1];" << std::endl; + popEnv.addSubstitution("group_start_id", "groupStartID"); + + // Use this to calculate local id within group + popEnv.getStream() << "const unsigned int lid = id - groupStartID;" << std::endl; + } + popEnv.addSubstitution("id", "lid"); + + handler(popEnv, gMerge); + + idStart += paddedSize; + } + } + } + } + + + template + void genParallelGroup(EnvironmentExternal &env, const std::vector &groups, size_t &idStart, + S getPaddedSizeFunc, GroupHandlerEnv handler) const + { + genParallelGroup(env, groups, idStart, getPaddedSizeFunc, + [](const T &) { return true; }, handler); + } // Helper function to generate kernel code to initialise variables associated with synapse group or custom WU update with dense/kernel connectivity template diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index c515398590..f01fb51a3f 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -168,7 +168,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal typedef std::function GetIndexFn; public: - EnvironmentLocalVarCache(const G &group, EnvironmentSubstitute &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") + EnvironmentLocalVarCache(const G &group, EnvironmentExternal &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") : EnvironmentExternal(enclosing), m_Group(group), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) { // Add name of each definition to map, initially with value set to value diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 92d55cbaf2..0543ed30f4 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -896,91 +896,91 @@ void BackendSIMT::genSynapseDynamicsKernel(CodeStream &os, const Substitutions & }); } //-------------------------------------------------------------------------- -void BackendSIMT::genCustomUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, +void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const { genParallelGroup( - os, kernelSubs, modelMerged.getMergedCustomUpdateGroups(), idStart, + env, modelMerged.getMergedCustomUpdateGroups(), idStart, [&modelMerged, this](const CustomUpdateInternal &cu) { return getPaddedNumCustomUpdateThreads(cu, modelMerged.getModel().getBatchSize()); }, [&updateGroup](const CustomUpdateGroupMerged &cg) { return (cg.getArchetype().getUpdateGroupName() == updateGroup); }, - [&modelMerged, this](CodeStream &os, const CustomUpdateGroupMerged &cg, Substitutions &popSubs) + [&modelMerged, this](EnvironmentExternal &env, const CustomUpdateGroupMerged &cg) { const size_t blockSize = getKernelBlockSize(KernelCustomUpdate); const unsigned int batchSize = modelMerged.getModel().getBatchSize(); // If update is a batch reduction - Substitutions cuSubs(&popSubs); + EnvironmentSubstitute cuEnv(env); if(cg.getArchetype().isBatchReduction()) { - os << "// only do this for existing neurons" << std::endl; - os << "if(" << cuSubs["id"] << " < group->size)"; + cuEnv.getStream() << "// only do this for existing neurons" << std::endl; + cuEnv.getStream() << "if(" << cuEnv.getName("id") << " < group->size)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(cuEnv.getStream()); // Initialise reduction targets - const auto reductionTargets = genInitReductionTargets(os, cg, cuSubs["id"]); + const auto reductionTargets = genInitReductionTargets(cuEnv.getStream(), cg, cuEnv.getName("id")); // Loop through batches // **TODO** this naive approach is good for reduction when there are lots of neurons/synapses but, // if this isn't the case (TF uses a threshold of 4096), we should do something smarter - os << "for(unsigned int batch = 0; batch < " << batchSize << "; batch++)"; + cuEnv.getStream() << "for(unsigned int batch = 0; batch < " << batchSize << "; batch++)"; { - CodeStream::Scope b(os); - cuSubs.addVarSubstitution("batch", "batch"); + CodeStream::Scope b(cuEnv.getStream()); + cuEnv.addSubstitution("batch", "batch"); - genCustomUpdateIndexCalculation(os, cg); + genCustomUpdateIndexCalculation(cuEnv.getStream(), cg); // **THINK** it would be great to 'lift' reads of SHARED variables out of this loop - cg.generateCustomUpdate(*this, os, cuSubs); + cg.generateCustomUpdate(*this, cuEnv); // Loop through reduction targets and generate reduction for(const auto &r : reductionTargets) { - os << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; + cuEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; } } // Loop through reduction targets and write reduced value back to memory for(const auto &r : reductionTargets) { - os << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; + cuEnv.getStream() << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; } } } // Otherwise, if this is a neuron reduction else if (cg.getArchetype().isNeuronReduction()) { - os << "// only do this for existing neurons" << std::endl; - os << "if(" << cuSubs["id"] << " < " << (32 * modelMerged.getModel().getBatchSize()) << ")"; + cuEnv.getStream() << "// only do this for existing neurons" << std::endl; + cuEnv.getStream() << "if(" << cuEnv.getName("id") << " < " << (32 * modelMerged.getModel().getBatchSize()) << ")"; { - CodeStream::Scope b(os); + CodeStream::Scope b(cuEnv.getStream()); // Split ID into lane and batch - os << "const unsigned int lane = " << cuSubs["id"] << " % 32;" << std::endl; - os << "const unsigned int batch = " << cuSubs["id"] << " / 32;" << std::endl; - cuSubs.addVarSubstitution("batch", "batch"); + cuEnv.getStream() << "const unsigned int lane = " << cuEnv.getName("id") << " % 32;" << std::endl; + cuEnv.getStream() << "const unsigned int batch = " << cuEnv.getName("id") << " / 32;" << std::endl; + cuEnv.addSubstitution("batch", "batch"); - genCustomUpdateIndexCalculation(os, cg); + genCustomUpdateIndexCalculation(cuEnv.getStream(), cg); // Initialise reduction targets - const auto reductionTargets = genInitReductionTargets(os, cg); + const auto reductionTargets = genInitReductionTargets(cuEnv.getStream(), cg); // Loop through warps of data // **TODO** this approach is good for reductions where there are small numbers of neurons but large batches sizes but, // if this isn't the case (TF uses a threshold of 1024), we should do something smarter - os << "for(unsigned int idx = lane; idx < group->size; idx += 32)"; + cuEnv.getStream() << "for(unsigned int idx = lane; idx < group->size; idx += 32)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(cuEnv.getStream()); // Re-substitute id with loop index - Substitutions reductionSubs(&cuSubs); - reductionSubs.addVarSubstitution("id", "idx", true); + EnvironmentSubstitute reductionEnv(cuEnv); + reductionEnv.addSubstitution("id", "idx"); // **THINK** it would be great to 'lift' reads of NEURON_SHARED variables out of this loop - cg.generateCustomUpdate(*this, os, reductionSubs); + cg.generateCustomUpdate(*this, reductionEnv); // Loop through reduction targets and generate reduction for (const auto &r : reductionTargets) { - os << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; + reductionEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; } } @@ -988,17 +988,17 @@ void BackendSIMT::genCustomUpdateKernel(CodeStream &os, const Substitutions &ker // **YUCK** CUDA-specific for (unsigned int i = 16; i > 0; i /= 2) { for (const auto &r : reductionTargets) { - os << getReductionOperation("lr" + r.name, "__shfl_down_sync(0xFFFFFFFF, lr" + r.name + ", " + std::to_string(i) + ")", - r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; + cuEnv.getStream() << getReductionOperation("lr" + r.name, "__shfl_down_sync(0xFFFFFFFF, lr" + r.name + ", " + std::to_string(i) + ")", + r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; } } // In first lane, loop through reduction targets and write reduced value back to memory - os << "if(lane == 0)"; + cuEnv.getStream() << "if(lane == 0)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(cuEnv.getStream()); for (const auto &r : reductionTargets) { - os << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; + cuEnv.getStream() << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; } } } @@ -1008,26 +1008,26 @@ void BackendSIMT::genCustomUpdateKernel(CodeStream &os, const Substitutions &ker if(cg.getArchetype().isBatched()) { // Split ID into intra-batch ID and batch // **TODO** fast-divide style optimisations here - os << "const unsigned int paddedSize = " << blockSize << " * ((group->size + " << blockSize << " - 1) / " << blockSize << ");" << std::endl; - os << "const unsigned int bid = " << cuSubs["id"] << " % paddedSize;" << std::endl; - os << "const unsigned int batch = " << cuSubs["id"] << " / paddedSize;" << std::endl; + cuEnv.getStream() << "const unsigned int paddedSize = " << blockSize << " * ((group->size + " << blockSize << " - 1) / " << blockSize << ");" << std::endl; + cuEnv.getStream() << "const unsigned int bid = " << cuEnv.getName("id") << " % paddedSize;" << std::endl; + cuEnv.getStream() << "const unsigned int batch = " << cuEnv.getName("id") << " / paddedSize;" << std::endl; // Replace id in substitution with intra-batch ID and add batch - cuSubs.addVarSubstitution("id", "bid", true); - cuSubs.addVarSubstitution("batch", "batch"); + cuEnv.addSubstitution("id", "bid"); + cuEnv.addSubstitution("batch", "batch"); } // Otherwise, just substitute "batch" for 0 else { - cuSubs.addVarSubstitution("batch", "0"); + cuEnv.addSubstitution("batch", "0"); } - os << "// only do this for existing neurons" << std::endl; - os << "if(" << cuSubs["id"] << " < group->size)"; + cuEnv.getStream() << "// only do this for existing neurons" << std::endl; + cuEnv.getStream() << "if(" << cuEnv.getName("id") << " < group->size)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(cuEnv.getStream()); - genCustomUpdateIndexCalculation(os, cg); - cg.generateCustomUpdate(*this, os, cuSubs); + genCustomUpdateIndexCalculation(cuEnv.getStream(), cg); + cg.generateCustomUpdate(*this, cuEnv); } } @@ -1035,17 +1035,17 @@ void BackendSIMT::genCustomUpdateKernel(CodeStream &os, const Substitutions &ker }); } //-------------------------------------------------------------------------- -void BackendSIMT::genCustomUpdateWUKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, +void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const { genParallelGroup( - os, kernelSubs, modelMerged.getMergedCustomUpdateWUGroups(), idStart, + env, modelMerged.getMergedCustomUpdateWUGroups(), idStart, [&modelMerged, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateWUThreads(cg, modelMerged.getModel().getBatchSize()); }, [&updateGroup](const CustomUpdateWUGroupMerged &cg) { return (cg.getArchetype().getUpdateGroupName() == updateGroup); }, - [&modelMerged, this](CodeStream &os, const CustomUpdateWUGroupMerged &cg, Substitutions &popSubs) + [&modelMerged, this](EnvironmentExternal &env, const CustomUpdateWUGroupMerged &cg) { const SynapseGroupInternal *sg = cg.getArchetype().getSynapseGroup(); const size_t blockSize = getKernelBlockSize(KernelCustomUpdate); @@ -1054,132 +1054,132 @@ void BackendSIMT::genCustomUpdateWUKernel(CodeStream &os, const Substitutions &k // Calculate size of each batch to update if (sg->getMatrixType() & SynapseMatrixWeight::KERNEL) { // Loop through kernel dimensions and multiply together - os << "const unsigned int size = "; + env.getStream() << "const unsigned int size = "; for (size_t i = 0; i < sg->getKernelSize().size(); i++) { - os << cg.getKernelSize(i); + env.getStream() << cg.getKernelSize(i); if (i != (sg->getKernelSize().size() - 1)) { - os << " * "; + env.getStream() << " * "; } } - os << ";" << std::endl; + env.getStream() << ";" << std::endl; } else { - os << "const unsigned int size = group->numSrcNeurons * group->rowStride;" << std::endl; + env.getStream() << "const unsigned int size = group->numSrcNeurons * group->rowStride;" << std::endl; } // If update isn't a batch reduction - Substitutions cuSubs(&popSubs); + EnvironmentSubstitute cuEnv(env); if(!cg.getArchetype().isBatchReduction()) { // If it's batched if(cg.getArchetype().isBatched()) { - os << "const unsigned int paddedSize = " << blockSize << " * ((size + " << blockSize << " - 1) / " << blockSize << ");" << std::endl; + cuEnv.getStream() << "const unsigned int paddedSize = " << blockSize << " * ((size + " << blockSize << " - 1) / " << blockSize << ");" << std::endl; // Split ID into intra-batch ID and batch // **TODO** fast-divide style optimisations here - os << "const unsigned int bid = " << cuSubs["id"] << " % paddedSize;" << std::endl; - os << "const unsigned int batch = " << cuSubs["id"] << " / paddedSize;" << std::endl; + cuEnv.getStream() << "const unsigned int bid = " << cuEnv.getName("id") << " % paddedSize;" << std::endl; + cuEnv.getStream() << "const unsigned int batch = " << cuEnv.getName("id") << " / paddedSize;" << std::endl; // Replace id in substitution with intra-batch ID and add batch - cuSubs.addVarSubstitution("id", "bid", true); - cuSubs.addVarSubstitution("batch", "batch"); + cuEnv.addSubstitution("id", "bid"); + cuEnv.addSubstitution("batch", "batch"); // Calculate batch offset - os << "const unsigned int batchOffset = size * batch;" << std::endl; + cuEnv.getStream() << "const unsigned int batchOffset = size * batch;" << std::endl; } // Otherwise, just substitute "batch" for 0 else { - cuSubs.addVarSubstitution("batch", "0"); + cuEnv.addSubstitution("batch", "0"); } } // if this isn't a padding thread - os << "if (" << cuSubs["id"] << " < size)"; + cuEnv.getStream() << "if (" << cuEnv.getName("id") << " < size)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(cuEnv.getStream()); if (sg->getMatrixType() & SynapseMatrixWeight::KERNEL) { - cuSubs.addVarSubstitution("id_syn", cuSubs["id"]); - cuSubs.addVarSubstitution("id_kernel", cuSubs["id"]); + cuEnv.addSubstitution("id_syn", cuEnv.getName("id")); + cuEnv.addSubstitution("id_kernel", cuEnv.getName("id")); } else { if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // **OPTIMIZE * *we can do a fast constant divide optimization here and use the result to calculate the remainder - os << "const unsigned int row = " << cuSubs["id"] << " / group->rowStride;" << std::endl; - os << "const unsigned int col = " << cuSubs["id"] << " % group->rowStride;" << std::endl; + cuEnv.getStream() << "const unsigned int row = " << cuEnv.getName("id") << " / group->rowStride;" << std::endl; + cuEnv.getStream() << "const unsigned int col = " << cuEnv.getName("id") << " % group->rowStride;" << std::endl; - cuSubs.addVarSubstitution("id_pre", "row"); - cuSubs.addVarSubstitution("id_post", "group->ind[" + cuSubs["id"] + "]"); - cuSubs.addVarSubstitution("id_syn", cuSubs["id"]); + cuEnv.addSubstitution("id_pre", "row"); + cuEnv.addSubstitution("id_post", "group->ind[" + cuEnv.getName("id") + "]"); + cuEnv.addSubstitution("id_syn", cuEnv.getName("id")); - os << "if(col < group->rowLength[row])"; - os << CodeStream::OB(2); + cuEnv.getStream() << "if(col < group->rowLength[row])"; + cuEnv.getStream() << CodeStream::OB(2); } else { // **OPTIMIZE** we can do a fast constant divide optimization here and use the result to calculate the remainder - cuSubs.addVarSubstitution("id_pre", "(" + cuSubs["id"] + " / group->rowStride)"); - cuSubs.addVarSubstitution("id_post", "(" + cuSubs["id"] + " % group->rowStride)"); - cuSubs.addVarSubstitution("id_syn", cuSubs["id"]); + cuEnv.addSubstitution("id_pre", "(" + cuEnv.getName("id") + " / group->rowStride)"); + cuEnv.addSubstitution("id_post", "(" +cuEnv.getName("id") + " % group->rowStride)"); + cuEnv.addSubstitution("id_syn", cuEnv.getName("id")); } } // Initialise reduction targets - const auto reductionTargets = genInitReductionTargets(os, cg, cuSubs["id_syn"]); + const auto reductionTargets = genInitReductionTargets(cuEnv.getStream(), cg, cuEnv.getName("id_syn")); // If this is a reduction if(cg.getArchetype().isBatchReduction()) { // Loop through batches // **TODO** this naive approach is good for reduction when there are lots of neurons/synapses but, // if this isn't the case (TF uses a threshold of 4096), we should do something smarter - os << "for(unsigned int batch = 0; batch < " << batchSize << "; batch++)"; - os << CodeStream::OB(1); - cuSubs.addVarSubstitution("batch", "batch"); + cuEnv.getStream() << "for(unsigned int batch = 0; batch < " << batchSize << "; batch++)"; + cuEnv.getStream() << CodeStream::OB(1); + cuEnv.addSubstitution("batch", "batch"); } // Calculate batch offset if required if(cg.getArchetype().isBatched()) { - os << "const unsigned int batchOffset = size * batch;" << std::endl; + cuEnv.getStream() << "const unsigned int batchOffset = size * batch;" << std::endl; } - cg.generateCustomUpdate(*this, os, cuSubs); + cg.generateCustomUpdate(*this, cuEnv); // If this is a reduction if(cg.getArchetype().isBatchReduction()) { // Loop through reduction targets and generate reduction for(const auto &r : reductionTargets) { - os << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; + cuEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; } // End for loop through batches - os << CodeStream::CB(1); + cuEnv.getStream() << CodeStream::CB(1); // Loop through reduction targets and write reduced value back to memory for(const auto &r : reductionTargets) { - os << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; + cuEnv.getStream() << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; } } if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << CodeStream::CB(2); + cuEnv.getStream() << CodeStream::CB(2); } } }); } //-------------------------------------------------------------------------- -void BackendSIMT::genCustomTransposeUpdateWUKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, +void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const { // Generate 2D array const size_t blockSize = getKernelBlockSize(KernelCustomTransposeUpdate); - os << getSharedPrefix() << " float shTile[" << blockSize << "][" << (blockSize + 1) << "];" << std::endl; + env.getStream() << getSharedPrefix() << " float shTile[" << blockSize << "][" << (blockSize + 1) << "];" << std::endl; genParallelGroup( - os, kernelSubs, modelMerged.getMergedCustomUpdateTransposeWUGroups(), idStart, + env, modelMerged.getMergedCustomUpdateTransposeWUGroups(), idStart, [&modelMerged, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateTransposeWUThreads(cg, modelMerged.getModel().getBatchSize()); }, [&updateGroup](const CustomUpdateTransposeWUGroupMerged &cg) { return (cg.getArchetype().getUpdateGroupName() == updateGroup); }, - [&modelMerged, this, blockSize](CodeStream &os, const CustomUpdateTransposeWUGroupMerged &cg, Substitutions &popSubs) + [&modelMerged, this, blockSize](EnvironmentExternal &env, const CustomUpdateTransposeWUGroupMerged &cg) { // Get index of variable being transposed const size_t transposeVarIdx = std::distance(cg.getArchetype().getVarReferences().cbegin(), @@ -1188,95 +1188,95 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(CodeStream &os, const Substit const std::string transposeVarName = cg.getArchetype().getCustomUpdateModel()->getVarRefs().at(transposeVarIdx).name; // To allow these kernels to be batched, we turn 2D grid into wide 1D grid of 2D block so calculate size - os << "const unsigned int numXBlocks = (group->numTrgNeurons + " << (blockSize - 1) << ") / " << blockSize << ";" << std::endl; + env.getStream() << "const unsigned int numXBlocks = (group->numTrgNeurons + " << (blockSize - 1) << ") / " << blockSize << ";" << std::endl; // Calculate what block this kernel starts at (because of kernel merging, it may not start at block 0) - os << "const unsigned int blockStart = " << popSubs["group_start_id"] << " / " << blockSize << ";" << std::endl; + env.getStream() << "const unsigned int blockStart = " << env.getName("group_start_id") << " / " << blockSize << ";" << std::endl; - Substitutions synSubs(&popSubs); + EnvironmentSubstitute synEnv(env); if(cg.getArchetype().isBatched()) { // If there's multiple batches we also need to know how many Y blocks and hence total blocks there are - os << "const unsigned int numYBlocks = (group->numSrcNeurons + " << (blockSize - 1) << ") / " << blockSize << ";" << std::endl; - os << "const unsigned int numBlocks = numXBlocks * numYBlocks;" << std::endl; + synEnv.getStream() << "const unsigned int numYBlocks = (group->numSrcNeurons + " << (blockSize - 1) << ") / " << blockSize << ";" << std::endl; + synEnv.getStream() << "const unsigned int numBlocks = numXBlocks * numYBlocks;" << std::endl; // Therefore determine block and batch - os << "const unsigned int batchBlock = " << getBlockID(0) << " - blockStart;" << std::endl; - os << "const unsigned int block = batchBlock % numBlocks;" << std::endl; - os << "const unsigned int batch = batchBlock / numBlocks;" << std::endl; + synEnv.getStream() << "const unsigned int batchBlock = " << getBlockID(0) << " - blockStart;" << std::endl; + synEnv.getStream() << "const unsigned int block = batchBlock % numBlocks;" << std::endl; + synEnv.getStream() << "const unsigned int batch = batchBlock / numBlocks;" << std::endl; // Finally, calculate batch offset into arrays etc - os << "const unsigned int batchOffset = batch * group->numSrcNeurons * group->numTrgNeurons;" << std::endl; + synEnv.getStream() << "const unsigned int batchOffset = batch * group->numSrcNeurons * group->numTrgNeurons;" << std::endl; // Add batch to substitutions - synSubs.addVarSubstitution("batch", "batch"); + synEnv.addSubstitution("batch", "batch"); } // Otherwise, just substitute "batch" for 0 else { - os << "const unsigned int block = " << getBlockID(0) << " - blockStart;" << std::endl; - synSubs.addVarSubstitution("batch", "0"); + synEnv.getStream() << "const unsigned int block = " << getBlockID(0) << " - blockStart;" << std::endl; + synEnv.addSubstitution("batch", "0"); } // Divide block index into x and y // **TODO** fast-divide style optimisations here - os << "const unsigned int blockX = (block % numXBlocks);" << std::endl; - os << "const unsigned int blockY = (block / numXBlocks);" << std::endl; + synEnv.getStream() << "const unsigned int blockX = (block % numXBlocks);" << std::endl; + synEnv.getStream() << "const unsigned int blockY = (block / numXBlocks);" << std::endl; { - CodeStream::Scope b(os); - os << "// Calculate coordinate of thread in input matrix" << std::endl; - os << "const unsigned int x = (blockX * " << blockSize << ") + " << getThreadID(0) << ";" << std::endl; - os << "const unsigned int y = (blockY * " << blockSize << ") + " << getThreadID(1) << ";" << std::endl; + CodeStream::Scope b(synEnv.getStream()); + synEnv.getStream() << "// Calculate coordinate of thread in input matrix" << std::endl; + synEnv.getStream() << "const unsigned int x = (blockX * " << blockSize << ") + " << getThreadID(0) << ";" << std::endl; + synEnv.getStream() << "const unsigned int y = (blockY * " << blockSize << ") + " << getThreadID(1) << ";" << std::endl; - os << "// If thread isn't off the 'right' edge of the input matrix" << std::endl; - os << "if(x < group->numTrgNeurons)"; + synEnv.getStream() << "// If thread isn't off the 'right' edge of the input matrix" << std::endl; + synEnv.getStream() << "if(x < group->numTrgNeurons)"; { - CodeStream::Scope b(os); - os << "// Loop through input rows " << std::endl; - os << "for (unsigned int j = 0; j < " << blockSize << "; j += 8)"; + CodeStream::Scope b(synEnv.getStream()); + synEnv.getStream() << "// Loop through input rows " << std::endl; + synEnv.getStream() << "for (unsigned int j = 0; j < " << blockSize << "; j += 8)"; { - CodeStream::Scope b(os); - os << "// If thread isn't off the 'bottom' edge of the input matrix" << std::endl; - os << "if((y + j) < group->numSrcNeurons)"; + CodeStream::Scope b(synEnv.getStream()); + synEnv.getStream() << "// If thread isn't off the 'bottom' edge of the input matrix" << std::endl; + synEnv.getStream() << "if((y + j) < group->numSrcNeurons)"; { - CodeStream::Scope b(os); - os << "// Read forward weight from global memory" << std::endl; - os << "const unsigned int idx = ((y + j) * group->numTrgNeurons) + x;" << std::endl; + CodeStream::Scope b(synEnv.getStream()); + synEnv.getStream() << "// Read forward weight from global memory" << std::endl; + synEnv.getStream() << "const unsigned int idx = ((y + j) * group->numTrgNeurons) + x;" << std::endl; - synSubs.addVarSubstitution("id_pre", "y"); - synSubs.addVarSubstitution("id_post", "x"); - synSubs.addVarSubstitution("id_syn", "idx"); - cg.generateCustomUpdate(*this, os, synSubs); + synEnv.addSubstitution("id_pre", "y"); + synEnv.addSubstitution("id_post", "x"); + synEnv.addSubstitution("id_syn", "idx"); + cg.generateCustomUpdate(*this, env); // Write forward weight to shared memory - os << "shTile[" << getThreadID(1) << " + j][" << getThreadID(0) << "] = l" << transposeVarName << ";" << std::endl; + synEnv.getStream() << "shTile[" << getThreadID(1) << " + j][" << getThreadID(0) << "] = l" << transposeVarName << ";" << std::endl; } } } } - genSharedMemBarrier(os); + genSharedMemBarrier(env.getStream()); { - CodeStream::Scope b(os); - os << "// Calculate (transposed) coordinate of thread in output matrix" << std::endl; - os << "const unsigned int x = (blockY * " << blockSize << ") + " << getThreadID(0) << ";" << std::endl; - os << "const unsigned int y = (blockX * " << blockSize << ") + " << getThreadID(1) << ";" << std::endl; + CodeStream::Scope b(synEnv.getStream()); + synEnv.getStream() << "// Calculate (transposed) coordinate of thread in output matrix" << std::endl; + synEnv.getStream() << "const unsigned int x = (blockY * " << blockSize << ") + " << getThreadID(0) << ";" << std::endl; + synEnv.getStream() << "const unsigned int y = (blockX * " << blockSize << ") + " << getThreadID(1) << ";" << std::endl; - os << "// If thread isn't off the 'right' edge of the output matrix" << std::endl; - os << "if(x < group->numSrcNeurons)"; + synEnv.getStream() << "// If thread isn't off the 'right' edge of the output matrix" << std::endl; + synEnv.getStream() << "if(x < group->numSrcNeurons)"; { - CodeStream::Scope b(os); - os << "// Loop through output rows" << std::endl; - os << "for(unsigned int j = 0; j < " << blockSize << "; j += 8)"; + CodeStream::Scope b(synEnv.getStream()); + synEnv.getStream() << "// Loop through output rows" << std::endl; + synEnv.getStream() << "for(unsigned int j = 0; j < " << blockSize << "; j += 8)"; { - CodeStream::Scope b(os); - os << "// If thread isn't off the 'bottom' edge of the output matrix" << std::endl; - os << "if((y + j) < group->numTrgNeurons)"; + CodeStream::Scope b(synEnv.getStream()); + synEnv.getStream() << "// If thread isn't off the 'bottom' edge of the output matrix" << std::endl; + synEnv.getStream() << "if((y + j) < group->numTrgNeurons)"; { - CodeStream::Scope b(os); - os << "group->" << transposeVarName << "Transpose["; + CodeStream::Scope b(synEnv.getStream()); + synEnv.getStream() << "group->" << transposeVarName << "Transpose["; if(cg.getArchetype().isBatched()) { - os << "batchOffset + "; + synEnv.getStream() << "batchOffset + "; } - os << "((y + j) * group->numSrcNeurons) + x] = shTile[" << getThreadID(0) << "][" << getThreadID(1) << " + j];" << std::endl; + synEnv.getStream() << "((y + j) * group->numSrcNeurons) + x] = shTile[" << getThreadID(0) << "][" << getThreadID(1) << " + j];" << std::endl; } } } diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index adc32d3677..74d7783abf 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -83,4 +83,5 @@ void EnvironmentSubstitute::addSubstitution(const std::string &source, const std size_t EnvironmentSubstitute::addInitialiser(const std::string &initialiser) { m_Initialisers.emplace_back(false, initialiser); + return (m_Initialisers.size() - 1); } \ No newline at end of file From 247f92f156794024a5398863816c11af90534139 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 31 Mar 2023 13:29:08 +0100 Subject: [PATCH 118/725] Not very nice fixes for Environment constructor-overloading issues - basically copy constructor was being used instead of actual constructors * Deleted copy constructor so, at least, compiler errors will be emitted rather than broken behaviour * Added static casts so correct ``EnvironmentExternal`` constructor called --- .../genn/genn/code_generator/environment.h | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index f01fb51a3f..93b45b5e34 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -19,6 +19,7 @@ namespace GeNN::CodeGenerator { class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase { +protected: using EnvironmentBase = Transpiler::PrettyPrinter::EnvironmentBase; public: EnvironmentExternal(EnvironmentBase &enclosing) @@ -30,6 +31,8 @@ class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase : m_Context(os) { } + + EnvironmentExternal(const EnvironmentExternal&) = delete; //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals @@ -59,15 +62,14 @@ class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase //! Standard pretty printing environment simply allowing substitutions to be implemented class EnvironmentSubstitute : public EnvironmentExternal { - using EnvironmentBase = Transpiler::PrettyPrinter::EnvironmentBase; public: EnvironmentSubstitute(EnvironmentSubstitute &enclosing) - : EnvironmentExternal(enclosing), m_Contents(m_ContentsStream) + : EnvironmentExternal(static_cast(enclosing)), m_Contents(m_ContentsStream) { } EnvironmentSubstitute(EnvironmentExternal &enclosing) - : EnvironmentExternal(enclosing), m_Contents(m_ContentsStream) + : EnvironmentExternal(static_cast(enclosing)), m_Contents(m_ContentsStream) { } @@ -75,6 +77,9 @@ class EnvironmentSubstitute : public EnvironmentExternal : EnvironmentExternal(os), m_Contents(m_ContentsStream) { } + + EnvironmentSubstitute(const EnvironmentSubstitute&) = delete; + ~EnvironmentSubstitute(); //------------------------------------------------------------------------ @@ -159,23 +164,25 @@ template class EnvironmentLocalVarCache : public EnvironmentExternal { //! Type of a single definition - typedef typename std::invoke_result_t::value_type DefType; + using DefType = typename std::invoke_result_t::value_type; //! Type of a single initialiser - typedef typename std::remove_reference_t>::mapped_type InitialiserType; + using InitialiserType = typename std::remove_reference_t>::mapped_type; //! Function used to provide index strings based on initialiser and access type - typedef std::function GetIndexFn; + using GetIndexFn = std::function; public: EnvironmentLocalVarCache(const G &group, EnvironmentExternal &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") - : EnvironmentExternal(enclosing), m_Group(group), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) + : EnvironmentExternal(static_cast(enclosing)), m_Group(group), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) { // Add name of each definition to map, initially with value set to value const auto defs = A(m_Group).getDefs(); std::transform(defs.cbegin(), defs.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), [](const auto &v){ return std::make_pair(v.name, false); }); } + + EnvironmentLocalVarCache(const EnvironmentLocalVarCache&) = delete; ~EnvironmentLocalVarCache() { From d4db86ea2336d55e5016623e8916b7dde97f2202 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 31 Mar 2023 16:20:47 +0100 Subject: [PATCH 119/725] ``GeNN::Transpiler::TypeChecker::EnvironmentBase::getTypes`` now returns vector to support overloaded types --- .../groupMergedTypeEnvironment.h | 9 +-- include/genn/genn/transpiler/typeChecker.h | 31 +++++++- src/genn/genn/transpiler/typeChecker.cc | 78 ++++++++++++++++++- 3 files changed, 108 insertions(+), 10 deletions(-) diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index a35155decb..735490b02a 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -84,12 +84,12 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa return EnvironmentBase::incDec(name, op, existingType->second.first, errorHandler); } - virtual const Type::Base *getType(const Token &name, ErrorHandlerBase &errorHandler) final + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final { auto type = m_Types.find(name.lexeme); if(type == m_Types.end()) { if(m_Enclosing) { - return m_Enclosing->getType(name, errorHandler); + return m_Enclosing->getTypes(name, errorHandler); } else { errorHandler.error(name, "Undefined variable"); @@ -100,7 +100,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa // Add field to merged group if required addField(type->second); - return type->second.first; + return {type->second.first}; } } @@ -109,8 +109,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa //--------------------------------------------------------------------------- void defineField(const Type::Base *type, const std::string &name) { - if(!m_Types.try_emplace(name, type, std::nullopt).second) - { + if(!m_Types.try_emplace(name, type, std::nullopt).second) { throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 0951d7cf06..96d40d5eed 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -4,6 +4,7 @@ #include #include #include +#include // GeNN includes #include "type.h" @@ -48,7 +49,12 @@ class EnvironmentBase bool initializer = false) = 0; virtual const Type::Base *incDec(const Token &name, Token::Type op, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) = 0; - virtual const Type::Base *getType(const Token &name, ErrorHandlerBase &errorHandler) = 0; + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) = 0; + + //--------------------------------------------------------------------------- + // Public API + //--------------------------------------------------------------------------- + const Type::Base *getType(const Token &name, ErrorHandlerBase &errorHandler); protected: //--------------------------------------------------------------------------- @@ -62,6 +68,29 @@ class EnvironmentBase const Type::Base *existingType, ErrorHandlerBase &errorHandler) const; }; +//--------------------------------------------------------------------------- +// GeNN::Transpiler::TypeChecker::StandardLibraryFunctionEnvironment +//--------------------------------------------------------------------------- +class StandardLibraryFunctionEnvironment : public EnvironmentBase +{ +public: + StandardLibraryFunctionEnvironment(); + + //------------------------------------------------------------------------ + // EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual void define(const Token &name, const Type::Base *type, ErrorHandlerBase &errorHandler) final; + virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, + const Type::TypeContext &context, ErrorHandlerBase &errorHandler, + bool initializer = false) final; + virtual const Type::Base *incDec(const Token &name, Token::Type op, + const Type::TypeContext &context, ErrorHandlerBase &errorHandler) final; + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final; + +private: + std::unordered_multimap m_Types; +}; + //--------------------------------------------------------------------------- // Free functions //--------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 7e6606a833..b5ebed7b9d 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -2,6 +2,7 @@ // Standard C++ includes #include +#include #include // Standard C includes @@ -116,15 +117,15 @@ class EnvironmentInternal : public EnvironmentBase // Perform standard type-checking logic return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); } - - virtual const Type::Base *getType(const Token &name, ErrorHandlerBase &errorHandler) final + + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final { auto type = m_Types.find(name.lexeme); if(type == m_Types.end()) { - return m_Enclosing.getType(name, errorHandler); + return m_Enclosing.getTypes(name, errorHandler); } else { - return type->second; + return {type->second}; } } @@ -674,6 +675,17 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- +const Type::Base *EnvironmentBase::getType(const Token &name, ErrorHandlerBase &errorHandler) +{ + const auto types = getTypes(name, errorHandler); + if (types.size() == 1) { + return types.front(); + } + else { + errorHandler.error(name, "Unambiguous type expected"); + } +} +//--------------------------------------------------------------------------- const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, const Type::Base *existingType, const Type::Base *assignedType, const Type::TypeContext &context, ErrorHandlerBase &errorHandler, @@ -770,6 +782,64 @@ const Type::Base *EnvironmentBase::incDec(const Token &name, Token::Type, } } +//--------------------------------------------------------------------------- +// GeNN::Transpiler::TypeChecker::StandardLibraryFunctionEnvironment +//--------------------------------------------------------------------------- +#define ADD_FLOAT_DOUBLE(NAME, CLASS_PREFIX) {#NAME, Type::CLASS_PREFIX##F::getInstance()}, {#NAME, Type::CLASS_PREFIX##D::getInstance()} +StandardLibraryFunctionEnvironment::StandardLibraryFunctionEnvironment() + : m_Types{ADD_FLOAT_DOUBLE(cos, Cos), ADD_FLOAT_DOUBLE(sin, Sin), ADD_FLOAT_DOUBLE(tan, Tan), + ADD_FLOAT_DOUBLE(acos, Acos), ADD_FLOAT_DOUBLE(asin, Asin), ADD_FLOAT_DOUBLE(atan, Atan), ADD_FLOAT_DOUBLE(atan2, Atan2), + ADD_FLOAT_DOUBLE(cosh, Cosh), ADD_FLOAT_DOUBLE(sinh, Sinh), ADD_FLOAT_DOUBLE(tanh, Tanh), + ADD_FLOAT_DOUBLE(exp, Exp), ADD_FLOAT_DOUBLE(expm1, ExpM1), ADD_FLOAT_DOUBLE(exp2, Exp2), ADD_FLOAT_DOUBLE(pow, Pow), + ADD_FLOAT_DOUBLE(scalbn, ScalBN), ADD_FLOAT_DOUBLE(log, Log), ADD_FLOAT_DOUBLE(log1p, Log1P), ADD_FLOAT_DOUBLE(log2, Log2), + ADD_FLOAT_DOUBLE(log10, Log10), ADD_FLOAT_DOUBLE(ldexp, LdExp), ADD_FLOAT_DOUBLE(ilogb, ILogB), + ADD_FLOAT_DOUBLE(sqrt, Sqrt), ADD_FLOAT_DOUBLE(cbrt, Cbrt), ADD_FLOAT_DOUBLE(hypot, Hypot), + ADD_FLOAT_DOUBLE(ceil, Ceil), ADD_FLOAT_DOUBLE(floor, Floor), ADD_FLOAT_DOUBLE(fmod, Fmod), + ADD_FLOAT_DOUBLE(round, Round), ADD_FLOAT_DOUBLE(rint, Rint), ADD_FLOAT_DOUBLE(trunc, Trunc), + ADD_FLOAT_DOUBLE(nearbyint, NearbyInt), ADD_FLOAT_DOUBLE(nextafter, NextAfter),ADD_FLOAT_DOUBLE(remainder, Remainder), + ADD_FLOAT_DOUBLE(fabs, FAbs), ADD_FLOAT_DOUBLE(fdim, FDim), ADD_FLOAT_DOUBLE(fmax, FMax), ADD_FLOAT_DOUBLE(fmin, FMin), + ADD_FLOAT_DOUBLE(erf, Erf), ADD_FLOAT_DOUBLE(erfc, ErfC), ADD_FLOAT_DOUBLE(tgamma, TGamma), ADD_FLOAT_DOUBLE(lgamma, LGamma), + ADD_FLOAT_DOUBLE(copysign, CopySign), ADD_FLOAT_DOUBLE(fma, FMA)} +{ +} +#undef ADD_FLOAT_DOUBLE +//------------------------------------------------------------------------ +void StandardLibraryFunctionEnvironment::define(const Token &name, const Type::Base*, ErrorHandlerBase &errorHandler) +{ + errorHandler.error(name, "Cannot declare variable in external environment"); + throw TypeCheckError(); +} +//--------------------------------------------------------------------------- +const Type::Base *StandardLibraryFunctionEnvironment::assign(const Token &name, Token::Type, const Type::Base*, + const Type::TypeContext&, ErrorHandlerBase &errorHandler, bool) +{ + errorHandler.error(name, "Cannot assign variable in external environment"); + throw TypeCheckError(); +} +//--------------------------------------------------------------------------- +const Type::Base *StandardLibraryFunctionEnvironment::incDec(const Token &name, Token::Type, const Type::TypeContext&, + ErrorHandlerBase &errorHandler) +{ + errorHandler.error(name, "Cannot increment/decrement variable in external environment"); + throw TypeCheckError(); +} +//--------------------------------------------------------------------------- +std::vector StandardLibraryFunctionEnvironment::getTypes(const Token &name, ErrorHandlerBase &errorHandler) +{ + auto [typeBegin, typeEnd] = m_Types.equal_range(name.lexeme); + if (typeBegin == typeEnd) { + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); + } + else { + std::vector types; + types.reserve(std::distance(typeBegin, typeEnd)); + std::transform(typeBegin, typeEnd, std::back_inserter(types), + [](auto t) { return t.second; }); + return types; + } +} + //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker //--------------------------------------------------------------------------- From fcb41962bf1beefa6767c18785c986dda78d1850 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 31 Mar 2023 16:28:30 +0100 Subject: [PATCH 120/725] Hooked up standard library type environment --- .../genn/code_generator/customUpdateGroupMerged.cc | 6 ++++-- src/genn/genn/transpiler/typeChecker.cc | 10 +++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 936a1cad52..0d274416fd 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -34,7 +34,8 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::TypeC using namespace Type; // Create type environment - GroupMergedTypeEnvironment typeEnvironment(*this); + TypeChecker::StandardLibraryFunctionEnvironment stdLibraryEnv; + GroupMergedTypeEnvironment typeEnvironment(*this, &stdLibraryEnv); addField("size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); @@ -266,7 +267,8 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const using namespace Type; // Create type environment - GroupMergedTypeEnvironment typeEnvironment(*this); + TypeChecker::StandardLibraryFunctionEnvironment stdLibraryEnv; + GroupMergedTypeEnvironment typeEnvironment(*this, &stdLibraryEnv); // If underlying synapse group has kernel weights if (getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL) { diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index b5ebed7b9d..fd83c35f88 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -428,7 +428,14 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Variable &variable) { - setExpressionType(&variable, m_Environment.get().getType(variable.getName(), m_ErrorHandler)); + const auto varTypes = m_Environment.get().getTypes(variable.getName(), m_ErrorHandler); + if (varTypes.size() == 1) { + setExpressionType(&variable, varTypes.front()); + } + else { + // **TODO** handler overload resolution + assert(false); + } } virtual void visit(const Expression::Unary &unary) final @@ -683,6 +690,7 @@ const Type::Base *EnvironmentBase::getType(const Token &name, ErrorHandlerBase & } else { errorHandler.error(name, "Unambiguous type expected"); + throw TypeCheckError(); } } //--------------------------------------------------------------------------- From 5cbac3591a5a069cdbbdbbc46846824db0f1e2aa Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 31 Mar 2023 17:09:43 +0100 Subject: [PATCH 121/725] started implementing function overload resolution --- src/genn/genn/transpiler/typeChecker.cc | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index fd83c35f88..d59a279101 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -288,10 +288,19 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Call &call) final { + // **TODO** think about nested calls + assert(m_CallArguments.empty()); + + // Evaluate argument types and store in class + m_CallArguments.clear(); + std::transform(call.getArguments().cbegin(), call.getArguments().cend(), std::back_inserter(m_CallArguments), + [this](const auto &a){ return evaluateType(a.get()); }); + // Evaluate callee type auto calleeType = evaluateType(call.getCallee()); auto calleeFunctionType = dynamic_cast(calleeType); + m_CallArguments.clear(); // If callee's a function if (calleeFunctionType) { // If argument count doesn't match @@ -428,11 +437,23 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Variable &variable) { + // If type of variable is unambiguous, const auto varTypes = m_Environment.get().getTypes(variable.getName(), m_ErrorHandler); if (varTypes.size() == 1) { setExpressionType(&variable, varTypes.front()); } + // Otherwise else { + // If there are no call arguments to disambiguate, give error + if (m_CallArguments.empty()) { + m_ErrorHandler.error(variable.getName(), + "Ambiguous identifier '" + variable.getName().lexeme + "'"); + throw TypeCheckError(); + } + else { + // 1) Viable - same number of arguments + // 2) Overload resolution + } // **TODO** handler overload resolution assert(false); } @@ -674,6 +695,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const Type::TypeContext &m_Context; ErrorHandlerBase &m_ErrorHandler; ResolvedTypeMap &m_ResolvedTypes; + std::vector m_CallArguments; bool m_InLoop; bool m_InSwitch; }; From 33973608a3c3b4efc0927f1d50e7f92cb347258b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 17 Apr 2023 14:34:22 +0100 Subject: [PATCH 122/725] fixed some GCC compile errors --- .../genn/genn/code_generator/backendSIMT.h | 1 + .../genn/genn/code_generator/environment.h | 20 +++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 68669dd8ec..f742e259cc 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -12,6 +12,7 @@ // GeNN code generator includes #include "code_generator/backendBase.h" #include "code_generator/codeStream.h" +#include "code_generator/environment.h" #include "code_generator/presynapticUpdateStrategySIMT.h" #include "code_generator/substitutions.h" diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 93b45b5e34..cab6882808 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -5,6 +5,10 @@ #include #include +// GeNN includes +#include "gennUtils.h" +#include "varAccess.h" + // GeNN code generator includes #include "code_generator/codeStream.h" @@ -26,14 +30,14 @@ class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase : m_Context(enclosing) { } - + EnvironmentExternal(CodeStream &os) : m_Context(os) { } EnvironmentExternal(const EnvironmentExternal&) = delete; - + //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ @@ -44,9 +48,9 @@ class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase // Protected API //------------------------------------------------------------------------ auto &getContext() const{ return m_Context; } - + CodeStream &getContextStream() const; - + std::string getContextName(const std::string &name) const; private: @@ -81,23 +85,23 @@ class EnvironmentSubstitute : public EnvironmentExternal EnvironmentSubstitute(const EnvironmentSubstitute&) = delete; ~EnvironmentSubstitute(); - + //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ virtual std::string getName(const std::string &name) final; - + virtual CodeStream &getStream() final { return m_Contents; } - + //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ void addSubstitution(const std::string &source, const std::string &destination, std::vector initialisers = {}); - + size_t addInitialiser(const std::string &initialiser); template From fea14c3733c68867d1737a961ce14c7ba25e85aa Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 17 Apr 2023 14:34:41 +0100 Subject: [PATCH 123/725] started adding some type checker unit test for Call --- tests/unit/typeChecker.cc | 58 +++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 4f8901d541..e5cf5ab4ec 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -63,20 +63,20 @@ class TestEnvironment : public TypeChecker::EnvironmentBase throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } - + template void define(const std::string &name, Type::Qualifier qualifiers = Type::Qualifier{0}) { define(name, T::getInstance()->getQualifiedType(qualifiers)); } - + template void definePointer(const std::string &name, Type::Qualifier valueQualifiers = Type::Qualifier{0}, Type::Qualifier pointerQualifiers = Type::Qualifier{0}) { define(name, T::getInstance()->getQualifiedType(valueQualifiers)->getPointerType(pointerQualifiers)); } - + //--------------------------------------------------------------------------- // EnvironmentBase virtuals @@ -86,7 +86,7 @@ class TestEnvironment : public TypeChecker::EnvironmentBase errorHandler.error(name, "Cannot declare variable in external environment"); throw TypeChecker::TypeCheckError(); } - + virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, const Type::TypeContext &context, ErrorHandlerBase &errorHandler, bool initializer = false) final @@ -97,11 +97,11 @@ class TestEnvironment : public TypeChecker::EnvironmentBase errorHandler.error(name, "Undefined variable"); throw TypeChecker::TypeCheckError(); } - + // Perform standard type-checking logic return EnvironmentBase::assign(name, op, existingType->second, assignedType, context, errorHandler, initializer); } - + virtual const Type::Base *incDec(const Token &name, Token::Type op, const Type::TypeContext&, ErrorHandlerBase &errorHandler) final { @@ -110,12 +110,12 @@ class TestEnvironment : public TypeChecker::EnvironmentBase errorHandler.error(name, "Undefined variable"); throw TypeChecker::TypeCheckError(); } - + // Perform standard type-checking logic return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); } - - virtual const Type::Base *getType(const Token &name, ErrorHandlerBase &errorHandler) final + + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final { auto type = m_Types.find(std::string{name.lexeme}); if(type == m_Types.end()) { @@ -123,10 +123,10 @@ class TestEnvironment : public TypeChecker::EnvironmentBase throw TypeChecker::TypeCheckError(); } else { - return type->second; + return {type->second}; } } - + private: //--------------------------------------------------------------------------- // Members @@ -150,7 +150,7 @@ void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment // Parse const auto statements = Parser::parseBlockItemList(tokens, errorHandler); ASSERT_FALSE(errorHandler.hasError()); - + // Typecheck TypeChecker::typeCheck(statements, typeEnvironment, typeContext, errorHandler); ASSERT_FALSE(errorHandler.hasError()); @@ -166,7 +166,7 @@ const Type::Base *typeCheckExpression(std::string_view code, TestEnvironment &ty // Parse const auto expression = Parser::parseExpression(tokens, errorHandler); EXPECT_FALSE(errorHandler.hasError()); - + // Typecheck const auto *type = TypeChecker::typeCheck(expression.get(), typeEnvironment, typeContext, errorHandler); EXPECT_FALSE(errorHandler.hasError()); @@ -189,7 +189,7 @@ TEST(TypeChecker, ArraySubscript) } // Pointer to pointer, double indexing - + // Float array indexing EXPECT_THROW({ TestEnvironment typeEnvironment; @@ -280,7 +280,7 @@ TEST(TypeChecker, Binary) } // **TODO** constness and - + // Pointer + non-integer EXPECT_THROW({ TestEnvironment typeEnvironment; @@ -328,6 +328,30 @@ TEST(TypeChecker, Binary) //-------------------------------------------------------------------------- TEST(TypeChecker, Call) { + // Too few arguments + TypeChecker::StandardLibraryFunctionEnvironment stdLibraryEnv; + EXPECT_THROW({ + typeCheckExpression("sin()", stdLibraryEnv);}, + TypeChecker::TypeCheckError); + + // Too many arguments + EXPECT_THROW({ + typeCheckExpression("sin(1.0f, 2.0f)", stdLibraryEnv);}, + TypeChecker::TypeCheckError); + + // Floating point trascendental function + { + const auto *type = typeCheckExpression("sin(1.0f)", stdLibraryEnv); + EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); + } + + // Double trascendental function + { + const auto *type = typeCheckExpression("sin(1.0d)", stdLibraryEnv); + EXPECT_EQ(type->getName(), Type::Double::getInstance()->getName()); + EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); + } } //-------------------------------------------------------------------------- TEST(TypeChecker, Cast) @@ -356,7 +380,7 @@ TEST(TypeChecker, Cast) typeEnvironment.definePointer("intArray"); const auto *type = typeCheckExpression("(const int*)intArray", typeEnvironment); EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); - + const auto *pointerType = dynamic_cast(type); EXPECT_TRUE(pointerType); EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); @@ -369,7 +393,7 @@ TEST(TypeChecker, Cast) typeEnvironment.definePointer("intArray"); const auto *type = typeCheckExpression("(int * const)intArray", typeEnvironment); EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONSTANT)); - + const auto *pointerType = dynamic_cast(type); EXPECT_TRUE(pointerType); EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); From c7c5003decbe72356776752e9b26f33b95d5eb6e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 17 Apr 2023 14:54:43 +0100 Subject: [PATCH 124/725] more typos --- tests/unit/typeChecker.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index e5cf5ab4ec..509c026d53 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -140,7 +140,7 @@ std::string getPointerTypeName() return T::getInstance()->getPointerType()->getName(); } -void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment, const Type::TypeContext &typeContext = {}) +void typeCheckStatements(std::string_view code, TypeChecker::EnvironmentBase &typeEnvironment, const Type::TypeContext &typeContext = {}) { // Scan TestErrorHandler errorHandler; @@ -156,7 +156,7 @@ void typeCheckStatements(std::string_view code, TestEnvironment &typeEnvironment ASSERT_FALSE(errorHandler.hasError()); } -const Type::Base *typeCheckExpression(std::string_view code, TestEnvironment &typeEnvironment, const Type::TypeContext &typeContext = {}) +const Type::Base *typeCheckExpression(std::string_view code, TypeChecker::EnvironmentBase &typeEnvironment, const Type::TypeContext &typeContext = {}) { // Scan TestErrorHandler errorHandler; From 5f3ae2ed0d484899a05a201b730e9c5fe88f0d0a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 17 Apr 2023 16:14:03 +0100 Subject: [PATCH 125/725] whitespace tidy --- .../genn/genn/code_generator/environment.h | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index cab6882808..bc79df2a85 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -169,10 +169,10 @@ class EnvironmentLocalVarCache : public EnvironmentExternal { //! Type of a single definition using DefType = typename std::invoke_result_t::value_type; - + //! Type of a single initialiser using InitialiserType = typename std::remove_reference_t>::mapped_type; - + //! Function used to provide index strings based on initialiser and access type using GetIndexFn = std::function; @@ -187,17 +187,17 @@ class EnvironmentLocalVarCache : public EnvironmentExternal } EnvironmentLocalVarCache(const EnvironmentLocalVarCache&) = delete; - + ~EnvironmentLocalVarCache() { A adapter(m_Group); - + // Copy definitions which have been referenced into new vector const auto defs = adapter.getDefs(); std::remove_const_t referencedVars; std::copy_if(defs.cbegin(), defs.cend(), std::back_inserter(referencedVars), [this](const auto &v){ return m_VariablesReferenced.at(v.name); }); - + // Loop through referenced variables const auto &initialisers = adapter.getInitialisers(); for(const auto &v : referencedVars) { @@ -205,7 +205,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal getContextStream() << "const "; } getContextStream() << v.type->getName() << " " << m_LocalPrefix << v.name; - + // If this isn't a reduction, read value from memory // **NOTE** by not initialising these variables for reductions, // compilers SHOULD emit a warning if user code doesn't set it to something @@ -214,10 +214,10 @@ class EnvironmentLocalVarCache : public EnvironmentExternal } getContextStream() << ";" << std::endl; } - + // Write contents to context stream getContextStream() << m_ContentsStream.str(); - + // Loop through referenced variables again for(const auto &v : referencedVars) { // If variables are read-write @@ -242,17 +242,17 @@ class EnvironmentLocalVarCache : public EnvironmentExternal else { // Set flag to indicate that variable has been referenced var->second = true; - + // Add local prefix to variable name return m_LocalPrefix + name; } } - + virtual CodeStream &getStream() final { return m_Contents; } - + private: //------------------------------------------------------------------------ // Members From 46178b98e1cce7a7f9244231d9d32fee00341fec Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 17 Apr 2023 16:14:28 +0100 Subject: [PATCH 126/725] first pass at viable function overload identification logic --- src/genn/genn/transpiler/typeChecker.cc | 79 ++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 7 deletions(-) diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index d59a279101..1cdbcac7ca 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -4,7 +4,7 @@ #include #include #include - +#include // Standard C includes #include @@ -266,7 +266,6 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If operator is a shift, promote left type if (opType == Token::Type::SHIFT_LEFT || opType == Token::Type::SHIFT_RIGHT) { - setExpressionType(&binary, Type::getPromotedType(leftNumericType, m_Context)); } // Otherwise, take common type @@ -444,16 +443,82 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Otherwise else { - // If there are no call arguments to disambiguate, give error - if (m_CallArguments.empty()) { + // Loop through variable types + std::vector>> viableFunctions; + for(const auto *type : varTypes) { + // Cast to function (only functions should be overloaded) + const auto *func = dynamic_cast(type); + assert(func); + + // If number of arguments match + const auto argumentTypes = func->getArgumentTypes(); + if(argumentTypes.size() == m_CallArguments.size()) { + // Create vector to hold argument conversion rank + std::vector argumentConversionRank; + argumentConversionRank.reserve(m_CallArguments.size()); + + // Loop through arguments + bool viable = true; + auto c = m_CallArguments.cbegin(); + auto a = argumentTypes.cbegin(); + for(;c != m_CallArguments.cend(); c++, a++) { + auto cNumericType = dynamic_cast(*c); + auto aNumericType = dynamic_cast(*a); + + // If both are numeric + if(cNumericType && aNumericType) { + // If names are identical (we don't care about qualifiers), match is exact + if(cNumericType->getName() == aNumericType->getName()) { + argumentConversionRank.push_back(0); + } + // Integer promotion + else if(aNumericType->getName() == Type::Int32::getInstance()->getName() + && cNumericType->isIntegral(m_Context) + && cNumericType->getRank(m_Context) < Type::Int32::getInstance()->getRank(m_Context)) + { + argumentConversionRank.push_back(1); + } + // Float promotion + else if(aNumericType->getName() == Type::Double::getInstance()->getName() + && cNumericType->getName() == Type::Float::getInstance()->getName()) + { + argumentConversionRank.push_back(1); + } + // Otherwise, numeric conversion + else { + argumentConversionRank.push_back(2); + } + } + // Otherwise, if they are matching pointers + // **TODO** some more nuance here + else if(checkPointerTypeAssignement(*c, *a, m_Context)) { + argumentConversionRank.push_back(0); + } + // Otherwise, this function is not viable + else { + viable = false; + break; + } + } + + // If function is viable, add to vector along with vector of conversion ranks + if(viable) { + assert(argumentConversionRank.size() == m_CallArguments.size()); + viableFunctions.emplace_back(func, argumentConversionRank); + } + } + } + + if(viableFunctions.empty()) { m_ErrorHandler.error(variable.getName(), - "Ambiguous identifier '" + variable.getName().lexeme + "'"); + "No viable function candidates for '" + variable.getName().lexeme + "'"); throw TypeCheckError(); } else { - // 1) Viable - same number of arguments - // 2) Overload resolution + std::cout << viableFunctions.size() << " function candidates" << std::endl;; } + + // **TODO** handler overload resolution assert(false); } From d7e20c9242c872baa90461e7dbabb87e44ed2d9c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 17 Apr 2023 16:28:45 +0100 Subject: [PATCH 127/725] fixed parser bug --- src/genn/genn/transpiler/parser.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 6002a13311..7be5ddda84 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -279,7 +279,7 @@ Expression::ExpressionPtr parsePostfix(ParserState &parserState) if(!parserState.check(Token::Type::RIGHT_PAREN)) { do { arguments.emplace_back(parseAssignment(parserState)); - } while(parserState.check(Token::Type::COMMA)); + } while(parserState.match(Token::Type::COMMA)); } Token closingParen = parserState.consume(Token::Type::RIGHT_PAREN, From 267312b0484f115f5b21b8f7d77470e9956e4373 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 17 Apr 2023 17:08:38 +0100 Subject: [PATCH 128/725] unit tests for nested function calls and scalar parameters --- tests/unit/typeChecker.cc | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 509c026d53..3aaab1d510 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -306,7 +306,7 @@ TEST(TypeChecker, Binary) const auto *type = typeCheckExpression("intArray - offset", typeEnvironment); const auto *pointerType = dynamic_cast(type); EXPECT_TRUE(pointerType); - EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); + EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); } // Integer + pointer @@ -339,18 +339,36 @@ TEST(TypeChecker, Call) typeCheckExpression("sin(1.0f, 2.0f)", stdLibraryEnv);}, TypeChecker::TypeCheckError); - // Floating point trascendental function + // Floating point transcendental function { const auto *type = typeCheckExpression("sin(1.0f)", stdLibraryEnv); EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); } - // Double trascendental function + // Double transcendental function { const auto *type = typeCheckExpression("sin(1.0d)", stdLibraryEnv); EXPECT_EQ(type->getName(), Type::Double::getInstance()->getName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); + } + + // Float scalar transcendental function + { + const Type::TypeContext typeContext{{"scalar", Type::Float::getInstance()}}; + const auto *type = typeCheckExpression("sin(1.0)", stdLibraryEnv, typeContext); + EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); + } + + // Double scalar transcendental function + { + const Type::TypeContext typeContext{{"scalar", Type::Double::getInstance()}}; + const auto *type = typeCheckExpression("sin(1.0)", stdLibraryEnv, typeContext); + EXPECT_EQ(type->getName(), Type::Double::getInstance()->getName()); + } + + // Nested transcendental function + { + const auto *type = typeCheckExpression("sin(fmax(0.0f, 1.0f))", stdLibraryEnv); + EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); } } //-------------------------------------------------------------------------- @@ -508,7 +526,7 @@ TEST(TypeChecker, Literal) { TestEnvironment typeEnvironment; const Type::TypeContext typeContext{{"scalar", Type::Float::getInstance()}}; - const auto *type = typeCheckExpression("1.0", typeEnvironment); + const auto *type = typeCheckExpression("1.0", typeEnvironment, typeContext); EXPECT_EQ(type->getResolvedName(typeContext), Type::Float::getInstance()->getName()); } From 73121316e7595d8ffdfb3908fb7651834b006c36 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 17 Apr 2023 17:11:05 +0100 Subject: [PATCH 129/725] Small tweaks * Use resolved name when type checking function calls * Remove function call argument checking in ``Expression::Call`` visitor and just ensure that functions always get checked for validity in ``Expression::Variable`` visitor --- src/genn/genn/transpiler/typeChecker.cc | 58 ++++++++++--------------- 1 file changed, 23 insertions(+), 35 deletions(-) diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 1cdbcac7ca..1ef31e15b9 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -77,7 +77,7 @@ class EnvironmentInternal : public EnvironmentBase : m_Enclosing(enclosing) { } - + //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- @@ -88,7 +88,7 @@ class EnvironmentInternal : public EnvironmentBase throw TypeCheckError(); } } - + virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, const Type::TypeContext &context, ErrorHandlerBase &errorHandler, bool initializer = false) final @@ -99,12 +99,12 @@ class EnvironmentInternal : public EnvironmentBase return m_Enclosing.assign(name, op, assignedType, context, errorHandler, initializer); } - + // Perform standard type-checking logic return EnvironmentBase::assign(name, op, existingType->second, assignedType, context, errorHandler, initializer); } - + virtual const Type::Base *incDec(const Token &name, Token::Type op, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) final { @@ -113,7 +113,7 @@ class EnvironmentInternal : public EnvironmentBase if(existingType == m_Types.end()) { return m_Enclosing.incDec(name, op, context, errorHandler); } - + // Perform standard type-checking logic return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); } @@ -302,26 +302,12 @@ class Visitor : public Expression::Visitor, public Statement::Visitor m_CallArguments.clear(); // If callee's a function if (calleeFunctionType) { - // If argument count doesn't match - const auto argTypes = calleeFunctionType->getArgumentTypes(); - if (call.getArguments().size() < argTypes.size()) { - m_ErrorHandler.error(call.getClosingParen(), "Too many arguments to function"); - throw TypeCheckError(); - } - else if (call.getArguments().size() > argTypes.size()) { - m_ErrorHandler.error(call.getClosingParen(), "Too few arguments to function"); - throw TypeCheckError(); - } - else { - // Loop through arguments - // **TODO** check - /*for(size_t i = 0; i < argTypes.size(); i++) { - // Evaluate argument type - auto callArgType = evaluateType(call.getArguments().at(i).get()); - }*/ - // Type is return type of function - setExpressionType(&call, calleeFunctionType->getReturnType()); - } + // Assert that argument count matches + // **NOTE** this should have been handled when visiting Expression::Variable + assert(call.getArguments().size() == calleeFunctionType->getArgumentTypes().size()); + + // Type is return type of function + setExpressionType(&call, calleeFunctionType->getReturnType()); } // Otherwise else { @@ -436,9 +422,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Variable &variable) { - // If type of variable is unambiguous, + // If type is unambiguous and not a function const auto varTypes = m_Environment.get().getTypes(variable.getName(), m_ErrorHandler); - if (varTypes.size() == 1) { + if (varTypes.size() == 1 && dynamic_cast(varTypes.front()) == nullptr) { setExpressionType(&variable, varTypes.front()); } // Otherwise @@ -468,7 +454,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If both are numeric if(cNumericType && aNumericType) { // If names are identical (we don't care about qualifiers), match is exact - if(cNumericType->getName() == aNumericType->getName()) { + if(cNumericType->getResolvedName(m_Context) == aNumericType->getResolvedName(m_Context)) { argumentConversionRank.push_back(0); } // Integer promotion @@ -479,12 +465,13 @@ class Visitor : public Expression::Visitor, public Statement::Visitor argumentConversionRank.push_back(1); } // Float promotion - else if(aNumericType->getName() == Type::Double::getInstance()->getName() - && cNumericType->getName() == Type::Float::getInstance()->getName()) + else if(aNumericType->getResolvedName(m_Context) == Type::Double::getInstance()->getName() + && cNumericType->getResolvedName(m_Context) == Type::Float::getInstance()->getName()) { argumentConversionRank.push_back(1); } // Otherwise, numeric conversion + // **TODO** integer to scalar promotion should be lower ranked than general conversion else { argumentConversionRank.push_back(2); } @@ -509,18 +496,19 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } } + // If there are no viable candidates, give error if(viableFunctions.empty()) { m_ErrorHandler.error(variable.getName(), "No viable function candidates for '" + variable.getName().lexeme + "'"); throw TypeCheckError(); } + // Otherwise, sort lexigraphically by conversion rank and return type of lowest + // **TODO** handle case when best is ambiguous else { - std::cout << viableFunctions.size() << " function candidates" << std::endl;; + std::sort(viableFunctions.begin(), viableFunctions.end(), + [](auto &f1, auto &f2){ return (f1.second < f2.second); }); + setExpressionType(&variable, viableFunctions.front().first); } - - - // **TODO** handler overload resolution - assert(false); } } From 80eba6c9b0cbc89e26a7ad020d0ab4f438876675 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 17 Apr 2023 17:38:20 +0100 Subject: [PATCH 130/725] string scanning --- include/genn/genn/transpiler/token.h | 2 +- src/genn/genn/transpiler/scanner.cc | 24 +++++++++++++++++++----- tests/unit/scanner.cc | 17 ++++++++++++++++- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/include/genn/genn/transpiler/token.h b/include/genn/genn/transpiler/token.h index e3878e2382..9cc5b53881 100644 --- a/include/genn/genn/transpiler/token.h +++ b/include/genn/genn/transpiler/token.h @@ -37,7 +37,7 @@ struct Token SHIFT_LEFT_EQUAL, SHIFT_RIGHT_EQUAL, // Literals - IDENTIFIER, UINT32_NUMBER, INT32_NUMBER, FLOAT_NUMBER, DOUBLE_NUMBER, SCALAR_NUMBER, + IDENTIFIER, UINT32_NUMBER, INT32_NUMBER, FLOAT_NUMBER, DOUBLE_NUMBER, SCALAR_NUMBER, STRING, // Types TYPE_SPECIFIER, diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index f534456ade..476138b596 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -167,7 +167,7 @@ Token::Type scanIntegerSuffix(ScanState &scanState) return integerLiteralTokenTypes.at(suffix); } //--------------------------------------------------------------------------- -void scanNumber(char c, ScanState &scanState, std::vector &tokens) +void scanNumber(char c, ScanState &scanState, std::vector &tokens) { // If this is a hexadecimal literal if(c == '0' && (scanState.match('x') || scanState.match('X'))) { @@ -222,7 +222,7 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) scanState.advance(); } } - + // If number has an f suffix, emplace FLOAT_NUMBER token if (std::tolower(scanState.peek()) == 'f') { scanState.advance(); @@ -246,6 +246,17 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) } } //--------------------------------------------------------------------------- +void scanString(ScanState &scanState, std::vector &tokens) +{ + // Read until end of string + // **TODO** more complex logic here + while(scanState.peek() != '"') { + scanState.advance(); + } + scanState.match('"'); + emplaceToken(tokens, Token::Type::STRING, scanState); +} +//--------------------------------------------------------------------------- void scanIdentifier(ScanState &scanState, std::vector &tokens) { // Read subsequent alphanumeric characters and underscores @@ -289,13 +300,13 @@ void scanToken(ScanState &scanState, std::vector &tokens) // Operators case '!': emplaceToken(tokens, scanState.match('=') ? Token::Type::NOT_EQUAL : Token::Type::NOT, scanState); break; case '=': emplaceToken(tokens, scanState.match('=') ? Token::Type::EQUAL_EQUAL : Token::Type::EQUAL, scanState); break; - + // Assignment operators case '*': emplaceToken(tokens, scanState.match('=') ? Token::Type::STAR_EQUAL : Token::Type::STAR, scanState); break; //case '/': emplaceToken(tokens, scanState.match('=') ? Token::Type::SLASH_EQUAL : Token::Type::SLASH, scanState); break; case '%': emplaceToken(tokens, scanState.match('=') ? Token::Type::PERCENT_EQUAL : Token::Type::PERCENT, scanState); break; case '^': emplaceToken(tokens, scanState.match('=') ? Token::Type::CARET_EQUAL : Token::Type::CARET, scanState); break; - + case '<': { if(scanState.match('=')) { @@ -389,7 +400,7 @@ void scanToken(ScanState &scanState, std::vector &tokens) } break; } - + case '/': { // Line comment @@ -404,6 +415,9 @@ void scanToken(ScanState &scanState, std::vector &tokens) break; } + // String + case '"': scanString(scanState, tokens); break; + // Whitespace case ' ': case '\r': diff --git a/tests/unit/scanner.cc b/tests/unit/scanner.cc index 3d5c102bf3..351a3836b7 100644 --- a/tests/unit/scanner.cc +++ b/tests/unit/scanner.cc @@ -118,4 +118,19 @@ TEST(Scanner, DecimalFloat) ASSERT_EQ(tokens[3].lexeme, "0.2f"); ASSERT_EQ(tokens[5].lexeme, "12.0d"); ASSERT_EQ(tokens[7].lexeme, "0.0004f"); -} \ No newline at end of file +} +//-------------------------------------------------------------------------- +TEST(Scanner, String) +{ + TestErrorHandler errorHandler; + const auto tokens = Scanner::scanSource("\"hello world\" \"pre-processor\"", errorHandler); + ASSERT_FALSE(errorHandler.hasError()); + + ASSERT_EQ(tokens.size(), 3); + ASSERT_EQ(tokens[0].type, Token::Type::STRING); + ASSERT_EQ(tokens[1].type, Token::Type::STRING); + ASSERT_EQ(tokens[2].type, Token::Type::END_OF_FILE); + + ASSERT_EQ(tokens[0].lexeme, "\"hello world\""); + ASSERT_EQ(tokens[1].lexeme, "\"pre-processor\""); +} From 36497f6956fb1f1b62b58f83c665b57610981447 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 17 Apr 2023 17:42:18 +0100 Subject: [PATCH 131/725] string parsing, type checking and type checker unit test --- src/genn/genn/transpiler/parser.cc | 6 +++--- src/genn/genn/transpiler/typeChecker.cc | 3 +++ tests/unit/typeChecker.cc | 11 +++++++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 7be5ddda84..36a9db7b05 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -237,9 +237,9 @@ Expression::ExpressionPtr parsePrimary(ParserState &parserState) // identifier // constant // "(" expression ")" - if (parserState.match({Token::Type::FALSE, Token::Type::TRUE, Token::Type::DOUBLE_NUMBER, - Token::Type::FLOAT_NUMBER, Token::Type::SCALAR_NUMBER, - Token::Type::INT32_NUMBER, Token::Type::UINT32_NUMBER})) { + if (parserState.match({Token::Type::FALSE, Token::Type::TRUE, Token::Type::STRING, + Token::Type::DOUBLE_NUMBER, Token::Type::FLOAT_NUMBER, Token::Type::SCALAR_NUMBER, + Token::Type::INT32_NUMBER, Token::Type::UINT32_NUMBER})) { return std::make_unique(parserState.previous()); } else if(parserState.match(Token::Type::IDENTIFIER)) { diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 1ef31e15b9..5782eb7599 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -394,6 +394,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor else if (literal.getValue().type == Token::Type::UINT32_NUMBER) { setExpressionType(&literal); } + else if(literal.getValue().type == Token::Type::STRING) { + setExpressionType(&literal, Type::Int8::getInstance()->getPointerType()); + } else { assert(false); } diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 3aaab1d510..dea0fc7e5c 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -529,7 +529,7 @@ TEST(TypeChecker, Literal) const auto *type = typeCheckExpression("1.0", typeEnvironment, typeContext); EXPECT_EQ(type->getResolvedName(typeContext), Type::Float::getInstance()->getName()); } - + // Scalar with double-precision { TestEnvironment typeEnvironment; @@ -537,7 +537,7 @@ TEST(TypeChecker, Literal) const auto *type = typeCheckExpression("1.0", typeEnvironment, typeContext); EXPECT_EQ(type->getResolvedName(typeContext), Type::Double::getInstance()->getName()); } - + // Double { TestEnvironment typeEnvironment; @@ -558,6 +558,13 @@ TEST(TypeChecker, Literal) const auto *type = typeCheckExpression("100U", typeEnvironment); EXPECT_EQ(type->getName(), Type::Uint32::getInstance()->getName()); } + + // String + { + TestEnvironment typeEnvironment; + const auto *type = typeCheckExpression("\"hello world\"", typeEnvironment); + EXPECT_EQ(type->getName(), Type::Int8::getInstance()->getPointerType()->getName()); + } } //-------------------------------------------------------------------------- TEST(TypeChecker, Unary) From e5adf11977bc1c8a91ffddb3b2c780fd0b1ceead Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 17 Apr 2023 18:07:02 +0100 Subject: [PATCH 132/725] start of variadic printf implementation --- include/genn/genn/type.h | 36 ++++++++++++++++++++++++- src/genn/genn/transpiler/typeChecker.cc | 21 +++++++-------- src/genn/genn/type.cc | 22 +++++++++++++++ tests/unit/typeChecker.cc | 19 +++++++++++++ 4 files changed, 86 insertions(+), 12 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index a94a48a143..db698fdb75 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -299,6 +299,11 @@ class FunctionBase : public Base //------------------------------------------------------------------------ virtual const Base *getReturnType() const = 0; virtual std::vector getArgumentTypes() const = 0; + + //---------------------------------------------------------------------------- + // Public API + //---------------------------------------------------------------------------- + bool isVariadic() const; }; //---------------------------------------------------------------------------- @@ -309,7 +314,7 @@ class Function : public FunctionBase { public: Function(Qualifier qualifiers = Qualifier{0}) : FunctionBase(qualifiers){} - + //------------------------------------------------------------------------ // Base virtuals //------------------------------------------------------------------------ @@ -409,6 +414,35 @@ DECLARE_NUMERIC_TYPE(Uint32, uint32_t, 30, "u"); DECLARE_NUMERIC_TYPE(Float, float, 50, "f"); DECLARE_NUMERIC_TYPE(Double, double, 60, ""); +//---------------------------------------------------------------------------- +// GeNN::Type::PrintF +//---------------------------------------------------------------------------- +class PrintF : public FunctionBase +{ +public: + DECLARE_TYPE(PrintF); + + PrintF(Qualifier qualifiers = Qualifier{0}) : FunctionBase(qualifiers){} + + //------------------------------------------------------------------------ + // Base virtuals + //------------------------------------------------------------------------ + virtual std::string getName() const final{ return "PrintF"; } + virtual std::string getResolvedName(const TypeContext&) const final{ return "PrintF"; } + virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new PrintF(qualifiers); } + virtual size_t getSizeBytes(const TypeContext&) const final + { + assert(false); + return 0; + } + + //------------------------------------------------------------------------ + // FunctionBase virtuals + //------------------------------------------------------------------------ + virtual const Base *getReturnType() const final { return Int32::getInstance(); }; + virtual std::vector getArgumentTypes() const final{ return {Int8::getInstance()->getPointerType(), nullptr}; } +}; + //---------------------------------------------------------------------------- // Declare standard library function types //---------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 5782eb7599..a984994675 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -300,13 +300,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto calleeFunctionType = dynamic_cast(calleeType); m_CallArguments.clear(); - // If callee's a function + // If callee's a function, type is return type of function if (calleeFunctionType) { - // Assert that argument count matches - // **NOTE** this should have been handled when visiting Expression::Variable - assert(call.getArguments().size() == calleeFunctionType->getArgumentTypes().size()); - - // Type is return type of function setExpressionType(&call, calleeFunctionType->getReturnType()); } // Otherwise @@ -439,9 +434,13 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto *func = dynamic_cast(type); assert(func); - // If number of arguments match + // If function is variadic and there are at least as many vall arguments as actual (last is nullptr) + // function parameters or function is non-variadic and number of arguments match const auto argumentTypes = func->getArgumentTypes(); - if(argumentTypes.size() == m_CallArguments.size()) { + const bool variadic = func->isVariadic(); + if((variadic && m_CallArguments.size() >= (argumentTypes.size() - 1)) + || (!variadic && m_CallArguments.size() == argumentTypes.size())) + { // Create vector to hold argument conversion rank std::vector argumentConversionRank; argumentConversionRank.reserve(m_CallArguments.size()); @@ -450,7 +449,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor bool viable = true; auto c = m_CallArguments.cbegin(); auto a = argumentTypes.cbegin(); - for(;c != m_CallArguments.cend(); c++, a++) { + for(;c != m_CallArguments.cend() && *a != nullptr; c++, a++) { auto cNumericType = dynamic_cast(*c); auto aNumericType = dynamic_cast(*a); @@ -493,7 +492,6 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If function is viable, add to vector along with vector of conversion ranks if(viable) { - assert(argumentConversionRank.size() == m_CallArguments.size()); viableFunctions.emplace_back(func, argumentConversionRank); } } @@ -885,7 +883,8 @@ StandardLibraryFunctionEnvironment::StandardLibraryFunctionEnvironment() ADD_FLOAT_DOUBLE(nearbyint, NearbyInt), ADD_FLOAT_DOUBLE(nextafter, NextAfter),ADD_FLOAT_DOUBLE(remainder, Remainder), ADD_FLOAT_DOUBLE(fabs, FAbs), ADD_FLOAT_DOUBLE(fdim, FDim), ADD_FLOAT_DOUBLE(fmax, FMax), ADD_FLOAT_DOUBLE(fmin, FMin), ADD_FLOAT_DOUBLE(erf, Erf), ADD_FLOAT_DOUBLE(erfc, ErfC), ADD_FLOAT_DOUBLE(tgamma, TGamma), ADD_FLOAT_DOUBLE(lgamma, LGamma), - ADD_FLOAT_DOUBLE(copysign, CopySign), ADD_FLOAT_DOUBLE(fma, FMA)} + ADD_FLOAT_DOUBLE(copysign, CopySign), ADD_FLOAT_DOUBLE(fma, FMA), + {"printf", Type::PrintF::getInstance()}} { } #undef ADD_FLOAT_DOUBLE diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index e28046c049..51729de710 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -73,6 +73,8 @@ IMPLEMENT_TYPE(Uint32); IMPLEMENT_TYPE(Float); IMPLEMENT_TYPE(Double); +IMPLEMENT_TYPE(PrintF); + // Implement trigonometric functions IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Cos); IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Sin); @@ -224,6 +226,26 @@ const Type::NumericBase *NumericTypedef::getResolvedType(const TypeContext &cont } } } + +//---------------------------------------------------------------------------- +// GeNN::Type::FunctionBase +//---------------------------------------------------------------------------- +bool FunctionBase::isVariadic() const +{ + // If variadic marker (nullptr) isn't found, function isn't variadic + const auto argTypes = getArgumentTypes(); + const auto variadicMarker = std::find(argTypes.cbegin(), argTypes.cend(), nullptr); + if(variadicMarker == argTypes.cend()) { + return false; + } + // Otherwise, after checking variadic marker is last argument, return true + else { + assert(argTypes.back() == nullptr); + return true; + } + +} + //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index dea0fc7e5c..f5a8d7fee8 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -370,6 +370,25 @@ TEST(TypeChecker, Call) const auto *type = typeCheckExpression("sin(fmax(0.0f, 1.0f))", stdLibraryEnv); EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); } + + + // Variadic with too few arguments + EXPECT_THROW({ + typeCheckExpression("printf()", stdLibraryEnv);}, + TypeChecker::TypeCheckError); + + // Variadic function with no extra arguments + { + const auto *type = typeCheckExpression("printf(\"hello world\")", stdLibraryEnv); + EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + } + + // Variadic function with extra arguments + { + const auto *type = typeCheckExpression("printf(\"hello world %d, %f\", 12, cos(5.0f))", stdLibraryEnv); + EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + } + } //-------------------------------------------------------------------------- TEST(TypeChecker, Cast) From 12a0c544838bae7e2f36b314b7b1aac64b3789e5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Apr 2023 11:21:49 +0100 Subject: [PATCH 133/725] stack for properly handling evaluated arguments to nested calls --- src/genn/genn/transpiler/typeChecker.cc | 30 ++++++++++++++----------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index a984994675..c8f2f3d812 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -3,8 +3,9 @@ // Standard C++ includes #include #include +#include #include -#include + // Standard C includes #include @@ -287,19 +288,19 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Call &call) final { - // **TODO** think about nested calls - assert(m_CallArguments.empty()); - // Evaluate argument types and store in class - m_CallArguments.clear(); - std::transform(call.getArguments().cbegin(), call.getArguments().cend(), std::back_inserter(m_CallArguments), + // Evaluate argument types and store in top of stack + m_CallArguments.emplace(); + std::transform(call.getArguments().cbegin(), call.getArguments().cend(), std::back_inserter(m_CallArguments.top()), [this](const auto &a){ return evaluateType(a.get()); }); // Evaluate callee type auto calleeType = evaluateType(call.getCallee()); auto calleeFunctionType = dynamic_cast(calleeType); - m_CallArguments.clear(); + // Pop stack + m_CallArguments.pop(); + // If callee's a function, type is return type of function if (calleeFunctionType) { setExpressionType(&call, calleeFunctionType->getReturnType()); @@ -427,6 +428,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Otherwise else { + // Check that there are call arguments on the stack + assert(!m_CallArguments.empty()); + // Loop through variable types std::vector>> viableFunctions; for(const auto *type : varTypes) { @@ -438,18 +442,18 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // function parameters or function is non-variadic and number of arguments match const auto argumentTypes = func->getArgumentTypes(); const bool variadic = func->isVariadic(); - if((variadic && m_CallArguments.size() >= (argumentTypes.size() - 1)) - || (!variadic && m_CallArguments.size() == argumentTypes.size())) + if((variadic && m_CallArguments.top().size() >= (argumentTypes.size() - 1)) + || (!variadic && m_CallArguments.top().size() == argumentTypes.size())) { // Create vector to hold argument conversion rank std::vector argumentConversionRank; - argumentConversionRank.reserve(m_CallArguments.size()); + argumentConversionRank.reserve(m_CallArguments.top().size()); // Loop through arguments bool viable = true; - auto c = m_CallArguments.cbegin(); + auto c = m_CallArguments.top().cbegin(); auto a = argumentTypes.cbegin(); - for(;c != m_CallArguments.cend() && *a != nullptr; c++, a++) { + for(;c != m_CallArguments.top().cend() && *a != nullptr; c++, a++) { auto cNumericType = dynamic_cast(*c); auto aNumericType = dynamic_cast(*a); @@ -749,7 +753,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const Type::TypeContext &m_Context; ErrorHandlerBase &m_ErrorHandler; ResolvedTypeMap &m_ResolvedTypes; - std::vector m_CallArguments; + std::stack> m_CallArguments; bool m_InLoop; bool m_InSwitch; }; From 3829d050a3294eceffc1782646a82cd7397973f9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Apr 2023 17:38:54 +0100 Subject: [PATCH 134/725] whitespace --- src/genn/genn/transpiler/prettyPrinter.cc | 33 +++++++++++------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 98720e440c..4714af27a8 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -33,7 +33,7 @@ class EnvironmentInternal : public EnvironmentBase : m_Enclosing(enclosing) { } - + //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- @@ -42,10 +42,10 @@ class EnvironmentInternal : public EnvironmentBase if(!m_LocalVariables.emplace(name).second) { throw std::runtime_error("Redeclaration of variable"); } - + return "_" + name; } - + virtual std::string getName(const std::string &name) final { if(m_LocalVariables.find(name) == m_LocalVariables.end()) { @@ -55,7 +55,7 @@ class EnvironmentInternal : public EnvironmentBase return "_" + name; } } - + virtual CodeStream &getStream() { return m_Enclosing.getStream(); @@ -188,7 +188,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor m_Environment.get().getStream() << unary.getOperator().lexeme; unary.getRight()->accept(*this); } - + //--------------------------------------------------------------------------- // Statement::Visitor virtuals //--------------------------------------------------------------------------- @@ -201,17 +201,17 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Cache reference to current reference std::reference_wrapper oldEnvironment = m_Environment; - + // Create new environment and set to current EnvironmentInternal environment(m_Environment); m_Environment = environment; - + CodeGenerator::CodeStream::Scope b(m_Environment.get().getStream()); for(auto &s : compound.getStatements()) { s->accept(*this); m_Environment.get().getStream() << std::endl; } - + // Restore old environment m_Environment = oldEnvironment; } @@ -240,11 +240,11 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { // Cache reference to current reference std::reference_wrapper oldEnvironment = m_Environment; - + // Create new environment and set to current EnvironmentInternal environment(m_Environment); m_Environment = environment; - + m_Environment.get().getStream() << "for("; if(forStatement.getInitialiser()) { forStatement.getInitialiser()->accept(*this); @@ -264,7 +264,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } m_Environment.get().getStream() << ")"; forStatement.getBody()->accept(*this); - + // Restore old environment m_Environment = oldEnvironment; } @@ -333,7 +333,6 @@ class Visitor : public Expression::Visitor, public Statement::Visitor void printType(const GeNN::Type::Base *type) { // **THINK** this should be Type::getName! - // Loop, building reversed list of tokens std::vector tokens; while(true) { @@ -344,10 +343,10 @@ class Visitor : public Expression::Visitor, public Statement::Visitor if(pointerType->hasQualifier(GeNN::Type::Qualifier::CONSTANT)) { tokens.push_back("const"); } - + // Add * tokens.push_back("*"); - + // Go to value type type = pointerType->getValueType(); } @@ -355,19 +354,17 @@ class Visitor : public Expression::Visitor, public Statement::Visitor else { // Add type specifier tokens.push_back(type->getName()); - - + if(pointerType->hasQualifier(GeNN::Type::Qualifier::CONSTANT)) { tokens.push_back("const"); } break; } } - // Copy tokens backwards into string stream, seperating with spaces std::copy(tokens.rbegin(), tokens.rend(), std::ostream_iterator(m_Environment.get().getStream(), " ")); - } + //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- From 12dfeb0ab08d74f30e4b6827469f3406640766ff Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Apr 2023 17:40:35 +0100 Subject: [PATCH 135/725] new "standard library" module - used to provide type information for standard library --- .../genn/genn/transpiler/standardLibrary.h | 31 ++++++++++ src/genn/genn/transpiler/standardLibrary.cc | 61 +++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 include/genn/genn/transpiler/standardLibrary.h create mode 100644 src/genn/genn/transpiler/standardLibrary.cc diff --git a/include/genn/genn/transpiler/standardLibrary.h b/include/genn/genn/transpiler/standardLibrary.h new file mode 100644 index 0000000000..a4b6bc5fd2 --- /dev/null +++ b/include/genn/genn/transpiler/standardLibrary.h @@ -0,0 +1,31 @@ +#pragma once + +// Standard C++ includes +#include +#include + +// Transpiler includes +#include "transpiler/typeChecker.h" + +//--------------------------------------------------------------------------- +// GeNN::Transpiler::StandardLibrary::TypeEnvironment +//--------------------------------------------------------------------------- +namespace GeNN::Transpiler::StandardLibrary +{ +class TypeEnvironment : public TypeChecker::EnvironmentBase +{ +public: + TypeEnvironment(); + + //------------------------------------------------------------------------ + // EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual void define(const Token &name, const Type::Base *type, ErrorHandlerBase &errorHandler) final; + virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, + const Type::TypeContext &context, ErrorHandlerBase &errorHandler, + bool initializer = false) final; + virtual const Type::Base *incDec(const Token &name, Token::Type op, + const Type::TypeContext &context, ErrorHandlerBase &errorHandler) final; + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final; +}; +} // namespace GeNN::Transpiler::StandardLibrary diff --git a/src/genn/genn/transpiler/standardLibrary.cc b/src/genn/genn/transpiler/standardLibrary.cc new file mode 100644 index 0000000000..9388558e34 --- /dev/null +++ b/src/genn/genn/transpiler/standardLibrary.cc @@ -0,0 +1,61 @@ +#include "transpiler/standardLibrary.h" + +// Standard C++ library +#include + +using namespace GeNN::Transpiler::standardLibrary; + +//#define ADD_FLOAT_DOUBLE(NAME, CLASS_PREFIX) {#NAME, Type::CLASS_PREFIX##F::getInstance()}, {#NAME, Type::CLASS_PREFIX##D::getInstance()} + +//--------------------------------------------------------------------------- +// Anonymous namespace +//--------------------------------------------------------------------------- +namespace +{ +const std::unordered_multimap, std::string>> libraryTypes{ +}; +} + + +//--------------------------------------------------------------------------- +// GeNN::Transpiler::TypeChecker::TypeEnvironment +//--------------------------------------------------------------------------- +TypeEnvironment::TypeEnvironment() +{ +} +//------------------------------------------------------------------------ +void TypeEnvironment::define(const Token &name, const Type::Base*, ErrorHandlerBase &errorHandler) +{ + errorHandler.error(name, "Cannot declare variable in external environment"); + throw TypeCheckError(); +} +//--------------------------------------------------------------------------- +const Type::Base *StandardLibraryFunctionEnvironment::assign(const Token &name, Token::Type, const Type::Base*, + const Type::TypeContext&, ErrorHandlerBase &errorHandler, bool) +{ + errorHandler.error(name, "Cannot assign variable in external environment"); + throw TypeCheckError(); +} +//--------------------------------------------------------------------------- +const Type::Base *StandardLibraryFunctionEnvironment::incDec(const Token &name, Token::Type, const Type::TypeContext&, + ErrorHandlerBase &errorHandler) +{ + errorHandler.error(name, "Cannot increment/decrement variable in external environment"); + throw TypeCheckError(); +} +//--------------------------------------------------------------------------- +std::vector StandardLibraryFunctionEnvironment::getTypes(const Token &name, ErrorHandlerBase &errorHandler) +{ + auto [typeBegin, typeEnd] = libraryTypes.equal_range(name.lexeme); + if (typeBegin == typeEnd) { + errorHandler.error(name, "Undefined variable"); + throw TypeCheckError(); + } + else { + std::vector types; + types.reserve(std::distance(typeBegin, typeEnd)); + std::transform(typeBegin, typeEnd, std::back_inserter(types), + [](auto t) { return t.second.first.get(); }); + return types; + } +} From 67d3611d71ad9a067c0ef13c00c3e4bf7526fcd0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Apr 2023 18:30:42 +0100 Subject: [PATCH 136/725] updated windows project, implemented standard library data structure --- src/genn/genn/genn.vcxproj | 2 + src/genn/genn/transpiler/standardLibrary.cc | 111 ++++++++++++++++++-- 2 files changed, 107 insertions(+), 6 deletions(-) diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index 4f62e42564..e2d1a220f3 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -60,6 +60,7 @@ + @@ -124,6 +125,7 @@ + diff --git a/src/genn/genn/transpiler/standardLibrary.cc b/src/genn/genn/transpiler/standardLibrary.cc index 9388558e34..90f94efa91 100644 --- a/src/genn/genn/transpiler/standardLibrary.cc +++ b/src/genn/genn/transpiler/standardLibrary.cc @@ -1,19 +1,118 @@ #include "transpiler/standardLibrary.h" // Standard C++ library +#include +#include #include -using namespace GeNN::Transpiler::standardLibrary; +// GeNN includes +#include "type.h" -//#define ADD_FLOAT_DOUBLE(NAME, CLASS_PREFIX) {#NAME, Type::CLASS_PREFIX##F::getInstance()}, {#NAME, Type::CLASS_PREFIX##D::getInstance()} +// Transpiler includes +#include "transpiler/errorHandler.h" +#include "transpiler/typeChecker.h" + +using namespace GeNN::Transpiler::StandardLibrary; +using namespace GeNN::Transpiler::TypeChecker; +namespace Type = GeNN::Type; + +//--------------------------------------------------------------------------- +// Macros +//--------------------------------------------------------------------------- +#define ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(NAME) \ + std::make_pair(#NAME, std::make_pair(std::make_unique>(), #NAME"($(0))")), \ + std::make_pair(#NAME, std::make_pair(std::make_unique>(), #NAME"($(0))")) + +#define ADD_TWO_ARG_FLOAT_DOUBLE_FUNC(NAME) \ + std::make_pair(#NAME, std::make_pair(std::make_unique>(), #NAME"($(0), $(1))")), \ + std::make_pair(#NAME, std::make_pair(std::make_unique>(), #NAME"($(0), $(1))")) + +#define ADD_THREE_ARG_FLOAT_DOUBLE_FUNC(NAME) \ + std::make_pair(#NAME, std::make_pair(std::make_unique>(), #NAME"($(0), $(1), $(2))")), \ + std::make_pair(#NAME, std::make_pair(std::make_unique>(), #NAME"($(0), $(1), $(2))")) //--------------------------------------------------------------------------- // Anonymous namespace //--------------------------------------------------------------------------- namespace { -const std::unordered_multimap, std::string>> libraryTypes{ -}; +template +auto initLibraryTypes(Args&&... args) +{ + std::unordered_multimap, std::string>> map; + (map.emplace(std::forward(args)), ...); + return map; +} + +const auto libraryTypes = initLibraryTypes( + // Trigonometric functions + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(cos), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(sin), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(tan), + + // Inverse trigonometric functions + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(acos), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(asin), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(atan), + ADD_TWO_ARG_FLOAT_DOUBLE_FUNC(atan2), + + // Hyperbolic functions + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(cosh), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(sinh), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(tanh), + + // Inverse Hyperbolic functions + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(acosh), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(asinh), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(atanh), + + // Exponential functions + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(exp), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(expm1), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(exp2), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(pow), + std::make_pair("scalbn", std::make_pair(std::make_unique>(), "scalbn($(0), $(1))")), + std::make_pair("scalbn", std::make_pair(std::make_unique>(), "scalbn($(0), $(1))")), + + // Logarithm functions + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(log), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(log1p), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(log2), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(log10), + std::make_pair("ldexp", std::make_pair(std::make_unique>(), "ldexp($(0), $(1))")), + std::make_pair("ldexp", std::make_pair(std::make_unique>(), "ldexp($(0), $(1))")), + std::make_pair("ilogb", std::make_pair(std::make_unique>(), "ilogb($(0))")), + std::make_pair("ilogb", std::make_pair(std::make_unique>(), "ilogb($(0))")), + + // Root functions + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(sqrt), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(cbrt), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(hypot), + + // Rounding functions + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(ceil), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(floor), + ADD_TWO_ARG_FLOAT_DOUBLE_FUNC(fmod), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(round), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(rint), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(trunc), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(nearbyint), + ADD_TWO_ARG_FLOAT_DOUBLE_FUNC(nextafter), + ADD_TWO_ARG_FLOAT_DOUBLE_FUNC(remainder), + + // Range functions + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(fabs), + ADD_TWO_ARG_FLOAT_DOUBLE_FUNC(fdim), + ADD_TWO_ARG_FLOAT_DOUBLE_FUNC(fmax), + ADD_TWO_ARG_FLOAT_DOUBLE_FUNC(fmin), + + // Other functions + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(erf), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(erfc), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(tgamma), + ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(lgamma), + ADD_TWO_ARG_FLOAT_DOUBLE_FUNC(copysign), + ADD_THREE_ARG_FLOAT_DOUBLE_FUNC(fma)); } @@ -46,7 +145,7 @@ const Type::Base *StandardLibraryFunctionEnvironment::incDec(const Token &name, //--------------------------------------------------------------------------- std::vector StandardLibraryFunctionEnvironment::getTypes(const Token &name, ErrorHandlerBase &errorHandler) { - auto [typeBegin, typeEnd] = libraryTypes.equal_range(name.lexeme); + const auto [typeBegin, typeEnd] = libraryTypes.equal_range(name.lexeme); if (typeBegin == typeEnd) { errorHandler.error(name, "Undefined variable"); throw TypeCheckError(); @@ -55,7 +154,7 @@ std::vector StandardLibraryFunctionEnvironment::getTypes(cons std::vector types; types.reserve(std::distance(typeBegin, typeEnd)); std::transform(typeBegin, typeEnd, std::back_inserter(types), - [](auto t) { return t.second.first.get(); }); + [](const auto &t) { return t.second.first.get(); }); return types; } } From 11d59e1a0ddc16358a4f16cc6f94358d2b62f07e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Apr 2023 18:41:21 +0100 Subject: [PATCH 137/725] much cleaner standard library implementation - only issue is variadics --- .../genn/genn/transpiler/standardLibrary.h | 6 +- include/genn/genn/transpiler/typeChecker.h | 23 --- include/genn/genn/type.h | 141 +----------------- .../code_generator/customUpdateGroupMerged.cc | 5 +- src/genn/genn/transpiler/standardLibrary.cc | 23 ++- src/genn/genn/transpiler/typeChecker.cc | 59 -------- src/genn/genn/type.cc | 69 --------- tests/unit/typeChecker.cc | 7 +- 8 files changed, 31 insertions(+), 302 deletions(-) diff --git a/include/genn/genn/transpiler/standardLibrary.h b/include/genn/genn/transpiler/standardLibrary.h index a4b6bc5fd2..3d0bb350f8 100644 --- a/include/genn/genn/transpiler/standardLibrary.h +++ b/include/genn/genn/transpiler/standardLibrary.h @@ -8,14 +8,14 @@ #include "transpiler/typeChecker.h" //--------------------------------------------------------------------------- -// GeNN::Transpiler::StandardLibrary::TypeEnvironment +// GeNN::Transpiler::StandardLibrary::FunctionTypes //--------------------------------------------------------------------------- namespace GeNN::Transpiler::StandardLibrary { -class TypeEnvironment : public TypeChecker::EnvironmentBase +class FunctionTypes : public TypeChecker::EnvironmentBase { public: - TypeEnvironment(); + FunctionTypes(); //------------------------------------------------------------------------ // EnvironmentBase virtuals diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 96d40d5eed..a045877070 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -68,29 +68,6 @@ class EnvironmentBase const Type::Base *existingType, ErrorHandlerBase &errorHandler) const; }; -//--------------------------------------------------------------------------- -// GeNN::Transpiler::TypeChecker::StandardLibraryFunctionEnvironment -//--------------------------------------------------------------------------- -class StandardLibraryFunctionEnvironment : public EnvironmentBase -{ -public: - StandardLibraryFunctionEnvironment(); - - //------------------------------------------------------------------------ - // EnvironmentBase virtuals - //------------------------------------------------------------------------ - virtual void define(const Token &name, const Type::Base *type, ErrorHandlerBase &errorHandler) final; - virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler, - bool initializer = false) final; - virtual const Type::Base *incDec(const Token &name, Token::Type op, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler) final; - virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final; - -private: - std::unordered_multimap m_Types; -}; - //--------------------------------------------------------------------------- // Free functions //--------------------------------------------------------------------------- diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index db698fdb75..59e5341175 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -50,36 +50,8 @@ using NumericType = TYPE; \ } -#define DECLARE_FUNCTION_TYPE(TYPE, RETURN_TYPE, ...) \ - class TYPE : public Function \ - { \ - DECLARE_TYPE(TYPE) \ - TYPE(Qualifier qualifiers = Qualifier{0}) : Function(qualifiers){} \ - virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new TYPE(qualifiers); } \ - } - #define IMPLEMENT_TYPE(TYPE) TYPE *TYPE::s_Instance = NULL -//! Helper macro to declare single and double precision one argument function types -#define DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(TYPE) \ - DECLARE_FUNCTION_TYPE(TYPE##F, Float, Float); \ - DECLARE_FUNCTION_TYPE(TYPE##D, Double, Double) - -//! Helper macro to declare single and double precision two argument function types -#define DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(TYPE) \ - DECLARE_FUNCTION_TYPE(TYPE##F, Float, Float, Float); \ - DECLARE_FUNCTION_TYPE(TYPE##D, Double, Double, Double) - -//! Helper macro to declare single and double precision three argument function types -#define DECLARE_THREE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(TYPE) \ - DECLARE_FUNCTION_TYPE(TYPE##F, Float, Float, Float, Float); \ - DECLARE_FUNCTION_TYPE(TYPE##D, Double, Double, Double, Double) - -//! Helper macro to implement single and double precision function types -#define IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(TYPE) \ - IMPLEMENT_TYPE(TYPE##F); \ - IMPLEMENT_TYPE(TYPE##D) - //---------------------------------------------------------------------------- // GeNN::Type::TypeTraits //---------------------------------------------------------------------------- @@ -340,6 +312,11 @@ class Function : public FunctionBase return 0; } + virtual Base *getQualifiedType(Qualifier qualifiers) const override + { + return new Function(qualifiers); + } + //------------------------------------------------------------------------ // FunctionBase virtuals //------------------------------------------------------------------------ @@ -414,114 +391,6 @@ DECLARE_NUMERIC_TYPE(Uint32, uint32_t, 30, "u"); DECLARE_NUMERIC_TYPE(Float, float, 50, "f"); DECLARE_NUMERIC_TYPE(Double, double, 60, ""); -//---------------------------------------------------------------------------- -// GeNN::Type::PrintF -//---------------------------------------------------------------------------- -class PrintF : public FunctionBase -{ -public: - DECLARE_TYPE(PrintF); - - PrintF(Qualifier qualifiers = Qualifier{0}) : FunctionBase(qualifiers){} - - //------------------------------------------------------------------------ - // Base virtuals - //------------------------------------------------------------------------ - virtual std::string getName() const final{ return "PrintF"; } - virtual std::string getResolvedName(const TypeContext&) const final{ return "PrintF"; } - virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new PrintF(qualifiers); } - virtual size_t getSizeBytes(const TypeContext&) const final - { - assert(false); - return 0; - } - - //------------------------------------------------------------------------ - // FunctionBase virtuals - //------------------------------------------------------------------------ - virtual const Base *getReturnType() const final { return Int32::getInstance(); }; - virtual std::vector getArgumentTypes() const final{ return {Int8::getInstance()->getPointerType(), nullptr}; } -}; - -//---------------------------------------------------------------------------- -// Declare standard library function types -//---------------------------------------------------------------------------- -// Trigonometric functions -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Cos); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Sin); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Tan); - -// Inverse trigonometric functions -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Acos); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Asin); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Atan); -DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Atan2); - -// Hyperbolic functions -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Cosh); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Sinh); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Tanh); - -// Inverse Hyperbolic functions -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Acosh); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Asinh); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Atanh); - -// Exponential functions -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Exp); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(ExpM1); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Exp2); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Pow); -DECLARE_FUNCTION_TYPE(ScalBNF, Float, Float, Int32); -DECLARE_FUNCTION_TYPE(ScalBND, Double, Double, Int32); - -// Logarithm functions -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Log); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Log1P); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Log2); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Log10); -DECLARE_FUNCTION_TYPE(LdExpF, Float, Float, Int32); -DECLARE_FUNCTION_TYPE(LdExpD, Double, Double, Int32); -DECLARE_FUNCTION_TYPE(ILogBF, Int32, Float); -DECLARE_FUNCTION_TYPE(ILogBD, Int32, Double); - -// Root functions -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Sqrt); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Cbrt); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Hypot); - -// Rounding functions -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Ceil); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Floor); -DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Fmod); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Round); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Rint); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Trunc); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(NearbyInt); -DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(NextAfter); -DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Remainder); - -// Range functions -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(FAbs); -DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(FDim); -DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(FMax); -DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(FMin); - -// Other functions -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(Erf); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(ErfC); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(TGamma); -DECLARE_ONE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(LGamma); -DECLARE_TWO_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(CopySign); -DECLARE_THREE_ARG_FLOAT_DOUBLE_FUNCTION_TYPE(FMA); -/*{, -{"frexp", "frexpf"}, // pointer arguments -{"modf", "modff"}, // pointer arguments -{"scalbln", "scalblnf"}, // long type -{"lround", "lroundf"}, // long return type -{"lrint", "lrintf"}, // long return type -{"remquo", "remquof"}, // pointer arguments -*/ //! Parse a numeric type const NumericBase *parseNumeric(const std::string &typeString); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 0d274416fd..cf61c365fd 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -13,6 +13,7 @@ #include "transpiler/parser.h" #include "transpiler/prettyPrinter.h" #include "transpiler/scanner.h" +#include "transpiler/standardLibrary.h" #include "transpiler/typeChecker.h" #include "transpiler/transpilerUtils.h" @@ -34,7 +35,7 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::TypeC using namespace Type; // Create type environment - TypeChecker::StandardLibraryFunctionEnvironment stdLibraryEnv; + StandardLibrary::FunctionTypes stdLibraryEnv; GroupMergedTypeEnvironment typeEnvironment(*this, &stdLibraryEnv); addField("size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); @@ -267,7 +268,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const using namespace Type; // Create type environment - TypeChecker::StandardLibraryFunctionEnvironment stdLibraryEnv; + StandardLibrary::FunctionTypes stdLibraryEnv; GroupMergedTypeEnvironment typeEnvironment(*this, &stdLibraryEnv); // If underlying synapse group has kernel weights diff --git a/src/genn/genn/transpiler/standardLibrary.cc b/src/genn/genn/transpiler/standardLibrary.cc index 90f94efa91..51ec7c3780 100644 --- a/src/genn/genn/transpiler/standardLibrary.cc +++ b/src/genn/genn/transpiler/standardLibrary.cc @@ -116,34 +116,43 @@ const auto libraryTypes = initLibraryTypes( } +/*{, +{"frexp", "frexpf"}, // pointer arguments +{"modf", "modff"}, // pointer arguments +{"scalbln", "scalblnf"}, // long type +{"lround", "lroundf"}, // long return type +{"lrint", "lrintf"}, // long return type +{"remquo", "remquof"}, // pointer arguments +*/ +//min, max, printf //--------------------------------------------------------------------------- -// GeNN::Transpiler::TypeChecker::TypeEnvironment +// GeNN::Transpiler::StandardLibrary::FunctionTypes //--------------------------------------------------------------------------- -TypeEnvironment::TypeEnvironment() +FunctionTypes::FunctionTypes() { } //------------------------------------------------------------------------ -void TypeEnvironment::define(const Token &name, const Type::Base*, ErrorHandlerBase &errorHandler) +void FunctionTypes::define(const Token &name, const Type::Base*, ErrorHandlerBase &errorHandler) { errorHandler.error(name, "Cannot declare variable in external environment"); throw TypeCheckError(); } //--------------------------------------------------------------------------- -const Type::Base *StandardLibraryFunctionEnvironment::assign(const Token &name, Token::Type, const Type::Base*, - const Type::TypeContext&, ErrorHandlerBase &errorHandler, bool) +const Type::Base *FunctionTypes::assign(const Token &name, Token::Type, const Type::Base*, + const Type::TypeContext&, ErrorHandlerBase &errorHandler, bool) { errorHandler.error(name, "Cannot assign variable in external environment"); throw TypeCheckError(); } //--------------------------------------------------------------------------- -const Type::Base *StandardLibraryFunctionEnvironment::incDec(const Token &name, Token::Type, const Type::TypeContext&, +const Type::Base *FunctionTypes::incDec(const Token &name, Token::Type, const Type::TypeContext&, ErrorHandlerBase &errorHandler) { errorHandler.error(name, "Cannot increment/decrement variable in external environment"); throw TypeCheckError(); } //--------------------------------------------------------------------------- -std::vector StandardLibraryFunctionEnvironment::getTypes(const Token &name, ErrorHandlerBase &errorHandler) +std::vector FunctionTypes::getTypes(const Token &name, ErrorHandlerBase &errorHandler) { const auto [typeBegin, typeEnd] = libraryTypes.equal_range(name.lexeme); if (typeBegin == typeEnd) { diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index c8f2f3d812..261e17138e 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -870,65 +870,6 @@ const Type::Base *EnvironmentBase::incDec(const Token &name, Token::Type, } } -//--------------------------------------------------------------------------- -// GeNN::Transpiler::TypeChecker::StandardLibraryFunctionEnvironment -//--------------------------------------------------------------------------- -#define ADD_FLOAT_DOUBLE(NAME, CLASS_PREFIX) {#NAME, Type::CLASS_PREFIX##F::getInstance()}, {#NAME, Type::CLASS_PREFIX##D::getInstance()} -StandardLibraryFunctionEnvironment::StandardLibraryFunctionEnvironment() - : m_Types{ADD_FLOAT_DOUBLE(cos, Cos), ADD_FLOAT_DOUBLE(sin, Sin), ADD_FLOAT_DOUBLE(tan, Tan), - ADD_FLOAT_DOUBLE(acos, Acos), ADD_FLOAT_DOUBLE(asin, Asin), ADD_FLOAT_DOUBLE(atan, Atan), ADD_FLOAT_DOUBLE(atan2, Atan2), - ADD_FLOAT_DOUBLE(cosh, Cosh), ADD_FLOAT_DOUBLE(sinh, Sinh), ADD_FLOAT_DOUBLE(tanh, Tanh), - ADD_FLOAT_DOUBLE(exp, Exp), ADD_FLOAT_DOUBLE(expm1, ExpM1), ADD_FLOAT_DOUBLE(exp2, Exp2), ADD_FLOAT_DOUBLE(pow, Pow), - ADD_FLOAT_DOUBLE(scalbn, ScalBN), ADD_FLOAT_DOUBLE(log, Log), ADD_FLOAT_DOUBLE(log1p, Log1P), ADD_FLOAT_DOUBLE(log2, Log2), - ADD_FLOAT_DOUBLE(log10, Log10), ADD_FLOAT_DOUBLE(ldexp, LdExp), ADD_FLOAT_DOUBLE(ilogb, ILogB), - ADD_FLOAT_DOUBLE(sqrt, Sqrt), ADD_FLOAT_DOUBLE(cbrt, Cbrt), ADD_FLOAT_DOUBLE(hypot, Hypot), - ADD_FLOAT_DOUBLE(ceil, Ceil), ADD_FLOAT_DOUBLE(floor, Floor), ADD_FLOAT_DOUBLE(fmod, Fmod), - ADD_FLOAT_DOUBLE(round, Round), ADD_FLOAT_DOUBLE(rint, Rint), ADD_FLOAT_DOUBLE(trunc, Trunc), - ADD_FLOAT_DOUBLE(nearbyint, NearbyInt), ADD_FLOAT_DOUBLE(nextafter, NextAfter),ADD_FLOAT_DOUBLE(remainder, Remainder), - ADD_FLOAT_DOUBLE(fabs, FAbs), ADD_FLOAT_DOUBLE(fdim, FDim), ADD_FLOAT_DOUBLE(fmax, FMax), ADD_FLOAT_DOUBLE(fmin, FMin), - ADD_FLOAT_DOUBLE(erf, Erf), ADD_FLOAT_DOUBLE(erfc, ErfC), ADD_FLOAT_DOUBLE(tgamma, TGamma), ADD_FLOAT_DOUBLE(lgamma, LGamma), - ADD_FLOAT_DOUBLE(copysign, CopySign), ADD_FLOAT_DOUBLE(fma, FMA), - {"printf", Type::PrintF::getInstance()}} -{ -} -#undef ADD_FLOAT_DOUBLE -//------------------------------------------------------------------------ -void StandardLibraryFunctionEnvironment::define(const Token &name, const Type::Base*, ErrorHandlerBase &errorHandler) -{ - errorHandler.error(name, "Cannot declare variable in external environment"); - throw TypeCheckError(); -} -//--------------------------------------------------------------------------- -const Type::Base *StandardLibraryFunctionEnvironment::assign(const Token &name, Token::Type, const Type::Base*, - const Type::TypeContext&, ErrorHandlerBase &errorHandler, bool) -{ - errorHandler.error(name, "Cannot assign variable in external environment"); - throw TypeCheckError(); -} -//--------------------------------------------------------------------------- -const Type::Base *StandardLibraryFunctionEnvironment::incDec(const Token &name, Token::Type, const Type::TypeContext&, - ErrorHandlerBase &errorHandler) -{ - errorHandler.error(name, "Cannot increment/decrement variable in external environment"); - throw TypeCheckError(); -} -//--------------------------------------------------------------------------- -std::vector StandardLibraryFunctionEnvironment::getTypes(const Token &name, ErrorHandlerBase &errorHandler) -{ - auto [typeBegin, typeEnd] = m_Types.equal_range(name.lexeme); - if (typeBegin == typeEnd) { - errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); - } - else { - std::vector types; - types.reserve(std::distance(typeBegin, typeEnd)); - std::transform(typeBegin, typeEnd, std::back_inserter(types), - [](auto t) { return t.second; }); - return types; - } -} - //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker //--------------------------------------------------------------------------- diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 51729de710..406047e3e7 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -73,75 +73,6 @@ IMPLEMENT_TYPE(Uint32); IMPLEMENT_TYPE(Float); IMPLEMENT_TYPE(Double); -IMPLEMENT_TYPE(PrintF); - -// Implement trigonometric functions -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Cos); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Sin); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Tan); - -// Implement inverse trigonometric functions -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Acos); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Asin); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Atan); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Atan2); - -// Implement hyperbolic functions -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Cosh); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Sinh); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Tanh); - -// Implement inverse hyperbolic functions -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Acosh); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Asinh); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Atanh); - -// Implement exponential functions -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Exp); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(ExpM1); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Exp2); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Pow); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(ScalBN); - -// Implement logarithm functions -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Log); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Log1P); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Log2); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Log10); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(LdExp); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(ILogB); - -// Implement root functions -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Sqrt); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Cbrt); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Hypot); - -// Implement rounding functions -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Ceil); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Floor); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Fmod); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Round); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Rint); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Trunc); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(NearbyInt); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(NextAfter); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Remainder); - -// Implement range functions -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(FAbs); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(FDim); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(FMax); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(FMin); - -// Implement other functions -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(Erf); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(ErfC); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(TGamma); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(LGamma); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(CopySign); -IMPLEMENT_FLOAT_DOUBLE_FUNCTION_TYPE(FMA); - - //---------------------------------------------------------------------------- // GeNN::Type::Base //---------------------------------------------------------------------------- diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index f5a8d7fee8..c9ea1490fb 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -8,6 +8,7 @@ #include "transpiler/errorHandler.h" #include "transpiler/parser.h" #include "transpiler/scanner.h" +#include "transpiler/standardLibrary.h" #include "transpiler/typeChecker.h" using namespace GeNN; @@ -329,7 +330,7 @@ TEST(TypeChecker, Binary) TEST(TypeChecker, Call) { // Too few arguments - TypeChecker::StandardLibraryFunctionEnvironment stdLibraryEnv; + StandardLibrary::FunctionTypes stdLibraryEnv; EXPECT_THROW({ typeCheckExpression("sin()", stdLibraryEnv);}, TypeChecker::TypeCheckError); @@ -373,7 +374,7 @@ TEST(TypeChecker, Call) // Variadic with too few arguments - EXPECT_THROW({ + /*EXPECT_THROW({ typeCheckExpression("printf()", stdLibraryEnv);}, TypeChecker::TypeCheckError); @@ -387,7 +388,7 @@ TEST(TypeChecker, Call) { const auto *type = typeCheckExpression("printf(\"hello world %d, %f\", 12, cos(5.0f))", stdLibraryEnv); EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - } + }*/ } //-------------------------------------------------------------------------- From 7626aaf63149d2e0c61aa3d90bac5bd9f2ee861e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 19 Apr 2023 11:48:31 +0100 Subject: [PATCH 138/725] type checker now returns resolved type map and custom update group merged stashes it in class **BROKEN IN VC++** --- include/genn/genn/code_generator/customUpdateGroupMerged.h | 7 +++++++ include/genn/genn/transpiler/typeChecker.h | 4 ++-- src/genn/genn/code_generator/customUpdateGroupMerged.cc | 4 ++-- src/genn/genn/transpiler/typeChecker.cc | 7 ++++--- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index 286a466686..923f29a21f 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -7,6 +7,7 @@ // GeNN transpiler includes #include "transpiler/statement.h" +#include "transpiler/typeChecker.h" //---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomUpdateGroupMerged @@ -54,6 +55,9 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMergedgetUpdateCode()); const auto tokens = Scanner::scanSource(code, errorHandler); m_UpdateStatements = Parser::parseBlockItemList(tokens, errorHandler); - TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, typeContext, errorHandler); + m_ResolvedTypes = TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, typeContext, errorHandler); } //---------------------------------------------------------------------------- bool CustomUpdateGroupMerged::isParamHeterogeneous(const std::string ¶mName) const @@ -366,7 +366,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const const std::string code = upgradeCodeString(cm->getUpdateCode()); const auto tokens = Scanner::scanSource(code, errorHandler); m_UpdateStatements = Parser::parseBlockItemList(tokens, errorHandler); - TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, typeContext, errorHandler); + m_ResolvedTypes = TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, typeContext, errorHandler); } // ---------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 261e17138e..6f1001fdbb 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -873,12 +873,13 @@ const Type::Base *EnvironmentBase::incDec(const Token &name, Token::Type, //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker //--------------------------------------------------------------------------- -void GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler) +ResolvedTypeMap GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, + const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { ResolvedTypeMap expressionTypes; EnvironmentInternal internalEnvironment(environment); - Visitor(statements, context, internalEnvironment, expressionTypes, errorHandler); + Visitor visitor(statements, context, internalEnvironment, expressionTypes, errorHandler); + return expressionTypes; } //--------------------------------------------------------------------------- const Type::Base *GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, EnvironmentBase &environment, From 5803c77b170aab906a57f70b7d9cd68fb9ddfa71 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 19 Apr 2023 11:48:47 +0100 Subject: [PATCH 139/725] start of standard library environment --- .../genn/genn/transpiler/standardLibrary.h | 21 +++++++++++++++++++ src/genn/genn/transpiler/standardLibrary.cc | 17 ++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/include/genn/genn/transpiler/standardLibrary.h b/include/genn/genn/transpiler/standardLibrary.h index 3d0bb350f8..fa8d602610 100644 --- a/include/genn/genn/transpiler/standardLibrary.h +++ b/include/genn/genn/transpiler/standardLibrary.h @@ -4,6 +4,10 @@ #include #include +// Code generator includes +#include "code_generator/codeStream.h" +#include "code_generator/environment.h" + // Transpiler includes #include "transpiler/typeChecker.h" @@ -28,4 +32,21 @@ class FunctionTypes : public TypeChecker::EnvironmentBase const Type::TypeContext &context, ErrorHandlerBase &errorHandler) final; virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final; }; + +//--------------------------------------------------------------------------- +// GeNN::Transpiler::StandardLibrary::FunctionEnvironment +//--------------------------------------------------------------------------- +class FunctionEnvironment : public CodeGenerator::EnvironmentExternal +{ +public: + FunctionEnvironment(CodeGenerator::CodeStream &os) + : CodeGenerator::EnvironmentExternal(os) + {} + + //------------------------------------------------------------------------ + // PrettyPrinter::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::string getName(const std::string &name) final; + virtual CodeGenerator::CodeStream &getStream() final; +}; } // namespace GeNN::Transpiler::StandardLibrary diff --git a/src/genn/genn/transpiler/standardLibrary.cc b/src/genn/genn/transpiler/standardLibrary.cc index 51ec7c3780..bff92d3e6c 100644 --- a/src/genn/genn/transpiler/standardLibrary.cc +++ b/src/genn/genn/transpiler/standardLibrary.cc @@ -12,6 +12,7 @@ #include "transpiler/errorHandler.h" #include "transpiler/typeChecker.h" +using namespace GeNN::CodeGenerator; using namespace GeNN::Transpiler::StandardLibrary; using namespace GeNN::Transpiler::TypeChecker; namespace Type = GeNN::Type; @@ -39,7 +40,7 @@ namespace template auto initLibraryTypes(Args&&... args) { - std::unordered_multimap, std::string>> map; + std::unordered_multimap, std::string>> map; (map.emplace(std::forward(args)), ...); return map; } @@ -125,6 +126,7 @@ const auto libraryTypes = initLibraryTypes( {"remquo", "remquof"}, // pointer arguments */ //min, max, printf + //--------------------------------------------------------------------------- // GeNN::Transpiler::StandardLibrary::FunctionTypes //--------------------------------------------------------------------------- @@ -167,3 +169,16 @@ std::vector FunctionTypes::getTypes(const Token &name, ErrorH return types; } } + +//--------------------------------------------------------------------------- +// GeNN::Transpiler::StandardLibrary::FunctionEnvironment +//--------------------------------------------------------------------------- +std::string FunctionEnvironment::getName(const std::string &name) +{ + return ""; +} +//--------------------------------------------------------------------------- +CodeStream &FunctionEnvironment::getStream() +{ + return getContextStream(); +} \ No newline at end of file From 5258cf6564cfe99eb0a229fdd9f476f4d2598f46 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 19 Apr 2023 12:46:20 +0100 Subject: [PATCH 140/725] fix for VC++ issue --- include/genn/genn/code_generator/groupMerged.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 0ee25dee84..5e06c3dc03 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -70,6 +70,9 @@ class GroupMerged : m_Index(index), m_TypeContext(typeContext), m_Groups(std::move(groups)) {} + GroupMerged(const GroupMerged&) = delete; + GroupMerged(GroupMerged&&) = default; + //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ From f82ee52ec10d2ece5dfe0ac28343c621a876df96 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 19 Apr 2023 14:03:09 +0100 Subject: [PATCH 141/725] hooked pretty printer to standard library --- .../code_generator/customUpdateGroupMerged.h | 3 ++- .../genn/genn/code_generator/environment.h | 8 +++---- include/genn/genn/transpiler/prettyPrinter.h | 5 +++-- .../genn/genn/transpiler/standardLibrary.h | 2 +- .../backends/single_threaded_cpu/backend.cc | 5 ++++- .../code_generator/customUpdateGroupMerged.cc | 4 ++-- src/genn/genn/code_generator/environment.cc | 8 +++---- src/genn/genn/transpiler/prettyPrinter.cc | 21 +++++++++++-------- src/genn/genn/transpiler/standardLibrary.cc | 13 ++++++++++-- 9 files changed, 43 insertions(+), 26 deletions(-) diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index 923f29a21f..31df29b5a6 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -97,7 +97,8 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged enclosing)->std::string { return enclosing.get().getName(name); }, + [&name, type](std::reference_wrapper enclosing)->std::string { return enclosing.get().getName(name, type); }, [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Variable '" + name + "' undefined"); }}, getContext()); } @@ -51,12 +51,12 @@ EnvironmentSubstitute::~EnvironmentSubstitute() getContextStream() << m_ContentsStream.str(); } //---------------------------------------------------------------------------- -std::string EnvironmentSubstitute::getName(const std::string &name) +std::string EnvironmentSubstitute::getName(const std::string &name, const Type::Base *type) { // If there isn't a substitution for this name, try and get name from context auto var = m_VarSubstitutions.find(name); if(var == m_VarSubstitutions.end()) { - return getContextName(name); + return getContextName(name, type); } // Otherwise, return substitution else { diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 4714af27a8..b656fc0187 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -12,6 +12,7 @@ // Transpiler includes #include "transpiler/transpilerUtils.h" +#include "transpiler/typeChecker.h" using namespace GeNN; using namespace GeNN::CodeGenerator; @@ -46,10 +47,10 @@ class EnvironmentInternal : public EnvironmentBase return "_" + name; } - virtual std::string getName(const std::string &name) final + virtual std::string getName(const std::string &name, const Type::Base *type = nullptr) final { if(m_LocalVariables.find(name) == m_LocalVariables.end()) { - return m_Enclosing.getName(name); + return m_Enclosing.getName(name, type); } else { return "_" + name; @@ -75,9 +76,9 @@ class EnvironmentInternal : public EnvironmentBase class Visitor : public Expression::Visitor, public Statement::Visitor { public: - Visitor(const Statement::StatementList &statements, - EnvironmentInternal &environment, const Type::TypeContext &context) - : m_Environment(environment), m_Context(context) + Visitor(const Statement::StatementList &statements, EnvironmentInternal &environment, + const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes) + : m_Environment(environment), m_Context(context), m_ResolvedTypes(resolvedTypes) { for(auto &s : statements) { s.get()->accept(*this); @@ -155,7 +156,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const Type::NumericBase *scalar = dynamic_cast(m_Context.at("scalar")); m_Environment.get().getStream() << lexeme << scalar->getLiteralSuffix(m_Context); } - // Otherwise, just write out original lexeme directly + // Otherwise, just write out original lexeme directly (strings are already quoted) else { m_Environment.get().getStream() << lexeme; } @@ -180,7 +181,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Variable &variable) final { - m_Environment.get().getStream() << m_Environment.get().getName(variable.getName().lexeme); + const auto *type = m_ResolvedTypes.at(&variable); + m_Environment.get().getStream() << m_Environment.get().getName(variable.getName().lexeme, type); } virtual void visit(const Expression::Unary &unary) final @@ -370,6 +372,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- std::reference_wrapper m_Environment; const Type::TypeContext &m_Context; + const TypeChecker::ResolvedTypeMap &m_ResolvedTypes; }; } // Anonymous namespace @@ -377,8 +380,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // GeNN::Transpiler::PrettyPrinter //--------------------------------------------------------------------------- void GeNN::Transpiler::PrettyPrinter::print(const Statement::StatementList &statements, EnvironmentBase &environment, - const Type::TypeContext &context) + const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes) { EnvironmentInternal internalEnvironment(environment); - Visitor(statements, internalEnvironment, context); + Visitor(statements, internalEnvironment, context, resolvedTypes); } diff --git a/src/genn/genn/transpiler/standardLibrary.cc b/src/genn/genn/transpiler/standardLibrary.cc index bff92d3e6c..c0d9ed9d5d 100644 --- a/src/genn/genn/transpiler/standardLibrary.cc +++ b/src/genn/genn/transpiler/standardLibrary.cc @@ -173,9 +173,18 @@ std::vector FunctionTypes::getTypes(const Token &name, ErrorH //--------------------------------------------------------------------------- // GeNN::Transpiler::StandardLibrary::FunctionEnvironment //--------------------------------------------------------------------------- -std::string FunctionEnvironment::getName(const std::string &name) +std::string FunctionEnvironment::getName(const std::string &name, const Type::Base *type) { - return ""; + const auto [libTypeBegin, libTypeEnd] = libraryTypes.equal_range(name); + if (libTypeBegin == libTypeEnd) { + return getContextName(name, type); + } + else { + const auto libType = std::find_if(libTypeBegin, libTypeEnd, + [type](const auto &t){ return t.second.first.get() == type; }); + assert(libType != libTypeEnd); + return libType->second.second; + } } //--------------------------------------------------------------------------- CodeStream &FunctionEnvironment::getStream() From dcd64284e4b0c25586b53a10df654314cbfadded Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 19 Apr 2023 14:47:52 +0100 Subject: [PATCH 142/725] get rid of accidental copies in unit tests --- tests/unit/neuronGroup.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/neuronGroup.cc b/tests/unit/neuronGroup.cc index 5a43a53fb6..147865d458 100644 --- a/tests/unit/neuronGroup.cc +++ b/tests/unit/neuronGroup.cc @@ -593,8 +593,8 @@ TEST(NeuronGroup, CompareNeuronModels) // Find which merged neuron init group is the one with the single population i.e. the one with constant initialisers const size_t constantInitIndex = (modelSpecMerged.getMergedNeuronInitGroups().at(0).getGroups().size() == 1) ? 0 : 1; - const auto constantInitMergedGroup = modelSpecMerged.getMergedNeuronInitGroups().at(constantInitIndex); - const auto uniformInitMergedGroup = modelSpecMerged.getMergedNeuronInitGroups().at(1 - constantInitIndex); + const auto &constantInitMergedGroup = modelSpecMerged.getMergedNeuronInitGroups().at(constantInitIndex); + const auto &uniformInitMergedGroup = modelSpecMerged.getMergedNeuronInitGroups().at(1 - constantInitIndex); // Check that only 'd' parameter is heterogeneous in neuron update group ASSERT_FALSE(modelSpecMerged.getMergedNeuronUpdateGroups().at(0).isParamHeterogeneous("a")); @@ -741,8 +741,8 @@ TEST(NeuronGroup, CompareCurrentSources) // Find which merged neuron group is the one with the single population i.e. the two DC current sources const size_t dcDCIndex = (modelSpecMerged.getMergedNeuronUpdateGroups().at(0).getGroups().size() == 4) ? 1 : 0; - const auto dcDCMergedGroup = modelSpecMerged.getMergedNeuronUpdateGroups().at(dcDCIndex); - const auto dcPoissonMergedGroup = modelSpecMerged.getMergedNeuronUpdateGroups().at(1 - dcDCIndex); + const auto &dcDCMergedGroup = modelSpecMerged.getMergedNeuronUpdateGroups().at(dcDCIndex); + const auto &dcPoissonMergedGroup = modelSpecMerged.getMergedNeuronUpdateGroups().at(1 - dcDCIndex); ASSERT_TRUE(dcDCMergedGroup.getGroups().size() == 1); // Find which child in the DC + poisson merged group is the poisson current source From 6e0836f12bcecace5f1435e03b8ec72dbf1dde9d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 19 Apr 2023 15:58:37 +0100 Subject: [PATCH 143/725] unified parsing of ``ArraySubscript`` and ``Call`` --- include/genn/genn/transpiler/expression.h | 12 +++++++----- src/genn/genn/transpiler/parser.cc | 19 ++++++------------- src/genn/genn/transpiler/prettyPrinter.cc | 3 ++- src/genn/genn/transpiler/typeChecker.cc | 10 +++++----- 4 files changed, 20 insertions(+), 24 deletions(-) diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index 27f27513a7..d0e31802d9 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -82,15 +82,17 @@ typedef std::vector ExpressionList; class ArraySubscript : public Acceptable { public: - ArraySubscript(Token pointerName, ExpressionPtr index) - : m_PointerName(pointerName), m_Index(std::move(index)) + ArraySubscript(ExpressionPtr array, Token closingSquareBracket, ExpressionPtr index) + : m_Array(std::move(array)), m_ClosingSquareBracket(closingSquareBracket), m_Index(std::move(index)) {} - const Token &getPointerName() const { return m_PointerName; } - const ExpressionPtr &getIndex() const { return m_Index; } + const Base *getArray() const { return m_Array.get(); } + const Token &getClosingSquareBracket() const { return m_ClosingSquareBracket; } + const Base *getIndex() const { return m_Index.get(); } private: - const Token m_PointerName; + const ExpressionPtr m_Array; + const Token m_ClosingSquareBracket; const ExpressionPtr m_Index; }; diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 36a9db7b05..38b8c8346d 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -213,10 +213,10 @@ const GeNN::Type::Base *parseDeclarationSpecifiers(ParserState &parserState) } } } while(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER, Token::Type::STAR})); - + // Lookup numeric type const Base *type = getNumericType(typeSpecifiers); - + // If there are any type qualifiers, add const // **THINK** this relies of const being only qualifier if(!typeQualifiers.empty()) { @@ -285,7 +285,6 @@ Expression::ExpressionPtr parsePostfix(ParserState &parserState) Token closingParen = parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after arguments."); - // Create call expression expression = std::make_unique(std::move(expression), closingParen, std::move(arguments)); @@ -296,15 +295,9 @@ Expression::ExpressionPtr parsePostfix(ParserState &parserState) Token closingSquareBracket = parserState.consume(Token::Type::RIGHT_SQUARE_BRACKET, "Expect ']' after index."); - // **TODO** everything all the way up(?) from unary are l-value so can be used - not just variable - auto expressionVariable = dynamic_cast(expression.get()); - if(expressionVariable) { - expression = std::make_unique(expressionVariable->getName(), - std::move(index)); - } - else { - parserState.error(closingSquareBracket, "Invalid subscript target"); - } + expression = std::make_unique(std::move(expression), + closingSquareBracket, + std::move(index)); } // Otherwise if this is an increment or decrement else if(parserState.match({Token::Type::PLUS_PLUS, Token::Type::MINUS_MINUS})) { @@ -313,7 +306,7 @@ Expression::ExpressionPtr parsePostfix(ParserState &parserState) // **TODO** everything all the way up(?) from unary are l-value so can be used - not just variable auto expressionVariable = dynamic_cast(expression.get()); if(expressionVariable) { - return std::make_unique(expressionVariable->getName(), op); + expression = std::make_unique(expressionVariable->getName(), op); } else { parserState.error(op, "Invalid postfix target"); diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index b656fc0187..a4834bb17a 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -92,7 +92,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- virtual void visit(const Expression::ArraySubscript &arraySubscript) final { - m_Environment.get().getStream() << m_Environment.get().getName(arraySubscript.getPointerName().lexeme) << "["; + arraySubscript.getArray()->accept(*this); + m_Environment.get().getStream() << "["; arraySubscript.getIndex()->accept(*this); m_Environment.get().getStream() << "]"; } diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 6f1001fdbb..97d4b9dbbe 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -174,16 +174,16 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::ArraySubscript &arraySubscript) final { // Get pointer type - auto arrayType = m_Environment.get().getType(arraySubscript.getPointerName(), m_ErrorHandler); + auto arrayType = evaluateType(arraySubscript.getArray()); auto pointerType = dynamic_cast(arrayType); // If pointer is indeed a pointer if (pointerType) { // Evaluate pointer type - auto indexType = evaluateType(arraySubscript.getIndex().get()); - auto indexNumericType = dynamic_cast(indexType); + auto indexType = evaluateType(arraySubscript.getIndex()); + auto indexNumericType = dynamic_cast(indexType); if (!indexNumericType || !indexNumericType->isIntegral(m_Context)) { - m_ErrorHandler.error(arraySubscript.getPointerName(), + m_ErrorHandler.error(arraySubscript.getClosingSquareBracket(), "Invalid subscript index type '" + indexType->getName() + "'"); throw TypeCheckError(); } @@ -193,7 +193,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Otherwise else { - m_ErrorHandler.error(arraySubscript.getPointerName(), "Subscripted object is not a pointer"); + m_ErrorHandler.error(arraySubscript.getClosingSquareBracket(), "Subscripted object is not a pointer"); throw TypeCheckError(); } } From 837706f88d5a9e2ceee722f9c1094209b405bff5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 19 Apr 2023 16:17:25 +0100 Subject: [PATCH 144/725] added failing *x = 7 style test --- tests/unit/typeChecker.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index c9ea1490fb..4cc7e732f7 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -252,6 +252,14 @@ TEST(TypeChecker, Assignment) typeCheckStatements("float *x = intArray;", typeEnvironment);}, TypeChecker::TypeCheckError); + // Dereference assignment + { + TestEnvironment typeEnvironment; + typeEnvironment.definePointer("intArray"); + typeCheckStatements( + "*intArray = 7;\n", + typeEnvironment); + } // **TODO** other assignements i.e. += -= %= } //-------------------------------------------------------------------------- From e92221c97c82d8c6d4c27d5e2f68b928d9762811 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 19 Apr 2023 16:17:45 +0100 Subject: [PATCH 145/725] added ``isLValue`` to ``Expression::Base`` --- include/genn/genn/transpiler/expression.h | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index d0e31802d9..424cc68bbf 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -58,6 +58,7 @@ class Base { public: virtual void accept(Visitor &visitor) const = 0; + virtual bool isLValue() const{ return false; } }; //--------------------------------------------------------------------------- @@ -86,6 +87,11 @@ class ArraySubscript : public Acceptable : m_Array(std::move(array)), m_ClosingSquareBracket(closingSquareBracket), m_Index(std::move(index)) {} + //------------------------------------------------------------------------ + // Expression::Base virtuals + //------------------------------------------------------------------------ + virtual bool isLValue() const{ return true; } + const Base *getArray() const { return m_Array.get(); } const Token &getClosingSquareBracket() const { return m_ClosingSquareBracket; } const Base *getIndex() const { return m_Index.get(); } @@ -208,6 +214,11 @@ class Grouping : public Acceptable : m_Expression(std::move(expression)) {} + //------------------------------------------------------------------------ + // Expression::Base virtuals + //------------------------------------------------------------------------ + virtual bool isLValue() const{ return m_Expression->isLValue(); } + const Base *getExpression() const { return m_Expression.get(); } private: @@ -224,6 +235,11 @@ class Literal : public Acceptable : m_Value(value) {} + //------------------------------------------------------------------------ + // Expression::Base virtuals + //------------------------------------------------------------------------ + virtual bool isLValue() const{ return (m_Value.type == Token::Type::STRING); } + Token getValue() const { return m_Value; } private: @@ -296,6 +312,11 @@ class Variable : public Acceptable : m_Name(name) {} + //------------------------------------------------------------------------ + // Expression::Base virtuals + //------------------------------------------------------------------------ + virtual bool isLValue() const{ return true; } + const Token &getName() const { return m_Name; } private: From de3e4acd9683175c48a94e817ee3a0ea032d375b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 21 Apr 2023 14:18:26 +0100 Subject: [PATCH 146/725] parser now accepts any l-value expression for address of, assignment and pre/post increment/decrement --- include/genn/genn/transpiler/expression.h | 24 ++++++------- src/genn/genn/transpiler/parser.cc | 44 ++++++++++++++--------- 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index 424cc68bbf..4a20e02edb 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -108,16 +108,16 @@ class ArraySubscript : public Acceptable class Assignment : public Acceptable { public: - Assignment(Token varName, Token op, ExpressionPtr value) - : m_VarName(varName), m_Operator(op), m_Value(std::move(value)) + Assignment(ExpressionPtr assignee, Token op, ExpressionPtr value) + : m_Assignee(std::move(assignee)), m_Operator(op), m_Value(std::move(value)) {} - const Token &getVarName() const { return m_VarName; } + const Base *getAssignee() const { return m_Assignee.get(); } const Token &getOperator() const { return m_Operator; } const Base *getValue() const { return m_Value.get(); } private: - const Token m_VarName; + const ExpressionPtr m_Assignee; const Token m_Operator; const ExpressionPtr m_Value; }; @@ -272,15 +272,15 @@ class Logical : public Acceptable class PostfixIncDec : public Acceptable { public: - PostfixIncDec(Token varName, Token op) - : m_VarName(varName), m_Operator(op) + PostfixIncDec(ExpressionPtr target, Token op) + : m_Target(std::move(target)), m_Operator(op) {} - const Token &getVarName() const { return m_VarName; } + const Base *getTarget() const { return m_Target.get(); } const Token &getOperator() const { return m_Operator; } private: - const Token m_VarName; + const ExpressionPtr m_Target; const Token m_Operator; }; @@ -290,15 +290,15 @@ class PostfixIncDec : public Acceptable class PrefixIncDec : public Acceptable { public: - PrefixIncDec(Token varName, Token op) - : m_VarName(varName), m_Operator(op) + PrefixIncDec(ExpressionPtr target, Token op) + : m_Target(std::move(target)), m_Operator(op) {} - const Token &getVarName() const { return m_VarName; } + const Base *getTarget() const { return m_Target.get(); } const Token &getOperator() const { return m_Operator; } private: - const Token m_VarName; + const ExpressionPtr m_Target; const Token m_Operator; }; diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 38b8c8346d..937655bca6 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -303,13 +303,12 @@ Expression::ExpressionPtr parsePostfix(ParserState &parserState) else if(parserState.match({Token::Type::PLUS_PLUS, Token::Type::MINUS_MINUS})) { Token op = parserState.previous(); - // **TODO** everything all the way up(?) from unary are l-value so can be used - not just variable - auto expressionVariable = dynamic_cast(expression.get()); - if(expressionVariable) { - expression = std::make_unique(expressionVariable->getName(), op); + // If expression is a valid l-value, + if(expression->isLValue()) { + expression = std::make_unique(std::move(expression), op); } else { - parserState.error(op, "Invalid postfix target"); + parserState.error(op, "Expression is not assignable"); } } else { @@ -335,8 +334,21 @@ Expression::ExpressionPtr parseUnary(ParserState &parserState) // "!" cast-expression // "sizeof" unary-expression **TODO** // "sizeof" "(" type-name ")" **TODO** - if(parserState.match({Token::Type::AMPERSAND, Token::Type::STAR, Token::Type::PLUS, - Token::Type::MINUS, Token::Type::TILDA, Token::Type::NOT})) { + if(parserState.match(Token::Type::AMPERSAND)) { + Token op = parserState.previous(); + auto expression = parseCast(parserState); + + // If expression is a valid l-value, + if (expression->isLValue()) { + return std::make_unique(op, std::move(expression)); + } + else { + parserState.error(op, "Cannot take the address of r-value"); + } + } + else if(parserState.match({Token::Type::STAR, Token::Type::PLUS, Token::Type::MINUS, + Token::Type::TILDA, Token::Type::NOT})) + { Token op = parserState.previous(); return std::make_unique(op, parseCast(parserState)); } @@ -344,13 +356,12 @@ Expression::ExpressionPtr parseUnary(ParserState &parserState) Token op = parserState.previous(); auto expression = parseUnary(parserState); - // **TODO** everything all the way up(?) from unary are l-value so can be used - not just variable - auto expressionVariable = dynamic_cast(expression.get()); - if(expressionVariable) { - return std::make_unique(expressionVariable->getName(), op); + // If expression is a valid l-value, + if(expression->isLValue()) { + return std::make_unique(std::move(expression), op); } else { - parserState.error(op, "Invalid prefix target"); + parserState.error(op, "Expression is not assignable"); } } @@ -525,13 +536,12 @@ Expression::ExpressionPtr parseAssignment(ParserState &parserState) Token op = parserState.previous(); auto value = parseAssignment(parserState); - // **TODO** everything all the way up(?) from unary are l-value so can be used - not just variable - auto expressionVariable = dynamic_cast(expression.get()); - if(expressionVariable) { - return std::make_unique(expressionVariable->getName(), op, std::move(value)); + // If expression is a valid l-value, + if(expression->isLValue()) { + return std::make_unique(std::move(expression), op, std::move(value)); } else { - parserState.error(op, "Invalid assignment target"); + parserState.error(op, "Expression is not assignable"); } } From d32e09a913d2d46e85009c9677bea277c24e9f07 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 21 Apr 2023 18:14:51 +0100 Subject: [PATCH 147/725] WIP type checker refactor --- include/genn/genn/transpiler/typeChecker.h | 16 ------ src/genn/genn/transpiler/typeChecker.cc | 67 ++++++---------------- 2 files changed, 17 insertions(+), 66 deletions(-) diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 5455fadad2..072c540f6c 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -44,28 +44,12 @@ class EnvironmentBase // Declared virtuals //------------------------------------------------------------------------ virtual void define(const Token &name, const Type::Base *type, ErrorHandlerBase &errorHandler) = 0; - virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler, - bool initializer = false) = 0; - virtual const Type::Base *incDec(const Token &name, Token::Type op, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler) = 0; virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) = 0; //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- const Type::Base *getType(const Token &name, ErrorHandlerBase &errorHandler); - -protected: - //--------------------------------------------------------------------------- - // Protected API - //--------------------------------------------------------------------------- - const Type::Base *assign(const Token &name, Token::Type op, - const Type::Base *existingType, const Type::Base *assignedType, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler, - bool initializer = false) const; - const Type::Base *incDec(const Token &name, Token::Type op, - const Type::Base *existingType, ErrorHandlerBase &errorHandler) const; }; //--------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 97d4b9dbbe..6c5c4e3ed3 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -90,35 +90,6 @@ class EnvironmentInternal : public EnvironmentBase } } - virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler, - bool initializer = false) final - { - // If type isn't found - auto existingType = m_Types.find(name.lexeme); - if(existingType == m_Types.end()) { - return m_Enclosing.assign(name, op, assignedType, - context, errorHandler, initializer); - } - - // Perform standard type-checking logic - return EnvironmentBase::assign(name, op, existingType->second, assignedType, - context, errorHandler, initializer); - } - - virtual const Type::Base *incDec(const Token &name, Token::Type op, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler) final - { - // If type isn't found - auto existingType = m_Types.find(name.lexeme); - if(existingType == m_Types.end()) { - return m_Enclosing.incDec(name, op, context, errorHandler); - } - - // Perform standard type-checking logic - return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); - } - virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final { auto type = m_Types.find(name.lexeme); @@ -200,10 +171,12 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Assignment &assignment) final { + const auto lhsType = evaluateType(assignment.getAssignee()); const auto rhsType = evaluateType(assignment.getValue()); - setExpressionType(&assignment, - m_Environment.get().assign(assignment.getVarName(), assignment.getOperator().type, rhsType, - m_Context, m_ErrorHandler)); + + assert(false); + + setExpressionType(&assignment, lhsType); } virtual void visit(const Expression::Binary &binary) final @@ -407,6 +380,12 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final { + const auto lhsType = evaluateType(postfixIncDec.getTarget()); + + if(lhsType->hasQualifier(Type::Qualifier::CONSTANT)) { + m_ErrorHandler.error(postfixIncDec.getOperator(), "Increment/decrement of read-only variable"); + throw TypeCheckError(); + } setExpressionType(&postfixIncDec, m_Environment.get().incDec(postfixIncDec.getVarName(), postfixIncDec.getOperator().type, m_Context, m_ErrorHandler)); @@ -414,6 +393,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final { + const auto rhsType = evaluateType(prefixIncDec.getTarget()); setExpressionType(&prefixIncDec, m_Environment.get().incDec(prefixIncDec.getVarName(), prefixIncDec.getOperator().type, m_Context, m_ErrorHandler)); @@ -694,17 +674,17 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::VarDeclaration &varDeclaration) final { + const auto *decType = varDeclaration.getType(); for (const auto &var : varDeclaration.getInitDeclaratorList()) { - m_Environment.get().define(std::get<0>(var), varDeclaration.getType(), m_ErrorHandler); + m_Environment.get().define(std::get<0>(var), decType, m_ErrorHandler); // If variable has an initialiser expression if (std::get<1>(var)) { // Evaluate type const auto initialiserType = evaluateType(std::get<1>(var).get()); - // Assign initialiser expression to variable - m_Environment.get().assign(std::get<0>(var), Token::Type::EQUAL, initialiserType, - m_Context, m_ErrorHandler, true); + assert(false); + // **TODO** check decType = initialiserType is implicit conversion } } } @@ -855,20 +835,7 @@ const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, // **THINK** return existingType; } -//--------------------------------------------------------------------------- -const Type::Base *EnvironmentBase::incDec(const Token &name, Token::Type, - const Type::Base *existingType, ErrorHandlerBase &errorHandler) const -{ - // If existing type has a constant qualifier, give errors - if(existingType->hasQualifier(Type::Qualifier::CONSTANT)) { - errorHandler.error(name, "Increment/decrement of read-only variable"); - throw TypeCheckError(); - } - // Otherwise, return type - else { - return existingType; - } -} + //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker From 7ac8cf056a3bae87af92b0ab9b7ebda2997b0564 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 24 Apr 2023 18:05:48 +0100 Subject: [PATCH 148/725] start of drastic type system refactor --- include/genn/genn/transpiler/parser.h | 2 +- include/genn/genn/type.h | 395 +++++------------------- src/genn/genn/transpiler/typeChecker.cc | 3 + src/genn/genn/type.cc | 224 +++----------- 4 files changed, 128 insertions(+), 496 deletions(-) diff --git a/include/genn/genn/transpiler/parser.h b/include/genn/genn/transpiler/parser.h index ae9010d871..cb0ab11a0d 100644 --- a/include/genn/genn/transpiler/parser.h +++ b/include/genn/genn/transpiler/parser.h @@ -29,6 +29,6 @@ Expression::ExpressionPtr parseExpression(const std::vector &tokens, Erro Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandlerBase &errorHandler); //! Parse type from tokens -const GeNN::Type::NumericBase *parseNumericType(const std::vector &tokens, ErrorHandlerBase &errorHandler); +const GeNN::Type::Type parseNumericType(const std::vector &tokens, ErrorHandlerBase &errorHandler); } // MiniParse::MiniParse diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 59e5341175..0329d232fd 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -6,68 +6,23 @@ // Standard C++ includes #include +#include #include #include -#include #include -#include #include #include +#include #include // GeNN includes #include "gennExport.h" //---------------------------------------------------------------------------- -// Macros -//---------------------------------------------------------------------------- -#define DECLARE_TYPE(TYPE) \ - private: \ - GENN_EXPORT static TYPE *s_Instance; \ - public: \ - static const TYPE *getInstance() \ - { \ - if(s_Instance == NULL) \ - { \ - s_Instance = new TYPE; \ - } \ - return s_Instance; \ - } - -#define DECLARE_NUMERIC_TYPE(TYPE, UNDERLYING_TYPE, RANK, LITERAL_SUFFIX) \ - class TYPE : public Numeric \ - { \ - DECLARE_TYPE(TYPE) \ - TYPE(Qualifier qualifiers = Qualifier{0}) : Numeric(qualifiers){} \ - virtual std::string getName() const final{ return #UNDERLYING_TYPE; } \ - virtual std::string getResolvedName(const TypeContext&) const final{ return #UNDERLYING_TYPE; } \ - virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new TYPE(qualifiers); } \ - virtual std::string getLiteralSuffix(const TypeContext&) const final{ return LITERAL_SUFFIX; } \ - }; \ - template<> \ - struct TypeTraits \ - { \ - using NumericType = TYPE; \ - } - -#define IMPLEMENT_TYPE(TYPE) TYPE *TYPE::s_Instance = NULL - -//---------------------------------------------------------------------------- -// GeNN::Type::TypeTraits +// GeNN::Type::Qualifier //---------------------------------------------------------------------------- namespace GeNN::Type { -//! Empty type trait structure -template -struct TypeTraits -{ -}; - -typedef std::unordered_map TypeContext; - -//---------------------------------------------------------------------------- -// GeNN::Type::Qualifier -//---------------------------------------------------------------------------- enum class Qualifier : unsigned int { CONSTANT = (1 << 0) @@ -84,325 +39,121 @@ inline Qualifier operator | (Qualifier a, Qualifier b) } //---------------------------------------------------------------------------- -// GeNN::Type::Base +// GeNN::Type::Type //---------------------------------------------------------------------------- -//! Base class for all types -class Base +struct Type { -public: - Base(Qualifier qualifiers = Qualifier{0}) : m_Qualifiers(qualifiers){} - //------------------------------------------------------------------------ - // Declared virtuals + // Numeric //------------------------------------------------------------------------ - //! Get the (unqualified) name of this type - virtual std::string getName() const = 0; + struct Numeric + { + const int rank; + const double min; + const double max; + const double lowest; + const int maxDigits10; - //! Get fully-resolved (unqualified) name of this type - virtual std::string getResolvedName(const TypeContext &context) const = 0; - - //! Get size of this type in bytes - virtual size_t getSizeBytes(const TypeContext &context) const = 0; + const bool isSigned; + const bool isIntegral; + + const std::string literalSuffix; + }; - //! Return new version of this type with specified qualifiers - virtual Base *getQualifiedType(Qualifier qualifiers) const = 0; - - //------------------------------------------------------------------------ - // Public API //------------------------------------------------------------------------ - //! Return a pointer to this type, optionally, with specified qualifiers - const class Pointer *getPointerType(Qualifier qualifiers = Qualifier{0}) const; - - //! Does this type have qualifier? - bool hasQualifier(Qualifier qualifier) const{ return (m_Qualifiers & qualifier); }; - -private: + // Pointer //------------------------------------------------------------------------ - // Members - //------------------------------------------------------------------------ - //! Bitfield of qualifiers - const Qualifier m_Qualifiers; -}; - -//---------------------------------------------------------------------------- -// GeNN::Type::Pointer -//---------------------------------------------------------------------------- -//! Type representing a pointer -class Pointer : public Base -{ -public: - Pointer(const Base *valueType, Qualifier qualifiers = Qualifier{0}) - : Base(qualifiers), m_ValueType(valueType) + struct Pointer { - } + Pointer(const Type &valueType) : valueType(std::make_unique(valueType)) + {} + Pointer(const Pointer &other) : valueType(std::make_unique(*other.valueType)) + {} + + const std::unique_ptr valueType; + }; //------------------------------------------------------------------------ - // Base virtuals + // Function //------------------------------------------------------------------------ - virtual std::string getName() const{ return getValueType()->getName() + "*";} - virtual std::string getResolvedName(const TypeContext &context) const{ return getValueType()->getResolvedName(context) + "*"; } - virtual size_t getSizeBytes(const TypeContext&) const final{ return sizeof(char*); } - virtual Base *getQualifiedType(Qualifier qualifiers) const final{ return new Pointer(m_ValueType, qualifiers); } + /*struct Function + { + const std::unique_ptr returnType; + const std::vector argTypes; + };*/ - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - const Base *getValueType() const{ return m_ValueType; } + Type(size_t size, Qualifier qualifiers, const Numeric &numeric) + : size(size), qualifiers(qualifiers), detail(numeric) + {} + Type(size_t size, Qualifier qualifiers, const Pointer &pointer) + : size(size), qualifiers(qualifiers), detail(pointer) + {} + Type(const Type &other) : size(other.size), qualifiers(qualifiers), detail(other.detail) + {} + Type(const Type other, Qualifier qualifiers) : size(other.size), qualifiers(qualifiers), detail(other.detail) + {} -private: //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - const Base *m_ValueType; -}; - -//---------------------------------------------------------------------------- -// GeNN::Type::ValueBase -//---------------------------------------------------------------------------- -class ValueBase : public Base -{ -public: - ValueBase(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} -}; - -//---------------------------------------------------------------------------- -// GeNN::Type::NumericBase -//---------------------------------------------------------------------------- -class NumericBase : public ValueBase -{ -public: - NumericBase(Qualifier qualifiers = Qualifier{0}) : ValueBase(qualifiers){} + const size_t size; - //------------------------------------------------------------------------ - // Declared virtuals - //------------------------------------------------------------------------ - virtual int getRank(const TypeContext&) const = 0; - virtual double getMin(const TypeContext&) const = 0; - virtual double getMax(const TypeContext&) const = 0; - virtual double getLowest(const TypeContext&) const = 0; - virtual int getMaxDigits10(const TypeContext&) const = 0; - - virtual bool isSigned(const TypeContext&) const = 0; - virtual bool isIntegral(const TypeContext&) const = 0; - - virtual std::string getLiteralSuffix(const TypeContext&) const = 0; -}; - -//---------------------------------------------------------------------------- -// GeNN::Type::Numeric -//---------------------------------------------------------------------------- -template -class Numeric : public NumericBase -{ -public: - Numeric(Qualifier qualifiers = Qualifier{0}) : NumericBase(qualifiers){} - - //------------------------------------------------------------------------ - // Base virtuals - //------------------------------------------------------------------------ - virtual size_t getSizeBytes(const TypeContext&) const final{ return sizeof(T); } - - //------------------------------------------------------------------------ - // NumericBase virtuals - //------------------------------------------------------------------------ - virtual int getRank(const TypeContext&) const final { return Rank; } - virtual double getMin(const TypeContext&) const final { return std::numeric_limits::min(); } - virtual double getMax(const TypeContext&) const final { return std::numeric_limits::max(); } - virtual double getLowest(const TypeContext&) const final { return std::numeric_limits::lowest(); } - virtual int getMaxDigits10(const TypeContext&) const final{ return std::numeric_limits::max_digits10; } - - virtual bool isSigned(const TypeContext&) const final { return std::is_signed::value; } - virtual bool isIntegral(const TypeContext&) const final { return std::is_integral::value; } -}; + const Qualifier qualifiers; -//---------------------------------------------------------------------------- -// GeNN::Type::NumericTypedef -//---------------------------------------------------------------------------- -class NumericTypedef : public NumericBase -{ -public: - NumericTypedef(const std::string &name, Qualifier qualifiers = Qualifier{0}) - : NumericBase(qualifiers), m_Name(name){} - - //------------------------------------------------------------------------ - // Base virtuals - //------------------------------------------------------------------------ - virtual std::string getName() const final{ return m_Name; } - virtual std::string getResolvedName(const TypeContext &context) const; - virtual size_t getSizeBytes(const TypeContext &context) const final; - virtual Base *getQualifiedType(Qualifier qualifiers) const final; - - //------------------------------------------------------------------------ - // NumericBase virtuals - //------------------------------------------------------------------------ - virtual int getRank(const TypeContext &context) const final; - virtual double getMin(const TypeContext &context) const final; - virtual double getMax(const TypeContext &context) const final; - virtual double getLowest(const TypeContext &context) const final; - virtual int getMaxDigits10(const TypeContext &context) const final; + const std::variant detail; - virtual bool isSigned(const TypeContext &context) const final; - virtual bool isIntegral(const TypeContext &context) const final; - - virtual std::string getLiteralSuffix(const TypeContext &context) const final; - //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - const Type::NumericBase *getResolvedType(const TypeContext &context) const; - -private: - //------------------------------------------------------------------------ - // Members - //------------------------------------------------------------------------ - const std::string m_Name; -}; + const Numeric &getNumeric() const{ return std::get(detail); } + const Pointer &getPointer() const{ return std::get(detail); } + const Type addQualifier(Qualifier qualifier) const{ return Type(*this, qualifier); } -//---------------------------------------------------------------------------- -// GeNN::Type::FunctionBase -//---------------------------------------------------------------------------- -class FunctionBase : public Base -{ -public: - FunctionBase(Qualifier qualifiers = Qualifier{0}) : Base(qualifiers){} - - //------------------------------------------------------------------------ - // Declared virtuals - //------------------------------------------------------------------------ - virtual const Base *getReturnType() const = 0; - virtual std::vector getArgumentTypes() const = 0; - - //---------------------------------------------------------------------------- - // Public API - //---------------------------------------------------------------------------- - bool isVariadic() const; -}; - -//---------------------------------------------------------------------------- -// GeNN::Type::Function -//---------------------------------------------------------------------------- -template -class Function : public FunctionBase -{ -public: - Function(Qualifier qualifiers = Qualifier{0}) : FunctionBase(qualifiers){} - - //------------------------------------------------------------------------ - // Base virtuals - //------------------------------------------------------------------------ - virtual std::string getName() const final - { - std::string typeName = getReturnType()->getName() + "("; - updateTypeName(typeName); - typeName += ")"; - return typeName; - } - - virtual std::string getResolvedName(const TypeContext &context) const final - { - std::string typeName = getReturnType()->getResolvedName(context) + "("; - updateResolvedTypeName(context, typeName); - typeName += ")"; - return typeName; - } - - virtual size_t getSizeBytes(const TypeContext&) const final - { - assert(false); - return 0; - } - - virtual Base *getQualifiedType(Qualifier qualifiers) const override - { - return new Function(qualifiers); - } - - //------------------------------------------------------------------------ - // FunctionBase virtuals - //------------------------------------------------------------------------ - virtual const Base *getReturnType() const final + bool operator == (const Type &other) const { - return ReturnType::getInstance(); + return (size == other.size && qualifiers == other.qualifiers && detail == other.detail); } - virtual std::vector getArgumentTypes() const final - { - std::vector args; - args.reserve(sizeof...(ArgTypes)); - updateArgumentTypes(args); - return args; - } - -private: //------------------------------------------------------------------------ - // Private methods + // Static API //------------------------------------------------------------------------ - template - static void updateTypeName(std::string &typeName) - { - // Add argument typename to string - typeName += T::getInstance()->getName(); - - // If there are more arguments left in pack, add comma and recurse - if constexpr (sizeof...(Args) > 0) { - typeName += ", "; - updateTypeName(typeName); - } - } - - template - static void updateResolvedTypeName(const TypeContext &context, std::string &typeName) + template + static Type createNumeric(int rank, const std::string &literalSuffix = "", Qualifier qualifiers = Qualifier{0}) { - // Add argument typename to string - typeName += T::getInstance()->getResolvedName(context); - - // If there are more arguments left in pack, add comma and recurse - if constexpr (sizeof...(Args) > 0) { - typeName += ", "; - updateResolvedTypeName(context, typeName); - } - } - - template - static void updateArgumentTypes(std::vector &args) - { - // Add argument typename to string - args.push_back(T::getInstance()); - - // If there are more arguments left in pack, recurse - if constexpr (sizeof...(Args) > 0) { - updateArgumentTypes(args); - } + return Type(sizeof(T), qualifiers, Numeric{rank, std::numeric_limits::min(), std::numeric_limits::max(), + std::numeric_limits::lowest(), std::numeric_limits::max_digits10, + std::is_signed::value, std::is_integral::value, literalSuffix}); } }; +typedef std::unordered_map TypeContext; + //---------------------------------------------------------------------------- // Declare numeric types //---------------------------------------------------------------------------- -DECLARE_NUMERIC_TYPE(Bool, bool, 0, ""); -DECLARE_NUMERIC_TYPE(Int8, int8_t, 10, ""); -DECLARE_NUMERIC_TYPE(Int16, int16_t, 20, ""); -DECLARE_NUMERIC_TYPE(Int32, int32_t, 30, ""); +inline static const Type Bool = Type::createNumeric(0); +inline static const Type Int8 = Type::createNumeric(10); +inline static const Type Int16 = Type::createNumeric(20); +inline static const Type Int32 = Type::createNumeric(30); //DECLARE_NUMERIC_TYPE(Int64, int64_t, 40); -DECLARE_NUMERIC_TYPE(Uint8, uint8_t, 10, "u"); -DECLARE_NUMERIC_TYPE(Uint16, uint16_t, 20, "u"); -DECLARE_NUMERIC_TYPE(Uint32, uint32_t, 30, "u"); +inline static const Type Uint8 = Type::createNumeric(10, "u"); +inline static const Type Uint16 = Type::createNumeric(20, "u"); +inline static const Type Uint32 = Type::createNumeric(30, "u"); //DECLARE_NUMERIC_TYPE(Uint64, uint64_t, 40); -DECLARE_NUMERIC_TYPE(Float, float, 50, "f"); -DECLARE_NUMERIC_TYPE(Double, double, 60, ""); +inline static const Type Float = Type::createNumeric(50, "f"); +inline static const Type Double = Type::createNumeric(60); //! Parse a numeric type -const NumericBase *parseNumeric(const std::string &typeString); +Type parseNumeric(const std::string &typeString); //! Look up numeric type based on set of type specifiers -const NumericBase *getNumericType(const std::set &typeSpecifiers); +Type getNumericType(const std::set &typeSpecifiers); //! Apply C type promotion rules to numeric type -const NumericBase *getPromotedType(const NumericBase *type, const TypeContext &context); +Type getPromotedType(const Type &type); //! Apply C rules to get common type between numeric types a and b -const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b, const TypeContext &context); +Type getCommonType(const Type &a, const Type &b); + -// **YUCK** unimplemented stream operator so we get linker errors if you try and write types directly to an IO stream -std::ostream& operator<<(std::ostream &stream, const Base* value); } // namespace GeNN::Type diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 6c5c4e3ed3..56aeeb8a4b 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -386,6 +386,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor m_ErrorHandler.error(postfixIncDec.getOperator(), "Increment/decrement of read-only variable"); throw TypeCheckError(); } + + // **TODO** + setExpressionType(&postfixIncDec, m_Environment.get().incDec(postfixIncDec.getVarName(), postfixIncDec.getOperator().type, m_Context, m_ErrorHandler)); diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 406047e3e7..53988ab7fc 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -17,44 +17,42 @@ using namespace GeNN; // Anonymous namespace namespace { -const std::map, const Type::NumericBase*> numericTypeSpecifiers{ - {{"char"}, Type::Int8::getInstance()}, - {{"int8_t"}, Type::Int8::getInstance()}, +const std::map, Type::Type> numericTypeSpecifiers{ + {{"char"}, Type::Int8}, + {{"int8_t"}, Type::Int8}, - {{"unsigned", "char"}, Type::Uint8::getInstance()}, - {{"uint8_t"}, Type::Uint8::getInstance()}, - - {{"short"}, Type::Int16::getInstance()}, - {{"short", "int"}, Type::Int16::getInstance()}, - {{"signed", "short"}, Type::Int16::getInstance()}, - {{"signed", "short", "int"}, Type::Int16::getInstance()}, - {{"int16_t"}, Type::Int16::getInstance()}, + {{"unsigned", "char"}, Type::Uint8}, + {{"uint8_t"}, Type::Uint8}, + + {{"short"}, Type::Int16}, + {{"short", "int"}, Type::Int16}, + {{"signed", "short"}, Type::Int16}, + {{"signed", "short", "int"}, Type::Int16}, + {{"int16_t"}, Type::Int16}, - {{"unsigned", "short"}, Type::Uint16::getInstance()}, - {{"unsigned", "short", "int"}, Type::Uint16::getInstance()}, - {{"uint16_t"}, Type::Uint8::getInstance()}, + {{"unsigned", "short"}, Type::Uint16}, + {{"unsigned", "short", "int"}, Type::Uint16}, + {{"uint16_t"}, Type::Uint8}, - {{"int"}, Type::Int32::getInstance()}, - {{"signed"}, Type::Int32::getInstance()}, - {{"signed", "int"}, Type::Int32::getInstance()}, - {{"int32_t"}, Type::Int32::getInstance()}, + {{"int"}, Type::Int32}, + {{"signed"}, Type::Int32}, + {{"signed", "int"}, Type::Int32}, + {{"int32_t"}, Type::Int32}, - {{"unsigned"}, Type::Uint32::getInstance()}, - {{"unsigned", "int"}, Type::Uint32::getInstance()}, - {{"uint32_t"}, Type::Uint32::getInstance()}, + {{"unsigned"}, Type::Uint32}, + {{"unsigned", "int"}, Type::Uint32}, + {{"uint32_t"}, Type::Uint32}, - {{"float"}, Type::Float::getInstance()}, - {{"double"}, Type::Double::getInstance()}, -}; + {{"float"}, Type::Float}, + {{"double"}, Type::Double}}; //---------------------------------------------------------------------------- const std::set scalarTypeSpecifier{{"scalar"}}; //---------------------------------------------------------------------------- // Mapping of signed integer numericTypeSpecifiers to their unsigned equivalents -const std::unordered_map unsignedType{ - {Type::Int8::getInstance(), Type::Uint8::getInstance()}, - {Type::Int16::getInstance(), Type::Uint16::getInstance()}, - {Type::Int32::getInstance(), Type::Uint32::getInstance()} -}; +const std::unordered_map unsignedType{ + {Type::Int8, Type::Uint8}, + {Type::Int16, Type::Uint16}, + {Type::Int32, Type::Uint32}}; } // Anonymous namespace //---------------------------------------------------------------------------- @@ -62,125 +60,10 @@ const std::unordered_map uns //---------------------------------------------------------------------------- namespace GeNN::Type { -// Implement numeric types -IMPLEMENT_TYPE(Bool); -IMPLEMENT_TYPE(Int8); -IMPLEMENT_TYPE(Int16); -IMPLEMENT_TYPE(Int32); -IMPLEMENT_TYPE(Uint8); -IMPLEMENT_TYPE(Uint16); -IMPLEMENT_TYPE(Uint32); -IMPLEMENT_TYPE(Float); -IMPLEMENT_TYPE(Double); - -//---------------------------------------------------------------------------- -// GeNN::Type::Base -//---------------------------------------------------------------------------- -const Pointer *Base::getPointerType(Qualifier qualifiers) const -{ - // **TODO** befriend constructor - // **TODO** don't just leak these! - return new Pointer(this, qualifiers); -} - -//---------------------------------------------------------------------------- -// GeNN::Type::NumericTypedef -//---------------------------------------------------------------------------- -std::string NumericTypedef::getResolvedName(const TypeContext &context) const -{ - return getResolvedType(context)->getResolvedName(context); -} -//---------------------------------------------------------------------------- -size_t NumericTypedef::getSizeBytes(const TypeContext &context) const -{ - return getResolvedType(context)->getSizeBytes(context); -} -//---------------------------------------------------------------------------- -Base *NumericTypedef::getQualifiedType(Qualifier qualifiers) const -{ - return new NumericTypedef(m_Name, qualifiers); -} -//---------------------------------------------------------------------------- -int NumericTypedef::getRank(const TypeContext &context) const -{ - return getResolvedType(context)->getRank(context); -} -//---------------------------------------------------------------------------- -double NumericTypedef::getMin(const TypeContext &context) const -{ - return getResolvedType(context)->getMin(context); -} -//---------------------------------------------------------------------------- -double NumericTypedef::getMax(const TypeContext &context) const -{ - return getResolvedType(context)->getMax(context); -} -//---------------------------------------------------------------------------- -double NumericTypedef::getLowest(const TypeContext &context) const -{ - return getResolvedType(context)->getLowest(context); -} -//---------------------------------------------------------------------------- -int NumericTypedef::getMaxDigits10(const TypeContext &context) const -{ - return getResolvedType(context)->getMaxDigits10(context); -} -//---------------------------------------------------------------------------- -bool NumericTypedef::isSigned(const TypeContext &context) const -{ - return getResolvedType(context)->getSizeBytes(context); -} -//---------------------------------------------------------------------------- -bool NumericTypedef::isIntegral(const TypeContext &context) const -{ - return getResolvedType(context)->isIntegral(context); -} -//---------------------------------------------------------------------------- -std::string NumericTypedef::getLiteralSuffix(const TypeContext &context) const -{ - return getResolvedType(context)->getLiteralSuffix(context); -} -//---------------------------------------------------------------------------- -const Type::NumericBase *NumericTypedef::getResolvedType(const TypeContext &context) const -{ - const auto t = context.find(m_Name); - if (t == context.cend()) { - throw std::runtime_error("No context for typedef '" + m_Name + "'"); - } - else { - const NumericBase *numericType = dynamic_cast(t->second); - if (numericType) { - return numericType; - } - else { - throw std::runtime_error("Numeric typedef '" + m_Name + "' resolved to non-numeric type '" + t->second->getName() + "'"); - } - } -} - -//---------------------------------------------------------------------------- -// GeNN::Type::FunctionBase -//---------------------------------------------------------------------------- -bool FunctionBase::isVariadic() const -{ - // If variadic marker (nullptr) isn't found, function isn't variadic - const auto argTypes = getArgumentTypes(); - const auto variadicMarker = std::find(argTypes.cbegin(), argTypes.cend(), nullptr); - if(variadicMarker == argTypes.cend()) { - return false; - } - // Otherwise, after checking variadic marker is last argument, return true - else { - assert(argTypes.back() == nullptr); - return true; - } - -} - //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- -const NumericBase *parseNumeric(const std::string &typeString) +Type parseNumeric(const std::string &typeString) { using namespace Transpiler; @@ -189,87 +72,82 @@ const NumericBase *parseNumeric(const std::string &typeString) const auto tokens = Scanner::scanSource(typeString, errorHandler); // Parse type numeric type - const auto *type = Parser::parseNumericType(tokens, errorHandler); + const auto type = Parser::parseNumericType(tokens, errorHandler); // If an error was encountered while scanning or parsing, throw exception if (errorHandler.hasError()) { throw std::runtime_error("Error parsing type '" + std::string{typeString} + "'"); } - // If tokens did not contain a valid numeric type, throw exception - if (!type) { - throw std::runtime_error("Unable to parse type '" + std::string{typeString} + "'"); - } return type; } //---------------------------------------------------------------------------- -const NumericBase *getNumericType(const std::set &typeSpecifiers) +Type getNumericType(const std::set &typeSpecifiers) { // If type matches scalar type specifiers if(typeSpecifiers == scalarTypeSpecifier) { - return new NumericTypedef("scalar"); + //return new NumericTypedef("scalar"); } // Otherwise else { const auto type = numericTypeSpecifiers.find(typeSpecifiers); - return (type == numericTypeSpecifiers.cend()) ? nullptr : type->second; + //return (type == numericTypeSpecifiers.cend()) ? nullptr : type->second; + return type->second; } } //---------------------------------------------------------------------------- -const NumericBase *getPromotedType(const NumericBase *type, const TypeContext &context) +Type getPromotedType(const Type &type) { // If a small integer type is used in an expression, it is implicitly converted to int which is always signed. // This is known as the integer promotions or the integer promotion rule // **NOTE** this is true because in our type system unsigned short is uint16 which can be represented in int32 - if(type->getRank(context) < Int32::getInstance()->getRank(context)) { - return Int32::getInstance(); + if(type.getNumeric().rank < Int32.getNumeric().rank) { + return Int32; } else { return type; } } //---------------------------------------------------------------------------- -const NumericBase *getCommonType(const NumericBase *a, const NumericBase *b, const TypeContext &context) +Type getCommonType(const Type &a, const Type &b) { // If either type is double, common type is double - const auto &aTypeName = a->getResolvedName(context); - const auto &bTypeName = b->getResolvedName(context); - if(aTypeName == Double::getInstance()->getName() || bTypeName == Double::getInstance()->getName()) { - return Double::getInstance(); + if(a == Double || b == Double) { + return Double; } // Otherwise, if either type is float, common type is float - if(aTypeName == Float::getInstance()->getName() || bTypeName == Float::getInstance()->getName()) { - return Float::getInstance(); + if(a == Float || b == Float) { + return Float; } // Otherwise, must be an integer type else { // Promote both numeric types - const auto *aPromoted = getPromotedType(a, context); - const auto *bPromoted = getPromotedType(b, context); + const Type aPromoted = getPromotedType(a); + const Type bPromoted = getPromotedType(b); // If both promoted operands have the same type, then no further conversion is needed. - if(aPromoted->getResolvedName(context) == bPromoted->getResolvedName(context)) { + if(aPromoted == bPromoted) { return aPromoted; } // Otherwise, if both promoted operands have signed integer numeric types or both have unsigned integer numeric types, // the operand with the type of lesser integer conversion rank is converted to the type of the operand with greater rank. - else if(aPromoted->isSigned(context) == bPromoted->isSigned(context)) { - return (aPromoted->getRank(context) > bPromoted->getRank(context)) ? aPromoted : bPromoted; + else if(aPromoted.getNumeric().isSigned == bPromoted.getNumeric().isSigned) { + return (aPromoted.getNumeric().rank > bPromoted.getNumeric().rank) ? aPromoted : bPromoted; } // Otherwise, if signedness of promoted operands differ else { - const auto *signedOp = aPromoted->isSigned(context) ? aPromoted : bPromoted; - const auto *unsignedOp = aPromoted->isSigned(context) ? bPromoted : aPromoted; + const Type signedOp = aPromoted.getNumeric().isSigned ? aPromoted : bPromoted; + const Type unsignedOp = aPromoted.getNumeric().isSigned ? bPromoted : aPromoted; // Otherwise, if the operand that has unsigned integer type has rank greater or equal to the rank of the type of the other operand, // then the operand with signed integer type is converted to the type of the operand with unsigned integer type. - if(unsignedOp->getRank(context) >= signedOp->getRank(context)) { + if(unsignedOp.getNumeric().rank >= signedOp.getNumeric().rank) { return unsignedOp; } // Otherwise, if the type of the operand with signed integer type can represent all of the values of the type of the operand with unsigned integer type, // then the operand with unsigned integer type is converted to the type of the operand with signed integer type. - else if((signedOp->getMin(context) <= unsignedOp->getMin(context)) - && (signedOp->getMax(context) >= unsignedOp->getMax(context))) + else if((signedOp.getNumeric().min <= unsignedOp.getNumeric().min) + && (signedOp.getNumeric().max >= unsignedOp.getNumeric().max)) { return signedOp; } From bb8fa26bea16a55a8ee93f93527c50c4cd4c1574 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Apr 2023 09:20:03 +0100 Subject: [PATCH 149/725] WIP update of type checker --- include/genn/genn/transpiler/expression.h | 6 +- include/genn/genn/transpiler/statement.h | 6 +- include/genn/genn/transpiler/typeChecker.h | 8 +- include/genn/genn/type.h | 111 +++- src/genn/genn/transpiler/typeChecker.cc | 566 +++++++++++---------- src/genn/genn/type.cc | 6 +- 6 files changed, 413 insertions(+), 290 deletions(-) diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index 4a20e02edb..4d2499b78b 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -168,16 +168,16 @@ class Call : public Acceptable class Cast : public Acceptable { public: - Cast(const Type::Base *type, ExpressionPtr expression, Token closingParen) + Cast(const Type::Type &type, ExpressionPtr expression, Token closingParen) : m_Type(type), m_Expression(std::move(expression)), m_ClosingParen(closingParen) {} - const Type::Base *getType() const{ return m_Type; } + const Type::Type &getType() const{ return m_Type; } const Base *getExpression() const { return m_Expression.get(); } const Token &getClosingParen() const { return m_ClosingParen; } private: - const Type::Base *m_Type; + const Type::Type m_Type; const ExpressionPtr m_Expression; const Token m_ClosingParen; }; diff --git a/include/genn/genn/transpiler/statement.h b/include/genn/genn/transpiler/statement.h index 1dc454edd4..c9177d3112 100644 --- a/include/genn/genn/transpiler/statement.h +++ b/include/genn/genn/transpiler/statement.h @@ -255,15 +255,15 @@ class VarDeclaration : public Acceptable public: typedef std::vector> InitDeclaratorList; - VarDeclaration(const Type::Base *type, InitDeclaratorList initDeclaratorList) + VarDeclaration(const Type::Type &type, InitDeclaratorList initDeclaratorList) : m_Type(type), m_InitDeclaratorList(std::move(initDeclaratorList)) {} - const Type::Base *getType() const{ return m_Type; } + const Type::Type &getType() const{ return m_Type; } const InitDeclaratorList &getInitDeclaratorList() const { return m_InitDeclaratorList; } private: - const Type::Base *m_Type; + const Type::Type m_Type; const std::vector m_DeclarationSpecifiers; const InitDeclaratorList m_InitDeclaratorList; }; diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 072c540f6c..afa454dd1e 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -32,7 +32,7 @@ class TypeCheckError : public std::runtime_error } }; -typedef std::unordered_map ResolvedTypeMap; +typedef std::unordered_map ResolvedTypeMap; //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker::EnvironmentBase @@ -43,13 +43,13 @@ class EnvironmentBase //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ - virtual void define(const Token &name, const Type::Base *type, ErrorHandlerBase &errorHandler) = 0; - virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) = 0; + virtual void define(const Token &name, const Type::Type &type, ErrorHandlerBase &errorHandler) = 0; + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) = 0; //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- - const Type::Base *getType(const Token &name, ErrorHandlerBase &errorHandler); + Type::Type getType(const Token &name, ErrorHandlerBase &errorHandler); }; //--------------------------------------------------------------------------- diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 0329d232fd..082d4a95fb 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -18,6 +18,11 @@ // GeNN includes #include "gennExport.h" +//---------------------------------------------------------------------------- +// Macros +//---------------------------------------------------------------------------- +#define CREATE_NUMERIC(TYPE, RANK, L_SUFFIX) Type::createNumeric(#TYPE, RANK, L_SUFFIX) + //---------------------------------------------------------------------------- // GeNN::Type::Qualifier //---------------------------------------------------------------------------- @@ -48,6 +53,8 @@ struct Type //------------------------------------------------------------------------ struct Numeric { + const std::string name; + const int rank; const double min; const double max; @@ -58,6 +65,21 @@ struct Type const bool isIntegral; const std::string literalSuffix; + + //------------------------------------------------------------------------ + // Operators + //------------------------------------------------------------------------ + bool operator == (const Numeric &other) const + { + return (std::make_tuple(rank, min, max, lowest, maxDigits10, isSigned, isIntegral) + == std::make_tuple(other.rank, other.min, other.max, other.lowest, other.maxDigits10, other.isSigned, other.isIntegral)); + } + + bool operator < (const Numeric &other) const + { + return (std::make_tuple(rank, min, max, lowest, maxDigits10, isSigned, isIntegral) + < std::make_tuple(other.rank, other.min, other.max, other.lowest, other.maxDigits10, other.isSigned, other.isIntegral)); + } }; //------------------------------------------------------------------------ @@ -71,24 +93,55 @@ struct Type {} const std::unique_ptr valueType; + + bool operator == (const Pointer &other) const + { + return (*valueType == *other.valueType); + } + + bool operator < (const Pointer &other) const + { + return (*valueType < *other.valueType); + } }; //------------------------------------------------------------------------ // Function //------------------------------------------------------------------------ - /*struct Function + struct Function { + Function(const Type &returnType, const std::vector &argTypes) + : returnType(std::make_unique(returnType)), argTypes(argTypes) + {} + Function(const Function &other) + : returnType(std::make_unique(*other.returnType)), argTypes(other.argTypes) + {} + const std::unique_ptr returnType; const std::vector argTypes; - };*/ + + bool operator == (const Function &other) const + { + return (*returnType == *other.returnType && argTypes == other.argTypes); + } + + bool operator < (const Function &other) const + { + return (*returnType < *other.returnType); + } + }; Type(size_t size, Qualifier qualifiers, const Numeric &numeric) : size(size), qualifiers(qualifiers), detail(numeric) {} - Type(size_t size, Qualifier qualifiers, const Pointer &pointer) - : size(size), qualifiers(qualifiers), detail(pointer) + Type(Qualifier qualifiers, const Pointer &pointer) + : size(sizeof(char*)), qualifiers(qualifiers), detail(pointer) + {} + Type(const Function &function) + : size(0), qualifiers(Qualifier{0}), detail(function) {} - Type(const Type &other) : size(other.size), qualifiers(qualifiers), detail(other.detail) + + Type(const Type &other) : size(other.size), qualifiers(other.qualifiers), detail(other.detail) {} Type(const Type other, Qualifier qualifiers) : size(other.size), qualifiers(qualifiers), detail(other.detail) {} @@ -97,33 +150,53 @@ struct Type // Members //------------------------------------------------------------------------ const size_t size; - const Qualifier qualifiers; - const std::variant detail; + const std::variant detail; //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ + bool isNumeric() const{ return std::holds_alternative(detail); } + bool isPointer() const{ return std::holds_alternative(detail); } + bool isFunction() const{ return std::holds_alternative(detail); } const Numeric &getNumeric() const{ return std::get(detail); } const Pointer &getPointer() const{ return std::get(detail); } + const Function &getFunction() const{ return std::get(detail); } + const Type addQualifier(Qualifier qualifier) const{ return Type(*this, qualifier); } + bool hasQualifier(Qualifier qualifier) const{ return (qualifiers & qualifier); } + //------------------------------------------------------------------------ + // Operators + //------------------------------------------------------------------------ bool operator == (const Type &other) const { - return (size == other.size && qualifiers == other.qualifiers && detail == other.detail); + return (std::make_tuple(size, qualifiers, detail) + == std::make_tuple(other.size, other.qualifiers, other.detail)); + } + + bool operator < (const Type &other) const + { + return (std::make_tuple(size, qualifiers, detail) + < std::make_tuple(other.size, other.qualifiers, other.detail)); } //------------------------------------------------------------------------ // Static API //------------------------------------------------------------------------ template - static Type createNumeric(int rank, const std::string &literalSuffix = "", Qualifier qualifiers = Qualifier{0}) + static Type createNumeric(const std::string &name, int rank, const std::string &literalSuffix = "", Qualifier qualifiers = Qualifier{0}) { - return Type(sizeof(T), qualifiers, Numeric{rank, std::numeric_limits::min(), std::numeric_limits::max(), + return Type(sizeof(T), qualifiers, Numeric{name, rank, std::numeric_limits::min(), std::numeric_limits::max(), std::numeric_limits::lowest(), std::numeric_limits::max_digits10, std::is_signed::value, std::is_integral::value, literalSuffix}); } + + static Type createPointer(const Type &valueType, Qualifier qualifiers = Qualifier{0}) + { + return Type(qualifiers, Pointer{valueType}); + } }; typedef std::unordered_map TypeContext; @@ -131,17 +204,17 @@ typedef std::unordered_map TypeContext; //---------------------------------------------------------------------------- // Declare numeric types //---------------------------------------------------------------------------- -inline static const Type Bool = Type::createNumeric(0); -inline static const Type Int8 = Type::createNumeric(10); -inline static const Type Int16 = Type::createNumeric(20); -inline static const Type Int32 = Type::createNumeric(30); +inline static const Type Bool = CREATE_NUMERIC(bool, 0, ""); +inline static const Type Int8 = CREATE_NUMERIC(int8_t, 10, ""); +inline static const Type Int16 = CREATE_NUMERIC(int16_t, 20, ""); +inline static const Type Int32 = CREATE_NUMERIC(int32_t, 30, ""); //DECLARE_NUMERIC_TYPE(Int64, int64_t, 40); -inline static const Type Uint8 = Type::createNumeric(10, "u"); -inline static const Type Uint16 = Type::createNumeric(20, "u"); -inline static const Type Uint32 = Type::createNumeric(30, "u"); +inline static const Type Uint8 = CREATE_NUMERIC(uint8_t, 10, "u"); +inline static const Type Uint16 = CREATE_NUMERIC(uint16_t, 20, "u"); +inline static const Type Uint32 = CREATE_NUMERIC(uint32_t, 30, "u"); //DECLARE_NUMERIC_TYPE(Uint64, uint64_t, 40); -inline static const Type Float = Type::createNumeric(50, "f"); -inline static const Type Double = Type::createNumeric(60); +inline static const Type Float = CREATE_NUMERIC(float, 50, "f"); +inline static const Type Double = CREATE_NUMERIC(double, 60, ""); //! Parse a numeric type Type parseNumeric(const std::string &typeString); diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 56aeeb8a4b..1579c30b51 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -3,6 +3,7 @@ // Standard C++ includes #include #include +#include #include #include @@ -26,46 +27,69 @@ namespace Type = GeNN::Type; //--------------------------------------------------------------------------- namespace { -bool checkPointerTypeAssignement(const Type::Base *rightType, const Type::Base *leftType, const Type::TypeContext &typeContext) +std::string getDescription(const Type::Type &type) { - // If both are pointers, recurse through value type - auto rightPointerType = dynamic_cast(rightType); - auto leftPointerType = dynamic_cast(leftType); - if (rightPointerType && leftPointerType) { - return checkPointerTypeAssignement(rightPointerType->getValueType(), leftPointerType->getValueType(), typeContext); - } - // Otherwise, if we've hit the value type at the end of the chain, check resolved names match - else if (!rightPointerType && !leftPointerType) { - return (rightType->getResolvedName(typeContext) == leftType->getResolvedName(typeContext)); - } - // Otherwise, pointers with different levels of indirection e.g. int* and int** are being compared - else { - return false; - } + const std::string qualifier = type.hasQualifier(Type::Qualifier::CONSTANT) ? "const " : ""; + return std::visit( + Utils::Overload{ + [&qualifier](const Type::Type::Numeric &numeric) + { + return qualifier + numeric.name; + }, + [&qualifier, &type](const Type::Type::Pointer &pointer) + { + return qualifier + getDescription(*pointer.valueType) + "*"; + }, + [&type](const Type::Type::Function &function) + { + std::string description = getDescription(*function.returnType) + "("; + for (const auto &a : function.argTypes) { + description += (getDescription(a) + ","); + } + return description + ")"; + }}, + type.detail); +} +//--------------------------------------------------------------------------- +bool checkPointerTypeAssignement(const Type::Type &rightType, const Type::Type &leftType) +{ + return std::visit( + Utils::Overload{ + [&rightType, &leftType](const Type::Type::Numeric &rightNumeric, const Type::Type::Numeric &leftNumeric) + { + return (rightType == leftType); + }, + [](const Type::Type::Pointer &rightPointer, const Type::Type::Pointer &leftPointer) + { + return checkPointerTypeAssignement(*rightPointer.valueType, *leftPointer.valueType); + }, + // Otherwise, pointers with different levels of indirection e.g. int* and int** are being compared + [](auto, auto) { return false; }}, + rightType.detail, leftType.detail); } //--------------------------------------------------------------------------- -bool checkForConstRemoval(const Type::Base *rightType, const Type::Base *leftType) +bool checkForConstRemoval(const Type::Type &rightType, const Type::Type &leftType) { // If const is being removed - if (rightType->hasQualifier(Type::Qualifier::CONSTANT) && !leftType->hasQualifier(Type::Qualifier::CONSTANT)) { - return false; - } - - // If both are pointers, recurse through value type - auto rightPointerType = dynamic_cast(rightType); - auto leftPointerType = dynamic_cast(leftType); - if (rightPointerType && leftPointerType) { - return checkForConstRemoval(rightPointerType->getValueType(), leftPointerType->getValueType()); - } - // Otherwise, if both are non-pointers, return true as const removal has been succesfully checked - else if (!rightPointerType && !leftPointerType) { - return true; - } - // Otherwise, pointers with different levels of indirection e.g. int* and int** are being compared - else { + if (rightType.hasQualifier(Type::Qualifier::CONSTANT) && !leftType.hasQualifier(Type::Qualifier::CONSTANT)) { return false; } + return std::visit( + Utils::Overload{ + // If both are non-pointers, return true as const removal has been succesfully checked + [](const Type::Type::Numeric &rightNumeric, const Type::Type::Numeric &leftNumeric) + { + return true; + }, + // Otherwise, if both are pointers, recurse through value type + [](const Type::Type::Pointer &rightPointer, const Type::Type::Pointer &leftPointer) + { + return checkForConstRemoval(*rightPointer.valueType, *leftPointer.valueType); + }, + // Otherwise, pointers with different levels of indirection e.g. int* and int** are being compared + [](auto, auto) { return false; }}, + rightType.detail, leftType.detail); } //--------------------------------------------------------------------------- @@ -82,7 +106,7 @@ class EnvironmentInternal : public EnvironmentBase //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- - virtual void define(const Token &name, const Type::Base *type, ErrorHandlerBase &errorHandler) final + virtual void define(const Token &name, const Type::Type &type, ErrorHandlerBase &errorHandler) final { if(!m_Types.try_emplace(name.lexeme, type).second) { errorHandler.error(name, "Redeclaration of variable"); @@ -90,7 +114,7 @@ class EnvironmentInternal : public EnvironmentBase } } - virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final { auto type = m_Types.find(name.lexeme); if(type == m_Types.end()) { @@ -106,7 +130,7 @@ class EnvironmentInternal : public EnvironmentBase // Members //--------------------------------------------------------------------------- EnvironmentBase &m_Enclosing; - std::unordered_map m_Types; + std::unordered_map m_Types; }; //--------------------------------------------------------------------------- @@ -144,23 +168,21 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- virtual void visit(const Expression::ArraySubscript &arraySubscript) final { - // Get pointer type + // Evaluate array type auto arrayType = evaluateType(arraySubscript.getArray()); - auto pointerType = dynamic_cast(arrayType); // If pointer is indeed a pointer - if (pointerType) { + if(arrayType.isPointer()) { // Evaluate pointer type auto indexType = evaluateType(arraySubscript.getIndex()); - auto indexNumericType = dynamic_cast(indexType); - if (!indexNumericType || !indexNumericType->isIntegral(m_Context)) { + if (!indexType.isNumeric() || !indexType.getNumeric().isIntegral) { m_ErrorHandler.error(arraySubscript.getClosingSquareBracket(), - "Invalid subscript index type '" + indexType->getName() + "'"); + "Invalid subscript index type '" + getDescription(indexType) + "'"); throw TypeCheckError(); } // Use value type of array - setExpressionType(&arraySubscript, pointerType->getValueType()); + setExpressionType(&arraySubscript, *arrayType.getPointer().valueType); } // Otherwise else { @@ -189,79 +211,96 @@ class Visitor : public Expression::Visitor, public Statement::Visitor else { // If we're subtracting two pointers const auto leftType = evaluateType(binary.getLeft()); - auto leftNumericType = dynamic_cast(leftType); - auto rightNumericType = dynamic_cast(rightType); - auto leftPointerType = dynamic_cast(leftType); - auto rightPointerType = dynamic_cast(rightType); - if (leftPointerType && rightPointerType && opType == Token::Type::MINUS) { - // Check pointers are compatible - if (leftPointerType->getResolvedName(m_Context) != rightPointerType->getResolvedName(m_Context)) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); - throw TypeCheckError(); - } - // **TODO** should be std::ptrdiff/Int64 - setExpressionType(&binary); - } - // Otherwise, if we're adding to or subtracting from pointers - else if (leftPointerType && rightNumericType && (opType == Token::Type::PLUS || opType == Token::Type::MINUS)) { // P + n or P - n - // Check that numeric operand is integer - if (!rightNumericType->isIntegral(m_Context)) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); - throw TypeCheckError(); - } - - // Use left type - setExpressionType(&binary, leftType); - } - // Otherwise, if we're adding a number to a pointer - else if (leftNumericType && rightPointerType && opType == Token::Type::PLUS) { // n + P - // Check that numeric operand is integer - if (!leftNumericType->isIntegral(m_Context)) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); - throw TypeCheckError(); - } - - // Use right type - setExpressionType(&binary, rightType); - } - // Otherwise, if both operands are numeric - else if (leftNumericType && rightNumericType) { - // Otherwise, if operator requires integer operands - if (opType == Token::Type::PERCENT || opType == Token::Type::SHIFT_LEFT - || opType == Token::Type::SHIFT_RIGHT || opType == Token::Type::CARET - || opType == Token::Type::AMPERSAND || opType == Token::Type::PIPE) - { - // Check that operands are integers - if (!leftNumericType->isIntegral(m_Context) || !rightNumericType->isIntegral(m_Context)) { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); - throw TypeCheckError(); - } - - // If operator is a shift, promote left type - if (opType == Token::Type::SHIFT_LEFT || opType == Token::Type::SHIFT_RIGHT) { - setExpressionType(&binary, Type::getPromotedType(leftNumericType, m_Context)); - } - // Otherwise, take common type - else { - setExpressionType(&binary, Type::getCommonType(leftNumericType, rightNumericType, m_Context)); - } + // Visit permutations of left and right types + const auto resultType = std::visit( + Utils::Overload{ + // If both operands are numeric + [&leftType, &rightType, opType, this] + (const Type::Type::Numeric &rightNumeric, const Type::Type::Numeric &leftNumeric) -> std::optional + { + // If operator requires integer operands + if (opType == Token::Type::PERCENT || opType == Token::Type::SHIFT_LEFT + || opType == Token::Type::SHIFT_RIGHT || opType == Token::Type::CARET + || opType == Token::Type::AMPERSAND || opType == Token::Type::PIPE) + { + // Check that operands are integers + if (leftNumeric.isIntegral && rightNumeric.isIntegral) { + // If operator is a shift, promote left type + if (opType == Token::Type::SHIFT_LEFT || opType == Token::Type::SHIFT_RIGHT) { + return Type::getPromotedType(leftType); + } + // Otherwise, take common type + else { + return Type::getCommonType(leftType, rightType); + } + } + else { + return std::nullopt; + } + } + // Otherwise, any numeric type will do, take common type + else { + return Type::getCommonType(leftType, rightType); + } + }, + // Otherwise, if both operands are pointers + [&binary, &leftType, &rightType, opType, this] + (const Type::Type::Pointer &rightPointer, const Type::Type::Pointer &leftPointer) -> std::optional + { + // If operator is minus and pointer types match + if (opType == Token::Type::MINUS && leftType == rightType) { + // **TODO** should be std::ptrdiff/Int64 + return Type::Int32; + } + else { + return std::nullopt; + } + }, + // Otherwise, if right is numeric and left is pointer + [&binary, &leftType, &rightType, opType, this] + (const Type::Type::Numeric &rightNumeric, const Type::Type::Pointer &leftPointer) -> std::optional + { + // If operator is valid and numeric type is integer + // P + n or P - n + if ((opType == Token::Type::PLUS || opType == Token::Type::MINUS) && rightNumeric.isIntegral) { + return leftType; + } + else { + return std::nullopt; + } + }, + // Otherwise, if right is pointer and left is numeric + [&binary, &rightType, opType, this] + (const Type::Type::Pointer &rightPointer, const Type::Type::Numeric &leftNumeric) -> std::optional + { + // n + P + if (opType == Token::Type::PLUS && leftNumeric.isIntegral) { + return rightType; + } + else { + return std::nullopt; + } + }, + // Otherwise, operator is being applied to unsupported types + [](auto, auto) -> std::optional + { + return std::nullopt; + }}, + rightType.detail, leftType.detail); + + if (resultType) { + setExpressionType(&binary, *resultType); } - // Otherwise, any numeric type will do, take common type else { - setExpressionType(&binary, Type::getCommonType(leftNumericType, rightNumericType, m_Context)); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + getDescription(leftType) + "' and '" + getDescription(rightType)); + throw TypeCheckError(); } - } - else { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType->getName() + "' and '" + rightType->getName()); - throw TypeCheckError(); - } } } virtual void visit(const Expression::Call &call) final { - // Evaluate argument types and store in top of stack m_CallArguments.emplace(); std::transform(call.getArguments().cbegin(), call.getArguments().cend(), std::back_inserter(m_CallArguments.top()), @@ -269,14 +308,13 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // Evaluate callee type auto calleeType = evaluateType(call.getCallee()); - auto calleeFunctionType = dynamic_cast(calleeType); // Pop stack m_CallArguments.pop(); // If callee's a function, type is return type of function - if (calleeFunctionType) { - setExpressionType(&call, calleeFunctionType->getReturnType()); + if (calleeType.isFunction()) { + setExpressionType(&call, *calleeType.getFunction().returnType); } // Otherwise else { @@ -292,48 +330,61 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If const is being removed if (!checkForConstRemoval(rightType, cast.getType())) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + getDescription(cast.getType()) + "' and '" + getDescription(rightType)); throw TypeCheckError(); } - // If we're trying to cast pointer to pointer - auto rightNumericType = dynamic_cast(rightType); - auto rightPointerType = dynamic_cast(rightType); - auto leftNumericType = dynamic_cast(cast.getType()); - auto leftPointerType = dynamic_cast(cast.getType()); - if (rightPointerType && leftPointerType) { - // Check that value type at the end matches - if (!checkPointerTypeAssignement(rightPointerType->getValueType(), leftPointerType->getValueType(), m_Context)) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); - throw TypeCheckError(); - } + const auto resultType = std::visit( + Utils::Overload{ + // If types are numeric, any cast goes + [&cast](const Type::Type::Numeric &rightNumeric, const Type::Type::Numeric &castNumeric) -> std::optional + { + return cast.getType(); + }, + // Otherwise, if we're trying to cast pointer to pointer + [&cast](const Type::Type::Pointer &rightPointer, const Type::Type::Pointer &castPointer) -> std::optional + { + // Check that value type at the end matches + if (checkPointerTypeAssignement(*rightPointer.valueType, *castPointer.valueType)) { + return cast.getType(); + } + else { + return std::nullopt; + } + }, + // Otherwise, pointers can't be cast to non-pointers and vice versa + [](auto, auto) -> std::optional + { + return std::nullopt; + }}, + rightType.detail, cast.getType().detail); + + if (resultType) { + setExpressionType(&cast, *resultType); } - // Otherwise, if either operand isn't numeric - else if(!leftNumericType | !rightNumericType) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType()->getName() + "' and '" + rightType->getName()); - throw TypeCheckError(); + else { + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + getDescription(cast.getType()) + "' and '" + getDescription(rightType)); + throw TypeCheckError(); } - - setExpressionType(&cast, cast.getType()); } virtual void visit(const Expression::Conditional &conditional) final { const auto trueType = evaluateType(conditional.getTrue()); const auto falseType = evaluateType(conditional.getFalse()); - auto trueNumericType = dynamic_cast(trueType); - auto falseNumericType = dynamic_cast(falseType); - if (trueNumericType && falseNumericType) { + if (trueType.isNumeric() && falseType.isNumeric()) { // **TODO** check behaviour - const Type::Base *type = Type::getCommonType(trueNumericType, falseNumericType, m_Context); - if(trueType->hasQualifier(Type::Qualifier::CONSTANT) || falseType->hasQualifier(Type::Qualifier::CONSTANT)) { - type = type->getQualifiedType(Type::Qualifier::CONSTANT); + const auto commonType = Type::getCommonType(trueType, falseType); + if(trueType.hasQualifier(Type::Qualifier::CONSTANT) || falseType.hasQualifier(Type::Qualifier::CONSTANT)) { + setExpressionType(&conditional, commonType.addQualifier(Type::Qualifier::CONSTANT)); + } + else { + setExpressionType(&conditional, commonType); } - setExpressionType(&conditional, type); } else { m_ErrorHandler.error(conditional.getQuestion(), - "Invalid operand types '" + trueType->getName() + "' and '" + falseType->getName() + "' to conditional"); + "Invalid operand types '" + getDescription(trueType) + "' and '" + getDescription(falseType) + "' to conditional"); throw TypeCheckError(); } } @@ -348,23 +399,25 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // Convert number token type to type // **THINK** is it better to use typedef for scalar or resolve from m_Context if (literal.getValue().type == Token::Type::DOUBLE_NUMBER) { - setExpressionType(&literal); + setExpressionType(&literal, Type::Double); } else if (literal.getValue().type == Token::Type::FLOAT_NUMBER) { - setExpressionType(&literal); + setExpressionType(&literal, Type::Float); } else if (literal.getValue().type == Token::Type::SCALAR_NUMBER) { // **TODO** cache - setExpressionType(&literal, new Type::NumericTypedef("scalar")); + assert(false); + // **THINK** why not resolve here? + //setExpressionType(&literal, new Type::NumericTypedef("scalar")); } else if (literal.getValue().type == Token::Type::INT32_NUMBER) { - setExpressionType(&literal); + setExpressionType(&literal, Type::Int32); } else if (literal.getValue().type == Token::Type::UINT32_NUMBER) { - setExpressionType(&literal); + setExpressionType(&literal, Type::Uint32); } else if(literal.getValue().type == Token::Type::STRING) { - setExpressionType(&literal, Type::Int8::getInstance()->getPointerType()); + setExpressionType(&literal, Type::Type::createPointer(Type::Int8, Type::Qualifier::CONSTANT)); } else { assert(false); @@ -375,38 +428,38 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { logical.getLeft()->accept(*this); logical.getRight()->accept(*this); - setExpressionType(&logical); + setExpressionType(&logical, Type::Int32); } virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final { const auto lhsType = evaluateType(postfixIncDec.getTarget()); - - if(lhsType->hasQualifier(Type::Qualifier::CONSTANT)) { + if(lhsType.hasQualifier(Type::Qualifier::CONSTANT)) { m_ErrorHandler.error(postfixIncDec.getOperator(), "Increment/decrement of read-only variable"); throw TypeCheckError(); } - - // **TODO** - - setExpressionType(&postfixIncDec, - m_Environment.get().incDec(postfixIncDec.getVarName(), postfixIncDec.getOperator().type, - m_Context, m_ErrorHandler)); + else { + setExpressionType(&postfixIncDec, lhsType); + } } virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final { const auto rhsType = evaluateType(prefixIncDec.getTarget()); - setExpressionType(&prefixIncDec, - m_Environment.get().incDec(prefixIncDec.getVarName(), prefixIncDec.getOperator().type, - m_Context, m_ErrorHandler)); + if(rhsType.hasQualifier(Type::Qualifier::CONSTANT)) { + m_ErrorHandler.error(prefixIncDec.getOperator(), "Increment/decrement of read-only variable"); + throw TypeCheckError(); + } + else { + setExpressionType(&prefixIncDec, rhsType); + } } virtual void visit(const Expression::Variable &variable) { // If type is unambiguous and not a function const auto varTypes = m_Environment.get().getTypes(variable.getName(), m_ErrorHandler); - if (varTypes.size() == 1 && dynamic_cast(varTypes.front()) == nullptr) { + if (varTypes.size() == 1 && !varTypes.front().isFunction()) { setExpressionType(&variable, varTypes.front()); } // Otherwise @@ -415,19 +468,11 @@ class Visitor : public Expression::Visitor, public Statement::Visitor assert(!m_CallArguments.empty()); // Loop through variable types - std::vector>> viableFunctions; - for(const auto *type : varTypes) { - // Cast to function (only functions should be overloaded) - const auto *func = dynamic_cast(type); - assert(func); - - // If function is variadic and there are at least as many vall arguments as actual (last is nullptr) - // function parameters or function is non-variadic and number of arguments match - const auto argumentTypes = func->getArgumentTypes(); - const bool variadic = func->isVariadic(); - if((variadic && m_CallArguments.top().size() >= (argumentTypes.size() - 1)) - || (!variadic && m_CallArguments.top().size() == argumentTypes.size())) - { + std::vector>> viableFunctions; + for(const auto &type : varTypes) { + // If function is non-variadic and number of arguments match + const auto &argumentTypes = type.getFunction().argTypes; + if(m_CallArguments.top().size() == argumentTypes.size()) { // Create vector to hold argument conversion rank std::vector argumentConversionRank; argumentConversionRank.reserve(m_CallArguments.top().size()); @@ -436,50 +481,64 @@ class Visitor : public Expression::Visitor, public Statement::Visitor bool viable = true; auto c = m_CallArguments.top().cbegin(); auto a = argumentTypes.cbegin(); - for(;c != m_CallArguments.top().cend() && *a != nullptr; c++, a++) { - auto cNumericType = dynamic_cast(*c); - auto aNumericType = dynamic_cast(*a); - - // If both are numeric - if(cNumericType && aNumericType) { - // If names are identical (we don't care about qualifiers), match is exact - if(cNumericType->getResolvedName(m_Context) == aNumericType->getResolvedName(m_Context)) { - argumentConversionRank.push_back(0); - } - // Integer promotion - else if(aNumericType->getName() == Type::Int32::getInstance()->getName() - && cNumericType->isIntegral(m_Context) - && cNumericType->getRank(m_Context) < Type::Int32::getInstance()->getRank(m_Context)) - { - argumentConversionRank.push_back(1); - } - // Float promotion - else if(aNumericType->getResolvedName(m_Context) == Type::Double::getInstance()->getName() - && cNumericType->getResolvedName(m_Context) == Type::Float::getInstance()->getName()) - { - argumentConversionRank.push_back(1); - } - // Otherwise, numeric conversion - // **TODO** integer to scalar promotion should be lower ranked than general conversion - else { - argumentConversionRank.push_back(2); - } - } - // Otherwise, if they are matching pointers - // **TODO** some more nuance here - else if(checkPointerTypeAssignement(*c, *a, m_Context)) { - argumentConversionRank.push_back(0); + for(;c != m_CallArguments.top().cend(); c++, a++) { + const auto argConversionRank = std::visit( + Utils::Overload{ + // If types are numeric, any cast goes + [c, a](const Type::Type::Numeric &cNumeric, const Type::Type::Numeric &aNumeric) -> std::optional + { + // If names are identical, match is exact + // **TODO** we don't care about qualifiers + if(*c == *a) { + return 0; + } + // Integer promotion + else if(*a == Type::Int32 && c->getNumeric().isIntegral + && c->getNumeric().rank < Type::Int32.getNumeric().rank) + { + return 1; + } + // Float promotion + else if(*a == Type::Double && *c == Type::Float) { + return 1; + } + // Otherwise, numeric conversion + // **TODO** integer to scalar promotion should be lower ranked than general conversion + else { + return 2; + } + }, + // Otherwise, if we're trying to cast pointer to pointer + [](const Type::Type::Pointer &cPointer, const Type::Type::Pointer &aPointer) -> std::optional + { + // Check that value type at the end matches + if (checkPointerTypeAssignement(*cPointer.valueType, *aPointer.valueType)) { + return 0; + } + else { + return std::nullopt; + } + }, + // Otherwise, pointers can't be cast to non-pointers and vice versa + [](auto, auto) -> std::optional + { + return std::nullopt; + }}, + c->detail, a->detail); + + // If there is a valid conversion between argument and definition + if (argConversionRank) { + argumentConversionRank.push_back(*argConversionRank); } // Otherwise, this function is not viable else { viable = false; - break; } } // If function is viable, add to vector along with vector of conversion ranks if(viable) { - viableFunctions.emplace_back(func, argumentConversionRank); + viableFunctions.emplace_back(type, argumentConversionRank); } } } @@ -506,53 +565,49 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If operator is pointer de-reference if (unary.getOperator().type == Token::Type::STAR) { - auto rightPointerType = dynamic_cast(rightType); - if (!rightPointerType) { + if (rightType.isPointer()) { + setExpressionType(&unary, *rightType.getPointer().valueType); + } + else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getName() + "'"); + "Invalid operand type '" + getDescription(rightType) + "'"); throw TypeCheckError(); } - - // Return value type - setExpressionType(&unary, rightPointerType->getValueType()); } // Otherwise - else { - auto rightNumericType = dynamic_cast(rightType); - if (rightNumericType) { - // If operator is arithmetic, return promoted type - if (unary.getOperator().type == Token::Type::PLUS || unary.getOperator().type == Token::Type::MINUS) { + else if (rightType.isNumeric()) { + // If operator is arithmetic, return promoted type + if (unary.getOperator().type == Token::Type::PLUS || unary.getOperator().type == Token::Type::MINUS) { + // **THINK** const through these? + setExpressionType(&unary, Type::getPromotedType(rightType)); + } + // Otherwise, if operator is bitwise + else if (unary.getOperator().type == Token::Type::TILDA) { + // If type is integer, return promoted type + if (rightType.getNumeric().isIntegral) { // **THINK** const through these? - setExpressionType(&unary, Type::getPromotedType(rightNumericType, m_Context)); - } - // Otherwise, if operator is bitwise - else if (unary.getOperator().type == Token::Type::TILDA) { - // If type is integer, return promoted type - if (rightNumericType->isIntegral(m_Context)) { - // **THINK** const through these? - setExpressionType(&unary, Type::getPromotedType(rightNumericType, m_Context)); - } - else { - m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getName() + "'"); - throw TypeCheckError(); - } - } - // Otherwise, if operator is logical - else if (unary.getOperator().type == Token::Type::NOT) { - setExpressionType(&unary); + setExpressionType(&unary, Type::getPromotedType(rightType)); } - // Otherwise, if operator is address of, return pointer type - else if (unary.getOperator().type == Token::Type::AMPERSAND) { - setExpressionType(&unary, rightType->getPointerType()); + else { + m_ErrorHandler.error(unary.getOperator(), + "Invalid operand type '" + getDescription(rightType) + "'"); + throw TypeCheckError(); } } - else { - m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + rightType->getName() + "'"); - throw TypeCheckError(); + // Otherwise, if operator is logical + else if (unary.getOperator().type == Token::Type::NOT) { + setExpressionType(&unary, Type::Int32); + } + // Otherwise, if operator is address of, return pointer type + else if (unary.getOperator().type == Token::Type::AMPERSAND) { + setExpressionType(&unary, Type::Type::createPointer(rightType)); } } + else { + m_ErrorHandler.error(unary.getOperator(), + "Invalid operand type '" + getDescription(rightType) + "'"); + throw TypeCheckError(); + } } //--------------------------------------------------------------------------- @@ -649,10 +704,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor if (labelled.getValue()) { auto valType = evaluateType(labelled.getValue()); - auto valNumericType = dynamic_cast(valType); - if (!valNumericType || !valNumericType->isIntegral(m_Context)) { + if (!valType.isNumeric() || !valType.getNumeric().isIntegral) { m_ErrorHandler.error(labelled.getKeyword(), - "Invalid case value '" + valType->getName() + "'"); + "Invalid case value '" + getDescription(valType) + "'"); throw TypeCheckError(); } } @@ -663,10 +717,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::Switch &switchStatement) final { auto condType = evaluateType(switchStatement.getCondition()); - auto condNumericType = dynamic_cast(condType); - if (!condNumericType || !condNumericType->isIntegral(m_Context)) { + if (!condType.isNumeric() || !condType.getNumeric().isIntegral) { m_ErrorHandler.error(switchStatement.getSwitch(), - "Invalid condition '" + condType->getName() + "'"); + "Invalid condition '" + getDescription(condType) + "'"); throw TypeCheckError(); } @@ -677,7 +730,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::VarDeclaration &varDeclaration) final { - const auto *decType = varDeclaration.getType(); + const auto decType = varDeclaration.getType(); for (const auto &var : varDeclaration.getInitDeclaratorList()) { m_Environment.get().define(std::get<0>(var), decType, m_ErrorHandler); @@ -709,26 +762,19 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // Private methods //--------------------------------------------------------------------------- - const Type::Base *evaluateType(const Expression::Base *expression) + Type::Type evaluateType(const Expression::Base *expression) { expression->accept(*this); return m_ResolvedTypes.at(expression); } - void setExpressionType(const Expression::Base *expression, const Type::Base *type) + void setExpressionType(const Expression::Base *expression, const Type::Type &type) { if (!m_ResolvedTypes.emplace(expression, type).second) { throw std::runtime_error("Expression type resolved multiple times"); } } - template - void setExpressionType(const Expression::Base *expression) - { - if (!m_ResolvedTypes.emplace(expression, T::getInstance()).second) { - throw std::runtime_error("Expression type resolved multiple times"); - } - } //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- @@ -736,7 +782,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const Type::TypeContext &m_Context; ErrorHandlerBase &m_ErrorHandler; ResolvedTypeMap &m_ResolvedTypes; - std::stack> m_CallArguments; + std::stack> m_CallArguments; bool m_InLoop; bool m_InSwitch; }; @@ -745,7 +791,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- -const Type::Base *EnvironmentBase::getType(const Token &name, ErrorHandlerBase &errorHandler) +Type::Type EnvironmentBase::getType(const Token &name, ErrorHandlerBase &errorHandler) { const auto types = getTypes(name, errorHandler); if (types.size() == 1) { @@ -757,10 +803,10 @@ const Type::Base *EnvironmentBase::getType(const Token &name, ErrorHandlerBase & } } //--------------------------------------------------------------------------- -const Type::Base *EnvironmentBase::assign(const Token &name, Token::Type op, - const Type::Base *existingType, const Type::Base *assignedType, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler, - bool initializer) const +Type::Type EnvironmentBase::assign(const Token &name, Token::Type op, + const Type::Base *existingType, const Type::Base *assignedType, + const Type::TypeContext &context, ErrorHandlerBase &errorHandler, + bool initializer) const { // If existing type is a const qualified and isn't being initialized, give error if(!initializer && existingType->hasQualifier(Type::Qualifier::CONSTANT)) { diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 53988ab7fc..c3b398537e 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -49,7 +49,7 @@ const std::map, Type::Type> numericTypeSpecifiers{ const std::set scalarTypeSpecifier{{"scalar"}}; //---------------------------------------------------------------------------- // Mapping of signed integer numericTypeSpecifiers to their unsigned equivalents -const std::unordered_map unsignedType{ +const std::map unsignedType{ {Type::Int8, Type::Uint8}, {Type::Int16, Type::Uint16}, {Type::Int32, Type::Uint32}}; @@ -86,6 +86,7 @@ Type getNumericType(const std::set &typeSpecifiers) { // If type matches scalar type specifiers if(typeSpecifiers == scalarTypeSpecifier) { + assert(false); //return new NumericTypedef("scalar"); } // Otherwise @@ -101,6 +102,7 @@ Type getPromotedType(const Type &type) // If a small integer type is used in an expression, it is implicitly converted to int which is always signed. // This is known as the integer promotions or the integer promotion rule // **NOTE** this is true because in our type system unsigned short is uint16 which can be represented in int32 + assert(type.isNumeric()); if(type.getNumeric().rank < Int32.getNumeric().rank) { return Int32; } @@ -112,6 +114,8 @@ Type getPromotedType(const Type &type) Type getCommonType(const Type &a, const Type &b) { // If either type is double, common type is double + assert(a.isNumeric()); + assert(b.isNumeric()); if(a == Double || b == Double) { return Double; } From 38673195916ee7af09a915768f67cd9105321041 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Apr 2023 12:57:51 +0100 Subject: [PATCH 150/725] assignement type checking --- include/genn/genn/transpiler/typeChecker.h | 4 +- src/genn/genn/transpiler/typeChecker.cc | 175 ++++++++++----------- 2 files changed, 82 insertions(+), 97 deletions(-) diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index afa454dd1e..53283d6ce8 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -58,6 +58,6 @@ class EnvironmentBase ResolvedTypeMap typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, const Type::TypeContext &context, ErrorHandlerBase &errorHandler); -const Type::Base *typeCheck(const Expression::Base *expression, EnvironmentBase &environment, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler); +Type::Type typeCheck(const Expression::Base *expression, EnvironmentBase &environment, + const Type::TypeContext &context, ErrorHandlerBase &errorHandler); } // namespace MiniParse::GeNN::Transpiler diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 1579c30b51..215f035662 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -77,7 +77,7 @@ bool checkForConstRemoval(const Type::Type &rightType, const Type::Type &leftTyp return std::visit( Utils::Overload{ - // If both are non-pointers, return true as const removal has been succesfully checked + // If both are numeric, return true as const removal has been succesfully checked [](const Type::Type::Numeric &rightNumeric, const Type::Type::Numeric &leftNumeric) { return true; @@ -91,6 +91,63 @@ bool checkForConstRemoval(const Type::Type &rightType, const Type::Type &leftTyp [](auto, auto) { return false; }}, rightType.detail, leftType.detail); } +//--------------------------------------------------------------------------- +bool checkImplicitConversion(const Type::Type &rightType, const Type::Type &leftType, Token::Type op = Token::Type::EQUAL) +{ + return std::visit( + Utils::Overload{ + // If both are numeric, return true as any numeric types can be assigned + [op](const Type::Type::Numeric &rightNumeric, const Type::Type::Numeric &leftNumeric) + { + // If operator requires it and both arguments are integers, return true + if (op == Token::Type::PERCENT_EQUAL || op == Token::Type::SHIFT_LEFT_EQUAL + || op == Token::Type::SHIFT_RIGHT_EQUAL || op == Token::Type::CARET + || op == Token::Type::AMPERSAND_EQUAL || op == Token::Type::PIPE_EQUAL) + { + return (leftNumeric.isIntegral && rightNumeric.isIntegral); + } + // Otherwise, assignement will work for any numeric type + else { + return true; + } + }, + // Otherwise, if both are pointers, recurse through value type + [op, &leftType, &rightType] + (const Type::Type::Pointer &rightPointer, const Type::Type::Pointer &leftPointer) + { + // If operator is equals + if (op == Token::Type::EQUAL) { + // Check that value type at the end matches + if (!checkPointerTypeAssignement(*rightPointer.valueType, *leftPointer.valueType)) { + return false; + } + // Check we're not trying to maketype less const + else if(!checkForConstRemoval(rightType, leftType)) { + return false; + } + else { + return true; + } + } + // Two pointers can only be assigned with = + else { + return false; + } + }, + // Otherwise, if left is pointer and right is numeric, + [op](const Type::Type::Numeric &rightNumeric, const Type::Type::Pointer &leftPointer) + { + if (op == Token::Type::PLUS_EQUAL || op == Token::Type::MINUS_EQUAL) { + return rightNumeric.isIntegral; + } + else { + return false; + } + }, + // Otherwise, we're trying to assign invalid types + [](auto, auto) { return false; }}, + rightType.detail, leftType.detail); +} //--------------------------------------------------------------------------- // EnvironmentInternal @@ -193,12 +250,21 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Assignment &assignment) final { - const auto lhsType = evaluateType(assignment.getAssignee()); - const auto rhsType = evaluateType(assignment.getValue()); + const auto leftType = evaluateType(assignment.getAssignee()); + const auto rightType = evaluateType(assignment.getValue()); - assert(false); + // If existing type is a const qualified and isn't being initialized, give error + if(leftType.hasQualifier(Type::Qualifier::CONSTANT)) { + m_ErrorHandler.error(assignment.getOperator(), "Assignment of read-only variable"); + throw TypeCheckError(); + } + // Otherwise, if implicit conversion fails, give error + else if (!checkImplicitConversion(rightType, leftType, assignment.getOperator().type)) { + m_ErrorHandler.error(assignment.getOperator(), "Invalid operand types '" + getDescription(leftType) + "' and '" + getDescription(rightType)); + throw TypeCheckError(); + } - setExpressionType(&assignment, lhsType); + setExpressionType(&assignment, leftType); } virtual void visit(const Expression::Binary &binary) final @@ -433,6 +499,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final { + // **TODO** more general lvalue thing const auto lhsType = evaluateType(postfixIncDec.getTarget()); if(lhsType.hasQualifier(Type::Qualifier::CONSTANT)) { m_ErrorHandler.error(postfixIncDec.getOperator(), "Increment/decrement of read-only variable"); @@ -445,6 +512,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final { + // **TODO** more general lvalue thing const auto rhsType = evaluateType(prefixIncDec.getTarget()); if(rhsType.hasQualifier(Type::Qualifier::CONSTANT)) { m_ErrorHandler.error(prefixIncDec.getOperator(), "Increment/decrement of read-only variable"); @@ -734,13 +802,13 @@ class Visitor : public Expression::Visitor, public Statement::Visitor for (const auto &var : varDeclaration.getInitDeclaratorList()) { m_Environment.get().define(std::get<0>(var), decType, m_ErrorHandler); - // If variable has an initialiser expression + // If variable has an initialiser expression, check that + // it can be implicitly converted to variable type if (std::get<1>(var)) { - // Evaluate type const auto initialiserType = evaluateType(std::get<1>(var).get()); - - assert(false); - // **TODO** check decType = initialiserType is implicit conversion + if (!checkImplicitConversion(initialiserType, decType)) { + m_ErrorHandler.error(std::get<0>(var), "Invalid operand types '" + getDescription(decType) + "' and '" + getDescription(initialiserType)); + } } } } @@ -802,89 +870,6 @@ Type::Type EnvironmentBase::getType(const Token &name, ErrorHandlerBase &errorHa throw TypeCheckError(); } } -//--------------------------------------------------------------------------- -Type::Type EnvironmentBase::assign(const Token &name, Token::Type op, - const Type::Base *existingType, const Type::Base *assignedType, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler, - bool initializer) const -{ - // If existing type is a const qualified and isn't being initialized, give error - if(!initializer && existingType->hasQualifier(Type::Qualifier::CONSTANT)) { - errorHandler.error(name, "Assignment of read-only variable"); - throw TypeCheckError(); - } - - // If assignment operation is plain equals, any type is fine so return - auto numericExistingType = dynamic_cast(existingType); - auto pointerExistingType = dynamic_cast(existingType); - auto numericAssignedType = dynamic_cast(assignedType); - auto pointerAssignedType = dynamic_cast(assignedType); - if(op == Token::Type::EQUAL) { - // If we're initialising a pointer with another pointer - if (pointerAssignedType && pointerExistingType) { - // Check that value type at the end matches - if (!checkPointerTypeAssignement(pointerAssignedType->getValueType(), pointerExistingType->getValueType(), context)) { - errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName() + "' and '" + pointerAssignedType->getName()); - throw TypeCheckError(); - } - - // If we're trying to make type less const - if (!checkForConstRemoval(pointerAssignedType, pointerExistingType)) { - errorHandler.error(name, "Invalid operand types '" + pointerExistingType->getName() + "' and '" + pointerAssignedType->getName()); - throw TypeCheckError(); - } - } - // Otherwise, if we're trying to initialise a pointer with a non-pointer or vice-versa - else if (pointerAssignedType || pointerExistingType) { - errorHandler.error(name, "Invalid operand types '" + existingType->getName() + "' and '" + assignedType->getName()); - throw TypeCheckError(); - } - } - // Otherwise, if operation is += or -- - else if (op == Token::Type::PLUS_EQUAL || op == Token::Type::MINUS_EQUAL) { - // If the operand being added isn't numeric or the type being added to is neither numeric or a pointer - if (!numericAssignedType || (!pointerExistingType && !numericExistingType)) - { - errorHandler.error(name, "Invalid operand types '" + existingType->getName() + "' and '" + assignedType->getName() + "'"); - throw TypeCheckError(); - } - - // If we're adding a numeric type to a pointer, check it's an integer - if (pointerExistingType && numericAssignedType->isIntegral(context)) { - errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName() + "'"); - throw TypeCheckError(); - } - } - // Otherwise, numeric types are required - else { - // If either type is non-numeric, give error - if(!numericAssignedType) { - errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName() + "'"); - throw TypeCheckError(); - } - if(!numericExistingType) { - errorHandler.error(name, "Invalid operand types '" + existingType->getName() + "'"); - throw TypeCheckError(); - } - - // If operand isn't one that takes any numeric type, check both operands are integral - if (op != Token::Type::STAR_EQUAL && op != Token::Type::SLASH_EQUAL) { - if(!numericAssignedType->isIntegral(context)) { - errorHandler.error(name, "Invalid operand types '" + numericAssignedType->getName() + "'"); - throw TypeCheckError(); - } - if(!numericExistingType->isIntegral(context)) { - errorHandler.error(name, "Invalid operand types '" + numericExistingType->getName() + "'"); - throw TypeCheckError(); - } - } - } - - // Return existing type - // **THINK** - return existingType; -} - //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker @@ -898,8 +883,8 @@ ResolvedTypeMap GeNN::Transpiler::TypeChecker::typeCheck(const Statement::Statem return expressionTypes; } //--------------------------------------------------------------------------- -const Type::Base *GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, EnvironmentBase &environment, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler) +Type::Type GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, EnvironmentBase &environment, + const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { ResolvedTypeMap expressionTypes; EnvironmentInternal internalEnvironment(environment); From 3b2c2851cf29dcecbe8284b6026f6553e8379b26 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Apr 2023 15:09:47 +0100 Subject: [PATCH 151/725] structs with immutable members just really doesn't work with C++ --- include/genn/genn/type.h | 68 +++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 082d4a95fb..60e5b17981 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -53,32 +53,32 @@ struct Type //------------------------------------------------------------------------ struct Numeric { - const std::string name; + std::string name; - const int rank; - const double min; - const double max; - const double lowest; - const int maxDigits10; + int rank; + double min; + double max; + double lowest; + int maxDigits10; - const bool isSigned; - const bool isIntegral; + bool isSigned; + bool isIntegral; - const std::string literalSuffix; + std::string literalSuffix; //------------------------------------------------------------------------ // Operators //------------------------------------------------------------------------ bool operator == (const Numeric &other) const { - return (std::make_tuple(rank, min, max, lowest, maxDigits10, isSigned, isIntegral) - == std::make_tuple(other.rank, other.min, other.max, other.lowest, other.maxDigits10, other.isSigned, other.isIntegral)); + return (std::tie(rank, min, max, lowest, maxDigits10, isSigned, isIntegral) + == std::tie(other.rank, other.min, other.max, other.lowest, other.maxDigits10, other.isSigned, other.isIntegral)); } bool operator < (const Numeric &other) const { - return (std::make_tuple(rank, min, max, lowest, maxDigits10, isSigned, isIntegral) - < std::make_tuple(other.rank, other.min, other.max, other.lowest, other.maxDigits10, other.isSigned, other.isIntegral)); + return (std::tie(rank, min, max, lowest, maxDigits10, isSigned, isIntegral) + < std::tie(other.rank, other.min, other.max, other.lowest, other.maxDigits10, other.isSigned, other.isIntegral)); } }; @@ -92,7 +92,7 @@ struct Type Pointer(const Pointer &other) : valueType(std::make_unique(*other.valueType)) {} - const std::unique_ptr valueType; + std::unique_ptr valueType; bool operator == (const Pointer &other) const { @@ -103,6 +103,12 @@ struct Type { return (*valueType < *other.valueType); } + + Pointer &operator = (const Pointer &other) + { + valueType.reset(new Type(*other.valueType)); + return *this; + } }; //------------------------------------------------------------------------ @@ -117,17 +123,24 @@ struct Type : returnType(std::make_unique(*other.returnType)), argTypes(other.argTypes) {} - const std::unique_ptr returnType; - const std::vector argTypes; + std::unique_ptr returnType; + std::vector argTypes; bool operator == (const Function &other) const { - return (*returnType == *other.returnType && argTypes == other.argTypes); + return std::tie(*returnType, argTypes) == std::tie(*other.returnType, other.argTypes); } bool operator < (const Function &other) const { - return (*returnType < *other.returnType); + return std::tie(*returnType, argTypes) < std::tie(*other.returnType, other.argTypes); + } + + Function &operator = (const Function &other) + { + returnType.reset(new Type(*other.returnType)); + argTypes = other.argTypes; + return *this; } }; @@ -140,19 +153,16 @@ struct Type Type(const Function &function) : size(0), qualifiers(Qualifier{0}), detail(function) {} - - Type(const Type &other) : size(other.size), qualifiers(other.qualifiers), detail(other.detail) - {} - Type(const Type other, Qualifier qualifiers) : size(other.size), qualifiers(qualifiers), detail(other.detail) + Type(const Type &other, Qualifier qualifiers) : size(other.size), qualifiers(qualifiers), detail(other.detail) {} //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - const size_t size; - const Qualifier qualifiers; + size_t size; + Qualifier qualifiers; - const std::variant detail; + std::variant detail; //------------------------------------------------------------------------ // Public API @@ -172,14 +182,14 @@ struct Type //------------------------------------------------------------------------ bool operator == (const Type &other) const { - return (std::make_tuple(size, qualifiers, detail) - == std::make_tuple(other.size, other.qualifiers, other.detail)); + return (std::tie(size, qualifiers, detail) + == std::tie(other.size, other.qualifiers, other.detail)); } bool operator < (const Type &other) const { - return (std::make_tuple(size, qualifiers, detail) - < std::make_tuple(other.size, other.qualifiers, other.detail)); + return (std::tie(size, qualifiers, detail) + < std::tie(other.size, other.qualifiers, other.detail)); } //------------------------------------------------------------------------ From 580b2e9d38539196e9de9700e54185eb45869c42 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Apr 2023 15:10:01 +0100 Subject: [PATCH 152/725] started updating types in backend --- .../genn/genn/code_generator/backendBase.h | 22 +++++----- .../backends/single_threaded_cpu/backend.cc | 41 +++++++++---------- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index c36505c0f7..68192ef40c 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -251,15 +251,15 @@ class GENN_EXPORT BackendBase //! Generate code to define a variable in the appropriate header file virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, - const Type::ValueBase *type, const std::string &name, VarLocation loc) const = 0; + const Type::Type &type, const std::string &name, VarLocation loc) const = 0; //! Generate code to instantiate a variable in the provided stream virtual void genVariableInstantiation(CodeStream &os, - const Type::ValueBase *type, const std::string &name, VarLocation loc) const = 0; + const Type::Type &type, const std::string &name, VarLocation loc) const = 0; //! Generate code to allocate variable with a size known at compile-time virtual void genVariableAllocation(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::Type &type, const Type::TypeContext &typeContext, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const = 0; //! Generate code to allocate variable with a size known at runtime @@ -272,22 +272,22 @@ class GENN_EXPORT BackendBase //! Generate code for pushing a variable with a size known at compile-time to the 'device' virtual void genVariablePush(CodeStream &os, - const Type::ValueBase *type, const std::string &name, + const Type::Type &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const = 0; //! Generate code for pulling a variable with a size known at compile-time from the 'device' virtual void genVariablePull(CodeStream &os, - const Type::ValueBase *type, const std::string &name, + const Type::Type &type, const std::string &name, VarLocation loc, size_t count) const = 0; //! Generate code for pushing a variable's value in the current timestep to the 'device' virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const std::string &name, + const Type::Type &type, const std::string &name, VarLocation loc, unsigned int batchSize) const = 0; //! Generate code for pulling a variable's value in the current timestep from the 'device' virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const std::string &name, + const Type::Type &type, const std::string &name, VarLocation loc, unsigned int batchSize) const = 0; //! Generate code for pushing a variable with a size known at tuntime to the 'device' @@ -309,7 +309,7 @@ class GENN_EXPORT BackendBase virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const = 0; //! When generating merged structures what type to use for simulation RNGs - virtual const Type::ValueBase *getMergedGroupSimRNGType() const = 0; + virtual const Type::Type &getMergedGroupSimRNGType() const = 0; virtual void genPopVariableInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0; virtual void genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, @@ -420,7 +420,7 @@ class GENN_EXPORT BackendBase //-------------------------------------------------------------------------- //! Helper function to generate matching push and pull functions for a variable void genVariablePushPull(CodeStream &push, CodeStream &pull, - const Type::ValueBase *type, const std::string &name, + const Type::Type &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const { genVariablePush(push, type, name, loc, autoInitialized, count); @@ -438,7 +438,7 @@ class GENN_EXPORT BackendBase //! Helper function to generate matching push and pull functions for the current state of a variable void genCurrentVariablePushPull(CodeStream &push, CodeStream &pull, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const std::string &name, + const Type::Type &type, const std::string &name, VarLocation loc, unsigned int batchSize) const { genCurrentVariablePush(push, ng, type, name, loc, batchSize); @@ -456,7 +456,7 @@ class GENN_EXPORT BackendBase //! Helper function to generate matching definition, declaration, allocation and free code for a statically-sized array void genArray(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::Type &type, const Type::TypeContext &typeContext, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const { genVariableDefinition(definitions, definitionsInternal, type, name, loc); diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 6aa9d76a0a..23124f6108 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1293,36 +1293,35 @@ void Backend::genStepTimeFinalisePreamble(CodeStream &, const ModelSpecMerged &) } //-------------------------------------------------------------------------- void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &, - const Type::ValueBase *type, const std::string &name, VarLocation) const + const Type::Type &type, const std::string &name, VarLocation) const { - definitions << "EXPORT_VAR " << type->getPointerType()->getName() << " " << name << ";" << std::endl; + definitions << "EXPORT_VAR " << type.getNumeric().name << "* " << name << ";" << std::endl; } //-------------------------------------------------------------------------- void Backend::genVariableInstantiation(CodeStream &os, - const Type::ValueBase *type, const std::string &name, VarLocation) const + const Type::Type &type, const std::string &name, VarLocation) const { - os << type->getPointerType()->getName() << " " << name << ";" << std::endl; + os << type.getNumeric().name << "* " << name << ";" << std::endl; } //-------------------------------------------------------------------------- void Backend::genVariableAllocation(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::Type &type, const Type::TypeContext &typeContext, const std::string &name, VarLocation, size_t count, MemAlloc &memAlloc) const { - os << name << " = new " << type->getName() << "[" << count << "];" << std::endl; + os << name << " = new " << type.getNumeric().name << "[" << count << "];" << std::endl; - memAlloc += MemAlloc::host(count * type->getSizeBytes(typeContext)); + memAlloc += MemAlloc::host(count * type.size); } //-------------------------------------------------------------------------- void Backend::genVariableDynamicAllocation(CodeStream &os, - const Type::Base *type, const std::string &name, VarLocation, + const Type::Type &type, const std::string &name, VarLocation, const std::string &countVarName, const std::string &prefix) const { - const auto *pointerType = dynamic_cast(type); - if (pointerType) { - os << "*" << prefix << name << " = new " << pointerType->getValueType()->getName() << "[" << countVarName << "];" << std::endl; + if (type.isPointer()) { + os << "*" << prefix << name << " = new " << type.getPointer().valueType->getNumeric().name << "[" << countVarName << "];" << std::endl; } else { - os << prefix << name << " = new " << type->getName() << "[" << countVarName << "];" << std::endl; + os << prefix << name << " = new " << type.getNumeric().name << "[" << countVarName << "];" << std::endl; } } //-------------------------------------------------------------------------- @@ -1331,39 +1330,39 @@ void Backend::genVariableFree(CodeStream &os, const std::string &name, VarLocati os << "delete[] " << name << ";" << std::endl; } //-------------------------------------------------------------------------- -void Backend::genVariablePush(CodeStream&, const Type::ValueBase*, const std::string&, VarLocation, bool, size_t) const +void Backend::genVariablePush(CodeStream&, const Type::Type&, const std::string&, VarLocation, bool, size_t) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- -void Backend::genVariablePull(CodeStream&, const Type::ValueBase*, const std::string&, VarLocation, size_t) const +void Backend::genVariablePull(CodeStream&, const Type::Type&, const std::string&, VarLocation, size_t) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- void Backend::genCurrentVariablePush(CodeStream&, const NeuronGroupInternal&, - const Type::ValueBase*, const std::string&, + const Type::Type&, const std::string&, VarLocation, unsigned int) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- void Backend::genCurrentVariablePull(CodeStream&, const NeuronGroupInternal&, - const Type::ValueBase*, const std::string&, + const Type::Type&, const std::string&, VarLocation, unsigned int) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- void Backend::genVariableDynamicPush(CodeStream&, - const Type::Base*, const std::string&, + const Type::Type&, const std::string&, VarLocation, const std::string&, const std::string&) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- void Backend::genVariableDynamicPull(CodeStream&, - const Type::Base*, const std::string&, + const Type::Type&, const std::string&, VarLocation, const std::string&, const std::string&) const { assert(!getPreferences().automaticCopy); @@ -1377,12 +1376,12 @@ void Backend::genMergedDynamicVariablePush(CodeStream &os, const std::string &su } //-------------------------------------------------------------------------- -std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) const +std::string Backend::getMergedGroupFieldHostTypeName(const Type::Type &type) const { - return type->getName(); + return type.getNumeric().name; } //-------------------------------------------------------------------------- -const Type::ValueBase *Backend::getMergedGroupSimRNGType() const +const Type::Type &Backend::getMergedGroupSimRNGType() const { assert(false); return nullptr; From 8ce7e54a5ec5d19e9939ad2736c681c8dfa31866 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Apr 2023 17:32:07 +0100 Subject: [PATCH 153/725] started adding ``Value`` as top-level type --- .../backends/single_threaded_cpu/backend.h | 24 +++++------ .../genn/genn/code_generator/backendBase.h | 8 ++-- include/genn/genn/type.h | 43 ++++++++++++++----- src/genn/genn/transpiler/typeChecker.cc | 24 ++++++----- 4 files changed, 63 insertions(+), 36 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 7b3c215269..63bc6e946a 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -61,20 +61,20 @@ class BACKEND_EXPORT Backend : public BackendBase //! Generate code to define a variable in the appropriate header file virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, - const Type::ValueBase *type, const std::string &name, VarLocation loc) const final; + const Type::Type &type, const std::string &name, VarLocation loc) const final; //! Generate code to instantiate a variable in the provided stream virtual void genVariableInstantiation(CodeStream &os, - const Type::ValueBase *type, const std::string &name, VarLocation loc) const final; + const Type::Type &type, const std::string &name, VarLocation loc) const final; //! Generate code to allocate variable with a size known at compile-time virtual void genVariableAllocation(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::Type &type, const Type::TypeContext &typeContext, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const final; //! Generate code to allocate variable with a size known at runtime virtual void genVariableDynamicAllocation(CodeStream &os, - const Type::Base *type, const std::string &name, VarLocation loc, + const Type::Type &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code to free a variable @@ -82,32 +82,32 @@ class BACKEND_EXPORT Backend : public BackendBase //! Generate code for pushing a variable with a size known at compile-time to the 'device' virtual void genVariablePush(CodeStream &os, - const Type::ValueBase *type, const std::string &name, VarLocation loc, + const Type::Type &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const final; //! Generate code for pulling a variable with a size known at compile-time from the 'device' virtual void genVariablePull(CodeStream &os, - const Type::ValueBase *type, const std::string &name, + const Type::Type &type, const std::string &name, VarLocation loc, size_t count) const final; //! Generate code for pushing a variable's value in the current timestep to the 'device' virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const std::string &name, + const Type::Type &type, const std::string &name, VarLocation loc, unsigned int batchSize) const final; //! Generate code for pulling a variable's value in the current timestep from the 'device' virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const std::string &name, + const Type::Type &type, const std::string &name, VarLocation loc, unsigned int batchSize) const final; //! Generate code for pushing a variable with a size known at tuntime to the 'device' virtual void genVariableDynamicPush(CodeStream &os, - const Type::Base *type, const std::string &name, VarLocation loc, + const Type::Type &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code for pulling a variable with a size known at runtime from the 'device' virtual void genVariableDynamicPull(CodeStream &os, - const Type::Base *type, const std::string &name, VarLocation loc, + const Type::Type &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' @@ -116,10 +116,10 @@ class BACKEND_EXPORT Backend : public BackendBase const std::string &egpName) const final; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const final; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Type &type) const final; //! When generating merged structures what type to use for simulation RNGs - virtual const Type::ValueBase *getMergedGroupSimRNGType() const final; + virtual const Type::Type &getMergedGroupSimRNGType() const final; virtual void genPopVariableInit(CodeStream &os,const Substitutions &kernelSubs, Handler handler) const final; virtual void genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 68192ef40c..9540ff2dd2 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -264,7 +264,7 @@ class GENN_EXPORT BackendBase //! Generate code to allocate variable with a size known at runtime virtual void genVariableDynamicAllocation(CodeStream &os, - const Type::Base *type, const std::string &name, VarLocation loc, + const Type::Type &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; //! Generate code to free a variable @@ -292,12 +292,12 @@ class GENN_EXPORT BackendBase //! Generate code for pushing a variable with a size known at tuntime to the 'device' virtual void genVariableDynamicPush(CodeStream &os, - const Type::Base *type, const std::string &name, VarLocation loc, + const Type::Type &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; //! Generate code for pulling a variable with a size known at runtime from the 'device' virtual void genVariableDynamicPull(CodeStream &os, - const Type::Base *type, const std::string &name, VarLocation loc, + const Type::Type &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' @@ -306,7 +306,7 @@ class GENN_EXPORT BackendBase const std::string &egpName) const = 0; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const = 0; + virtual std::string getMergedGroupFieldHostTypeName(const Type::Type &type) const = 0; //! When generating merged structures what type to use for simulation RNGs virtual const Type::Type &getMergedGroupSimRNGType() const = 0; diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 60e5b17981..4d01fd2762 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -7,6 +7,7 @@ // Standard C++ includes #include #include +#include #include #include #include @@ -53,8 +54,6 @@ struct Type //------------------------------------------------------------------------ struct Numeric { - std::string name; - int rank; double min; double max; @@ -82,6 +81,28 @@ struct Type } }; + //------------------------------------------------------------------------ + // Value + //------------------------------------------------------------------------ + struct Value + { + std::string name; + std::optional numeric; + + //------------------------------------------------------------------------ + // Operators + //------------------------------------------------------------------------ + bool operator == (const Value &other) const + { + return (numeric == other.numeric); + } + + bool operator < (const Value &other) const + { + return (numeric < other.numeric); + } + }; + //------------------------------------------------------------------------ // Pointer //------------------------------------------------------------------------ @@ -144,7 +165,7 @@ struct Type } }; - Type(size_t size, Qualifier qualifiers, const Numeric &numeric) + Type(size_t size, Qualifier qualifiers, const Value &numeric) : size(size), qualifiers(qualifiers), detail(numeric) {} Type(Qualifier qualifiers, const Pointer &pointer) @@ -162,18 +183,20 @@ struct Type size_t size; Qualifier qualifiers; - std::variant detail; + std::variant detail; //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - bool isNumeric() const{ return std::holds_alternative(detail); } + bool isValue() const{ return std::holds_alternative(detail); } bool isPointer() const{ return std::holds_alternative(detail); } bool isFunction() const{ return std::holds_alternative(detail); } - const Numeric &getNumeric() const{ return std::get(detail); } + bool isNumeric() const{ return isValue() && getValue().numeric; } + const Value &getValue() const{ return std::get(detail); } const Pointer &getPointer() const{ return std::get(detail); } const Function &getFunction() const{ return std::get(detail); } - + const Numeric &getNumeric() const{ return *getValue().numeric; } + const Type addQualifier(Qualifier qualifier) const{ return Type(*this, qualifier); } bool hasQualifier(Qualifier qualifier) const{ return (qualifiers & qualifier); } @@ -198,9 +221,9 @@ struct Type template static Type createNumeric(const std::string &name, int rank, const std::string &literalSuffix = "", Qualifier qualifiers = Qualifier{0}) { - return Type(sizeof(T), qualifiers, Numeric{name, rank, std::numeric_limits::min(), std::numeric_limits::max(), - std::numeric_limits::lowest(), std::numeric_limits::max_digits10, - std::is_signed::value, std::is_integral::value, literalSuffix}); + return Type{sizeof(T), qualifiers, Value{name, Numeric{rank, std::numeric_limits::min(), std::numeric_limits::max(), + std::numeric_limits::lowest(), std::numeric_limits::max_digits10, + std::is_signed::value, std::is_integral::value, literalSuffix}}}; } static Type createPointer(const Type &valueType, Qualifier qualifiers = Qualifier{0}) diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 215f035662..b0143e3e68 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -32,9 +32,10 @@ std::string getDescription(const Type::Type &type) const std::string qualifier = type.hasQualifier(Type::Qualifier::CONSTANT) ? "const " : ""; return std::visit( Utils::Overload{ - [&qualifier](const Type::Type::Numeric &numeric) + [&qualifier](const Type::Type::Value &value) { - return qualifier + numeric.name; + assert(value.numeric); + return qualifier + value.name; }, [&qualifier, &type](const Type::Type::Pointer &pointer) { @@ -55,8 +56,9 @@ bool checkPointerTypeAssignement(const Type::Type &rightType, const Type::Type & { return std::visit( Utils::Overload{ - [&rightType, &leftType](const Type::Type::Numeric &rightNumeric, const Type::Type::Numeric &leftNumeric) + [&rightType, &leftType](const Type::Type::Value &leftValue, const Type::Type::Value &rightValue) { + assert(leftValue.numeric && rightValue.numeric); return (rightType == leftType); }, [](const Type::Type::Pointer &rightPointer, const Type::Type::Pointer &leftPointer) @@ -77,8 +79,8 @@ bool checkForConstRemoval(const Type::Type &rightType, const Type::Type &leftTyp return std::visit( Utils::Overload{ - // If both are numeric, return true as const removal has been succesfully checked - [](const Type::Type::Numeric &rightNumeric, const Type::Type::Numeric &leftNumeric) + // If both are value types + [](const Type::Type::Value &rightValue, const Type::Type::Value &leftValue) { return true; }, @@ -97,14 +99,15 @@ bool checkImplicitConversion(const Type::Type &rightType, const Type::Type &left return std::visit( Utils::Overload{ // If both are numeric, return true as any numeric types can be assigned - [op](const Type::Type::Numeric &rightNumeric, const Type::Type::Numeric &leftNumeric) + [op](const Type::Type::Value &rightValue, const Type::Type::Value &leftValue) { // If operator requires it and both arguments are integers, return true + assert(leftValue.numeric && rightValue.numeric); if (op == Token::Type::PERCENT_EQUAL || op == Token::Type::SHIFT_LEFT_EQUAL || op == Token::Type::SHIFT_RIGHT_EQUAL || op == Token::Type::CARET || op == Token::Type::AMPERSAND_EQUAL || op == Token::Type::PIPE_EQUAL) { - return (leftNumeric.isIntegral && rightNumeric.isIntegral); + return (leftValue.numeric->isIntegral && rightValue.numeric->isIntegral); } // Otherwise, assignement will work for any numeric type else { @@ -135,10 +138,11 @@ bool checkImplicitConversion(const Type::Type &rightType, const Type::Type &left } }, // Otherwise, if left is pointer and right is numeric, - [op](const Type::Type::Numeric &rightNumeric, const Type::Type::Pointer &leftPointer) + [op](const Type::Type::Value &rightValue, const Type::Type::Pointer &leftPointer) { + assert(rightValue.numeric); if (op == Token::Type::PLUS_EQUAL || op == Token::Type::MINUS_EQUAL) { - return rightNumeric.isIntegral; + return rightValue.numeric->isIntegral; } else { return false; @@ -283,7 +287,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor Utils::Overload{ // If both operands are numeric [&leftType, &rightType, opType, this] - (const Type::Type::Numeric &rightNumeric, const Type::Type::Numeric &leftNumeric) -> std::optional + (const Type::Type::Value &rightNumeric, const Type::Type::Numeric &leftNumeric) -> std::optional { // If operator requires integer operands if (opType == Token::Type::PERCENT || opType == Token::Type::SHIFT_LEFT From b8c734033dab0c1f83649b084c7ad7de539f6301 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 24 May 2023 15:13:20 +0100 Subject: [PATCH 154/725] moved ``size`` into ``Value`` --- include/genn/genn/type.h | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 4d01fd2762..29d38ef919 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -86,7 +86,8 @@ struct Type //------------------------------------------------------------------------ struct Value { - std::string name; + size_t size; + std::string name; // **TODO** delete me std::optional numeric; //------------------------------------------------------------------------ @@ -94,12 +95,12 @@ struct Type //------------------------------------------------------------------------ bool operator == (const Value &other) const { - return (numeric == other.numeric); + return std::tie(size, numeric) == std::tie(other.size, other.numeric); } bool operator < (const Value &other) const { - return (numeric < other.numeric); + return std::tie(size, numeric) < std::tie(other.size, other.numeric); } }; @@ -165,22 +166,21 @@ struct Type } }; - Type(size_t size, Qualifier qualifiers, const Value &numeric) - : size(size), qualifiers(qualifiers), detail(numeric) + Type(Qualifier qualifiers, const Value &value) + : qualifiers(qualifiers), detail(value) {} Type(Qualifier qualifiers, const Pointer &pointer) - : size(sizeof(char*)), qualifiers(qualifiers), detail(pointer) + : qualifiers(qualifiers), detail(pointer) {} Type(const Function &function) - : size(0), qualifiers(Qualifier{0}), detail(function) + : qualifiers(Qualifier{0}), detail(function) {} - Type(const Type &other, Qualifier qualifiers) : size(other.size), qualifiers(qualifiers), detail(other.detail) + Type(const Type &other, Qualifier qualifiers) : qualifiers(qualifiers), detail(other.detail) {} //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - size_t size; Qualifier qualifiers; std::variant detail; @@ -205,14 +205,12 @@ struct Type //------------------------------------------------------------------------ bool operator == (const Type &other) const { - return (std::tie(size, qualifiers, detail) - == std::tie(other.size, other.qualifiers, other.detail)); + return std::tie(qualifiers, detail) == std::tie(other.qualifiers, other.detail); } bool operator < (const Type &other) const { - return (std::tie(size, qualifiers, detail) - < std::tie(other.size, other.qualifiers, other.detail)); + return std::tie(qualifiers, detail) < std::tie(other.qualifiers, other.detail); } //------------------------------------------------------------------------ @@ -221,7 +219,7 @@ struct Type template static Type createNumeric(const std::string &name, int rank, const std::string &literalSuffix = "", Qualifier qualifiers = Qualifier{0}) { - return Type{sizeof(T), qualifiers, Value{name, Numeric{rank, std::numeric_limits::min(), std::numeric_limits::max(), + return Type{qualifiers, Value{sizeof(T), name, Numeric{rank, std::numeric_limits::min(), std::numeric_limits::max(), std::numeric_limits::lowest(), std::numeric_limits::max_digits10, std::is_signed::value, std::is_integral::value, literalSuffix}}}; } From ae8dde4c987c7558277de1224975077ae00f40da Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 24 May 2023 15:28:49 +0100 Subject: [PATCH 155/725] rename Type::Type to Type::ResolvedType --- .../backends/single_threaded_cpu/backend.h | 24 ++--- .../genn/genn/code_generator/backendBase.h | 30 +++---- include/genn/genn/transpiler/expression.h | 6 +- include/genn/genn/transpiler/parser.h | 2 +- include/genn/genn/transpiler/statement.h | 6 +- include/genn/genn/transpiler/typeChecker.h | 10 +-- include/genn/genn/type.h | 78 ++++++++--------- .../backends/single_threaded_cpu/backend.cc | 24 ++--- src/genn/genn/transpiler/typeChecker.cc | 87 ++++++++++--------- src/genn/genn/type.cc | 20 ++--- 10 files changed, 146 insertions(+), 141 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 63bc6e946a..7d07254dc2 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -61,20 +61,20 @@ class BACKEND_EXPORT Backend : public BackendBase //! Generate code to define a variable in the appropriate header file virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, - const Type::Type &type, const std::string &name, VarLocation loc) const final; + const Type::ResolvedType &type, const std::string &name, VarLocation loc) const final; //! Generate code to instantiate a variable in the provided stream virtual void genVariableInstantiation(CodeStream &os, - const Type::Type &type, const std::string &name, VarLocation loc) const final; + const Type::ResolvedType &type, const std::string &name, VarLocation loc) const final; //! Generate code to allocate variable with a size known at compile-time virtual void genVariableAllocation(CodeStream &os, - const Type::Type &type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ResolvedType &type, const Type::TypeContext &typeContext, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const final; //! Generate code to allocate variable with a size known at runtime virtual void genVariableDynamicAllocation(CodeStream &os, - const Type::Type &type, const std::string &name, VarLocation loc, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code to free a variable @@ -82,32 +82,32 @@ class BACKEND_EXPORT Backend : public BackendBase //! Generate code for pushing a variable with a size known at compile-time to the 'device' virtual void genVariablePush(CodeStream &os, - const Type::Type &type, const std::string &name, VarLocation loc, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const final; //! Generate code for pulling a variable with a size known at compile-time from the 'device' virtual void genVariablePull(CodeStream &os, - const Type::Type &type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, size_t count) const final; //! Generate code for pushing a variable's value in the current timestep to the 'device' virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, - const Type::Type &type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, unsigned int batchSize) const final; //! Generate code for pulling a variable's value in the current timestep from the 'device' virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, - const Type::Type &type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, unsigned int batchSize) const final; //! Generate code for pushing a variable with a size known at tuntime to the 'device' virtual void genVariableDynamicPush(CodeStream &os, - const Type::Type &type, const std::string &name, VarLocation loc, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code for pulling a variable with a size known at runtime from the 'device' virtual void genVariableDynamicPull(CodeStream &os, - const Type::Type &type, const std::string &name, VarLocation loc, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' @@ -116,10 +116,10 @@ class BACKEND_EXPORT Backend : public BackendBase const std::string &egpName) const final; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::Type &type) const final; + virtual std::string getMergedGroupFieldHostTypeName(const Type::ResolvedType &type) const final; //! When generating merged structures what type to use for simulation RNGs - virtual const Type::Type &getMergedGroupSimRNGType() const final; + virtual const Type::ResolvedType &getMergedGroupSimRNGType() const final; virtual void genPopVariableInit(CodeStream &os,const Substitutions &kernelSubs, Handler handler) const final; virtual void genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 9540ff2dd2..3d2d1a9f33 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -251,20 +251,20 @@ class GENN_EXPORT BackendBase //! Generate code to define a variable in the appropriate header file virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, - const Type::Type &type, const std::string &name, VarLocation loc) const = 0; + const Type::ResolvedType &type, const std::string &name, VarLocation loc) const = 0; //! Generate code to instantiate a variable in the provided stream virtual void genVariableInstantiation(CodeStream &os, - const Type::Type &type, const std::string &name, VarLocation loc) const = 0; + const Type::ResolvedType &type, const std::string &name, VarLocation loc) const = 0; //! Generate code to allocate variable with a size known at compile-time virtual void genVariableAllocation(CodeStream &os, - const Type::Type &type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ResolvedType &type, const Type::TypeContext &typeContext, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const = 0; //! Generate code to allocate variable with a size known at runtime virtual void genVariableDynamicAllocation(CodeStream &os, - const Type::Type &type, const std::string &name, VarLocation loc, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; //! Generate code to free a variable @@ -272,32 +272,32 @@ class GENN_EXPORT BackendBase //! Generate code for pushing a variable with a size known at compile-time to the 'device' virtual void genVariablePush(CodeStream &os, - const Type::Type &type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const = 0; //! Generate code for pulling a variable with a size known at compile-time from the 'device' virtual void genVariablePull(CodeStream &os, - const Type::Type &type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, size_t count) const = 0; //! Generate code for pushing a variable's value in the current timestep to the 'device' virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, - const Type::Type &type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, unsigned int batchSize) const = 0; //! Generate code for pulling a variable's value in the current timestep from the 'device' virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, - const Type::Type &type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, unsigned int batchSize) const = 0; //! Generate code for pushing a variable with a size known at tuntime to the 'device' virtual void genVariableDynamicPush(CodeStream &os, - const Type::Type &type, const std::string &name, VarLocation loc, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; //! Generate code for pulling a variable with a size known at runtime from the 'device' virtual void genVariableDynamicPull(CodeStream &os, - const Type::Type &type, const std::string &name, VarLocation loc, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' @@ -306,10 +306,10 @@ class GENN_EXPORT BackendBase const std::string &egpName) const = 0; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::Type &type) const = 0; + virtual std::string getMergedGroupFieldHostTypeName(const Type::ResolvedType &type) const = 0; //! When generating merged structures what type to use for simulation RNGs - virtual const Type::Type &getMergedGroupSimRNGType() const = 0; + virtual const Type::ResolvedType &getMergedGroupSimRNGType() const = 0; virtual void genPopVariableInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0; virtual void genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, @@ -420,7 +420,7 @@ class GENN_EXPORT BackendBase //-------------------------------------------------------------------------- //! Helper function to generate matching push and pull functions for a variable void genVariablePushPull(CodeStream &push, CodeStream &pull, - const Type::Type &type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const { genVariablePush(push, type, name, loc, autoInitialized, count); @@ -438,7 +438,7 @@ class GENN_EXPORT BackendBase //! Helper function to generate matching push and pull functions for the current state of a variable void genCurrentVariablePushPull(CodeStream &push, CodeStream &pull, const NeuronGroupInternal &ng, - const Type::Type &type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, unsigned int batchSize) const { genCurrentVariablePush(push, ng, type, name, loc, batchSize); @@ -456,7 +456,7 @@ class GENN_EXPORT BackendBase //! Helper function to generate matching definition, declaration, allocation and free code for a statically-sized array void genArray(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const Type::Type &type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ResolvedType &type, const Type::TypeContext &typeContext, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const { genVariableDefinition(definitions, definitionsInternal, type, name, loc); diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index 4d2499b78b..550f453a34 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -168,16 +168,16 @@ class Call : public Acceptable class Cast : public Acceptable { public: - Cast(const Type::Type &type, ExpressionPtr expression, Token closingParen) + Cast(const Type::ResolvedType &type, ExpressionPtr expression, Token closingParen) : m_Type(type), m_Expression(std::move(expression)), m_ClosingParen(closingParen) {} - const Type::Type &getType() const{ return m_Type; } + const Type::ResolvedType &getType() const{ return m_Type; } const Base *getExpression() const { return m_Expression.get(); } const Token &getClosingParen() const { return m_ClosingParen; } private: - const Type::Type m_Type; + const Type::ResolvedType m_Type; const ExpressionPtr m_Expression; const Token m_ClosingParen; }; diff --git a/include/genn/genn/transpiler/parser.h b/include/genn/genn/transpiler/parser.h index cb0ab11a0d..7f9302ab93 100644 --- a/include/genn/genn/transpiler/parser.h +++ b/include/genn/genn/transpiler/parser.h @@ -29,6 +29,6 @@ Expression::ExpressionPtr parseExpression(const std::vector &tokens, Erro Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandlerBase &errorHandler); //! Parse type from tokens -const GeNN::Type::Type parseNumericType(const std::vector &tokens, ErrorHandlerBase &errorHandler); +const GeNN::Type::ResolvedType parseNumericType(const std::vector &tokens, ErrorHandlerBase &errorHandler); } // MiniParse::MiniParse diff --git a/include/genn/genn/transpiler/statement.h b/include/genn/genn/transpiler/statement.h index c9177d3112..f36fa1264d 100644 --- a/include/genn/genn/transpiler/statement.h +++ b/include/genn/genn/transpiler/statement.h @@ -255,15 +255,15 @@ class VarDeclaration : public Acceptable public: typedef std::vector> InitDeclaratorList; - VarDeclaration(const Type::Type &type, InitDeclaratorList initDeclaratorList) + VarDeclaration(const Type::ResolvedType &type, InitDeclaratorList initDeclaratorList) : m_Type(type), m_InitDeclaratorList(std::move(initDeclaratorList)) {} - const Type::Type &getType() const{ return m_Type; } + const Type::ResolvedType &getType() const{ return m_Type; } const InitDeclaratorList &getInitDeclaratorList() const { return m_InitDeclaratorList; } private: - const Type::Type m_Type; + const Type::ResolvedType m_Type; const std::vector m_DeclarationSpecifiers; const InitDeclaratorList m_InitDeclaratorList; }; diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 53283d6ce8..3e23220fe7 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -32,7 +32,7 @@ class TypeCheckError : public std::runtime_error } }; -typedef std::unordered_map ResolvedTypeMap; +typedef std::unordered_map ResolvedTypeMap; //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker::EnvironmentBase @@ -43,13 +43,13 @@ class EnvironmentBase //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ - virtual void define(const Token &name, const Type::Type &type, ErrorHandlerBase &errorHandler) = 0; - virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) = 0; + virtual void define(const Token &name, const Type::ResolvedType &type, ErrorHandlerBase &errorHandler) = 0; + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) = 0; //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- - Type::Type getType(const Token &name, ErrorHandlerBase &errorHandler); + Type::ResolvedType getType(const Token &name, ErrorHandlerBase &errorHandler); }; //--------------------------------------------------------------------------- @@ -58,6 +58,6 @@ class EnvironmentBase ResolvedTypeMap typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, const Type::TypeContext &context, ErrorHandlerBase &errorHandler); -Type::Type typeCheck(const Expression::Base *expression, EnvironmentBase &environment, +Type::ResolvedType typeCheck(const Expression::Base *expression, EnvironmentBase &environment, const Type::TypeContext &context, ErrorHandlerBase &errorHandler); } // namespace MiniParse::GeNN::Transpiler diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 29d38ef919..f24d670d1d 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -22,7 +22,7 @@ //---------------------------------------------------------------------------- // Macros //---------------------------------------------------------------------------- -#define CREATE_NUMERIC(TYPE, RANK, L_SUFFIX) Type::createNumeric(#TYPE, RANK, L_SUFFIX) +#define CREATE_NUMERIC(TYPE, RANK, L_SUFFIX) ResolvedType::createNumeric(#TYPE, RANK, L_SUFFIX) //---------------------------------------------------------------------------- // GeNN::Type::Qualifier @@ -45,9 +45,9 @@ inline Qualifier operator | (Qualifier a, Qualifier b) } //---------------------------------------------------------------------------- -// GeNN::Type::Type +// GeNN::Type::ResolvedType //---------------------------------------------------------------------------- -struct Type +struct ResolvedType { //------------------------------------------------------------------------ // Numeric @@ -109,12 +109,12 @@ struct Type //------------------------------------------------------------------------ struct Pointer { - Pointer(const Type &valueType) : valueType(std::make_unique(valueType)) + Pointer(const ResolvedType &valueType) : valueType(std::make_unique(valueType)) {} - Pointer(const Pointer &other) : valueType(std::make_unique(*other.valueType)) + Pointer(const Pointer &other) : valueType(std::make_unique(*other.valueType)) {} - std::unique_ptr valueType; + std::unique_ptr valueType; bool operator == (const Pointer &other) const { @@ -128,7 +128,7 @@ struct Type Pointer &operator = (const Pointer &other) { - valueType.reset(new Type(*other.valueType)); + valueType.reset(new ResolvedType(*other.valueType)); return *this; } }; @@ -138,15 +138,15 @@ struct Type //------------------------------------------------------------------------ struct Function { - Function(const Type &returnType, const std::vector &argTypes) - : returnType(std::make_unique(returnType)), argTypes(argTypes) + Function(const ResolvedType &returnType, const std::vector &argTypes) + : returnType(std::make_unique(returnType)), argTypes(argTypes) {} Function(const Function &other) - : returnType(std::make_unique(*other.returnType)), argTypes(other.argTypes) + : returnType(std::make_unique(*other.returnType)), argTypes(other.argTypes) {} - std::unique_ptr returnType; - std::vector argTypes; + std::unique_ptr returnType; + std::vector argTypes; bool operator == (const Function &other) const { @@ -160,22 +160,22 @@ struct Type Function &operator = (const Function &other) { - returnType.reset(new Type(*other.returnType)); + returnType.reset(new ResolvedType(*other.returnType)); argTypes = other.argTypes; return *this; } }; - Type(Qualifier qualifiers, const Value &value) + ResolvedType(Qualifier qualifiers, const Value &value) : qualifiers(qualifiers), detail(value) {} - Type(Qualifier qualifiers, const Pointer &pointer) + ResolvedType(Qualifier qualifiers, const Pointer &pointer) : qualifiers(qualifiers), detail(pointer) {} - Type(const Function &function) + ResolvedType(const Function &function) : qualifiers(Qualifier{0}), detail(function) {} - Type(const Type &other, Qualifier qualifiers) : qualifiers(qualifiers), detail(other.detail) + ResolvedType(const ResolvedType &other, Qualifier qualifiers) : qualifiers(qualifiers), detail(other.detail) {} //------------------------------------------------------------------------ @@ -197,18 +197,18 @@ struct Type const Function &getFunction() const{ return std::get(detail); } const Numeric &getNumeric() const{ return *getValue().numeric; } - const Type addQualifier(Qualifier qualifier) const{ return Type(*this, qualifier); } + const ResolvedType addQualifier(Qualifier qualifier) const{ return ResolvedType(*this, qualifier); } bool hasQualifier(Qualifier qualifier) const{ return (qualifiers & qualifier); } //------------------------------------------------------------------------ // Operators //------------------------------------------------------------------------ - bool operator == (const Type &other) const + bool operator == (const ResolvedType &other) const { return std::tie(qualifiers, detail) == std::tie(other.qualifiers, other.detail); } - bool operator < (const Type &other) const + bool operator < (const ResolvedType &other) const { return std::tie(qualifiers, detail) < std::tie(other.qualifiers, other.detail); } @@ -217,16 +217,16 @@ struct Type // Static API //------------------------------------------------------------------------ template - static Type createNumeric(const std::string &name, int rank, const std::string &literalSuffix = "", Qualifier qualifiers = Qualifier{0}) + static ResolvedType createNumeric(const std::string &name, int rank, const std::string &literalSuffix = "", Qualifier qualifiers = Qualifier{0}) { - return Type{qualifiers, Value{sizeof(T), name, Numeric{rank, std::numeric_limits::min(), std::numeric_limits::max(), - std::numeric_limits::lowest(), std::numeric_limits::max_digits10, - std::is_signed::value, std::is_integral::value, literalSuffix}}}; + return ResolvedType{qualifiers, Value{sizeof(T), name, Numeric{rank, std::numeric_limits::min(), std::numeric_limits::max(), + std::numeric_limits::lowest(), std::numeric_limits::max_digits10, + std::is_signed::value, std::is_integral::value, literalSuffix}}}; } - static Type createPointer(const Type &valueType, Qualifier qualifiers = Qualifier{0}) + static ResolvedType createPointer(const ResolvedType &valueType, Qualifier qualifiers = Qualifier{0}) { - return Type(qualifiers, Pointer{valueType}); + return ResolvedType(qualifiers, Pointer{valueType}); } }; @@ -235,29 +235,29 @@ typedef std::unordered_map TypeContext; //---------------------------------------------------------------------------- // Declare numeric types //---------------------------------------------------------------------------- -inline static const Type Bool = CREATE_NUMERIC(bool, 0, ""); -inline static const Type Int8 = CREATE_NUMERIC(int8_t, 10, ""); -inline static const Type Int16 = CREATE_NUMERIC(int16_t, 20, ""); -inline static const Type Int32 = CREATE_NUMERIC(int32_t, 30, ""); +inline static const ResolvedType Bool = CREATE_NUMERIC(bool, 0, ""); +inline static const ResolvedType Int8 = CREATE_NUMERIC(int8_t, 10, ""); +inline static const ResolvedType Int16 = CREATE_NUMERIC(int16_t, 20, ""); +inline static const ResolvedType Int32 = CREATE_NUMERIC(int32_t, 30, ""); //DECLARE_NUMERIC_TYPE(Int64, int64_t, 40); -inline static const Type Uint8 = CREATE_NUMERIC(uint8_t, 10, "u"); -inline static const Type Uint16 = CREATE_NUMERIC(uint16_t, 20, "u"); -inline static const Type Uint32 = CREATE_NUMERIC(uint32_t, 30, "u"); +inline static const ResolvedType Uint8 = CREATE_NUMERIC(uint8_t, 10, "u"); +inline static const ResolvedType Uint16 = CREATE_NUMERIC(uint16_t, 20, "u"); +inline static const ResolvedType Uint32 = CREATE_NUMERIC(uint32_t, 30, "u"); //DECLARE_NUMERIC_TYPE(Uint64, uint64_t, 40); -inline static const Type Float = CREATE_NUMERIC(float, 50, "f"); -inline static const Type Double = CREATE_NUMERIC(double, 60, ""); +inline static const ResolvedType Float = CREATE_NUMERIC(float, 50, "f"); +inline static const ResolvedType Double = CREATE_NUMERIC(double, 60, ""); //! Parse a numeric type -Type parseNumeric(const std::string &typeString); +ResolvedType parseNumeric(const std::string &typeString); //! Look up numeric type based on set of type specifiers -Type getNumericType(const std::set &typeSpecifiers); +ResolvedType getNumericType(const std::set &typeSpecifiers); //! Apply C type promotion rules to numeric type -Type getPromotedType(const Type &type); +ResolvedType getPromotedType(const ResolvedType &type); //! Apply C rules to get common type between numeric types a and b -Type getCommonType(const Type &a, const Type &b); +ResolvedType getCommonType(const ResolvedType &a, const ResolvedType &b); } // namespace GeNN::Type diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 23124f6108..e08cbca98c 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1293,19 +1293,19 @@ void Backend::genStepTimeFinalisePreamble(CodeStream &, const ModelSpecMerged &) } //-------------------------------------------------------------------------- void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &, - const Type::Type &type, const std::string &name, VarLocation) const + const Type::ResolvedType &type, const std::string &name, VarLocation) const { definitions << "EXPORT_VAR " << type.getNumeric().name << "* " << name << ";" << std::endl; } //-------------------------------------------------------------------------- void Backend::genVariableInstantiation(CodeStream &os, - const Type::Type &type, const std::string &name, VarLocation) const + const Type::ResolvedType &type, const std::string &name, VarLocation) const { os << type.getNumeric().name << "* " << name << ";" << std::endl; } //-------------------------------------------------------------------------- void Backend::genVariableAllocation(CodeStream &os, - const Type::Type &type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ResolvedType &type, const Type::TypeContext &typeContext, const std::string &name, VarLocation, size_t count, MemAlloc &memAlloc) const { os << name << " = new " << type.getNumeric().name << "[" << count << "];" << std::endl; @@ -1314,7 +1314,7 @@ void Backend::genVariableAllocation(CodeStream &os, } //-------------------------------------------------------------------------- void Backend::genVariableDynamicAllocation(CodeStream &os, - const Type::Type &type, const std::string &name, VarLocation, + const Type::ResolvedType &type, const std::string &name, VarLocation, const std::string &countVarName, const std::string &prefix) const { if (type.isPointer()) { @@ -1330,39 +1330,39 @@ void Backend::genVariableFree(CodeStream &os, const std::string &name, VarLocati os << "delete[] " << name << ";" << std::endl; } //-------------------------------------------------------------------------- -void Backend::genVariablePush(CodeStream&, const Type::Type&, const std::string&, VarLocation, bool, size_t) const +void Backend::genVariablePush(CodeStream&, const Type::ResolvedType&, const std::string&, VarLocation, bool, size_t) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- -void Backend::genVariablePull(CodeStream&, const Type::Type&, const std::string&, VarLocation, size_t) const +void Backend::genVariablePull(CodeStream&, const Type::ResolvedType&, const std::string&, VarLocation, size_t) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- void Backend::genCurrentVariablePush(CodeStream&, const NeuronGroupInternal&, - const Type::Type&, const std::string&, + const Type::ResolvedType&, const std::string&, VarLocation, unsigned int) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- void Backend::genCurrentVariablePull(CodeStream&, const NeuronGroupInternal&, - const Type::Type&, const std::string&, + const Type::ResolvedType&, const std::string&, VarLocation, unsigned int) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- void Backend::genVariableDynamicPush(CodeStream&, - const Type::Type&, const std::string&, + const Type::ResolvedType&, const std::string&, VarLocation, const std::string&, const std::string&) const { assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- void Backend::genVariableDynamicPull(CodeStream&, - const Type::Type&, const std::string&, + const Type::ResolvedType&, const std::string&, VarLocation, const std::string&, const std::string&) const { assert(!getPreferences().automaticCopy); @@ -1376,12 +1376,12 @@ void Backend::genMergedDynamicVariablePush(CodeStream &os, const std::string &su } //-------------------------------------------------------------------------- -std::string Backend::getMergedGroupFieldHostTypeName(const Type::Type &type) const +std::string Backend::getMergedGroupFieldHostTypeName(const Type::ResolvedType &type) const { return type.getNumeric().name; } //-------------------------------------------------------------------------- -const Type::Type &Backend::getMergedGroupSimRNGType() const +const Type::ResolvedType &Backend::getMergedGroupSimRNGType() const { assert(false); return nullptr; diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index b0143e3e68..9db6c03fa4 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -27,21 +27,21 @@ namespace Type = GeNN::Type; //--------------------------------------------------------------------------- namespace { -std::string getDescription(const Type::Type &type) +std::string getDescription(const Type::ResolvedType &type) { const std::string qualifier = type.hasQualifier(Type::Qualifier::CONSTANT) ? "const " : ""; return std::visit( Utils::Overload{ - [&qualifier](const Type::Type::Value &value) + [&qualifier](const Type::ResolvedType::Value &value) { assert(value.numeric); return qualifier + value.name; }, - [&qualifier, &type](const Type::Type::Pointer &pointer) + [&qualifier, &type](const Type::ResolvedType::Pointer &pointer) { return qualifier + getDescription(*pointer.valueType) + "*"; }, - [&type](const Type::Type::Function &function) + [&type](const Type::ResolvedType::Function &function) { std::string description = getDescription(*function.returnType) + "("; for (const auto &a : function.argTypes) { @@ -52,16 +52,16 @@ std::string getDescription(const Type::Type &type) type.detail); } //--------------------------------------------------------------------------- -bool checkPointerTypeAssignement(const Type::Type &rightType, const Type::Type &leftType) +bool checkPointerTypeAssignement(const Type::ResolvedType &rightType, const Type::ResolvedType &leftType) { return std::visit( Utils::Overload{ - [&rightType, &leftType](const Type::Type::Value &leftValue, const Type::Type::Value &rightValue) + [&rightType, &leftType](const Type::ResolvedType::Value &leftValue, const Type::ResolvedType::Value &rightValue) { assert(leftValue.numeric && rightValue.numeric); return (rightType == leftType); }, - [](const Type::Type::Pointer &rightPointer, const Type::Type::Pointer &leftPointer) + [](const Type::ResolvedType::Pointer &rightPointer, const Type::ResolvedType::Pointer &leftPointer) { return checkPointerTypeAssignement(*rightPointer.valueType, *leftPointer.valueType); }, @@ -70,7 +70,7 @@ bool checkPointerTypeAssignement(const Type::Type &rightType, const Type::Type & rightType.detail, leftType.detail); } //--------------------------------------------------------------------------- -bool checkForConstRemoval(const Type::Type &rightType, const Type::Type &leftType) +bool checkForConstRemoval(const Type::ResolvedType &rightType, const Type::ResolvedType &leftType) { // If const is being removed if (rightType.hasQualifier(Type::Qualifier::CONSTANT) && !leftType.hasQualifier(Type::Qualifier::CONSTANT)) { @@ -80,12 +80,12 @@ bool checkForConstRemoval(const Type::Type &rightType, const Type::Type &leftTyp return std::visit( Utils::Overload{ // If both are value types - [](const Type::Type::Value &rightValue, const Type::Type::Value &leftValue) + [](const Type::ResolvedType::Value &rightValue, const Type::ResolvedType::Value &leftValue) { return true; }, // Otherwise, if both are pointers, recurse through value type - [](const Type::Type::Pointer &rightPointer, const Type::Type::Pointer &leftPointer) + [](const Type::ResolvedType::Pointer &rightPointer, const Type::ResolvedType::Pointer &leftPointer) { return checkForConstRemoval(*rightPointer.valueType, *leftPointer.valueType); }, @@ -94,12 +94,12 @@ bool checkForConstRemoval(const Type::Type &rightType, const Type::Type &leftTyp rightType.detail, leftType.detail); } //--------------------------------------------------------------------------- -bool checkImplicitConversion(const Type::Type &rightType, const Type::Type &leftType, Token::Type op = Token::Type::EQUAL) +bool checkImplicitConversion(const Type::ResolvedType &rightType, const Type::ResolvedType &leftType, Token::Type op = Token::Type::EQUAL) { return std::visit( Utils::Overload{ // If both are numeric, return true as any numeric types can be assigned - [op](const Type::Type::Value &rightValue, const Type::Type::Value &leftValue) + [op](const Type::ResolvedType::Value &rightValue, const Type::ResolvedType::Value &leftValue) { // If operator requires it and both arguments are integers, return true assert(leftValue.numeric && rightValue.numeric); @@ -116,7 +116,7 @@ bool checkImplicitConversion(const Type::Type &rightType, const Type::Type &left }, // Otherwise, if both are pointers, recurse through value type [op, &leftType, &rightType] - (const Type::Type::Pointer &rightPointer, const Type::Type::Pointer &leftPointer) + (const Type::ResolvedType::Pointer &rightPointer, const Type::ResolvedType::Pointer &leftPointer) { // If operator is equals if (op == Token::Type::EQUAL) { @@ -138,7 +138,7 @@ bool checkImplicitConversion(const Type::Type &rightType, const Type::Type &left } }, // Otherwise, if left is pointer and right is numeric, - [op](const Type::Type::Value &rightValue, const Type::Type::Pointer &leftPointer) + [op](const Type::ResolvedType::Value &rightValue, const Type::ResolvedType::Pointer &leftPointer) { assert(rightValue.numeric); if (op == Token::Type::PLUS_EQUAL || op == Token::Type::MINUS_EQUAL) { @@ -167,7 +167,7 @@ class EnvironmentInternal : public EnvironmentBase //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- - virtual void define(const Token &name, const Type::Type &type, ErrorHandlerBase &errorHandler) final + virtual void define(const Token &name, const Type::ResolvedType &type, ErrorHandlerBase &errorHandler) final { if(!m_Types.try_emplace(name.lexeme, type).second) { errorHandler.error(name, "Redeclaration of variable"); @@ -175,7 +175,7 @@ class EnvironmentInternal : public EnvironmentBase } } - virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final { auto type = m_Types.find(name.lexeme); if(type == m_Types.end()) { @@ -191,7 +191,7 @@ class EnvironmentInternal : public EnvironmentBase // Members //--------------------------------------------------------------------------- EnvironmentBase &m_Enclosing; - std::unordered_map m_Types; + std::unordered_map m_Types; }; //--------------------------------------------------------------------------- @@ -287,7 +287,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor Utils::Overload{ // If both operands are numeric [&leftType, &rightType, opType, this] - (const Type::Type::Value &rightNumeric, const Type::Type::Numeric &leftNumeric) -> std::optional + (const Type::ResolvedType::Value &rightValue, const Type::ResolvedType::Value &leftValue) -> std::optional { // If operator requires integer operands if (opType == Token::Type::PERCENT || opType == Token::Type::SHIFT_LEFT @@ -295,7 +295,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor || opType == Token::Type::AMPERSAND || opType == Token::Type::PIPE) { // Check that operands are integers - if (leftNumeric.isIntegral && rightNumeric.isIntegral) { + if (leftValue.numeric->isIntegral && rightValue.numeric->isIntegral) { // If operator is a shift, promote left type if (opType == Token::Type::SHIFT_LEFT || opType == Token::Type::SHIFT_RIGHT) { return Type::getPromotedType(leftType); @@ -316,7 +316,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor }, // Otherwise, if both operands are pointers [&binary, &leftType, &rightType, opType, this] - (const Type::Type::Pointer &rightPointer, const Type::Type::Pointer &leftPointer) -> std::optional + (const Type::ResolvedType::Pointer &rightPointer, const Type::ResolvedType::Pointer &leftPointer) -> std::optional { // If operator is minus and pointer types match if (opType == Token::Type::MINUS && leftType == rightType) { @@ -329,11 +329,11 @@ class Visitor : public Expression::Visitor, public Statement::Visitor }, // Otherwise, if right is numeric and left is pointer [&binary, &leftType, &rightType, opType, this] - (const Type::Type::Numeric &rightNumeric, const Type::Type::Pointer &leftPointer) -> std::optional + (const Type::ResolvedType::Value &rightValue, const Type::ResolvedType::Pointer &leftPointer) -> std::optional { // If operator is valid and numeric type is integer // P + n or P - n - if ((opType == Token::Type::PLUS || opType == Token::Type::MINUS) && rightNumeric.isIntegral) { + if ((opType == Token::Type::PLUS || opType == Token::Type::MINUS) && rightValue.numeric->isIntegral) { return leftType; } else { @@ -342,10 +342,10 @@ class Visitor : public Expression::Visitor, public Statement::Visitor }, // Otherwise, if right is pointer and left is numeric [&binary, &rightType, opType, this] - (const Type::Type::Pointer &rightPointer, const Type::Type::Numeric &leftNumeric) -> std::optional + (const Type::ResolvedType::Pointer &rightPointer, const Type::ResolvedType::Value &leftValue) -> std::optional { // n + P - if (opType == Token::Type::PLUS && leftNumeric.isIntegral) { + if (opType == Token::Type::PLUS && leftValue.numeric->isIntegral) { return rightType; } else { @@ -353,7 +353,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } }, // Otherwise, operator is being applied to unsupported types - [](auto, auto) -> std::optional + [](auto, auto) -> std::optional { return std::nullopt; }}, @@ -407,12 +407,17 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto resultType = std::visit( Utils::Overload{ // If types are numeric, any cast goes - [&cast](const Type::Type::Numeric &rightNumeric, const Type::Type::Numeric &castNumeric) -> std::optional + [&cast](const Type::ResolvedType::Value &rightValue, const Type::ResolvedType::Value &castValue) -> std::optional { - return cast.getType(); + if (rightValue.numeric && castValue.numeric) { + return cast.getType(); + } + else { + return std::nullopt; + } }, // Otherwise, if we're trying to cast pointer to pointer - [&cast](const Type::Type::Pointer &rightPointer, const Type::Type::Pointer &castPointer) -> std::optional + [&cast](const Type::ResolvedType::Pointer &rightPointer, const Type::ResolvedType::Pointer &castPointer) -> std::optional { // Check that value type at the end matches if (checkPointerTypeAssignement(*rightPointer.valueType, *castPointer.valueType)) { @@ -423,7 +428,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } }, // Otherwise, pointers can't be cast to non-pointers and vice versa - [](auto, auto) -> std::optional + [](auto, auto) -> std::optional { return std::nullopt; }}, @@ -487,7 +492,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor setExpressionType(&literal, Type::Uint32); } else if(literal.getValue().type == Token::Type::STRING) { - setExpressionType(&literal, Type::Type::createPointer(Type::Int8, Type::Qualifier::CONSTANT)); + setExpressionType(&literal, Type::ResolvedType::createPointer(Type::Int8, Type::Qualifier::CONSTANT)); } else { assert(false); @@ -540,7 +545,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor assert(!m_CallArguments.empty()); // Loop through variable types - std::vector>> viableFunctions; + std::vector>> viableFunctions; for(const auto &type : varTypes) { // If function is non-variadic and number of arguments match const auto &argumentTypes = type.getFunction().argTypes; @@ -557,7 +562,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto argConversionRank = std::visit( Utils::Overload{ // If types are numeric, any cast goes - [c, a](const Type::Type::Numeric &cNumeric, const Type::Type::Numeric &aNumeric) -> std::optional + [c, a](const Type::ResolvedType::Value &cValue, const Type::ResolvedType::Value &aValue) -> std::optional { // If names are identical, match is exact // **TODO** we don't care about qualifiers @@ -565,8 +570,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor return 0; } // Integer promotion - else if(*a == Type::Int32 && c->getNumeric().isIntegral - && c->getNumeric().rank < Type::Int32.getNumeric().rank) + else if(*a == Type::Int32 && cValue.numeric->isIntegral + && cValue.numeric->rank < Type::Int32.getNumeric().rank) { return 1; } @@ -581,7 +586,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } }, // Otherwise, if we're trying to cast pointer to pointer - [](const Type::Type::Pointer &cPointer, const Type::Type::Pointer &aPointer) -> std::optional + [](const Type::ResolvedType::Pointer &cPointer, const Type::ResolvedType::Pointer &aPointer) -> std::optional { // Check that value type at the end matches if (checkPointerTypeAssignement(*cPointer.valueType, *aPointer.valueType)) { @@ -672,7 +677,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Otherwise, if operator is address of, return pointer type else if (unary.getOperator().type == Token::Type::AMPERSAND) { - setExpressionType(&unary, Type::Type::createPointer(rightType)); + setExpressionType(&unary, Type::ResolvedType::createPointer(rightType)); } } else { @@ -834,13 +839,13 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // Private methods //--------------------------------------------------------------------------- - Type::Type evaluateType(const Expression::Base *expression) + Type::ResolvedType evaluateType(const Expression::Base *expression) { expression->accept(*this); return m_ResolvedTypes.at(expression); } - void setExpressionType(const Expression::Base *expression, const Type::Type &type) + void setExpressionType(const Expression::Base *expression, const Type::ResolvedType &type) { if (!m_ResolvedTypes.emplace(expression, type).second) { throw std::runtime_error("Expression type resolved multiple times"); @@ -854,7 +859,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const Type::TypeContext &m_Context; ErrorHandlerBase &m_ErrorHandler; ResolvedTypeMap &m_ResolvedTypes; - std::stack> m_CallArguments; + std::stack> m_CallArguments; bool m_InLoop; bool m_InSwitch; }; @@ -863,7 +868,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- -Type::Type EnvironmentBase::getType(const Token &name, ErrorHandlerBase &errorHandler) +Type::ResolvedType EnvironmentBase::getType(const Token &name, ErrorHandlerBase &errorHandler) { const auto types = getTypes(name, errorHandler); if (types.size() == 1) { @@ -887,7 +892,7 @@ ResolvedTypeMap GeNN::Transpiler::TypeChecker::typeCheck(const Statement::Statem return expressionTypes; } //--------------------------------------------------------------------------- -Type::Type GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, EnvironmentBase &environment, +Type::ResolvedType GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, EnvironmentBase &environment, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { ResolvedTypeMap expressionTypes; diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index c3b398537e..1890d7d4e2 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -17,7 +17,7 @@ using namespace GeNN; // Anonymous namespace namespace { -const std::map, Type::Type> numericTypeSpecifiers{ +const std::map, Type::ResolvedType> numericTypeSpecifiers{ {{"char"}, Type::Int8}, {{"int8_t"}, Type::Int8}, @@ -49,7 +49,7 @@ const std::map, Type::Type> numericTypeSpecifiers{ const std::set scalarTypeSpecifier{{"scalar"}}; //---------------------------------------------------------------------------- // Mapping of signed integer numericTypeSpecifiers to their unsigned equivalents -const std::map unsignedType{ +const std::map unsignedType{ {Type::Int8, Type::Uint8}, {Type::Int16, Type::Uint16}, {Type::Int32, Type::Uint32}}; @@ -63,7 +63,7 @@ namespace GeNN::Type //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- -Type parseNumeric(const std::string &typeString) +ResolvedType parseNumeric(const std::string &typeString) { using namespace Transpiler; @@ -82,7 +82,7 @@ Type parseNumeric(const std::string &typeString) return type; } //---------------------------------------------------------------------------- -Type getNumericType(const std::set &typeSpecifiers) +ResolvedType getNumericType(const std::set &typeSpecifiers) { // If type matches scalar type specifiers if(typeSpecifiers == scalarTypeSpecifier) { @@ -97,7 +97,7 @@ Type getNumericType(const std::set &typeSpecifiers) } } //---------------------------------------------------------------------------- -Type getPromotedType(const Type &type) +ResolvedType getPromotedType(const ResolvedType &type) { // If a small integer type is used in an expression, it is implicitly converted to int which is always signed. // This is known as the integer promotions or the integer promotion rule @@ -111,7 +111,7 @@ Type getPromotedType(const Type &type) } } //---------------------------------------------------------------------------- -Type getCommonType(const Type &a, const Type &b) +ResolvedType getCommonType(const ResolvedType &a, const ResolvedType &b) { // If either type is double, common type is double assert(a.isNumeric()); @@ -126,8 +126,8 @@ Type getCommonType(const Type &a, const Type &b) // Otherwise, must be an integer type else { // Promote both numeric types - const Type aPromoted = getPromotedType(a); - const Type bPromoted = getPromotedType(b); + const ResolvedType aPromoted = getPromotedType(a); + const ResolvedType bPromoted = getPromotedType(b); // If both promoted operands have the same type, then no further conversion is needed. if(aPromoted == bPromoted) { @@ -140,8 +140,8 @@ Type getCommonType(const Type &a, const Type &b) } // Otherwise, if signedness of promoted operands differ else { - const Type signedOp = aPromoted.getNumeric().isSigned ? aPromoted : bPromoted; - const Type unsignedOp = aPromoted.getNumeric().isSigned ? bPromoted : aPromoted; + const ResolvedType signedOp = aPromoted.getNumeric().isSigned ? aPromoted : bPromoted; + const ResolvedType unsignedOp = aPromoted.getNumeric().isSigned ? bPromoted : aPromoted; // Otherwise, if the operand that has unsigned integer type has rank greater or equal to the rank of the type of the other operand, // then the operand with signed integer type is converted to the type of the operand with unsigned integer type. From 35454b53a62c2ea30d48c65f3cb5b12fe054a02d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 24 May 2023 16:22:25 +0100 Subject: [PATCH 156/725] moved ``Utils::Overload`` into GeNN utils and removed transpiler utils --- include/genn/genn/gennUtils.h | 4 +++ .../genn/genn/transpiler/transpilerUtils.h | 34 ------------------- src/genn/genn/genn.vcxproj | 1 - src/genn/genn/transpiler/typeChecker.cc | 3 +- src/genn/genn/type.cc | 20 +++++++++++ 5 files changed, 26 insertions(+), 36 deletions(-) delete mode 100644 include/genn/genn/transpiler/transpilerUtils.h diff --git a/include/genn/genn/gennUtils.h b/include/genn/genn/gennUtils.h index 3aabbdb2e6..413781b8df 100644 --- a/include/genn/genn/gennUtils.h +++ b/include/genn/genn/gennUtils.h @@ -179,6 +179,10 @@ inline void updateHash(const std::unordered_map &map, boost::uuids::detail } } +// Boilerplate for overloading base std::visit +template struct Overload : Ts... { using Ts::operator()...; }; +template Overload(Ts...) -> Overload; // line not needed in + //! Functor for generating a hash suitable for use in std::unordered_map etc (i.e. size_t size) from a SHA1 digests struct SHA1Hash { diff --git a/include/genn/genn/transpiler/transpilerUtils.h b/include/genn/genn/transpiler/transpilerUtils.h deleted file mode 100644 index 92f79059b3..0000000000 --- a/include/genn/genn/transpiler/transpilerUtils.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -// Standard C++ includes -//#include -#include -#include - -namespace GeNN::Transpiler::Utils -{ -template struct Overload : Ts... { using Ts::operator()...; }; -template Overload(Ts...) -> Overload; // line not needed in - -/*template -T toCharsThrow(std::string_view input, int base = 10) -{ - T out; - std::from_chars_result result; - if constexpr (std::is_floating_point_v) { - result = std::from_chars(input.data(), input.data() + input.size(), out, - (base == 10) ? std::chars_format::general : std::chars_format::hex); - } - else { - result = std::from_chars(input.data(), input.data() + input.size(), out, base); - } - - if(result.ec == std::errc::invalid_argument) { - throw std::invalid_argument("Unable to convert chars '" + std::string{input} + "'"); - } - else if(result.ec == std::errc::result_out_of_range) { - throw std::out_of_range("Unable to convert chars '" + std::string{input} + "'"); - } - return out; -}*/ -} // namespace GeNN::Transpiler::Utils diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index e2d1a220f3..5815b91f05 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -128,7 +128,6 @@ - diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 9db6c03fa4..f66d70fd18 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -11,16 +11,17 @@ #include // GeNN includes +#include "gennUtils.h" #include "type.h" // Transpiler includes #include "transpiler/errorHandler.h" #include "transpiler/expression.h" -#include "transpiler/transpilerUtils.h" using namespace GeNN::Transpiler; using namespace GeNN::Transpiler::TypeChecker; namespace Type = GeNN::Type; +namespace Utils = GeNN::Utils; //--------------------------------------------------------------------------- // Anonymous namespace diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 1890d7d4e2..26d7f35be2 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -3,8 +3,10 @@ // Standard C++ includes #include #include +#include // GeNN includes +#include "gennUtils.h" #include "logging.h" // Transpiler includes @@ -60,6 +62,24 @@ const std::map unsignedType{ //---------------------------------------------------------------------------- namespace GeNN::Type { + +//---------------------------------------------------------------------------- +// UnresolvedType +//---------------------------------------------------------------------------- +ResolvedType UnresolvedType::resolve(const std::unordered_map &typeContext) const +{ + return std::visit( + Utils::Overload{ + [](const Type::ResolvedType &resolved) + { + return resolved; + }, + [&typeContext](const std::string &name) + { + return typeContext.at(name); + }}, + detail); +} //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- From bcad3d71bd274ae739e061cf7a7da9f131bfbb48 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 24 May 2023 16:30:46 +0100 Subject: [PATCH 157/725] started adding infrastructure to use ``Type::UnresolvedType`` in ``Models::Var`` * New ``updateHash`` functions for --- include/genn/genn/gennUtils.h | 33 ++++++++++-- include/genn/genn/models.h | 30 ++++++----- include/genn/genn/type.h | 98 +++++++++++++++++++++++++++++++---- src/genn/genn/models.cc | 35 ++----------- src/genn/genn/type.cc | 41 ++++++++++++++- 5 files changed, 179 insertions(+), 58 deletions(-) diff --git a/include/genn/genn/gennUtils.h b/include/genn/genn/gennUtils.h index 413781b8df..60f95e9b69 100644 --- a/include/genn/genn/gennUtils.h +++ b/include/genn/genn/gennUtils.h @@ -4,10 +4,12 @@ #include #include #include +#include #include #include #include #include +#include #include // Standard C includes @@ -125,6 +127,10 @@ inline std::string writePreciseString(T value, int maxDigits10 = std::numeric_li return s.str(); } +//! Boilerplate for overloading base std::visit +template struct Overload : Ts... { using Ts::operator()...; }; +template Overload(Ts...) -> Overload; // line not needed in + //! Hash arithmetic types and enums template::value || std::is_enum::value>::type* = nullptr> inline void updateHash(const T& value, boost::uuids::detail::sha1& hash) @@ -179,9 +185,30 @@ inline void updateHash(const std::unordered_map &map, boost::uuids::detail } } -// Boilerplate for overloading base std::visit -template struct Overload : Ts... { using Ts::operator()...; }; -template Overload(Ts...) -> Overload; // line not needed in +//! Hash optional types which can, themeselves, be hashed +template +inline void updateHash(const std::optional &optional, boost::uuids::detail::sha1 &hash) +{ + updateHash(optional.has_value(), hash); + if (optional) { + updateHash(optional.value(), hash); + } +} + +//! Hash variants of types which can, themeselves, be hashed +template +inline void updateHash(const std::variant &variant, boost::uuids::detail::sha1 &hash) +{ + updateHash(variant.index(), hash); + std::visit( + Utils::Overload{ + [&hash](const auto &v) + { + updateHash(v, hash); + }}, + variant); +} + //! Functor for generating a hash suitable for use in std::unordered_map etc (i.e. size_t size) from a SHA1 digests struct SHA1Hash diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 3c4bc42398..643b6dddb0 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -51,35 +51,39 @@ class GENN_EXPORT Base : public Snippet::Base if not specified, this results in a -Wmissing-field-initializers warning on GCC and Clang*/ struct Var { - Var(const std::string &n, const Type::NumericBase *t, VarAccess a) : name(n), type(t), access(a) + Var(const std::string &n, const Type::ResolvedType &t, VarAccess a) : name(n), type(t), access(a) {} - Var(const std::string &n, const Type::NumericBase *t) : Var(n, t, VarAccess::READ_WRITE) + Var(const std::string &n, const Type::ResolvedType &t) : Var(n, t, VarAccess::READ_WRITE) + {} + Var(const std::string &n, const std::string &t, VarAccess a) : name(n), type(t), access(a) {} - Var(const std::string &n, const std::string &t, VarAccess a); Var(const std::string &n, const std::string &t) : Var(n, t, VarAccess::READ_WRITE) {} - bool operator == (const Var &other) const; + bool operator == (const Var &other) const + { + return (std::tie(name, type, access) == std::tie(other.name, other.type, other.access)); + } - const std::string name; - const Type::NumericBase *type; - const VarAccess access; + std::string name; + Type::UnresolvedType type; + VarAccess access; }; struct VarRef { - VarRef(const std::string &n, const Type::NumericBase *t, VarAccessMode a) : name(n), type(t), access(a) + VarRef(const std::string &n, const Type::ResolvedType &t, VarAccessMode a) : name(n), type(t), access(a) {} - VarRef(const std::string &n, const Type::NumericBase *t) : VarRef(n, t, VarAccessMode::READ_WRITE) + VarRef(const std::string &n, const Type::ResolvedType &t) : VarRef(n, t, VarAccessMode::READ_WRITE) {} VarRef(const std::string &n, const std::string &t, VarAccessMode a); VarRef(const std::string &n, const std::string &t); bool operator == (const VarRef &other) const; - const std::string name; - const Type::NumericBase *type; - const VarAccessMode access; + std::string name; + Type::UnresolvedType type; + VarAccessMode access; }; //---------------------------------------------------------------------------- @@ -91,7 +95,7 @@ class GENN_EXPORT Base : public Snippet::Base //---------------------------------------------------------------------------- // Declared virtuals //------------------------------------------------------------------------ - //! Gets names and types (as strings) of model variables + //! Gets model variables virtual VarVec getVars() const{ return {}; } //------------------------------------------------------------------------ diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index f24d670d1d..96a93af085 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -18,6 +18,7 @@ // GeNN includes #include "gennExport.h" +#include "gennUtils.h" //---------------------------------------------------------------------------- // Macros @@ -74,6 +75,12 @@ struct ResolvedType == std::tie(other.rank, other.min, other.max, other.lowest, other.maxDigits10, other.isSigned, other.isIntegral)); } + bool operator != (const Numeric &other) const + { + return (std::tie(rank, min, max, lowest, maxDigits10, isSigned, isIntegral) + != std::tie(other.rank, other.min, other.max, other.lowest, other.maxDigits10, other.isSigned, other.isIntegral)); + } + bool operator < (const Numeric &other) const { return (std::tie(rank, min, max, lowest, maxDigits10, isSigned, isIntegral) @@ -95,12 +102,17 @@ struct ResolvedType //------------------------------------------------------------------------ bool operator == (const Value &other) const { - return std::tie(size, numeric) == std::tie(other.size, other.numeric); + return (std::tie(size, numeric) == std::tie(other.size, other.numeric)); + } + + bool operator != (const Value &other) const + { + return (std::tie(size, numeric) != std::tie(other.size, other.numeric)); } bool operator < (const Value &other) const { - return std::tie(size, numeric) < std::tie(other.size, other.numeric); + return (std::tie(size, numeric) < std::tie(other.size, other.numeric)); } }; @@ -121,6 +133,11 @@ struct ResolvedType return (*valueType == *other.valueType); } + bool operator != (const Pointer &other) const + { + return (*valueType != *other.valueType); + } + bool operator < (const Pointer &other) const { return (*valueType < *other.valueType); @@ -150,12 +167,17 @@ struct ResolvedType bool operator == (const Function &other) const { - return std::tie(*returnType, argTypes) == std::tie(*other.returnType, other.argTypes); + return (std::tie(*returnType, argTypes) == std::tie(*other.returnType, other.argTypes)); + } + + bool operator != (const Function &other) const + { + return (std::tie(*returnType, argTypes) != std::tie(*other.returnType, other.argTypes)); } bool operator < (const Function &other) const { - return std::tie(*returnType, argTypes) < std::tie(*other.returnType, other.argTypes); + return (std::tie(*returnType, argTypes) < std::tie(*other.returnType, other.argTypes)); } Function &operator = (const Function &other) @@ -205,12 +227,17 @@ struct ResolvedType //------------------------------------------------------------------------ bool operator == (const ResolvedType &other) const { - return std::tie(qualifiers, detail) == std::tie(other.qualifiers, other.detail); + return (std::tie(qualifiers, detail) == std::tie(other.qualifiers, other.detail)); + } + + bool operator != (const ResolvedType &other) const + { + return (std::tie(qualifiers, detail) != std::tie(other.qualifiers, other.detail)); } bool operator < (const ResolvedType &other) const { - return std::tie(qualifiers, detail) < std::tie(other.qualifiers, other.detail); + return (std::tie(qualifiers, detail) < std::tie(other.qualifiers, other.detail)); } //------------------------------------------------------------------------ @@ -230,6 +257,47 @@ struct ResolvedType } }; +//---------------------------------------------------------------------------- +// UnresolvedType +//---------------------------------------------------------------------------- +struct UnresolvedType +{ + UnresolvedType(const ResolvedType &type) + : detail(type) + {} + UnresolvedType(const std::string &name) + : detail(name) + {} + + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + std::variant detail; + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + ResolvedType resolve(const std::unordered_map &typeContext) const; + + //------------------------------------------------------------------------ + // Operators + //------------------------------------------------------------------------ + bool operator == (const UnresolvedType &other) const + { + return (detail == other.detail); + } + + bool operator != (const UnresolvedType &other) const + { + return (detail != other.detail); + } + + bool operator < (const UnresolvedType &other) const + { + return (detail < other.detail); + } +}; + typedef std::unordered_map TypeContext; //---------------------------------------------------------------------------- @@ -248,16 +316,24 @@ inline static const ResolvedType Float = CREATE_NUMERIC(float, 50, "f"); inline static const ResolvedType Double = CREATE_NUMERIC(double, 60, ""); //! Parse a numeric type -ResolvedType parseNumeric(const std::string &typeString); +GENN_EXPORT ResolvedType parseNumeric(const std::string &typeString); //! Look up numeric type based on set of type specifiers -ResolvedType getNumericType(const std::set &typeSpecifiers); +GENN_EXPORT ResolvedType getNumericType(const std::set &typeSpecifiers); //! Apply C type promotion rules to numeric type -ResolvedType getPromotedType(const ResolvedType &type); +GENN_EXPORT ResolvedType getPromotedType(const ResolvedType &type); //! Apply C rules to get common type between numeric types a and b -ResolvedType getCommonType(const ResolvedType &a, const ResolvedType &b); - +GENN_EXPORT ResolvedType getCommonType(const ResolvedType &a, const ResolvedType &b); +//---------------------------------------------------------------------------- +// updateHash overrides +//---------------------------------------------------------------------------- +GENN_EXPORT void updateHash(const ResolvedType::Numeric &v, boost::uuids::detail::sha1 &hash); +GENN_EXPORT void updateHash(const ResolvedType::Value &v, boost::uuids::detail::sha1 &hash); +GENN_EXPORT void updateHash(const ResolvedType::Pointer &v, boost::uuids::detail::sha1 &hash); +GENN_EXPORT void updateHash(const ResolvedType::Function &v, boost::uuids::detail::sha1 &hash); +GENN_EXPORT void updateHash(const ResolvedType &v, boost::uuids::detail::sha1 &hash); +GENN_EXPORT void updateHash(const UnresolvedType &v, boost::uuids::detail::sha1 &hash); } // namespace GeNN::Type diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 6b3b3b7115..ad9b50f7b5 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -6,38 +6,12 @@ #include "currentSourceInternal.h" #include "neuronGroupInternal.h" #include "synapseGroupInternal.h" -#include "type.h" //---------------------------------------------------------------------------- -// GeNN::Models::Base::Var +// GeNN::Models::Base //---------------------------------------------------------------------------- namespace GeNN::Models { -Base::Var::Var(const std::string &n, const std::string &t, VarAccess a) : name(n), type(Type::parseNumeric(t)), access(a) -{} -//---------------------------------------------------------------------------- -bool Base::Var::operator == (const Var &other) const -{ - return (std::make_tuple(name, type->getName(), access) == std::make_tuple(other.name, other.type->getName(), other.access)); -} - -//---------------------------------------------------------------------------- -// GeNN::Models::Base::VarRef -//---------------------------------------------------------------------------- -Base::VarRef::VarRef(const std::string &n, const std::string &t, VarAccessMode a) : name(n), type(Type::parseNumeric(t)), access(a) -{} -//---------------------------------------------------------------------------- -Base::VarRef::VarRef(const std::string &n, const std::string &t) : VarRef(n, t, VarAccessMode::READ_WRITE) -{} -//---------------------------------------------------------------------------- -bool Base::VarRef::operator == (const VarRef &other) const -{ - return (std::make_tuple(name, type->getName(), access) == std::make_tuple(other.name, other.type->getName(), other.access)); -} - -//---------------------------------------------------------------------------- -// GeNN::Models::Base -//---------------------------------------------------------------------------- void Base::updateHash(boost::uuids::detail::sha1 &hash) const { // Superclass @@ -195,7 +169,8 @@ WUVarReference::WUVarReference(SynapseGroup *sg, const std::string &varName, } // Check types - if(getVar().type->getName() != getTransposeVar().type->getName()) { + // **NOTE** this is a bit over-conservative as, at this point, types are not resolved so "scalar" cannot be compared with "float" + if(getVar().type != getTransposeVar().type) { throw std::runtime_error("Transpose updates can only be performed on variables with the same type"); } @@ -236,14 +211,14 @@ SynapseGroup *WUVarReference::getTransposeSynapseGroup() const void updateHash(const Base::Var &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.name, hash); - Utils::updateHash(v.type->getName(), hash); + Type::updateHash(v.type, hash); Utils::updateHash(v.access, hash); } //---------------------------------------------------------------------------- void updateHash(const Base::VarRef &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.name, hash); - Utils::updateHash(v.type->getName(), hash); + Type::updateHash(v.type, hash); Utils::updateHash(v.access, hash); } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 26d7f35be2..9edde1855e 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -6,7 +6,6 @@ #include // GeNN includes -#include "gennUtils.h" #include "logging.h" // Transpiler includes @@ -182,4 +181,44 @@ ResolvedType getCommonType(const ResolvedType &a, const ResolvedType &b) } } } +//---------------------------------------------------------------------------- +void updateHash(const ResolvedType::Numeric &v, boost::uuids::detail::sha1 &hash) +{ + Utils::updateHash(v.rank, hash); + Utils::updateHash(v.min, hash); + Utils::updateHash(v.max, hash); + Utils::updateHash(v.lowest, hash); + Utils::updateHash(v.maxDigits10, hash); + Utils::updateHash(v.isSigned, hash); + Utils::updateHash(v.isIntegral, hash); + Utils::updateHash(v.literalSuffix, hash); +} +//---------------------------------------------------------------------------- +void updateHash(const ResolvedType::Value &v, boost::uuids::detail::sha1 &hash) +{ + Utils::updateHash(v.size, hash); + Utils::updateHash(v.numeric, hash); +} +//---------------------------------------------------------------------------- +void updateHash(const ResolvedType::Pointer &v, boost::uuids::detail::sha1 &hash) +{ + updateHash(*v.valueType, hash); +} +//---------------------------------------------------------------------------- +void updateHash(const ResolvedType::Function &v, boost::uuids::detail::sha1 &hash) +{ + updateHash(*v.returnType, hash); + Utils::updateHash(v.argTypes, hash); +} +//---------------------------------------------------------------------------- +void updateHash(const ResolvedType &v, boost::uuids::detail::sha1 &hash) +{ + Utils::updateHash(v.qualifiers, hash); + Utils::updateHash(v.detail, hash); +} +//---------------------------------------------------------------------------- +void updateHash(const UnresolvedType &v, boost::uuids::detail::sha1 &hash) +{ + Utils::updateHash(v.detail, hash); +} } // namespace GeNN::Type From d9f9261edcad5e6e38787e26f94ebac3389d6b3d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 24 May 2023 17:13:14 +0100 Subject: [PATCH 158/725] started hooking up value types throughout API --- .../genn/genn/code_generator/backendBase.h | 19 +++----- .../genn/genn/code_generator/codeGenUtils.h | 2 +- include/genn/genn/modelSpec.h | 13 ++--- include/genn/genn/snippet.h | 48 ++++++++++--------- include/genn/genn/synapseGroup.h | 2 +- src/genn/genn/code_generator/backendBase.cc | 12 +++-- src/genn/genn/code_generator/codeGenUtils.cc | 12 ++--- .../genn/code_generator/generateRunner.cc | 4 +- src/genn/genn/modelSpec.cc | 32 +++++++++++-- src/genn/genn/snippet.cc | 36 ++------------ src/genn/genn/synapseGroup.cc | 12 ++--- 11 files changed, 92 insertions(+), 100 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 3d2d1a9f33..43c843bd12 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -494,15 +494,10 @@ class GENN_EXPORT BackendBase //! Simple struct to hold reduction targets struct ReductionTarget { - ReductionTarget(const std::string &n, const Type::NumericBase *t, VarAccessMode a, const std::string &i) - : name(n), type(t), access(a), index(i) - { - } - - const std::string name; - const Type::NumericBase *type; - const VarAccessMode access; - const std::string index; + std::string name; + Type::ResolvedType type; + VarAccessMode access; + std::string index; }; //-------------------------------------------------------------------------- @@ -517,11 +512,11 @@ class GENN_EXPORT BackendBase void genCustomConnectivityUpdateIndexCalculation(CodeStream &os, const CustomConnectivityUpdateGroupMerged &cu) const; //! Get the initial value to start reduction operations from - std::string getReductionInitialValue(VarAccessMode access, const Type::NumericBase *type, const Type::TypeContext &context) const; + std::string getReductionInitialValue(VarAccessMode access, const Type::ResolvedType &type) const; //! Generate a reduction operation to reduce value into reduction - std::string getReductionOperation(const std::string &reduction, const std::string &value, VarAccessMode access, - const Type::NumericBase *type, const Type::TypeContext &context) const; + std::string getReductionOperation(const std::string &reduction, const std::string &value, + VarAccessMode access, const Type::ResolvedType &type) const; //! Helper function to generate initialisation code for any reduction operations carried out be custom update group. diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index ac07a50899..8363bd2a92 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -67,7 +67,7 @@ inline size_t padSize(size_t size, size_t blockSize) return ceilDivide(size, blockSize) * blockSize; } -GENN_EXPORT void genTypeRange(CodeStream &os, const Type::NumericBase *precision, const Type::TypeContext &typeContext, const std::string &prefix); +GENN_EXPORT void genTypeRange(CodeStream &os, const Type::ResolvedType &type, const std::string &prefix); //-------------------------------------------------------------------------- /*! \brief This function implements a parser that converts any floating point constant in a code snippet to a floating point constant with an explicit precision (by appending "f" or removing it). diff --git a/include/genn/genn/modelSpec.h b/include/genn/genn/modelSpec.h index 6d398c378d..5c7e502af3 100644 --- a/include/genn/genn/modelSpec.h +++ b/include/genn/genn/modelSpec.h @@ -25,6 +25,7 @@ Part of the code generation and generated code sections. // Standard C++ includes #include +#include #include #include #include @@ -231,10 +232,10 @@ class GENN_EXPORT ModelSpec void setName(const std::string &name){ m_Name = name; } //! Set numerical precision for floating point - void setPrecision(const Type::NumericBase *precision){ m_Precision = precision; } + void setPrecision(const Type::ResolvedType &precision); //! Set numerical precision for time - void setTimePrecision(const Type::NumericBase *timePrecision){ m_TimePrecision = timePrecision; } + void setTimePrecision(const Type::ResolvedType &timePrecision); //! Set the integration step size of the model void setDT(double dt){ m_DT = dt; } @@ -278,10 +279,10 @@ class GENN_EXPORT ModelSpec const std::string &getName() const{ return m_Name; } //! Gets the floating point numerical precision - const Type::NumericBase *getPrecision() const{ return m_Precision; } + const Type::ResolvedType &getPrecision() const{ return m_Precision; } //! Gets the floating point numerical precision used to represent time - const Type::NumericBase *getTimePrecision() const{ return m_TimePrecision ? m_TimePrecision : m_Precision; } + const Type::ResolvedType &getTimePrecision() const{ return m_TimePrecision ? m_TimePrecision.value() : m_Precision; } //! Gets the model integration step size double getDT() const { return m_DT; } @@ -728,10 +729,10 @@ class GENN_EXPORT ModelSpec std::string m_Name; //! Type of floating point variables (float, double, ...; default: float) - const Type::NumericBase *m_Precision; + Type::ResolvedType m_Precision; //! Type of floating point variables used to store time - const Type::NumericBase *m_TimePrecision; + std::optional m_TimePrecision; //! The integration time step of the model double m_DT; diff --git a/include/genn/genn/snippet.h b/include/genn/genn/snippet.h index 072f11c572..e9a6e05a7b 100644 --- a/include/genn/genn/snippet.h +++ b/include/genn/genn/snippet.h @@ -13,15 +13,7 @@ // GeNN includes #include "gennExport.h" #include "gennUtils.h" - -// Forward declarations -namespace GeNN -{ -namespace Type -{ -class NumericBase; -} -} +#include "type.h" //---------------------------------------------------------------------------- // Macros @@ -65,31 +57,41 @@ class GENN_EXPORT Base //! An extra global parameter has a name and a type struct EGP { - EGP(const std::string &n, const Type::NumericBase *t); - EGP(const std::string &n, const std::string &t); + EGP(const std::string &n, const Type::ResolvedType &t) : name(n), type(t) + {} + EGP(const std::string &n, const std::string &t) : name(n), type(t) + {} - bool operator == (const EGP &other) const; + bool operator == (const EGP &other) const + { + return (std::tie(name, type) == std::tie(other.name, other.type)); + } - const std::string name; - const Type::NumericBase *type; + std::string name; + Type::UnresolvedType type; }; //! Additional input variables, row state variables and other things have a name, a type and an initial value struct ParamVal { - ParamVal(const std::string &n, const Type::NumericBase *t, const std::string &v) : name(n), type(t), value(v) + ParamVal(const std::string &n, const Type::ResolvedType &t, const std::string &v) : name(n), type(t), value(v) + {} + ParamVal(const std::string &n, const Type::ResolvedType &t, double v) : ParamVal(n, t, Utils::writePreciseString(v)) {} - ParamVal(const std::string &n, const Type::NumericBase *t, double v) : ParamVal(n, t, Utils::writePreciseString(v)) + ParamVal(const std::string &n, const std::string &t, const std::string &v) : name(n), type(t), value(v) {} - ParamVal(const std::string &n, const std::string &t, const std::string &v); ParamVal(const std::string &n, const std::string &t, double v) : ParamVal(n, t, Utils::writePreciseString(v)) {} - bool operator == (const ParamVal &other) const; + bool operator == (const ParamVal &other) const + { + // **THINK** why isn't value included? + return (std::tie(name, type) == std::tie(other.name, other.type)); + } - const std::string name; - const Type::NumericBase *type; - const std::string value; + std::string name; + Type::UnresolvedType type; + std::string value; }; //! A derived parameter has a name and a function for obtaining its value @@ -100,8 +102,8 @@ class GENN_EXPORT Base return (name == other.name); } - const std::string name; - const std::function&, double)> func; + std::string name; + std::function&, double)> func; }; diff --git a/include/genn/genn/synapseGroup.h b/include/genn/genn/synapseGroup.h index 48e089f302..88e5d13941 100644 --- a/include/genn/genn/synapseGroup.h +++ b/include/genn/genn/synapseGroup.h @@ -314,7 +314,7 @@ class GENN_EXPORT SynapseGroup bool isWUPostModelFused() const { return m_FusedWUPostVarSuffix != getName(); } //! Get the type to use for sparse connectivity indices for synapse group - const Type::NumericBase *getSparseIndType() const; + const Type::ResolvedType &getSparseIndType() const; //! Generate hash of weight update component of this synapse group /*! NOTE: this can only be called after model is finalized */ diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 99e7f8802e..53876ae4ae 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -190,15 +190,16 @@ void BackendBase::genCustomConnectivityUpdateIndexCalculation(CodeStream &os, co } } //---------------------------------------------------------------------------- -std::string BackendBase::getReductionInitialValue(VarAccessMode access, const Type::NumericBase *type, const Type::TypeContext &context) const +std::string BackendBase::getReductionInitialValue(VarAccessMode access, const Type::ResolvedType &type) const { // If reduction is a sum, initialise to zero + assert(type.isNumeric()); if(access & VarAccessModeAttribute::SUM) { return "0"; } // Otherwise, reduction is a maximum operation, return lowest value for type else if(access & VarAccessModeAttribute::MAX) { - return Utils::writePreciseString(type->getLowest(context)); + return Utils::writePreciseString(type.getNumeric().lowest); } else { assert(false); @@ -206,17 +207,18 @@ std::string BackendBase::getReductionInitialValue(VarAccessMode access, const Ty } } //---------------------------------------------------------------------------- -std::string BackendBase::getReductionOperation(const std::string &reduction, const std::string &value, VarAccessMode access, - const Type::NumericBase *type, const Type::TypeContext &context) const +std::string BackendBase::getReductionOperation(const std::string &reduction, const std::string &value, + VarAccessMode access, const Type::ResolvedType &type) const { // If operation is sum, add output of custom update to sum + assert(type.isNumeric()); if(access & VarAccessModeAttribute::SUM) { return reduction + " += " + value; } // Otherwise, if it's max else if(access & VarAccessModeAttribute::MAX) { // If type is integral, generate max call - if(type->isIntegral(context)) { + if(type.getNumeric().isIntegral) { return reduction + " = " + "max(" + reduction + ", " + value + ")"; } diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index 4040ad149e..327e5e7725 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -326,16 +326,12 @@ void functionSubstitute(std::string &code, const std::string &funcName, } } //---------------------------------------------------------------------------- -void genTypeRange(CodeStream &os, const Type::NumericBase *precision, const Type::TypeContext &typeContext, const std::string &prefix) +void genTypeRange(CodeStream &os, const Type::ResolvedType &type, const std::string &prefix) { + const auto &numeric = type.getNumeric(); + os << "#define " << prefix << "_MIN " << Utils::writePreciseString(numeric.min, numeric.maxDigits10) << numeric.literalSuffix << std::endl << std::endl; - os << "#define " << prefix << "_MIN "; - Utils::writePreciseString(os, precision->getMin(typeContext), precision->getMaxDigits10(typeContext)); - os << precision->getLiteralSuffix(typeContext) << std::endl << std::endl; - - os << "#define " << prefix << "_MAX "; - Utils::writePreciseString(os, precision->getMax(typeContext), precision->getMaxDigits10(typeContext)); - os << precision->getLiteralSuffix(typeContext) << std::endl; + os << "#define " << prefix << "_MAX " << Utils::writePreciseString(numeric.max, numeric.maxDigits10) << numeric.literalSuffix << std::endl; } //---------------------------------------------------------------------------- std::string ensureFtype(const std::string &oldcode, const std::string &type) diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index c8c7f73d39..eb6d67bbe4 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -256,7 +256,7 @@ void genStatePushPull(CodeStream &definitionsFunc, CodeStream &runnerPushFunc, C void genVariable(const ModelSpecMerged &modelMerged, const BackendBase &backend, CodeStream &definitionsVar, CodeStream &definitionsFunc, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, CodeStream &push, CodeStream &pull, - const Type::ValueBase *type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count, MemAlloc &mem, std::vector &statePushPullFunction) { @@ -274,7 +274,7 @@ void genVariable(const ModelSpecMerged &modelMerged, const BackendBase &backend, //------------------------------------------------------------------------- void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase &backend, CodeStream &definitionsVar, CodeStream &definitionsFunc, CodeStream &definitionsInternalVar, CodeStream &runner, - CodeStream &extraGlobalParam, const Type::NumericBase *type, const std::string &name, bool apiRequired, VarLocation loc) + CodeStream &extraGlobalParam, const Type::ResolvedType &type, const std::string &name, bool apiRequired, VarLocation loc) { // Generate variables backend.genVariableDefinition(definitionsVar, definitionsInternalVar, type, name, loc); diff --git a/src/genn/genn/modelSpec.cc b/src/genn/genn/modelSpec.cc index 5aa032abc3..fd81fe5b33 100644 --- a/src/genn/genn/modelSpec.cc +++ b/src/genn/genn/modelSpec.cc @@ -36,7 +36,7 @@ namespace GeNN { ModelSpec::ModelSpec() -: m_Precision(Type::Float::getInstance()), m_TimePrecision(nullptr), m_DT(0.5), m_TimingEnabled(false), m_Seed(0), +: m_Precision(Type::Float), m_TimePrecision(std::nullopt), m_DT(0.5), m_TimingEnabled(false), m_Seed(0), m_DefaultVarLocation(VarLocation::HOST_DEVICE), m_DefaultExtraGlobalParamLocation(VarLocation::HOST_DEVICE), m_DefaultSparseConnectivityLocation(VarLocation::HOST_DEVICE), m_DefaultNarrowSparseIndEnabled(false), m_ShouldFusePostsynapticModels(false), m_ShouldFusePrePostWeightUpdateModels(false), m_BatchSize(1) @@ -47,6 +47,32 @@ ModelSpec::~ModelSpec() { } // --------------------------------------------------------------------------- +void ModelSpec::setPrecision(const Type::ResolvedType &precision) +{ + if (!precision.isNumeric()) { + throw std::runtime_error("Only numeric types can be used for precision"); + } + else { + if (precision.getNumeric().isIntegral) { + throw std::runtime_error("Only floating point types can be used for precision"); + } + m_Precision = precision; + } +} +// --------------------------------------------------------------------------- +void ModelSpec::setTimePrecision(const Type::ResolvedType &timePrecision) +{ + if (!timePrecision.isNumeric()) { + throw std::runtime_error("Only numeric types can be used for timeprecision"); + } + else { + if (timePrecision.getNumeric().isIntegral) { + throw std::runtime_error("Only floating point types can be used for time precision"); + } + m_TimePrecision = timePrecision; + } +} +// --------------------------------------------------------------------------- unsigned int ModelSpec::getNumNeurons() const { // Return sum of local neuron group sizes @@ -349,8 +375,8 @@ boost::uuids::detail::sha1::digest_type ModelSpec::getHashDigest() const boost::uuids::detail::sha1 hash; Utils::updateHash(getName(), hash); - Utils::updateHash(getPrecision()->getName(), hash); - Utils::updateHash(getTimePrecision()->getName(), hash); + Type::updateHash(getPrecision(), hash); + Type::updateHash(getTimePrecision(), hash); Utils::updateHash(getDT(), hash); Utils::updateHash(isTimingEnabled(), hash); Utils::updateHash(getBatchSize(), hash); diff --git a/src/genn/genn/snippet.cc b/src/genn/genn/snippet.cc index 95d4e5e6f7..0e511318f2 100644 --- a/src/genn/genn/snippet.cc +++ b/src/genn/genn/snippet.cc @@ -2,42 +2,12 @@ // GeNN includes #include "logging.h" -#include "type.h" //---------------------------------------------------------------------------- -// GeNN::Snippet::Base::EGP +// GeNN::Snippet::Base //---------------------------------------------------------------------------- namespace GeNN::Snippet { -Base::EGP::EGP(const std::string &n, const std::string &t) -: name(n), type(Type::parseNumeric((t.back() == '*') ? t.substr(0, t.length() - 1) : t)) -{ - // If type ends in a *, give warning as this is legacy syntax - if(t.back() == '*') { - LOGW_GENN << "Extra global parameters are now always arrays so * at end of type is no longer necessary"; - } -} -//---------------------------------------------------------------------------- -bool Base::EGP::operator == (const EGP &other) const -{ - return ((name == other.name) && (type->getName() == other.type->getName())); -} - -//---------------------------------------------------------------------------- -// GeNN::Snippet::Base::ParamVal -//---------------------------------------------------------------------------- -Base::ParamVal::ParamVal(const std::string &n, const std::string &t, const std::string &v) : name(n), type(Type::parseNumeric(t)), value(v) -{ -} -//---------------------------------------------------------------------------- -bool Base::ParamVal::operator == (const ParamVal &other) const -{ - return ((name == other.name) && (type->getName() == other.type->getName()) && (value == other.value)); -} - -//---------------------------------------------------------------------------- -// GeNN::Snippet::Base -//---------------------------------------------------------------------------- void Base::updateHash(boost::uuids::detail::sha1 &hash) const { Utils::updateHash(getParamNames(), hash); @@ -72,13 +42,13 @@ void Base::validate(const std::unordered_map ¶mValues, void updateHash(const Base::EGP &e, boost::uuids::detail::sha1 &hash) { Utils::updateHash(e.name, hash); - Utils::updateHash(e.type->getName(), hash); + Type::updateHash(e.type, hash); } //---------------------------------------------------------------------------- void updateHash(const Base::ParamVal &p, boost::uuids::detail::sha1 &hash) { Utils::updateHash(p.name, hash); - Utils::updateHash(p.type->getName(), hash); + Type::updateHash(p.type, hash); Utils::updateHash(p.value, hash); } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 35ed646961..ceb713c65a 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -702,23 +702,23 @@ bool SynapseGroup::canPreOutputBeFused() const return true; } //---------------------------------------------------------------------------- -const Type::NumericBase *SynapseGroup::getSparseIndType() const +const Type::ResolvedType &SynapseGroup::getSparseIndType() const { // If narrow sparse inds are enabled if(m_NarrowSparseIndEnabled) { // If number of target neurons can be represented using a uint8, use this type const unsigned int numTrgNeurons = getTrgNeuronGroup()->getNumNeurons(); - if(numTrgNeurons <= Type::Uint8::getInstance()->getMax({})) { - return Type::Uint8::getInstance();; + if(numTrgNeurons <= Type::Uint8.getNumeric().max) { + return Type::Uint8; } // Otherwise, if they can be represented as a uint16, use this type - else if(numTrgNeurons <= Type::Uint16::getInstance()->getMax({})) { - return Type::Uint16::getInstance(); + else if(numTrgNeurons <= Type::Uint16.getNumeric().max) { + return Type::Uint16; } } // Otherwise, use 32-bit int - return Type::Uint32::getInstance(); + return Type::Uint32; } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type SynapseGroup::getWUHashDigest() const From 91da0fe54e5336fc1b25b26428e273e071ef4c91 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 25 May 2023 11:39:21 +0100 Subject: [PATCH 159/725] more updating of runner, group merged etc --- .../backends/single_threaded_cpu/backend.h | 6 +- .../genn/genn/code_generator/backendBase.h | 51 ++-- .../genn/genn/code_generator/environment.h | 7 +- .../genn/genn/code_generator/groupMerged.h | 62 ++--- .../genn/code_generator/modelSpecMerged.h | 16 +- .../genn/code_generator/supportCodeMerged.h | 2 +- include/genn/genn/transpiler/prettyPrinter.h | 2 +- include/genn/genn/type.h | 34 ++- src/genn/backends/cuda/backend.cc | 49 +--- .../backends/single_threaded_cpu/backend.cc | 23 +- src/genn/genn/code_generator/backendBase.cc | 4 +- src/genn/genn/code_generator/environment.cc | 11 +- .../genn/code_generator/generateRunner.cc | 239 +++++++++--------- src/genn/genn/code_generator/groupMerged.cc | 145 ++++++----- src/genn/genn/transpiler/typeChecker.cc | 59 ++--- src/genn/genn/type.cc | 51 +++- 16 files changed, 368 insertions(+), 393 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 7d07254dc2..6f960c250a 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -69,7 +69,7 @@ class BACKEND_EXPORT Backend : public BackendBase //! Generate code to allocate variable with a size known at compile-time virtual void genVariableAllocation(CodeStream &os, - const Type::ResolvedType &type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const final; //! Generate code to allocate variable with a size known at runtime @@ -130,10 +130,10 @@ class BACKEND_EXPORT Backend : public BackendBase virtual void genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged &cu, const Substitutions &kernelSubs, Handler handler) const final; virtual void genGlobalDeviceRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, - CodeStream &allocations, CodeStream &free, const Type::TypeContext &typeContext, MemAlloc &memAlloc) const final; + CodeStream &allocations, CodeStream &free, MemAlloc &memAlloc) const final; virtual void genPopulationRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const Type::TypeContext &typeContext, const std::string &name, size_t count, MemAlloc &memAlloc) const final; + const std::string &name, size_t count, MemAlloc &memAlloc) const final; virtual void genTimer(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, CodeStream &stepTimeFinalise, const std::string &name, bool updateInStepTime) const final; diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 43c843bd12..fe3a239749 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -259,7 +259,7 @@ class GENN_EXPORT BackendBase //! Generate code to allocate variable with a size known at compile-time virtual void genVariableAllocation(CodeStream &os, - const Type::ResolvedType &type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const = 0; //! Generate code to allocate variable with a size known at runtime @@ -322,12 +322,12 @@ class GENN_EXPORT BackendBase //! Generate a single RNG instance /*! On single-threaded platforms this can be a standard RNG like M.T. but, on parallel platforms, it is likely to be a counter-based RNG */ virtual void genGlobalDeviceRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, - CodeStream &allocations, CodeStream &free, const Type::TypeContext &typeContext, MemAlloc &memAlloc) const = 0; + CodeStream &allocations, CodeStream &free, MemAlloc &memAlloc) const = 0; //! Generate an RNG with a state per population member virtual void genPopulationRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const Type::TypeContext &typeContext, const std::string &name, size_t count, MemAlloc &memAlloc) const = 0; + const std::string &name, size_t count, MemAlloc &memAlloc) const = 0; virtual void genTimer(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, CodeStream &stepTimeFinalise, const std::string &name, bool updateInStepTime) const = 0; @@ -427,15 +427,6 @@ class GENN_EXPORT BackendBase genVariablePull(pull, type, name, loc, count); } - //! Templated version of helper function to generate matching push and pull functions for - //! a variable when type is known at compile time - template - void genVariablePushPull(CodeStream &push, CodeStream &pull, - const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const - { - genVariablePushPull(push, pull, T::getInstance(), name, loc, autoInitialized, count); - } - //! Helper function to generate matching push and pull functions for the current state of a variable void genCurrentVariablePushPull(CodeStream &push, CodeStream &pull, const NeuronGroupInternal &ng, const Type::ResolvedType &type, const std::string &name, @@ -445,35 +436,18 @@ class GENN_EXPORT BackendBase genCurrentVariablePull(pull, ng, type, name, loc, batchSize); } - //! Templated version of gelper function to generate matching push and pull functions - //! for the current state of variable when type is known at compile time - template - void genCurrentVariablePushPull(CodeStream &push, CodeStream &pull, const NeuronGroupInternal &ng, - const std::string &name, VarLocation loc, unsigned int batchSize) const - { - genCurrentVariablePushPull(push, pull, ng, T::getInstance(), name, loc, batchSize); - } //! Helper function to generate matching definition, declaration, allocation and free code for a statically-sized array void genArray(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const Type::ResolvedType &type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const { genVariableDefinition(definitions, definitionsInternal, type, name, loc); genVariableInstantiation(runner, type, name, loc); genVariableFree(free, name, loc); - genVariableAllocation(allocations, type, typeContext, name, loc, count, memAlloc); + genVariableAllocation(allocations, type, name, loc, count, memAlloc); } - //! Templated version of helper function to generate matching definition, declaration, - //! allocation and free code for a statically-sized array when type is known at compile-time - template - void genArray(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const Type::TypeContext &typeContext, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const - { - genArray(definitions, definitionsInternal, runner, allocations, free, T::getInstance(), typeContext, name, loc, count, memAlloc); - - } //! Get the prefix for accessing the address of 'scalar' variables std::string getScalarAddressPrefix() const { @@ -482,6 +456,9 @@ class GENN_EXPORT BackendBase bool areSixtyFourBitSynapseIndicesRequired(const SynapseGroupMergedBase &sg) const; + //! Get backend-specific pointer size in bytes + size_t getPointerBytes() const{ return m_PointerBytes; } + const PreferencesBase &getPreferences() const { return m_Preferences; } template @@ -503,6 +480,11 @@ class GENN_EXPORT BackendBase //-------------------------------------------------------------------------- // Protected API //-------------------------------------------------------------------------- + void setPointerBytes(size_t pointerBytes) + { + m_PointerBytes = pointerBytes; + } + void genNeuronIndexCalculation(CodeStream &os, const NeuronUpdateGroupMerged &ng, unsigned int batchSize) const; void genSynapseIndexCalculation(CodeStream &os, const SynapseGroupMergedBase &sg, unsigned int batchSize) const; @@ -540,7 +522,7 @@ class GENN_EXPORT BackendBase for (const auto &v : cm->getVars()) { // If variable is a reduction target, define variable initialised to correct initial value for reduction if (v.access & VarAccessModeAttribute::REDUCE) { - os << v.type->getName() << " lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), v.type, cg.getTypeContext()) << ";" << std::endl; + os << v.type->getName() << " lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), v.type) << ";" << std::endl; reductionTargets.emplace_back(v.name, v.type, getVarAccessMode(v.access), cg.getVarIndex(getVarAccessDuplication(v.access), idx)); } @@ -552,7 +534,7 @@ class GENN_EXPORT BackendBase // If variable reference is a reduction target, define variable initialised to correct initial value for reduction if (modelVarRef.access & VarAccessModeAttribute::REDUCE) { - os << modelVarRef.type->getName() << " lr" << modelVarRef.name << " = " << getReductionInitialValue(modelVarRef.access, modelVarRef.type, cg.getTypeContext()) << ";" << std::endl; + os << modelVarRef.type->getName() << " lr" << modelVarRef.name << " = " << getReductionInitialValue(modelVarRef.access, modelVarRef.type) << ";" << std::endl; reductionTargets.emplace_back(modelVarRef.name, modelVarRef.type, modelVarRef.access, getVarRefIndexFn(varRef, idx)); } @@ -564,6 +546,9 @@ class GENN_EXPORT BackendBase //-------------------------------------------------------------------------- // Members //-------------------------------------------------------------------------- + //! How large is a device pointer? E.g. on some AMD devices this != sizeof(char*) + size_t m_PointerBytes; + //! Preferences const PreferencesBase &m_Preferences; }; diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index e2e047dc6a..121ed4d1e8 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -14,7 +14,6 @@ // GeNN transpiler includes #include "transpiler/prettyPrinter.h" -#include "transpiler/transpilerUtils.h" //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentExternal @@ -51,7 +50,7 @@ class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase CodeStream &getContextStream() const; - std::string getContextName(const std::string &name, const Type::Base *type) const; + std::string getContextName(const std::string &name, const Type::ResolvedType &type) const; private: //------------------------------------------------------------------------ @@ -89,7 +88,7 @@ class EnvironmentSubstitute : public EnvironmentExternal //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual std::string getName(const std::string &name, const Type::Base *type = nullptr) final; + virtual std::string getName(const std::string &name, const Type::ResolvedType &type) final; virtual CodeStream &getStream() final { @@ -231,7 +230,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual std::string getName(const std::string &name, const Type::Base *type = nullptr) final + virtual std::string getName(const std::string &name, const Type::ResolvedType &type) final { // If variable with this name isn't found, try and get name from context auto var = m_VariablesReferenced.find(name); diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 5e06c3dc03..740fc11474 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -64,7 +64,7 @@ class GroupMerged typedef G GroupInternal; typedef std::function GetFieldValueFunc; typedef std::function GetFieldDoubleValueFunc; - typedef std::tuple Field; + typedef std::tuple Field; GroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> groups) : m_Index(index), m_TypeContext(typeContext), m_Groups(std::move(groups)) @@ -99,10 +99,11 @@ class GroupMerged // Make a copy of fields and sort so largest come first. This should mean that due // to structure packing rules, significant memory is saved and estimate is more precise auto sortedFields = m_Fields; + const size_t pointerBytes = backend.getPointerBytes(); std::sort(sortedFields.begin(), sortedFields.end(), - [&backend, this](const Field &a, const Field &b) + [pointerBytes](const Field &a, const Field &b) { - return (std::get<0>(a)->getSizeBytes(m_TypeContext) > std::get<0>(b)->getSizeBytes(m_TypeContext)); + return (std::get<0>(a).getSize(pointerBytes) > std::get<0>(b).getSize(pointerBytes)); }); return sortedFields; @@ -119,20 +120,20 @@ class GroupMerged for(const auto &f : sortedFields) { // If field is a pointer and not marked as being a host field // (in which case the backend should leave its type alone!) - const auto *type = std::get<0>(f); - if(dynamic_cast(type) && !(std::get<3>(f) & GroupMergedFieldType::HOST)) { + const auto &type = std::get<0>(f); + if(type.isPointer() && !(std::get<3>(f) & GroupMergedFieldType::HOST)) { // If we are generating a host structure, allow the backend to override the type if(host) { os << backend.getMergedGroupFieldHostTypeName(type); } // Otherwise, allow the backend to add a prefix else { - os << backend.getPointerPrefix() << type->getName(); + os << backend.getPointerPrefix() << type.getName(); } } // Otherwise, leave the type alone else { - os << type->getName(); + os << type.getName(); } os << " " << std::get<1>(f) << ";" << std::endl; } @@ -205,8 +206,8 @@ class GroupMerged //------------------------------------------------------------------------ // Protected methods //------------------------------------------------------------------------ - const Type::NumericBase *getScalarType() const{ return dynamic_cast(m_TypeContext.at("scalar")); } - const Type::NumericBase *getTimeType() const{ return dynamic_cast(m_TypeContext.at("timepoint")); } + const Type::ResolvedType &getScalarType() const{ return m_TypeContext.at("scalar"); } + const Type::ResolvedType &getTimeType() const{ return m_TypeContext.at("timepoint"); } //! Helper to test whether parameter is referenced in vector of codestrings bool isParamReferenced(const std::vector &codeStrings, const std::string ¶mName) const @@ -248,38 +249,27 @@ class GroupMerged }); } - void addField(const Type::Base *type, const std::string &name, GetFieldValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) + void addField(const Type::ResolvedType &type, const std::string &name, GetFieldValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) { // Add field to data structure m_Fields.emplace_back(type, name, getFieldValue, fieldType); } - template - void addField(const std::string &name, GetFieldValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) - { - // Add field to data structure - m_Fields.emplace_back(T::getInstance(), name, getFieldValue, fieldType); - } - void addScalarField(const std::string &name, GetFieldDoubleValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) { addField(getScalarType(), name, [getFieldValue, this](const G &g, size_t i) { - return Utils::writePreciseString(getFieldValue(g, i), getScalarType()->getMaxDigits10(m_TypeContext)) + getScalarType()->getLiteralSuffix(m_TypeContext); + return Utils::writePreciseString(getFieldValue(g, i), getScalarType().getNumeric().maxDigits10) + getScalarType().getNumeric().literalSuffix; }, fieldType); } - void addPointerField(const Type::Base *type, const std::string &name, const std::string &prefix) - { - addField(type->getPointerType(), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); - } - - template - void addPointerField(const std::string &name, const std::string &prefix) + void addPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix) { - addField(T::getInstance()->getPointerType(), name, [prefix](const G &g, size_t) { return prefix + g.getName(); }); + assert(type.isValue()); + addField(type.createPointer(), name, + [prefix](const G &g, size_t) { return prefix + g.getName(); }); } @@ -296,7 +286,7 @@ class GroupMerged { // Loop through variables for(const auto &v : varReferences) { - addField(v.type->getPointerType(), v.name, + addField(v.type.resolve(getTypeContext()).createPointer(), v.name, [getVarRefFn, arrayPrefix, v](const G &g, size_t) { const auto varRef = getVarRefFn(g).at(v.name); @@ -308,7 +298,7 @@ class GroupMerged void addEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "") { for(const auto &e : egps) { - addField(e.type->getPointerType(), e.name + varName, + addField(e.type.resolve(getTypeContext()).createPointer(), e.name + varName, [e, arrayPrefix, varName](const G &g, size_t) { return arrayPrefix + e.name + varName + g.getName(); }, GroupMergedFieldType::DYNAMIC); } @@ -479,7 +469,7 @@ class GroupMerged // Loop through fields again to generate any EGP pushing functions that are require for(const auto &f : sortedFields) { // If this field is a dynamic pointer - if((std::get<3>(f) & GroupMergedFieldType::DYNAMIC) && dynamic_cast(std::get<0>(f))) { + if((std::get<3>(f) & GroupMergedFieldType::DYNAMIC) && std::get<0>(f).isPointer()) { definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << getIndex() << std::get<1>(f) << "ToDevice(unsigned int idx, "; definitionsInternalFunc << backend.getMergedGroupFieldHostTypeName(std::get<0>(f)) << " value);" << std::endl; } @@ -832,7 +822,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMergedgetPointerType(), e.name + prefix + std::to_string(childIndex), + addField(e.type.resolve(getTypeContext()).createPointer(), e.name + prefix + std::to_string(childIndex), [getEGPSuffixFn, childIndex, e, arrayPrefix](const NeuronGroupInternal&, size_t groupIndex) { return arrayPrefix + e.name + getEGPSuffixFn(groupIndex, childIndex); @@ -934,10 +924,10 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMergedgetName(), fieldName, hostGroup) - < std::make_tuple(other.mergedGroupIndex, other.type->getName(), other.fieldName, other.hostGroup)); + return (std::make_tuple(mergedGroupIndex, type, fieldName, hostGroup) + < std::make_tuple(other.mergedGroupIndex, other.type, other.fieldName, other.hostGroup)); } }; @@ -63,7 +63,7 @@ class GENN_EXPORT ModelSpecMerged //! Immutable structure for tracking where an extra global variable ends up after merging struct MergedEGP : public EGPField { - MergedEGP(size_t m, size_t g, const Type::Pointer *t, const std::string &f, bool h) + MergedEGP(size_t m, size_t g, const Type::ResolvedType &t, const std::string &f, bool h) : EGPField(m, t, f, h), groupIndex(g) {} const size_t groupIndex; diff --git a/include/genn/genn/code_generator/supportCodeMerged.h b/include/genn/genn/code_generator/supportCodeMerged.h index 4afcedbee2..f3dd100e2c 100644 --- a/include/genn/genn/code_generator/supportCodeMerged.h +++ b/include/genn/genn/code_generator/supportCodeMerged.h @@ -44,7 +44,7 @@ class SupportCodeMerged } //! Generate support code - void gen(CodeStream &os, const Type::NumericBase *scalarType, const bool supportsNamespace = true) const + void gen(CodeStream &os, const Type::ResolvedType &scalarType, const bool supportsNamespace = true) const { // Loop through support code for(const auto &s : m_SupportCode) { diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h index c9e3949337..3ad5e9cd49 100644 --- a/include/genn/genn/transpiler/prettyPrinter.h +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -31,7 +31,7 @@ class EnvironmentBase virtual std::string define(const std::string &name) = 0; //! Get the name to use in code for the variable named by token - virtual std::string getName(const std::string &name, const Type::Base *type = nullptr) = 0; + virtual std::string getName(const std::string &name, const Type::ResolvedType &type) = 0; //! Get stream to write code within this environment to virtual CodeGenerator::CodeStream &getStream() = 0; diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 96a93af085..8b2052a599 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -93,8 +93,8 @@ struct ResolvedType //------------------------------------------------------------------------ struct Value { + std::string name; size_t size; - std::string name; // **TODO** delete me std::optional numeric; //------------------------------------------------------------------------ @@ -188,14 +188,14 @@ struct ResolvedType } }; - ResolvedType(Qualifier qualifiers, const Value &value) + ResolvedType(const Value &value, Qualifier qualifiers = Qualifier{0}) : qualifiers(qualifiers), detail(value) {} - ResolvedType(Qualifier qualifiers, const Pointer &pointer) + ResolvedType(const Pointer &pointer, Qualifier qualifiers = Qualifier{0}) : qualifiers(qualifiers), detail(pointer) {} ResolvedType(const Function &function) - : qualifiers(Qualifier{0}), detail(function) + : qualifiers(Qualifier{0}), detail(function) {} ResolvedType(const ResolvedType &other, Qualifier qualifiers) : qualifiers(qualifiers), detail(other.detail) {} @@ -214,6 +214,7 @@ struct ResolvedType bool isPointer() const{ return std::holds_alternative(detail); } bool isFunction() const{ return std::holds_alternative(detail); } bool isNumeric() const{ return isValue() && getValue().numeric; } + const Value &getValue() const{ return std::get(detail); } const Pointer &getPointer() const{ return std::get(detail); } const Function &getFunction() const{ return std::get(detail); } @@ -222,6 +223,14 @@ struct ResolvedType const ResolvedType addQualifier(Qualifier qualifier) const{ return ResolvedType(*this, qualifier); } bool hasQualifier(Qualifier qualifier) const{ return (qualifiers & qualifier); } + std::string getName() const; + size_t getSize(size_t pointerBytes) const; + + ResolvedType createPointer(Qualifier qualifiers = Qualifier{0}) const + { + return ResolvedType(Pointer{*this}, qualifiers); + } + //------------------------------------------------------------------------ // Operators //------------------------------------------------------------------------ @@ -246,17 +255,21 @@ struct ResolvedType template static ResolvedType createNumeric(const std::string &name, int rank, const std::string &literalSuffix = "", Qualifier qualifiers = Qualifier{0}) { - return ResolvedType{qualifiers, Value{sizeof(T), name, Numeric{rank, std::numeric_limits::min(), std::numeric_limits::max(), - std::numeric_limits::lowest(), std::numeric_limits::max_digits10, - std::is_signed::value, std::is_integral::value, literalSuffix}}}; + return ResolvedType{Value{name, sizeof(T), Numeric{rank, std::numeric_limits::min(), std::numeric_limits::max(), + std::numeric_limits::lowest(), std::numeric_limits::max_digits10, + std::is_signed::value, std::is_integral::value, literalSuffix}}, + qualifiers}; } - static ResolvedType createPointer(const ResolvedType &valueType, Qualifier qualifiers = Qualifier{0}) + template + static ResolvedType createValue(const std::string &name, Qualifier qualifiers = Qualifier{0}) { - return ResolvedType(qualifiers, Pointer{valueType}); + return ResolvedType{Value{name, sizeof(T), std::nullopt}, qualifiers}; } }; +typedef std::unordered_map TypeContext; + //---------------------------------------------------------------------------- // UnresolvedType //---------------------------------------------------------------------------- @@ -277,7 +290,7 @@ struct UnresolvedType //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - ResolvedType resolve(const std::unordered_map &typeContext) const; + ResolvedType resolve(const TypeContext &typeContext) const; //------------------------------------------------------------------------ // Operators @@ -298,7 +311,6 @@ struct UnresolvedType } }; -typedef std::unordered_map TypeContext; //---------------------------------------------------------------------------- // Declare numeric types diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 0f51347ab8..031e9a5ac2 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -4,6 +4,9 @@ #include #include +// CUDA includes +#include + // GeNN includes #include "gennUtils.h" #include "logging.h" @@ -47,50 +50,8 @@ const std::vector cudaDoublePrecisionFunctions //-------------------------------------------------------------------------- // CUDADeviceType //-------------------------------------------------------------------------- -//! Tag class used to mark types which are only usable on device -struct CUDADeviceType -{ -}; - -//-------------------------------------------------------------------------- -// CURandState -//-------------------------------------------------------------------------- -class CURandState : public Type::ValueBase, public CUDADeviceType -{ -public: - DECLARE_TYPE(CURandState); - - CURandState(Type::Qualifier qualifiers = Type::Qualifier{0}) : ValueBase(qualifiers){} - - //------------------------------------------------------------------------ - // Base overloads - //------------------------------------------------------------------------ - virtual std::string getName() const final{ return "curandState"; } - virtual std::string getResolvedName(const Type::TypeContext&) const final{ return "curandState"; } - virtual Base *getQualifiedType(Type::Qualifier qualifiers) const { return new CURandState(qualifiers); } - virtual size_t getSizeBytes(const Type::TypeContext&) const final{ return 44; } -}; -IMPLEMENT_TYPE(CURandState); - -//-------------------------------------------------------------------------- -// CURandStatePhilox43210 -//-------------------------------------------------------------------------- -class CURandStatePhilox43210 : public Type::ValueBase, public CUDADeviceType -{ -public: - DECLARE_TYPE(CURandStatePhilox43210); - - CURandStatePhilox43210(Type::Qualifier qualifiers = Type::Qualifier{0}) : ValueBase(qualifiers){} - - //------------------------------------------------------------------------ - // Base overloads - //------------------------------------------------------------------------ - virtual std::string getName() const final{ return "curandStatePhilox4_32_10_t"; } - virtual std::string getResolvedName(const Type::TypeContext&) const final{ return "curandStatePhilox4_32_10_t"; } - virtual Base *getQualifiedType(Type::Qualifier qualifiers) const { return new CURandStatePhilox43210(qualifiers); } - virtual size_t getSizeBytes(const Type::TypeContext&) const final{ return 64; } -}; -IMPLEMENT_TYPE(CURandStatePhilox43210); +const Type::ResolvedType CURandState = Type::ResolvedType::createValue(); +const Type::ResolvedType CURandStatePhilox43210 = Type::ResolvedType::createValue(); //-------------------------------------------------------------------------- // Timer diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index e08cbca98c..447fd6a0f7 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -579,7 +579,7 @@ void Backend::genCustomUpdate(CodeStream &os_, const ModelSpecMerged &modelMerge // Loop through reduction targets and generate reduction for (const auto &r : reductionTargets) { - env.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; + env.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; } } @@ -1295,22 +1295,21 @@ void Backend::genStepTimeFinalisePreamble(CodeStream &, const ModelSpecMerged &) void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &, const Type::ResolvedType &type, const std::string &name, VarLocation) const { - definitions << "EXPORT_VAR " << type.getNumeric().name << "* " << name << ";" << std::endl; + definitions << "EXPORT_VAR " << type.getValue().name << "* " << name << ";" << std::endl; } //-------------------------------------------------------------------------- void Backend::genVariableInstantiation(CodeStream &os, const Type::ResolvedType &type, const std::string &name, VarLocation) const { - os << type.getNumeric().name << "* " << name << ";" << std::endl; + os << type.getValue().name << "* " << name << ";" << std::endl; } //-------------------------------------------------------------------------- -void Backend::genVariableAllocation(CodeStream &os, - const Type::ResolvedType &type, const Type::TypeContext &typeContext, const std::string &name, +void Backend::genVariableAllocation(CodeStream &os, const Type::ResolvedType &type, const std::string &name, VarLocation, size_t count, MemAlloc &memAlloc) const { - os << name << " = new " << type.getNumeric().name << "[" << count << "];" << std::endl; + os << name << " = new " << type.getValue().name << "[" << count << "];" << std::endl; - memAlloc += MemAlloc::host(count * type.size); + memAlloc += MemAlloc::host(count * type.getValue().size); } //-------------------------------------------------------------------------- void Backend::genVariableDynamicAllocation(CodeStream &os, @@ -1318,10 +1317,10 @@ void Backend::genVariableDynamicAllocation(CodeStream &os, const std::string &countVarName, const std::string &prefix) const { if (type.isPointer()) { - os << "*" << prefix << name << " = new " << type.getPointer().valueType->getNumeric().name << "[" << countVarName << "];" << std::endl; + os << "*" << prefix << name << " = new " << type.getPointer().valueType->getValue().name << "[" << countVarName << "];" << std::endl; } else { - os << prefix << name << " = new " << type.getNumeric().name << "[" << countVarName << "];" << std::endl; + os << prefix << name << " = new " << type.getValue().name << "[" << countVarName << "];" << std::endl; } } //-------------------------------------------------------------------------- @@ -1378,7 +1377,7 @@ void Backend::genMergedDynamicVariablePush(CodeStream &os, const std::string &su //-------------------------------------------------------------------------- std::string Backend::getMergedGroupFieldHostTypeName(const Type::ResolvedType &type) const { - return type.getNumeric().name; + return type.getName(); } //-------------------------------------------------------------------------- const Type::ResolvedType &Backend::getMergedGroupSimRNGType() const @@ -1445,13 +1444,13 @@ void Backend::genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUp //genKernelIteration(os, cu, cu.getArchetype().getSynapseGroup()->getKernelSize().size(), kernelSubs, handler); } //-------------------------------------------------------------------------- -void Backend::genGlobalDeviceRNG(CodeStream&, CodeStream&, CodeStream&, CodeStream&, CodeStream&, const Type::TypeContext&, MemAlloc&) const +void Backend::genGlobalDeviceRNG(CodeStream&, CodeStream&, CodeStream&, CodeStream&, CodeStream&, MemAlloc&) const { assert(false); } //-------------------------------------------------------------------------- void Backend::genPopulationRNG(CodeStream&, CodeStream&, CodeStream&, CodeStream&, CodeStream&, - const Type::TypeContext&, const std::string&, size_t, MemAlloc&) const + const std::string&, size_t, MemAlloc&) const { } //-------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 53876ae4ae..3d4561c862 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -16,7 +16,7 @@ namespace GeNN::CodeGenerator { BackendBase::BackendBase(const PreferencesBase &preferences) -: m_Preferences(preferences) +: m_PointerBytes(sizeof(char *)), m_Preferences(preferences) { } //-------------------------------------------------------------------------- @@ -253,4 +253,4 @@ std::vector BackendBase::genInitReductionTargets(C index); }); } -} // namespace GeNN::CodeGenerator +} // namespace GeNN::CodeGenerator \ No newline at end of file diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index 743036d9bc..9b0d9078ff 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -6,6 +6,9 @@ // Standard C includes #include +// GeNN includes +#include "gennUtils.h" + using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- @@ -19,16 +22,16 @@ std::string EnvironmentExternal::define(const std::string&) CodeStream &EnvironmentExternal::getContextStream() const { return std::visit( - Transpiler::Utils::Overload{ + Utils::Overload{ [](std::reference_wrapper enclosing)->CodeStream& { return enclosing.get().getStream(); }, [](std::reference_wrapper os)->CodeStream& { return os.get(); }}, getContext()); } //---------------------------------------------------------------------------- -std::string EnvironmentExternal::getContextName(const std::string &name, const Type::Base *type) const +std::string EnvironmentExternal::getContextName(const std::string &name, const Type::ResolvedType &type) const { return std::visit( - Transpiler::Utils::Overload{ + Utils::Overload{ [&name, type](std::reference_wrapper enclosing)->std::string { return enclosing.get().getName(name, type); }, [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Variable '" + name + "' undefined"); }}, getContext()); @@ -51,7 +54,7 @@ EnvironmentSubstitute::~EnvironmentSubstitute() getContextStream() << m_ContentsStream.str(); } //---------------------------------------------------------------------------- -std::string EnvironmentSubstitute::getName(const std::string &name, const Type::Base *type) +std::string EnvironmentSubstitute::getName(const std::string &name, const Type::ResolvedType &type) { // If there isn't a substitution for this name, try and get name from context auto var = m_VarSubstitutions.find(name); diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index eb6d67bbe4..a4019b27cd 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -87,26 +87,24 @@ void genSpikeMacros(CodeStream &os, const NeuronGroupInternal &ng, bool trueSpik os << std::endl << std::endl; } //-------------------------------------------------------------------------- -template void genHostScalar(CodeStream &definitionsVar, CodeStream &runnerVarDecl, - const std::string &name, const std::string &value) + const Type::ResolvedType &type, const std::string &name, const std::string &value) { - definitionsVar << "EXPORT_VAR " << T::getInstance()->getName() << " " << name << ";" << std::endl; - runnerVarDecl << T::getInstance()->getName() << " " << name << " = " << value << ";" << std::endl; + definitionsVar << "EXPORT_VAR " << type.getValue().name << " " << name << ";" << std::endl; + runnerVarDecl << type.getValue().name << " " << name << " = " << value << ";" << std::endl; } //-------------------------------------------------------------------------- -template void genHostDeviceScalar(const ModelSpecMerged &modelMerged, const BackendBase &backend, CodeStream &definitionsVar, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerVarAlloc, CodeStream &runnerVarFree, - const std::string &name, const std::string &hostValue, MemAlloc &mem) + const Type::ResolvedType &type, const std::string &name, const std::string &hostValue, MemAlloc &mem) { // Generate a host scalar - genHostScalar(definitionsVar, runnerVarDecl, name, hostValue); + genHostScalar(definitionsVar, runnerVarDecl, type, name, hostValue); // Generate a single-element array on device if(backend.isDeviceScalarRequired()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - T::getInstance(), modelMerged.getTypeContext(), name, VarLocation::DEVICE, 1, mem); + type, name, VarLocation::DEVICE, 1, mem); } } //-------------------------------------------------------------------------- @@ -183,7 +181,7 @@ void genSpikeGetters(CodeStream &definitionsFunc, CodeStream &runnerGetterFunc, // Generate getter for current spike counts genVarGetterScope(definitionsFunc, runnerGetterFunc, - loc, ng.getName() + (trueSpike ? "CurrentSpikes" : "CurrentSpikeEvents"), "unsigned int*", + loc, ng.getName() + (trueSpike ? "CurrentSpikes" : "CurrentSpikeEvents"), "uint32_t*", [&]() { runnerGetterFunc << "return (glbSpk" << eventSuffix << ng.getName(); @@ -201,7 +199,7 @@ void genSpikeGetters(CodeStream &definitionsFunc, CodeStream &runnerGetterFunc, // Generate getter for current spikes genVarGetterScope(definitionsFunc, runnerGetterFunc, - loc, ng.getName() + (trueSpike ? "CurrentSpikeCount" : "CurrentSpikeEventCount"), "unsigned int&", + loc, ng.getName() + (trueSpike ? "CurrentSpikeCount" : "CurrentSpikeEventCount"), "uint32_t&", [&]() { runnerGetterFunc << "return glbSpkCnt" << eventSuffix << ng.getName() << "["; @@ -269,16 +267,19 @@ void genVariable(const ModelSpecMerged &modelMerged, const BackendBase &backend, // Generate variables backend.genArray(definitionsVar, definitionsInternal, runner, allocations, free, - type, modelMerged.getTypeContext(), name, loc, count, mem); + type, name, loc, count, mem); } //------------------------------------------------------------------------- void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase &backend, CodeStream &definitionsVar, CodeStream &definitionsFunc, CodeStream &definitionsInternalVar, CodeStream &runner, - CodeStream &extraGlobalParam, const Type::ResolvedType &type, const std::string &name, bool apiRequired, VarLocation loc) + CodeStream &extraGlobalParam, const Type::UnresolvedType &type, const std::string &name, bool apiRequired, VarLocation loc) { + // Resolved type + const auto resolvedType = type.resolve(modelMerged.getTypeContext()); + // Generate variables - backend.genVariableDefinition(definitionsVar, definitionsInternalVar, type, name, loc); - backend.genVariableInstantiation(runner, type, name, loc); + backend.genVariableDefinition(definitionsVar, definitionsInternalVar, resolvedType, name, loc); + backend.genVariableInstantiation(runner, resolvedType, name, loc); // If API is required if(apiRequired) { @@ -290,7 +291,7 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & extraGlobalParam << "void allocate" << name << "(unsigned int count)"; { CodeStream::Scope a(extraGlobalParam); - backend.genVariableDynamicAllocation(extraGlobalParam, type, name, loc); + backend.genVariableDynamicAllocation(extraGlobalParam, resolvedType, name, loc); // Loop through destinations in merged structures, the device EGP needs to be copied to // **TODO** rename to dynamic @@ -353,7 +354,7 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & extraGlobalParam << "void push" << name << "ToDevice(unsigned int count)"; { CodeStream::Scope a(extraGlobalParam); - backend.genVariableDynamicPush(extraGlobalParam, type, name, loc); + backend.genVariableDynamicPush(extraGlobalParam, resolvedType, name, loc); } if(backend.getPreferences().generateExtraGlobalParamPull) { @@ -364,7 +365,7 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & extraGlobalParam << "void pull" << name << "FromDevice(unsigned int count)"; { CodeGenerator::CodeStream::Scope a(extraGlobalParam); - backend.genVariableDynamicPull(extraGlobalParam, type, name, loc); + backend.genVariableDynamicPull(extraGlobalParam, resolvedType, name, loc); } } } @@ -415,8 +416,9 @@ void genRunnerVars(const ModelSpecMerged &modelMerged, const BackendBase &backen for(const auto &var : varAdaptor.getDefs()) { const auto *varInitSnippet = varAdaptor.getInitialisers().at(var.name).getSnippet(); const bool autoInitialized = !varInitSnippet->getCode().empty(); + const auto resolvedType = var.type.resolve(modelMerged.getTypeContext()); genVariable(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - runnerPushFunc, runnerPullFunc, var.type, var.name + group.getName(), varAdaptor.getLoc(var.name), + runnerPushFunc, runnerPullFunc, resolvedType, var.name + group.getName(), varAdaptor.getLoc(var.name), autoInitialized, getSizeFn(group, var), mem, statePushPullFunctions); // Loop through EGPs required to initialize variable @@ -438,8 +440,9 @@ void genRunnerFusedVars(const ModelSpecMerged &modelMerged, const BackendBase &b // Loop through variables const V varAdaptor(group); for(const auto &var : varAdaptor.getDefs()) { + const auto resolvedType = var.type.resolve(modelMerged.getTypeContext()); backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - var.type, modelMerged.getTypeContext(), var.name + varAdaptor.getFusedSuffix(), varAdaptor.getLoc(var.name), + resolvedType, var.name + varAdaptor.getFusedSuffix(), varAdaptor.getLoc(var.name), getSizeFn(group, var), mem); // Loop through EGPs required to initialize variable @@ -453,7 +456,7 @@ void genRunnerFusedVars(const ModelSpecMerged &modelMerged, const BackendBase &b } //------------------------------------------------------------------------- template -void genRunnerFusedVarPushPull(const BackendBase &backend, CodeStream &definitionsFunc, +void genRunnerFusedVarPushPull(const ModelSpecMerged &modelMerged, const BackendBase &backend, CodeStream &definitionsFunc, CodeStream &runnerPushFunc, CodeStream &runnerPullFunc, const G &group, std::vector &groupStatePushPullFunctions, S getSizeFn) { @@ -461,12 +464,13 @@ void genRunnerFusedVarPushPull(const BackendBase &backend, CodeStream &definitio const V varAdaptor(group); for(const auto &var : varAdaptor.getDefs()) { const bool autoInitialized = !varAdaptor.getInitialisers().at(var.name).getSnippet()->getCode().empty(); + const auto resolvedType = var.type.resolve(modelMerged.getTypeContext()); genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, varAdaptor.getLoc(var.name), backend.getPreferences().automaticCopy, var.name + group.getName(), groupStatePushPullFunctions, [&]() { backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - var.type, var.name + group.getName(), + resolvedType, var.name + group.getName(), varAdaptor.getLoc(var.name), autoInitialized, getSizeFn(group, var)); }); } @@ -557,16 +561,11 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // write DT macro const ModelSpecInternal &model = modelMerged.getModel(); - definitions << "#define DT " << Utils::writePreciseString(model.getDT()) << model.getTimePrecision()->getLiteralSuffix(modelMerged.getTypeContext()) << std::endl; - - // Typedefine types in type context - for (const auto &t : modelMerged.getTypeContext()) { - definitions << "typedef " << t.second->getName() << " " << t.first << ";" << std::endl; - } + definitions << "#define DT " << Utils::writePreciseString(model.getDT()) << model.getTimePrecision().getNumeric().literalSuffix << std::endl; // Write ranges of scalar and time types - genTypeRange(definitions, model.getPrecision(), modelMerged.getTypeContext(), "SCALAR"); - genTypeRange(definitions, model.getTimePrecision(), modelMerged.getTypeContext(), "TIME"); + genTypeRange(definitions, model.getPrecision(), "SCALAR"); + genTypeRange(definitions, model.getTimePrecision(), "TIME"); definitions << "// ------------------------------------------------------------------------" << std::endl; definitions << "// bit tool macros" << std::endl; @@ -629,8 +628,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, } // If backend requires a global device RNG to simulate (or initialize) this model if(backend.isGlobalDeviceRNGRequired(modelMerged)) { - backend.genGlobalDeviceRNG(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - modelMerged.getTypeContext(), mem); + backend.genGlobalDeviceRNG(definitionsVar, definitionsInternalVar, + runnerVarDecl, runnerVarAlloc, runnerVarFree, mem); } // If backend required a global host RNG to simulate (or initialize) this model, generate a standard Mersenne Twister if(backend.isGlobalHostRNGRequired(modelMerged)) { @@ -660,17 +659,17 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Generate variables to store total elapsed time // **NOTE** we ALWAYS generate these so usercode doesn't require #ifdefs around timing code - genHostScalar(definitionsVar, runnerVarDecl, "initTime", "0.0"); - genHostScalar(definitionsVar, runnerVarDecl, "initSparseTime", "0.0"); - genHostScalar(definitionsVar, runnerVarDecl, "neuronUpdateTime", "0.0"); - genHostScalar(definitionsVar, runnerVarDecl, "presynapticUpdateTime", "0.0"); - genHostScalar(definitionsVar, runnerVarDecl, "postsynapticUpdateTime", "0.0"); - genHostScalar(definitionsVar, runnerVarDecl, "synapseDynamicsTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, Type::Double, "initTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, Type::Double, "initSparseTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, Type::Double, "neuronUpdateTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, Type::Double, "presynapticUpdateTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, Type::Double, "postsynapticUpdateTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, Type::Double, "synapseDynamicsTime", "0.0"); // Generate variables to store total elapsed time for each custom update group for(const auto &g : customUpdateGroups) { - genHostScalar(definitionsVar, runnerVarDecl, "customUpdate" + g + "Time", "0.0"); - genHostScalar(definitionsVar, runnerVarDecl, "customUpdate" + g + "TransposeTime", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, Type::Double, "customUpdate" + g + "Time", "0.0"); + genHostScalar(definitionsVar, runnerVarDecl, Type::Double, "customUpdate" + g + "TransposeTime", "0.0"); } // If timing is actually enabled @@ -907,24 +906,24 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const size_t numNeuronDelaySlots = batchSize * (size_t)n.second.getNumNeurons() * (size_t)n.second.getNumDelaySlots(); const size_t numSpikeCounts = n.second.isTrueSpikeRequired() ? (batchSize * n.second.getNumDelaySlots()) : batchSize; const size_t numSpikes = n.second.isTrueSpikeRequired() ? numNeuronDelaySlots : (batchSize * n.second.getNumNeurons()); - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - modelMerged.getTypeContext(), "glbSpkCnt" + n.first, - n.second.getSpikeLocation(), numSpikeCounts, mem); - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - modelMerged.getTypeContext(), "glbSpk" + n.first, - n.second.getSpikeLocation(), numSpikes, mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + Type::Uint32, "glbSpkCnt" + n.first, + n.second.getSpikeLocation(), numSpikeCounts, mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + Type::Uint32, "glbSpk" + n.first, + n.second.getSpikeLocation(), numSpikes, mem); // True spike push and pull functions genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, n.second.getSpikeLocation(), backend.getPreferences().automaticCopy, n.first + "Spikes", [&]() { - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - "glbSpkCnt" + n.first, - n.second.getSpikeLocation(), true, numSpikeCounts); - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - "glbSpk" + n.first, - n.second.getSpikeLocation(), true, numSpikes); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + Type::Uint32, "glbSpkCnt" + n.first, + n.second.getSpikeLocation(), true, numSpikeCounts); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + Type::Uint32, "glbSpk" + n.first, + n.second.getSpikeLocation(), true, numSpikes); }); // Current true spike getter functions @@ -933,10 +932,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If spike recording is enabled, define and declare variables and add free if(n.second.isSpikeRecordingEnabled()) { backend.genVariableDefinition(definitionsVar, definitionsInternalVar, - Type::Uint32::getInstance(), "recordSpk" + n.first, + Type::Uint32, "recordSpk" + n.first, VarLocation::HOST_DEVICE); backend.genVariableInstantiation(runnerVarDecl, - Type::Uint32::getInstance(), "recordSpk" + n.first, + Type::Uint32, "recordSpk" + n.first, VarLocation::HOST_DEVICE); backend.genVariableFree(runnerVarFree, "recordSpk" + n.first, VarLocation::HOST_DEVICE); @@ -950,24 +949,24 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, } // Spike-like event variables - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - modelMerged.getTypeContext(), "glbSpkCntEvnt" + n.first, n.second.getSpikeEventLocation(), - batchSize * n.second.getNumDelaySlots(), mem); - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - modelMerged.getTypeContext(), "glbSpkEvnt" + n.first, n.second.getSpikeEventLocation(), - numNeuronDelaySlots, mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + Type::Uint32, "glbSpkCntEvnt" + n.first, n.second.getSpikeEventLocation(), + batchSize * n.second.getNumDelaySlots(), mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + Type::Uint32, "glbSpkEvnt" + n.first, n.second.getSpikeEventLocation(), + numNeuronDelaySlots, mem); // Spike-like event push and pull functions genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, n.second.getSpikeEventLocation(), backend.getPreferences().automaticCopy, n.first + "SpikeEvents", [&]() { - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - "glbSpkCntEvnt" + n.first, - n.second.getSpikeLocation(), true, batchSize * n.second.getNumDelaySlots()); - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - "glbSpkEvnt" + n.first, - n.second.getSpikeLocation(), true, numNeuronDelaySlots); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + Type::Uint32, "glbSpkCntEvnt" + n.first, + n.second.getSpikeLocation(), true, batchSize * n.second.getNumDelaySlots()); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + Type::Uint32, "glbSpkEvnt" + n.first, + n.second.getSpikeLocation(), true, numNeuronDelaySlots); }); // Current true spike getter functions @@ -976,10 +975,10 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If spike recording is enabled, define and declare variables and add free if(n.second.isSpikeEventRecordingEnabled()) { backend.genVariableDefinition(definitionsVar, definitionsInternalVar, - Type::Uint32::getInstance(), "recordSpkEvent" + n.first, + Type::Uint32, "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); backend.genVariableInstantiation(runnerVarDecl, - Type::Uint32::getInstance(), "recordSpkEvent" + n.first, + Type::Uint32, "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); backend.genVariableFree(runnerVarFree, "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE); } @@ -987,14 +986,15 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If neuron group has axonal delays if (n.second.isDelayRequired()) { - genHostDeviceScalar(modelMerged, backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "spkQuePtr" + n.first, "0", mem); + genHostDeviceScalar(modelMerged, backend, definitionsVar, definitionsInternalVar, + runnerVarDecl, runnerVarAlloc, runnerVarFree, + Type::Uint32, "spkQuePtr" + n.first, "0", mem); } // If neuron group needs to record its spike times if (n.second.isSpikeTimeRequired()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getTimePrecision(), modelMerged.getTypeContext(), "sT" + n.first, + model.getTimePrecision(), "sT" + n.first, n.second.getSpikeTimeLocation(), numNeuronDelaySlots, mem); // Generate push and pull functions @@ -1011,7 +1011,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If neuron group needs to record its previous spike times if (n.second.isPrevSpikeTimeRequired()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getTimePrecision(), modelMerged.getTypeContext(), "prevST" + n.first, + model.getTimePrecision(), "prevST" + n.first, n.second.getPrevSpikeTimeLocation(), numNeuronDelaySlots, mem); // Generate push and pull functions @@ -1028,7 +1028,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If neuron group needs to record its spike-like-event times if (n.second.isSpikeEventTimeRequired()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getTimePrecision(), modelMerged.getTypeContext(), "seT" + n.first, + model.getTimePrecision(), "seT" + n.first, n.second.getSpikeEventTimeLocation(), numNeuronDelaySlots, mem); // Generate push and pull functions @@ -1045,7 +1045,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If neuron group needs to record its previous spike-like-event times if (n.second.isPrevSpikeEventTimeRequired()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getTimePrecision(), modelMerged.getTypeContext(), "prevSET" + n.first, + model.getTimePrecision(), "prevSET" + n.first, n.second.getPrevSpikeEventTimeLocation(), numNeuronDelaySlots, mem); // Generate push and pull functions @@ -1062,7 +1062,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If neuron group needs per-neuron RNGs if(n.second.isSimRNGRequired()) { backend.genPopulationRNG(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - modelMerged.getTypeContext(), "rng" + n.first, batchSize * n.second.getNumNeurons(), mem); + "rng" + n.first, batchSize * n.second.getNumNeurons(), mem); } // Neuron state variables @@ -1074,8 +1074,9 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const unsigned int numElements = getNumVarElements(var.access, n.second.getNumNeurons()); const size_t count = n.second.isVarQueueRequired(var.name) ? numCopies * numElements * n.second.getNumDelaySlots() : numCopies * n.second.getNumNeurons(); const bool autoInitialized = !varInitSnippet->getCode().empty(); + const auto resolvedType = var.type.resolve(modelMerged.getTypeContext()); genVariable(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, - runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, var.type, var.name + n.first, + runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, resolvedType, var.name + n.first, n.second.getVarLocation(var.name), autoInitialized, count, mem, neuronStatePushPullFunctions); // Current variable push and pull functions @@ -1084,14 +1085,14 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { backend.genCurrentVariablePushPull(runnerPushFunc, runnerPullFunc, n.second, - var.type, var.name, + resolvedType, var.name, n.second.getVarLocation(var.name), numCopies); }); // Write getter to get access to correct pointer const bool delayRequired = (n.second.isVarQueueRequired(var.name) && n.second.isDelayRequired()); genVarGetterScope(definitionsFunc, runnerGetterFunc, n.second.getVarLocation(var.name), - "Current" + var.name + n.first, var.type->getPointerType()->getName(), + "Current" + var.name + n.first, resolvedType.getValue().name + "*", [&]() { runnerGetterFunc << "return " << var.name << n.first; @@ -1223,7 +1224,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If custom connectivity update group needs per-row RNGs if(c.second.isRowSimRNGRequired()) { backend.genPopulationRNG(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - modelMerged.getTypeContext(), "rowRNG" + c.first, c.second.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), mem); + "rowRNG" + c.first, c.second.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), mem); } @@ -1243,15 +1244,15 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Loop through merged postsynaptic models of incoming synaptic populations for(const auto *sg : n.second.getFusedPSMInSyn()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getPrecision(), modelMerged.getTypeContext(), "inSyn" + sg->getFusedPSVarSuffix(), + model.getPrecision(), "inSyn" + sg->getFusedPSVarSuffix(), sg->getInSynLocation(), sg->getTrgNeuronGroup()->getNumNeurons() * batchSize, mem); if (sg->isDendriticDelayRequired()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getPrecision(), modelMerged.getTypeContext(), "denDelay" + sg->getFusedPSVarSuffix(), + model.getPrecision(), "denDelay" + sg->getFusedPSVarSuffix(), sg->getDendriticDelayLocation(), (size_t)sg->getMaxDendriticDelayTimesteps() * (size_t)sg->getTrgNeuronGroup()->getNumNeurons() * batchSize, mem); - genHostDeviceScalar(modelMerged, backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - "denDelayPtr" + sg->getFusedPSVarSuffix(), "0", mem); + genHostDeviceScalar(modelMerged, backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + Type::Uint32, "denDelayPtr" + sg->getFusedPSVarSuffix(), "0", mem); } genRunnerFusedVars(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, @@ -1264,7 +1265,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Loop through fused outgoing synapse populations with weightupdate models that have presynaptic output for(const auto *sg : n.second.getFusedPreOutputOutSyn()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getPrecision(), modelMerged.getTypeContext(), "revInSyn" + sg->getFusedPreOutputSuffix(), + model.getPrecision(), "revInSyn" + sg->getFusedPreOutputSuffix(), sg->getInSynLocation(), sg->getSrcNeuronGroup()->getNumNeurons() * batchSize, mem); } @@ -1304,18 +1305,18 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, if(s.second.getMatrixType() & SynapseMatrixConnectivity::BITMASK) { const size_t gpSize = ceilDivide((size_t)s.second.getSrcNeuronGroup()->getNumNeurons() * backend.getSynapticMatrixRowStride(s.second), 32); - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - modelMerged.getTypeContext(), "gp" + s.second.getName(), - s.second.getSparseConnectivityLocation(), gpSize, mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + Type::Uint32, "gp" + s.second.getName(), + s.second.getSparseConnectivityLocation(), gpSize, mem); // Generate push and pull functions for bitmask genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, s.second.getSparseConnectivityLocation(), backend.getPreferences().automaticCopy, s.second.getName() + "Connectivity", connectivityPushPullFunctions, [&]() { - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - "gp" + s.second.getName(), - s.second.getSparseConnectivityLocation(), autoInitialized, gpSize); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + Type::Uint32, "gp" + s.second.getName(), + s.second.getSparseConnectivityLocation(), autoInitialized, gpSize); }); } else if(s.second.getMatrixType() & SynapseMatrixConnectivity::SPARSE) { @@ -1327,13 +1328,13 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerVarDecl << "const unsigned int maxRowLength" << s.second.getName() << " = " << backend.getSynapticMatrixRowStride(s.second) << ";" << std::endl; // Row lengths - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - modelMerged.getTypeContext(), "rowLength" + s.second.getName(), - varLoc, s.second.getSrcNeuronGroup()->getNumNeurons(), mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + Type::Uint32, "rowLength" + s.second.getName(), + varLoc, s.second.getSrcNeuronGroup()->getNumNeurons(), mem); // Target indices backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - s.second.getSparseIndType(), modelMerged.getTypeContext(), "ind" + s.second.getName(), + s.second.getSparseIndType(), "ind" + s.second.getName(), varLoc, size, mem); // **TODO** remap is not always required @@ -1341,14 +1342,14 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const size_t postSize = (size_t)s.second.getTrgNeuronGroup()->getNumNeurons() * (size_t)s.second.getMaxSourceConnections(); // Allocate column lengths - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - modelMerged.getTypeContext(), "colLength" + s.second.getName(), - VarLocation::DEVICE, s.second.getTrgNeuronGroup()->getNumNeurons(), mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + Type::Uint32, "colLength" + s.second.getName(), + VarLocation::DEVICE, s.second.getTrgNeuronGroup()->getNumNeurons(), mem); // Allocate remap - backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - modelMerged.getTypeContext(), "remap" + s.second.getName(), - VarLocation::DEVICE, postSize, mem); + backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + Type::Uint32, "remap" + s.second.getName(), + VarLocation::DEVICE, postSize, mem); } // Generate push and pull functions for sparse connectivity @@ -1357,9 +1358,9 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&]() { // Row lengths - backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - "rowLength" + s.second.getName(), - s.second.getSparseConnectivityLocation(), autoInitialized, s.second.getSrcNeuronGroup()->getNumNeurons()); + backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, + Type::Uint32, "rowLength" + s.second.getName(), + s.second.getSparseConnectivityLocation(), autoInitialized, s.second.getSrcNeuronGroup()->getNumNeurons()); // Target indices backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, @@ -1386,10 +1387,11 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, for(const auto &wuVar : wu->getVars()) { const auto *varInitSnippet = s.second.getWUVarInitialisers().at(wuVar.name).getSnippet(); const bool autoInitialized = !varInitSnippet->getCode().empty(); + const auto resolvedType = wuVar.type.resolve(modelMerged.getTypeContext()); if(individualWeights) { const size_t size = (size_t)s.second.getSrcNeuronGroup()->getNumNeurons() * (size_t)backend.getSynapticMatrixRowStride(s.second); genVariable(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - runnerPushFunc, runnerPullFunc, wuVar.type, wuVar.name + s.second.getName(), s.second.getWUVarLocation(wuVar.name), + runnerPushFunc, runnerPullFunc, resolvedType, wuVar.name + s.second.getName(), s.second.getWUVarLocation(wuVar.name), autoInitialized, size * getNumVarCopies(wuVar.access, batchSize), mem, synapseGroupStatePushPullFunctions); } else if(kernelWeights) { @@ -1398,7 +1400,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Generate variable genVariable(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - runnerPushFunc, runnerPullFunc, wuVar.type, wuVar.name + s.second.getName(), s.second.getWUVarLocation(wuVar.name), + runnerPushFunc, runnerPullFunc, resolvedType, wuVar.name + s.second.getName(), s.second.getWUVarLocation(wuVar.name), autoInitialized, size, mem, synapseGroupStatePushPullFunctions); } @@ -1425,7 +1427,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, s.second.getInSynLocation(), true, s.second.getTrgNeuronGroup()->getNumNeurons() * batchSize); }); - genRunnerFusedVarPushPull(backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, + genRunnerFusedVarPushPull(modelMerged, backend, definitionsFunc, runnerPushFunc, runnerPullFunc, + s.second, synapseGroupStatePushPullFunctions, [batchSize](const SynapseGroupInternal &sg, const Models::Base::Var &var) { return getVarSize(var.access, sg.getTrgNeuronGroup()->getNumNeurons(), batchSize); @@ -1436,7 +1439,8 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // **NOTE** we generated initialisation and declaration code earlier - here we just generate push and pull as we want this per-synapse group if(!s.second.isWUPreModelFused()) { const unsigned int preDelaySlots = (s.second.getDelaySteps() == NO_DELAY) ? 1 : s.second.getSrcNeuronGroup()->getNumDelaySlots(); - genRunnerFusedVarPushPull(backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, + genRunnerFusedVarPushPull(modelMerged, backend, definitionsFunc, runnerPushFunc, runnerPullFunc, + s.second, synapseGroupStatePushPullFunctions, [batchSize, preDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) { return getVarSize(var.access, sg.getSrcNeuronGroup()->getNumNeurons(), batchSize, preDelaySlots); @@ -1448,11 +1452,12 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // **NOTE** we generated initialisation and declaration code earlier - here we just generate push and pull as we want this per-synapse group if(!s.second.isWUPostModelFused()) { const unsigned int postDelaySlots = (s.second.getBackPropDelaySteps() == NO_DELAY) ? 1 : s.second.getTrgNeuronGroup()->getNumDelaySlots(); - genRunnerFusedVarPushPull(backend, definitionsFunc, runnerPushFunc, runnerPullFunc, s.second, synapseGroupStatePushPullFunctions, - [batchSize, postDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) - { - return getVarSize(var.access, sg.getTrgNeuronGroup()->getNumNeurons(), batchSize, postDelaySlots); - }); + genRunnerFusedVarPushPull(modelMerged, backend, definitionsFunc, runnerPushFunc, runnerPullFunc, + s.second, synapseGroupStatePushPullFunctions, + [batchSize, postDelaySlots](const SynapseGroupInternal &sg, const Models::Base::Var &var) + { + return getVarSize(var.access, sg.getTrgNeuronGroup()->getNumNeurons(), batchSize, postDelaySlots); + }); } // Add helper function to push and pull entire synapse group state @@ -1602,8 +1607,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Allocate spike array if required if(n.second.isSpikeRecordingEnabled()) { CodeStream::Scope b(runner); - backend.genVariableDynamicAllocation(runner, - Type::Uint32::getInstance(), "recordSpk" + n.first, + backend.genVariableDynamicAllocation(runner, Type::Uint32, "recordSpk" + n.first, VarLocation::HOST_DEVICE, "numWords"); // Get destinations in merged structures, this EGP @@ -1618,8 +1622,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Allocate spike event array if required if(n.second.isSpikeEventRecordingEnabled()) { CodeStream::Scope b(runner); - backend.genVariableDynamicAllocation(runner, - Type::Uint32::getInstance(), "recordSpkEvent" + n.first, + backend.genVariableDynamicAllocation(runner, Type::Uint32, "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE, "numWords"); // Get destinations in merged structures, this EGP @@ -1658,16 +1661,14 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Pull spike array if required if(n.second.isSpikeRecordingEnabled()) { CodeStream::Scope b(runner); - backend.genVariableDynamicPull(runner, - Type::Uint32::getInstance(), "recordSpk" + n.first, + backend.genVariableDynamicPull(runner, Type::Uint32, "recordSpk" + n.first, VarLocation::HOST_DEVICE, "numWords"); } // AllocaPullte spike event array if required // **YUCK** maybe this should be renamed pullDynamicArray if(n.second.isSpikeEventRecordingEnabled()) { CodeStream::Scope b(runner); - backend.genVariableDynamicPull(runner, - Type::Uint32::getInstance(), "recordSpkEvent" + n.first, + backend.genVariableDynamicPull(runner, Type::Uint32, "recordSpkEvent" + n.first, VarLocation::HOST_DEVICE, "numWords"); } } diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index e7d3529183..519b7bff56 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -26,13 +26,13 @@ NeuronSpikeQueueUpdateGroupMerged::NeuronSpikeQueueUpdateGroupMerged(size_t inde using namespace Type; if(getArchetype().isDelayRequired()) { - addPointerField("spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); + addPointerField(Uint32, "spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); } - addPointerField("spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); + addPointerField(Uint32, "spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); if(getArchetype().isSpikeEventRequired()) { - addPointerField("spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); + addPointerField(Uint32, "spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); } } //---------------------------------------------------------------------------- @@ -75,27 +75,27 @@ NeuronPrevSpikeTimeUpdateGroupMerged::NeuronPrevSpikeTimeUpdateGroupMerged(size_ using namespace Type; if(getArchetype().isDelayRequired()) { - addPointerField("spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); + addPointerField(Uint32, "spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); } - addPointerField("spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); + addPointerField(Uint32, "spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); if(getArchetype().isSpikeEventRequired()) { - addPointerField("spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); + addPointerField(Uint32, "spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); } if(getArchetype().isPrevSpikeTimeRequired()) { - addPointerField("spk", backend.getDeviceVarPrefix() + "glbSpk"); + addPointerField(Uint32, "spk", backend.getDeviceVarPrefix() + "glbSpk"); addPointerField(getTimeType(), "prevST", backend.getDeviceVarPrefix() + "prevST"); } if(getArchetype().isPrevSpikeEventTimeRequired()) { - addPointerField("spkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); + addPointerField(Uint32, "spkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); addPointerField(getTimeType(), "prevSET", backend.getDeviceVarPrefix() + "prevSET"); } if(getArchetype().isDelayRequired()) { - addField("numNeurons", - [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); + addField(Uint32, "numNeurons", + [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); } } @@ -201,19 +201,19 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeConte orderNeuronGroupChildren(m_SortedCurrentSources, &NeuronGroupInternal::getCurrentSources, init ? &CurrentSourceInternal::getInitHashDigest : &CurrentSourceInternal::getHashDigest); - addField("numNeurons", - [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); + addField(Uint32, "numNeurons", + [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); - addPointerField("spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); - addPointerField("spk", backend.getDeviceVarPrefix() + "glbSpk"); + addPointerField(Uint32, "spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); + addPointerField(Uint32, "spk", backend.getDeviceVarPrefix() + "glbSpk"); if(getArchetype().isSpikeEventRequired()) { - addPointerField("spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); - addPointerField("spkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); + addPointerField(Uint32, "spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); + addPointerField(Uint32, "spkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); } if(getArchetype().isDelayRequired()) { - addPointerField("spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); + addPointerField(Uint32, "spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); } if(getArchetype().isSpikeTimeRequired()) { @@ -244,7 +244,8 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeConte for(const auto &var : vars) { // If we're not initialising or if there is initialization code for this variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addPointerField(var.type, var.name, backend.getDeviceVarPrefix() + var.name); + addPointerField(var.type.resolve(getTypeContext()), var.name, + backend.getDeviceVarPrefix() + var.name); } // If we're initializing, add any var init EGPs to structure @@ -289,7 +290,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeConte // Add pointer to dendritic delay buffer if required if(sg->isDendriticDelayRequired()) { addMergedInSynPointerField(getScalarType(), "denDelayInSyn", i, backend.getDeviceVarPrefix() + "denDelay"); - addMergedInSynPointerField(Uint32::getInstance(), "denDelayPtrInSyn", i, backend.getScalarAddressPrefix() + "denDelayPtr"); + addMergedInSynPointerField(Uint32, "denDelayPtrInSyn", i, backend.getScalarAddressPrefix() + "denDelayPtr"); } // Loop through variables @@ -297,7 +298,8 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeConte for(const auto &var : sg->getPSModel()->getVars()) { // Add pointers to state variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addMergedInSynPointerField(var.type, var.name + "InSyn", i, backend.getDeviceVarPrefix() + var.name); + addMergedInSynPointerField(var.type.resolve(getTypeContext()), var.name + "InSyn", i, + backend.getDeviceVarPrefix() + var.name); } // If we're generating an initialization structure, also add any heterogeneous parameters, derived parameters or extra global parameters required for initializers @@ -308,10 +310,10 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeConte addHeterogeneousChildVarInitDerivedParams(varInitSnippet->getDerivedParams(), m_SortedMergedInSyns, i, var.name, "InSyn", &NeuronGroupMergedBase::isPSMVarInitDerivedParamHeterogeneous, &SynapseGroupInternal::getPSVarInitialisers); addChildEGPs(varInitSnippet->getExtraGlobalParams(), i, backend.getDeviceVarPrefix(), var.name + "InSyn", - [var, this](size_t groupIndex, size_t childIndex) - { - return var.name + m_SortedMergedInSyns.at(groupIndex).at(childIndex)->getFusedPSVarSuffix(); - }); + [var, this](size_t groupIndex, size_t childIndex) + { + return var.name + m_SortedMergedInSyns.at(groupIndex).at(childIndex)->getFusedPSVarSuffix(); + }); } } @@ -352,7 +354,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeConte for(const auto &var : cs->getCurrentSourceModel()->getVars()) { // Add pointers to state variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(var.type->getPointerType(), var.name + "CS" + std::to_string(i), + addField(var.type.resolve(getTypeContext()).createPointer(), var.name + "CS" + std::to_string(i), [&backend, i, var, this](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + m_SortedCurrentSources.at(groupIndex).at(i)->getName(); @@ -480,20 +482,22 @@ bool NeuronGroupMergedBase::isPSMVarInitParamReferenced(size_t childIndex, const return isParamReferenced({varInitSnippet->getCode()}, paramName); } //---------------------------------------------------------------------------- -void NeuronGroupMergedBase::addMergedInSynPointerField(const Type::NumericBase *type, const std::string &name, +void NeuronGroupMergedBase::addMergedInSynPointerField(const Type::ResolvedType &type, const std::string &name, size_t archetypeIndex, const std::string &prefix) { - addField(type->getPointerType(), name + std::to_string(archetypeIndex), + assert(type.isValue()); + addField(type.createPointer(), name + std::to_string(archetypeIndex), [prefix, archetypeIndex, this](const auto&, size_t groupIndex) { return prefix + m_SortedMergedInSyns.at(groupIndex).at(archetypeIndex)->getFusedPSVarSuffix(); }); } //---------------------------------------------------------------------------- -void NeuronGroupMergedBase::addMergedPreOutputOutSynPointerField(const Type::NumericBase *type, const std::string &name, +void NeuronGroupMergedBase::addMergedPreOutputOutSynPointerField(const Type::ResolvedType &type, const std::string &name, size_t archetypeIndex, const std::string &prefix) { - addField(type->getPointerType(), name + std::to_string(archetypeIndex), + assert(type.isValue()); + addField(type.createPointer(), name + std::to_string(archetypeIndex), [prefix, archetypeIndex, this](const auto&, size_t groupIndex) { return prefix + m_SortedMergedPreOutputOutSyns.at(groupIndex).at(archetypeIndex)->getFusedPreOutputSuffix(); @@ -674,16 +678,16 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon // If role isn't an init role or weights aren't kernel if(role != Role::Init || !(getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL)) { - addField("rowStride", - [&backend](const SynapseGroupInternal &sg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(sg)); }); - addField("numSrcNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); - addField("numTrgNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); + addField(Uint32, "rowStride", + [&backend](const SynapseGroupInternal &sg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(sg)); }); + addField(Uint32, "numSrcNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + addField(Uint32, "numTrgNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); } if(role == Role::PostsynapticUpdate || role == Role::SparseInit) { - addField("colStride", + addField(Uint32, "colStride", [](const auto &sg, size_t) { return std::to_string(sg.getMaxSourceConnections()); }); } @@ -691,7 +695,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon if(role == Role::PresynapticUpdate || role == Role::SynapseDynamics) { if(getArchetype().isDendriticDelayRequired()) { addPSPointerField(getScalarType(), "denDelay", backend.getDeviceVarPrefix() + "denDelay"); - addPSPointerField(Uint32::getInstance(), "denDelayPtr", backend.getScalarAddressPrefix() + "denDelayPtr"); + addPSPointerField(Uint32, "denDelayPtr", backend.getScalarAddressPrefix() + "denDelayPtr"); } else { addPSPointerField(getScalarType(), "inSyn", backend.getDeviceVarPrefix() + "inSyn"); @@ -700,18 +704,18 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon if(role == Role::PresynapticUpdate) { if(getArchetype().isTrueSpikeRequired()) { - addSrcPointerField(Uint32::getInstance(), "srcSpkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); - addSrcPointerField(Uint32::getInstance(), "srcSpk", backend.getDeviceVarPrefix() + "glbSpk"); + addSrcPointerField(Uint32, "srcSpkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); + addSrcPointerField(Uint32, "srcSpk", backend.getDeviceVarPrefix() + "glbSpk"); } if(getArchetype().isSpikeEventRequired()) { - addSrcPointerField(Uint32::getInstance(), "srcSpkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); - addSrcPointerField(Uint32::getInstance(), "srcSpkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); + addSrcPointerField(Uint32, "srcSpkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); + addSrcPointerField(Uint32, "srcSpkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); } } else if(role == Role::PostsynapticUpdate) { - addTrgPointerField(Uint32::getInstance(), "trgSpkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); - addTrgPointerField(Uint32::getInstance(), "trgSpk", backend.getDeviceVarPrefix() + "glbSpk"); + addTrgPointerField(Uint32, "trgSpkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); + addTrgPointerField(Uint32, "trgSpk", backend.getDeviceVarPrefix() + "glbSpk"); } // If this structure is used for updating rather than initializing @@ -723,12 +727,12 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon // If presynaptic population has delay buffers if(getArchetype().getSrcNeuronGroup()->isDelayRequired()) { - addSrcPointerField(Uint32::getInstance(), "srcSpkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); + addSrcPointerField(Uint32, "srcSpkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); } // If postsynaptic population has delay buffers if(getArchetype().getTrgNeuronGroup()->isDelayRequired()) { - addTrgPointerField(Uint32::getInstance(), "trgSpkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); + addTrgPointerField(Uint32, "trgSpkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); } // Add heterogeneous presynaptic neuron model parameters @@ -763,7 +767,8 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon for(const auto &v : preVars) { // If variable is referenced in code string, add source pointer if(code.find("$(" + v.name + "_pre)") != std::string::npos) { - addSrcPointerField(v.type, v.name + "Pre", backend.getDeviceVarPrefix() + v.name); + addSrcPointerField(v.type.resolve(getTypeContext()), v.name + "Pre", + backend.getDeviceVarPrefix() + v.name); } } @@ -772,7 +777,8 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon for(const auto &v : postVars) { // If variable is referenced in code string, add target pointer if(code.find("$(" + v.name + "_post)") != std::string::npos) { - addTrgPointerField(v.type, v.name + "Post", backend.getDeviceVarPrefix() + v.name); + addTrgPointerField(v.type.resolve(getTypeContext()), v.name + "Post", + backend.getDeviceVarPrefix() + v.name); } } @@ -781,7 +787,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon for(const auto &e : preEGPs) { if(code.find("$(" + e.name + "_pre)") != std::string::npos) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(e.type->getPointerType(), e.name + "Pre", + addField(e.type.resolve(getTypeContext()).createPointer(), e.name + "Pre", [e, prefix](const auto &sg, size_t) { return prefix + e.name + sg.getSrcNeuronGroup()->getName(); }, GroupMergedFieldType::DYNAMIC); } @@ -792,7 +798,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon for(const auto &e : postEGPs) { if(code.find("$(" + e.name + "_post)") != std::string::npos) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(e.type->getPointerType(), e.name + "Post", + addField(e.type.resolve(getTypeContext()).createPointer(), e.name + "Post", [e, prefix](const auto &sg, size_t) { return prefix + e.name + sg.getTrgNeuronGroup()->getName(); }, GroupMergedFieldType::DYNAMIC); } @@ -832,14 +838,14 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon // Add presynaptic variables to struct for(const auto &v : wum->getPreVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(v.type->getPointerType(), v.name, + addField(v.type.resolve(getTypeContext()).createPointer(), v.name, [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPreVarSuffix(); }); } // Add presynaptic variables to struct for(const auto &v : wum->getPostVars()) { const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(v.type->getPointerType(), v.name, + addField(v.type.resolve(getTypeContext()).createPointer(), v.name, [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPostVarSuffix(); }); } @@ -849,19 +855,19 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon // Add pointers to connectivity data if(getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addPointerField("rowLength", backend.getDeviceVarPrefix() + "rowLength"); + addPointerField(Uint32, "rowLength", backend.getDeviceVarPrefix() + "rowLength"); addPointerField(getArchetype().getSparseIndType(), "ind", backend.getDeviceVarPrefix() + "ind"); // Add additional structure for postsynaptic access if(backend.isPostsynapticRemapRequired() && !wum->getLearnPostCode().empty() && (role == Role::PostsynapticUpdate || role == Role::SparseInit)) { - addPointerField("colLength", backend.getDeviceVarPrefix() + "colLength"); - addPointerField("remap", backend.getDeviceVarPrefix() + "remap"); + addPointerField(Uint32, "colLength", backend.getDeviceVarPrefix() + "colLength"); + addPointerField(Uint32, "remap", backend.getDeviceVarPrefix() + "remap"); } } else if(getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { - addPointerField("gp", backend.getDeviceVarPrefix() + "gp"); + addPointerField(Uint32, "gp", backend.getDeviceVarPrefix() + "gp"); } // If we're updating a group with procedural connectivity or initialising connectivity @@ -934,8 +940,8 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon for(size_t d = 0; d < getArchetype().getKernelSize().size(); d++) { // If this dimension has a heterogeneous size, add it to struct if(isKernelSizeHeterogeneous(d)) { - addField("kernelSize" + std::to_string(d), - [d](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getKernelSize().at(d)); }); + addField(Uint32, "kernelSize" + std::to_string(d), + [d](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getKernelSize().at(d)); }); } } } @@ -964,7 +970,8 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon // If we're performing an update with individual weights; or this variable should be initialised if((updateRole && individualWeights) || (kernelWeights && updateRole) || varInitRequired) { - addPointerField(var.type, var.name, backend.getDeviceVarPrefix() + var.name); + addPointerField(var.type.resolve(getTypeContext()), var.name, + backend.getDeviceVarPrefix() + var.name); } // If we're performing a procedural update or this variable should be initialised, add any var init EGPs to structure @@ -972,7 +979,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon const auto egps = snippet->getExtraGlobalParams(); for(const auto &e : egps) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(e.type->getPointerType(), e.name + var.name, + addField(e.type.resolve(getTypeContext()).createPointer(), e.name + var.name, [e, prefix, var](const SynapseGroupInternal &sg, size_t) { return prefix + e.name + var.name + sg.getName(); @@ -1097,24 +1104,28 @@ boost::uuids::detail::sha1::digest_type SynapseGroupMergedBase::getHashDigest(Ro return hash.get_digest(); } //---------------------------------------------------------------------------- -void SynapseGroupMergedBase::addPSPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) +void SynapseGroupMergedBase::addPSPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix) { - addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPSVarSuffix(); }); + assert(type.isValue()); + addField(type.createPointer(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPSVarSuffix(); }); } //---------------------------------------------------------------------------- -void SynapseGroupMergedBase::addPreOutputPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) +void SynapseGroupMergedBase::addPreOutputPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix) { - addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPreOutputSuffix(); }); + assert(type.isValue()); + addField(type.createPointer(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPreOutputSuffix(); }); } //---------------------------------------------------------------------------- -void SynapseGroupMergedBase::addSrcPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) +void SynapseGroupMergedBase::addSrcPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix) { - addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getSrcNeuronGroup()->getName(); }); + assert(type.isValue()); + addField(type.createPointer(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getSrcNeuronGroup()->getName(); }); } //---------------------------------------------------------------------------- -void SynapseGroupMergedBase::addTrgPointerField(const Type::NumericBase *type, const std::string &name, const std::string &prefix) +void SynapseGroupMergedBase::addTrgPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix) { - addField(type->getPointerType(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getTrgNeuronGroup()->getName(); }); + assert(type.isValue()); + addField(type.createPointer(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getTrgNeuronGroup()->getName(); }); } //---------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index f66d70fd18..186a3008b0 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -28,31 +28,6 @@ namespace Utils = GeNN::Utils; //--------------------------------------------------------------------------- namespace { -std::string getDescription(const Type::ResolvedType &type) -{ - const std::string qualifier = type.hasQualifier(Type::Qualifier::CONSTANT) ? "const " : ""; - return std::visit( - Utils::Overload{ - [&qualifier](const Type::ResolvedType::Value &value) - { - assert(value.numeric); - return qualifier + value.name; - }, - [&qualifier, &type](const Type::ResolvedType::Pointer &pointer) - { - return qualifier + getDescription(*pointer.valueType) + "*"; - }, - [&type](const Type::ResolvedType::Function &function) - { - std::string description = getDescription(*function.returnType) + "("; - for (const auto &a : function.argTypes) { - description += (getDescription(a) + ","); - } - return description + ")"; - }}, - type.detail); -} -//--------------------------------------------------------------------------- bool checkPointerTypeAssignement(const Type::ResolvedType &rightType, const Type::ResolvedType &leftType) { return std::visit( @@ -239,7 +214,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto indexType = evaluateType(arraySubscript.getIndex()); if (!indexType.isNumeric() || !indexType.getNumeric().isIntegral) { m_ErrorHandler.error(arraySubscript.getClosingSquareBracket(), - "Invalid subscript index type '" + getDescription(indexType) + "'"); + "Invalid subscript index type '" + indexType.getName() + "'"); throw TypeCheckError(); } @@ -265,7 +240,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Otherwise, if implicit conversion fails, give error else if (!checkImplicitConversion(rightType, leftType, assignment.getOperator().type)) { - m_ErrorHandler.error(assignment.getOperator(), "Invalid operand types '" + getDescription(leftType) + "' and '" + getDescription(rightType)); + m_ErrorHandler.error(assignment.getOperator(), "Invalid operand types '" + leftType.getName() + "' and '" + rightType.getName()); throw TypeCheckError(); } @@ -364,7 +339,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor setExpressionType(&binary, *resultType); } else { - m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + getDescription(leftType) + "' and '" + getDescription(rightType)); + m_ErrorHandler.error(binary.getOperator(), "Invalid operand types '" + leftType.getName() + "' and '" + rightType.getName()); throw TypeCheckError(); } } @@ -401,7 +376,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If const is being removed if (!checkForConstRemoval(rightType, cast.getType())) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + getDescription(cast.getType()) + "' and '" + getDescription(rightType)); + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType().getName() + "' and '" + rightType.getName()); throw TypeCheckError(); } @@ -439,7 +414,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor setExpressionType(&cast, *resultType); } else { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + getDescription(cast.getType()) + "' and '" + getDescription(rightType)); + m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType().getName() + "' and '" + rightType.getName()); throw TypeCheckError(); } } @@ -460,7 +435,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } else { m_ErrorHandler.error(conditional.getQuestion(), - "Invalid operand types '" + getDescription(trueType) + "' and '" + getDescription(falseType) + "' to conditional"); + "Invalid operand types '" + trueType.getName() + "' and '" + falseType.getName() + "' to conditional"); throw TypeCheckError(); } } @@ -473,7 +448,6 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Literal &literal) final { // Convert number token type to type - // **THINK** is it better to use typedef for scalar or resolve from m_Context if (literal.getValue().type == Token::Type::DOUBLE_NUMBER) { setExpressionType(&literal, Type::Double); } @@ -481,10 +455,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor setExpressionType(&literal, Type::Float); } else if (literal.getValue().type == Token::Type::SCALAR_NUMBER) { - // **TODO** cache - assert(false); - // **THINK** why not resolve here? - //setExpressionType(&literal, new Type::NumericTypedef("scalar")); + setExpressionType(&literal, m_Context.at("scalar")); } else if (literal.getValue().type == Token::Type::INT32_NUMBER) { setExpressionType(&literal, Type::Int32); @@ -493,7 +464,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor setExpressionType(&literal, Type::Uint32); } else if(literal.getValue().type == Token::Type::STRING) { - setExpressionType(&literal, Type::ResolvedType::createPointer(Type::Int8, Type::Qualifier::CONSTANT)); + setExpressionType(&literal, Type::Int8.createPointer(Type::Qualifier::CONSTANT)); } else { assert(false); @@ -648,7 +619,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + getDescription(rightType) + "'"); + "Invalid operand type '" + rightType.getName() + "'"); throw TypeCheckError(); } } @@ -668,7 +639,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + getDescription(rightType) + "'"); + "Invalid operand type '" + rightType.getName() + "'"); throw TypeCheckError(); } } @@ -678,12 +649,12 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } // Otherwise, if operator is address of, return pointer type else if (unary.getOperator().type == Token::Type::AMPERSAND) { - setExpressionType(&unary, Type::ResolvedType::createPointer(rightType)); + setExpressionType(&unary, rightType.createPointer()); } } else { m_ErrorHandler.error(unary.getOperator(), - "Invalid operand type '" + getDescription(rightType) + "'"); + "Invalid operand type '" + rightType.getName() + "'"); throw TypeCheckError(); } } @@ -784,7 +755,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto valType = evaluateType(labelled.getValue()); if (!valType.isNumeric() || !valType.getNumeric().isIntegral) { m_ErrorHandler.error(labelled.getKeyword(), - "Invalid case value '" + getDescription(valType) + "'"); + "Invalid case value '" + valType.getName() + "'"); throw TypeCheckError(); } } @@ -797,7 +768,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor auto condType = evaluateType(switchStatement.getCondition()); if (!condType.isNumeric() || !condType.getNumeric().isIntegral) { m_ErrorHandler.error(switchStatement.getSwitch(), - "Invalid condition '" + getDescription(condType) + "'"); + "Invalid condition '" + condType.getName() + "'"); throw TypeCheckError(); } @@ -817,7 +788,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor if (std::get<1>(var)) { const auto initialiserType = evaluateType(std::get<1>(var).get()); if (!checkImplicitConversion(initialiserType, decType)) { - m_ErrorHandler.error(std::get<0>(var), "Invalid operand types '" + getDescription(decType) + "' and '" + getDescription(initialiserType)); + m_ErrorHandler.error(std::get<0>(var), "Invalid operand types '" + decType.getName() + "' and '" + initialiserType.getName()); } } } diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 9edde1855e..59e974cb91 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -6,6 +6,7 @@ #include // GeNN includes +#include "gennUtils.h" #include "logging.h" // Transpiler includes @@ -57,15 +58,57 @@ const std::map unsignedType{ } // Anonymous namespace //---------------------------------------------------------------------------- -// GeNN::Type +// GeNN::Type::ResolvedType //---------------------------------------------------------------------------- namespace GeNN::Type { - +std::string ResolvedType::getName() const +{ + const std::string qualifier = hasQualifier(Type::Qualifier::CONSTANT) ? "const " : ""; + return std::visit( + Utils::Overload{ + [&qualifier](const Type::ResolvedType::Value &value) + { + assert(value.numeric); + return qualifier + value.name; + }, + [&qualifier](const Type::ResolvedType::Pointer &pointer) + { + return qualifier + pointer.valueType->getName() + "*"; + }, + [&qualifier](const Type::ResolvedType::Function &function) + { + std::string description = qualifier + function.returnType->getName() + "("; + for (const auto &a : function.argTypes) { + description += (a.getName() + ","); + } + return description + ")"; + }}, + detail); +} +//---------------------------------------------------------------------------- +size_t ResolvedType::getSize(size_t pointerBytes) const +{ + return std::visit( + Utils::Overload{ + [](const Type::ResolvedType::Value &value) + { + return value.size; + }, + [pointerBytes](const Type::ResolvedType::Pointer&) + { + return pointerBytes; + }, + [](const Type::ResolvedType::Function&) + { + throw std::runtime_error("Function types do not have size"); + }}, + detail); +} //---------------------------------------------------------------------------- -// UnresolvedType +// GeNN::Type::UnresolvedType //---------------------------------------------------------------------------- -ResolvedType UnresolvedType::resolve(const std::unordered_map &typeContext) const +ResolvedType UnresolvedType::resolve(const TypeContext &typeContext) const { return std::visit( Utils::Overload{ From f52fe2aa25c82773ef22c9d4a717f5952ee9af6a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 25 May 2023 14:26:55 +0100 Subject: [PATCH 160/725] fixed lots of compiler errors --- .../genn/genn/code_generator/backendSIMT.h | 10 +- .../genn/genn/code_generator/environment.h | 6 +- .../genn/genn/code_generator/groupMerged.h | 13 ++- .../groupMergedTypeEnvironment.h | 71 ++---------- .../genn/code_generator/modelSpecMerged.h | 2 +- include/genn/genn/transpiler/expression.h | 60 +++++------ include/genn/genn/transpiler/prettyPrinter.h | 2 +- .../genn/genn/transpiler/standardLibrary.h | 11 +- include/genn/genn/transpiler/statement.h | 50 ++++----- include/genn/genn/type.h | 5 + src/genn/genn/code_generator/backendSIMT.cc | 36 +++---- src/genn/genn/code_generator/environment.cc | 4 +- src/genn/genn/code_generator/groupMerged.cc | 4 +- .../genn/code_generator/initGroupMerged.cc | 102 +++++++++--------- .../code_generator/neuronUpdateGroupMerged.cc | 20 ++-- .../synapseUpdateGroupMerged.cc | 6 +- src/genn/genn/transpiler/parser.cc | 4 +- src/genn/genn/transpiler/standardLibrary.cc | 53 ++++----- 18 files changed, 194 insertions(+), 265 deletions(-) diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index f742e259cc..c1061187d4 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -92,7 +92,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase virtual std::string getCLZ() const = 0; //! Get name of atomic operation - virtual std::string getAtomic(const Type::NumericBase *type, const Type::TypeContext &typeContext, + virtual std::string getAtomic(const Type::ResolvedType &type, AtomicOperation op = AtomicOperation::ADD, AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const = 0; @@ -160,14 +160,6 @@ class GENN_EXPORT BackendSIMT : public BackendBase size_t getPaddedNumCustomUpdateWUThreads(const CustomUpdateWUInternal &cg, unsigned int batchSize) const; size_t getPaddedNumCustomUpdateTransposeWUThreads(const CustomUpdateWUInternal &cg, unsigned int batchSize) const; - //! Helper to get name of atomic operation - template - std::string getAtomic(const Type::TypeContext &typeContext, - AtomicOperation op = AtomicOperation::ADD, - AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const - { - return getAtomic(T::getInstance(), typeContext, op, memSpace); - } //-------------------------------------------------------------------------- // Static API diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 121ed4d1e8..621e55cd12 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -50,7 +50,7 @@ class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase CodeStream &getContextStream() const; - std::string getContextName(const std::string &name, const Type::ResolvedType &type) const; + std::string getContextName(const std::string &name, std::optional type) const; private: //------------------------------------------------------------------------ @@ -88,7 +88,7 @@ class EnvironmentSubstitute : public EnvironmentExternal //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual std::string getName(const std::string &name, const Type::ResolvedType &type) final; + virtual std::string getName(const std::string &name, std::optional type = std::nullopt) final; virtual CodeStream &getStream() final { @@ -230,7 +230,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual std::string getName(const std::string &name, const Type::ResolvedType &type) final + virtual std::string getName(const std::string &name, std::optional type = std::nullopt) final { // If variable with this name isn't found, try and get name from context auto var = m_VariablesReferenced.find(name); diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 740fc11474..4a16ac6892 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -70,8 +70,8 @@ class GroupMerged : m_Index(index), m_TypeContext(typeContext), m_Groups(std::move(groups)) {} - GroupMerged(const GroupMerged&) = delete; - GroupMerged(GroupMerged&&) = default; + //GroupMerged(const GroupMerged&) = delete; + //GroupMerged(GroupMerged&&) = default; //------------------------------------------------------------------------ // Public API @@ -164,7 +164,7 @@ class GroupMerged const auto sortedFields = getSortedFields(backend); for(const auto &f : sortedFields) { // Add size of field to total - const size_t fieldSize = std::get<0>(f)->getSizeBytes(m_TypeContext); + const size_t fieldSize = std::get<0>(f).getSize(backend.getPointerBytes()); structSize += fieldSize; // Update largest field size @@ -272,6 +272,11 @@ class GroupMerged [prefix](const G &g, size_t) { return prefix + g.getName(); }); } + void addPointerField(const Type::UnresolvedType &type, const std::string &name, const std::string &prefix) + { + addPointerField(type.resolve(getTypeContext()), name, prefix); + } + void addVars(const Models::Base::VarVec &vars, const std::string &arrayPrefix) { @@ -532,7 +537,7 @@ class GroupMerged //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - const size_t m_Index; + size_t m_Index; const Type::TypeContext &m_TypeContext; std::string m_MemorySpace; std::vector m_Fields; diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index 735490b02a..adf62a1f4f 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -32,59 +32,13 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- - virtual void define(const Transpiler::Token &name, const Type::Base*, ErrorHandlerBase &errorHandler) final + virtual void define(const Transpiler::Token &name, const Type::ResolvedType&, ErrorHandlerBase &errorHandler) final { errorHandler.error(name, "Cannot declare variable in external environment"); throw TypeCheckError(); } - virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler, - bool initializer) final - { - // If type isn't found - auto existingType = m_Types.find(name.lexeme); - if(existingType == m_Types.end()) { - if(m_Enclosing) { - return m_Enclosing->assign(name, op, assignedType, - context, errorHandler, initializer); - } - else { - errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); - } - } - - // Add field to merged group if required - addField(existingType->second); - - // Perform standard type-checking logicGroupMergedTypeEnvironment - return EnvironmentBase::assign(name, op, existingType->second.first, assignedType, - context, errorHandler, initializer); - } - - virtual const Type::Base *incDec(const Token &name, Token::Type op, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler) final - { - auto existingType = m_Types.find(name.lexeme); - if(existingType == m_Types.end()) { - if(m_Enclosing) { - return m_Enclosing->incDec(name, op, context, errorHandler); - } - else { - errorHandler.error(name, "Undefined variable"); - throw TypeCheckError(); - } - } - - // Add field to merged group if required - addField(existingType->second); - - // Perform standard type-checking logic - return EnvironmentBase::incDec(name, op, existingType->second.first, errorHandler); - } - - virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final { auto type = m_Types.find(name.lexeme); if(type == m_Types.end()) { @@ -107,21 +61,16 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- - void defineField(const Type::Base *type, const std::string &name) + void defineField(const Type::ResolvedType &type, const std::string &name) { if(!m_Types.try_emplace(name, type, std::nullopt).second) { throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } - template - void defineField(const std::string &name) - { - defineField(T::getInstance(), name); - } - void defineField(const Type::Base *type, const std::string &name, - const Type::Base *fieldType, std::string_view fieldName, typename G::GetFieldValueFunc getFieldValue, + void defineField(const Type::ResolvedType &type, const std::string &name, + const Type::ResolvedType &fieldType, std::string_view fieldName, typename G::GetFieldValueFunc getFieldValue, GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD) { if(!m_Types.try_emplace(name, std::piecewise_construct, @@ -132,16 +81,16 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa } } - void definePointerField(const Type::NumericBase *type, const std::string &name,const std::string &prefix, VarAccessMode access) + void definePointerField(const Type::ResolvedType &type, const std::string &name,const std::string &prefix, VarAccessMode access) { - const auto *qualifiedType = (access & VarAccessModeAttribute::READ_ONLY) ? type->getQualifiedType(Type::Qualifier::CONSTANT) : type; + const auto *qualifiedType = type.addQualifier((access & VarAccessModeAttribute::READ_ONLY) ? Type::Qualifier::CONSTANT : Type::Qualifier{0}); defineField(qualifiedType, name, type->getPointerType(), name, [prefix](const auto &g, size_t) { return prefix + g.getName(); }); } void defineScalarField(const std::string &name, typename G::GetFieldDoubleValueFunc getFieldValue) { - defineField(m_GroupMerged.getScalarType()->getQualifiedType(Type::Qualifier::CONSTANT), name, + defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), name, m_GroupMerged.getScalarType(), name, [getFieldValue, this](const auto &g, size_t i) { @@ -231,7 +180,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa //--------------------------------------------------------------------------- // Private methods //--------------------------------------------------------------------------- - void addField(std::pair> &type) + void addField(std::pair> &type) { // If this type has an associated field if (type.second) { @@ -252,6 +201,6 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa G &m_GroupMerged; EnvironmentBase *m_Enclosing; - std::unordered_map>> m_Types; + std::unordered_map>> m_Types; }; } // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 077590d182..87c80cb1de 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -454,6 +454,6 @@ class GENN_EXPORT ModelSpecMerged MergedEGPMap m_MergedEGPs; //! Type context used to resolve all types used in model - const Type::TypeContext m_TypeContext; + Type::TypeContext m_TypeContext; }; } // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index 550f453a34..7ec92e94c5 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -97,9 +97,9 @@ class ArraySubscript : public Acceptable const Base *getIndex() const { return m_Index.get(); } private: - const ExpressionPtr m_Array; - const Token m_ClosingSquareBracket; - const ExpressionPtr m_Index; + ExpressionPtr m_Array; + Token m_ClosingSquareBracket; + ExpressionPtr m_Index; }; //--------------------------------------------------------------------------- @@ -117,9 +117,9 @@ class Assignment : public Acceptable const Base *getValue() const { return m_Value.get(); } private: - const ExpressionPtr m_Assignee; - const Token m_Operator; - const ExpressionPtr m_Value; + ExpressionPtr m_Assignee; + Token m_Operator; + ExpressionPtr m_Value; }; //--------------------------------------------------------------------------- @@ -137,9 +137,9 @@ class Binary : public Acceptable const Base *getRight() const { return m_Right.get(); } private: - const ExpressionPtr m_Left; - const Token m_Operator; - const ExpressionPtr m_Right; + ExpressionPtr m_Left; + Token m_Operator; + ExpressionPtr m_Right; }; //--------------------------------------------------------------------------- @@ -157,9 +157,9 @@ class Call : public Acceptable const ExpressionList &getArguments() const { return m_Arguments; } private: - const ExpressionPtr m_Callee; - const Token m_ClosingParen; - const ExpressionList m_Arguments; + ExpressionPtr m_Callee; + Token m_ClosingParen; + ExpressionList m_Arguments; }; //--------------------------------------------------------------------------- @@ -177,9 +177,9 @@ class Cast : public Acceptable const Token &getClosingParen() const { return m_ClosingParen; } private: - const Type::ResolvedType m_Type; - const ExpressionPtr m_Expression; - const Token m_ClosingParen; + Type::ResolvedType m_Type; + ExpressionPtr m_Expression; + Token m_ClosingParen; }; //--------------------------------------------------------------------------- @@ -198,10 +198,10 @@ class Conditional : public Acceptable const Base *getFalse() const { return m_False.get(); } private: - const ExpressionPtr m_Condition; - const Token m_Question; - const ExpressionPtr m_True; - const ExpressionPtr m_False; + ExpressionPtr m_Condition; + Token m_Question; + ExpressionPtr m_True; + ExpressionPtr m_False; }; //--------------------------------------------------------------------------- @@ -222,7 +222,7 @@ class Grouping : public Acceptable const Base *getExpression() const { return m_Expression.get(); } private: - const ExpressionPtr m_Expression; + ExpressionPtr m_Expression; }; //--------------------------------------------------------------------------- @@ -261,9 +261,9 @@ class Logical : public Acceptable const Base *getRight() const { return m_Right.get(); } private: - const ExpressionPtr m_Left; - const Token m_Operator; - const ExpressionPtr m_Right; + ExpressionPtr m_Left; + Token m_Operator; + ExpressionPtr m_Right; }; //--------------------------------------------------------------------------- @@ -280,8 +280,8 @@ class PostfixIncDec : public Acceptable const Token &getOperator() const { return m_Operator; } private: - const ExpressionPtr m_Target; - const Token m_Operator; + ExpressionPtr m_Target; + Token m_Operator; }; //--------------------------------------------------------------------------- @@ -298,8 +298,8 @@ class PrefixIncDec : public Acceptable const Token &getOperator() const { return m_Operator; } private: - const ExpressionPtr m_Target; - const Token m_Operator; + ExpressionPtr m_Target; + Token m_Operator; }; //--------------------------------------------------------------------------- @@ -320,7 +320,7 @@ class Variable : public Acceptable const Token &getName() const { return m_Name; } private: - const Token m_Name; + Token m_Name; }; //--------------------------------------------------------------------------- @@ -337,7 +337,7 @@ class Unary : public Acceptable const Base *getRight() const { return m_Right.get(); } private: - const Token m_Operator; - const ExpressionPtr m_Right; + Token m_Operator; + ExpressionPtr m_Right; }; } // namespace GeNN::Transpiler::Expression diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h index 3ad5e9cd49..34fbda05f5 100644 --- a/include/genn/genn/transpiler/prettyPrinter.h +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -31,7 +31,7 @@ class EnvironmentBase virtual std::string define(const std::string &name) = 0; //! Get the name to use in code for the variable named by token - virtual std::string getName(const std::string &name, const Type::ResolvedType &type) = 0; + virtual std::string getName(const std::string &name, std::optional type = std::nullopt) = 0; //! Get stream to write code within this environment to virtual CodeGenerator::CodeStream &getStream() = 0; diff --git a/include/genn/genn/transpiler/standardLibrary.h b/include/genn/genn/transpiler/standardLibrary.h index 7073c48146..d6e8598e17 100644 --- a/include/genn/genn/transpiler/standardLibrary.h +++ b/include/genn/genn/transpiler/standardLibrary.h @@ -24,13 +24,8 @@ class FunctionTypes : public TypeChecker::EnvironmentBase //------------------------------------------------------------------------ // EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual void define(const Token &name, const Type::Base *type, ErrorHandlerBase &errorHandler) final; - virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler, - bool initializer = false) final; - virtual const Type::Base *incDec(const Token &name, Token::Type op, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler) final; - virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final; + virtual void define(const Token &name, const Type::ResolvedType &type, ErrorHandlerBase &errorHandler) final; + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final; }; //--------------------------------------------------------------------------- @@ -46,7 +41,7 @@ class FunctionEnvironment : public CodeGenerator::EnvironmentExternal //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual std::string getName(const std::string &name, const Type::Base *type) final; + virtual std::string getName(const std::string &name, const Type::ResolvedType &type) final; virtual CodeGenerator::CodeStream &getStream() final; }; } // namespace GeNN::Transpiler::StandardLibrary diff --git a/include/genn/genn/transpiler/statement.h b/include/genn/genn/transpiler/statement.h index f36fa1264d..964ade6a9a 100644 --- a/include/genn/genn/transpiler/statement.h +++ b/include/genn/genn/transpiler/statement.h @@ -88,7 +88,7 @@ class Break : public Acceptable const Token &getToken() const { return m_Token; } private: - const Token m_Token; + Token m_Token; }; //--------------------------------------------------------------------------- @@ -104,7 +104,7 @@ class Compound : public Acceptable const StatementList &getStatements() const { return m_Statements; } private: - const StatementList m_Statements; + StatementList m_Statements; }; //--------------------------------------------------------------------------- @@ -120,7 +120,7 @@ class Continue : public Acceptable const Token &getToken() const { return m_Token; } private: - const Token m_Token; + Token m_Token; }; //--------------------------------------------------------------------------- @@ -138,8 +138,8 @@ class Do : public Acceptable const Base *getBody() const { return m_Body.get(); } private: - const ExpressionPtr m_Condition; - const StatementPtr m_Body; + ExpressionPtr m_Condition; + StatementPtr m_Body; }; //--------------------------------------------------------------------------- @@ -156,7 +156,7 @@ class Expression : public Acceptable const ExpressionPtr::element_type *getExpression() const { return m_Expression.get(); } private: - const ExpressionPtr m_Expression; + ExpressionPtr m_Expression; }; //--------------------------------------------------------------------------- @@ -176,10 +176,10 @@ class For : public Acceptable const Base *getBody() const { return m_Body.get(); } private: - const StatementPtr m_Initialiser; - const ExpressionPtr m_Condition; - const ExpressionPtr m_Increment; - const StatementPtr m_Body; + StatementPtr m_Initialiser; + ExpressionPtr m_Condition; + ExpressionPtr m_Increment; + StatementPtr m_Body; }; //--------------------------------------------------------------------------- @@ -198,9 +198,9 @@ class If : public Acceptable const Base *getElseBranch() const { return m_ElseBranch.get(); } private: - const ExpressionPtr m_Condition; - const StatementPtr m_ThenBranch; - const StatementPtr m_ElseBranch; + ExpressionPtr m_Condition; + StatementPtr m_ThenBranch; + StatementPtr m_ElseBranch; }; //--------------------------------------------------------------------------- @@ -219,9 +219,9 @@ class Labelled : public Acceptable const Base *getBody() const { return m_Body.get(); } private: - const Token m_Keyword; - const ExpressionPtr m_Value; - const StatementPtr m_Body; + Token m_Keyword; + ExpressionPtr m_Value; + StatementPtr m_Body; }; @@ -241,9 +241,9 @@ class Switch : public Acceptable const Base *getBody() const { return m_Body.get(); } private: - const Token m_Switch; - const ExpressionPtr m_Condition; - const StatementPtr m_Body; + Token m_Switch; + ExpressionPtr m_Condition; + StatementPtr m_Body; }; @@ -263,9 +263,9 @@ class VarDeclaration : public Acceptable const InitDeclaratorList &getInitDeclaratorList() const { return m_InitDeclaratorList; } private: - const Type::ResolvedType m_Type; - const std::vector m_DeclarationSpecifiers; - const InitDeclaratorList m_InitDeclaratorList; + Type::ResolvedType m_Type; + std::vector m_DeclarationSpecifiers; + InitDeclaratorList m_InitDeclaratorList; }; //--------------------------------------------------------------------------- @@ -283,8 +283,8 @@ class While : public Acceptable const Base *getBody() const { return m_Body.get(); } private: - const ExpressionPtr m_Condition; - const StatementPtr m_Body; + ExpressionPtr m_Condition; + StatementPtr m_Body; }; //--------------------------------------------------------------------------- @@ -302,6 +302,6 @@ class Print : public Acceptable const ExpressionPtr::element_type *getExpression() const { return m_Expression.get(); } private: - const ExpressionPtr m_Expression; + ExpressionPtr m_Expression; }; } // namespace GeNN::Transpiler::Statement diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 8b2052a599..514b502abf 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -266,6 +266,11 @@ struct ResolvedType { return ResolvedType{Value{name, sizeof(T), std::nullopt}, qualifiers}; } + + static ResolvedType createFunction(const ResolvedType &returnType, const std::vector &argTypes) + { + return ResolvedType{Function{returnType, argTypes}, Qualifier{0}}; + } }; typedef std::unordered_map TypeContext; diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 0543ed30f4..b3028171a3 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -523,7 +523,7 @@ void BackendSIMT::genNeuronUpdateKernel(CodeStream &os, const Substitutions &ker os << "if (shSpkEvntCount > 0)"; { CodeStream::Scope b(os); - os << "shPosSpkEvnt = " << getAtomic(modelMerged.getTypeContext()) << "(&group->spkCntEvnt"; + os << "shPosSpkEvnt = " << getAtomic(Type::Uint32) << "(&group->spkCntEvnt"; if(ng.getArchetype().isDelayRequired()) { os << "[*group->spkQuePtr"; if(batchSize > 1) { @@ -546,7 +546,7 @@ void BackendSIMT::genNeuronUpdateKernel(CodeStream &os, const Substitutions &ker os << "if (shSpkCount > 0)"; { CodeStream::Scope b(os); - os << "shPosSpk = " << getAtomic(modelMerged.getTypeContext()) << "(&group->spkCnt"; + os << "shPosSpk = " << getAtomic(Type::Uint32) << "(&group->spkCnt"; if(ng.getArchetype().isDelayRequired() && ng.getArchetype().isTrueSpikeRequired()) { os << "[*group->spkQuePtr"; if(batchSize > 1) { @@ -814,7 +814,7 @@ void BackendSIMT::genPostsynapticUpdateKernel(CodeStream &os, const Substitution if(sg.getArchetype().isPresynapticOutputRequired()) { synSubs.addFuncSubstitution("addToPre", 1, - getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); + getAtomic(modelMerged.getModel().getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); } sg.generateSynapseUpdate(*this, os, modelMerged, synSubs); @@ -875,16 +875,16 @@ void BackendSIMT::genSynapseDynamicsKernel(CodeStream &os, const Substitutions & // If dendritic delay is required, always use atomic operation to update dendritic delay buffer // **TODO** once synapse dynamics gets refactored into update strategy classes, move the index building code elsewhere if(sg.getArchetype().isDendriticDelayRequired()) { - synSubs.addFuncSubstitution("addToInSynDelay", 2, getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, synSubs["id_post"], "$(1)") + "], $(0))"); + synSubs.addFuncSubstitution("addToInSynDelay", 2, getAtomic(modelMerged.getModel().getPrecision()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, synSubs["id_post"], "$(1)") + "], $(0))"); } // Otherwise else { - synSubs.addFuncSubstitution("addToInSyn", 1, getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, synSubs["id_post"]) + "], $(0))"); + synSubs.addFuncSubstitution("addToInSyn", 1, getAtomic(modelMerged.getModel().getPrecision()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, synSubs["id_post"]) + "], $(0))"); } if(sg.getArchetype().isPresynapticOutputRequired()) { synSubs.addFuncSubstitution("addToPre", 1, - getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); + getAtomic(modelMerged.getModel().getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); } sg.generateSynapseUpdate(*this, os, modelMerged, synSubs); @@ -937,7 +937,7 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpe // Loop through reduction targets and generate reduction for(const auto &r : reductionTargets) { - cuEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; + cuEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; } } @@ -980,7 +980,7 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpe // Loop through reduction targets and generate reduction for (const auto &r : reductionTargets) { - reductionEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; + reductionEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; } } @@ -989,7 +989,7 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpe for (unsigned int i = 16; i > 0; i /= 2) { for (const auto &r : reductionTargets) { cuEnv.getStream() << getReductionOperation("lr" + r.name, "__shfl_down_sync(0xFFFFFFFF, lr" + r.name + ", " + std::to_string(i) + ")", - r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; + r.access, r.type) << ";" << std::endl; } } @@ -1146,7 +1146,7 @@ void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, const ModelS if(cg.getArchetype().isBatchReduction()) { // Loop through reduction targets and generate reduction for(const auto &r : reductionTargets) { - cuEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type, modelMerged.getTypeContext()) << ";" << std::endl; + cuEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; } // End for loop through batches @@ -1564,7 +1564,7 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne } // Otherwise else { - kernelInit << "group->ind[(($(0)) * group->rowStride) + " << getAtomic(modelMerged.getTypeContext()) << +"(&group->rowLength[$(0)], 1)] = " << popSubs["id_post"] << ";"; + kernelInit << "group->ind[(($(0)) * group->rowStride) + " << getAtomic(Type::Uint32) << +"(&group->rowLength[$(0)], 1)] = " << popSubs["id_post"] << ";"; } } // Otherwise, if it's bitmask @@ -1575,12 +1575,12 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne // If there is row-building code in this snippet if(!snippet->getRowBuildCode().empty()) { kernelInit << "const " << indexType << " rowStartGID = " << popSubs["id"] << " * (" << indexType << ")group->rowStride;" << std::endl; - kernelInit << getAtomic(modelMerged.getTypeContext(), AtomicOperation::OR) << "(&group->gp[(rowStartGID + ($(0))) / 32], 0x80000000 >> ((rowStartGID + ($(0))) & 31));" << std::endl; + kernelInit << getAtomic(Type::Uint32, AtomicOperation::OR) << "(&group->gp[(rowStartGID + ($(0))) / 32], 0x80000000 >> ((rowStartGID + ($(0))) & 31));" << std::endl; } // Otherwise else { kernelInit << "const " << indexType << " colStartGID = " << popSubs["id"] << ";" << std::endl; - kernelInit << getAtomic(modelMerged.getTypeContext(), AtomicOperation::OR) << "(&group->gp[(colStartGID + (($(0)) * group->rowStride)) / 32], 0x80000000 >> ((colStartGID + (($(0)) * group->rowStride)) & 31));" << std::endl; + kernelInit << getAtomic(Type::Uint32, AtomicOperation::OR) << "(&group->gp[(colStartGID + (($(0)) * group->rowStride)) / 32], 0x80000000 >> ((colStartGID + (($(0)) * group->rowStride)) & 31));" << std::endl; } } } @@ -1645,7 +1645,7 @@ void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions // Generate sparse synapse variable initialisation code genSparseSynapseVarInit( os, modelMerged, sg, popSubs, sg.getArchetype().isWUVarInitRequired(), - [&modelMerged, this](CodeStream &os, const SynapseSparseInitGroupMerged &sg, Substitutions&) + [this](CodeStream &os, const SynapseSparseInitGroupMerged &sg, Substitutions&) { // If postsynaptic learning is required if(!sg.getArchetype().getWUModel()->getLearnPostCode().empty()) { @@ -1656,7 +1656,7 @@ void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions // Atomically increment length of column of connectivity associated with this target // **NOTE** this returns previous length i.e. where to insert new entry - os << "const unsigned int colLocation = " << getAtomic(modelMerged.getTypeContext()) << "(&group->colLength[postIndex], 1);" << std::endl; + os << "const unsigned int colLocation = " << getAtomic(Type::Uint32) << "(&group->colLength[postIndex], 1);" << std::endl; // From this calculate index into column-major matrix os << "const unsigned int colMajorIndex = (postIndex * group->colStride) + colLocation;" << std::endl; @@ -1711,16 +1711,16 @@ size_t BackendSIMT::padKernelSize(size_t size, Kernel kernel) const //-------------------------------------------------------------------------- void BackendSIMT::genEmitSpike(const ModelSpecMerged &modelMerged, CodeStream &os, const Substitutions &subs, const std::string &suffix, bool recordingEnabled) const { - os << "const unsigned int spk" << suffix << "Idx = " << getAtomic(modelMerged.getTypeContext(), AtomicOperation::ADD, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Count, 1);" << std::endl; + os << "const unsigned int spk" << suffix << "Idx = " << getAtomic(Type::Uint32, AtomicOperation::ADD, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Count, 1);" << std::endl; os << "shSpk" << suffix << "[spk" << suffix << "Idx] = " << subs["id"] << ";" << std::endl; // If recording is enabled, set bit in recording word if(recordingEnabled) { if(m_KernelBlockSizes[KernelNeuronUpdate] == 32) { - os << getAtomic(modelMerged.getTypeContext(), AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record, 1 << " << getThreadID() << ");" << std::endl; + os << getAtomic(Type::Uint32, AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record, 1 << " << getThreadID() << ");" << std::endl; } else { - os << getAtomic(modelMerged.getTypeContext(), AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record[" << getThreadID() << " / 32], 1 << (" << getThreadID() << " % 32));" << std::endl; + os << getAtomic(Type::Uint32, AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record[" << getThreadID() << " / 32], 1 << (" << getThreadID() << " % 32));" << std::endl; } } } diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index 9b0d9078ff..4fed4dd26d 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -28,7 +28,7 @@ CodeStream &EnvironmentExternal::getContextStream() const getContext()); } //---------------------------------------------------------------------------- -std::string EnvironmentExternal::getContextName(const std::string &name, const Type::ResolvedType &type) const +std::string EnvironmentExternal::getContextName(const std::string &name, std::optional type) const { return std::visit( Utils::Overload{ @@ -54,7 +54,7 @@ EnvironmentSubstitute::~EnvironmentSubstitute() getContextStream() << m_ContentsStream.str(); } //---------------------------------------------------------------------------- -std::string EnvironmentSubstitute::getName(const std::string &name, const Type::ResolvedType &type) +std::string EnvironmentSubstitute::getName(const std::string &name, std::optional type) { // If there isn't a substitution for this name, try and get name from context auto var = m_VarSubstitutions.find(name); diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 519b7bff56..98b8833e69 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -244,7 +244,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeConte for(const auto &var : vars) { // If we're not initialising or if there is initialization code for this variable if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addPointerField(var.type.resolve(getTypeContext()), var.name, + addPointerField(var.type, var.name, backend.getDeviceVarPrefix() + var.name); } @@ -970,7 +970,7 @@ SynapseGroupMergedBase::SynapseGroupMergedBase(size_t index, const Type::TypeCon // If we're performing an update with individual weights; or this variable should be initialised if((updateRole && individualWeights) || (kernelWeights && updateRole) || varInitRequired) { - addPointerField(var.type.resolve(getTypeContext()), var.name, + addPointerField(var.type, var.name, backend.getDeviceVarPrefix() + var.name); } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 50230d8b5d..26c72c5d7f 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -86,7 +86,7 @@ void genInitNeuronVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, co (CodeStream &os, Substitutions &varInitSubs) { // Generate initial value into temporary variable - os << var.type->getName() << " initVal;" << std::endl; + os << var.type.resolve(modelMerged.getTypeContext()).getName() << " initVal;" << std::endl; varInitSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); varInitSubs.applyCheckUnreplaced(code, "initVar : " + var.name + "merged" + std::to_string(groupIndex)); @@ -106,7 +106,7 @@ void genInitNeuronVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, co (CodeStream &os, Substitutions &varInitSubs) { // Generate initial value into temporary variable - os << var.type->getName() << " initVal;" << std::endl; + os << var.type.resolve(modelMerged.getTypeContext()).getName() << " initVal;" << std::endl; varInitSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); varInitSubs.applyCheckUnreplaced(code, "initVar : " + var.name + "merged" + std::to_string(groupIndex)); @@ -164,7 +164,7 @@ void genInitWUVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const "", "group->", var.name); // Generate initial value into temporary variable - os << var.type->getName() << " initVal;" << std::endl; + os << var.type.resolve(modelMerged.getTypeContext()).getName() << " initVal;" << std::endl; varSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); varSubs.applyCheckUnreplaced(code, "initVar : merged" + var.name + std::to_string(groupIndex)); @@ -444,7 +444,7 @@ void NeuronInitGroupMerged::generateWUVar(const BackendBase &backend, for(const auto &var : vars) { // Add pointers to state variable if(!varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(var.type->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), + addField(var.type.resolve(getTypeContext()).createPointer(), var.name + fieldPrefixStem + std::to_string(i), [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto&, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); @@ -652,7 +652,7 @@ void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase&, "", "group->", var.name); // Generate initial value into temporary variable - os << var.type->getName() << " initVal;" << std::endl; + os << var.type.resolve(getTypeContext()).getName() << " initVal;" << std::endl; popSubs.addVarSubstitution("value", "initVal"); std::string code = varInit.getSnippet()->getCode(); //popSubs.applyCheckUnreplaced(code, "initVar : merged" + vars[k].name + std::to_string(sg.getIndex())); @@ -691,7 +691,7 @@ void SynapseConnectivityInitGroupMerged::genInitConnectivity(CodeStream &os, Sub popSubs.applyCheckUnreplaced(value, "initSparseConnectivity state var : merged" + std::to_string(getIndex())); //value = ensureFtype(value, ftype); - os << a.type->getName() << " " << a.name << " = " << value << ";" << std::endl; + os << a.type.resolve(getTypeContext()).getName() << " " << a.name << " = " << value << ";" << std::endl; } os << "while(true)"; { @@ -721,12 +721,12 @@ SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(s using namespace Type; // **TODO** these could be generic - addField("numSrcNeurons", - [](const auto &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); - addField("numTrgNeurons", - [](const auto &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); - addField("rowStride", - [&backend](const auto &sg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(sg)); }); + addField(Uint32, "numSrcNeurons", + [](const auto &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + addField(Uint32, "numTrgNeurons", + [](const auto &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); + addField(Uint32, "rowStride", + [&backend](const auto &sg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(sg)); }); // Add heterogeneous connectivity initialiser model parameters addHeterogeneousParams( @@ -743,7 +743,7 @@ SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(s // Add EGP pointers to struct for both host and device EGPs if they are seperate const auto egps = getArchetype().getConnectivityInitialiser().getSnippet()->getExtraGlobalParams(); for(const auto &e : egps) { - const auto *pointerToPointerToEGP = e.type->getPointerType()->getPointerType(); + const auto &pointerToPointerToEGP = e.type.resolve(getTypeContext()).createPointer().createPointer(); addField(pointerToPointerToEGP, e.name, [e](const SynapseGroupInternal &g, size_t) { return "&" + e.name + g.getName(); }, GroupMergedFieldType::HOST_DYNAMIC); @@ -806,7 +806,7 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac // Generate code to allocate this EGP with count specified by $(0) // **NOTE** we generate these with a pointer type as the fields are pointer to pointer std::stringstream allocStream; - const auto *pointerToEGP = egp.type->getPointerType(); + const auto &pointerToEGP = egp.type.resolve(getTypeContext()).createPointer(); CodeGenerator::CodeStream alloc(allocStream); backend.genVariableDynamicAllocation(alloc, pointerToEGP, egp.name, @@ -864,8 +864,8 @@ CustomUpdateInitGroupMerged::CustomUpdateInitGroupMerged(size_t index, const Typ const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) { - addField("size", - [](const auto &c, size_t) { return std::to_string(c.getSize()); }); + addField(Type::Uint32, "size", + [](const auto &c, size_t) { return std::to_string(c.getSize()); }); } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomUpdateInitGroupMerged::getHashDigest() const @@ -906,18 +906,18 @@ CustomWUUpdateInitGroupMerged::CustomWUUpdateInitGroupMerged(size_t index, const for (size_t d = 0; d < getArchetype().getSynapseGroup()->getKernelSize().size(); d++) { // If this dimension has a heterogeneous size, add it to struct if (isKernelSizeHeterogeneous(d)) { - addField("kernelSize" + std::to_string(d), - [d](const auto &g, size_t) { return std::to_string(g.getSynapseGroup()->getKernelSize().at(d)); }); + addField(Uint32, "kernelSize" + std::to_string(d), + [d](const auto &g, size_t) { return std::to_string(g.getSynapseGroup()->getKernelSize().at(d)); }); } } } else { - addField("rowStride", - [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); - addField("numSrcNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); - addField("numTrgNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); + addField(Uint32, "rowStride", + [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); + addField(Uint32, "numSrcNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); + addField(Uint32, "numTrgNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); } } //---------------------------------------------------------------------------- @@ -1011,21 +1011,21 @@ CustomWUUpdateSparseInitGroupMerged::CustomWUUpdateSparseInitGroupMerged(size_t { using namespace Type; - addField("rowStride", - [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); + addField(Uint32, "rowStride", + [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); - addField("numSrcNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); - addField("numTrgNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); + addField(Uint32, "numSrcNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); + addField(Uint32, "numTrgNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); - addField(Uint32::getInstance()->getPointerType(), "rowLength", + addField(Uint32.createPointer(), "rowLength", [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); }); - addField(getArchetype().getSynapseGroup()->getSparseIndType()->getPointerType(), "ind", + addField(getArchetype().getSynapseGroup()->getSparseIndType().createPointer(), "ind", [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); @@ -1081,13 +1081,11 @@ CustomConnectivityUpdatePreInitGroupMerged::CustomConnectivityUpdatePreInitGroup const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) { - using namespace Type; - - addField("size", - [](const auto &c, size_t) - { - return std::to_string(c.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); - }); + addField(Type::Uint32, "size", + [](const auto &c, size_t) + { + return std::to_string(c.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); + }); // If this backend initialises population RNGs on device and this group requires one for simulation if(backend.isPopulationRNGRequired() && getArchetype().isRowSimRNGRequired() && backend.isPopulationRNGInitialisedOnDevice()) { @@ -1130,11 +1128,11 @@ CustomConnectivityUpdatePostInitGroupMerged::CustomConnectivityUpdatePostInitGro const std::vector> &groups) : CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) { - addField("size", - [](const auto &c, size_t) - { - return std::to_string(c.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); - }); + addField(Type::Uint32, "size", + [](const auto &c, size_t) + { + return std::to_string(c.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); + }); } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePostInitGroupMerged::getHashDigest() const @@ -1174,21 +1172,21 @@ CustomConnectivityUpdateSparseInitGroupMerged::CustomConnectivityUpdateSparseIni { using namespace Type; - addField("rowStride", - [&backend](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); + addField(Uint32, "rowStride", + [&backend](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); - addField("numSrcNeurons", - [](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); - addField("numTrgNeurons", - [](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); + addField(Uint32, "numSrcNeurons", + [](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); + addField(Uint32, "numTrgNeurons", + [](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); - addField(Uint32::getInstance()->getPointerType(), "rowLength", + addField(Uint32.createPointer(), "rowLength", [&backend](const CustomConnectivityUpdateInternal &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); }); - addField(getArchetype().getSynapseGroup()->getSparseIndType()->getPointerType(), "ind", + addField(getArchetype().getSynapseGroup()->getSparseIndType().createPointer(), "ind", [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sg = cg.getSynapseGroup(); diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 40ead7cb6b..a75dd7c185 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -66,7 +66,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC // If EGP is referenced in event threshold code if(s.eventThresholdCode.find("$(" + egp.name + ")") != std::string::npos) { const std::string prefix = backend.getDeviceVarPrefix(); - addField(egp.type->getPointerType(), egp.name + "EventThresh" + std::to_string(i), + addField(egp.type.resolve(getTypeContext()).createPointer(), egp.name + "EventThresh" + std::to_string(i), [eventThresholdSGs, prefix, egp, i](const auto &, size_t groupIndex) { return prefix + egp.name + eventThresholdSGs.at(groupIndex).at(i)->getName(); @@ -80,7 +80,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC for(const auto &var : sgPreVars) { // If variable is referenced in event threshold code if(s.eventThresholdCode.find("$(" + var.name + ")") != std::string::npos) { - addField(var.type->getPointerType(), var.name + "EventThresh" + std::to_string(i), + addField(var.type.resolve(getTypeContext()).createPointer(), var.name + "EventThresh" + std::to_string(i), [&backend, eventThresholdSGs, var, i](const auto&, size_t groupIndex) { return backend.getDeviceVarPrefix() + var.name + eventThresholdSGs.at(groupIndex).at(i)->getName(); @@ -93,7 +93,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC if(getArchetype().isSpikeRecordingEnabled()) { // Add field for spike recording - addField(Uint32::getInstance()->getPointerType(), "recordSpk", + addField(Uint32.createPointer(), "recordSpk", [&backend](const auto &ng, size_t) { return backend.getDeviceVarPrefix() + "recordSpk" + ng.getName(); @@ -103,7 +103,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC if(getArchetype().isSpikeEventRecordingEnabled()) { // Add field for spike event recording - addField(Uint32::getInstance()->getPointerType(), "recordSpkEvent", + addField(Uint32.createPointer(), "recordSpkEvent", [&backend](const auto &ng, size_t) { return backend.getDeviceVarPrefix() + "recordSpkEvent" + ng.getName(); @@ -187,7 +187,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type->getName() << " l" << v.name << " = group->" << v.name << "["; + os << v.type.resolve(getTypeContext()).getName() << " l" << v.name << " = group->" << v.name << "["; const bool delayed = (getArchetype().isVarQueueRequired(v.name) && getArchetype().isDelayRequired()); os << getReadVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), popSubs["id"]) << "];" << std::endl; } @@ -249,7 +249,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C neuronSubs.applyCheckUnreplaced(value, "neuron additional input var : merged" + std::to_string(getIndex())); //value = ensureFtype(value, modelMerged.getModel().getPrecision()); - os << a.type->getName() << " " << a.name << " = " << value << ";" << std::endl; + os << a.type.resolve(getTypeContext()).getName() << " " << a.name << " = " << value << ";" << std::endl; } // Loop through incoming synapse groups @@ -282,7 +282,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type->getName() << " lps" << v.name << " = group->" << v.name << "InSyn" << i << "["; + os << v.type.resolve(getTypeContext()).getName() << " lps" << v.name << " = group->" << v.name << "InSyn" << i << "["; os << getVarIndex(batchSize, getVarAccessDuplication(v.access), neuronSubs["id"]) << "];" << std::endl; } @@ -366,7 +366,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type->getName() << " lcs" << v.name << " = " << "group->" << v.name << "CS" << i << "["; + os << v.type.resolve(getTypeContext()).getName() << " lcs" << v.name << " = " << "group->" << v.name << "CS" << i << "["; os << getVarIndex(batchSize, getVarAccessDuplication(v.access), popSubs["id"]) << "];" << std::endl; } @@ -716,7 +716,7 @@ void NeuronUpdateGroupMerged::generateWUVar(const BackendBase &backend, const s for(size_t v = 0; v < vars.size(); v++) { // Add pointers to state variable const auto var = vars[v]; - addField(var.type->getPointerType(), var.name + fieldPrefixStem + std::to_string(i), + addField(var.type.resolve(getTypeContext()).createPointer(), var.name + fieldPrefixStem + std::to_string(i), [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto &, size_t groupIndex) { const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); @@ -794,7 +794,7 @@ void NeuronUpdateGroupMerged::generateWUVarUpdate(CodeStream &os, const Substitu if(v.access & VarAccessMode::READ_ONLY) { os << "const "; } - os << v.type->getName() << " l" << v.name << " = group->" << v.name << fieldPrefixStem << i << "["; + os << v.type.resolve(getTypeContext()).getName() << " l" << v.name << " = group->" << v.name << fieldPrefixStem << i << "["; os << getReadVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), subs["id"]) << "];" << std::endl; } diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 004e7a3459..39e4aa3eae 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -90,7 +90,7 @@ void applySynapseSubstitutions(CodeStream &os, std::string code, const std::stri varSubs.applyCheckUnreplaced(code, "initVar : merged" + var.name + std::to_string(sg.getIndex())); // Declare local variable - os << var.type->getName() << " " << "l" << var.name << ";" << std::endl; + os << var.type.resolve(sg.getTypeContext()).getName() << " " << "l" << var.name << ";" << std::endl; // Insert code to initialize variable into scope { @@ -251,7 +251,7 @@ void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendB std::string value = a.value; popSubs.applyCheckUnreplaced(value, "proceduralSparseConnectivity row build state var : merged" + std::to_string(getIndex())); //value = ensureFtype(value, modelMerged.getModel().getPrecision()); - os << a.type->getName() << " " << a.name << " = " << value << ";" << std::endl; + os << a.type.resolve(getTypeContext()).getName() << " " << a.name << " = " << value << ";" << std::endl; } // Loop through synapses in row @@ -325,7 +325,7 @@ SynapseDendriticDelayUpdateGroupMerged::SynapseDendriticDelayUpdateGroupMerged(s const std::vector> &groups) : GroupMerged(index, typeContext, groups) { - addField(Type::Uint32::getInstance()->getPointerType(), "denDelayPtr", + addField(Type::Uint32.createPointer(), "denDelayPtr", [&backend](const SynapseGroupInternal &sg, size_t) { return backend.getScalarAddressPrefix() + "denDelayPtr" + sg.getFusedPSVarSuffix(); diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 937655bca6..04cadacf46 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -183,7 +183,7 @@ Expression::ExpressionPtr parseBinary(ParserState &parserState, N nonTerminal, s return expression; } -const GeNN::Type::Base *parseDeclarationSpecifiers(ParserState &parserState) +const GeNN::Type::ResolvedType *parseDeclarationSpecifiers(ParserState &parserState) { using namespace GeNN::Type; @@ -215,7 +215,7 @@ const GeNN::Type::Base *parseDeclarationSpecifiers(ParserState &parserState) } while(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER, Token::Type::STAR})); // Lookup numeric type - const Base *type = getNumericType(typeSpecifiers); + const Type::ResolvedType *type = getNumericType(typeSpecifiers); // If there are any type qualifiers, add const // **THINK** this relies of const being only qualifier diff --git a/src/genn/genn/transpiler/standardLibrary.cc b/src/genn/genn/transpiler/standardLibrary.cc index c0d9ed9d5d..d84365a5bf 100644 --- a/src/genn/genn/transpiler/standardLibrary.cc +++ b/src/genn/genn/transpiler/standardLibrary.cc @@ -3,7 +3,6 @@ // Standard C++ library #include #include -#include // GeNN includes #include "type.h" @@ -21,16 +20,16 @@ namespace Type = GeNN::Type; // Macros //--------------------------------------------------------------------------- #define ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(NAME) \ - std::make_pair(#NAME, std::make_pair(std::make_unique>(), #NAME"($(0))")), \ - std::make_pair(#NAME, std::make_pair(std::make_unique>(), #NAME"($(0))")) + std::make_pair(#NAME, std::make_pair(Type::ResolvedType::createFunction(Type::Float, {Type::Float}), #NAME"($(0))")), \ + std::make_pair(#NAME, std::make_pair(Type::ResolvedType::createFunction(Type::Double, {Type::Double}), #NAME"($(0))")) #define ADD_TWO_ARG_FLOAT_DOUBLE_FUNC(NAME) \ - std::make_pair(#NAME, std::make_pair(std::make_unique>(), #NAME"($(0), $(1))")), \ - std::make_pair(#NAME, std::make_pair(std::make_unique>(), #NAME"($(0), $(1))")) + std::make_pair(#NAME, std::make_pair(Type::ResolvedType::createFunction(Type::Float, {Type::Float, Type::Float}), #NAME"($(0), $(1))")), \ + std::make_pair(#NAME, std::make_pair(Type::ResolvedType::createFunction(Type::Double, {Type::Double, Type::Double}), #NAME"($(0), $(1))")) #define ADD_THREE_ARG_FLOAT_DOUBLE_FUNC(NAME) \ - std::make_pair(#NAME, std::make_pair(std::make_unique>(), #NAME"($(0), $(1), $(2))")), \ - std::make_pair(#NAME, std::make_pair(std::make_unique>(), #NAME"($(0), $(1), $(2))")) + std::make_pair(#NAME, std::make_pair(Type::ResolvedType::createFunction(Type::Float, {Type::Float, Type::Float, Type::Float}), #NAME"($(0), $(1), $(2))")), \ + std::make_pair(#NAME, std::make_pair(Type::ResolvedType::createFunction(Type::Double, {Type::Double, Type::Double, Type::Double}), #NAME"($(0), $(1), $(2))")) //--------------------------------------------------------------------------- // Anonymous namespace @@ -40,7 +39,7 @@ namespace template auto initLibraryTypes(Args&&... args) { - std::unordered_multimap, std::string>> map; + std::unordered_multimap> map; (map.emplace(std::forward(args)), ...); return map; } @@ -72,18 +71,18 @@ const auto libraryTypes = initLibraryTypes( ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(expm1), ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(exp2), ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(pow), - std::make_pair("scalbn", std::make_pair(std::make_unique>(), "scalbn($(0), $(1))")), - std::make_pair("scalbn", std::make_pair(std::make_unique>(), "scalbn($(0), $(1))")), + std::make_pair("scalbn", std::make_pair(Type::ResolvedType::createFunction(Type::Float, {Type::Float, Type::Int32}), "scalbn($(0), $(1))")), + std::make_pair("scalbn", std::make_pair(Type::ResolvedType::createFunction(Type::Double, {Type::Double, Type::Int32}), "scalbn($(0), $(1))")), // Logarithm functions ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(log), ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(log1p), ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(log2), ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(log10), - std::make_pair("ldexp", std::make_pair(std::make_unique>(), "ldexp($(0), $(1))")), - std::make_pair("ldexp", std::make_pair(std::make_unique>(), "ldexp($(0), $(1))")), - std::make_pair("ilogb", std::make_pair(std::make_unique>(), "ilogb($(0))")), - std::make_pair("ilogb", std::make_pair(std::make_unique>(), "ilogb($(0))")), + std::make_pair("ldexp", std::make_pair(Type::ResolvedType::createFunction(Type::Float, {Type::Float, Type::Int32}), "ldexp($(0), $(1))")), + std::make_pair("ldexp", std::make_pair(Type::ResolvedType::createFunction(Type::Double, {Type::Double, Type::Int32}), "ldexp($(0), $(1))")), + std::make_pair("ilogb", std::make_pair(Type::ResolvedType::createFunction(Type::Int32, {Type::Float}), "ilogb($(0))")), + std::make_pair("ilogb", std::make_pair(Type::ResolvedType::createFunction(Type::Int32, {Type::Double}), "ilogb($(0))")), // Root functions ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(sqrt), @@ -134,27 +133,13 @@ FunctionTypes::FunctionTypes() { } //------------------------------------------------------------------------ -void FunctionTypes::define(const Token &name, const Type::Base*, ErrorHandlerBase &errorHandler) +void FunctionTypes::define(const Token &name, const Type::ResolvedType&, ErrorHandlerBase &errorHandler) { errorHandler.error(name, "Cannot declare variable in external environment"); throw TypeCheckError(); } //--------------------------------------------------------------------------- -const Type::Base *FunctionTypes::assign(const Token &name, Token::Type, const Type::Base*, - const Type::TypeContext&, ErrorHandlerBase &errorHandler, bool) -{ - errorHandler.error(name, "Cannot assign variable in external environment"); - throw TypeCheckError(); -} -//--------------------------------------------------------------------------- -const Type::Base *FunctionTypes::incDec(const Token &name, Token::Type, const Type::TypeContext&, - ErrorHandlerBase &errorHandler) -{ - errorHandler.error(name, "Cannot increment/decrement variable in external environment"); - throw TypeCheckError(); -} -//--------------------------------------------------------------------------- -std::vector FunctionTypes::getTypes(const Token &name, ErrorHandlerBase &errorHandler) +std::vector FunctionTypes::getTypes(const Token &name, ErrorHandlerBase &errorHandler) { const auto [typeBegin, typeEnd] = libraryTypes.equal_range(name.lexeme); if (typeBegin == typeEnd) { @@ -162,10 +147,10 @@ std::vector FunctionTypes::getTypes(const Token &name, ErrorH throw TypeCheckError(); } else { - std::vector types; + std::vector types; types.reserve(std::distance(typeBegin, typeEnd)); std::transform(typeBegin, typeEnd, std::back_inserter(types), - [](const auto &t) { return t.second.first.get(); }); + [](const auto &t) { return t.second.first; }); return types; } } @@ -173,7 +158,7 @@ std::vector FunctionTypes::getTypes(const Token &name, ErrorH //--------------------------------------------------------------------------- // GeNN::Transpiler::StandardLibrary::FunctionEnvironment //--------------------------------------------------------------------------- -std::string FunctionEnvironment::getName(const std::string &name, const Type::Base *type) +std::string FunctionEnvironment::getName(const std::string &name, const Type::ResolvedType &type) { const auto [libTypeBegin, libTypeEnd] = libraryTypes.equal_range(name); if (libTypeBegin == libTypeEnd) { @@ -181,7 +166,7 @@ std::string FunctionEnvironment::getName(const std::string &name, const Type::Ba } else { const auto libType = std::find_if(libTypeBegin, libTypeEnd, - [type](const auto &t){ return t.second.first.get() == type; }); + [type](const auto &t){ return t.second.first == type; }); assert(libType != libTypeEnd); return libType->second.second; } From 447a201f0f760fd201252c479d981a55e2b94825 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 25 May 2023 15:01:02 +0100 Subject: [PATCH 161/725] removed ``Token::Type::SCALAR`` and resolve this at scanning time --- include/genn/genn/transpiler/scanner.h | 9 +++-- include/genn/genn/transpiler/token.h | 2 +- .../code_generator/customUpdateGroupMerged.cc | 5 ++- src/genn/genn/transpiler/scanner.cc | 35 +++++++++++++------ 4 files changed, 31 insertions(+), 20 deletions(-) diff --git a/include/genn/genn/transpiler/scanner.h b/include/genn/genn/transpiler/scanner.h index b3af6b8b75..4c2ba5f375 100644 --- a/include/genn/genn/transpiler/scanner.h +++ b/include/genn/genn/transpiler/scanner.h @@ -8,14 +8,13 @@ #include #include +// GeNN includes +#include "type.h" + // Transpiler includes #include "transpiler/token.h" // Forward declarations -namespace GeNN::Type -{ -class NumericBase; -} namespace GeNN::Transpiler { class ErrorHandlerBase; @@ -26,6 +25,6 @@ class ErrorHandlerBase; //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Scanner { -std::vector scanSource(const std::string_view &source, ErrorHandlerBase &errorHandler); +std::vector scanSource(const std::string_view &source, const Type::TypeContext &context, ErrorHandlerBase &errorHandler); } // namespace Scanner diff --git a/include/genn/genn/transpiler/token.h b/include/genn/genn/transpiler/token.h index 9cc5b53881..62212b74ce 100644 --- a/include/genn/genn/transpiler/token.h +++ b/include/genn/genn/transpiler/token.h @@ -37,7 +37,7 @@ struct Token SHIFT_LEFT_EQUAL, SHIFT_RIGHT_EQUAL, // Literals - IDENTIFIER, UINT32_NUMBER, INT32_NUMBER, FLOAT_NUMBER, DOUBLE_NUMBER, SCALAR_NUMBER, STRING, + IDENTIFIER, UINT32_NUMBER, INT32_NUMBER, FLOAT_NUMBER, DOUBLE_NUMBER, STRING, // Types TYPE_SPECIFIER, diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index dc3b0b3917..26daae54d3 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -15,7 +15,6 @@ #include "transpiler/scanner.h" #include "transpiler/standardLibrary.h" #include "transpiler/typeChecker.h" -#include "transpiler/transpilerUtils.h" using namespace GeNN; @@ -75,7 +74,7 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::TypeC // Scan, parse and type-check update code ErrorHandler errorHandler; const std::string code = upgradeCodeString(cm->getUpdateCode()); - const auto tokens = Scanner::scanSource(code, errorHandler); + const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); m_UpdateStatements = Parser::parseBlockItemList(tokens, errorHandler); m_ResolvedTypes = TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, typeContext, errorHandler); } @@ -364,7 +363,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // Scan, parse and type-check update code ErrorHandler errorHandler; const std::string code = upgradeCodeString(cm->getUpdateCode()); - const auto tokens = Scanner::scanSource(code, errorHandler); + const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); m_UpdateStatements = Parser::parseBlockItemList(tokens, errorHandler); m_ResolvedTypes = TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, typeContext, errorHandler); } diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index 476138b596..f8e85fc796 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -12,7 +12,6 @@ // Transpiler includes #include "transpiler/errorHandler.h" -#include "transpiler/transpilerUtils.h" using namespace GeNN; using namespace GeNN::Transpiler; @@ -65,9 +64,20 @@ const std::map, Token::Type> integerLiteralTokenTypes{ class ScanState { public: - ScanState(std::string_view source, const std::unordered_set &typedefNames, ErrorHandlerBase &errorHandler) - : m_Start(0), m_Current(0), m_Line(1), m_Source(source), m_TypedefNames(typedefNames), m_ErrorHandler(errorHandler) - {} + ScanState(std::string_view source, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) + : m_Start(0), m_Current(0), m_Line(1), m_Source(source), m_Context(context), m_ErrorHandler(errorHandler) + { + const auto &scalarType = context.at("scalar"); + if (scalarType == Type::Float) { + m_ScalarTokenType = Token::Type::FLOAT_NUMBER; + } + else if (scalarType == Type::Double) { + m_ScalarTokenType = Token::Type::DOUBLE_NUMBER; + } + else { + throw std::runtime_error("Unsupported scalar type '" + scalarType.getName() + "'"); + } + } //--------------------------------------------------------------------------- // Public API @@ -130,8 +140,10 @@ class ScanState } bool isTypedefIdentifier(std::string_view lexeme) { - return (m_TypedefNames.find(std::string{lexeme}) != m_TypedefNames.cend()); + return (m_Context.find(std::string{lexeme}) != m_Context.cend()); } + + Token::Type getScalarTokenType() const{ return m_ScalarTokenType; } private: //--------------------------------------------------------------------------- @@ -141,9 +153,10 @@ class ScanState size_t m_Current; size_t m_Line; - const std::string_view m_Source; - const std::unordered_set m_TypedefNames; + std::string_view m_Source; + const Type::TypeContext &m_Context; ErrorHandlerBase &m_ErrorHandler; + Token::Type m_ScalarTokenType; }; bool isodigit(char c) @@ -234,9 +247,9 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) scanState.advance(); emplaceToken(tokens, Token::Type::DOUBLE_NUMBER, scanState); } - // Otherwise, emplace SCALAR_NUMBER token + // Otherwise, emplace literal with whatever type is specified else { - emplaceToken(tokens, Token::Type::SCALAR_NUMBER, scanState); + emplaceToken(tokens, scanState.getScalarTokenType(), scanState); } } // Otherwise, emplace integer token @@ -450,11 +463,11 @@ void scanToken(ScanState &scanState, std::vector &tokens) //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Scanner { -std::vector scanSource(const std::string_view &source, ErrorHandlerBase &errorHandler) +std::vector scanSource(const std::string_view &source, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { std::vector tokens; - ScanState scanState(source, {"scalar"}, errorHandler); + ScanState scanState(source, context, errorHandler); // Scan tokens while(!scanState.isAtEnd()) { From 4fee2ba6ac8b64e2452b23af9e5c42f4babc0d51 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 25 May 2023 15:52:07 +0100 Subject: [PATCH 162/725] parser now resolves types named in context --- include/genn/genn/transpiler/parser.h | 9 ++++--- include/genn/genn/type.h | 2 +- src/genn/genn/transpiler/parser.cc | 35 +++++++++++++++------------ src/genn/genn/type.cc | 26 +++++++++++--------- 4 files changed, 41 insertions(+), 31 deletions(-) diff --git a/include/genn/genn/transpiler/parser.h b/include/genn/genn/transpiler/parser.h index 7f9302ab93..9df753a5d6 100644 --- a/include/genn/genn/transpiler/parser.h +++ b/include/genn/genn/transpiler/parser.h @@ -5,6 +5,9 @@ #include #include +// GeNN includes +#include "type.h" + // Transpiler includes #include "transpiler/expression.h" #include "transpiler/statement.h" @@ -22,13 +25,13 @@ class ErrorHandlerBase; namespace GeNN::Transpiler::Parser { //! Parse expression from tokens -Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandlerBase &errorHandler); +Expression::ExpressionPtr parseExpression(const std::vector &tokens, const Type::TypeContext &context, ErrorHandlerBase &errorHandler); //! Parse block item list from tokens /*! Block item lists are function body scope list of statements */ -Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandlerBase &errorHandler); +Statement::StatementList parseBlockItemList(const std::vector &tokens, const Type::TypeContext &context, ErrorHandlerBase &errorHandler); //! Parse type from tokens -const GeNN::Type::ResolvedType parseNumericType(const std::vector &tokens, ErrorHandlerBase &errorHandler); +const GeNN::Type::ResolvedType parseNumericType(const std::vector &tokens, const Type::TypeContext &context, ErrorHandlerBase &errorHandler); } // MiniParse::MiniParse diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 514b502abf..221a279be4 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -336,7 +336,7 @@ inline static const ResolvedType Double = CREATE_NUMERIC(double, 60, ""); GENN_EXPORT ResolvedType parseNumeric(const std::string &typeString); //! Look up numeric type based on set of type specifiers -GENN_EXPORT ResolvedType getNumericType(const std::set &typeSpecifiers); +GENN_EXPORT ResolvedType getNumericType(const std::set &typeSpecifiers, const TypeContext &context); //! Apply C type promotion rules to numeric type GENN_EXPORT ResolvedType getPromotedType(const ResolvedType &type); diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 04cadacf46..8cdae68804 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -17,6 +17,7 @@ // Transpiler includes #include "transpiler/errorHandler.h" +using namespace GeNN; using namespace GeNN::Transpiler; //--------------------------------------------------------------------------- @@ -38,8 +39,8 @@ class ParseError class ParserState { public: - ParserState(const std::vector &tokens, ErrorHandlerBase &errorHandler) - : m_Current(0), m_Tokens(tokens), m_ErrorHandler(errorHandler) + ParserState(const std::vector &tokens, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) + : m_Current(0), m_Tokens(tokens), m_Context(context), m_ErrorHandler(errorHandler) {} //--------------------------------------------------------------------------- @@ -128,6 +129,7 @@ class ParserState bool isAtEnd() const { return (peek().type == Token::Type::END_OF_FILE); } + const Type::TypeContext &getContext() const{ return m_Context; } private: //--------------------------------------------------------------------------- @@ -136,6 +138,7 @@ class ParserState size_t m_Current; const std::vector &m_Tokens; + const Type::TypeContext &m_Context; ErrorHandlerBase &m_ErrorHandler; }; @@ -183,7 +186,7 @@ Expression::ExpressionPtr parseBinary(ParserState &parserState, N nonTerminal, s return expression; } -const GeNN::Type::ResolvedType *parseDeclarationSpecifiers(ParserState &parserState) +GeNN::Type::ResolvedType parseDeclarationSpecifiers(ParserState &parserState) { using namespace GeNN::Type; @@ -215,18 +218,18 @@ const GeNN::Type::ResolvedType *parseDeclarationSpecifiers(ParserState &parserSt } while(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER, Token::Type::STAR})); // Lookup numeric type - const Type::ResolvedType *type = getNumericType(typeSpecifiers); + Type::ResolvedType type = getNumericType(typeSpecifiers, parserState.getContext()); // If there are any type qualifiers, add const // **THINK** this relies of const being only qualifier if(!typeQualifiers.empty()) { - type = type->getQualifiedType(Qualifier::CONSTANT); + type = type.addQualifier(Qualifier::CONSTANT); } // Loop through levels of pointer indirection // **THINK** this relies of const being only qualifier for(const auto &p : pointerTypeQualifiers) { - type = type->getPointerType(p.empty() ? Qualifier{0} : Qualifier::CONSTANT); + type = type.createPointer(p.empty() ? Qualifier{0} : Qualifier::CONSTANT); } return type; } @@ -238,7 +241,7 @@ Expression::ExpressionPtr parsePrimary(ParserState &parserState) // constant // "(" expression ")" if (parserState.match({Token::Type::FALSE, Token::Type::TRUE, Token::Type::STRING, - Token::Type::DOUBLE_NUMBER, Token::Type::FLOAT_NUMBER, Token::Type::SCALAR_NUMBER, + Token::Type::DOUBLE_NUMBER, Token::Type::FLOAT_NUMBER, Token::Type::INT32_NUMBER, Token::Type::UINT32_NUMBER})) { return std::make_unique(parserState.previous()); } @@ -379,7 +382,7 @@ Expression::ExpressionPtr parseCast(ParserState &parserState) // If this is followed by some part of a type declarator if(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER})) { // Parse declaration specifiers - const auto *type = parseDeclarationSpecifiers(parserState); + const auto type = parseDeclarationSpecifiers(parserState); const auto closingParen = parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after cast type."); @@ -798,7 +801,7 @@ Statement::StatementPtr parseDeclaration(ParserState &parserState) // "const" // Parse declaration specifiers - const auto *type = parseDeclarationSpecifiers(parserState); + const auto type = parseDeclarationSpecifiers(parserState); // Read init declarator list std::vector> initDeclaratorList; @@ -851,9 +854,9 @@ std::unique_ptr parseBlockItem(ParserState &parserState) //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Parser { -Expression::ExpressionPtr parseExpression(const std::vector &tokens, ErrorHandlerBase &errorHandler) +Expression::ExpressionPtr parseExpression(const std::vector &tokens, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { - ParserState parserState(tokens, errorHandler); + ParserState parserState(tokens, context, errorHandler); try { return parseExpression(parserState); @@ -863,9 +866,9 @@ Expression::ExpressionPtr parseExpression(const std::vector &tokens, Erro } } //--------------------------------------------------------------------------- -Statement::StatementList parseBlockItemList(const std::vector &tokens, ErrorHandlerBase &errorHandler) +Statement::StatementList parseBlockItemList(const std::vector &tokens, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { - ParserState parserState(tokens, errorHandler); + ParserState parserState(tokens, context, errorHandler); std::vector> statements; while(!parserState.isAtEnd()) { @@ -874,9 +877,9 @@ Statement::StatementList parseBlockItemList(const std::vector &tokens, Er return statements; } //--------------------------------------------------------------------------- -const GeNN::Type::NumericBase *parseNumericType(const std::vector &tokens, ErrorHandlerBase &errorHandler) +const GeNN::Type::ResolvedType parseNumericType(const std::vector &tokens, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { - ParserState parserState(tokens, errorHandler); + ParserState parserState(tokens, context, errorHandler); std::set typeSpecifiers; while(parserState.match(Token::Type::TYPE_SPECIFIER)) { if(!typeSpecifiers.insert(parserState.previous().lexeme).second) { @@ -885,6 +888,6 @@ const GeNN::Type::NumericBase *parseNumericType(const std::vector &tokens }; // Return numeric type - return GeNN::Type::getNumericType(typeSpecifiers); + return GeNN::Type::getNumericType(typeSpecifiers, context); } } diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 59e974cb91..b1004cb452 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -48,8 +48,6 @@ const std::map, Type::ResolvedType> numericTypeSpecifiers{ {{"float"}, Type::Float}, {{"double"}, Type::Double}}; //---------------------------------------------------------------------------- -const std::set scalarTypeSpecifier{{"scalar"}}; -//---------------------------------------------------------------------------- // Mapping of signed integer numericTypeSpecifiers to their unsigned equivalents const std::map unsignedType{ {Type::Int8, Type::Uint8}, @@ -144,18 +142,24 @@ ResolvedType parseNumeric(const std::string &typeString) return type; } //---------------------------------------------------------------------------- -ResolvedType getNumericType(const std::set &typeSpecifiers) +ResolvedType getNumericType(const std::set &typeSpecifiers, const TypeContext &context) { - // If type matches scalar type specifiers - if(typeSpecifiers == scalarTypeSpecifier) { - assert(false); - //return new NumericTypedef("scalar"); + // If type is numeric, return + const auto type = numericTypeSpecifiers.find(typeSpecifiers); + if (type != numericTypeSpecifiers.cend()) { + return type->second; } - // Otherwise else { - const auto type = numericTypeSpecifiers.find(typeSpecifiers); - //return (type == numericTypeSpecifiers.cend()) ? nullptr : type->second; - return type->second; + // **YUCK** use sets everywhere + if (typeSpecifiers.size() == 1) { + const auto contextType = context.find(*typeSpecifiers.begin()); + if (contextType != context.cend()) { + return contextType->second; + } + } + + // **TODO** improve error + throw std::runtime_error("Unknown numeric type specifier"); } } //---------------------------------------------------------------------------- From 9dac81d0986ce67692ff67b002513e21ba9034fc Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 25 May 2023 17:08:07 +0100 Subject: [PATCH 163/725] project compiles --- .../backends/single_threaded_cpu/backend.h | 2 +- .../genn/genn/code_generator/backendBase.h | 17 ++-- .../genn/genn/code_generator/environment.h | 11 +-- .../genn/genn/code_generator/groupMerged.h | 4 +- .../groupMergedTypeEnvironment.h | 31 ++++--- .../genn/code_generator/modelSpecMerged.h | 5 +- include/genn/genn/models.h | 2 +- .../genn/genn/transpiler/standardLibrary.h | 2 +- include/genn/genn/transpiler/typeChecker.h | 4 +- include/genn/genn/type.h | 2 +- .../backends/single_threaded_cpu/backend.cc | 26 +++--- .../customConnectivityUpdateGroupMerged.cc | 77 ++++++++-------- .../code_generator/customUpdateGroupMerged.cc | 87 +++++++++---------- src/genn/genn/code_generator/groupMerged.cc | 2 +- .../genn/code_generator/initGroupMerged.cc | 2 +- .../genn/code_generator/modelSpecMerged.cc | 9 +- .../presynapticUpdateStrategySIMT.cc | 36 ++++---- src/genn/genn/customConnectivityUpdate.cc | 4 +- src/genn/genn/customUpdate.cc | 4 +- src/genn/genn/synapseGroup.cc | 6 +- src/genn/genn/transpiler/prettyPrinter.cc | 66 +++----------- src/genn/genn/transpiler/standardLibrary.cc | 5 +- src/genn/genn/transpiler/typeChecker.cc | 28 +++--- src/genn/genn/type.cc | 9 +- 24 files changed, 206 insertions(+), 235 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 6f960c250a..155a4a8db9 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -119,7 +119,7 @@ class BACKEND_EXPORT Backend : public BackendBase virtual std::string getMergedGroupFieldHostTypeName(const Type::ResolvedType &type) const final; //! When generating merged structures what type to use for simulation RNGs - virtual const Type::ResolvedType &getMergedGroupSimRNGType() const final; + virtual std::optional getMergedGroupSimRNGType() const final; virtual void genPopVariableInit(CodeStream &os,const Substitutions &kernelSubs, Handler handler) const final; virtual void genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index fe3a239749..ddc1e4e250 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -309,7 +310,7 @@ class GENN_EXPORT BackendBase virtual std::string getMergedGroupFieldHostTypeName(const Type::ResolvedType &type) const = 0; //! When generating merged structures what type to use for simulation RNGs - virtual const Type::ResolvedType &getMergedGroupSimRNGType() const = 0; + virtual std::optional getMergedGroupSimRNGType() const = 0; virtual void genPopVariableInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0; virtual void genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, @@ -522,9 +523,10 @@ class GENN_EXPORT BackendBase for (const auto &v : cm->getVars()) { // If variable is a reduction target, define variable initialised to correct initial value for reduction if (v.access & VarAccessModeAttribute::REDUCE) { - os << v.type->getName() << " lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), v.type) << ";" << std::endl; - reductionTargets.emplace_back(v.name, v.type, getVarAccessMode(v.access), - cg.getVarIndex(getVarAccessDuplication(v.access), idx)); + const auto resolvedType = v.type.resolve(cg.getTypeContext()); + os << resolvedType.getName() << " lr" << v.name << " = " << getReductionInitialValue(getVarAccessMode(v.access), resolvedType) << ";" << std::endl; + reductionTargets.push_back({v.name, resolvedType, getVarAccessMode(v.access), + cg.getVarIndex(getVarAccessDuplication(v.access), idx)}); } } @@ -534,9 +536,10 @@ class GENN_EXPORT BackendBase // If variable reference is a reduction target, define variable initialised to correct initial value for reduction if (modelVarRef.access & VarAccessModeAttribute::REDUCE) { - os << modelVarRef.type->getName() << " lr" << modelVarRef.name << " = " << getReductionInitialValue(modelVarRef.access, modelVarRef.type) << ";" << std::endl; - reductionTargets.emplace_back(modelVarRef.name, modelVarRef.type, modelVarRef.access, - getVarRefIndexFn(varRef, idx)); + const auto resolvedType = modelVarRef.type.resolve(cg.getTypeContext()); + os << resolvedType.getName() << " lr" << modelVarRef.name << " = " << getReductionInitialValue(modelVarRef.access, resolvedType) << ";" << std::endl; + reductionTargets.push_back({modelVarRef.name, resolvedType, modelVarRef.access, + getVarRefIndexFn(varRef, idx)}); } } return reductionTargets; diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 621e55cd12..b60d7ef220 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -176,8 +176,8 @@ class EnvironmentLocalVarCache : public EnvironmentExternal using GetIndexFn = std::function; public: - EnvironmentLocalVarCache(const G &group, EnvironmentExternal &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") - : EnvironmentExternal(static_cast(enclosing)), m_Group(group), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) + EnvironmentLocalVarCache(const G &group, const Type::TypeContext &context, EnvironmentExternal &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") + : EnvironmentExternal(static_cast(enclosing)), m_Group(group), m_Context(context), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) { // Add name of each definition to map, initially with value set to value const auto defs = A(m_Group).getDefs(); @@ -203,7 +203,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal if(v.access & VarAccessMode::READ_ONLY) { getContextStream() << "const "; } - getContextStream() << v.type->getName() << " " << m_LocalPrefix << v.name; + getContextStream() << v.type.resolve(m_Context).getName() << " " << m_LocalPrefix << v.name; // If this isn't a reduction, read value from memory // **NOTE** by not initialising these variables for reductions, @@ -257,10 +257,11 @@ class EnvironmentLocalVarCache : public EnvironmentExternal // Members //------------------------------------------------------------------------ const G &m_Group; + const Type::TypeContext &m_Context; std::ostringstream m_ContentsStream; CodeStream m_Contents; - const std::string m_LocalPrefix; - const GetIndexFn m_GetIndex; + std::string m_LocalPrefix; + GetIndexFn m_GetIndex; std::unordered_map m_VariablesReferenced; }; } // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 4a16ac6892..897459f747 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -70,8 +70,8 @@ class GroupMerged : m_Index(index), m_TypeContext(typeContext), m_Groups(std::move(groups)) {} - //GroupMerged(const GroupMerged&) = delete; - //GroupMerged(GroupMerged&&) = default; + GroupMerged(const GroupMerged&) = delete; + GroupMerged(GroupMerged&&) = default; //------------------------------------------------------------------------ // Public API diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index adf62a1f4f..93079accbf 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -83,9 +83,14 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa void definePointerField(const Type::ResolvedType &type, const std::string &name,const std::string &prefix, VarAccessMode access) { - const auto *qualifiedType = type.addQualifier((access & VarAccessModeAttribute::READ_ONLY) ? Type::Qualifier::CONSTANT : Type::Qualifier{0}); + const auto qualifiedType = (access & VarAccessModeAttribute::READ_ONLY) ? type.addQualifier(Type::Qualifier::CONSTANT) : type; defineField(qualifiedType, name, - type->getPointerType(), name, [prefix](const auto &g, size_t) { return prefix + g.getName(); }); + type.createPointer(), name, [prefix](const auto &g, size_t) { return prefix + g.getName(); }); + } + + void definePointerField(const Type::UnresolvedType &type, const std::string &name, const std::string &prefix, VarAccessMode access) + { + definePointerField(type.resolve(m_GroupMerged.getTypeContext()), name, prefix, access); } void defineScalarField(const std::string &name, typename G::GetFieldDoubleValueFunc getFieldValue) @@ -94,8 +99,8 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa m_GroupMerged.getScalarType(), name, [getFieldValue, this](const auto &g, size_t i) { - return (Utils::writePreciseString(getFieldValue(g, i), m_GroupMerged.getScalarType()->getMaxDigits10(m_GroupMerged.getTypeContext())) - + m_GroupMerged.getScalarType()->getLiteralSuffix(m_GroupMerged.getTypeContext())); + return (Utils::writePreciseString(getFieldValue(g, i), m_GroupMerged.getScalarType().getNumeric().maxDigits10) + + m_GroupMerged.getScalarType().getNumeric().literalSuffix); }); } @@ -114,7 +119,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa } // Otherwise, just add a const-qualified scalar to the type environment else { - defineField(m_GroupMerged.getScalarType()->getQualifiedType(Type::Qualifier::CONSTANT), p + suffix); + defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), p + suffix); } } } @@ -133,7 +138,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa }); } else { - defineField(m_GroupMerged.getScalarType()->getQualifiedType(Type::Qualifier::CONSTANT), d.name + suffix); + defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), d.name + suffix); } } } @@ -152,9 +157,10 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa // Loop through variables for(const auto &v : varReferences) { // If variable access is read-only, qualify type with const - const auto *qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? v.type->getQualifiedType(Type::Qualifier::CONSTANT) : v.type; + const auto resolvedType = v.type.resolve(m_GroupMerged.getTypeContext()); + const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addQualifier(Type::Qualifier::CONSTANT) : resolvedType; defineField(qualifiedType, v.name, - v.type->getPointerType(), v.name, + resolvedType.createPointer(), v.name, [arrayPrefix, getVarRefFn, v](const auto &g, size_t) { const auto varRef = getVarRefFn(g).at(v.name); @@ -162,12 +168,13 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa }); } } - + void defineEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "") { for(const auto &e : egps) { - defineField(e.type->getPointerType(), e.name, - e.type->getPointerType(), e.name + varName, + const auto pointerType = e.type.resolve(m_GroupMerged.getTypeContext()).createPointer(); + defineField(pointerType, e.name, + pointerType, e.name + varName, [arrayPrefix, e, varName](const auto &g, size_t) { return arrayPrefix + e.name + varName + g.getName(); @@ -180,7 +187,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa //--------------------------------------------------------------------------- // Private methods //--------------------------------------------------------------------------- - void addField(std::pair> &type) + void addField(std::pair> &type) { // If this type has an associated field if (type.second) { diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 87c80cb1de..846ce07bd9 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -323,12 +323,11 @@ class GENN_EXPORT ModelSpecMerged const auto &g = mergedGroups.back().getGroups()[groupIndex]; // Add reference to this group's variable to data structure - const auto *pointerType = dynamic_cast(std::get<0>(f)); - assert(pointerType); + assert(std::get<0>(f).isPointer()); m_MergedEGPs[std::get<2>(f)(g, groupIndex)].emplace( std::piecewise_construct, std::forward_as_tuple(MergedGroup::name), - std::forward_as_tuple(i, groupIndex, pointerType, std::get<1>(f), host)); + std::forward_as_tuple(i, groupIndex, std::get<0>(f), std::get<1>(f), host)); } } } diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 643b6dddb0..04d71ed16d 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -313,7 +313,7 @@ void checkVarReferences(const std::unordered_map &varRefs, const // Check types of variable references against those specified in model // **THINK** this is rather conservative but I think not allowing scalar and whatever happens to be scalar type is ok - if(varRef.getVar().type->getName() != modelVarRef.type->getName()) { + if(varRef.getVar().type != modelVarRef.type) { throw std::runtime_error("Incompatible type for variable reference '" + modelVarRef.name + "'"); } diff --git a/include/genn/genn/transpiler/standardLibrary.h b/include/genn/genn/transpiler/standardLibrary.h index d6e8598e17..7a4ceaa228 100644 --- a/include/genn/genn/transpiler/standardLibrary.h +++ b/include/genn/genn/transpiler/standardLibrary.h @@ -41,7 +41,7 @@ class FunctionEnvironment : public CodeGenerator::EnvironmentExternal //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual std::string getName(const std::string &name, const Type::ResolvedType &type) final; + virtual std::string getName(const std::string &name, std::optional type = std::nullopt) final; virtual CodeGenerator::CodeStream &getStream() final; }; } // namespace GeNN::Transpiler::StandardLibrary diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 3e23220fe7..d0cf01b92a 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -56,8 +56,8 @@ class EnvironmentBase // Free functions //--------------------------------------------------------------------------- ResolvedTypeMap typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler); + ErrorHandlerBase &errorHandler); Type::ResolvedType typeCheck(const Expression::Base *expression, EnvironmentBase &environment, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler); + ErrorHandlerBase &errorHandler); } // namespace MiniParse::GeNN::Transpiler diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 221a279be4..29a64712ae 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -333,7 +333,7 @@ inline static const ResolvedType Float = CREATE_NUMERIC(float, 50, "f"); inline static const ResolvedType Double = CREATE_NUMERIC(double, 60, ""); //! Parse a numeric type -GENN_EXPORT ResolvedType parseNumeric(const std::string &typeString); +GENN_EXPORT ResolvedType parseNumeric(const std::string &typeString, const TypeContext &context); //! Look up numeric type based on set of type specifiers GENN_EXPORT ResolvedType getNumericType(const std::set &typeSpecifiers, const TypeContext &context); diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 447fd6a0f7..55d2f12761 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -153,7 +153,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision()->getName())); + Substitutions funcSubs(getFunctionTemplates(model.getPrecision().getName())); funcSubs.addVarSubstitution("t", "t"); funcSubs.addVarSubstitution("batch", "0"); @@ -320,7 +320,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge os << "void updateSynapses(timepoint t)"; { CodeStream::Scope b(os); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision()->getName())); + Substitutions funcSubs(getFunctionTemplates(model.getPrecision().getName())); funcSubs.addVarSubstitution("t", "t"); funcSubs.addVarSubstitution("batch", "0"); @@ -815,7 +815,7 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHa os << "void initialize()"; { CodeStream::Scope b(os); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision()->getName())); + Substitutions funcSubs(getFunctionTemplates(model.getPrecision().getName())); Timer t(os, "init", model.isTimingEnabled()); @@ -1070,7 +1070,7 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHa os << "void initializeSparse()"; { CodeStream::Scope b(os); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision()->getName())); + Substitutions funcSubs(getFunctionTemplates(model.getPrecision().getName())); Timer t(os, "initSparse", model.isTimingEnabled()); @@ -1224,9 +1224,9 @@ void Backend::genDefinitionsPreamble(CodeStream &os, const ModelSpecMerged &mode // If a global RNG is required, define standard host distributions as recreating them each call is slow if(isGlobalHostRNGRequired(modelMerged)) { - os << "EXPORT_VAR " << "std::uniform_real_distribution<" << model.getPrecision()->getName() << "> standardUniformDistribution;" << std::endl; - os << "EXPORT_VAR " << "std::normal_distribution<" << model.getPrecision()->getName() << "> standardNormalDistribution;" << std::endl; - os << "EXPORT_VAR " << "std::exponential_distribution<" << model.getPrecision()->getName() << "> standardExponentialDistribution;" << std::endl; + os << "EXPORT_VAR " << "std::uniform_real_distribution<" << model.getPrecision().getName() << "> standardUniformDistribution;" << std::endl; + os << "EXPORT_VAR " << "std::normal_distribution<" << model.getPrecision().getName() << "> standardNormalDistribution;" << std::endl; + os << "EXPORT_VAR " << "std::exponential_distribution<" << model.getPrecision().getName() << "> standardExponentialDistribution;" << std::endl; os << std::endl; } } @@ -1272,9 +1272,9 @@ void Backend::genRunnerPreamble(CodeStream &os, const ModelSpecMerged &modelMerg // If a global RNG is required, implement standard host distributions as recreating them each call is slow if(isGlobalHostRNGRequired(modelMerged)) { - os << "std::uniform_real_distribution<" << model.getPrecision()->getName() << "> standardUniformDistribution(" << modelMerged.scalarExpr(0.0) << ", " << modelMerged.scalarExpr(1.0) << ");" << std::endl; - os << "std::normal_distribution<" << model.getPrecision()->getName() << "> standardNormalDistribution(" << modelMerged.scalarExpr(0.0) << ", " << modelMerged.scalarExpr(1.0) << ");" << std::endl; - os << "std::exponential_distribution<" << model.getPrecision()->getName() << "> standardExponentialDistribution(" << modelMerged.scalarExpr(1.0) << ");" << std::endl; + os << "std::uniform_real_distribution<" << model.getPrecision().getName() << "> standardUniformDistribution(" << modelMerged.scalarExpr(0.0) << ", " << modelMerged.scalarExpr(1.0) << ");" << std::endl; + os << "std::normal_distribution<" << model.getPrecision().getName() << "> standardNormalDistribution(" << modelMerged.scalarExpr(0.0) << ", " << modelMerged.scalarExpr(1.0) << ");" << std::endl; + os << "std::exponential_distribution<" << model.getPrecision().getName() << "> standardExponentialDistribution(" << modelMerged.scalarExpr(1.0) << ");" << std::endl; os << std::endl; } os << std::endl; @@ -1380,10 +1380,10 @@ std::string Backend::getMergedGroupFieldHostTypeName(const Type::ResolvedType &t return type.getName(); } //-------------------------------------------------------------------------- -const Type::ResolvedType &Backend::getMergedGroupSimRNGType() const +std::optional Backend::getMergedGroupSimRNGType() const { assert(false); - return nullptr; + return std::nullopt; } //-------------------------------------------------------------------------- void Backend::genPopVariableInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const @@ -1664,7 +1664,7 @@ void Backend::genPresynapticUpdate(CodeStream &os, const ModelSpecMerged &modelM connSubs.applyCheckUnreplaced(value, "toeplitz diagonal build state var : merged" + std::to_string(sg.getIndex())); //value = ensureFtype(value, modelMerged.getModel().getPrecision()); - os << d.type->getName() << " " << d.name << " = " << value << ";" << std::endl; + os << d.type.resolve(sg.getTypeContext()).getName() << " " << d.name << " = " << value << ";" << std::endl; } // Detect spike events or spikes and do the update diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 04257d158f..8df17d8b7c 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -18,19 +18,19 @@ CustomConnectivityUpdateGroupMergedBase::CustomConnectivityUpdateGroupMergedBase { using namespace Type; - addField("numSrcNeurons", - [](const auto &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); - }); + addField(Uint32, "numSrcNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); + }); - addField("numTrgNeurons", - [](const auto &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(sgInternal->getTrgNeuronGroup()->getNumNeurons()); - }); + addField(Uint32, "numTrgNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getTrgNeuronGroup()->getNumNeurons()); + }); // Add heterogeneous custom update model parameters addHeterogeneousParams( @@ -80,11 +80,11 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t dependentVarsList.sort([](const auto &a, const auto &b) { boost::uuids::detail::sha1 hashA; - Utils::updateHash(a.getVar().type->getName(), hashA); + Type::updateHash(a.getVar().type, hashA); Utils::updateHash(getVarAccessDuplication(a.getVar().access), hashA); boost::uuids::detail::sha1 hashB; - Utils::updateHash(b.getVar().type->getName(), hashB); + Type::updateHash(b.getVar().type, hashB); Utils::updateHash(getVarAccessDuplication(b.getVar().access), hashB); return (hashA.get_digest() < hashB.get_digest()); @@ -102,22 +102,22 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t })); - addField("rowStride", - [&backend](const auto &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(backend.getSynapticMatrixRowStride(*sgInternal)); - }); + addField(Uint32, "rowStride", + [&backend](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(backend.getSynapticMatrixRowStride(*sgInternal)); + }); assert(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE); - addField(getArchetype().getSynapseGroup()->getSparseIndType()->getPointerType(), "ind", + addField(getArchetype().getSynapseGroup()->getSparseIndType().createPointer(), "ind", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); }); - addField(Uint32::getInstance()->getPointerType(), "rowLength", + addField(Uint32.createPointer(), "rowLength", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); @@ -125,7 +125,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // If some presynaptic variables are delayed, add delay pointer if (getArchetype().getPreDelayNeuronGroup() != nullptr) { - addField(Uint32::getInstance()->getPointerType(), "preSpkQuePtr", + addField(Uint32.createPointer(), "preSpkQuePtr", [&backend](const auto &cg, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPreDelayNeuronGroup()->getName(); @@ -134,7 +134,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // If some postsynaptic variables are delayed, add delay pointer if (getArchetype().getPostDelayNeuronGroup() != nullptr) { - addField(Uint32::getInstance()->getPointerType(), "postSpkQuePtr", + addField(Uint32.createPointer(), "postSpkQuePtr", [&backend](const auto &cg, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getPostDelayNeuronGroup()->getName(); @@ -143,7 +143,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // If this backend requires per-population RNGs and this group requires one if(backend.isPopulationRNGRequired() && getArchetype().isRowSimRNGRequired()){ - addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); + addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); } // Add variables to struct @@ -166,7 +166,8 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t // Loop through sorted dependent variables for(size_t i = 0; i < getSortedArchetypeDependentVars().size(); i++) { - addField(getSortedArchetypeDependentVars().at(i).getVar().type->getPointerType(), "_dependentVar" + std::to_string(i), + auto resolvedType = getSortedArchetypeDependentVars().at(i).getVar().type.resolve(getTypeContext()); + addField(resolvedType.createPointer(), "_dependentVar" + std::to_string(i), [i, &backend, this](const auto&, size_t g) { const auto &varRef = m_SortedDependentVars[g][i]; @@ -266,7 +267,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back if ((batchSize > 1) && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) { // Copy parameter into a register (just incase it's e.g. a RNG call) and copy into all batches - addSynapse << "const " << ccuVarRefs[i].type->getName() << " _" << ccuVarRefs[i].name << "Val = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; + addSynapse << "const " << ccuVarRefs[i].type.resolve(getTypeContext()).getName() << " _" << ccuVarRefs[i].name << "Val = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; addSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { CodeStream::Scope b(addSynapse); @@ -440,12 +441,13 @@ CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged // Add host extra global parameters for(const auto &e : cm->getExtraGlobalParams()) { - addField(e.type->getPointerType(), e.name, + const auto resolvedType = e.type.resolve(getTypeContext()); + addField(resolvedType.createPointer(), e.name, [e](const auto &g, size_t) { return e.name + g.getName(); }, GroupMergedFieldType::HOST_DYNAMIC); if(!backend.getDeviceVarPrefix().empty()) { - addField(e.type->getPointerType(), backend.getDeviceVarPrefix() + e.name, + addField(resolvedType.createPointer(), backend.getDeviceVarPrefix() + e.name, [e, &backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + e.name + g.getName(); @@ -485,10 +487,12 @@ void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase & // Loop through EGPs for(const auto &egp : cm->getExtraGlobalParams()) { + const auto resolvedType = egp.type.resolve(getTypeContext()); + // Generate code to push this EGP with count specified by $(0) std::stringstream pushStream; CodeStream push(pushStream); - backend.genVariableDynamicPush(push, egp.type, egp.name, + backend.genVariableDynamicPush(push, resolvedType, egp.name, VarLocation::HOST_DEVICE, "$(0)", "group->"); // Add substitution @@ -497,7 +501,7 @@ void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase & // Generate code to pull this EGP with count specified by $(0) std::stringstream pullStream; CodeStream pull(pullStream); - backend.genVariableDynamicPull(pull, egp.type, egp.name, + backend.genVariableDynamicPull(pull, resolvedType, egp.name, VarLocation::HOST_DEVICE, "$(0)", "group->"); // Add substitution @@ -523,13 +527,15 @@ void CustomConnectivityHostUpdateGroupMerged::addVarPushPullFuncSubs(const Backe { // Loop through variables for(const auto &v : vars) { + const auto resolvedType = v.type.resolve(getTypeContext()); + // If var is located on the host const auto loc = std::invoke(getVarLocationFn, getArchetype(), v.name); if (loc & VarLocation::HOST) { // Generate code to push this variable std::stringstream pushStream; CodeStream push(pushStream); - backend.genVariableDynamicPush(push, v.type, v.name, + backend.genVariableDynamicPush(push, resolvedType, v.name, loc, count, "group->"); // Add substitution @@ -539,7 +545,7 @@ void CustomConnectivityHostUpdateGroupMerged::addVarPushPullFuncSubs(const Backe // **YUCK** these EGP functions should probably just be called dynamic or something std::stringstream pullStream; CodeStream pull(pullStream); - backend.genVariableDynamicPull(pull, v.type, v.name, + backend.genVariableDynamicPull(pull, resolvedType, v.name, loc, count, "group->"); // Add substitution @@ -556,14 +562,15 @@ void CustomConnectivityHostUpdateGroupMerged::addVars(const BackendBase &backend // Loop through variables for(const auto &v : vars) { // If var is located on the host + const auto resolvedType = v.type.resolve(getTypeContext()); if (std::invoke(getVarLocationFn, getArchetype(), v.name) & VarLocation::HOST) { - addField(v.type->getPointerType(), v.name, + addField(resolvedType.createPointer(), v.name, [v](const auto &g, size_t) { return v.name + g.getName(); }, GroupMergedFieldType::HOST); if(!backend.getDeviceVarPrefix().empty()) { // **TODO** I think could use addPointerField - addField(v.type->getPointerType(), backend.getDeviceVarPrefix() + v.name, + addField(resolvedType.createPointer(), backend.getDeviceVarPrefix() + v.name, [v, &backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + v.name + g.getName(); diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 26daae54d3..fb46fa7fdd 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -37,11 +37,11 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::TypeC StandardLibrary::FunctionTypes stdLibraryEnv; GroupMergedTypeEnvironment typeEnvironment(*this, &stdLibraryEnv); - addField("size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); + addField(Uint32, "size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); // If some variables are delayed, add delay pointer if(getArchetype().getDelayNeuronGroup() != nullptr) { - addField(Uint32::getInstance()->getPointerType(), "spkQuePtr", + addField(Uint32.createPointer(), "spkQuePtr", [&backend](const auto &cg, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); @@ -75,8 +75,8 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::TypeC ErrorHandler errorHandler; const std::string code = upgradeCodeString(cm->getUpdateCode()); const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); - m_UpdateStatements = Parser::parseBlockItemList(tokens, errorHandler); - m_ResolvedTypes = TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, typeContext, errorHandler); + m_UpdateStatements = Parser::parseBlockItemList(tokens, typeContext, errorHandler); + m_ResolvedTypes = TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, errorHandler); } //---------------------------------------------------------------------------- bool CustomUpdateGroupMerged::isParamHeterogeneous(const std::string ¶mName) const @@ -120,7 +120,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varSubs( - getArchetype(), envSubs, + getArchetype(), getTypeContext(), envSubs, [this](const Models::VarInit&, VarAccess a) { return getVarIndex(getVarAccessDuplication(a), "id"); @@ -128,7 +128,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E // Create an environment which caches variable references in local variables if they are accessed EnvironmentLocalVarCache varRefSubs( - getArchetype(), varSubs, + getArchetype(), getTypeContext(), varSubs, [this](const Models::VarReference &v, VarAccessMode) { return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, @@ -229,7 +229,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varSubs( - getArchetype(), envSubs, + getArchetype(), getTypeContext(), envSubs, [this](const Models::VarInit&, VarAccess a) { return getVarIndex(getVarAccessDuplication(a), "id_syn"); @@ -237,7 +237,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back // Create an environment which caches variable references in local variables if they are accessed EnvironmentLocalVarCache varRefSubs( - getArchetype(), varSubs, + getArchetype(), getTypeContext(), varSubs, [this](const Models::WUVarReference &v, VarAccessMode) { return getVarRefIndex(getVarAccessDuplication(v.getVar().access), @@ -276,46 +276,46 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const for (size_t d = 0; d < getArchetype().getSynapseGroup()->getKernelSize().size(); d++) { // If this dimension has a heterogeneous size, add it to struct if (isKernelSizeHeterogeneous(d)) { - addField("kernelSize" + std::to_string(d), - [d](const auto &cu, size_t) - { - return std::to_string(cu.getSynapseGroup()->getKernelSize().at(d)); - }); + addField(Uint32, "kernelSize" + std::to_string(d), + [d](const auto &cu, size_t) + { + return std::to_string(cu.getSynapseGroup()->getKernelSize().at(d)); + }); } } } // Otherwise else { - addField("rowStride", - [&backend](const auto &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(backend.getSynapticMatrixRowStride(*sgInternal)); - }); + addField(Uint32, "rowStride", + [&backend](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(backend.getSynapticMatrixRowStride(*sgInternal)); + }); - addField("numSrcNeurons", - [](const auto &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); - }); + addField(Uint32, "numSrcNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); + }); - addField("numTrgNeurons", - [](const auto &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(sgInternal->getTrgNeuronGroup()->getNumNeurons()); - }); + addField(Uint32, "numTrgNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getTrgNeuronGroup()->getNumNeurons()); + }); // If synapse group has sparse connectivity if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addField(getArchetype().getSynapseGroup()->getSparseIndType()->getPointerType(), "ind", + addField(getArchetype().getSynapseGroup()->getSparseIndType().createPointer(), "ind", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); }); - addField(Uint32::getInstance()->getPointerType(), "rowLength", + addField(Uint32.createPointer(), "rowLength", [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); @@ -349,7 +349,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // If variable has a transpose if(getArchetype().getVarReferences().at(v.name).getTransposeSynapseGroup() != nullptr) { // Add field with transpose suffix, pointing to transpose var - addField(v.type->getPointerType(), v.name + "Transpose", + addField(v.type.resolve(getTypeContext()).createPointer(), v.name + "Transpose", [&backend, v](const auto &g, size_t) { const auto varRef = g.getVarReferences().at(v.name); @@ -364,8 +364,8 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const ErrorHandler errorHandler; const std::string code = upgradeCodeString(cm->getUpdateCode()); const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); - m_UpdateStatements = Parser::parseBlockItemList(tokens, errorHandler); - m_ResolvedTypes = TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, typeContext, errorHandler); + m_UpdateStatements = Parser::parseBlockItemList(tokens, typeContext, errorHandler); + m_ResolvedTypes = TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, errorHandler); } // ---------------------------------------------------------------------------- @@ -389,13 +389,12 @@ CustomUpdateHostReductionGroupMerged::CustomUpdateHostReductionGroupMerged(size_ { using namespace Type; - addField("size", - [](const auto &c, size_t) { return std::to_string(c.getSize()); }); + addField(Uint32, "size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); // If some variables are delayed, add delay pointer // **NOTE** this is HOST delay pointer if(getArchetype().getDelayNeuronGroup() != nullptr) { - addField(Uint32::getInstance()->getPointerType(), "spkQuePtr", + addField(Uint32.createPointer(), "spkQuePtr", [](const auto &cg, size_t) { return "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); @@ -414,9 +413,9 @@ CustomWUUpdateHostReductionGroupMerged::CustomWUUpdateHostReductionGroupMerged(s { using namespace Type; - addField("size", - [&backend](const auto &cg, size_t) - { - return std::to_string(cg.getSynapseGroup()->getMaxConnections() * (size_t)cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); - }); + addField(Uint32, "size", + [&backend](const auto &cg, size_t) + { + return std::to_string(cg.getSynapseGroup()->getMaxConnections() * (size_t)cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); + }); } diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 98b8833e69..a669d42a35 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -234,7 +234,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeConte if(backend.isPopulationRNGRequired() && getArchetype().isSimRNGRequired() && (!init || backend.isPopulationRNGInitialisedOnDevice())) { - addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); + addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); } // Loop through variables diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 26c72c5d7f..73b6c97084 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -1089,7 +1089,7 @@ CustomConnectivityUpdatePreInitGroupMerged::CustomConnectivityUpdatePreInitGroup // If this backend initialises population RNGs on device and this group requires one for simulation if(backend.isPopulationRNGRequired() && getArchetype().isRowSimRNGRequired() && backend.isPopulationRNGInitialisedOnDevice()) { - addPointerField(backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); + addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); } } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index 4b36e0691f..ee7688a855 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -566,8 +566,8 @@ boost::uuids::detail::sha1::digest_type ModelSpecMerged::getInitArchetypeHashDig //---------------------------------------------------------------------------- std::string ModelSpecMerged::scalarExpr(double value) const { - const auto *scalarType = dynamic_cast(m_TypeContext.at("scalar")); - return Utils::writePreciseString(value, scalarType->getMaxDigits10(m_TypeContext)) + scalarType->getLiteralSuffix(m_TypeContext); + const auto scalarNumeric = m_TypeContext.at("scalar").getNumeric(); + return Utils::writePreciseString(value, scalarNumeric.maxDigits10) + scalarNumeric.literalSuffix; } //---------------------------------------------------------------------------- bool ModelSpecMerged::anyPointerEGPs() const @@ -577,10 +577,7 @@ bool ModelSpecMerged::anyPointerEGPs() const // If there's any pointer EGPs, return true // **TODO** without scalar EGPS, all EGPS are pointer EGPS! if(std::any_of(e.second.cbegin(), e.second.cend(), - [](const MergedEGPDestinations::value_type &g) - { - return dynamic_cast(g.second.type); - })) + [](const MergedEGPDestinations::value_type &g){ return g.second.type.isPointer(); })) { return true; } diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index f8d08635bf..447d5fe168 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -156,12 +156,12 @@ void PreSpan::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, cons // If dendritic delay is required, use atomic operation to update dendritic delay buffer if(sg.getArchetype().isDendriticDelayRequired()) { synSubs.addFuncSubstitution("addToInSynDelay", 2, - backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, "ipost", "$(1)") + "], $(0))"); + backend.getAtomic(model.getPrecision()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, "ipost", "$(1)") + "], $(0))"); } // Otherwise, substitute global memory array for $(inSyn) else { synSubs.addFuncSubstitution("addToInSyn", 1, - backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, "ipost") + "], $(0))"); + backend.getAtomic(model.getPrecision()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, "ipost") + "], $(0))"); } if(sg.getArchetype().isPresynapticOutputRequired()) { @@ -184,7 +184,7 @@ void PreSpan::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, cons // Should this be in the Postamble? if(sg.getArchetype().isPresynapticOutputRequired()) { // write lrevInSyn to global memory if not 0 - os << "if(lrevInSyn != 0.0) " << backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, "preInd") + "], lrevInSyn);" << std::endl; + os << "if(lrevInSyn != 0.0) " << backend.getAtomic(model.getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, "preInd") + "], lrevInSyn);" << std::endl; } } @@ -343,7 +343,7 @@ void PostSpan::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, con // If dendritic delay is required, always use atomic operation to update dendritic delay buffer if(sg.getArchetype().isDendriticDelayRequired()) { synSubs.addFuncSubstitution("addToInSynDelay", 2, - backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, synSubs["id_post"], "$(1)") + "], $(0))"); + backend.getAtomic(model.getPrecision()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, synSubs["id_post"], "$(1)") + "], $(0))"); } // Otherwise else { @@ -359,13 +359,13 @@ void PostSpan::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, con // Otherwise, use global memory atomic else { synSubs.addFuncSubstitution("addToInSyn", 1, - backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, synSubs["id_post"]) + "], $(0))"); + backend.getAtomic(model.getPrecision()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, synSubs["id_post"]) + "], $(0))"); } } if(sg.getArchetype().isPresynapticOutputRequired()) { synSubs.addFuncSubstitution("addToPre", 1, - backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); + backend.getAtomic(model.getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); } if(trueSpike) { @@ -403,7 +403,7 @@ void PostSpan::genPostamble(CodeStream &os, const ModelSpecMerged &modelMerged, CodeStream::Scope b(os); const std::string inSyn = "group->inSyn[" + sg.getPostISynIndex(batchSize, popSubs["id"]) + "]"; if(sg.getArchetype().isPSModelFused()) { - os << backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) << "(&" << inSyn << ", linSyn);" << std::endl; + os << backend.getAtomic(model.getPrecision()) << "(&" << inSyn << ", linSyn);" << std::endl; } else { os << inSyn << " += linSyn;" << std::endl; @@ -416,7 +416,7 @@ void PostSpan::genPostamble(CodeStream &os, const ModelSpecMerged &modelMerged, os << "if(" << backend.getThreadID() << " < group->numTrgNeurons)"; { CodeGenerator::CodeStream::Scope b(os); - os << backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) << "(&group->inSyn[" << sg.getPostISynIndex(batchSize, backend.getThreadID()) << "], "; + os << backend.getAtomic(model.getPrecision()) << "(&group->inSyn[" << sg.getPostISynIndex(batchSize, backend.getThreadID()) << "], "; os << "shLg[" << backend.getThreadID() << "]); " << std::endl; } } @@ -577,12 +577,12 @@ void PreSpanProcedural::genUpdate(CodeStream &os, const ModelSpecMerged &modelMe // If dendritic delay is required, use atomic operation to update dendritic delay buffer if(sg.getArchetype().isDendriticDelayRequired()) { presynapticUpdateSubs.addFuncSubstitution("addToInSynDelay", 2, - backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); + backend.getAtomic(model.getPrecision()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); } // Otherwise, substitute global memory array for $(inSyn) else { presynapticUpdateSubs.addFuncSubstitution("addToInSyn", 1, - backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, "$(id_post)") + "], $(0))"); + backend.getAtomic(model.getPrecision()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, "$(id_post)") + "], $(0))"); } if(sg.getArchetype().isPresynapticOutputRequired()) { @@ -612,7 +612,7 @@ void PreSpanProcedural::genUpdate(CodeStream &os, const ModelSpecMerged &modelMe // Should this be in the Postamble? if(sg.getArchetype().isPresynapticOutputRequired()) { // write lrevInSyn to global memory if not 0 - os << "if(lrevInSyn != 0.0) " << backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, "preInd") + "], lrevInSyn);" << std::endl; + os << "if(lrevInSyn != 0.0) " << backend.getAtomic(model.getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, "preInd") + "], lrevInSyn);" << std::endl; } } @@ -753,7 +753,7 @@ void PostSpanBitmask::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerg if(sg.getArchetype().isPresynapticOutputRequired()) { synSubs.addFuncSubstitution("addToPre", 1, - backend.getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); + backend.getAtomic(modelMerged.getModel().getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); } if(trueSpike) { @@ -793,7 +793,7 @@ void PostSpanBitmask::genPostamble(CodeStream &os, const ModelSpecMerged &modelM CodeStream::Scope b(os); const std::string inSyn = "group->inSyn[" + sg.getPostISynIndex(modelMerged.getModel().getBatchSize(), "glbIdx") +"]"; if(sg.getArchetype().isPSModelFused()) { - os << backend.getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()) << "(&" << inSyn << ", shLg[shIdx]);" << std::endl; + os << backend.getAtomic(modelMerged.getModel().getPrecision()) << "(&" << inSyn << ", shLg[shIdx]);" << std::endl; } else { os << inSyn << " += shLg[shIdx];" << std::endl; @@ -871,7 +871,7 @@ void PostSpanToeplitz::genUpdate(CodeStream &os, const ModelSpecMerged &modelMer connSubs.applyCheckUnreplaced(value, "toeplitz diagonal build state var : merged" + std::to_string(sg.getIndex())); //value = ensureFtype(value, modelMerged.getModel().getPrecision()); - os << d.type->getName() << " " << d.name << " = " << value << ";" << std::endl; + os << d.type.resolve(sg.getTypeContext()).getName() << " " << d.name << " = " << value << ";" << std::endl; } os << "const unsigned int numSpikes = group->srcSpkCnt" << eventSuffix << "[" << sg.getPreSlot(batchSize) << "];" << std::endl; @@ -936,7 +936,7 @@ void PostSpanToeplitz::genUpdate(CodeStream &os, const ModelSpecMerged &modelMer // If dendritic delay is required, always use atomic operation to update dendritic delay buffer if(sg.getArchetype().isDendriticDelayRequired()) { presynapticUpdateSubs.addFuncSubstitution("addToInSynDelay", 2, - backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); + backend.getAtomic(model.getPrecision()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); } // Otherwise else { @@ -947,13 +947,13 @@ void PostSpanToeplitz::genUpdate(CodeStream &os, const ModelSpecMerged &modelMer // Otherwise, use global memory atomic else { presynapticUpdateSubs.addFuncSubstitution("addToInSyn", 1, - backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, "$(id_post)") + "], $(0))"); + backend.getAtomic(model.getPrecision()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, "$(id_post)") + "], $(0))"); } } if(sg.getArchetype().isPresynapticOutputRequired()) { presynapticUpdateSubs.addFuncSubstitution("addToPre", 1, - backend.getAtomic(model.getPrecision(), modelMerged.getTypeContext()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, presynapticUpdateSubs["id_pre"]) + "], $(0))"); + backend.getAtomic(model.getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, presynapticUpdateSubs["id_pre"]) + "], $(0))"); } // Generate presynaptic simulation code into new stringstream-backed code stream @@ -989,7 +989,7 @@ void PostSpanToeplitz::genPostamble(CodeStream &os, const ModelSpecMerged &model os << "if(" << backend.getThreadID() << " < group->numTrgNeurons)"; { CodeGenerator::CodeStream::Scope b(os); - os << backend.getAtomic(modelMerged.getModel().getPrecision(), modelMerged.getTypeContext()); + os << backend.getAtomic(modelMerged.getModel().getPrecision()); os << "(&group->inSyn[" << sg.getPostISynIndex(modelMerged.getModel().getBatchSize(), backend.getThreadID()) << "], "; os << "shLg[" << backend.getThreadID() << "]); " << std::endl; } diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index e7f2642a29..e45292ed3d 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -312,7 +312,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdate::getHashDigest( Utils::updateHash(getUpdateGroupName(), hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName(), hash); + Type::updateHash(getSynapseGroup()->getSparseIndType(), hash); // Because it adds and removes synapses, connectivity update has to update // ALL variables associated with synapse group being modified as well as @@ -327,7 +327,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdate::getHashDigest( [](const Models::WUVarReference &v) { boost::uuids::detail::sha1 hash; - Utils::updateHash(v.getVar().type->getName(), hash); + Type::updateHash(v.getVar().type, hash); Utils::updateHash(v.isDuplicated(), hash); return hash.get_digest(); }); diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 7d6d02b2c4..497fe2e60d 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -267,7 +267,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getHashDigest() const CustomUpdateBase::updateHash(hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName(), hash); + Type::updateHash(getSynapseGroup()->getSparseIndType(), hash); // Loop through variable references for(const auto &v : getVarReferences()) { @@ -288,7 +288,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWU::getInitHashDigest() cons CustomUpdateBase::updateInitHash(hash); Utils::updateHash(getSynapseMatrixConnectivity(getSynapseGroup()->getMatrixType()), hash); - Utils::updateHash(getSynapseGroup()->getSparseIndType()->getName(), hash); + Type::updateHash(getSynapseGroup()->getSparseIndType(), hash); return hash.get_digest(); } } // namespace GeNN diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index ceb713c65a..960c1a6ea9 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -728,7 +728,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getWUHashDigest() const Utils::updateHash(getDelaySteps(), hash); Utils::updateHash(getBackPropDelaySteps(), hash); Utils::updateHash(getMaxDendriticDelayTimesteps(), hash); - Utils::updateHash(getSparseIndType()->getName(), hash); + Type::updateHash(getSparseIndType(), hash); Utils::updateHash(getNumThreadsPerSpike(), hash); Utils::updateHash(isEventThresholdReTestRequired(), hash); Utils::updateHash(getSpanType(), hash); @@ -893,7 +893,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getWUInitHashDigest() cons { boost::uuids::detail::sha1 hash; Utils::updateHash(getMatrixType(), hash); - Utils::updateHash(getSparseIndType()->getName(), hash); + Type::updateHash(getSparseIndType(), hash); Utils::updateHash(getWUModel()->getVars(), hash); Utils::updateHash(getWUModel()->getSynapseDynamicsCode().empty(), hash); @@ -958,7 +958,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getConnectivityInitHashDig boost::uuids::detail::sha1 hash; Utils::updateHash(getConnectivityInitialiser().getHashDigest(), hash); Utils::updateHash(getMatrixType(), hash); - Utils::updateHash(getSparseIndType()->getName(), hash); + Type::updateHash(getSparseIndType(), hash); return hash.get_digest(); } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index a4834bb17a..e78af85453 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -11,7 +11,6 @@ #include "code_generator/codeStream.h" // Transpiler includes -#include "transpiler/transpilerUtils.h" #include "transpiler/typeChecker.h" using namespace GeNN; @@ -47,7 +46,7 @@ class EnvironmentInternal : public EnvironmentBase return "_" + name; } - virtual std::string getName(const std::string &name, const Type::Base *type = nullptr) final + virtual std::string getName(const std::string &name, std::optional type) final { if(m_LocalVariables.find(name) == m_LocalVariables.end()) { return m_Enclosing.getName(name, type); @@ -100,7 +99,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Assignment &assignement) final { - m_Environment.get().getStream() << m_Environment.get().getName(assignement.getVarName().lexeme) << " " << assignement.getOperator().lexeme << " "; + assignement.getAssignee()->accept(*this); + m_Environment.get().getStream() << " " << assignement.getOperator().lexeme << " "; assignement.getValue()->accept(*this); } @@ -123,9 +123,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Cast &cast) final { - m_Environment.get().getStream() << "("; - printType(cast.getType()); - m_Environment.get().getStream() << ")"; + m_Environment.get().getStream() << "(" << cast.getType().getName() << ")"; cast.getExpression()->accept(*this); } @@ -147,15 +145,10 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Literal &literal) final { - // If literal is a double, we want to remove the d suffix in generated code + // If literal is a float, add f suffix std::string_view lexeme = literal.getValue().lexeme; - if (literal.getValue().type == Token::Type::DOUBLE_NUMBER){ - m_Environment.get().getStream() << lexeme.substr(0, literal.getValue().lexeme.size() - 1); - } - // Otherwise, if literal is a scalar, we want to add appropriate suffix for scalar type - else if (literal.getValue().type == Token::Type::SCALAR_NUMBER) { - const Type::NumericBase *scalar = dynamic_cast(m_Context.at("scalar")); - m_Environment.get().getStream() << lexeme << scalar->getLiteralSuffix(m_Context); + if (literal.getValue().type == Token::Type::FLOAT_NUMBER){ + m_Environment.get().getStream() << lexeme << "f"; } // Otherwise, just write out original lexeme directly (strings are already quoted) else { @@ -172,17 +165,19 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::PostfixIncDec &postfixIncDec) final { - m_Environment.get().getStream() << m_Environment.get().getName(postfixIncDec.getVarName().lexeme) << postfixIncDec.getOperator().lexeme; + postfixIncDec.getTarget()->accept(*this); + m_Environment.get().getStream() << postfixIncDec.getOperator().lexeme; } virtual void visit(const Expression::PrefixIncDec &prefixIncDec) final { - m_Environment.get().getStream() << prefixIncDec.getOperator().lexeme << m_Environment.get().getName(prefixIncDec.getVarName().lexeme); + m_Environment.get().getStream() << prefixIncDec.getOperator().lexeme; + prefixIncDec.getTarget()->accept(*this); } virtual void visit(const Expression::Variable &variable) final { - const auto *type = m_ResolvedTypes.at(&variable); + const auto &type = m_ResolvedTypes.at(&variable); m_Environment.get().getStream() << m_Environment.get().getName(variable.getName().lexeme, type); } @@ -304,7 +299,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::VarDeclaration &varDeclaration) final { - printType(varDeclaration.getType()); + m_Environment.get().getStream() << varDeclaration.getType().getName() << " "; for(const auto &var : varDeclaration.getInitDeclaratorList()) { m_Environment.get().getStream() << m_Environment.get().define(std::get<0>(var).lexeme); @@ -333,41 +328,6 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } private: - void printType(const GeNN::Type::Base *type) - { - // **THINK** this should be Type::getName! - // Loop, building reversed list of tokens - std::vector tokens; - while(true) { - // If type is a pointer - const auto *pointerType = dynamic_cast(type); - if(pointerType) { - // If pointer has const qualifier, add const - if(pointerType->hasQualifier(GeNN::Type::Qualifier::CONSTANT)) { - tokens.push_back("const"); - } - - // Add * - tokens.push_back("*"); - - // Go to value type - type = pointerType->getValueType(); - } - // Otherwise - else { - // Add type specifier - tokens.push_back(type->getName()); - - if(pointerType->hasQualifier(GeNN::Type::Qualifier::CONSTANT)) { - tokens.push_back("const"); - } - break; - } - } - // Copy tokens backwards into string stream, seperating with spaces - std::copy(tokens.rbegin(), tokens.rend(), std::ostream_iterator(m_Environment.get().getStream(), " ")); - } - //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/standardLibrary.cc b/src/genn/genn/transpiler/standardLibrary.cc index d84365a5bf..90d11df388 100644 --- a/src/genn/genn/transpiler/standardLibrary.cc +++ b/src/genn/genn/transpiler/standardLibrary.cc @@ -158,13 +158,16 @@ std::vector FunctionTypes::getTypes(const Token &name, Error //--------------------------------------------------------------------------- // GeNN::Transpiler::StandardLibrary::FunctionEnvironment //--------------------------------------------------------------------------- -std::string FunctionEnvironment::getName(const std::string &name, const Type::ResolvedType &type) +std::string FunctionEnvironment::getName(const std::string &name, std::optional type) { const auto [libTypeBegin, libTypeEnd] = libraryTypes.equal_range(name); if (libTypeBegin == libTypeEnd) { return getContextName(name, type); } else { + if (!type) { + throw std::runtime_error("Ambiguous reference to '" + name + "' but no type provided to disambiguate"); + } const auto libType = std::find_if(libTypeBegin, libTypeEnd, [type](const auto &t){ return t.second.first == type; }); assert(libType != libTypeEnd); diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 186a3008b0..daaea84aab 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -176,26 +176,26 @@ class EnvironmentInternal : public EnvironmentBase class Visitor : public Expression::Visitor, public Statement::Visitor { public: - Visitor(const Statement::StatementList &statements, const Type::TypeContext &context, - EnvironmentInternal &environment, ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler) - : Visitor(context, environment, resolvedTypes, errorHandler) + Visitor(const Statement::StatementList &statements, EnvironmentInternal &environment, + ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler) + : Visitor(environment, resolvedTypes, errorHandler) { for (auto &s : statements) { s.get()->accept(*this); } } - Visitor(const Expression::Base *expression, const Type::TypeContext &context, - EnvironmentInternal &environment, ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler) - : Visitor(context, environment, resolvedTypes, errorHandler) + Visitor(const Expression::Base *expression, EnvironmentInternal &environment, + ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler) + : Visitor(environment, resolvedTypes, errorHandler) { expression->accept(*this); } private: - Visitor(const Type::TypeContext &context, EnvironmentInternal &environment, + Visitor(EnvironmentInternal &environment, ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler) - : m_Environment(environment), m_Context(context), m_ErrorHandler(errorHandler), + : m_Environment(environment), m_ErrorHandler(errorHandler), m_ResolvedTypes(resolvedTypes), m_InLoop(false), m_InSwitch(false) { } @@ -454,9 +454,6 @@ class Visitor : public Expression::Visitor, public Statement::Visitor else if (literal.getValue().type == Token::Type::FLOAT_NUMBER) { setExpressionType(&literal, Type::Float); } - else if (literal.getValue().type == Token::Type::SCALAR_NUMBER) { - setExpressionType(&literal, m_Context.at("scalar")); - } else if (literal.getValue().type == Token::Type::INT32_NUMBER) { setExpressionType(&literal, Type::Int32); } @@ -828,7 +825,6 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // Members //--------------------------------------------------------------------------- std::reference_wrapper m_Environment; - const Type::TypeContext &m_Context; ErrorHandlerBase &m_ErrorHandler; ResolvedTypeMap &m_ResolvedTypes; std::stack> m_CallArguments; @@ -856,19 +852,19 @@ Type::ResolvedType EnvironmentBase::getType(const Token &name, ErrorHandlerBase // GeNN::Transpiler::TypeChecker //--------------------------------------------------------------------------- ResolvedTypeMap GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler) + ErrorHandlerBase &errorHandler) { ResolvedTypeMap expressionTypes; EnvironmentInternal internalEnvironment(environment); - Visitor visitor(statements, context, internalEnvironment, expressionTypes, errorHandler); + Visitor visitor(statements, internalEnvironment, expressionTypes, errorHandler); return expressionTypes; } //--------------------------------------------------------------------------- Type::ResolvedType GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, EnvironmentBase &environment, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler) + ErrorHandlerBase &errorHandler) { ResolvedTypeMap expressionTypes; EnvironmentInternal internalEnvironment(environment); - Visitor visitor(expression, context, internalEnvironment, expressionTypes, errorHandler); + Visitor visitor(expression, internalEnvironment, expressionTypes, errorHandler); return expressionTypes.at(expression); } diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index b1004cb452..ecde457274 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -67,7 +67,6 @@ std::string ResolvedType::getName() const Utils::Overload{ [&qualifier](const Type::ResolvedType::Value &value) { - assert(value.numeric); return qualifier + value.name; }, [&qualifier](const Type::ResolvedType::Pointer &pointer) @@ -97,7 +96,7 @@ size_t ResolvedType::getSize(size_t pointerBytes) const { return pointerBytes; }, - [](const Type::ResolvedType::Function&) + [](const Type::ResolvedType::Function&)->size_t { throw std::runtime_error("Function types do not have size"); }}, @@ -123,16 +122,16 @@ ResolvedType UnresolvedType::resolve(const TypeContext &typeContext) const //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- -ResolvedType parseNumeric(const std::string &typeString) +ResolvedType parseNumeric(const std::string &typeString, const TypeContext &context) { using namespace Transpiler; // Scan type SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, errorHandler); + const auto tokens = Scanner::scanSource(typeString, context, errorHandler); // Parse type numeric type - const auto type = Parser::parseNumericType(tokens, errorHandler); + const auto type = Parser::parseNumericType(tokens, context, errorHandler); // If an error was encountered while scanning or parsing, throw exception if (errorHandler.hasError()) { From 8d7e6b97773138d69cb5deffc11c3f3663ba476a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 25 May 2023 17:42:14 +0100 Subject: [PATCH 164/725] unit tests running again --- include/genn/genn/models.h | 23 +-- tests/unit/modelSpecMerged.cc | 7 +- tests/unit/scanner.cc | 12 +- tests/unit/typeChecker.cc | 335 ++++++++++++++-------------------- 4 files changed, 158 insertions(+), 219 deletions(-) diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 04d71ed16d..1df29d2c24 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -51,15 +51,11 @@ class GENN_EXPORT Base : public Snippet::Base if not specified, this results in a -Wmissing-field-initializers warning on GCC and Clang*/ struct Var { - Var(const std::string &n, const Type::ResolvedType &t, VarAccess a) : name(n), type(t), access(a) + Var(const std::string &n, const Type::ResolvedType &t, VarAccess a = VarAccess::READ_WRITE) : name(n), type(t), access(a) {} - Var(const std::string &n, const Type::ResolvedType &t) : Var(n, t, VarAccess::READ_WRITE) + Var(const std::string &n, const std::string &t, VarAccess a = VarAccess::READ_WRITE) : name(n), type(t), access(a) {} - Var(const std::string &n, const std::string &t, VarAccess a) : name(n), type(t), access(a) - {} - Var(const std::string &n, const std::string &t) : Var(n, t, VarAccess::READ_WRITE) - {} - + bool operator == (const Var &other) const { return (std::tie(name, type, access) == std::tie(other.name, other.type, other.access)); @@ -72,14 +68,15 @@ class GENN_EXPORT Base : public Snippet::Base struct VarRef { - VarRef(const std::string &n, const Type::ResolvedType &t, VarAccessMode a) : name(n), type(t), access(a) + VarRef(const std::string &n, const Type::ResolvedType &t, VarAccessMode a = VarAccessMode::READ_WRITE) : name(n), type(t), access(a) {} - VarRef(const std::string &n, const Type::ResolvedType &t) : VarRef(n, t, VarAccessMode::READ_WRITE) + VarRef(const std::string &n, const std::string &t, VarAccessMode a = VarAccessMode::READ_WRITE) : name(n), type(t), access(a) {} - VarRef(const std::string &n, const std::string &t, VarAccessMode a); - VarRef(const std::string &n, const std::string &t); - - bool operator == (const VarRef &other) const; + + bool operator == (const VarRef &other) const + { + return (std::tie(name, type, access) == std::tie(other.name, other.type, other.access)); + } std::string name; Type::UnresolvedType type; diff --git a/tests/unit/modelSpecMerged.cc b/tests/unit/modelSpecMerged.cc index 5c684c7ec3..018a05ea36 100644 --- a/tests/unit/modelSpecMerged.cc +++ b/tests/unit/modelSpecMerged.cc @@ -182,8 +182,7 @@ void test(const std::pair (&modelModifiers)[N], M applyModifierFn) model.setName("test"); model.setDT(0.1); model.setTiming(false); - model.setPrecision(Type::Float::getInstance()); - model.setTimePrecision(nullptr); + model.setPrecision(Type::Float); model.setBatchSize(1); model.setSeed(0); @@ -319,8 +318,8 @@ TEST(ModelSpecMerged, CompareModelChanges) {[](ModelSpecInternal &model) { model.setName("interesting_name"); }, false}, {[](ModelSpecInternal &model) { model.setDT(1.0); }, false}, {[](ModelSpecInternal &model) { model.setTiming(true); }, false}, - {[](ModelSpecInternal &model) { model.setPrecision(Type::Double::getInstance()); }, false}, - {[](ModelSpecInternal &model) { model.setTimePrecision(Type::Double::getInstance()); }, false}, + {[](ModelSpecInternal &model) { model.setPrecision(Type::Double); }, false}, + {[](ModelSpecInternal &model) { model.setTimePrecision(Type::Double); }, false}, {[](ModelSpecInternal &model) { model.setBatchSize(10); }, false}, {[](ModelSpecInternal &model) { model.setSeed(1234); }, false}}; diff --git a/tests/unit/scanner.cc b/tests/unit/scanner.cc index 351a3836b7..43e5ca7157 100644 --- a/tests/unit/scanner.cc +++ b/tests/unit/scanner.cc @@ -56,7 +56,7 @@ class TestErrorHandler : public ErrorHandlerBase TEST(Scanner, DecimalInt) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("1234 4294967295U -2345 -2147483647", errorHandler); + const auto tokens = Scanner::scanSource("1234 4294967295U -2345 -2147483647", {}, errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 7); @@ -77,7 +77,7 @@ TEST(Scanner, DecimalInt) TEST(Scanner, HexInt) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("0x1234 0xFFFFFFFFU -0x1234 -0x7FFFFFFF", errorHandler); + const auto tokens = Scanner::scanSource("0x1234 0xFFFFFFFFU -0x1234 -0x7FFFFFFF", {}, errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 7); @@ -98,12 +98,12 @@ TEST(Scanner, HexInt) TEST(Scanner, DecimalFloat) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("1.0 0.2 100.0f 0.2f -12.0d -0.0004f", errorHandler); + const auto tokens = Scanner::scanSource("1.0 0.2 100.0f 0.2f -12.0d -0.0004f", {{"scalar", Type::Float}}, errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 9); - ASSERT_EQ(tokens[0].type, Token::Type::SCALAR_NUMBER); - ASSERT_EQ(tokens[1].type, Token::Type::SCALAR_NUMBER); + ASSERT_EQ(tokens[0].type, Token::Type::FLOAT_NUMBER); + ASSERT_EQ(tokens[1].type, Token::Type::FLOAT_NUMBER); ASSERT_EQ(tokens[2].type, Token::Type::FLOAT_NUMBER); ASSERT_EQ(tokens[3].type, Token::Type::FLOAT_NUMBER); ASSERT_EQ(tokens[4].type, Token::Type::MINUS); @@ -123,7 +123,7 @@ TEST(Scanner, DecimalFloat) TEST(Scanner, String) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("\"hello world\" \"pre-processor\"", errorHandler); + const auto tokens = Scanner::scanSource("\"hello world\" \"pre-processor\"", {}, errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 3); diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 4cc7e732f7..123534dd06 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -58,65 +58,31 @@ class TestEnvironment : public TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- - void define(const std::string &name, const Type::Base *type) + void define(const Type::ResolvedType &type, const std::string &name, Type::Qualifier qualifiers = Type::Qualifier{0}) { - if(!m_Types.try_emplace(name, type).second) { + if(!m_Types.try_emplace(name, type.addQualifier(qualifiers)).second) { throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } - template - void define(const std::string &name, Type::Qualifier qualifiers = Type::Qualifier{0}) - { - define(name, T::getInstance()->getQualifiedType(qualifiers)); - } - template - void definePointer(const std::string &name, Type::Qualifier valueQualifiers = Type::Qualifier{0}, + void definePointer(const Type::ResolvedType &type, const std::string &name, Type::Qualifier valueQualifiers = Type::Qualifier{0}, Type::Qualifier pointerQualifiers = Type::Qualifier{0}) { - define(name, T::getInstance()->getQualifiedType(valueQualifiers)->getPointerType(pointerQualifiers)); + define(type.addQualifier(valueQualifiers).createPointer(pointerQualifiers), name); } //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- - virtual void define(const Token &name, const Type::Base*, ErrorHandlerBase &errorHandler) final + virtual void define(const Token &name, const Type::ResolvedType&, ErrorHandlerBase &errorHandler) final { errorHandler.error(name, "Cannot declare variable in external environment"); throw TypeChecker::TypeCheckError(); } - virtual const Type::Base *assign(const Token &name, Token::Type op, const Type::Base *assignedType, - const Type::TypeContext &context, ErrorHandlerBase &errorHandler, - bool initializer = false) final - { - // If type isn't found - auto existingType = m_Types.find(name.lexeme); - if(existingType == m_Types.end()) { - errorHandler.error(name, "Undefined variable"); - throw TypeChecker::TypeCheckError(); - } - - // Perform standard type-checking logic - return EnvironmentBase::assign(name, op, existingType->second, assignedType, context, errorHandler, initializer); - } - - virtual const Type::Base *incDec(const Token &name, Token::Type op, - const Type::TypeContext&, ErrorHandlerBase &errorHandler) final - { - auto existingType = m_Types.find(name.lexeme); - if(existingType == m_Types.end()) { - errorHandler.error(name, "Undefined variable"); - throw TypeChecker::TypeCheckError(); - } - - // Perform standard type-checking logic - return EnvironmentBase::incDec(name, op, existingType->second, errorHandler); - } - - virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final { auto type = m_Types.find(std::string{name.lexeme}); if(type == m_Types.end()) { @@ -132,44 +98,38 @@ class TestEnvironment : public TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- - std::unordered_map m_Types; + std::unordered_map m_Types; }; -template -std::string getPointerTypeName() -{ - return T::getInstance()->getPointerType()->getName(); -} - void typeCheckStatements(std::string_view code, TypeChecker::EnvironmentBase &typeEnvironment, const Type::TypeContext &typeContext = {}) { // Scan TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(code, errorHandler); + const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); ASSERT_FALSE(errorHandler.hasError()); // Parse - const auto statements = Parser::parseBlockItemList(tokens, errorHandler); + const auto statements = Parser::parseBlockItemList(tokens, typeContext, errorHandler); ASSERT_FALSE(errorHandler.hasError()); // Typecheck - TypeChecker::typeCheck(statements, typeEnvironment, typeContext, errorHandler); + TypeChecker::typeCheck(statements, typeEnvironment, errorHandler); ASSERT_FALSE(errorHandler.hasError()); } -const Type::Base *typeCheckExpression(std::string_view code, TypeChecker::EnvironmentBase &typeEnvironment, const Type::TypeContext &typeContext = {}) +Type::ResolvedType typeCheckExpression(std::string_view code, TypeChecker::EnvironmentBase &typeEnvironment, const Type::TypeContext &typeContext = {}) { // Scan TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(code, errorHandler); + const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); EXPECT_FALSE(errorHandler.hasError()); // Parse - const auto expression = Parser::parseExpression(tokens, errorHandler); + const auto expression = Parser::parseExpression(tokens, typeContext, errorHandler); EXPECT_FALSE(errorHandler.hasError()); // Typecheck - const auto *type = TypeChecker::typeCheck(expression.get(), typeEnvironment, typeContext, errorHandler); + const auto type = TypeChecker::typeCheck(expression.get(), typeEnvironment, errorHandler); EXPECT_FALSE(errorHandler.hasError()); return type; } @@ -183,10 +143,10 @@ TEST(TypeChecker, ArraySubscript) // Integer array indexing { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); - const auto *type = typeCheckExpression("intArray[4]", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); + typeEnvironment.definePointer(Type::Int32, "intArray"); + const auto type = typeCheckExpression("intArray[4]", typeEnvironment); + EXPECT_EQ(type, Type::Int32); + EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); } // Pointer to pointer, double indexing @@ -194,15 +154,15 @@ TEST(TypeChecker, ArraySubscript) // Float array indexing EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); + typeEnvironment.definePointer(Type::Int32, "intArray"); typeCheckExpression("intArray[4.0f]", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer indexing EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); - typeEnvironment.definePointer("indexArray"); + typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.definePointer(Type::Int32, "indexArray"); typeCheckExpression("intArray[indexArray]", typeEnvironment);}, TypeChecker::TypeCheckError); } @@ -212,9 +172,9 @@ TEST(TypeChecker, Assignment) // Numeric assignment { TestEnvironment typeEnvironment; - typeEnvironment.define("intVal"); - typeEnvironment.define("floatVal"); - typeEnvironment.define("intValConst", Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32, "intVal"); + typeEnvironment.define(Type::Float, "floatVal"); + typeEnvironment.define(Type::Int32, "intValConst", Type::Qualifier::CONSTANT); typeCheckStatements( "int w = intVal;\n" "float x = floatVal;\n" @@ -229,8 +189,8 @@ TEST(TypeChecker, Assignment) // Pointer assignement { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); - typeEnvironment.definePointer("intArrayConst", Type::Qualifier::CONSTANT); + typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.definePointer(Type::Int32, "intArrayConst", Type::Qualifier::CONSTANT); typeCheckStatements( "int *x = intArray;\n" "const int *y = intArray;\n" @@ -241,21 +201,21 @@ TEST(TypeChecker, Assignment) // Pointer assignement, attempt to remove const EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier::CONSTANT); + typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier::CONSTANT); typeCheckStatements("int *x = intArray;", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer assignement without explicit cast EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); + typeEnvironment.definePointer(Type::Int32, "intArray"); typeCheckStatements("float *x = intArray;", typeEnvironment);}, TypeChecker::TypeCheckError); // Dereference assignment { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); + typeEnvironment.definePointer(Type::Int32, "intArray"); typeCheckStatements( "*intArray = 7;\n", typeEnvironment); @@ -268,10 +228,10 @@ TEST(TypeChecker, Binary) // Pointer difference { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray1"); - typeEnvironment.definePointer("intArray2"); - const auto *type = typeCheckExpression("intArray1 - intArray2", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + typeEnvironment.definePointer(Type::Int32, "intArray1"); + typeEnvironment.definePointer(Type::Int32, "intArray2"); + const auto type = typeCheckExpression("intArray1 - intArray2", typeEnvironment); + EXPECT_EQ(type, Type::Int32); } // **TODO** different pointer types @@ -280,12 +240,10 @@ TEST(TypeChecker, Binary) // Pointer + integer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); - typeEnvironment.define("offset"); - const auto *type = typeCheckExpression("intArray + offset", typeEnvironment); - const auto *pointerType = dynamic_cast(type); - EXPECT_TRUE(pointerType); - EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); + typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32, "offset"); + const auto type = typeCheckExpression("intArray + offset", typeEnvironment); + EXPECT_EQ(*type.getPointer().valueType, Type::Int32); } // **TODO** constness and @@ -293,16 +251,16 @@ TEST(TypeChecker, Binary) // Pointer + non-integer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); - typeEnvironment.define("offset"); + typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Float, "offset"); typeCheckExpression("intArray + offset", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer + pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray1"); - typeEnvironment.definePointer("intArray2"); + typeEnvironment.definePointer(Type::Int32, "intArray1"); + typeEnvironment.definePointer(Type::Int32, "intArray2"); typeCheckExpression("intArray1 + intArray2", typeEnvironment);}, TypeChecker::TypeCheckError); @@ -310,23 +268,19 @@ TEST(TypeChecker, Binary) // Pointer - integer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); - typeEnvironment.define("offset"); - const auto *type = typeCheckExpression("intArray - offset", typeEnvironment); - const auto *pointerType = dynamic_cast(type); - EXPECT_TRUE(pointerType); - EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); + typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32, "offset"); + const auto type = typeCheckExpression("intArray - offset", typeEnvironment); + EXPECT_EQ(*type.getPointer().valueType, Type::Int32); } // Integer + pointer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); - typeEnvironment.define("offset"); - const auto *type = typeCheckExpression("offset + intArray", typeEnvironment); - const auto *pointerType = dynamic_cast(type); - EXPECT_TRUE(pointerType); - EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); + typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32, "offset"); + const auto type = typeCheckExpression("offset + intArray", typeEnvironment); + EXPECT_EQ(*type.getPointer().valueType, Type::Int32); } /*integer only (opType == Token::Type::PERCENT || opType == Token::Type::SHIFT_LEFT @@ -350,34 +304,34 @@ TEST(TypeChecker, Call) // Floating point transcendental function { - const auto *type = typeCheckExpression("sin(1.0f)", stdLibraryEnv); - EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); + const auto type = typeCheckExpression("sin(1.0f)", stdLibraryEnv); + EXPECT_EQ(type, Type::Float); } // Double transcendental function { - const auto *type = typeCheckExpression("sin(1.0d)", stdLibraryEnv); - EXPECT_EQ(type->getName(), Type::Double::getInstance()->getName()); + const auto type = typeCheckExpression("sin(1.0d)", stdLibraryEnv); + EXPECT_EQ(type, Type::Double); } // Float scalar transcendental function { - const Type::TypeContext typeContext{{"scalar", Type::Float::getInstance()}}; - const auto *type = typeCheckExpression("sin(1.0)", stdLibraryEnv, typeContext); - EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); + const Type::TypeContext typeContext{{"scalar", Type::Float}}; + const auto type = typeCheckExpression("sin(1.0)", stdLibraryEnv, typeContext); + EXPECT_EQ(type, Type::Float); } // Double scalar transcendental function { - const Type::TypeContext typeContext{{"scalar", Type::Double::getInstance()}}; - const auto *type = typeCheckExpression("sin(1.0)", stdLibraryEnv, typeContext); - EXPECT_EQ(type->getName(), Type::Double::getInstance()->getName()); + const Type::TypeContext typeContext{{"scalar", Type::Double}}; + const auto type = typeCheckExpression("sin(1.0)", stdLibraryEnv, typeContext); + EXPECT_EQ(type, Type::Double); } // Nested transcendental function { - const auto *type = typeCheckExpression("sin(fmax(0.0f, 1.0f))", stdLibraryEnv); - EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); + const auto type = typeCheckExpression("sin(fmax(0.0f, 1.0f))", stdLibraryEnv); + EXPECT_EQ(type, Type::Float); } @@ -405,86 +359,81 @@ TEST(TypeChecker, Cast) // Numeric cast { TestEnvironment typeEnvironment; - typeEnvironment.define("intVal"); - const auto *type = typeCheckExpression("(float)intVal", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); + typeEnvironment.define(Type::Int32, "intVal"); + const auto type = typeCheckExpression("(float)intVal", typeEnvironment); + EXPECT_EQ(type, Type::Float); + EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); } // Numeric cast to const { TestEnvironment typeEnvironment; - typeEnvironment.define("intVal"); - const auto *type = typeCheckExpression("(const int)intVal", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONSTANT)); + typeEnvironment.define(Type::Int32, "intVal"); + const auto type = typeCheckExpression("(const int)intVal", typeEnvironment); + EXPECT_EQ(type, Type::Int32); + EXPECT_TRUE(type.hasQualifier(Type::Qualifier::CONSTANT)); } // Pointer cast to value const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); - const auto *type = typeCheckExpression("(const int*)intArray", typeEnvironment); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); + typeEnvironment.definePointer(Type::Int32, "intArray"); + const auto type = typeCheckExpression("(const int*)intArray", typeEnvironment); + EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); - const auto *pointerType = dynamic_cast(type); - EXPECT_TRUE(pointerType); - EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONSTANT)); + EXPECT_EQ(*type.getPointer().valueType, Type::Int32); + EXPECT_TRUE(type.getPointer().valueType->hasQualifier(Type::Qualifier::CONSTANT)); } // Pointer cast to pointer const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); - const auto *type = typeCheckExpression("(int * const)intArray", typeEnvironment); - EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONSTANT)); - - const auto *pointerType = dynamic_cast(type); - EXPECT_TRUE(pointerType); - EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONSTANT)); + typeEnvironment.definePointer(Type::Int32, "intArray"); + const auto type = typeCheckExpression("(int * const)intArray", typeEnvironment); + EXPECT_TRUE(type.hasQualifier(Type::Qualifier::CONSTANT)); + EXPECT_EQ(*type.getPointer().valueType, Type::Int32); + EXPECT_FALSE(type.getPointer().valueType->hasQualifier(Type::Qualifier::CONSTANT)); } // Can't remove value const from numeric EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.define("intVal", Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32, "intVal", Type::Qualifier::CONSTANT); typeCheckExpression("(int)intVal", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't remove value const from pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier::CONSTANT); + typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier::CONSTANT); typeCheckExpression("(int*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't remove pointer const from pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier{0}, Type::Qualifier::CONSTANT); + typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier{0}, Type::Qualifier::CONSTANT); typeCheckExpression("(int*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer cast can't reinterpret EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); + typeEnvironment.definePointer(Type::Int32, "intArray"); typeCheckExpression("(float*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer can't be cast to numeric EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); + typeEnvironment.definePointer(Type::Int32, "intArray"); typeCheckExpression("(int)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Numeric can't be cast to pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.define("intVal"); + typeEnvironment.define(Type::Int32, "intVal"); typeCheckExpression("(int*)intVal", typeEnvironment);}, TypeChecker::TypeCheckError); } @@ -498,45 +447,42 @@ TEST(TypeChecker, IncDec) // Can increment numeric { TestEnvironment typeEnvironment; - typeEnvironment.define("intVal"); - const auto *type = typeCheckExpression("intVal++", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); + typeEnvironment.define(Type::Int32, "intVal"); + const auto type = typeCheckExpression("intVal++", typeEnvironment); + EXPECT_EQ(type, Type::Int32); + EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); } // Can increment pointer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); - const auto *type = typeCheckExpression("intArray++", typeEnvironment); - EXPECT_EQ(type->getName(), getPointerTypeName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); + typeEnvironment.definePointer(Type::Int32, "intArray"); + const auto type = typeCheckExpression("intArray++", typeEnvironment); + EXPECT_EQ(type, Type::Int32.getPointer()); + EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); } // Can increment pointer to const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier::CONSTANT); - const auto *type = typeCheckExpression("intArray++", typeEnvironment); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); - - const auto *pointerType = dynamic_cast(type); - EXPECT_TRUE(pointerType); - EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONSTANT)); + typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier::CONSTANT); + const auto type = typeCheckExpression("intArray++", typeEnvironment); + EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); + EXPECT_EQ(*type.getPointer().valueType, Type::Int32); + EXPECT_TRUE(type.getPointer().valueType->hasQualifier(Type::Qualifier::CONSTANT)); } // Can't increment const number EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.define("intVal", Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32, "intVal", Type::Qualifier::CONSTANT); typeCheckExpression("intVal++", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't increment const pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier{0}, Type::Qualifier::CONSTANT); + typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier{0}, Type::Qualifier::CONSTANT); typeCheckExpression("intArray++", typeEnvironment);}, TypeChecker::TypeCheckError); } @@ -546,52 +492,52 @@ TEST(TypeChecker, Literal) // Float { TestEnvironment typeEnvironment; - const auto *type = typeCheckExpression("1.0f", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Float::getInstance()->getName()); + const auto type = typeCheckExpression("1.0f", typeEnvironment); + EXPECT_EQ(type, Type::Float); } // Scalar with single-precision { TestEnvironment typeEnvironment; - const Type::TypeContext typeContext{{"scalar", Type::Float::getInstance()}}; - const auto *type = typeCheckExpression("1.0", typeEnvironment, typeContext); - EXPECT_EQ(type->getResolvedName(typeContext), Type::Float::getInstance()->getName()); + const Type::TypeContext typeContext{{"scalar", Type::Float}}; + const auto type = typeCheckExpression("1.0", typeEnvironment, typeContext); + EXPECT_EQ(type, Type::Float); } // Scalar with double-precision { TestEnvironment typeEnvironment; - const Type::TypeContext typeContext{{"scalar", Type::Double::getInstance()}}; - const auto *type = typeCheckExpression("1.0", typeEnvironment, typeContext); - EXPECT_EQ(type->getResolvedName(typeContext), Type::Double::getInstance()->getName()); + const Type::TypeContext typeContext{{"scalar", Type::Double}}; + const auto type = typeCheckExpression("1.0", typeEnvironment, typeContext); + EXPECT_EQ(type, Type::Double); } // Double { TestEnvironment typeEnvironment; - const auto *type = typeCheckExpression("1.0d", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Double::getInstance()->getName()); + const auto type = typeCheckExpression("1.0d", typeEnvironment); + EXPECT_EQ(type, Type::Double); } // Integer { TestEnvironment typeEnvironment; - const auto *type = typeCheckExpression("100", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + const auto type = typeCheckExpression("100", typeEnvironment); + EXPECT_EQ(type, Type::Int32); } // Unsigned integer { TestEnvironment typeEnvironment; - const auto *type = typeCheckExpression("100U", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Uint32::getInstance()->getName()); + const auto type = typeCheckExpression("100U", typeEnvironment); + EXPECT_EQ(type, Type::Uint32); } // String { TestEnvironment typeEnvironment; - const auto *type = typeCheckExpression("\"hello world\"", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Int8::getInstance()->getPointerType()->getName()); + const auto type = typeCheckExpression("\"hello world\"", typeEnvironment); + EXPECT_EQ(type, Type::Int8.getPointer()); } } //-------------------------------------------------------------------------- @@ -600,63 +546,60 @@ TEST(TypeChecker, Unary) // Dereference pointer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); - const auto *type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); + typeEnvironment.definePointer(Type::Int32, "intArray"); + const auto type = typeCheckExpression("*intArray", typeEnvironment); + EXPECT_EQ(type, Type::Int32); + EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); } // Dereference pointer to const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier::CONSTANT); - const auto *type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONSTANT)); + typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier::CONSTANT); + const auto type = typeCheckExpression("*intArray", typeEnvironment); + EXPECT_EQ(type, Type::Int32); + EXPECT_TRUE(type.hasQualifier(Type::Qualifier::CONSTANT)); } // Dereference const pointer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier{0}, Type::Qualifier::CONSTANT); - const auto *type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); + typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier{0}, Type::Qualifier::CONSTANT); + const auto type = typeCheckExpression("*intArray", typeEnvironment); + EXPECT_EQ(type, Type::Int32); + EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); } // Dereference const pointer to const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray", Type::Qualifier::CONSTANT, Type::Qualifier::CONSTANT); - const auto *type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - EXPECT_TRUE(type->hasQualifier(Type::Qualifier::CONSTANT)); + typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier::CONSTANT, Type::Qualifier::CONSTANT); + const auto type = typeCheckExpression("*intArray", typeEnvironment); + EXPECT_EQ(type, Type::Int32); + EXPECT_TRUE(type.hasQualifier(Type::Qualifier::CONSTANT)); } // Dereference numeric EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.define("intVal"); + typeEnvironment.define(Type::Int32, "intVal"); typeCheckExpression("*intVal", typeEnvironment); }, TypeChecker::TypeCheckError); // Address of numeric { TestEnvironment typeEnvironment; - typeEnvironment.define("intVal"); - const auto *type = typeCheckExpression("&intVal", typeEnvironment); - EXPECT_FALSE(type->hasQualifier(Type::Qualifier::CONSTANT)); - - const auto *pointerType = dynamic_cast(type); - EXPECT_TRUE(pointerType); - EXPECT_EQ(pointerType->getValueType()->getName(), Type::Int32::getInstance()->getName()); - EXPECT_FALSE(pointerType->getValueType()->hasQualifier(Type::Qualifier::CONSTANT)); + typeEnvironment.define(Type::Int32, "intVal"); + const auto type = typeCheckExpression("&intVal", typeEnvironment); + EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); + EXPECT_EQ(*type.getPointer().valueType, Type::Int32); + EXPECT_FALSE(type.getPointer().valueType,->hasQualifier(Type::Qualifier::CONSTANT)); } // Address of pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer("intArray"); + typeEnvironment.definePointer(Type::Int32, "intArray"); typeCheckExpression("&intArray", typeEnvironment);}, TypeChecker::TypeCheckError); } From 72e9b09ade990fad51d85bc7ce60aed95a61050f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 25 May 2023 17:58:16 +0100 Subject: [PATCH 165/725] fixed bug in scalar literal scanning --- src/genn/genn/transpiler/scanner.cc | 34 +++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index f8e85fc796..e7b2cddb18 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -67,16 +67,6 @@ class ScanState ScanState(std::string_view source, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) : m_Start(0), m_Current(0), m_Line(1), m_Source(source), m_Context(context), m_ErrorHandler(errorHandler) { - const auto &scalarType = context.at("scalar"); - if (scalarType == Type::Float) { - m_ScalarTokenType = Token::Type::FLOAT_NUMBER; - } - else if (scalarType == Type::Double) { - m_ScalarTokenType = Token::Type::DOUBLE_NUMBER; - } - else { - throw std::runtime_error("Unsupported scalar type '" + scalarType.getName() + "'"); - } } //--------------------------------------------------------------------------- @@ -143,7 +133,24 @@ class ScanState return (m_Context.find(std::string{lexeme}) != m_Context.cend()); } - Token::Type getScalarTokenType() const{ return m_ScalarTokenType; } + Token::Type getScalarTokenType() const + { + const auto scalarType = m_Context.find("scalar"); + if (scalarType == m_Context.cend()) { + throw std::runtime_error("Cannot scan scalar literals without 'scalar' type being defined in type context"); + } + else { + if (scalarType->second == Type::Float) { + return Token::Type::FLOAT_NUMBER; + } + else if (scalarType->second == Type::Double) { + return Token::Type::DOUBLE_NUMBER; + } + else { + throw std::runtime_error("Unsupported scalar type '" + scalarType->first + "'"); + } + } + } private: //--------------------------------------------------------------------------- @@ -156,7 +163,6 @@ class ScanState std::string_view m_Source; const Type::TypeContext &m_Context; ErrorHandlerBase &m_ErrorHandler; - Token::Type m_ScalarTokenType; }; bool isodigit(char c) @@ -238,14 +244,14 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) // If number has an f suffix, emplace FLOAT_NUMBER token if (std::tolower(scanState.peek()) == 'f') { - scanState.advance(); emplaceToken(tokens, Token::Type::FLOAT_NUMBER, scanState); + scanState.advance(); } // Otherwise, if it has a d suffix, emplace DOUBLE_NUMBER token // **NOTE** 'd' is a GeNN extension not standard C else if (std::tolower(scanState.peek()) == 'd') { - scanState.advance(); emplaceToken(tokens, Token::Type::DOUBLE_NUMBER, scanState); + scanState.advance(); } // Otherwise, emplace literal with whatever type is specified else { From fe4e74154657616252d822d152ad29bad90f26a3 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 25 May 2023 17:58:32 +0100 Subject: [PATCH 166/725] fixed typo in type checker unit tests --- tests/unit/typeChecker.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 123534dd06..b86ef1eef0 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -458,7 +458,7 @@ TEST(TypeChecker, IncDec) TestEnvironment typeEnvironment; typeEnvironment.definePointer(Type::Int32, "intArray"); const auto type = typeCheckExpression("intArray++", typeEnvironment); - EXPECT_EQ(type, Type::Int32.getPointer()); + EXPECT_EQ(type, Type::Int32.createPointer()); EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); } @@ -537,7 +537,7 @@ TEST(TypeChecker, Literal) { TestEnvironment typeEnvironment; const auto type = typeCheckExpression("\"hello world\"", typeEnvironment); - EXPECT_EQ(type, Type::Int8.getPointer()); + EXPECT_EQ(type, Type::Int8.createPointer()); } } //-------------------------------------------------------------------------- @@ -593,7 +593,7 @@ TEST(TypeChecker, Unary) const auto type = typeCheckExpression("&intVal", typeEnvironment); EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); EXPECT_EQ(*type.getPointer().valueType, Type::Int32); - EXPECT_FALSE(type.getPointer().valueType,->hasQualifier(Type::Qualifier::CONSTANT)); + EXPECT_FALSE(type.getPointer().valueType->hasQualifier(Type::Qualifier::CONSTANT)); } // Address of pointer From c4eb3937b580f3c53a0e33353d1b6529a0d284e7 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 25 May 2023 18:08:04 +0100 Subject: [PATCH 167/725] fixed more equality typos in type checker unit tests --- tests/unit/typeChecker.cc | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index b86ef1eef0..eb4da856d8 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -362,7 +362,6 @@ TEST(TypeChecker, Cast) typeEnvironment.define(Type::Int32, "intVal"); const auto type = typeCheckExpression("(float)intVal", typeEnvironment); EXPECT_EQ(type, Type::Float); - EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); } // Numeric cast to const @@ -370,8 +369,7 @@ TEST(TypeChecker, Cast) TestEnvironment typeEnvironment; typeEnvironment.define(Type::Int32, "intVal"); const auto type = typeCheckExpression("(const int)intVal", typeEnvironment); - EXPECT_EQ(type, Type::Int32); - EXPECT_TRUE(type.hasQualifier(Type::Qualifier::CONSTANT)); + EXPECT_EQ(type, Type::Int32.addQualifier(Type::Qualifier::CONSTANT)); } // Pointer cast to value const @@ -381,8 +379,7 @@ TEST(TypeChecker, Cast) const auto type = typeCheckExpression("(const int*)intArray", typeEnvironment); EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); - EXPECT_EQ(*type.getPointer().valueType, Type::Int32); - EXPECT_TRUE(type.getPointer().valueType->hasQualifier(Type::Qualifier::CONSTANT)); + EXPECT_EQ(*type.getPointer().valueType, Type::Int32.addQualifier(Type::Qualifier::CONSTANT)); } // Pointer cast to pointer const @@ -391,8 +388,7 @@ TEST(TypeChecker, Cast) typeEnvironment.definePointer(Type::Int32, "intArray"); const auto type = typeCheckExpression("(int * const)intArray", typeEnvironment); EXPECT_TRUE(type.hasQualifier(Type::Qualifier::CONSTANT)); - EXPECT_EQ(*type.getPointer().valueType, Type::Int32); - EXPECT_FALSE(type.getPointer().valueType->hasQualifier(Type::Qualifier::CONSTANT)); + EXPECT_EQ(*type.getPointer().valueType, Type::Int32); } // Can't remove value const from numeric @@ -468,8 +464,7 @@ TEST(TypeChecker, IncDec) typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier::CONSTANT); const auto type = typeCheckExpression("intArray++", typeEnvironment); EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); - EXPECT_EQ(*type.getPointer().valueType, Type::Int32); - EXPECT_TRUE(type.getPointer().valueType->hasQualifier(Type::Qualifier::CONSTANT)); + EXPECT_EQ(*type.getPointer().valueType, Type::Int32.addQualifier(Type::Qualifier::CONSTANT)); } // Can't increment const number @@ -537,7 +532,7 @@ TEST(TypeChecker, Literal) { TestEnvironment typeEnvironment; const auto type = typeCheckExpression("\"hello world\"", typeEnvironment); - EXPECT_EQ(type, Type::Int8.createPointer()); + EXPECT_EQ(type, Type::Int8.createPointer(Type::Qualifier::CONSTANT)); } } //-------------------------------------------------------------------------- @@ -549,7 +544,6 @@ TEST(TypeChecker, Unary) typeEnvironment.definePointer(Type::Int32, "intArray"); const auto type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type, Type::Int32); - EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); } // Dereference pointer to const @@ -557,8 +551,7 @@ TEST(TypeChecker, Unary) TestEnvironment typeEnvironment; typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier::CONSTANT); const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type, Type::Int32); - EXPECT_TRUE(type.hasQualifier(Type::Qualifier::CONSTANT)); + EXPECT_EQ(type, Type::Int32.addQualifier(Type::Qualifier::CONSTANT)); } // Dereference const pointer @@ -567,7 +560,6 @@ TEST(TypeChecker, Unary) typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier{0}, Type::Qualifier::CONSTANT); const auto type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type, Type::Int32); - EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); } // Dereference const pointer to const @@ -575,8 +567,7 @@ TEST(TypeChecker, Unary) TestEnvironment typeEnvironment; typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier::CONSTANT, Type::Qualifier::CONSTANT); const auto type = typeCheckExpression("*intArray", typeEnvironment); - EXPECT_EQ(type, Type::Int32); - EXPECT_TRUE(type.hasQualifier(Type::Qualifier::CONSTANT)); + EXPECT_EQ(type, Type::Int32.addQualifier(Type::Qualifier::CONSTANT)); } // Dereference numeric From 429b9e5e8ba4d3b7a5ef69e1ca9d29859910ea54 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 25 May 2023 18:15:31 +0100 Subject: [PATCH 168/725] whitespace --- src/genn/genn/transpiler/typeChecker.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index daaea84aab..e7dd31d134 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -415,7 +415,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } else { m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType().getName() + "' and '" + rightType.getName()); - throw TypeCheckError(); + throw TypeCheckError(); } } From 862c08ce80ab059c00f6720657727362f56018b5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 09:36:30 +0100 Subject: [PATCH 169/725] when resolved types use Type::parseNumeric --- src/genn/genn/type.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index ecde457274..a59a854871 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -115,7 +115,7 @@ ResolvedType UnresolvedType::resolve(const TypeContext &typeContext) const }, [&typeContext](const std::string &name) { - return typeContext.at(name); + return parseNumeric(name, typeContext); }}, detail); } From 127749d39fe1e2bb541a73bc2f9a05a3d994d0cd Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 09:48:17 +0100 Subject: [PATCH 170/725] slightly changed numeric literal handling * Scanner now strips out suffices (better for parsing if required) * Pretty printer now adds them back in --- src/genn/genn/transpiler/prettyPrinter.cc | 12 ++++---- src/genn/genn/transpiler/scanner.cc | 35 +++++++++++------------ tests/unit/scanner.cc | 10 +++---- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index e78af85453..2b69a2dc9a 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -145,14 +145,16 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Literal &literal) final { + // Write out lexeme + m_Environment.get().getStream() << literal.getValue().lexeme; + // If literal is a float, add f suffix - std::string_view lexeme = literal.getValue().lexeme; if (literal.getValue().type == Token::Type::FLOAT_NUMBER){ - m_Environment.get().getStream() << lexeme << "f"; + m_Environment.get().getStream() << "f"; } - // Otherwise, just write out original lexeme directly (strings are already quoted) - else { - m_Environment.get().getStream() << lexeme; + // Otherwise, if it's an unsigned integer, add u suffix + else if (literal.getValue().type == Token::Type::UINT32_NUMBER) { + m_Environment.get().getStream() << "u"; } } diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index e7b2cddb18..33c9911f7f 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -53,11 +53,6 @@ const std::unordered_map keywords{ {"int32_t", Token::Type::TYPE_SPECIFIER}, {"bool", Token::Type::TYPE_SPECIFIER}}; //--------------------------------------------------------------------------- -const std::map, Token::Type> integerLiteralTokenTypes{ - {{}, Token::Type::INT32_NUMBER}, - {{'U'}, Token::Type::UINT32_NUMBER} -}; -//--------------------------------------------------------------------------- // ScanState //--------------------------------------------------------------------------- //! Class encapsulated logic to navigate through source characters @@ -176,16 +171,6 @@ void emplaceToken(std::vector &tokens, Token::Type type, const ScanState tokens.emplace_back(type, scanState.getLexeme(), scanState.getLine()); } //--------------------------------------------------------------------------- -Token::Type scanIntegerSuffix(ScanState &scanState) -{ - // Read suffix - std::set suffix; - while(std::toupper(scanState.peek()) == 'U' || std::toupper(scanState.peek()) == 'L') { - suffix.insert(std::toupper(scanState.advance())); - } - return integerLiteralTokenTypes.at(suffix); -} -//--------------------------------------------------------------------------- void scanNumber(char c, ScanState &scanState, std::vector &tokens) { // If this is a hexadecimal literal @@ -205,8 +190,15 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) scanState.advance(); } - // Add integer token - emplaceToken(tokens, scanIntegerSuffix(scanState), scanState); + // If there's a U suffix, emplace + if (std::toupper(scanState.peek()) == 'U') { + emplaceToken(tokens, Token::Type::UINT32_NUMBER, scanState); + scanState.advance(); + } + else { + emplaceToken(tokens, Token::Type::INT32_NUMBER, scanState); + } + } // Otherwise, if this is an octal integer else if(c == '0' && isodigit(scanState.peek())){ @@ -260,7 +252,14 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) } // Otherwise, emplace integer token else { - emplaceToken(tokens, scanIntegerSuffix(scanState), scanState); + // If there's a U suffix, emplace + if (std::toupper(scanState.peek()) == 'U') { + emplaceToken(tokens, Token::Type::UINT32_NUMBER, scanState); + scanState.advance(); + } + else { + emplaceToken(tokens, Token::Type::INT32_NUMBER, scanState); + } } } } diff --git a/tests/unit/scanner.cc b/tests/unit/scanner.cc index 43e5ca7157..731651175c 100644 --- a/tests/unit/scanner.cc +++ b/tests/unit/scanner.cc @@ -90,7 +90,7 @@ TEST(Scanner, HexInt) ASSERT_EQ(tokens[6].type, Token::Type::END_OF_FILE); ASSERT_EQ(tokens[0].lexeme, "0x1234"); - ASSERT_EQ(tokens[1].lexeme, "0xFFFFFFFFU"); + ASSERT_EQ(tokens[1].lexeme, "0xFFFFFFFF"); ASSERT_EQ(tokens[3].lexeme, "0x1234"); ASSERT_EQ(tokens[5].lexeme, "0x7FFFFFFF"); } @@ -114,10 +114,10 @@ TEST(Scanner, DecimalFloat) ASSERT_EQ(tokens[0].lexeme, "1.0"); ASSERT_EQ(tokens[1].lexeme, "0.2"); - ASSERT_EQ(tokens[2].lexeme, "100.0f"); - ASSERT_EQ(tokens[3].lexeme, "0.2f"); - ASSERT_EQ(tokens[5].lexeme, "12.0d"); - ASSERT_EQ(tokens[7].lexeme, "0.0004f"); + ASSERT_EQ(tokens[2].lexeme, "100.0"); + ASSERT_EQ(tokens[3].lexeme, "0.2"); + ASSERT_EQ(tokens[5].lexeme, "12.0"); + ASSERT_EQ(tokens[7].lexeme, "0.0004"); } //-------------------------------------------------------------------------- TEST(Scanner, String) From 2a0bcd6c772e567b345bff55c8aff9ba74f06548 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 09:49:33 +0100 Subject: [PATCH 171/725] missed one unit test --- tests/unit/scanner.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/scanner.cc b/tests/unit/scanner.cc index 731651175c..412415259a 100644 --- a/tests/unit/scanner.cc +++ b/tests/unit/scanner.cc @@ -69,7 +69,7 @@ TEST(Scanner, DecimalInt) ASSERT_EQ(tokens[6].type, Token::Type::END_OF_FILE); ASSERT_EQ(tokens[0].lexeme, "1234"); - ASSERT_EQ(tokens[1].lexeme, "4294967295U"); + ASSERT_EQ(tokens[1].lexeme, "4294967295"); ASSERT_EQ(tokens[3].lexeme, "2345"); ASSERT_EQ(tokens[5].lexeme, "2147483647"); } From 0350dc0042b62202a45b2b99570a367f29b1e3e2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 13:42:42 +0100 Subject: [PATCH 172/725] Type::ResolvedType::addQualifier should actually ADD qualifiers rather than replacing them --- include/genn/genn/type.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 29a64712ae..9c0f584840 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -220,7 +220,7 @@ struct ResolvedType const Function &getFunction() const{ return std::get(detail); } const Numeric &getNumeric() const{ return *getValue().numeric; } - const ResolvedType addQualifier(Qualifier qualifier) const{ return ResolvedType(*this, qualifier); } + const ResolvedType addQualifier(Qualifier qualifier) const{ return ResolvedType(*this, qualifiers | qualifier); } bool hasQualifier(Qualifier qualifier) const{ return (qualifiers & qualifier); } std::string getName() const; From 6c404871769fb1e1e34f54b755fe95d6f9a02888 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 13:43:27 +0100 Subject: [PATCH 173/725] use standard syntax for building types for tests rather than having bespoke methods --- tests/unit/typeChecker.cc | 85 ++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 46 deletions(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index eb4da856d8..45d7111762 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -58,21 +58,13 @@ class TestEnvironment : public TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- // Public API //--------------------------------------------------------------------------- - void define(const Type::ResolvedType &type, const std::string &name, Type::Qualifier qualifiers = Type::Qualifier{0}) + void define(const Type::ResolvedType &type, const std::string &name) { - if(!m_Types.try_emplace(name, type.addQualifier(qualifiers)).second) { + if(!m_Types.try_emplace(name, type).second) { throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } - - void definePointer(const Type::ResolvedType &type, const std::string &name, Type::Qualifier valueQualifiers = Type::Qualifier{0}, - Type::Qualifier pointerQualifiers = Type::Qualifier{0}) - { - define(type.addQualifier(valueQualifiers).createPointer(pointerQualifiers), name); - } - - //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- @@ -143,7 +135,7 @@ TEST(TypeChecker, ArraySubscript) // Integer array indexing { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); const auto type = typeCheckExpression("intArray[4]", typeEnvironment); EXPECT_EQ(type, Type::Int32); EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); @@ -154,15 +146,15 @@ TEST(TypeChecker, ArraySubscript) // Float array indexing EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); typeCheckExpression("intArray[4.0f]", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer indexing EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); - typeEnvironment.definePointer(Type::Int32, "indexArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "indexArray"); typeCheckExpression("intArray[indexArray]", typeEnvironment);}, TypeChecker::TypeCheckError); } @@ -174,7 +166,7 @@ TEST(TypeChecker, Assignment) TestEnvironment typeEnvironment; typeEnvironment.define(Type::Int32, "intVal"); typeEnvironment.define(Type::Float, "floatVal"); - typeEnvironment.define(Type::Int32, "intValConst", Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT), "intValConst"); typeCheckStatements( "int w = intVal;\n" "float x = floatVal;\n" @@ -189,8 +181,8 @@ TEST(TypeChecker, Assignment) // Pointer assignement { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); - typeEnvironment.definePointer(Type::Int32, "intArrayConst", Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); + typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT).createPointer(), "intArrayConst"); typeCheckStatements( "int *x = intArray;\n" "const int *y = intArray;\n" @@ -201,21 +193,21 @@ TEST(TypeChecker, Assignment) // Pointer assignement, attempt to remove const EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT).createPointer(), "intArray"); typeCheckStatements("int *x = intArray;", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer assignement without explicit cast EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); typeCheckStatements("float *x = intArray;", typeEnvironment);}, TypeChecker::TypeCheckError); // Dereference assignment { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); typeCheckStatements( "*intArray = 7;\n", typeEnvironment); @@ -228,8 +220,8 @@ TEST(TypeChecker, Binary) // Pointer difference { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray1"); - typeEnvironment.definePointer(Type::Int32, "intArray2"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray1"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray2"); const auto type = typeCheckExpression("intArray1 - intArray2", typeEnvironment); EXPECT_EQ(type, Type::Int32); } @@ -240,7 +232,7 @@ TEST(TypeChecker, Binary) // Pointer + integer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); typeEnvironment.define(Type::Int32, "offset"); const auto type = typeCheckExpression("intArray + offset", typeEnvironment); EXPECT_EQ(*type.getPointer().valueType, Type::Int32); @@ -251,7 +243,7 @@ TEST(TypeChecker, Binary) // Pointer + non-integer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); typeEnvironment.define(Type::Float, "offset"); typeCheckExpression("intArray + offset", typeEnvironment);}, TypeChecker::TypeCheckError); @@ -259,8 +251,8 @@ TEST(TypeChecker, Binary) // Pointer + pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray1"); - typeEnvironment.definePointer(Type::Int32, "intArray2"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray1"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray2"); typeCheckExpression("intArray1 + intArray2", typeEnvironment);}, TypeChecker::TypeCheckError); @@ -268,7 +260,7 @@ TEST(TypeChecker, Binary) // Pointer - integer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); typeEnvironment.define(Type::Int32, "offset"); const auto type = typeCheckExpression("intArray - offset", typeEnvironment); EXPECT_EQ(*type.getPointer().valueType, Type::Int32); @@ -277,7 +269,7 @@ TEST(TypeChecker, Binary) // Integer + pointer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); typeEnvironment.define(Type::Int32, "offset"); const auto type = typeCheckExpression("offset + intArray", typeEnvironment); EXPECT_EQ(*type.getPointer().valueType, Type::Int32); @@ -375,7 +367,7 @@ TEST(TypeChecker, Cast) // Pointer cast to value const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); const auto type = typeCheckExpression("(const int*)intArray", typeEnvironment); EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); @@ -385,44 +377,45 @@ TEST(TypeChecker, Cast) // Pointer cast to pointer const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); const auto type = typeCheckExpression("(int * const)intArray", typeEnvironment); EXPECT_TRUE(type.hasQualifier(Type::Qualifier::CONSTANT)); EXPECT_EQ(*type.getPointer().valueType, Type::Int32); } // Can't remove value const from numeric + // **THINK** why not? it's a copy EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.define(Type::Int32, "intVal", Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT), "intVal"); typeCheckExpression("(int)intVal", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't remove value const from pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT).createPointer(), "intArray"); typeCheckExpression("(int*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Can't remove pointer const from pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier{0}, Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32.createPointer(Type::Qualifier::CONSTANT), "intArray"); typeCheckExpression("(int*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer cast can't reinterpret EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); typeCheckExpression("(float*)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); // Pointer can't be cast to numeric EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); typeCheckExpression("(int)intArray", typeEnvironment);}, TypeChecker::TypeCheckError); @@ -449,10 +442,10 @@ TEST(TypeChecker, IncDec) EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); } - // Can increment pointer + // Can increment const int* pointer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); const auto type = typeCheckExpression("intArray++", typeEnvironment); EXPECT_EQ(type, Type::Int32.createPointer()); EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); @@ -461,7 +454,7 @@ TEST(TypeChecker, IncDec) // Can increment pointer to const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT).createPointer(), "intArray"); const auto type = typeCheckExpression("intArray++", typeEnvironment); EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); EXPECT_EQ(*type.getPointer().valueType, Type::Int32.addQualifier(Type::Qualifier::CONSTANT)); @@ -470,14 +463,14 @@ TEST(TypeChecker, IncDec) // Can't increment const number EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.define(Type::Int32, "intVal", Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT), "intVal"); typeCheckExpression("intVal++", typeEnvironment);}, TypeChecker::TypeCheckError); - // Can't increment const pointer + // Can't increment int * const pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier{0}, Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32.createPointer(Type::Qualifier::CONSTANT), "intArray"); typeCheckExpression("intArray++", typeEnvironment);}, TypeChecker::TypeCheckError); } @@ -541,7 +534,7 @@ TEST(TypeChecker, Unary) // Dereference pointer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); const auto type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type, Type::Int32); } @@ -549,7 +542,7 @@ TEST(TypeChecker, Unary) // Dereference pointer to const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT).createPointer(), "intArray"); const auto type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type, Type::Int32.addQualifier(Type::Qualifier::CONSTANT)); } @@ -557,7 +550,7 @@ TEST(TypeChecker, Unary) // Dereference const pointer { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier{0}, Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32.createPointer(Type::Qualifier::CONSTANT), "intArray"); const auto type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type, Type::Int32); } @@ -565,7 +558,7 @@ TEST(TypeChecker, Unary) // Dereference const pointer to const { TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray", Type::Qualifier::CONSTANT, Type::Qualifier::CONSTANT); + typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT).createPointer(Type::Qualifier::CONSTANT), "intArray"); const auto type = typeCheckExpression("*intArray", typeEnvironment); EXPECT_EQ(type, Type::Int32.addQualifier(Type::Qualifier::CONSTANT)); } @@ -590,7 +583,7 @@ TEST(TypeChecker, Unary) // Address of pointer EXPECT_THROW({ TestEnvironment typeEnvironment; - typeEnvironment.definePointer(Type::Int32, "intArray"); + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); typeCheckExpression("&intArray", typeEnvironment);}, TypeChecker::TypeCheckError); } From c9ce1755cfeac359c04ed9d337a4d36822201ba3 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 14:20:53 +0100 Subject: [PATCH 174/725] small fixes * only compare unqualified types in checkPointerTypeAssignement * throw type errors whenever one occurs --- src/genn/genn/transpiler/typeChecker.cc | 11 +++++++---- tests/unit/typeChecker.cc | 1 - 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index e7dd31d134..299465cbbf 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -32,10 +32,9 @@ bool checkPointerTypeAssignement(const Type::ResolvedType &rightType, const Type { return std::visit( Utils::Overload{ - [&rightType, &leftType](const Type::ResolvedType::Value &leftValue, const Type::ResolvedType::Value &rightValue) + [](const Type::ResolvedType::Value &leftValue, const Type::ResolvedType::Value &rightValue) { - assert(leftValue.numeric && rightValue.numeric); - return (rightType == leftType); + return (rightValue == leftValue); }, [](const Type::ResolvedType::Pointer &rightPointer, const Type::ResolvedType::Pointer &leftPointer) { @@ -592,7 +591,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // If there are no viable candidates, give error if(viableFunctions.empty()) { m_ErrorHandler.error(variable.getName(), - "No viable function candidates for '" + variable.getName().lexeme + "'"); + "No viable function candidates for '" + variable.getName().lexeme + "'"); throw TypeCheckError(); } // Otherwise, sort lexigraphically by conversion rank and return type of lowest @@ -663,6 +662,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { if (!m_InLoop && !m_InSwitch) { m_ErrorHandler.error(breakStatement.getToken(), "Statement not within loop"); + throw TypeCheckError(); } } @@ -687,6 +687,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { if (!m_InLoop) { m_ErrorHandler.error(continueStatement.getToken(), "Statement not within loop"); + throw TypeCheckError(); } } @@ -746,6 +747,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { if (!m_InSwitch) { m_ErrorHandler.error(labelled.getKeyword(), "Statement not within switch statement"); + throw TypeCheckError(); } if (labelled.getValue()) { @@ -786,6 +788,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto initialiserType = evaluateType(std::get<1>(var).get()); if (!checkImplicitConversion(initialiserType, decType)) { m_ErrorHandler.error(std::get<0>(var), "Invalid operand types '" + decType.getName() + "' and '" + initialiserType.getName()); + throw TypeCheckError(); } } } diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 45d7111762..33cf77b896 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -370,7 +370,6 @@ TEST(TypeChecker, Cast) typeEnvironment.define(Type::Int32.createPointer(), "intArray"); const auto type = typeCheckExpression("(const int*)intArray", typeEnvironment); EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); - EXPECT_EQ(*type.getPointer().valueType, Type::Int32.addQualifier(Type::Qualifier::CONSTANT)); } From b7738453590269d9741b43e9acf48a2636b35a9f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 14:26:04 +0100 Subject: [PATCH 175/725] missed that unary * operator expressions are l-value --- include/genn/genn/transpiler/expression.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index 7ec92e94c5..d6c1e87769 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -333,6 +333,11 @@ class Unary : public Acceptable : m_Operator(op), m_Right(std::move(right)) {} + //------------------------------------------------------------------------ + // Expression::Base virtuals + //------------------------------------------------------------------------ + virtual bool isLValue() const{ return (m_Operator.type == Token::Type::STAR); } + const Token &getOperator() const { return m_Operator; } const Base *getRight() const { return m_Right.get(); } From bb0ffc866d5a9d751089cf7118eb20585b36e46d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 14:37:02 +0100 Subject: [PATCH 176/725] renamed ``Variable`` expression to ``Identifier`` to better match C grammar --- include/genn/genn/transpiler/expression.h | 10 +++++----- src/genn/genn/transpiler/parser.cc | 2 +- src/genn/genn/transpiler/prettyPrinter.cc | 2 +- src/genn/genn/transpiler/typeChecker.cc | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/include/genn/genn/transpiler/expression.h b/include/genn/genn/transpiler/expression.h index d6c1e87769..e888bf16e0 100644 --- a/include/genn/genn/transpiler/expression.h +++ b/include/genn/genn/transpiler/expression.h @@ -24,7 +24,7 @@ class Literal; class Logical; class PostfixIncDec; class PrefixIncDec; -class Variable; +class Identifier; class Unary; } @@ -47,7 +47,7 @@ class Visitor virtual void visit(const Logical &logical) = 0; virtual void visit(const PostfixIncDec &postfixIncDec) = 0; virtual void visit(const PrefixIncDec &postfixIncDec) = 0; - virtual void visit(const Variable &variable) = 0; + virtual void visit(const Identifier &variable) = 0; virtual void visit(const Unary &unary) = 0; }; @@ -303,12 +303,12 @@ class PrefixIncDec : public Acceptable }; //--------------------------------------------------------------------------- -// GeNN::Transpiler::Expression::Variable +// GeNN::Transpiler::Expression::Identifier //--------------------------------------------------------------------------- -class Variable : public Acceptable +class Identifier : public Acceptable { public: - Variable(Token name) + Identifier(Token name) : m_Name(name) {} diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 8cdae68804..b4ba31274f 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -246,7 +246,7 @@ Expression::ExpressionPtr parsePrimary(ParserState &parserState) return std::make_unique(parserState.previous()); } else if(parserState.match(Token::Type::IDENTIFIER)) { - return std::make_unique(parserState.previous()); + return std::make_unique(parserState.previous()); } else if(parserState.match(Token::Type::LEFT_PAREN)) { auto expression = parseExpression(parserState); diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 2b69a2dc9a..53bf48af7b 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -177,7 +177,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor prefixIncDec.getTarget()->accept(*this); } - virtual void visit(const Expression::Variable &variable) final + virtual void visit(const Expression::Identifier &variable) final { const auto &type = m_ResolvedTypes.at(&variable); m_Environment.get().getStream() << m_Environment.get().getName(variable.getName().lexeme, type); diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 299465cbbf..fd28d92bb7 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -500,7 +500,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } } - virtual void visit(const Expression::Variable &variable) + virtual void visit(const Expression::Identifier &variable) { // If type is unambiguous and not a function const auto varTypes = m_Environment.get().getTypes(variable.getName(), m_ErrorHandler); From 9b38292a1ea4a2c05bb6a851a71ba6e68a78c58f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 14:37:43 +0100 Subject: [PATCH 177/725] Most of Assignment unit test actually test of VarDeclaration statement so split --- tests/unit/typeChecker.cc | 88 ++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 38 deletions(-) diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 33cf77b896..56a88b465d 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -161,55 +161,21 @@ TEST(TypeChecker, ArraySubscript) //-------------------------------------------------------------------------- TEST(TypeChecker, Assignment) { - // Numeric assignment - { - TestEnvironment typeEnvironment; - typeEnvironment.define(Type::Int32, "intVal"); - typeEnvironment.define(Type::Float, "floatVal"); - typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT), "intValConst"); - typeCheckStatements( - "int w = intVal;\n" - "float x = floatVal;\n" - "int y = floatVal;\n" - "float z = intVal;\n" - "int wc = intValConst;\n" - "const int cw = intVal;\n" - "const int cwc = intValConst;\n", - typeEnvironment); - } - - // Pointer assignement + // Dereference assignment { TestEnvironment typeEnvironment; typeEnvironment.define(Type::Int32.createPointer(), "intArray"); - typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT).createPointer(), "intArrayConst"); typeCheckStatements( - "int *x = intArray;\n" - "const int *y = intArray;\n" - "const int *z = intArrayConst;\n", + "*intArray = 7;\n", typeEnvironment); } - // Pointer assignement, attempt to remove const - EXPECT_THROW({ - TestEnvironment typeEnvironment; - typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT).createPointer(), "intArray"); - typeCheckStatements("int *x = intArray;", typeEnvironment);}, - TypeChecker::TypeCheckError); - - // Pointer assignement without explicit cast - EXPECT_THROW({ - TestEnvironment typeEnvironment; - typeEnvironment.define(Type::Int32.createPointer(), "intArray"); - typeCheckStatements("float *x = intArray;", typeEnvironment);}, - TypeChecker::TypeCheckError); - - // Dereference assignment + // Array subscript assignment { TestEnvironment typeEnvironment; typeEnvironment.define(Type::Int32.createPointer(), "intArray"); typeCheckStatements( - "*intArray = 7;\n", + "intArray[5] = 7;\n", typeEnvironment); } // **TODO** other assignements i.e. += -= %= @@ -586,3 +552,49 @@ TEST(TypeChecker, Unary) typeCheckExpression("&intArray", typeEnvironment);}, TypeChecker::TypeCheckError); } +//-------------------------------------------------------------------------- +TEST(TypeChecker, VarDeclaration) +{ + // Numeric var declaration + { + TestEnvironment typeEnvironment; + typeEnvironment.define(Type::Int32, "intVal"); + typeEnvironment.define(Type::Float, "floatVal"); + typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT), "intValConst"); + typeCheckStatements( + "int w = intVal;\n" + "float x = floatVal;\n" + "int y = floatVal;\n" + "float z = intVal;\n" + "int wc = intValConst;\n" + "const int cw = intVal;\n" + "const int cwc = intValConst;\n", + typeEnvironment); + } + + // Pointer var declaration + { + TestEnvironment typeEnvironment; + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); + typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT).createPointer(), "intArrayConst"); + typeCheckStatements( + "int *x = intArray;\n" + "const int *y = intArray;\n" + "const int *z = intArrayConst;\n", + typeEnvironment); + } + + // Pointer var declaration, attempt to remove const + EXPECT_THROW({ + TestEnvironment typeEnvironment; + typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT).createPointer(), "intArray"); + typeCheckStatements("int *x = intArray;", typeEnvironment);}, + TypeChecker::TypeCheckError); + + // Pointer var declaration without explicit cast + EXPECT_THROW({ + TestEnvironment typeEnvironment; + typeEnvironment.define(Type::Int32.createPointer(), "intArray"); + typeCheckStatements("float *x = intArray;", typeEnvironment);}, + TypeChecker::TypeCheckError); +} \ No newline at end of file From 2debcee2822f41c692308c001620fd61b115199d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 15:38:05 +0100 Subject: [PATCH 178/725] totally remove support for & unary expression - solves our "addressable" issues as pointers can only be obtained from other pointers --- src/genn/genn/transpiler/parser.cc | 17 ++--------------- src/genn/genn/transpiler/typeChecker.cc | 7 ++++--- tests/unit/typeChecker.cc | 17 ----------------- 3 files changed, 6 insertions(+), 35 deletions(-) diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index b4ba31274f..50caa47481 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -329,7 +329,6 @@ Expression::ExpressionPtr parseUnary(ParserState &parserState) // postfix-expression // "++" unary-expression // "--" unary-expression - // "&" cast-expression // "*" cast-expression // "+" cast-expression // "-" cast-expression @@ -337,20 +336,8 @@ Expression::ExpressionPtr parseUnary(ParserState &parserState) // "!" cast-expression // "sizeof" unary-expression **TODO** // "sizeof" "(" type-name ")" **TODO** - if(parserState.match(Token::Type::AMPERSAND)) { - Token op = parserState.previous(); - auto expression = parseCast(parserState); - - // If expression is a valid l-value, - if (expression->isLValue()) { - return std::make_unique(op, std::move(expression)); - } - else { - parserState.error(op, "Cannot take the address of r-value"); - } - } - else if(parserState.match({Token::Type::STAR, Token::Type::PLUS, Token::Type::MINUS, - Token::Type::TILDA, Token::Type::NOT})) + if(parserState.match({Token::Type::STAR, Token::Type::PLUS, Token::Type::MINUS, + Token::Type::TILDA, Token::Type::NOT})) { Token op = parserState.previous(); return std::make_unique(op, parseCast(parserState)); diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index fd28d92bb7..96b16e97e2 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -643,9 +643,10 @@ class Visitor : public Expression::Visitor, public Statement::Visitor else if (unary.getOperator().type == Token::Type::NOT) { setExpressionType(&unary, Type::Int32); } - // Otherwise, if operator is address of, return pointer type - else if (unary.getOperator().type == Token::Type::AMPERSAND) { - setExpressionType(&unary, rightType.createPointer()); + else { + m_ErrorHandler.error(unary.getOperator(), + "Invalid operand type '" + rightType.getName() + "'"); + throw TypeCheckError(); } } else { diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 56a88b465d..9464f2a3aa 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -534,23 +534,6 @@ TEST(TypeChecker, Unary) typeEnvironment.define(Type::Int32, "intVal"); typeCheckExpression("*intVal", typeEnvironment); }, TypeChecker::TypeCheckError); - - // Address of numeric - { - TestEnvironment typeEnvironment; - typeEnvironment.define(Type::Int32, "intVal"); - const auto type = typeCheckExpression("&intVal", typeEnvironment); - EXPECT_FALSE(type.hasQualifier(Type::Qualifier::CONSTANT)); - EXPECT_EQ(*type.getPointer().valueType, Type::Int32); - EXPECT_FALSE(type.getPointer().valueType->hasQualifier(Type::Qualifier::CONSTANT)); - } - - // Address of pointer - EXPECT_THROW({ - TestEnvironment typeEnvironment; - typeEnvironment.define(Type::Int32.createPointer(), "intArray"); - typeCheckExpression("&intArray", typeEnvironment);}, - TypeChecker::TypeCheckError); } //-------------------------------------------------------------------------- TEST(TypeChecker, VarDeclaration) From d65df285f364cdbec7733c124a3c7b5fc86f0943 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 15:52:51 +0100 Subject: [PATCH 179/725] remove leftover print statement from lox --- include/genn/genn/transpiler/statement.h | 20 -------------------- include/genn/genn/transpiler/token.h | 2 +- src/genn/genn/transpiler/parser.cc | 20 ++++---------------- src/genn/genn/transpiler/prettyPrinter.cc | 7 ------- src/genn/genn/transpiler/scanner.cc | 1 - src/genn/genn/transpiler/typeChecker.cc | 5 ----- 6 files changed, 5 insertions(+), 50 deletions(-) diff --git a/include/genn/genn/transpiler/statement.h b/include/genn/genn/transpiler/statement.h index 964ade6a9a..97a171b813 100644 --- a/include/genn/genn/transpiler/statement.h +++ b/include/genn/genn/transpiler/statement.h @@ -24,7 +24,6 @@ class Labelled; class Switch; class VarDeclaration; class While; -class Print; } //--------------------------------------------------------------------------- @@ -46,7 +45,6 @@ class Visitor virtual void visit(const Switch &switchStatement) = 0; virtual void visit(const VarDeclaration &varDeclaration) = 0; virtual void visit(const While &whileStatement) = 0; - virtual void visit(const Print &print) = 0; }; //--------------------------------------------------------------------------- @@ -286,22 +284,4 @@ class While : public Acceptable ExpressionPtr m_Condition; StatementPtr m_Body; }; - -//--------------------------------------------------------------------------- -// GeNN::Transpiler::Statement::Print -//--------------------------------------------------------------------------- -// **HACK** temporary until function calling is working -class Print : public Acceptable -{ - using ExpressionPtr = GeNN::Transpiler::Expression::ExpressionPtr; -public: - Print(ExpressionPtr expression) - : m_Expression(std::move(expression)) - {} - - const ExpressionPtr::element_type *getExpression() const { return m_Expression.get(); } - -private: - ExpressionPtr m_Expression; -}; } // namespace GeNN::Transpiler::Statement diff --git a/include/genn/genn/transpiler/token.h b/include/genn/genn/transpiler/token.h index 62212b74ce..66b3cbcf13 100644 --- a/include/genn/genn/transpiler/token.h +++ b/include/genn/genn/transpiler/token.h @@ -44,7 +44,7 @@ struct Token TYPE_QUALIFIER, // Keywords - DO, ELSE, FALSE, FOR, IF, TRUE, WHILE, PRINT, SWITCH, CONTINUE, BREAK, CASE, DEFAULT, + DO, ELSE, FALSE, FOR, IF, TRUE, WHILE, SWITCH, CONTINUE, BREAK, CASE, DEFAULT, END_OF_FILE, }; diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 50caa47481..36fb3f7dd3 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -594,14 +594,6 @@ Statement::StatementPtr parseExpressionStatement(ParserState &parserState) return std::make_unique(std::move(expression)); } -Statement::StatementPtr parsePrintStatement(ParserState &parserState) -{ - auto expression = parseExpression(parserState); - - parserState.consume(Token::Type::SEMICOLON, "Expect ';' after expression"); - return std::make_unique(std::move(expression)); -} - Statement::StatementPtr parseSelectionStatement(ParserState &parserState) { // selection-statement ::= @@ -734,16 +726,15 @@ Statement::StatementPtr parseStatement(ParserState &parserState) // labeled-statement // compound-statement // expression-statement - // print-statement // **TEMP** // selection-statement // iteration-statement // jump-statement - if(parserState.match(Token::Type::PRINT)) { - return parsePrintStatement(parserState); - } - else if(parserState.match({Token::Type::CASE, Token::Type::DEFAULT})) { + if(parserState.match({Token::Type::CASE, Token::Type::DEFAULT})) { return parseLabelledStatement(parserState); } + else if(parserState.match(Token::Type::LEFT_BRACE)) { + return parseCompoundStatement(parserState); + } else if(parserState.match({Token::Type::IF, Token::Type::SWITCH})) { return parseSelectionStatement(parserState); } @@ -753,9 +744,6 @@ Statement::StatementPtr parseStatement(ParserState &parserState) else if(parserState.match({Token::Type::CONTINUE, Token::Type::BREAK})) { return parseJumpStatement(parserState); } - else if(parserState.match(Token::Type::LEFT_BRACE)) { - return parseCompoundStatement(parserState); - } else { return parseExpressionStatement(parserState); } diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 53bf48af7b..c8f73ff3c4 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -322,13 +322,6 @@ class Visitor : public Expression::Visitor, public Statement::Visitor whileStatement.getBody()->accept(*this); } - virtual void visit(const Statement::Print &print) final - { - m_Environment.get().getStream() << "print "; - print.getExpression()->accept(*this); - m_Environment.get().getStream() << ";"; - } - private: //--------------------------------------------------------------------------- // Members diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index 33c9911f7f..cc36196fa2 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -36,7 +36,6 @@ const std::unordered_map keywords{ {"continue", Token::Type::CONTINUE}, {"case", Token::Type::CASE}, {"default", Token::Type::DEFAULT}, - {"print", Token::Type::PRINT}, // **HACK** {"char", Token::Type::TYPE_SPECIFIER}, {"short", Token::Type::TYPE_SPECIFIER}, {"int", Token::Type::TYPE_SPECIFIER}, diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 96b16e97e2..c9077b71e6 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -803,11 +803,6 @@ class Visitor : public Expression::Visitor, public Statement::Visitor m_InLoop = false; } - virtual void visit(const Statement::Print &print) final - { - print.getExpression()->accept(*this); - } - private: //--------------------------------------------------------------------------- // Private methods From 009f9ecc5d3f1a482699b19d71175c8babb147c9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 16:05:45 +0100 Subject: [PATCH 180/725] reinstated handling of variadic functions for printf --- include/genn/genn/type.h | 22 +++++++++++++-------- src/genn/genn/transpiler/standardLibrary.cc | 5 ++++- src/genn/genn/transpiler/typeChecker.cc | 11 ++++++++--- tests/unit/typeChecker.cc | 14 ++++++------- 4 files changed, 32 insertions(+), 20 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 9c0f584840..c256684223 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -155,35 +155,41 @@ struct ResolvedType //------------------------------------------------------------------------ struct Function { - Function(const ResolvedType &returnType, const std::vector &argTypes) - : returnType(std::make_unique(returnType)), argTypes(argTypes) + Function(const ResolvedType &returnType, const std::vector &argTypes, bool variadic=false) + : returnType(std::make_unique(returnType)), argTypes(argTypes), variadic(variadic) {} Function(const Function &other) - : returnType(std::make_unique(*other.returnType)), argTypes(other.argTypes) + : returnType(std::make_unique(*other.returnType)), + argTypes(other.argTypes), variadic(other.variadic) {} std::unique_ptr returnType; std::vector argTypes; + bool variadic; bool operator == (const Function &other) const { - return (std::tie(*returnType, argTypes) == std::tie(*other.returnType, other.argTypes)); + return (std::tie(*returnType, argTypes, variadic) + == std::tie(*other.returnType, other.argTypes, other.variadic)); } bool operator != (const Function &other) const { - return (std::tie(*returnType, argTypes) != std::tie(*other.returnType, other.argTypes)); + return (std::tie(*returnType, argTypes, variadic) + != std::tie(*other.returnType, other.argTypes, other.variadic)); } bool operator < (const Function &other) const { - return (std::tie(*returnType, argTypes) < std::tie(*other.returnType, other.argTypes)); + return (std::tie(*returnType, argTypes, variadic) + < std::tie(*other.returnType, other.argTypes, other.variadic)); } Function &operator = (const Function &other) { returnType.reset(new ResolvedType(*other.returnType)); argTypes = other.argTypes; + variadic = other.variadic; return *this; } }; @@ -267,9 +273,9 @@ struct ResolvedType return ResolvedType{Value{name, sizeof(T), std::nullopt}, qualifiers}; } - static ResolvedType createFunction(const ResolvedType &returnType, const std::vector &argTypes) + static ResolvedType createFunction(const ResolvedType &returnType, const std::vector &argTypes, bool variadic=false) { - return ResolvedType{Function{returnType, argTypes}, Qualifier{0}}; + return ResolvedType{Function{returnType, argTypes, variadic}, Qualifier{0}}; } }; diff --git a/src/genn/genn/transpiler/standardLibrary.cc b/src/genn/genn/transpiler/standardLibrary.cc index 90d11df388..266634a050 100644 --- a/src/genn/genn/transpiler/standardLibrary.cc +++ b/src/genn/genn/transpiler/standardLibrary.cc @@ -112,7 +112,10 @@ const auto libraryTypes = initLibraryTypes( ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(tgamma), ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(lgamma), ADD_TWO_ARG_FLOAT_DOUBLE_FUNC(copysign), - ADD_THREE_ARG_FLOAT_DOUBLE_FUNC(fma)); + ADD_THREE_ARG_FLOAT_DOUBLE_FUNC(fma), + + // Printf + std::make_pair("printf", std::make_pair(Type::ResolvedType::createFunction(Type::Int32, {Type::Int8.addQualifier(Type::Qualifier::CONSTANT).createPointer()}, true), "printf($(0), $(@))"))); } diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index c9077b71e6..b295737654 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -515,18 +515,23 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // Loop through variable types std::vector>> viableFunctions; for(const auto &type : varTypes) { - // If function is non-variadic and number of arguments match + // If function is non-variadic and number of arguments + // match or variadic and enough arguments are provided const auto &argumentTypes = type.getFunction().argTypes; - if(m_CallArguments.top().size() == argumentTypes.size()) { + const bool variadic = type.getFunction().variadic; + if((!variadic && m_CallArguments.top().size() == argumentTypes.size()) + || (variadic && m_CallArguments.top().size() >= argumentTypes.size())) + { // Create vector to hold argument conversion rank std::vector argumentConversionRank; argumentConversionRank.reserve(m_CallArguments.top().size()); // Loop through arguments + // **NOTE** we loop through function arguments to deal with variadic bool viable = true; auto c = m_CallArguments.top().cbegin(); auto a = argumentTypes.cbegin(); - for(;c != m_CallArguments.top().cend(); c++, a++) { + for(;a != argumentTypes.cend(); c++, a++) { const auto argConversionRank = std::visit( Utils::Overload{ // If types are numeric, any cast goes diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index 9464f2a3aa..edd9ecdedf 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -292,24 +292,22 @@ TEST(TypeChecker, Call) EXPECT_EQ(type, Type::Float); } - // Variadic with too few arguments - /*EXPECT_THROW({ + EXPECT_THROW({ typeCheckExpression("printf()", stdLibraryEnv);}, TypeChecker::TypeCheckError); // Variadic function with no extra arguments { - const auto *type = typeCheckExpression("printf(\"hello world\")", stdLibraryEnv); - EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); + const auto type = typeCheckExpression("printf(\"hello world\")", stdLibraryEnv); + EXPECT_EQ(type, Type::Int32); } // Variadic function with extra arguments { - const auto *type = typeCheckExpression("printf(\"hello world %d, %f\", 12, cos(5.0f))", stdLibraryEnv); - EXPECT_EQ(type->getName(), Type::Int32::getInstance()->getName()); - }*/ - + const auto type = typeCheckExpression("printf(\"hello world %d, %f\", 12, cos(5.0f))", stdLibraryEnv); + EXPECT_EQ(type, Type::Int32); + } } //-------------------------------------------------------------------------- TEST(TypeChecker, Cast) From af5878fbdfd22c4e8a767168aed4f36b24c61d24 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 17:19:24 +0100 Subject: [PATCH 181/725] Pretty printer now expands out function templates returned by environment --- src/genn/genn/transpiler/prettyPrinter.cc | 147 ++++++++++++++++++++-- 1 file changed, 137 insertions(+), 10 deletions(-) diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index c8f73ff3c4..bae2653029 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include // GeNN code generator includes @@ -69,6 +70,52 @@ class EnvironmentInternal : public EnvironmentBase std::unordered_set m_LocalVariables; }; +//--------------------------------------------------------------------------- +// EnvironmentCallArgument +//--------------------------------------------------------------------------- +class EnvironmentCallArgument : public EnvironmentBase +{ +public: + EnvironmentCallArgument(EnvironmentBase &enclosing) + : m_Enclosing(enclosing), m_CodeStream(m_Stream) + { + } + + //--------------------------------------------------------------------------- + // EnvironmentBase virtuals + //--------------------------------------------------------------------------- + virtual std::string define(const std::string &name) final + { + throw std::runtime_error("Cannot declare variable in call environment"); + } + + virtual std::string getName(const std::string &name, std::optional type) final + { + return m_Enclosing.getName(name, type); + } + + virtual CodeStream &getStream() + { + return m_CodeStream; + } + + //--------------------------------------------------------------------------- + // Public API + //--------------------------------------------------------------------------- + std::string getString() const + { + return m_Stream.str(); + } + +private: + //--------------------------------------------------------------------------- + // Members + //--------------------------------------------------------------------------- + EnvironmentBase &m_Enclosing; + std::ostringstream m_Stream; + CodeStream m_CodeStream; +}; + //--------------------------------------------------------------------------- // Visitor //--------------------------------------------------------------------------- @@ -113,12 +160,36 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Call &call) final { - call.getCallee()->accept(*this); - m_Environment.get().getStream() << "("; - for(const auto &a : call.getArguments()) { + // Cache reference to current reference + std::reference_wrapper oldEnvironment = m_Environment; + + // Push new vector of arguments onto call argument stack and + // reserve memory to hold all arguments + m_CallArguments.emplace(); + m_CallArguments.top().reserve(call.getArguments().size()); + + // Loop through call arguments + for (const auto &a : call.getArguments()) { + // Create new call argument environment and set to current + EnvironmentCallArgument environment(oldEnvironment.get()); + m_Environment = environment; + + // Pretty print argument a->accept(*this); - } - m_Environment.get().getStream() << ")"; + + // Add pretty printed argument to vector on top of stack + m_CallArguments.top().push_back(environment.getString()); + } + + // Restore old environment + m_Environment = oldEnvironment; + + // Pretty print callee + call.getCallee()->accept(*this); + + // Pop stack + m_CallArguments.pop(); + } virtual void visit(const Expression::Cast &cast) final @@ -179,8 +250,63 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Identifier &variable) final { + // Get name of identifier const auto &type = m_ResolvedTypes.at(&variable); - m_Environment.get().getStream() << m_Environment.get().getName(variable.getName().lexeme, type); + std::string name = m_Environment.get().getName(variable.getName().lexeme, type); + + // If identifier is function i.e. name is a function template + if (type.isFunction()) { + // Check that there are call arguments on the stack + assert(!m_CallArguments.empty()); + + // Loop through call arguments on top of stack + size_t i = 0; + for (i = 0; i < m_CallArguments.top().size(); i++) { + // If name contains a $(i) placeholder to replace with this argument, replace with pretty-printed argument + const std::string placeholder = "$(" + std::to_string(i) + ")"; + const size_t found = name.find(placeholder); + if (found != std::string::npos) { + name.replace(found, placeholder.length(), m_CallArguments.top().at(i)); + } + // Otherwise, stop searching + else { + break; + } + } + + // If all arguments haven't been substituted + if (i != m_CallArguments.top().size()) { + // If function is variadic + if (type.getFunction().variadic) { + // If variadic placeholder is found + const std::string variadicPlaceholder = "$(@)"; + const size_t found = name.find(variadicPlaceholder); + if (found != std::string::npos) { + // Concatenate together all remaining arguments + std::ostringstream variadicArgumentsStream; + std::copy(m_CallArguments.top().cbegin() + i, m_CallArguments.top().cend(), + std::ostream_iterator(variadicArgumentsStream, ", ")); + + // Replace variadic placeholder with all remaining arguments (after trimming trailing ", ") + std::string variadicArguments = variadicArgumentsStream.str(); + name.replace(found, variadicPlaceholder.length(), + variadicArguments.substr(0, variadicArguments.length() - 2)); + } + else { + throw std::runtime_error("Variadic function template for '" + variable.getName().lexeme + "' (" + name + ") has " + "insufficient placeholders for " + std::to_string(m_CallArguments.top().size()) + " argument call and no variadic placeholder '$(@)'"); + } + } + // Otherwise, give error + else { + throw std::runtime_error("Function template for '" + variable.getName().lexeme + "' (" + name + ") has " + "insufficient placeholders for " + std::to_string(m_CallArguments.top().size()) + " argument call"); + } + } + } + // Print out name + // **NOTE** in case of function this will be full pretty-printed call + m_Environment.get().getStream() << name; } virtual void visit(const Expression::Unary &unary) final @@ -200,7 +326,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::Compound &compound) final { // Cache reference to current reference - std::reference_wrapper oldEnvironment = m_Environment; + std::reference_wrapper oldEnvironment = m_Environment; // Create new environment and set to current EnvironmentInternal environment(m_Environment); @@ -239,7 +365,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::For &forStatement) final { // Cache reference to current reference - std::reference_wrapper oldEnvironment = m_Environment; + std::reference_wrapper oldEnvironment = m_Environment; // Create new environment and set to current EnvironmentInternal environment(m_Environment); @@ -326,9 +452,10 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- // Members //--------------------------------------------------------------------------- - std::reference_wrapper m_Environment; + std::reference_wrapper m_Environment; const Type::TypeContext &m_Context; const TypeChecker::ResolvedTypeMap &m_ResolvedTypes; + std::stack> m_CallArguments; }; } // Anonymous namespace @@ -339,5 +466,5 @@ void GeNN::Transpiler::PrettyPrinter::print(const Statement::StatementList &stat const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes) { EnvironmentInternal internalEnvironment(environment); - Visitor(statements, internalEnvironment, context, resolvedTypes); + Visitor visitor(statements, internalEnvironment, context, resolvedTypes); } From 611fe4fe037d1ec820bdc9b520e7d651ef839f8b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 26 May 2023 17:25:42 +0100 Subject: [PATCH 182/725] use "identifier" rather than "variable" in error messages --- include/genn/genn/code_generator/groupMergedTypeEnvironment.h | 2 +- src/genn/genn/code_generator/environment.cc | 2 +- src/genn/genn/transpiler/standardLibrary.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index 93079accbf..39ca462987 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -46,7 +46,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa return m_Enclosing->getTypes(name, errorHandler); } else { - errorHandler.error(name, "Undefined variable"); + errorHandler.error(name, "Undefined identifier"); throw TypeCheckError(); } } diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index 4fed4dd26d..c1a4fd2b5b 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -33,7 +33,7 @@ std::string EnvironmentExternal::getContextName(const std::string &name, std::op return std::visit( Utils::Overload{ [&name, type](std::reference_wrapper enclosing)->std::string { return enclosing.get().getName(name, type); }, - [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Variable '" + name + "' undefined"); }}, + [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Identifier '" + name + "' undefined"); }}, getContext()); } diff --git a/src/genn/genn/transpiler/standardLibrary.cc b/src/genn/genn/transpiler/standardLibrary.cc index 266634a050..6818c985b9 100644 --- a/src/genn/genn/transpiler/standardLibrary.cc +++ b/src/genn/genn/transpiler/standardLibrary.cc @@ -146,7 +146,7 @@ std::vector FunctionTypes::getTypes(const Token &name, Error { const auto [typeBegin, typeEnd] = libraryTypes.equal_range(name.lexeme); if (typeBegin == typeEnd) { - errorHandler.error(name, "Undefined variable"); + errorHandler.error(name, "Undefined identifier"); throw TypeCheckError(); } else { From ef89c7a2655d3ceff3c7abae876b1fee4dacaf05 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 30 May 2023 18:48:35 +0100 Subject: [PATCH 183/725] WIP child merged group --- .../code_generator/neuronUpdateGroupMerged.h | 21 +++++++++ .../code_generator/neuronUpdateGroupMerged.cc | 43 +++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index 85c398e57b..6850ba39c4 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -11,6 +11,27 @@ namespace GeNN::CodeGenerator class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase { public: + //---------------------------------------------------------------------------- + // GeNN::CodeGenerator::NeuronUpdateGroupMerged::CurrentSource + //---------------------------------------------------------------------------- + class CurrentSource : public GroupMerged + { + public: + CurrentSource(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups); + + // Public API + //! Should the current source parameter be implemented heterogeneously? + bool isParamHeterogeneous(const std::string ¶mName) const; + + //! Should the current source derived parameter be implemented heterogeneously? + bool isDerivedParamHeterogeneous(const std::string ¶mName) const; + + private: + //! Is the current source parameter referenced? + bool isParamReferenced(const std::string ¶mName) const; + }; + NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index a75dd7c185..e5c3235d07 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -6,6 +6,49 @@ using namespace GeNN; using namespace GeNN::CodeGenerator; +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronUpdateGroupMerged::CurrentSource +//---------------------------------------------------------------------------- +NeuronUpdateGroupMerged::CurrentSource::CurrentSource(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups) +: GroupMerged(index, typeContext, groups) +{ + // **TODO** correct "CS" + index field names + source name (for current source just its name but for PSM etc fused name) + addVars(getArchetype().getCurrentSourceModel()->getVars(), backend.getDeviceVarPrefix()); + + // **TODO** correct "CS" + index field names + source name (for current source just its name but for PSM etc fused name) + addHeterogeneousParams(getArchetype().getCurrentSourceModel()->getParamNames(), "CS", + &NeuronUpdateGroupMerged::CurrentSource::isParamHeterogeneous, + &CurrentSourceInternal::getParams); + + // **TODO** correct "CS" + index field names + source name (for current source just its name but for PSM etc fused name) + addHeterogeneousDerivedParams(getArchetype().getCurrentSourceModel()->getDerivedParams(), "CS", + &NeuronUpdateGroupMerged::CurrentSource::isDerivedParamHeterogeneous, + &CurrentSourceInternal::getDerivedParams); + + // Add EGPs + // **TODO** correct field names + source name (for current source just its name but for PSM etc fused name) + addEGPs(getArchetype().getCurrentSourceModel()->getExtraGlobalParams(), backend.getDeviceVarPrefix()); +} +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::CurrentSource::isParamHeterogeneous(const std::string ¶mName) const +{ + return (isParamReferenced(paramName) && + isParamValueHeterogeneous(paramName, [](const CurrentSourceInternal &cs) { return cs.getParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::CurrentSource::isDerivedParamHeterogeneous( const std::string ¶mName) const +{ + return (isParamReferenced(paramName) && + isParamValueHeterogeneous(paramName, [](const CurrentSourceInternal &cs) { return cs.getDerivedParams(); })); + +} +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::CurrentSource::isParamReferenced(const std::string ¶mName) const +{ + return isParamReferenced({getArchetype().getCurrentSourceModel()->getInjectionCode()}, paramName); +} + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged //---------------------------------------------------------------------------- From 7697868bfc4cf9c687da4bed5a6931f27ac80f82 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 1 Jun 2023 11:25:27 +0100 Subject: [PATCH 184/725] WIP implemented child groups for all bits of ``NeuronUpdateGroupMerged`` --- .../genn/genn/code_generator/groupMerged.h | 11 +- .../code_generator/neuronUpdateGroupMerged.h | 97 +++++ .../code_generator/neuronUpdateGroupMerged.cc | 390 +++++++++++++++++- 3 files changed, 478 insertions(+), 20 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 897459f747..e672238217 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -265,16 +265,19 @@ class GroupMerged fieldType); } - void addPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix) + void addPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix, + GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) { assert(type.isValue()); addField(type.createPointer(), name, - [prefix](const G &g, size_t) { return prefix + g.getName(); }); + [prefix](const G &g, size_t) { return prefix + g.getName(); }, + fieldType); } - void addPointerField(const Type::UnresolvedType &type, const std::string &name, const std::string &prefix) + void addPointerField(const Type::UnresolvedType &type, const std::string &name, const std::string &prefix, + GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) { - addPointerField(type.resolve(getTypeContext()), name, prefix); + addPointerField(type.resolve(getTypeContext()), name, prefix, fieldType); } diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index 6850ba39c4..b89f0b63f0 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -14,13 +14,19 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::CurrentSource //---------------------------------------------------------------------------- + //! Child group merged for current sources attached to this neuron update group class CurrentSource : public GroupMerged { public: CurrentSource(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); + //---------------------------------------------------------------------------- // Public API + //---------------------------------------------------------------------------- + void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs); + //! Should the current source parameter be implemented heterogeneously? bool isParamHeterogeneous(const std::string ¶mName) const; @@ -28,10 +34,101 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase bool isDerivedParamHeterogeneous(const std::string ¶mName) const; private: + //---------------------------------------------------------------------------- + // Private API + //---------------------------------------------------------------------------- //! Is the current source parameter referenced? bool isParamReferenced(const std::string ¶mName) const; }; + //---------------------------------------------------------------------------- + // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynPSM + //---------------------------------------------------------------------------- + //! Child group merged for incoming synapse groups + class InSynPSM : public GroupMerged + { + public: + InSynPSM(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups); + + //---------------------------------------------------------------------------- + // Public API + //---------------------------------------------------------------------------- + void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs); + + //! Should the current source parameter be implemented heterogeneously? + bool isParamHeterogeneous(const std::string ¶mName) const; + + //! Should the current source derived parameter be implemented heterogeneously? + bool isDerivedParamHeterogeneous(const std::string ¶mName) const; + + private: + //---------------------------------------------------------------------------- + // Private API + //---------------------------------------------------------------------------- + //! Is the current source parameter referenced? + bool isParamReferenced(const std::string ¶mName) const; + }; + + //---------------------------------------------------------------------------- + // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynPreOutput + //---------------------------------------------------------------------------- + //! Child group merged for outgoing synapse groups with $(addToPre) logic + class OutSynPreOutput : public GroupMerged + { + public: + OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups); + + //---------------------------------------------------------------------------- + // Public API + //---------------------------------------------------------------------------- + void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs); + }; + + //---------------------------------------------------------------------------- + // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynWUMPostCode + //---------------------------------------------------------------------------- + //! Child group merged for incoming synapse groups with postsynaptic update/spike code + class InSynWUMPostCode : public GroupMerged + { + public: + InSynWUMPostCode(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups); + + //---------------------------------------------------------------------------- + // Public API + //---------------------------------------------------------------------------- + void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike); + + //! Should the current source parameter be implemented heterogeneously? + bool isParamHeterogeneous(const std::string ¶mName) const; + + //! Should the current source derived parameter be implemented heterogeneously? + bool isDerivedParamHeterogeneous(const std::string ¶mName) const; + + private: + //---------------------------------------------------------------------------- + // Private API + //---------------------------------------------------------------------------- + //! Is the current source parameter referenced? + bool isParamReferenced(const std::string ¶mName) const; + }; + + //---------------------------------------------------------------------------- + // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynPreCode + //---------------------------------------------------------------------------- + //! Child group merged for outgoing synapse groups with presynaptic update/spike code + class OutSynPreCode : public GroupMerged + { + + }; + + + NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index e5c3235d07..a0d2f0f0c2 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -9,26 +9,79 @@ using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::CurrentSource //---------------------------------------------------------------------------- +// **TODO** +// * field suffix (string) and value suffix (function to get suffix from group) common to everything in group - GroupMerged fields? +// * without nasty combined groups, getParams and getDerivedParams functions can use pointers to members +// * pre and post neuron stuff in synapse update group merged can also be child classes NeuronUpdateGroupMerged::CurrentSource::CurrentSource(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, typeContext, groups) +: GroupMerged(index, typeContext, groups) { - // **TODO** correct "CS" + index field names + source name (for current source just its name but for PSM etc fused name) - addVars(getArchetype().getCurrentSourceModel()->getVars(), backend.getDeviceVarPrefix()); - - // **TODO** correct "CS" + index field names + source name (for current source just its name but for PSM etc fused name) - addHeterogeneousParams(getArchetype().getCurrentSourceModel()->getParamNames(), "CS", - &NeuronUpdateGroupMerged::CurrentSource::isParamHeterogeneous, - &CurrentSourceInternal::getParams); + const std::string suffix = "CS" + std::to_string(getIndex()); - // **TODO** correct "CS" + index field names + source name (for current source just its name but for PSM etc fused name) - addHeterogeneousDerivedParams(getArchetype().getCurrentSourceModel()->getDerivedParams(), "CS", - &NeuronUpdateGroupMerged::CurrentSource::isDerivedParamHeterogeneous, - &CurrentSourceInternal::getDerivedParams); + // Add variables + for(const auto &var : getArchetype().getCurrentSourceModel()->getVars()) { + addPointerField(var.type, var.name + suffix, + backend.getDeviceVarPrefix() + var.name); + } + + // Add parameters and derived parameters + addHeterogeneousParams( + getArchetype().getCurrentSourceModel()->getParamNames(), suffix, + [](const auto &cs) { return cs.getParams(); }, + &NeuronUpdateGroupMerged::CurrentSource::isParamHeterogeneous); + addHeterogeneousDerivedParams( + getArchetype().getCurrentSourceModel()->getDerivedParams(), suffix, + [](const auto &cs) { return cs.getDerivedParams(); }, + &NeuronUpdateGroupMerged::CurrentSource::isDerivedParamHeterogeneous); // Add EGPs - // **TODO** correct field names + source name (for current source just its name but for PSM etc fused name) - addEGPs(getArchetype().getCurrentSourceModel()->getExtraGlobalParams(), backend.getDeviceVarPrefix()); + for(const auto &egp : getArchetype().getCurrentSourceModel()->getExtraGlobalParams()) { + addPointerField(egp.type, egp.name + suffix, + backend.getDeviceVarPrefix() + egp.name, + GroupMergedFieldType::DYNAMIC); + } +} +//---------------------------------------------------------------------------- +void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs) +{ + os << "// current source " << getIndex() << std::endl; + + // Read current source variables into registers + const std::string suffix = "CS" + std::to_string(getIndex()); + for(const auto &v : getArchetype().getCurrentSourceModel()->getVars()) { + if(v.access & VarAccessMode::READ_ONLY) { + os << "const "; + } + os << v.type.resolve(getTypeContext()).getName() << " lcs" << v.name << " = " << "group->" << v.name << suffix << "["; + os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), popSubs["id"]) << "];" << std::endl; + } + + Substitutions currSourceSubs(&popSubs); + currSourceSubs.addFuncSubstitution("injectCurrent", 1, "Isyn += $(0)"); + currSourceSubs.addVarNameSubstitution(getArchetype().getCurrentSourceModel()->getVars(), "", "lcs"); + currSourceSubs.addParamValueSubstitution(getArchetype().getCurrentSourceModel()->getParamNames(), getArchetype().getParams(), + [this](const std::string &p) { return isParamHeterogeneous(p); }, + "", "group->", suffix); + currSourceSubs.addVarValueSubstitution(getArchetype().getCurrentSourceModel()->getDerivedParams(), getArchetype().getDerivedParams(), + [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }, + "", "group->", suffix); + currSourceSubs.addVarNameSubstitution(getArchetype().getCurrentSourceModel()->getExtraGlobalParams(), "", "group->", suffix); + + std::string iCode = getArchetype().getCurrentSourceModel()->getInjectionCode(); + currSourceSubs.applyCheckUnreplaced(iCode, "injectionCode : merged" + getIndex()); + //iCode = ensureFtype(iCode, model.getPrecision()); + os << iCode << std::endl; + + // Write read/write variables back to global memory + for(const auto &v : getArchetype().getCurrentSourceModel()->getVars()) { + if(v.access & VarAccessMode::READ_WRITE) { + os << "group->" << v.name << suffix << "["; + os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), currSourceSubs["id"]); + os << "] = lcs" << v.name << ";" << std::endl; + } + } } //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::CurrentSource::isParamHeterogeneous(const std::string ¶mName) const @@ -46,7 +99,312 @@ bool NeuronUpdateGroupMerged::CurrentSource::isDerivedParamHeterogeneous( const //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::CurrentSource::isParamReferenced(const std::string ¶mName) const { - return isParamReferenced({getArchetype().getCurrentSourceModel()->getInjectionCode()}, paramName); + return GroupMerged::isParamReferenced({getArchetype().getCurrentSourceModel()->getInjectionCode()}, + paramName); +} + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynPSM +//---------------------------------------------------------------------------- +NeuronUpdateGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups) +: GroupMerged(index, typeContext, groups) +{ + const std::string suffix = "InSyn" + std::to_string(getIndex()); + + // Add pointer to insyn + addField(getScalarType().createPointer(), "inSyn" + suffix, + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "inSyn" + g.getFusedPSVarSuffix(); }); + + // Add pointer to dendritic delay buffer if required + if(getArchetype().isDendriticDelayRequired()) { + addField(getScalarType().createPointer(), "denDelay" + suffix, + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); + + addField(Type::Uint32.createPointer(), "denDelayPtr" + suffix, + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); + } + + // Add pointers to state variable + // **FUSE** + for(const auto &var : getArchetype().getPSModel()->getVars()) { + addField(var.type.resolve(getTypeContext()).createPointer(), var.name + suffix, + [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getFusedPSVarSuffix(); }); + } + + // Add any heterogeneous postsynaptic model parameters + addHeterogeneousParams( + getArchetype().getPSModel()->getParamNames(), suffix, + [](const auto &sg) { return sg.getPSParams(); }, + &NeuronUpdateGroupMerged::InSynPSM::isParamHeterogeneous); + + // Add any heterogeneous postsynaptic mode derived parameters + addHeterogeneousDerivedParams( + getArchetype().getPSModel()->getDerivedParams(), suffix, + [](const auto &sg) { return sg.getPSDerivedParams(); }, + &NeuronUpdateGroupMerged::InSynPSM::isDerivedParamHeterogeneous); + + // Add EGPs + for(const auto &egp : getArchetype().getPSModel()->getExtraGlobalParams()) { + addField(egp.type.resolve(getTypeContext()).createPointer(), egp.name + suffix, + [&backend, egp](const auto &g, size_t) { return backend.getDeviceVarPrefix() + egp.name + g.getFusedPSVarSuffix(); }, + GroupMergedFieldType::DYNAMIC); + } +} +//---------------------------------------------------------------------------- +void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs) +{ + const std::string suffix = "InSyn" + std::to_string(getIndex()); + const auto *psm = getArchetype().getPSModel(); + + os << "// pull inSyn values in a coalesced access" << std::endl; + os << "scalar linSyn = group->inSynInSyn" << getIndex() << "["; + os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, popSubs["id"]); + os << "];" << std::endl; + + // If dendritic delay is required + if (getArchetype().isDendriticDelayRequired()) { + // Get reference to dendritic delay buffer input for this timestep + os << backend.getPointerPrefix() << "scalar *denDelayFront = "; + os << "&group->denDelay" << suffix << "[(*group->denDelayPtr" << suffix << " * group->numNeurons) + "; + os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, popSubs["id"]); + os << "];" << std::endl; + + // Add delayed input from buffer into inSyn + os << "linSyn += *denDelayFront;" << std::endl; + + // Zero delay buffer slot + os << "*denDelayFront = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; + } + + // Pull postsynaptic model variables in a coalesced access + for (const auto &v : psm->getVars()) { + if(v.access & VarAccessMode::READ_ONLY) { + os << "const "; + } + os << v.type.resolve(getTypeContext()).getName() << " lps" << v.name << " = group->" << v.name << suffix << "["; + os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), popSubs["id"]); + os << "];" << std::endl; + } + + Substitutions inSynSubs(&popSubs); + inSynSubs.addVarSubstitution("inSyn", "linSyn"); + + // Allow synapse group's PS output var to override what Isyn points to + inSynSubs.addVarSubstitution("Isyn", getArchetype().getPSTargetVar(), true); + inSynSubs.addVarNameSubstitution(psm->getVars(), "", "lps"); + + inSynSubs.addParamValueSubstitution(psm->getParamNames(), getArchetype().getPSParams(), + [this](const std::string &p) { return isParamHeterogeneous(p); }, + "", "group->", suffix); + inSynSubs.addVarValueSubstitution(psm->getDerivedParams(), getArchetype().getPSDerivedParams(), + [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }, + "", "group->", suffix); + inSynSubs.addVarNameSubstitution(psm->getExtraGlobalParams(), "", "group->", suffix); + + // Apply substitutions to current converter code + std::string psCode = psm->getApplyInputCode(); + inSynSubs.applyCheckUnreplaced(psCode, "postSyntoCurrent : merged " + getIndex()); + //psCode = ensureFtype(psCode, model.getPrecision()); + + // Apply substitutions to decay code + std::string pdCode = psm->getDecayCode(); + inSynSubs.applyCheckUnreplaced(pdCode, "decayCode : merged " + getIndex()); + //pdCode = ensureFtype(pdCode, model.getPrecision()); + + if (!psm->getSupportCode().empty() && backend.supportsNamespace()) { + os << "using namespace " << modelMerged.getPostsynapticDynamicsSupportCodeNamespace(psm->getSupportCode()) << ";" << std::endl; + } + + if (!psm->getSupportCode().empty() && !backend.supportsNamespace()) { + psCode = disambiguateNamespaceFunction(psm->getSupportCode(), psCode, modelMerged.getPostsynapticDynamicsSupportCodeNamespace(psm->getSupportCode())); + pdCode = disambiguateNamespaceFunction(psm->getSupportCode(), pdCode, modelMerged.getPostsynapticDynamicsSupportCodeNamespace(psm->getSupportCode())); + } + + os << psCode << std::endl; + os << pdCode << std::endl; + + if (!psm->getSupportCode().empty()) { + os << CodeStream::CB(29) << " // namespace bracket closed" << std::endl; + } + + // Write back linSyn + os << "group->inSyn" << suffix << "["; + os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, inSynSubs["id"]); + os << "] = linSyn;" << std::endl; + + // Copy any non-readonly postsynaptic model variables back to global state variables dd_V etc + for (const auto &v : psm->getVars()) { + if(v.access & VarAccessMode::READ_WRITE) { + os << "group->" << v.name << suffix << "["; + os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), inSynSubs["id"]); + os << "]" << " = lps" << v.name << ";" << std::endl; + } + } +} +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::InSynPSM::isParamHeterogeneous(const std::string ¶mName) const +{ + return (isParamReferenced(paramName) && + isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getPSParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::InSynPSM::isDerivedParamHeterogeneous( const std::string ¶mName) const +{ + return (isParamReferenced(paramName) && + isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getPSDerivedParams(); })); + +} +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::InSynPSM::isParamReferenced(const std::string ¶mName) const +{ + return GroupMerged::isParamReferenced( + {getArchetype().getPSModel()->getApplyInputCode(), getArchetype().getPSModel()->getDecayCode()}, + paramName); +} + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynPreOutput +//---------------------------------------------------------------------------- +NeuronUpdateGroupMerged::OutSynPreOutput::OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups) +: GroupMerged(index, typeContext, groups) +{ + const std::string suffix = "OutSyn" + std::to_string(getIndex()); + + addField(getScalarType().createPointer(), "revInSyn" + suffix, + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "revInSyn" + g.getFusedPreOutputSuffix(); }); +} +//---------------------------------------------------------------------------- +void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs) +{ + const std::string suffix = "OutSyn" + std::to_string(getIndex()); + + os << getArchetype().getPreTargetVar() << "+= "; + os << "group->revInSyn" << suffix << "["; + os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, popSubs["id"]); + os << "];" << std::endl; + os << "group->revInSyn" << suffix << "["; + os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, popSubs["id"]); + os << "]= " << modelMerged.scalarExpr(0.0) << ";" << std::endl; +} + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynWUMPostCode +//---------------------------------------------------------------------------- +NeuronUpdateGroupMerged::InSynWUMPostCode::InSynWUMPostCode(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups) +: GroupMerged(index, typeContext, groups) +{ + const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); + + // Add variables + for(const auto &var : getArchetype().getWUModel()->getPostVars()) { + addField(var.type.resolve(getTypeContext()).createPointer(), var.name + suffix, + [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getFusedWUPostVarSuffix(); }); + } + + // Add parameters and derived parameters + addHeterogeneousParams( + getArchetype().getWUModel()->getParamNames(), suffix, + [](const auto &sg) { return sg.getWUParams(); }, + &NeuronUpdateGroupMerged::InSynWUMPostCode::isParamHeterogeneous); + addHeterogeneousDerivedParams( + getArchetype().getWUModel()->getDerivedParams(), suffix, + [](const auto &sg) { return sg.getWUDerivedParams(); }, + &NeuronUpdateGroupMerged::InSynWUMPostCode::isDerivedParamHeterogeneous); + + // Add EGPs + for(const auto &egp : getArchetype().getWUModel()->getExtraGlobalParams()) { + addField(egp.type.resolve(getTypeContext()).createPointer(), egp.name + suffix, + [&backend, egp](const auto &g, size_t) { return backend.getDeviceVarPrefix() + egp.name + g.getFusedWUPostVarSuffix(); }, + GroupMergedFieldType::DYNAMIC); + } +} +//---------------------------------------------------------------------------- +void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike) +{ + const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); + + const unsigned int batchSize = modelMerged.getModel().getBatchSize(); + + // If this code string isn't empty + std::string code = dynamicsNotSpike ? getArchetype().getWUModel()->getPostDynamicsCode() : getArchetype().getWUModel()->getPostSpikeCode(); + if(!code.empty()) { + Substitutions subs(&popSubs); + + // Fetch variables from global memory + os << "// perform WUM update required for merged" << getIndex() << std::endl; + const auto vars = getArchetype().getWUModel()->getPostVars(); + const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); + for(const auto &v : vars) { + if(v.access & VarAccessMode::READ_ONLY) { + os << "const "; + } + os << v.type.resolve(getTypeContext()).getName() << " l" << v.name << " = group->" << v.name << suffix << "["; + os << ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), subs["id"]) << "];" << std::endl; + } + + subs.addParamValueSubstitution(getArchetype().getWUModel()->getParamNames(), getArchetype().getWUParams(), + [this](const std::string &p) { return isParamHeterogeneous(p); }, + "", "group->", suffix); + subs.addVarValueSubstitution(getArchetype().getWUModel()->getDerivedParams(), getArchetype().getWUDerivedParams(), + [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }, + "", "group->", suffix); + subs.addVarNameSubstitution(getArchetype().getWUModel()->getExtraGlobalParams(), "", "group->", suffix); + subs.addVarNameSubstitution(vars, "", "l"); + + neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_post", "", "", "", dynamicsNotSpike, + [&ng](const std::string &p) { return ng.isParamHeterogeneous(p); }, + [&ng](const std::string &p) { return ng.isDerivedParamHeterogeneous(p); }, + [&subs, &ng, batchSize](bool delay, VarAccessDuplication varDuplication) + { + return ng.getReadVarIndex(delay, batchSize, varDuplication, subs["id"]); + }, + [&subs, &ng, batchSize](bool delay, VarAccessDuplication varDuplication) + { + return ng.getReadVarIndex(delay, batchSize, varDuplication, subs["id"]); + }); + + // Perform standard substitutions + subs.applyCheckUnreplaced(code, "spikeCode : merged" + getIndex()); + //code = ensureFtype(code, precision); + os << code; + + // Write back presynaptic variables into global memory + for(const auto &v : vars) { + // If state variables is read/write - meaning that it may have been updated - or it is delayed - + // meaning that it needs to be copied into next delay slot whatever - copy neuron state variables + // back to global state variables dd_V etc + if((v.access & VarAccessMode::READ_WRITE) || delayed) { + os << "group->" << v.name << suffix << "["; + os << ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), subs["id"]) << "] = l" << v.name << ";" << std::endl; + } + } + } +} +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::InSynWUMPostCode::isParamHeterogeneous(const std::string ¶mName) const +{ + return (isParamReferenced(paramName) && + isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::InSynWUMPostCode::isDerivedParamHeterogeneous( const std::string ¶mName) const +{ + return (isParamReferenced(paramName) && + isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUDerivedParams(); })); + +} +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::InSynWUMPostCode::isParamReferenced(const std::string ¶mName) const +{ + return GroupMerged::isParamReferenced( + {getArchetype().getWUModel()->getPostDynamicsCode(), getArchetype().getWUModel()->getPostSpikeCode()}, + paramName); } //---------------------------------------------------------------------------- @@ -392,7 +750,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C os << "group->revInSynOutSyn" << i << "["; os << getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; os << "group->revInSynOutSyn" << i << "["; - os << getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "]= 0.0;" << std::endl; + os << getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "]= 0.0;" << std::endl; } // Loop through all of neuron group's current sources From 77c7b82a8d7e2c21c2aede96b78b55c65d9001e0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 1 Jun 2023 11:52:10 +0100 Subject: [PATCH 185/725] prior to further refactoring, adding in last child group --- .../code_generator/neuronUpdateGroupMerged.h | 27 +++- .../code_generator/neuronUpdateGroupMerged.cc | 148 ++++++++++++++++-- 2 files changed, 156 insertions(+), 19 deletions(-) diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index b89f0b63f0..b46c46b502 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -96,7 +96,7 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase { public: InSynWUMPostCode(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + const std::vector> &groups); //---------------------------------------------------------------------------- // Public API @@ -119,12 +119,33 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase }; //---------------------------------------------------------------------------- - // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynPreCode + // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynWUMPreCode //---------------------------------------------------------------------------- //! Child group merged for outgoing synapse groups with presynaptic update/spike code - class OutSynPreCode : public GroupMerged + class OutSynWUMPreCode : public GroupMerged { + public: + OutSynWUMPreCode(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups); + + //---------------------------------------------------------------------------- + // Public API + //---------------------------------------------------------------------------- + void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike); + + //! Should the current source parameter be implemented heterogeneously? + bool isParamHeterogeneous(const std::string ¶mName) const; + //! Should the current source derived parameter be implemented heterogeneously? + bool isDerivedParamHeterogeneous(const std::string ¶mName) const; + + private: + //---------------------------------------------------------------------------- + // Private API + //---------------------------------------------------------------------------- + //! Is the current source parameter referenced? + bool isParamReferenced(const std::string ¶mName) const; }; diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index a0d2f0f0c2..188255348d 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -26,14 +26,14 @@ NeuronUpdateGroupMerged::CurrentSource::CurrentSource(size_t index, const Type:: } // Add parameters and derived parameters - addHeterogeneousParams( + addHeterogeneousParams( getArchetype().getCurrentSourceModel()->getParamNames(), suffix, [](const auto &cs) { return cs.getParams(); }, - &NeuronUpdateGroupMerged::CurrentSource::isParamHeterogeneous); - addHeterogeneousDerivedParams( + &CurrentSource::isParamHeterogeneous); + addHeterogeneousDerivedParams( getArchetype().getCurrentSourceModel()->getDerivedParams(), suffix, [](const auto &cs) { return cs.getDerivedParams(); }, - &NeuronUpdateGroupMerged::CurrentSource::isDerivedParamHeterogeneous); + &CurrentSource::isDerivedParamHeterogeneous); // Add EGPs for(const auto &egp : getArchetype().getCurrentSourceModel()->getExtraGlobalParams()) { @@ -133,16 +133,16 @@ NeuronUpdateGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContex } // Add any heterogeneous postsynaptic model parameters - addHeterogeneousParams( + addHeterogeneousParams( getArchetype().getPSModel()->getParamNames(), suffix, [](const auto &sg) { return sg.getPSParams(); }, - &NeuronUpdateGroupMerged::InSynPSM::isParamHeterogeneous); + &InSynPSM::isParamHeterogeneous); // Add any heterogeneous postsynaptic mode derived parameters - addHeterogeneousDerivedParams( + addHeterogeneousDerivedParams( getArchetype().getPSModel()->getDerivedParams(), suffix, [](const auto &sg) { return sg.getPSDerivedParams(); }, - &NeuronUpdateGroupMerged::InSynPSM::isDerivedParamHeterogeneous); + &InSynPSM::isDerivedParamHeterogeneous); // Add EGPs for(const auto &egp : getArchetype().getPSModel()->getExtraGlobalParams()) { @@ -295,26 +295,26 @@ void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backe // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynWUMPostCode //---------------------------------------------------------------------------- NeuronUpdateGroupMerged::InSynWUMPostCode::InSynWUMPostCode(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) + const std::vector> &groups) : GroupMerged(index, typeContext, groups) { const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); - // Add variables + // Add postsynaptic variables for(const auto &var : getArchetype().getWUModel()->getPostVars()) { addField(var.type.resolve(getTypeContext()).createPointer(), var.name + suffix, [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getFusedWUPostVarSuffix(); }); } // Add parameters and derived parameters - addHeterogeneousParams( + addHeterogeneousParams( getArchetype().getWUModel()->getParamNames(), suffix, [](const auto &sg) { return sg.getWUParams(); }, - &NeuronUpdateGroupMerged::InSynWUMPostCode::isParamHeterogeneous); - addHeterogeneousDerivedParams( + &InSynWUMPostCode::isParamHeterogeneous); + addHeterogeneousDerivedParams( getArchetype().getWUModel()->getDerivedParams(), suffix, [](const auto &sg) { return sg.getWUDerivedParams(); }, - &NeuronUpdateGroupMerged::InSynWUMPostCode::isDerivedParamHeterogeneous); + &InSynWUMPostCode::isDerivedParamHeterogeneous); // Add EGPs for(const auto &egp : getArchetype().getWUModel()->getExtraGlobalParams()) { @@ -336,7 +336,7 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back if(!code.empty()) { Substitutions subs(&popSubs); - // Fetch variables from global memory + // Fetch postsynaptic variables from global memory os << "// perform WUM update required for merged" << getIndex() << std::endl; const auto vars = getArchetype().getWUModel()->getPostVars(); const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); @@ -374,7 +374,7 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back //code = ensureFtype(code, precision); os << code; - // Write back presynaptic variables into global memory + // Write back postsynaptic variables into global memory for(const auto &v : vars) { // If state variables is read/write - meaning that it may have been updated - or it is delayed - // meaning that it needs to be copied into next delay slot whatever - copy neuron state variables @@ -407,6 +407,122 @@ bool NeuronUpdateGroupMerged::InSynWUMPostCode::isParamReferenced(const std::str paramName); } + //---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynWUMPreCode +//---------------------------------------------------------------------------- +NeuronUpdateGroupMerged::OutSynWUMPreCode::OutSynWUMPreCode(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups) +: GroupMerged(index, typeContext, groups) +{ + const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); + + // Add presynaptic variables + for(const auto &var : getArchetype().getWUModel()->getPreVars()) { + addField(var.type.resolve(getTypeContext()).createPointer(), var.name + suffix, + [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getFusedWUPreVarSuffix(); }); + } + + // Add parameters and derived parameters + addHeterogeneousParams( + getArchetype().getWUModel()->getParamNames(), suffix, + [](const auto &sg) { return sg.getWUParams(); }, + &OutSynWUMPreCode::isParamHeterogeneous); + addHeterogeneousDerivedParams( + getArchetype().getWUModel()->getDerivedParams(), suffix, + [](const auto &sg) { return sg.getWUDerivedParams(); }, + &OutSynWUMPreCode::isDerivedParamHeterogeneous); + + // Add EGPs + for(const auto &egp : getArchetype().getWUModel()->getExtraGlobalParams()) { + addField(egp.type.resolve(getTypeContext()).createPointer(), egp.name + suffix, + [&backend, egp](const auto &g, size_t) { return backend.getDeviceVarPrefix() + egp.name + g.getFusedWUPreVarSuffix(); }, + GroupMergedFieldType::DYNAMIC); + } +} +//---------------------------------------------------------------------------- +void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike) +{ + const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); + + const unsigned int batchSize = modelMerged.getModel().getBatchSize(); + + // If this code string isn't empty + std::string code = dynamicsNotSpike ? getArchetype().getWUModel()->getPreDynamicsCode() : getArchetype().getWUModel()->getPreSpikeCode(); + if(!code.empty()) { + Substitutions subs(&popSubs); + + // Fetch presynaptic variables from global memory + os << "// perform WUM update required for merged" << getIndex() << std::endl; + const auto vars = getArchetype().getWUModel()->getPreVars(); + const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); + for(const auto &v : vars) { + if(v.access & VarAccessMode::READ_ONLY) { + os << "const "; + } + os << v.type.resolve(getTypeContext()).getName() << " l" << v.name << " = group->" << v.name << suffix << "["; + os << ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), subs["id"]) << "];" << std::endl; + } + + subs.addParamValueSubstitution(getArchetype().getWUModel()->getParamNames(), getArchetype().getWUParams(), + [this](const std::string &p) { return isParamHeterogeneous(p); }, + "", "group->", suffix); + subs.addVarValueSubstitution(getArchetype().getWUModel()->getDerivedParams(), getArchetype().getWUDerivedParams(), + [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }, + "", "group->", suffix); + subs.addVarNameSubstitution(getArchetype().getWUModel()->getExtraGlobalParams(), "", "group->", suffix); + subs.addVarNameSubstitution(vars, "", "l"); + + neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_pre", "", "", "", dynamicsNotSpike, + [&ng](const std::string &p) { return ng.isParamHeterogeneous(p); }, + [&ng](const std::string &p) { return ng.isDerivedParamHeterogeneous(p); }, + [&subs, &ng, batchSize](bool delay, VarAccessDuplication varDuplication) + { + return ng.getReadVarIndex(delay, batchSize, varDuplication, subs["id"]); + }, + [&subs, &ng, batchSize](bool delay, VarAccessDuplication varDuplication) + { + return ng.getReadVarIndex(delay, batchSize, varDuplication, subs["id"]); + }); + + // Perform standard substitutions + subs.applyCheckUnreplaced(code, "spikeCode : merged" + getIndex()); + //code = ensureFtype(code, precision); + os << code; + + // Write back presynaptic variables into global memory + for(const auto &v : vars) { + // If state variables is read/write - meaning that it may have been updated - or it is delayed - + // meaning that it needs to be copied into next delay slot whatever - copy neuron state variables + // back to global state variables dd_V etc + if((v.access & VarAccessMode::READ_WRITE) || delayed) { + os << "group->" << v.name << suffix << "["; + os << ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), subs["id"]) << "] = l" << v.name << ";" << std::endl; + } + } + } +} +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::OutSynWUMPreCode::isParamHeterogeneous(const std::string ¶mName) const +{ + return (isParamReferenced(paramName) && + isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::OutSynWUMPreCode::isDerivedParamHeterogeneous( const std::string ¶mName) const +{ + return (isParamReferenced(paramName) && + isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUDerivedParams(); })); + +} +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::OutSynWUMPreCode::isParamReferenced(const std::string ¶mName) const +{ + return GroupMerged::isParamReferenced( + {getArchetype().getWUModel()->getPreDynamicsCode(), getArchetype().getWUModel()->getPreSpikeCode()}, + paramName); +} + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged //---------------------------------------------------------------------------- From 8166f09f0c9ea3772ed30934fa291314ba37ac08 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 1 Jun 2023 14:47:32 +0100 Subject: [PATCH 186/725] added matching child groups for ``NeuronInitGroupMerged`` --- .../genn/genn/code_generator/groupMerged.h | 8 +- .../genn/code_generator/initGroupMerged.h | 137 +++++++ .../code_generator/neuronUpdateGroupMerged.h | 2 - .../genn/code_generator/initGroupMerged.cc | 345 ++++++++++++++++++ 4 files changed, 486 insertions(+), 6 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index e672238217..f600b71f22 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -351,7 +351,7 @@ class GroupMerged } template - void addHeterogeneousVarInitParams(H isHeterogeneous) + void addHeterogeneousVarInitParams(H isHeterogeneous, const std::string &suffix = "") { // Loop through weight update model variables const A archetypeAdaptor(getArchetype()); @@ -359,7 +359,7 @@ class GroupMerged // Loop through parameters for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getParams()) { if((static_cast(this)->*isHeterogeneous)(v.name, p.first)) { - addScalarField(p.first + v.name, + addScalarField(p.first + v.name + suffix, [p, v](const G &g, size_t) { return A(g).getInitialisers().at(v.name).getParams().at(p.first); @@ -370,7 +370,7 @@ class GroupMerged } template - void addHeterogeneousVarInitDerivedParams(H isHeterogeneous) + void addHeterogeneousVarInitDerivedParams(H isHeterogeneous, const std::string &suffix = "") { // Loop through weight update model variables const A archetypeAdaptor(getArchetype()); @@ -378,7 +378,7 @@ class GroupMerged // Loop through parameters for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getDerivedParams()) { if((static_cast(this)->*isHeterogeneous)(v.name, p.first)) { - addScalarField(p.first + v.name, + addScalarField(p.first + v.name + suffix, [p, v](const G &g, size_t) { return A(g).getInitialisers().at(v.name).getDerivedParams().at(p.first); diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index bd1598ef79..997506df48 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -11,6 +11,143 @@ namespace GeNN::CodeGenerator class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase { public: + //---------------------------------------------------------------------------- + // GeNN::CodeGenerator::NeuronInitGroupMerged::CurrentSource + //---------------------------------------------------------------------------- + //! Child group merged for current sources attached to this neuron update group + class CurrentSource : public GroupMerged + { + public: + CurrentSource(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups); + + //---------------------------------------------------------------------------- + // Public API + //---------------------------------------------------------------------------- + void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs); + + private: + //---------------------------------------------------------------------------- + // Private methods + //---------------------------------------------------------------------------- + //! Is the var init parameter referenced? + bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; + + //! Should the var init parameter be implemented heterogeneously? + bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + + //! Should the var init derived parameter be implemented heterogeneously? + bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + }; + + //---------------------------------------------------------------------------- + // GeNN::CodeGenerator::NeuronInitGroupMerged::InSynPSM + //---------------------------------------------------------------------------- + //! Child group merged for incoming synapse groups + class InSynPSM : public GroupMerged + { + public: + InSynPSM(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups); + + //---------------------------------------------------------------------------- + // Public API + //---------------------------------------------------------------------------- + void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs); + + private: + //---------------------------------------------------------------------------- + // Private methods + //---------------------------------------------------------------------------- + //! Is the var init parameter referenced? + bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; + + //! Should the var init parameter be implemented heterogeneously? + bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + + //! Should the var init derived parameter be implemented heterogeneously? + bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + }; + + //---------------------------------------------------------------------------- + // GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynPreOutput + //---------------------------------------------------------------------------- + //! Child group merged for outgoing synapse groups with $(addToPre) logic + class OutSynPreOutput : public GroupMerged + { + public: + OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups); + + //---------------------------------------------------------------------------- + // Public API + //---------------------------------------------------------------------------- + void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs); + }; + + //---------------------------------------------------------------------------- + // GeNN::CodeGenerator::NeuronInitGroupMerged::InSynWUMPostCode + //---------------------------------------------------------------------------- + //! Child group merged for incoming synapse groups with postsynaptic variables + class InSynWUMPostVars : public GroupMerged + { + public: + InSynWUMPostVars(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups); + + //---------------------------------------------------------------------------- + // Public API + //---------------------------------------------------------------------------- + void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs); + + private: + //---------------------------------------------------------------------------- + // Private methods + //---------------------------------------------------------------------------- + //! Is the var init parameter referenced? + bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; + + //! Should the var init parameter be implemented heterogeneously? + bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + + //! Should the var init derived parameter be implemented heterogeneously? + bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + }; + + //---------------------------------------------------------------------------- + // GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynWUMPreVars + //---------------------------------------------------------------------------- + //! Child group merged for outgoing synapse groups with presynaptic variables + class OutSynWUMPreVars: public GroupMerged + { + public: + OutSynWUMPreVars(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups); + + //---------------------------------------------------------------------------- + // Public API + //---------------------------------------------------------------------------- + void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs); + + private: + //---------------------------------------------------------------------------- + // Private methods + //---------------------------------------------------------------------------- + //! Is the var init parameter referenced? + bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; + + //! Should the var init parameter be implemented heterogeneously? + bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + + //! Should the var init derived parameter be implemented heterogeneously? + bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + }; + NeuronInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index b46c46b502..0af8b3c1e1 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -148,8 +148,6 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase bool isParamReferenced(const std::string ¶mName) const; }; - - NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups); diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 73b6c97084..df119b5504 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -180,6 +180,351 @@ void genInitWUVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const } } // Anonymous namespace +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronInitGroupMerged::CurrentSource +//---------------------------------------------------------------------------- +NeuronInitGroupMerged::CurrentSource::CurrentSource(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups) +: GroupMerged(index, typeContext, groups) +{ + const std::string suffix = "CS" + std::to_string(getIndex()); + + // Loop through variables + // **TODO** adaptor + const auto &varInit = getArchetype().getVarInitialisers(); + for(const auto &var : getArchetype().getCurrentSourceModel()->getVars()) { + // Add pointers to state variable + if(!varInit.at(var.name).getSnippet()->getCode().empty()) { + addPointerField(var.type, var.name + suffix, + backend.getDeviceVarPrefix() + var.name); + } + + // Add heterogeneous var init parameters + addHeterogeneousVarInitParams( + &CurrentSource::isVarInitParamHeterogeneous, suffix); + addHeterogeneousVarInitDerivedParams( + &CurrentSource::isVarInitDerivedParamHeterogeneous, suffix); + + // Add extra global parameters + for(const auto &e : varInit.at(var.name).getSnippet()->getExtraGlobalParams()) { + addField(e.type.resolve(getTypeContext()).createPointer(), e.name + var.name + suffix, + [&backend, e, suffix, var](const auto &g, size_t) + { + return backend.getDeviceVarPrefix() + e.name + var.name + g.getName(); + }, + GroupMergedFieldType::DYNAMIC); + } + } +} +//---------------------------------------------------------------------------- +void NeuronInitGroupMerged::CurrentSource::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs) +{ + const std::string suffix = "CS" + std::to_string(getIndex()); + + genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getCurrentSourceModel()->getVars(), getArchetype().getVarInitialisers(), + suffix, "numNeurons", getIndex(), modelMerged.getModel().getBatchSize(), + [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, + [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::CurrentSource::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return (isVarInitParamReferenced(varName, paramName) && + isParamValueHeterogeneous(paramName, + [varName](const auto &cs){ return cs.getVarInitialisers().at(varName).getParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::CurrentSource::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return (isVarInitParamReferenced(varName, paramName) && + isParamValueHeterogeneous(paramName, + [varName](const auto &cs){ return cs.getVarInitialisers().at(varName).getDerivedParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::CurrentSource::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const +{ + const auto *varInitSnippet = getArchetype().getVarInitialisers().at(varName).getSnippet(); + return isParamReferenced({varInitSnippet->getCode()}, paramName); +} + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronInitGroupMerged::InSynPSM +//---------------------------------------------------------------------------- +NeuronInitGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups) +: GroupMerged(index, typeContext, groups) +{ + const std::string suffix = "InSyn" + std::to_string(getIndex()); + + // Add pointer to insyn + addField(getScalarType().createPointer(), "inSyn" + suffix, + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "inSyn" + g.getFusedPSVarSuffix(); }); + + // Add pointer to dendritic delay buffer if required + if(getArchetype().isDendriticDelayRequired()) { + addField(getScalarType().createPointer(), "denDelay" + suffix, + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); + + addField(Type::Uint32.createPointer(), "denDelayPtr" + suffix, + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); + } + + // Loop through variables + // **TODO** adaptor + const auto &varInit = getArchetype().getPSVarInitialisers(); + for(const auto &var : getArchetype().getPSModel()->getVars()) { + // Add pointers to state variable + if(!varInit.at(var.name).getSnippet()->getCode().empty()) { + addField(var.type.resolve(getTypeContext()).createPointer(), var.name + suffix, + [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getFusedPSVarSuffix(); }); + } + + // Add heterogeneous var init parameters + addHeterogeneousVarInitParams( + &InSynPSM::isVarInitParamHeterogeneous, suffix); + addHeterogeneousVarInitDerivedParams( + &InSynPSM::isVarInitDerivedParamHeterogeneous, suffix); + + // Add extra global parameters + for(const auto &e : varInit.at(var.name).getSnippet()->getExtraGlobalParams()) { + addField(e.type.resolve(getTypeContext()).createPointer(), e.name + var.name + suffix, + [&backend, e, suffix, var](const auto &g, size_t) + { + return backend.getDeviceVarPrefix() + e.name + var.name + g.getFusedPSVarSuffix(); + }, + GroupMergedFieldType::DYNAMIC); + } + } +} +//---------------------------------------------------------------------------- +void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs) +{ + const std::string suffix = "InSyn" + std::to_string(getIndex()); + + // Zero InSyn + backend.genVariableInit(os, "group->numNeurons", "id", popSubs, + [&modelMerged, &suffix] (CodeStream &os, Substitutions &varSubs) + { + genVariableFill(os, "inSyn" + suffix, modelMerged.scalarExpr(0.0), + varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, + modelMerged.getModel().getBatchSize()); + + }); + + // If dendritic delays are required + if(getArchetype().isDendriticDelayRequired()) { + // Zero dendritic delay buffer + backend.genVariableInit(os, "group->numNeurons", "id", popSubs, + [&modelMerged, &suffix, this](CodeStream &os, Substitutions &varSubs) + { + genVariableFill(os, "denDelay" + suffix, modelMerged.scalarExpr(0.0), + varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, + modelMerged.getModel().getBatchSize(), + true, getArchetype().getMaxDendriticDelayTimesteps()); + }); + + // Zero dendritic delay pointer + backend.genPopVariableInit(os, popSubs, + [&suffix](CodeStream &os, Substitutions &) + { + os << "*group->denDelayPtr" << suffix << " = 0;" << std::endl; + }); + } + + // **TODO** adaptor + genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getPSModel()->getVars(), getArchetype().getPSVarInitialisers(), + suffix, "numNeurons", getIndex(), modelMerged.getModel().getBatchSize(), + [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, + [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::InSynPSM::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return (isVarInitParamReferenced(varName, paramName) && + isParamValueHeterogeneous(paramName, + [varName](const auto &sg){ return sg.getPSVarInitialisers().at(varName).getParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::InSynPSM::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return (isVarInitParamReferenced(varName, paramName) && + isParamValueHeterogeneous(paramName, + [varName](const auto &sg){ return sg.getPSVarInitialisers().at(varName).getDerivedParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::InSynPSM::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const +{ + const auto *varInitSnippet = getArchetype().getPSVarInitialisers().at(varName).getSnippet(); + return isParamReferenced({varInitSnippet->getCode()}, paramName); +} + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynPreOutput +//---------------------------------------------------------------------------- +NeuronInitGroupMerged::OutSynPreOutput::OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups) +: GroupMerged(index, typeContext, groups) +{ + const std::string suffix = "OutSyn" + std::to_string(getIndex()); + + addField(getScalarType().createPointer(), "revInSyn" + suffix, + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "revInSyn" + g.getFusedPreOutputSuffix(); }); +} +//---------------------------------------------------------------------------- +void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs) +{ + const std::string suffix = "OutSyn" + std::to_string(getIndex()); + + backend.genVariableInit(os, "group->numNeurons", "id", popSubs, + [&modelMerged, suffix] (CodeStream &os, Substitutions &varSubs) + { + genVariableFill(os, "revInSyn" + suffix, modelMerged.scalarExpr(0.0), + varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, + modelMerged.getModel().getBatchSize()); + }); +} + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronInitGroupMerged::InSynWUMPostVars +//---------------------------------------------------------------------------- +NeuronInitGroupMerged::InSynWUMPostVars::InSynWUMPostVars(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups) +: GroupMerged(index, typeContext, groups) +{ + const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); + + // Loop through variables + // **TODO** adaptor + const auto &varInit = getArchetype().getWUPostVarInitialisers(); + for(const auto &var : getArchetype().getWUModel()->getPostVars()) { + // Add pointers to state variable + if(!varInit.at(var.name).getSnippet()->getCode().empty()) { + addField(var.type.resolve(getTypeContext()).createPointer(), var.name + suffix, + [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getFusedWUPostVarSuffix(); }); + } + + // Add heterogeneous var init parameters + addHeterogeneousVarInitParams( + &InSynWUMPostVars::isVarInitParamHeterogeneous, suffix); + addHeterogeneousVarInitDerivedParams( + &InSynWUMPostVars::isVarInitDerivedParamHeterogeneous, suffix); + + // Add extra global parameters + for(const auto &e : varInit.at(var.name).getSnippet()->getExtraGlobalParams()) { + addField(e.type.resolve(getTypeContext()).createPointer(), e.name + var.name + suffix, + [&backend, e, suffix, var](const auto &g, size_t) + { + return backend.getDeviceVarPrefix() + e.name + var.name + g.getFusedWUPostVarSuffix(); + }, + GroupMergedFieldType::DYNAMIC); + } + } +} +//---------------------------------------------------------------------------- +void NeuronInitGroupMerged::InSynWUMPostVars::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs) +{ + const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); + + genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getWUModel()->getPostVars(), getArchetype().getWUPostVarInitialisers(), + suffix, "numNeurons", getArchetype().getTrgNeuronGroup()->getNumDelaySlots(), getIndex(), modelMerged.getModel().getBatchSize(), + [this](const std::string&){ return (getArchetype().getBackPropDelaySteps() != NO_DELAY); }, + [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, + [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::InSynWUMPostVars::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return (isVarInitParamReferenced(varName, paramName) && + isParamValueHeterogeneous(paramName, + [varName](const auto &sg){ return sg.getWUPostVarInitialisers().at(varName).getParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::InSynWUMPostVars::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return (isVarInitParamReferenced(varName, paramName) && + isParamValueHeterogeneous(paramName, + [varName](const auto &sg){ return sg.getWUPostVarInitialisers().at(varName).getDerivedParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::InSynWUMPostVars::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const +{ + const auto *varInitSnippet = getArchetype().getWUPostVarInitialisers().at(varName).getSnippet(); + return isParamReferenced({varInitSnippet->getCode()}, paramName); +} + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynWUMPreVars +//---------------------------------------------------------------------------- +NeuronInitGroupMerged::OutSynWUMPreVars::OutSynWUMPreVars(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups) +: GroupMerged(index, typeContext, groups) +{ + const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); + + // Loop through variables + // **TODO** adaptor + const auto &varInit = getArchetype().getWUPreVarInitialisers(); + for(const auto &var : getArchetype().getWUModel()->getPreVars()) { + // Add pointers to state variable + if(!varInit.at(var.name).getSnippet()->getCode().empty()) { + addField(var.type.resolve(getTypeContext()).createPointer(), var.name + suffix, + [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getFusedWUPreVarSuffix(); }); + } + + // Add heterogeneous var init parameters + addHeterogeneousVarInitParams( + &OutSynWUMPreVars::isVarInitParamHeterogeneous, suffix); + addHeterogeneousVarInitDerivedParams( + &OutSynWUMPreVars::isVarInitDerivedParamHeterogeneous, suffix); + + // Add extra global parameters + for(const auto &e : varInit.at(var.name).getSnippet()->getExtraGlobalParams()) { + addField(e.type.resolve(getTypeContext()).createPointer(), e.name + var.name + suffix, + [&backend, e, suffix, var](const auto &g, size_t) + { + return backend.getDeviceVarPrefix() + e.name + var.name + g.getFusedWUPreVarSuffix(); + }, + GroupMergedFieldType::DYNAMIC); + } + } +} +//---------------------------------------------------------------------------- +void NeuronInitGroupMerged::OutSynWUMPreVars::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs) +{ + const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); + + genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getWUModel()->getPreVars(), getArchetype().getWUPreVarInitialisers(), + suffix, "numNeurons", getArchetype().getSrcNeuronGroup()->getNumDelaySlots(), getIndex(), modelMerged.getModel().getBatchSize(), + [this](const std::string&){ return (getArchetype().getDelaySteps() != NO_DELAY); }, + [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, + [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::OutSynWUMPreVars::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return (isVarInitParamReferenced(varName, paramName) && + isParamValueHeterogeneous(paramName, + [varName](const auto &sg){ return sg.getWUPreVarInitialisers().at(varName).getParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::OutSynWUMPreVars::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return (isVarInitParamReferenced(varName, paramName) && + isParamValueHeterogeneous(paramName, + [varName](const auto &sg){ return sg.getWUPreVarInitialisers().at(varName).getDerivedParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::OutSynWUMPreVars::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const +{ + const auto *varInitSnippet = getArchetype().getWUPreVarInitialisers().at(varName).getSnippet(); + return isParamReferenced({varInitSnippet->getCode()}, paramName); +} + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged //---------------------------------------------------------------------------- From 8ebd7c0055785c25b5f2c90886995bcb045dd8ec Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 1 Jun 2023 18:44:54 +0100 Subject: [PATCH 187/725] WIP surgery on ``NeuronUpdateGroupMerged`` --- .../genn/genn/code_generator/groupMerged.h | 212 +--------- .../code_generator/neuronUpdateGroupMerged.h | 65 +-- src/genn/genn/code_generator/groupMerged.cc | 121 +----- .../code_generator/neuronUpdateGroupMerged.cc | 387 +++--------------- 4 files changed, 95 insertions(+), 690 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index f600b71f22..e3cc08ed32 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -624,59 +624,24 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged &getSortedArchetypeMergedInSyns() const { return m_SortedMergedInSyns.front(); } - - //! Get sorted vectors of merged outgoing synapse groups with presynaptic output belonging to archetype group - const std::vector &getSortedArchetypeMergedPreOutputOutSyns() const { return m_SortedMergedPreOutputOutSyns.front(); } - - //! Get sorted vectors of current sources belonging to archetype group - const std::vector &getSortedArchetypeCurrentSources() const { return m_SortedCurrentSources.front(); } - protected: - //------------------------------------------------------------------------ - // Protected methods - //------------------------------------------------------------------------ - NeuronGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - bool init, const std::vector> &groups); + NeuronGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + const std::vector> &groups); void updateBaseHash(bool init, boost::uuids::detail::sha1 &hash) const; - template - void orderNeuronGroupChildren(std::vector> &sortedGroupChildren, + template + void orderNeuronGroupChildren(std::vector &childGroups, const Type::TypeContext &typeContext, const BackendBase &backend, G getVectorFunc, H getHashDigestFunc) const { - const std::vector &archetypeChildren = (getArchetype().*getVectorFunc)(); + const std::vector &archetypeChildren = (getArchetype().*getVectorFunc)(); - // Reserve vector of vectors to hold children for all neuron groups, in archetype order - sortedGroupChildren.reserve(getGroups().size()); + // Resize vector of vectors to hold children for all neuron groups, sorted in a consistent manner + std::vector>> sortedGroupChildren; + sortedGroupChildren.resize(archetypeChildren.size()); // Create temporary vector of children and their digests - std::vector> childDigests; + std::vector> childDigests; childDigests.reserve(archetypeChildren.size()); // Loop through groups @@ -693,151 +658,27 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged &a, - const std::pair &b) + [](const auto &a, const auto &b) { return (a.first < b.first); }); - // Reserve vector for this group's children - sortedGroupChildren.emplace_back(); - sortedGroupChildren.back().reserve(groupChildren.size()); - - // Copy sorted child pointers into sortedGroupChildren - std::transform(childDigests.cbegin(), childDigests.cend(), std::back_inserter(sortedGroupChildren.back()), - [](const std::pair &a){ return a.second; }); - } - } - - //! Is the var init parameter referenced? - bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; - - //! Is the current source parameter referenced? - bool isCurrentSourceParamReferenced(size_t childIndex, const std::string ¶mName) const; - - //! Is the current source var init parameter referenced? - bool isCurrentSourceVarInitParamReferenced(size_t childIndex, const std::string &varName, const std::string ¶mName) const; - - //! Is the postsynaptic model parameter referenced? - bool isPSMParamReferenced(size_t childIndex, const std::string ¶mName) const; - - //! Is the postsynaptic model var init parameter referenced? - bool isPSMVarInitParamReferenced(size_t childIndex, const std::string &varName, const std::string ¶mName) const; - - template - bool isChildParamValueHeterogeneous(size_t childIndex, const std::string ¶mName, - const std::vector> &sortedGroupChildren, G getParamValuesFn) const - { - // Get value of archetype derived parameter - const double firstValue = getParamValuesFn(sortedGroupChildren[0][childIndex]).at(paramName); - - // Loop through groups within merged group - for(size_t i = 0; i < sortedGroupChildren.size(); i++) { - const auto group = sortedGroupChildren[i][childIndex]; - if(getParamValuesFn(group).at(paramName) != firstValue) { - return true; - } - } - - return false; - } - - template - void addHeterogeneousChildParams(const Snippet::Base::StringVec ¶mNames, - const std::vector> &sortedGroupChildren, - size_t childIndex, const std::string &prefix, - H isChildParamHeterogeneousFn, V getValueFn) - { - // Loop through parameters - for(const auto &p : paramNames) { - // If parameter is heterogeneous - if((static_cast(this)->*isChildParamHeterogeneousFn)(childIndex, p)) { - addScalarField(p + prefix + std::to_string(childIndex), - [&sortedGroupChildren, childIndex, p, getValueFn](const NeuronGroupInternal &, size_t groupIndex) - { - const auto *child = sortedGroupChildren.at(groupIndex).at(childIndex); - return std::invoke(getValueFn, child).at(p); - }); - } - } - } - - template - void addHeterogeneousChildDerivedParams(const Snippet::Base::DerivedParamVec &derivedParams, - const std::vector> &sortedGroupChildren, - size_t childIndex, const std::string &prefix, - H isChildDerivedParamHeterogeneousFn, V getValueFn) - { - // Loop through derived parameters - for(const auto &p : derivedParams) { - // If parameter is heterogeneous - if((static_cast(this)->*isChildDerivedParamHeterogeneousFn)(childIndex, p.name)) { - addScalarField(p.name + prefix + std::to_string(childIndex), - [&sortedGroupChildren, childIndex, p, getValueFn](const NeuronGroupInternal &, size_t groupIndex) - { - const auto *child = sortedGroupChildren.at(groupIndex).at(childIndex); - return std::invoke(getValueFn, child).at(p.name); - }); - } - } - } - - template - void addHeterogeneousChildVarInitParams(const Snippet::Base::StringVec ¶mNames, - const std::vector> &sortedGroupChildren, - size_t childIndex, const std::string &varName, const std::string &prefix, - H isChildParamHeterogeneousFn, V getVarInitialiserFn) - { - // Loop through parameters - for(const auto &p : paramNames) { - // If parameter is heterogeneous - if((static_cast(this)->*isChildParamHeterogeneousFn)(childIndex, varName, p)) { - addScalarField(p + varName + prefix + std::to_string(childIndex), - [&sortedGroupChildren, childIndex, varName, p, getVarInitialiserFn](const NeuronGroupInternal &, size_t groupIndex) - { - const auto *child = sortedGroupChildren.at(groupIndex).at(childIndex); - return std::invoke(getVarInitialiserFn, child).at(varName).getParams().at(p); - }); + // Populate 'transpose' vector of vectors + for (size_t i = 0; i < childDigests.size(); i++) { + sortedGroupChildren[i].emplace_back(*childDigests[i].second); } } - } - template - void addHeterogeneousChildVarInitDerivedParams(const Snippet::Base::DerivedParamVec &derivedParams, - const std::vector> &sortedGroupChildren, - size_t childIndex, const std::string &varName, const std::string &prefix, - H isChildDerivedParamHeterogeneousFn, V getVarInitialiserFn) - { - // Loop through parameters - for(const auto &d : derivedParams) { - // If parameter is heterogeneous - // **TODO** std::invoke - if((static_cast(this)->*isChildDerivedParamHeterogeneousFn)(childIndex, varName, d.name)) { - addScalarField(d.name + varName + prefix + std::to_string(childIndex), - [&sortedGroupChildren, childIndex, varName, d, getVarInitialiserFn](const NeuronGroupInternal &, size_t groupIndex) - { - const auto *child = sortedGroupChildren.at(groupIndex).at(childIndex); - return std::invoke(getVarInitialiserFn, child).at(varName).getDerivedParams().at(d.name); - }); - } + // Reserve vector of child groups and create merged group objects based on vector of groups + childGroups.reserve(archetypeChildren); + for(size_t i = 0; i < sortedGroupChildren.size(); i++) { + childGroups.emplace_back(i, typeContext, backend, sortedGroupChildren[i]); } } - template - void addChildEGPs(const std::vector &egps, size_t childIndex, - const std::string &arrayPrefix, const std::string &prefix, - S getEGPSuffixFn) - { - for(const auto &e : egps) { - addField(e.type.resolve(getTypeContext()).createPointer(), e.name + prefix + std::to_string(childIndex), - [getEGPSuffixFn, childIndex, e, arrayPrefix](const NeuronGroupInternal&, size_t groupIndex) - { - return arrayPrefix + e.name + getEGPSuffixFn(groupIndex, childIndex); - }, - GroupMergedFieldType::DYNAMIC); - } - } + //! Is the var init parameter referenced? + bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; template void updateChildParamHash(const std::vector> &sortedGroupChildren, @@ -931,21 +772,6 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged> m_SortedMergedInSyns; - std::vector> m_SortedMergedPreOutputOutSyns; - std::vector> m_SortedCurrentSources; }; //---------------------------------------------------------------------------- diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index 0af8b3c1e1..5e4a706c26 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -25,7 +25,10 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase // Public API //---------------------------------------------------------------------------- void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs); + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + + //! Update hash with child groups + void updateHash(boost::uuids::detail::sha1 &hash) const; //! Should the current source parameter be implemented heterogeneously? bool isParamHeterogeneous(const std::string ¶mName) const; @@ -49,13 +52,13 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase { public: InSynPSM(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + const std::vector> &groups); //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs); + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; //! Should the current source parameter be implemented heterogeneously? bool isParamHeterogeneous(const std::string ¶mName) const; @@ -85,7 +88,7 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase // Public API //---------------------------------------------------------------------------- void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs); + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; }; //---------------------------------------------------------------------------- @@ -102,7 +105,7 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase // Public API //---------------------------------------------------------------------------- void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike); + const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike) const; //! Should the current source parameter be implemented heterogeneously? bool isParamHeterogeneous(const std::string ¶mName) const; @@ -132,7 +135,7 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase // Public API //---------------------------------------------------------------------------- void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike); + const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike) const; //! Should the current source parameter be implemented heterogeneously? bool isParamHeterogeneous(const std::string ¶mName) const; @@ -154,24 +157,6 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - //! Should the incoming synapse weight update model parameter be implemented heterogeneously? - bool isInSynWUMParamHeterogeneous(size_t childIndex, const std::string ¶mName) const; - - //! Should the incoming synapse weight update model derived parameter be implemented heterogeneously? - bool isInSynWUMDerivedParamHeterogeneous(size_t childIndex, const std::string ¶mName) const; - - //! Should the outgoing synapse weight update model parameter be implemented heterogeneously? - bool isOutSynWUMParamHeterogeneous(size_t childIndex, const std::string ¶mName) const; - - //! Should the outgoing synapse weight update model derived parameter be implemented heterogeneously? - bool isOutSynWUMDerivedParamHeterogeneous(size_t childIndex, const std::string ¶mName) const; - - //! Get sorted vectors of incoming synapse groups with postsynaptic code belonging to archetype group - const std::vector &getSortedArchetypeInSynWithPostCode() const { return m_SortedInSynWithPostCode.front(); } - - //! Get sorted vectors of outgoing synapse groups with presynaptic code belonging to archetype group - const std::vector &getSortedArchetypeOutSynWithPreCode() const { return m_SortedOutSynWithPreCode.front(); } - //! Get hash digest used for detecting changes boost::uuids::detail::sha1::digest_type getHashDigest() const; @@ -203,36 +188,16 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //------------------------------------------------------------------------ // Private methods //------------------------------------------------------------------------ - //! Helper to generate merged struct fields for WU pre and post vars - void generateWUVar(const BackendBase &backend, const std::string &fieldPrefixStem, - const std::vector> &sortedSyn, - Models::Base::VarVec(WeightUpdateModels::Base::*getVars)(void) const, - bool(NeuronUpdateGroupMerged::*isParamHeterogeneous)(size_t, const std::string&) const, - bool(NeuronUpdateGroupMerged::*isDerivedParamHeterogeneous)(size_t, const std::string&) const, - const std::string&(SynapseGroupInternal::*getFusedVarSuffix)(void) const); - - //! Is the incoming synapse weight update model parameter referenced? - bool isInSynWUMParamReferenced(size_t childIndex, const std::string ¶mName) const; - - //! Is the outgoing synapse weight update model parameter referenced? - bool isOutSynWUMParamReferenced(size_t childIndex, const std::string ¶mName) const; - void addNeuronModelSubstitutions(Substitutions &substitution, const std::string &sourceSuffix = "", const std::string &destSuffix = "") const; - void generateWUVarUpdate(CodeStream &os, const Substitutions &popSubs, - const std::string &fieldPrefixStem, const std::string &sourceSuffix, - bool useLocalNeuronVars, unsigned int batchSize, - const std::vector &archetypeSyn, - unsigned int(SynapseGroupInternal::*getDelaySteps)(void) const, - Models::Base::VarVec(WeightUpdateModels::Base::*getVars)(void) const, - std::string(WeightUpdateModels::Base::*getCode)(void) const, - bool(NeuronUpdateGroupMerged::*isParamHeterogeneous)(size_t, const std::string&) const, - bool(NeuronUpdateGroupMerged::*isDerivedParamHeterogeneous)(size_t, const std::string&) const) const; - //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - std::vector> m_SortedInSynWithPostCode; - std::vector> m_SortedOutSynWithPreCode; + std::vector m_CurrentSources; + std::vector m_InSynPSMs; + std::vector m_OutSynPreOutput; + std::vector m_SortedInSynWithPostCode; + std::vector m_InSynWUMPostCode; + std::vector m_OutSynWUMPreCode; }; } // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index a669d42a35..f367da4a54 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -126,81 +126,12 @@ bool NeuronGroupMergedBase::isVarInitDerivedParamHeterogeneous(const std::string [varName](const NeuronGroupInternal &sg){ return sg.getVarInitialisers().at(varName).getDerivedParams(); })); } //---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isCurrentSourceParamHeterogeneous(size_t childIndex, const std::string ¶mName) const -{ - return (isCurrentSourceParamReferenced(childIndex, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedCurrentSources, - [](const CurrentSourceInternal *cs) { return cs->getParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isCurrentSourceDerivedParamHeterogeneous(size_t childIndex, const std::string ¶mName) const -{ - return (isCurrentSourceParamReferenced(childIndex, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedCurrentSources, - [](const CurrentSourceInternal *cs) { return cs->getDerivedParams(); })); - -} -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isCurrentSourceVarInitParamHeterogeneous(size_t childIndex, const std::string &varName, const std::string ¶mName) const -{ - return (isCurrentSourceVarInitParamReferenced(childIndex, varName, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedCurrentSources, - [varName](const CurrentSourceInternal *cs) { return cs->getVarInitialisers().at(varName).getParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isCurrentSourceVarInitDerivedParamHeterogeneous(size_t childIndex, const std::string &varName, const std::string ¶mName) const -{ - return (isCurrentSourceVarInitParamReferenced(childIndex, varName, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedCurrentSources, - [varName](const CurrentSourceInternal *cs) { return cs->getVarInitialisers().at(varName).getDerivedParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isPSMParamHeterogeneous(size_t childIndex, const std::string ¶mName) const -{ - return (isPSMParamReferenced(childIndex, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedMergedInSyns, - [](const SynapseGroupInternal *inSyn) { return inSyn->getPSParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isPSMDerivedParamHeterogeneous(size_t childIndex, const std::string ¶mName) const -{ - return (isPSMParamReferenced(childIndex, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedMergedInSyns, - [](const SynapseGroupInternal *inSyn) { return inSyn->getPSDerivedParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isPSMVarInitParamHeterogeneous(size_t childIndex, const std::string &varName, const std::string ¶mName) const -{ - return (isPSMVarInitParamReferenced(childIndex, varName, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedMergedInSyns, - [varName](const SynapseGroupInternal *inSyn){ return inSyn->getPSVarInitialisers().at(varName).getParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isPSMVarInitDerivedParamHeterogeneous(size_t childIndex, const std::string &varName, const std::string ¶mName) const -{ - return (isPSMVarInitParamReferenced(childIndex, varName, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedMergedInSyns, - [varName](const SynapseGroupInternal *inSyn){ return inSyn->getPSVarInitialisers().at(varName).getDerivedParams(); })); -} -//---------------------------------------------------------------------------- NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - bool init, const std::vector> &groups) + const std::vector> &groups) : GroupMerged(index, typeContext, groups) { using namespace Type; - // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group - orderNeuronGroupChildren(m_SortedMergedInSyns, &NeuronGroupInternal::getFusedPSMInSyn, - init ? &SynapseGroupInternal::getPSInitHashDigest : &SynapseGroupInternal::getPSHashDigest); - - // Build vector of vectors containing each child group's merged out syns with pre output, ordered to match those of the archetype group - orderNeuronGroupChildren(m_SortedMergedPreOutputOutSyns, &NeuronGroupInternal::getFusedPreOutputOutSyn, - init ? &SynapseGroupInternal::getPreOutputInitHashDigest : &SynapseGroupInternal::getPreOutputHashDigest); - - // Build vector of vectors containing each child group's current sources, ordered to match those of the archetype group - orderNeuronGroupChildren(m_SortedCurrentSources, &NeuronGroupInternal::getCurrentSources, - init ? &CurrentSourceInternal::getInitHashDigest : &CurrentSourceInternal::getHashDigest); - addField(Uint32, "numNeurons", [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); @@ -231,7 +162,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeConte } // If this backend initialises population RNGs on device and this group requires on for simulation - if(backend.isPopulationRNGRequired() && getArchetype().isSimRNGRequired() + /*if(backend.isPopulationRNGRequired() && getArchetype().isSimRNGRequired() && (!init || backend.isPopulationRNGInitialisedOnDevice())) { addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); @@ -397,7 +328,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeConte }); } - } + }*/ } //---------------------------------------------------------------------------- void NeuronGroupMergedBase::updateBaseHash(bool init, boost::uuids::detail::sha1 &hash) const @@ -457,52 +388,6 @@ bool NeuronGroupMergedBase::isVarInitParamReferenced(const std::string &varName, const auto *varInitSnippet = getArchetype().getVarInitialisers().at(varName).getSnippet(); return isParamReferenced({varInitSnippet->getCode()}, paramName); } -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isCurrentSourceParamReferenced(size_t childIndex, const std::string ¶mName) const -{ - const auto *csm = getSortedArchetypeCurrentSources().at(childIndex)->getCurrentSourceModel(); - return isParamReferenced({csm->getInjectionCode()}, paramName); -} -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isCurrentSourceVarInitParamReferenced(size_t childIndex, const std::string &varName, const std::string ¶mName) const -{ - const auto *varInitSnippet = getSortedArchetypeCurrentSources().at(childIndex)->getVarInitialisers().at(varName).getSnippet(); - return isParamReferenced({varInitSnippet->getCode()}, paramName); -} -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isPSMParamReferenced(size_t childIndex, const std::string ¶mName) const -{ - const auto *psm = getSortedArchetypeMergedInSyns().at(childIndex)->getPSModel(); - return isParamReferenced({psm->getApplyInputCode(), psm->getDecayCode()}, paramName); -} -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isPSMVarInitParamReferenced(size_t childIndex, const std::string &varName, const std::string ¶mName) const -{ - const auto *varInitSnippet = getSortedArchetypeMergedInSyns().at(childIndex)->getPSVarInitialisers().at(varName).getSnippet(); - return isParamReferenced({varInitSnippet->getCode()}, paramName); -} -//---------------------------------------------------------------------------- -void NeuronGroupMergedBase::addMergedInSynPointerField(const Type::ResolvedType &type, const std::string &name, - size_t archetypeIndex, const std::string &prefix) -{ - assert(type.isValue()); - addField(type.createPointer(), name + std::to_string(archetypeIndex), - [prefix, archetypeIndex, this](const auto&, size_t groupIndex) - { - return prefix + m_SortedMergedInSyns.at(groupIndex).at(archetypeIndex)->getFusedPSVarSuffix(); - }); -} -//---------------------------------------------------------------------------- -void NeuronGroupMergedBase::addMergedPreOutputOutSynPointerField(const Type::ResolvedType &type, const std::string &name, - size_t archetypeIndex, const std::string &prefix) -{ - assert(type.isValue()); - addField(type.createPointer(), name + std::to_string(archetypeIndex), - [prefix, archetypeIndex, this](const auto&, size_t groupIndex) - { - return prefix + m_SortedMergedPreOutputOutSyns.at(groupIndex).at(archetypeIndex)->getFusedPreOutputSuffix(); - }); -} //---------------------------------------------------------------------------- // GeNN::CodeGenerator::SynapseGroupMergedBase diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 188255348d..d7518263c9 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -44,7 +44,7 @@ NeuronUpdateGroupMerged::CurrentSource::CurrentSource(size_t index, const Type:: } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { os << "// current source " << getIndex() << std::endl; @@ -84,6 +84,14 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend } } //---------------------------------------------------------------------------- +void NeuronUpdateGroupMerged::CurrentSource::updateHash(boost::uuids::detail::sha1 &hash) const +{ + updateParamHash(&CurrentSource::isParamReferenced, + [](const CurrentSourceInternal &g) { return g.getParams(); }, hash); + updateParamHash(&CurrentSource::isParamReferenced, + [](const CurrentSourceInternal &g) { return g.getDerivedParams(); }, hash); +} +//---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::CurrentSource::isParamHeterogeneous(const std::string ¶mName) const { return (isParamReferenced(paramName) && @@ -153,7 +161,7 @@ NeuronUpdateGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContex } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { const std::string suffix = "InSyn" + std::to_string(getIndex()); const auto *psm = getArchetype().getPSModel(); @@ -278,7 +286,7 @@ NeuronUpdateGroupMerged::OutSynPreOutput::OutSynPreOutput(size_t index, const Ty } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { const std::string suffix = "OutSyn" + std::to_string(getIndex()); @@ -325,7 +333,7 @@ NeuronUpdateGroupMerged::InSynWUMPostCode::InSynWUMPostCode(size_t index, const } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike) + const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike) const { const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); @@ -441,7 +449,7 @@ NeuronUpdateGroupMerged::OutSynWUMPreCode::OutSynWUMPreCode(size_t index, const } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike) + const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike) const { const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); @@ -530,32 +538,38 @@ const std::string NeuronUpdateGroupMerged::name = "NeuronUpdate"; //---------------------------------------------------------------------------- NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: NeuronGroupMergedBase(index, typeContext, backend, false, groups) +: NeuronGroupMergedBase(index, typeContext, backend, groups) { using namespace Type; + // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group + orderNeuronGroupChildren(m_InSynPSMs, typeContext, backend, + &NeuronGroupInternal::getFusedPSMInSyn, + &SynapseGroupInternal::getPSHashDigest); + + // Build vector of vectors containing each child group's merged out syns with pre output, ordered to match those of the archetype group + orderNeuronGroupChildren(m_OutSynPreOutput, typeContext, backend, + &NeuronGroupInternal::getFusedPreOutputOutSyn, + SynapseGroupInternal::getPreOutputHashDigest); + + // Build vector of vectors containing each child group's current sources, ordered to match those of the archetype group + orderNeuronGroupChildren(m_CurrentSources, typeContext, backend, + &NeuronGroupInternal::getCurrentSources, + &CurrentSourceInternal::getHashDigest); + + // Build vector of vectors containing each child group's incoming synapse groups // with postsynaptic updates, ordered to match those of the archetype group - orderNeuronGroupChildren(m_SortedInSynWithPostCode, &NeuronGroupInternal::getFusedInSynWithPostCode, + orderNeuronGroupChildren(m_InSynWUMPostCode, typeContext, backend, + &NeuronGroupInternal::getFusedInSynWithPostCode, &SynapseGroupInternal::getWUPostHashDigest); // Build vector of vectors containing each child group's outgoing synapse groups // with presynaptic synaptic updates, ordered to match those of the archetype group - orderNeuronGroupChildren(m_SortedOutSynWithPreCode, &NeuronGroupInternal::getFusedOutSynWithPreCode, + orderNeuronGroupChildren(m_OutSynWUMPreCode, typeContext, backend, + &NeuronGroupInternal::getFusedOutSynWithPreCode, &SynapseGroupInternal::getWUPreHashDigest); - // Generate struct fields for incoming synapse groups with postsynaptic update code - generateWUVar(backend, "WUPost", m_SortedInSynWithPostCode, - &WeightUpdateModels::Base::getPostVars, &NeuronUpdateGroupMerged::isInSynWUMParamHeterogeneous, - &NeuronUpdateGroupMerged::isInSynWUMDerivedParamHeterogeneous, - &SynapseGroupInternal::getFusedWUPostVarSuffix); - - // Generate struct fields for outgoing synapse groups with presynaptic update code - generateWUVar(backend, "WUPre", m_SortedOutSynWithPreCode, - &WeightUpdateModels::Base::getPreVars, &NeuronUpdateGroupMerged::isOutSynWUMParamHeterogeneous, - &NeuronUpdateGroupMerged::isOutSynWUMDerivedParamHeterogeneous, - &SynapseGroupInternal::getFusedWUPreVarSuffix); - // Loop through neuron groups std::vector> eventThresholdSGs; for(const auto &g : getGroups()) { @@ -630,34 +644,6 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC } //---------------------------------------------------------------------------- -bool NeuronUpdateGroupMerged::isInSynWUMParamHeterogeneous(size_t childIndex, const std::string ¶mName) const -{ - return (isInSynWUMParamReferenced(childIndex, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedInSynWithPostCode, - [](const SynapseGroupInternal *s) { return s->getWUParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronUpdateGroupMerged::isInSynWUMDerivedParamHeterogeneous(size_t childIndex, const std::string ¶mName) const -{ - return (isInSynWUMParamReferenced(childIndex, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedInSynWithPostCode, - [](const SynapseGroupInternal *s) { return s->getWUDerivedParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronUpdateGroupMerged::isOutSynWUMParamHeterogeneous(size_t childIndex, const std::string ¶mName) const -{ - return (isOutSynWUMParamReferenced(childIndex, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedOutSynWithPreCode, - [](const SynapseGroupInternal *s) { return s->getWUParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronUpdateGroupMerged::isOutSynWUMDerivedParamHeterogeneous(size_t childIndex, const std::string ¶mName) const -{ - return (isOutSynWUMParamReferenced(childIndex, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedOutSynWithPreCode, - [](const SynapseGroupInternal *s) { return s->getWUDerivedParams(); })); -} -//---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type NeuronUpdateGroupMerged::getHashDigest() const { boost::uuids::detail::sha1 hash; @@ -770,146 +756,22 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C } // Loop through incoming synapse groups - for(size_t i = 0; i < getSortedArchetypeMergedInSyns().size(); i++) { + for(const auto &sg : m_InSynPSMs) { CodeStream::Scope b(os); - - const auto *sg = getSortedArchetypeMergedInSyns().at(i); - const auto *psm = sg->getPSModel(); - - os << "// pull inSyn values in a coalesced access" << std::endl; - os << "scalar linSyn = group->inSynInSyn" << i << "["; - os << getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; - - // If dendritic delay is required - if (sg->isDendriticDelayRequired()) { - // Get reference to dendritic delay buffer input for this timestep - os << backend.getPointerPrefix() << "scalar *denDelayFront = "; - os << "&group->denDelayInSyn" << i << "[(*group->denDelayPtrInSyn" << i << " * group->numNeurons) + "; - os << getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; - - // Add delayed input from buffer into inSyn - os << "linSyn += *denDelayFront;" << std::endl; - - // Zero delay buffer slot - os << "*denDelayFront = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; - } - - // Pull postsynaptic model variables in a coalesced access - for (const auto &v : psm->getVars()) { - if(v.access & VarAccessMode::READ_ONLY) { - os << "const "; - } - os << v.type.resolve(getTypeContext()).getName() << " lps" << v.name << " = group->" << v.name << "InSyn" << i << "["; - os << getVarIndex(batchSize, getVarAccessDuplication(v.access), neuronSubs["id"]) << "];" << std::endl; - } - - Substitutions inSynSubs(&neuronSubs); - inSynSubs.addVarSubstitution("inSyn", "linSyn"); - - // Allow synapse group's PS output var to override what Isyn points to - inSynSubs.addVarSubstitution("Isyn", sg->getPSTargetVar(), true); - inSynSubs.addVarNameSubstitution(psm->getVars(), "", "lps"); - - inSynSubs.addParamValueSubstitution(psm->getParamNames(), sg->getPSParams(), - [i, this](const std::string &p) { return isPSMParamHeterogeneous(i, p); }, - "", "group->", "InSyn" + std::to_string(i)); - inSynSubs.addVarValueSubstitution(psm->getDerivedParams(), sg->getPSDerivedParams(), - [i, this](const std::string &p) { return isPSMDerivedParamHeterogeneous(i, p); }, - "", "group->", "InSyn" + std::to_string(i)); - inSynSubs.addVarNameSubstitution(psm->getExtraGlobalParams(), "", "group->", "InSyn" + std::to_string(i)); - - // Apply substitutions to current converter code - std::string psCode = psm->getApplyInputCode(); - inSynSubs.applyCheckUnreplaced(psCode, "postSyntoCurrent : merged " + std::to_string(i)); - //psCode = ensureFtype(psCode, model.getPrecision()); - - // Apply substitutions to decay code - std::string pdCode = psm->getDecayCode(); - inSynSubs.applyCheckUnreplaced(pdCode, "decayCode : merged " + std::to_string(i)); - //pdCode = ensureFtype(pdCode, model.getPrecision()); - - if (!psm->getSupportCode().empty() && backend.supportsNamespace()) { - os << "using namespace " << modelMerged.getPostsynapticDynamicsSupportCodeNamespace(psm->getSupportCode()) << ";" << std::endl; - } - - if (!psm->getSupportCode().empty() && !backend.supportsNamespace()) { - psCode = disambiguateNamespaceFunction(psm->getSupportCode(), psCode, modelMerged.getPostsynapticDynamicsSupportCodeNamespace(psm->getSupportCode())); - pdCode = disambiguateNamespaceFunction(psm->getSupportCode(), pdCode, modelMerged.getPostsynapticDynamicsSupportCodeNamespace(psm->getSupportCode())); - } - - os << psCode << std::endl; - os << pdCode << std::endl; - - if (!psm->getSupportCode().empty()) { - os << CodeStream::CB(29) << " // namespace bracket closed" << std::endl; - } - - // Write back linSyn - os << "group->inSynInSyn" << i << "["; - os << getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, inSynSubs["id"]) << "] = linSyn;" << std::endl; - - // Copy any non-readonly postsynaptic model variables back to global state variables dd_V etc - for (const auto &v : psm->getVars()) { - if(v.access & VarAccessMode::READ_WRITE) { - os << "group->" << v.name << "InSyn" << i << "["; - os << getVarIndex(batchSize, getVarAccessDuplication(v.access), inSynSubs["id"]) << "]" << " = lps" << v.name << ";" << std::endl; - } - } + sg.generate(backend, os, *this, modelMerged, popSubs); } // Loop through outgoing synapse groups with presynaptic output - for(size_t i = 0; i < getSortedArchetypeMergedPreOutputOutSyns().size(); i++) { + for (const auto &sg : m_OutSynPreOutput) { CodeStream::Scope b(os); - const auto *sg = getSortedArchetypeMergedPreOutputOutSyns().at(i); - - os << sg->getPreTargetVar() << "+= "; - os << "group->revInSynOutSyn" << i << "["; - os << getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; - os << "group->revInSynOutSyn" << i << "["; - os << getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "]= 0.0;" << std::endl; + sg.generate(backend, os, *this, modelMerged, popSubs); } + // Loop through all of neuron group's current sources - for(size_t i = 0; i < getSortedArchetypeCurrentSources().size(); i++) { - const auto *cs = getSortedArchetypeCurrentSources().at(i); - - os << "// current source " << i << std::endl; + for (const auto &cs : m_CurrentSources) { CodeStream::Scope b(os); - - const auto *csm = cs->getCurrentSourceModel(); - - // Read current source variables into registers - for(const auto &v : csm->getVars()) { - if(v.access & VarAccessMode::READ_ONLY) { - os << "const "; - } - os << v.type.resolve(getTypeContext()).getName() << " lcs" << v.name << " = " << "group->" << v.name << "CS" << i << "["; - os << getVarIndex(batchSize, getVarAccessDuplication(v.access), popSubs["id"]) << "];" << std::endl; - } - - Substitutions currSourceSubs(&popSubs); - currSourceSubs.addFuncSubstitution("injectCurrent", 1, "Isyn += $(0)"); - currSourceSubs.addVarNameSubstitution(csm->getVars(), "", "lcs"); - currSourceSubs.addParamValueSubstitution(csm->getParamNames(), cs->getParams(), - [i, this](const std::string &p) { return isCurrentSourceParamHeterogeneous(i, p); }, - "", "group->", "CS" + std::to_string(i)); - currSourceSubs.addVarValueSubstitution(csm->getDerivedParams(), cs->getDerivedParams(), - [i, this](const std::string &p) { return isCurrentSourceDerivedParamHeterogeneous(i, p); }, - "", "group->", "CS" + std::to_string(i)); - currSourceSubs.addVarNameSubstitution(csm->getExtraGlobalParams(), "", "group->", "CS" + std::to_string(i)); - - std::string iCode = csm->getInjectionCode(); - currSourceSubs.applyCheckUnreplaced(iCode, "injectionCode : merged" + std::to_string(i)); - //iCode = ensureFtype(iCode, model.getPrecision()); - os << iCode << std::endl; - - // Write read/write variables back to global memory - for(const auto &v : csm->getVars()) { - if(v.access & VarAccessMode::READ_WRITE) { - os << "group->" << v.name << "CS" << i << "["; - os << getVarIndex(batchSize, getVarAccessDuplication(v.access), currSourceSubs["id"]) << "] = lcs" << v.name << ";" << std::endl; - } - } + cs.generate(backend, os, *this, modelMerged, popSubs); } if (!nm->getSupportCode().empty() && backend.supportsNamespace()) { @@ -951,19 +813,16 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C os << sCode << std::endl; // Generate var update for outgoing synaptic populations with presynaptic update code - generateWUVarUpdate(os, popSubs, "WUPre", "_pre", true, batchSize, - getSortedArchetypeOutSynWithPreCode(), &SynapseGroupInternal::getDelaySteps, - &WeightUpdateModels::Base::getPreVars, &WeightUpdateModels::Base::getPreDynamicsCode, - &NeuronUpdateGroupMerged::isOutSynWUMParamHeterogeneous, - &NeuronUpdateGroupMerged::isOutSynWUMDerivedParamHeterogeneous); - + for (const auto &sg : m_OutSynWUMPreCode) { + CodeStream::Scope b(os); + sg.generate(backend, os, *this, modelMerged, popSubs, true); + } // Generate var update for incoming synaptic populations with postsynaptic code - generateWUVarUpdate(os, popSubs, "WUPost", "_post", true, batchSize, - getSortedArchetypeInSynWithPostCode(), &SynapseGroupInternal::getBackPropDelaySteps, - &WeightUpdateModels::Base::getPostVars, &WeightUpdateModels::Base::getPostDynamicsCode, - &NeuronUpdateGroupMerged::isInSynWUMParamHeterogeneous, - &NeuronUpdateGroupMerged::isInSynWUMDerivedParamHeterogeneous); + for (const auto &sg : m_OutSynWUMPreCode) { + CodeStream::Scope b(os); + sg.generate(backend, os, *this, modelMerged, popSubs, true); + } // look for spike type events first. if (getArchetype().isSpikeEventRequired()) { @@ -1147,23 +1006,19 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C } } //-------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { // Generate var update for outgoing synaptic populations with presynaptic update code - const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - generateWUVarUpdate(os, popSubs, "WUPre", "_pre", false, batchSize, - getSortedArchetypeOutSynWithPreCode(), &SynapseGroupInternal::getDelaySteps, - &WeightUpdateModels::Base::getPreVars, &WeightUpdateModels::Base::getPreSpikeCode, - &NeuronUpdateGroupMerged::isOutSynWUMParamHeterogeneous, - &NeuronUpdateGroupMerged::isOutSynWUMDerivedParamHeterogeneous); - + for (const auto &sg : m_OutSynWUMPreCode) { + CodeStream::Scope b(os); + sg.generate(backend, os, *this, modelMerged, popSubs, false); + } // Generate var update for incoming synaptic populations with postsynaptic code - generateWUVarUpdate(os, popSubs, "WUPost", "_post", false, batchSize, - getSortedArchetypeInSynWithPostCode(), &SynapseGroupInternal::getBackPropDelaySteps, - &WeightUpdateModels::Base::getPostVars, &WeightUpdateModels::Base::getPostSpikeCode, - &NeuronUpdateGroupMerged::isInSynWUMParamHeterogeneous, - &NeuronUpdateGroupMerged::isInSynWUMDerivedParamHeterogeneous); + for (const auto &sg : m_OutSynWUMPreCode) { + CodeStream::Scope b(os); + sg.generate(backend, os, *this, modelMerged, popSubs, false); + } } //-------------------------------------------------------------------------- std::string NeuronUpdateGroupMerged::getVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const @@ -1216,60 +1071,6 @@ std::string NeuronUpdateGroupMerged::getWriteVarIndex(bool delay, unsigned int b } } //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::generateWUVar(const BackendBase &backend, const std::string &fieldPrefixStem, - const std::vector> &sortedSyn, - Models::Base::VarVec (WeightUpdateModels::Base::*getVars)(void) const, - bool(NeuronUpdateGroupMerged::*isParamHeterogeneous)(size_t, const std::string&) const, - bool(NeuronUpdateGroupMerged::*isDerivedParamHeterogeneous)(size_t, const std::string&) const, - const std::string&(SynapseGroupInternal::*getFusedVarSuffix)(void) const) -{ - // Loop through synapse groups - const auto &archetypeSyns = sortedSyn.front(); - for(size_t i = 0; i < archetypeSyns.size(); i++) { - const auto *sg = archetypeSyns.at(i); - - // Loop through variables - const auto vars = (sg->getWUModel()->*getVars)(); - for(size_t v = 0; v < vars.size(); v++) { - // Add pointers to state variable - const auto var = vars[v]; - addField(var.type.resolve(getTypeContext()).createPointer(), var.name + fieldPrefixStem + std::to_string(i), - [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto &, size_t groupIndex) - { - const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); - return backend.getDeviceVarPrefix() + var.name + varMergeSuffix; - }); - } - - // Add any heterogeneous parameters - addHeterogeneousChildParams(sg->getWUModel()->getParamNames(), sortedSyn, i, fieldPrefixStem, - isParamHeterogeneous, &SynapseGroupInternal::getWUParams); - - // Add any heterogeneous derived parameters - addHeterogeneousChildDerivedParams(sg->getWUModel()->getDerivedParams(), sortedSyn, i, fieldPrefixStem, - isDerivedParamHeterogeneous, &SynapseGroupInternal::getWUDerivedParams); - - // Add EGPs - addChildEGPs(sg->getWUModel()->getExtraGlobalParams(), i, backend.getDeviceVarPrefix(), fieldPrefixStem, - [&sortedSyn](size_t groupIndex, size_t childIndex) - { - return sortedSyn.at(groupIndex).at(childIndex)->getName(); - }); - } -} -//---------------------------------------------------------------------------- -bool NeuronUpdateGroupMerged::isInSynWUMParamReferenced(size_t childIndex, const std::string ¶mName) const -{ - const auto *wum = getSortedArchetypeInSynWithPostCode().at(childIndex)->getWUModel(); - return isParamReferenced({wum->getPostSpikeCode(), wum->getPostDynamicsCode()}, paramName); -} -//---------------------------------------------------------------------------- -bool NeuronUpdateGroupMerged::isOutSynWUMParamReferenced(size_t childIndex, const std::string ¶mName) const -{ - const auto *wum = getSortedArchetypeOutSynWithPreCode().at(childIndex)->getWUModel(); - return isParamReferenced({wum->getPreSpikeCode(), wum->getPreDynamicsCode()}, paramName); -} -//---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::addNeuronModelSubstitutions(Substitutions &substitution, const std::string &sourceSuffix, const std::string &destSuffix) const { const NeuronModels::Base *nm = getArchetype().getNeuronModel(); @@ -1282,75 +1083,3 @@ void NeuronUpdateGroupMerged::addNeuronModelSubstitutions(Substitutions &substit sourceSuffix, "group->"); substitution.addVarNameSubstitution(nm->getExtraGlobalParams(), sourceSuffix, "group->"); } -//-------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::generateWUVarUpdate(CodeStream &os, const Substitutions &popSubs, - const std::string &fieldPrefixStem, const std::string &sourceSuffix, - bool useLocalNeuronVars, unsigned int batchSize, - const std::vector &archetypeSyn, - unsigned int(SynapseGroupInternal::*getDelaySteps)(void) const, - Models::Base::VarVec(WeightUpdateModels::Base::*getVars)(void) const, - std::string(WeightUpdateModels::Base::*getCode)(void) const, - bool(NeuronUpdateGroupMerged::*isParamHeterogeneous)(size_t, const std::string&) const, - bool(NeuronUpdateGroupMerged::*isDerivedParamHeterogeneous)(size_t, const std::string&) const) const -{ - // Loop through synaptic populations - for(size_t i = 0; i < archetypeSyn.size(); i++) { - const SynapseGroupInternal *sg = archetypeSyn[i]; - - // If this code string isn't empty - std::string code = (sg->getWUModel()->*getCode)(); - if(!code.empty()) { - Substitutions subs(&popSubs); - CodeStream::Scope b(os); - - // Fetch variables from global memory - os << "// perform WUM update required for merged" << i << std::endl; - const auto vars = (sg->getWUModel()->*getVars)(); - const bool delayed = ((sg->*getDelaySteps)() != NO_DELAY); - for(const auto &v : vars) { - if(v.access & VarAccessMode::READ_ONLY) { - os << "const "; - } - os << v.type.resolve(getTypeContext()).getName() << " l" << v.name << " = group->" << v.name << fieldPrefixStem << i << "["; - os << getReadVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), subs["id"]) << "];" << std::endl; - } - - subs.addParamValueSubstitution(sg->getWUModel()->getParamNames(), sg->getWUParams(), - [i, isParamHeterogeneous , this](const std::string &p) { return (this->*isParamHeterogeneous)(i, p); }, - "", "group->", fieldPrefixStem + std::to_string(i)); - subs.addVarValueSubstitution(sg->getWUModel()->getDerivedParams(), sg->getWUDerivedParams(), - [i, isDerivedParamHeterogeneous, this](const std::string &p) { return (this->*isDerivedParamHeterogeneous)(i, p); }, - "", "group->", fieldPrefixStem + std::to_string(i)); - subs.addVarNameSubstitution(sg->getWUModel()->getExtraGlobalParams(), "", "group->", fieldPrefixStem + std::to_string(i)); - subs.addVarNameSubstitution(vars, "", "l"); - - neuronSubstitutionsInSynapticCode(subs, &getArchetype(), "", sourceSuffix, "", "", "", useLocalNeuronVars, - [this](const std::string &p) { return this->isParamHeterogeneous(p); }, - [this](const std::string &p) { return this->isDerivedParamHeterogeneous(p); }, - [&subs, batchSize, this](bool delay, VarAccessDuplication varDuplication) - { - return getReadVarIndex(delay, batchSize, varDuplication, subs["id"]); - }, - [&subs, batchSize, this](bool delay, VarAccessDuplication varDuplication) - { - return getReadVarIndex(delay, batchSize, varDuplication, subs["id"]); - }); - - // Perform standard substitutions - subs.applyCheckUnreplaced(code, "spikeCode : merged" + std::to_string(i)); - //code = ensureFtype(code, precision); - os << code; - - // Write back presynaptic variables into global memory - for(const auto &v : vars) { - // If state variables is read/write - meaning that it may have been updated - or it is delayed - - // meaning that it needs to be copied into next delay slot whatever - copy neuron state variables - // back to global state variables dd_V etc - if((v.access & VarAccessMode::READ_WRITE) || delayed) { - os << "group->" << v.name << fieldPrefixStem << i << "["; - os << getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), subs["id"]) << "] = l" << v.name << ";" << std::endl; - } - } - } - } -} From cf80bfa496f2950aa13d86c5b62d66dca38f2bfa Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 1 Jun 2023 21:02:18 +0100 Subject: [PATCH 188/725] hierarchical ``NeuronUpdateGroupMerged`` now compiles --- .../genn/genn/code_generator/groupMerged.h | 6 +- .../code_generator/neuronUpdateGroupMerged.h | 15 ++ src/genn/genn/code_generator/groupMerged.cc | 52 ------- .../code_generator/neuronUpdateGroupMerged.cc | 145 ++++++++++++------ 4 files changed, 111 insertions(+), 107 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index e3cc08ed32..b40cafad70 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -628,8 +628,6 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged> &groups); - void updateBaseHash(bool init, boost::uuids::detail::sha1 &hash) const; - template void orderNeuronGroupChildren(std::vector &childGroups, const Type::TypeContext &typeContext, const BackendBase &backend, G getVectorFunc, H getHashDigestFunc) const @@ -647,7 +645,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged &groupChildren = (g.get().*getVectorFunc)(); + const std::vector &groupChildren = (g.get().*getVectorFunc)(); assert(groupChildren.size() == archetypeChildren.size()); // Loop through children and add them and their digests to vector @@ -671,7 +669,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMergedgetVarInitialisers()) { - updateChildVarInitParamsHash( - m_SortedCurrentSources, c, v.first, &NeuronGroupMergedBase::isCurrentSourceVarInitParamReferenced, hash); - updateChildVarInitDerivedParamsHash( - m_SortedCurrentSources, c, v.first, &NeuronGroupMergedBase::isCurrentSourceVarInitParamReferenced, hash); - } - } - - // Loop through child merged insyns - for(size_t c = 0; c < getSortedArchetypeMergedInSyns().size(); c++) { - const auto *sg = getSortedArchetypeMergedInSyns().at(c); - - // Loop through variables and update hash with variable initialisation parameters and derived parameters - for(const auto &v : sg->getPSVarInitialisers()) { - updateChildVarInitParamsHash( - m_SortedMergedInSyns, c, v.first, &NeuronGroupMergedBase::isPSMVarInitParamReferenced, hash); - updateChildVarInitDerivedParamsHash( - m_SortedMergedInSyns, c, v.first, &NeuronGroupMergedBase::isPSMVarInitParamReferenced, hash); - } - } - } - else { - // Loop through child current sources - for(size_t i = 0; i < getSortedArchetypeCurrentSources().size(); i++) { - updateChildParamHash(m_SortedCurrentSources, i, &NeuronGroupMergedBase::isCurrentSourceParamReferenced, - &CurrentSourceInternal::getParams, hash); - updateChildDerivedParamHash(m_SortedCurrentSources, i, &NeuronGroupMergedBase::isCurrentSourceParamReferenced, - &CurrentSourceInternal::getDerivedParams, hash); - } - - // Loop through child merged insyns - for(size_t i = 0; i < getSortedArchetypeMergedInSyns().size(); i++) { - updateChildParamHash(m_SortedMergedInSyns, i, &NeuronGroupMergedBase::isPSMParamReferenced, - &SynapseGroupInternal::getPSParams, hash); - updateChildDerivedParamHash(m_SortedMergedInSyns, i, &NeuronGroupMergedBase::isPSMParamReferenced, - &SynapseGroupInternal::getPSDerivedParams, hash); - } - } -} -//---------------------------------------------------------------------------- bool NeuronGroupMergedBase::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const { const auto *varInitSnippet = getArchetype().getVarInitialisers().at(varName).getSnippet(); diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index d7518263c9..7f9e03f7b9 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -252,6 +252,14 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Cod } } //---------------------------------------------------------------------------- +void NeuronUpdateGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &hash) const +{ + updateParamHash(&InSynPSM::isParamReferenced, + [](const SynapseGroupInternal &g) { return g.getPSParams(); }, hash); + updateParamHash(&InSynPSM::isParamReferenced, + [](const SynapseGroupInternal &g) { return g.getPSDerivedParams(); }, hash); +} +//---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::InSynPSM::isParamHeterogeneous(const std::string ¶mName) const { return (isParamReferenced(paramName) && @@ -395,6 +403,35 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back } } //---------------------------------------------------------------------------- +void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +{ + // If this group has a delay and no postsynaptic dynamics (which will already perform this copying) + const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); + if(getArchetype().getBackPropDelaySteps() != NO_DELAY && getArchetype().getWUModel()->getPostDynamicsCode().empty()) { + // Loop through variables and copy between read and write delay slots + for(const auto &v : getArchetype().getWUModel()->getPostVars()) { + if(v.access & VarAccessMode::READ_WRITE) { + os << "group->" << v.name << suffix << "["; + os << ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), popSubs["id"]); + os << "] = "; + + os << "group->" << v.name << suffix << "["; + os << ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), popSubs["id"]); + os << "];" << std::endl; + } + } + } +} +//---------------------------------------------------------------------------- +void NeuronUpdateGroupMerged::InSynWUMPostCode::updateHash(boost::uuids::detail::sha1 &hash) const +{ + updateParamHash(&InSynWUMPostCode::isParamReferenced, + [](const SynapseGroupInternal &g) { return g.getWUParams(); }, hash); + updateParamHash(&InSynWUMPostCode::isParamReferenced, + [](const SynapseGroupInternal &g) { return g.getWUDerivedParams(); }, hash); +} +//---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::InSynWUMPostCode::isParamHeterogeneous(const std::string ¶mName) const { return (isParamReferenced(paramName) && @@ -511,6 +548,35 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back } } //---------------------------------------------------------------------------- +void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(CodeStream &os, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +{ + // If this group has a delay and no presynaptic dynamics (which will already perform this copying) + const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); + if(getArchetype().getDelaySteps() != NO_DELAY && getArchetype().getWUModel()->getPreDynamicsCode().empty()) { + // Loop through variables and copy between read and write delay slots + for(const auto &v : getArchetype().getWUModel()->getPreVars()) { + if(v.access & VarAccessMode::READ_WRITE) { + os << "group->" << v.name << suffix << "["; + os << ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), popSubs["id"]); + os << "] = "; + + os << "group->" << v.name << suffix << "["; + os << ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), popSubs["id"]); + os << "];" << std::endl; + } + } + } +} +//---------------------------------------------------------------------------- +void NeuronUpdateGroupMerged::OutSynWUMPreCode::updateHash(boost::uuids::detail::sha1 &hash) const +{ + updateParamHash(&OutSynWUMPreCode::isParamReferenced, + [](const SynapseGroupInternal &g) { return g.getWUParams(); }, hash); + updateParamHash(&OutSynWUMPreCode::isParamReferenced, + [](const SynapseGroupInternal &g) { return g.getWUDerivedParams(); }, hash); +} +//---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::OutSynWUMPreCode::isParamHeterogeneous(const std::string ¶mName) const { return (isParamReferenced(paramName) && @@ -550,7 +616,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC // Build vector of vectors containing each child group's merged out syns with pre output, ordered to match those of the archetype group orderNeuronGroupChildren(m_OutSynPreOutput, typeContext, backend, &NeuronGroupInternal::getFusedPreOutputOutSyn, - SynapseGroupInternal::getPreOutputHashDigest); + &SynapseGroupInternal::getPreOutputHashDigest); // Build vector of vectors containing each child group's current sources, ordered to match those of the archetype group orderNeuronGroupChildren(m_CurrentSources, typeContext, backend, @@ -648,8 +714,8 @@ boost::uuids::detail::sha1::digest_type NeuronUpdateGroupMerged::getHashDigest() { boost::uuids::detail::sha1 hash; - // Update hash with generic neuron group data - updateBaseHash(false, hash); + // Update hash with each group's neuron count + updateHash([](const NeuronGroupInternal &g) { return g.getNumNeurons(); }, hash); // Update hash with archetype's hash digest Utils::updateHash(getArchetype().getHashDigest(), hash); @@ -657,21 +723,19 @@ boost::uuids::detail::sha1::digest_type NeuronUpdateGroupMerged::getHashDigest() // Update hash with each group's parameters and derived parameters updateHash([](const NeuronGroupInternal &g) { return g.getParams(); }, hash); updateHash([](const NeuronGroupInternal &g) { return g.getDerivedParams(); }, hash); - - // Loop through child incoming synapse groups with postsynaptic update code - for(size_t i = 0; i < getSortedArchetypeInSynWithPostCode().size(); i++) { - updateChildParamHash(m_SortedInSynWithPostCode, i, &NeuronUpdateGroupMerged::isInSynWUMParamReferenced, - &SynapseGroupInternal::getWUParams, hash); - updateChildDerivedParamHash(m_SortedInSynWithPostCode, i, &NeuronUpdateGroupMerged::isInSynWUMParamReferenced, - &SynapseGroupInternal::getWUDerivedParams, hash); + + // Update hash with child groups + for (const auto &cs : m_CurrentSources) { + cs.updateHash(hash); } - - // Loop through child outgoing synapse groups with presynaptic update code - for(size_t i = 0; i < getSortedArchetypeOutSynWithPreCode().size(); i++) { - updateChildParamHash(m_SortedOutSynWithPreCode, i, &NeuronUpdateGroupMerged::isOutSynWUMParamReferenced, - &SynapseGroupInternal::getWUParams, hash); - updateChildDerivedParamHash( m_SortedOutSynWithPreCode, i, &NeuronUpdateGroupMerged::isOutSynWUMParamReferenced, - &SynapseGroupInternal::getWUDerivedParams, hash); + for(const auto &sg : m_InSynPSMs) { + sg.updateHash(hash); + } + for (const auto &sg : m_OutSynWUMPreCode) { + sg.updateHash(hash); + } + for (const auto &sg : m_OutSynWUMPreCode) { + sg.updateHash(hash); } return hash.get_digest(); @@ -766,8 +830,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C CodeStream::Scope b(os); sg.generate(backend, os, *this, modelMerged, popSubs); } - - + // Loop through all of neuron group's current sources for (const auto &cs : m_CurrentSources) { CodeStream::Scope b(os); @@ -929,20 +992,20 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C // Are there any outgoing synapse groups with presynaptic code // which have axonal delay and no presynaptic dynamics - const bool preVars = std::any_of(getSortedArchetypeOutSynWithPreCode().cbegin(), getSortedArchetypeOutSynWithPreCode().cend(), - [](const SynapseGroupInternal *sg) + const bool preVars = std::any_of(m_OutSynWUMPreCode.cbegin(), m_OutSynWUMPreCode.cend(), + [](const OutSynWUMPreCode &sg) { - return ((sg->getDelaySteps() != NO_DELAY) - && sg->getWUModel()->getPreDynamicsCode().empty()); + return ((sg.getArchetype().getDelaySteps() != NO_DELAY) + && sg.getArchetype().getWUModel()->getPreDynamicsCode().empty()); }); // Are there any incoming synapse groups with postsynaptic code // which have back-propagation delay and no postsynaptic dynamics - const bool postVars = std::any_of(getSortedArchetypeInSynWithPostCode().cbegin(), getSortedArchetypeInSynWithPostCode().cend(), - [](const SynapseGroupInternal *sg) + const bool postVars = std::any_of(m_InSynWUMPostCode.cbegin(), m_InSynWUMPostCode.cend(), + [](const auto &sg) { - return ((sg->getBackPropDelaySteps() != NO_DELAY) - && sg->getWUModel()->getPostDynamicsCode().empty()); + return ((sg.getArchetype().getBackPropDelaySteps() != NO_DELAY) + && sg.getArchetype().getWUModel()->getPostDynamicsCode().empty()); }); // If spike times, presynaptic variables or postsynaptic variables are required, add if clause @@ -961,33 +1024,13 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C } // Loop through outgoing synapse groups with some sort of presynaptic code - for(size_t i = 0; i < getSortedArchetypeOutSynWithPreCode().size(); i++) { - const auto *sg = getSortedArchetypeOutSynWithPreCode().at(i); - // If this group has a delay and no presynaptic dynamics (which will already perform this copying) - if(sg->getDelaySteps() != NO_DELAY && sg->getWUModel()->getPreDynamicsCode().empty()) { - // Loop through variables and copy between read and write delay slots - for(const auto &v : sg->getWUModel()->getPreVars()) { - if(v.access & VarAccessMode::READ_WRITE) { - os << "group->" << v.name << "WUPre" << i << "[" << getWriteVarIndex(true, batchSize, getVarAccessDuplication(v.access), popSubs["id"]) << "] = "; - os << "group->" << v.name << "WUPre" << i << "[" << getReadVarIndex(true, batchSize, getVarAccessDuplication(v.access), popSubs["id"]) << "];" << std::endl; - } - } - } + for (const auto &sg : m_OutSynWUMPreCode) { + sg.genCopyDelayedVars(os, *this, modelMerged, popSubs); } // Loop through outgoing synapse groups with some sort of postsynaptic code - for(size_t i = 0; i < getSortedArchetypeInSynWithPostCode().size(); i++) { - const auto *sg = getSortedArchetypeInSynWithPostCode().at(i); - // If this group has a delay and no postsynaptic dynamics (which will already perform this copying) - if(sg->getBackPropDelaySteps() != NO_DELAY && sg->getWUModel()->getPostDynamicsCode().empty()) { - // Loop through variables and copy between read and write delay slots - for(const auto &v : sg->getWUModel()->getPostVars()) { - if(v.access & VarAccessMode::READ_WRITE) { - os << "group->" << v.name << "WUPost" << i << "[" << getWriteVarIndex(true, batchSize, getVarAccessDuplication(v.access), popSubs["id"]) << "] = "; - os << "group->" << v.name << "WUPost" << i << "[" << getReadVarIndex(true, batchSize, getVarAccessDuplication(v.access), popSubs["id"]) << "];" << std::endl; - } - } - } + for (const auto &sg : m_OutSynWUMPreCode) { + sg.genCopyDelayedVars(os, *this, modelMerged, popSubs); } } } From 3a185813b05511ab25fc9b64065af0ff9db75e88 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 2 Jun 2023 09:27:09 +0100 Subject: [PATCH 189/725] hierarchical ``NeuronInitGroupMerged`` also now compiles --- .../genn/genn/code_generator/groupMerged.h | 96 +----- .../genn/code_generator/initGroupMerged.h | 67 ++-- src/genn/genn/code_generator/groupMerged.cc | 169 --------- .../genn/code_generator/initGroupMerged.cc | 322 ++++++------------ .../code_generator/neuronUpdateGroupMerged.cc | 20 ++ 5 files changed, 157 insertions(+), 517 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index b40cafad70..2bac5509ac 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -628,6 +628,9 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged> &groups); + //------------------------------------------------------------------------ + // Protected API + //------------------------------------------------------------------------ template void orderNeuronGroupChildren(std::vector &childGroups, const Type::TypeContext &typeContext, const BackendBase &backend, G getVectorFunc, H getHashDigestFunc) const @@ -677,99 +680,6 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged - void updateChildParamHash(const std::vector> &sortedGroupChildren, - size_t childIndex, R isChildParamReferencedFn, V getValueFn, - boost::uuids::detail::sha1 &hash) const - { - // Loop through parameters - const auto &archetypeParams = (sortedGroupChildren.front().at(childIndex)->*getValueFn)(); - for(const auto &p : archetypeParams) { - // If any of the code strings reference the parameter - // **TODO** std::invoke - if((static_cast(this)->*isChildParamReferencedFn)(childIndex, p.first)) { - // Loop through groups - for(size_t g = 0; g < getGroups().size(); g++) { - // Get child group - const auto *child = sortedGroupChildren.at(g).at(childIndex); - - // Update hash with parameter value - Utils::updateHash((child->*getValueFn)().at(p.first), hash); - } - } - } - } - - template - void updateChildDerivedParamHash(const std::vector> &sortedGroupChildren, - size_t childIndex, R isChildParamReferencedFn, V getValueFn, - boost::uuids::detail::sha1 &hash) const - { - // Loop through derived parameters - const auto &archetypeDerivedParams = (sortedGroupChildren.front().at(childIndex)->*getValueFn)(); - for(const auto &d : archetypeDerivedParams) { - // If any of the code strings reference the parameter - if((static_cast(this)->*isChildParamReferencedFn)(childIndex, d.first)) { - // Loop through groups - for(size_t g = 0; g < getGroups().size(); g++) { - // Get child group - const auto *child = sortedGroupChildren.at(g).at(childIndex); - - // Update hash with parameter value - Utils::updateHash((child->*getValueFn)().at(d.first), hash); - } - } - } - } - - template - void updateChildVarInitParamsHash(const std::vector> &sortedGroupChildren, - size_t childIndex, const std::string &varName, R isChildParamReferencedFn, - boost::uuids::detail::sha1 &hash) const - { - // Loop through parameters - const auto &archetypeVarInit = A(*sortedGroupChildren.front().at(childIndex)).getInitialisers(); - const auto &archetypeParams = archetypeVarInit.at(varName).getParams(); - for(const auto &p : archetypeParams) { - // If parameter is referenced - if((static_cast(this)->*isChildParamReferencedFn)(childIndex, varName, p.first)) { - // Loop through groups - for(size_t g = 0; g < getGroups().size(); g++) { - // Get child group and its variable initialisers - const auto *child = sortedGroupChildren.at(g).at(childIndex); - const auto &varInit = A(*child).getInitialisers(); - - // Update hash with parameter value - Utils::updateHash(varInit.at(varName).getParams().at(p.first), hash); - } - } - } - } - - template - void updateChildVarInitDerivedParamsHash(const std::vector> &sortedGroupChildren, - size_t childIndex, const std::string &varName, R isChildParamReferencedFn, - boost::uuids::detail::sha1 &hash) const - { - // Loop through derived parameters - const auto &archetypeVarInit = A(*sortedGroupChildren.front().at(childIndex)).getInitialisers(); - const auto &archetypeDerivedParams = archetypeVarInit.at(varName).getDerivedParams(); - for(const auto &d : archetypeDerivedParams) { - // If parameter is referenced - if((static_cast(this)->*isChildParamReferencedFn)(childIndex, varName, d.first)) { - // Loop through groups - for(size_t g = 0; g < getGroups().size(); g++) { - // Get child group and its variable initialisers - const auto *child = sortedGroupChildren.at(g).at(childIndex); - const auto &varInit = A(*child).getInitialisers(); - - // Update hash with parameter value - Utils::updateHash(varInit.at(varName).getDerivedParams().at(d.first), hash); - } - } - } - } }; //---------------------------------------------------------------------------- diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 997506df48..bf3a582702 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -25,8 +25,11 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase // Public API //---------------------------------------------------------------------------- void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs); + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + //! Update hash with child groups + void updateHash(boost::uuids::detail::sha1 &hash) const; + private: //---------------------------------------------------------------------------- // Private methods @@ -55,7 +58,10 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase // Public API //---------------------------------------------------------------------------- void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs); + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + + //! Update hash with child groups + void updateHash(boost::uuids::detail::sha1 &hash) const; private: //---------------------------------------------------------------------------- @@ -85,7 +91,7 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase // Public API //---------------------------------------------------------------------------- void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs); + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; }; //---------------------------------------------------------------------------- @@ -102,7 +108,10 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase // Public API //---------------------------------------------------------------------------- void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs); + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + + //! Update hash with child groups + void updateHash(boost::uuids::detail::sha1 &hash) const; private: //---------------------------------------------------------------------------- @@ -131,9 +140,12 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs); - + void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + + //! Update hash with child groups + void updateHash(boost::uuids::detail::sha1 &hash) const; + private: //---------------------------------------------------------------------------- // Private methods @@ -168,18 +180,6 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase void generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; - //! Should the incoming synapse weight update model var init parameter be implemented heterogeneously? - bool isInSynWUMVarInitParamHeterogeneous(size_t childIndex, const std::string &varName, const std::string ¶mName) const; - - //! Should the incoming synapse weight update model var init derived parameter be implemented heterogeneously? - bool isInSynWUMVarInitDerivedParamHeterogeneous(size_t childIndex, const std::string &varName, const std::string ¶mName) const; - - //! Should the outgoing synapse weight update model var init parameter be implemented heterogeneously? - bool isOutSynWUMVarInitParamHeterogeneous(size_t childIndex, const std::string &varName, const std::string ¶mName) const; - - //! Should the outgoing synapse weight update model var init derived parameter be implemented heterogeneously? - bool isOutSynWUMVarInitDerivedParamHeterogeneous(size_t childIndex, const std::string &varName, const std::string ¶mName) const; - //---------------------------------------------------------------------------- // Static constants //---------------------------------------------------------------------------- @@ -189,21 +189,6 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //------------------------------------------------------------------------ // Private methods //------------------------------------------------------------------------ - //! Helper to generate merged struct fields for WU pre and post vars - void generateWUVar(const BackendBase &backend, const std::string &fieldPrefixStem, - const std::vector> &sortedSyn, - Models::Base::VarVec(WeightUpdateModels::Base::*getVars)(void) const, - const std::unordered_map&(SynapseGroupInternal::*getVarInitialiserFn)(void) const, - bool(NeuronInitGroupMerged::*isParamHeterogeneousFn)(size_t, const std::string&, const std::string&) const, - bool(NeuronInitGroupMerged::*isDerivedParamHeterogeneousFn)(size_t, const std::string&, const std::string&) const, - const std::string&(SynapseGroupInternal::*getFusedVarSuffix)(void) const); - - //! Is the incoming synapse weight update model var init parameter referenced? - bool isInSynWUMVarInitParamReferenced(size_t childIndex, const std::string &varName, const std::string ¶mName) const; - - //! Is the outgoing synapse weight update model var init parameter referenced? - bool isOutSynWUMVarInitParamReferenced(size_t childIndex, const std::string &varName, const std::string ¶mName) const; - void genInitSpikeCount(CodeStream &os, const BackendBase &backend, const Substitutions &popSubs, bool spikeEvent, unsigned int batchSize) const; @@ -213,17 +198,15 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase void genInitSpikeTime(CodeStream &os, const BackendBase &backend, const Substitutions &popSubs, const std::string &varName, unsigned int batchSize) const; - //! Get sorted vectors of incoming synapse groups with postsynaptic variables belonging to archetype group - const std::vector &getSortedArchetypeInSynWithPostVars() const { return m_SortedInSynWithPostVars.front(); } - - //! Get sorted vectors of outgoing synapse groups with presynaptic variables belonging to archetype group - const std::vector &getSortedArchetypeOutSynWithPreVars() const { return m_SortedOutSynWithPreVars.front(); } - //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - std::vector> m_SortedInSynWithPostVars; - std::vector> m_SortedOutSynWithPreVars; + std::vector m_CurrentSources; + std::vector m_InSynPSMs; + std::vector m_OutSynPreOutput; + std::vector m_SortedInSynWithPostCode; + std::vector m_InSynWUMPostVars; + std::vector m_OutSynWUMPreVars; }; diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 4d6d242d4e..67bcd89536 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -160,175 +160,6 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeConte if(getArchetype().isPrevSpikeEventTimeRequired()) { addPointerField(getTimeType(), "prevSET", backend.getDeviceVarPrefix() + "prevSET"); } - - // If this backend initialises population RNGs on device and this group requires on for simulation - /*if(backend.isPopulationRNGRequired() && getArchetype().isSimRNGRequired() - && (!init || backend.isPopulationRNGInitialisedOnDevice())) - { - addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); - } - - // Loop through variables - const NeuronModels::Base *nm = getArchetype().getNeuronModel(); - const auto vars = nm->getVars(); - const auto &varInit = getArchetype().getVarInitialisers(); - for(const auto &var : vars) { - // If we're not initialising or if there is initialization code for this variable - if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addPointerField(var.type, var.name, - backend.getDeviceVarPrefix() + var.name); - } - - // If we're initializing, add any var init EGPs to structure - if(init) { - addEGPs(varInit.at(var.name).getSnippet()->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name); - } - } - - // If we're generating a struct for initialization - if(init) { - // Add heterogeneous var init parameters - addHeterogeneousVarInitParams( - &NeuronGroupMergedBase::isVarInitParamHeterogeneous); - - addHeterogeneousVarInitDerivedParams( - &NeuronGroupMergedBase::isVarInitDerivedParamHeterogeneous); - } - // Otherwise - else { - addEGPs(nm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); - - // Add heterogeneous neuron model parameters - addHeterogeneousParams( - getArchetype().getNeuronModel()->getParamNames(), "", - [](const NeuronGroupInternal &ng) { return ng.getParams(); }, - &NeuronGroupMergedBase::isParamHeterogeneous); - - // Add heterogeneous neuron model derived parameters - addHeterogeneousDerivedParams( - getArchetype().getNeuronModel()->getDerivedParams(), "", - [](const NeuronGroupInternal &ng) { return ng.getDerivedParams(); }, - &NeuronGroupMergedBase::isDerivedParamHeterogeneous); - } - - // Loop through merged synaptic inputs to archetypical neuron group (0) in sorted order - for(size_t i = 0; i < getSortedArchetypeMergedInSyns().size(); i++) { - const SynapseGroupInternal *sg = getSortedArchetypeMergedInSyns().at(i); - - // Add pointer to insyn - addMergedInSynPointerField(getScalarType(), "inSynInSyn", i, backend.getDeviceVarPrefix() + "inSyn"); - - // Add pointer to dendritic delay buffer if required - if(sg->isDendriticDelayRequired()) { - addMergedInSynPointerField(getScalarType(), "denDelayInSyn", i, backend.getDeviceVarPrefix() + "denDelay"); - addMergedInSynPointerField(Uint32, "denDelayPtrInSyn", i, backend.getScalarAddressPrefix() + "denDelayPtr"); - } - - // Loop through variables - const auto &varInit = sg->getPSVarInitialisers(); - for(const auto &var : sg->getPSModel()->getVars()) { - // Add pointers to state variable - if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addMergedInSynPointerField(var.type.resolve(getTypeContext()), var.name + "InSyn", i, - backend.getDeviceVarPrefix() + var.name); - } - - // If we're generating an initialization structure, also add any heterogeneous parameters, derived parameters or extra global parameters required for initializers - if(init) { - const auto *varInitSnippet = varInit.at(var.name).getSnippet(); - addHeterogeneousChildVarInitParams(varInitSnippet->getParamNames(), m_SortedMergedInSyns, i, var.name, "InSyn", - &NeuronGroupMergedBase::isPSMVarInitParamHeterogeneous, &SynapseGroupInternal::getPSVarInitialisers); - addHeterogeneousChildVarInitDerivedParams(varInitSnippet->getDerivedParams(), m_SortedMergedInSyns, i, var.name, "InSyn", - &NeuronGroupMergedBase::isPSMVarInitDerivedParamHeterogeneous, &SynapseGroupInternal::getPSVarInitialisers); - addChildEGPs(varInitSnippet->getExtraGlobalParams(), i, backend.getDeviceVarPrefix(), var.name + "InSyn", - [var, this](size_t groupIndex, size_t childIndex) - { - return var.name + m_SortedMergedInSyns.at(groupIndex).at(childIndex)->getFusedPSVarSuffix(); - }); - } - } - - if(!init) { - // Add any heterogeneous postsynaptic model parameters - const auto paramNames = sg->getPSModel()->getParamNames(); - addHeterogeneousChildParams(paramNames, m_SortedMergedInSyns, i, "InSyn", - &NeuronGroupMergedBase::isPSMParamHeterogeneous, - &SynapseGroupInternal::getPSParams); - - // Add any heterogeneous postsynaptic mode derived parameters - const auto derivedParams = sg->getPSModel()->getDerivedParams(); - addHeterogeneousChildDerivedParams(derivedParams, m_SortedMergedInSyns, i, "InSyn", - &NeuronGroupMergedBase::isPSMDerivedParamHeterogeneous, - &SynapseGroupInternal::getPSDerivedParams); - - // Add EGPs - addChildEGPs(sg->getPSModel()->getExtraGlobalParams(), i, backend.getDeviceVarPrefix(), "InSyn", - [this](size_t groupIndex, size_t childIndex) - { - return m_SortedMergedInSyns.at(groupIndex).at(childIndex)->getFusedPSVarSuffix(); - }); - } - } - - // Loop through merged output synapses with presynaptic output of archetypical neuron group (0) in sorted order - for(size_t i = 0; i < getSortedArchetypeMergedPreOutputOutSyns().size(); i++) { - // Add pointer to revInSyn - addMergedPreOutputOutSynPointerField(getScalarType(), "revInSynOutSyn", i, backend.getDeviceVarPrefix() + "revInSyn"); - } - - // Loop through current sources to archetypical neuron group in sorted order - for(size_t i = 0; i < getSortedArchetypeCurrentSources().size(); i++) { - const auto *cs = getSortedArchetypeCurrentSources().at(i); - - // Loop through variables - const auto &varInit = cs->getVarInitialisers(); - for(const auto &var : cs->getCurrentSourceModel()->getVars()) { - // Add pointers to state variable - if(!init || !varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(var.type.resolve(getTypeContext()).createPointer(), var.name + "CS" + std::to_string(i), - [&backend, i, var, this](const auto&, size_t groupIndex) - { - return backend.getDeviceVarPrefix() + var.name + m_SortedCurrentSources.at(groupIndex).at(i)->getName(); - }); - } - - // If we're generating an initialization structure, also add any heterogeneous parameters, derived parameters or extra global parameters required for initializers - if(init) { - const auto *varInitSnippet = varInit.at(var.name).getSnippet(); - addHeterogeneousChildVarInitParams(varInitSnippet->getParamNames(), m_SortedCurrentSources, i, var.name, "CS", - &NeuronGroupMergedBase::isCurrentSourceVarInitParamHeterogeneous, &CurrentSourceInternal::getVarInitialisers); - addHeterogeneousChildVarInitDerivedParams(varInitSnippet->getDerivedParams(), m_SortedCurrentSources, i, var.name, "CS", - &NeuronGroupMergedBase::isCurrentSourceVarInitDerivedParamHeterogeneous, &CurrentSourceInternal::getVarInitialisers); - addChildEGPs(varInitSnippet->getExtraGlobalParams(), i, backend.getDeviceVarPrefix(), var.name + "CS", - [var, this](size_t groupIndex, size_t childIndex) - { - return var.name + m_SortedCurrentSources.at(groupIndex).at(childIndex)->getName(); - }); - } - } - - if(!init) { - // Add any heterogeneous current source parameters - const auto paramNames = cs->getCurrentSourceModel()->getParamNames(); - addHeterogeneousChildParams(paramNames, m_SortedCurrentSources, i, "CS", - &NeuronGroupMergedBase::isCurrentSourceParamHeterogeneous, - &CurrentSourceInternal::getParams); - - // Add any heterogeneous current source derived parameters - const auto derivedParams = cs->getCurrentSourceModel()->getDerivedParams(); - addHeterogeneousChildDerivedParams(derivedParams, m_SortedCurrentSources, i, "CS", - &NeuronGroupMergedBase::isCurrentSourceDerivedParamHeterogeneous, - &CurrentSourceInternal::getDerivedParams); - - // Add EGPs - addChildEGPs(cs->getCurrentSourceModel()->getExtraGlobalParams(), i, backend.getDeviceVarPrefix(), "CS", - [this](size_t groupIndex, size_t childIndex) - { - return m_SortedCurrentSources.at(groupIndex).at(childIndex)->getName(); - }); - - } - }*/ } //---------------------------------------------------------------------------- bool NeuronGroupMergedBase::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index df119b5504..350a491f94 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -218,7 +218,7 @@ NeuronInitGroupMerged::CurrentSource::CurrentSource(size_t index, const Type::Ty } //---------------------------------------------------------------------------- void NeuronInitGroupMerged::CurrentSource::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { const std::string suffix = "CS" + std::to_string(getIndex()); @@ -228,6 +228,12 @@ void NeuronInitGroupMerged::CurrentSource::generate(const BackendBase &backend, [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); } //---------------------------------------------------------------------------- +void NeuronInitGroupMerged::CurrentSource::updateHash(boost::uuids::detail::sha1 &hash) const +{ + updateVarInitParamHash(&CurrentSource::isVarInitParamReferenced, hash); + updateVarInitDerivedParamHash(&CurrentSource::isVarInitParamReferenced, hash); +} +//---------------------------------------------------------------------------- bool NeuronInitGroupMerged::CurrentSource::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { return (isVarInitParamReferenced(varName, paramName) && @@ -299,7 +305,7 @@ NeuronInitGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContext } //---------------------------------------------------------------------------- void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { const std::string suffix = "InSyn" + std::to_string(getIndex()); @@ -340,6 +346,12 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, CodeS [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); } //---------------------------------------------------------------------------- +void NeuronInitGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &hash) const +{ + updateVarInitParamHash(&InSynPSM::isVarInitParamReferenced, hash); + updateVarInitDerivedParamHash(&InSynPSM::isVarInitParamReferenced, hash); +} +//---------------------------------------------------------------------------- bool NeuronInitGroupMerged::InSynPSM::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { return (isVarInitParamReferenced(varName, paramName) && @@ -374,7 +386,7 @@ NeuronInitGroupMerged::OutSynPreOutput::OutSynPreOutput(size_t index, const Type } //---------------------------------------------------------------------------- void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { const std::string suffix = "OutSyn" + std::to_string(getIndex()); @@ -425,7 +437,7 @@ NeuronInitGroupMerged::InSynWUMPostVars::InSynWUMPostVars(size_t index, const Ty } //---------------------------------------------------------------------------- void NeuronInitGroupMerged::InSynWUMPostVars::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); @@ -436,6 +448,12 @@ void NeuronInitGroupMerged::InSynWUMPostVars::generate(const BackendBase &backen [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); } //---------------------------------------------------------------------------- +void NeuronInitGroupMerged::InSynWUMPostVars::updateHash(boost::uuids::detail::sha1 &hash) const +{ + updateVarInitParamHash(&InSynWUMPostVars::isVarInitParamReferenced, hash); + updateVarInitDerivedParamHash(&InSynWUMPostVars::isVarInitParamReferenced, hash); +} +//---------------------------------------------------------------------------- bool NeuronInitGroupMerged::InSynWUMPostVars::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { return (isVarInitParamReferenced(varName, paramName) && @@ -493,8 +511,8 @@ NeuronInitGroupMerged::OutSynWUMPreVars::OutSynWUMPreVars(size_t index, const Ty } } //---------------------------------------------------------------------------- -void NeuronInitGroupMerged::OutSynWUMPreVars::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) +void NeuronInitGroupMerged::OutSynWUMPreVars::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, + const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); @@ -505,6 +523,12 @@ void NeuronInitGroupMerged::OutSynWUMPreVars::generate(const BackendBase &backen [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); } //---------------------------------------------------------------------------- +void NeuronInitGroupMerged::OutSynWUMPreVars::updateHash(boost::uuids::detail::sha1 &hash) const +{ + updateVarInitParamHash(&OutSynWUMPreVars::isVarInitParamReferenced, hash); + updateVarInitDerivedParamHash(&OutSynWUMPreVars::isVarInitParamReferenced, hash); +} +//---------------------------------------------------------------------------- bool NeuronInitGroupMerged::OutSynWUMPreVars::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { return (isVarInitParamReferenced(varName, paramName) && @@ -532,40 +556,72 @@ const std::string NeuronInitGroupMerged::name = "NeuronInit"; //---------------------------------------------------------------------------- NeuronInitGroupMerged::NeuronInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: NeuronGroupMergedBase(index, typeContext, backend, true, groups) +: NeuronGroupMergedBase(index, typeContext, backend, groups) { - // Build vector of vectors containing each child group's incoming - // synapse groups, ordered to match those of the archetype group - orderNeuronGroupChildren(m_SortedInSynWithPostVars, &NeuronGroupInternal::getFusedInSynWithPostVars, + // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group + orderNeuronGroupChildren(m_InSynPSMs, typeContext, backend, + &NeuronGroupInternal::getFusedPSMInSyn, + &SynapseGroupInternal::getPSInitHashDigest ); + + // Build vector of vectors containing each child group's merged out syns with pre output, ordered to match those of the archetype group + orderNeuronGroupChildren(m_OutSynPreOutput, typeContext, backend, + &NeuronGroupInternal::getFusedPreOutputOutSyn, + &SynapseGroupInternal::getPreOutputInitHashDigest ); + + // Build vector of vectors containing each child group's current sources, ordered to match those of the archetype group + orderNeuronGroupChildren(m_CurrentSources, typeContext, backend, + &NeuronGroupInternal::getCurrentSources, + &CurrentSourceInternal::getInitHashDigest ); + + + // Build vector of vectors containing each child group's incoming synapse groups + // with postsynaptic weight update model variable, ordered to match those of the archetype group + orderNeuronGroupChildren(m_InSynWUMPostVars, typeContext, backend, + &NeuronGroupInternal::getFusedInSynWithPostVars, &SynapseGroupInternal::getWUPostInitHashDigest); - // Build vector of vectors containing each child group's outgoing - // synapse groups, ordered to match those of the archetype group - orderNeuronGroupChildren(m_SortedOutSynWithPreVars, &NeuronGroupInternal::getFusedOutSynWithPreVars, + // Build vector of vectors containing each child group's outgoing synapse groups + // with presynaptic weight update model variables, ordered to match those of the archetype group + orderNeuronGroupChildren(m_OutSynWUMPreVars, typeContext, backend, + &NeuronGroupInternal::getFusedOutSynWithPreVars, &SynapseGroupInternal::getWUPreInitHashDigest); - // Generate struct fields for incoming synapse groups with postsynaptic variables - generateWUVar(backend, "WUPost", m_SortedInSynWithPostVars, - &WeightUpdateModels::Base::getPostVars, &SynapseGroupInternal::getWUPostVarInitialisers, - &NeuronInitGroupMerged::isInSynWUMVarInitParamHeterogeneous, - &NeuronInitGroupMerged::isInSynWUMVarInitDerivedParamHeterogeneous, - &SynapseGroupInternal::getFusedWUPostVarSuffix); + if(backend.isPopulationRNGRequired() && getArchetype().isSimRNGRequired() + && backend.isPopulationRNGInitialisedOnDevice()) + { + addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); + } + + // Loop through variables + const NeuronModels::Base *nm = getArchetype().getNeuronModel(); + const auto vars = nm->getVars(); + const auto &varInit = getArchetype().getVarInitialisers(); + for(const auto &var : vars) { + // If we're not initialising or if there is initialization code for this variable + if(!varInit.at(var.name).getSnippet()->getCode().empty()) { + addPointerField(var.type, var.name, + backend.getDeviceVarPrefix() + var.name); + } + + // Add any var init EGPs to structure + addEGPs(varInit.at(var.name).getSnippet()->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name); + } + + // Add heterogeneous var init parameters + addHeterogeneousVarInitParams( + &NeuronGroupMergedBase::isVarInitParamHeterogeneous); - // Generate struct fields for outgoing synapse groups - generateWUVar(backend, "WUPre", m_SortedOutSynWithPreVars, - &WeightUpdateModels::Base::getPreVars, &SynapseGroupInternal::getWUPreVarInitialisers, - &NeuronInitGroupMerged::isOutSynWUMVarInitParamHeterogeneous, - &NeuronInitGroupMerged::isOutSynWUMVarInitDerivedParamHeterogeneous, - &SynapseGroupInternal::getFusedWUPreVarSuffix); + addHeterogeneousVarInitDerivedParams( + &NeuronGroupMergedBase::isVarInitDerivedParamHeterogeneous); } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type NeuronInitGroupMerged::getHashDigest() const { boost::uuids::detail::sha1 hash; - // Update hash with generic neuron group data - updateBaseHash(true, hash); + /// Update hash with each group's neuron count + updateHash([](const NeuronGroupInternal &g) { return g.getNumNeurons(); }, hash); // Update hash with archetype's hash digest Utils::updateHash(getArchetype().getInitHashDigest(), hash); @@ -574,30 +630,18 @@ boost::uuids::detail::sha1::digest_type NeuronInitGroupMerged::getHashDigest() c updateVarInitParamHash(&NeuronInitGroupMerged::isVarInitParamReferenced, hash); updateVarInitDerivedParamHash(&NeuronInitGroupMerged::isVarInitParamReferenced, hash); - // Loop through child incoming synapse groups with postsynaptic variables - for(size_t c = 0; c < getSortedArchetypeInSynWithPostVars().size(); c++) { - const auto *sg = getSortedArchetypeInSynWithPostVars().at(c); - - // Loop through variables and update hash with variable initialisation parameters and derived parameters - for(const auto &v : sg->getWUPostVarInitialisers()) { - updateChildVarInitParamsHash( - m_SortedInSynWithPostVars, c, v.first, &NeuronInitGroupMerged::isInSynWUMVarInitParamReferenced, hash); - updateChildVarInitDerivedParamsHash( - m_SortedInSynWithPostVars, c, v.first, &NeuronInitGroupMerged::isInSynWUMVarInitParamReferenced, hash); - } + // Update hash with child groups + for (const auto &cs : m_CurrentSources) { + cs.updateHash(hash); } - - // Loop through child outgoing synapse groups with presynaptic variables - for(size_t c = 0; c < getSortedArchetypeOutSynWithPreVars().size(); c++) { - const auto *sg = getSortedArchetypeOutSynWithPreVars().at(c); - - // Loop through variables and update hash with variable initialisation parameters and derived parameters - for(const auto &v : sg->getWUPreVarInitialisers()) { - updateChildVarInitParamsHash( - m_SortedOutSynWithPreVars, c, v.first, &NeuronInitGroupMerged::isOutSynWUMVarInitParamReferenced, hash); - updateChildVarInitDerivedParamsHash( - m_SortedOutSynWithPreVars, c, v.first, &NeuronInitGroupMerged::isOutSynWUMVarInitParamReferenced, hash); - } + for(const auto &sg : m_InSynPSMs) { + sg.updateHash(hash); + } + for (const auto &sg : m_OutSynWUMPreVars) { + sg.updateHash(hash); + } + for (const auto &sg : m_OutSynWUMPreVars) { + sg.updateHash(hash); } return hash.get_digest(); @@ -651,178 +695,30 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); - // Loop through incoming synaptic populations - for(size_t i = 0; i < getSortedArchetypeMergedInSyns().size(); i++) { - CodeStream::Scope b(os); - - const auto *sg = getSortedArchetypeMergedInSyns().at(i); - - // Zero InSyn - backend.genVariableInit(os, "group->numNeurons", "id", popSubs, - [&model, &modelMerged, i] (CodeStream &os, Substitutions &varSubs) - { - genVariableFill(os, "inSynInSyn" + std::to_string(i), modelMerged.scalarExpr(0.0), - varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, model.getBatchSize()); - - }); - - // If dendritic delays are required - if(sg->isDendriticDelayRequired()) { - // Zero dendritic delay buffer - backend.genVariableInit(os, "group->numNeurons", "id", popSubs, - [&model, &modelMerged, sg, i](CodeStream &os, Substitutions &varSubs) - { - genVariableFill(os, "denDelayInSyn" + std::to_string(i), modelMerged.scalarExpr(0.0), - varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, model.getBatchSize(), - true, sg->getMaxDendriticDelayTimesteps()); - }); - - // Zero dendritic delay pointer - backend.genPopVariableInit(os, popSubs, - [i](CodeStream &os, Substitutions &) - { - os << "*group->denDelayPtrInSyn" << i << " = 0;" << std::endl; - }); - } - - // **TODO** adaptor - genInitNeuronVarCode(os, modelMerged, backend, popSubs, sg->getPSModel()->getVars(), sg->getPSVarInitialisers(), - "InSyn" + std::to_string(i), "numNeurons", i, model.getBatchSize(), - [i, this](const std::string &v, const std::string &p) { return isPSMVarInitParamHeterogeneous(i, v, p); }, - [i, this](const std::string &v, const std::string &p) { return isPSMVarInitDerivedParamHeterogeneous(i, v, p); }); + // Loop through all of neuron group's current sources + for (const auto &cs : m_CurrentSources) { + cs.generate(backend, os, *this, modelMerged, popSubs); } - // Loop through incoming synaptic populations with postsynaptic variables - // **NOTE** number of delay slots is based on the target neuron (for simplicity) but whether delay is required is based on the synapse group - for(size_t i = 0; i < getSortedArchetypeInSynWithPostVars().size(); i++) { - const auto *sg = getSortedArchetypeInSynWithPostVars().at(i); - // **TODO** adaptor - genInitNeuronVarCode(os, modelMerged, backend, popSubs, sg->getWUModel()->getPostVars(), sg->getWUPostVarInitialisers(), - "WUPost" + std::to_string(i), "numNeurons", sg->getTrgNeuronGroup()->getNumDelaySlots(), i, model.getBatchSize(), - [&sg](const std::string&){ return (sg->getBackPropDelaySteps() != NO_DELAY); }, - [i, this](const std::string &v, const std::string &p) { return isInSynWUMVarInitParamHeterogeneous(i, v, p); }, - [i, this](const std::string &v, const std::string &p) { return isInSynWUMVarInitDerivedParamHeterogeneous(i, v, p); }); -} - - // Loop through outgoing synaptic populations with presynaptic variables - // **NOTE** number of delay slots is based on the source neuron (for simplicity) but whether delay is required is based on the synapse group - for(size_t i = 0; i < getSortedArchetypeOutSynWithPreVars().size(); i++) { - const auto *sg = getSortedArchetypeOutSynWithPreVars().at(i); - // **TODO** adaptor - genInitNeuronVarCode(os, modelMerged, backend, popSubs, sg->getWUModel()->getPreVars(), sg->getWUPreVarInitialisers(), - "WUPre" + std::to_string(i), "numNeurons", sg->getSrcNeuronGroup()->getNumDelaySlots(), i, model.getBatchSize(), - [&sg](const std::string&){ return (sg->getDelaySteps() != NO_DELAY); }, - [i, this](const std::string &v, const std::string &p) { return isOutSynWUMVarInitParamHeterogeneous(i, v, p); }, - [i, this](const std::string &v, const std::string &p) { return isOutSynWUMVarInitDerivedParamHeterogeneous(i, v, p); }); + for(const auto &sg : m_InSynPSMs) { + sg.generate(backend, os, *this, modelMerged, popSubs); } - // Loop through outgoing synaptic populations with presynaptic output - for(size_t i = 0; i < getSortedArchetypeMergedPreOutputOutSyns().size(); i++) { - // Zero revInSynOutSyn - backend.genVariableInit(os, "group->numNeurons", "id", popSubs, - [&model, &modelMerged, i] (CodeStream &os, Substitutions &varSubs) - { - genVariableFill(os, "revInSynOutSyn" + std::to_string(i), modelMerged.scalarExpr(0.0), - varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, model.getBatchSize()); - }); + // Loop through outgoing synapse groups with presynaptic output + for (const auto &sg : m_OutSynPreOutput) { + sg.generate(backend, os, *this, modelMerged, popSubs); } - - // Loop through current sources - os << "// current source variables" << std::endl; - for(size_t i = 0; i < getSortedArchetypeCurrentSources().size(); i++) { - const auto *cs = getSortedArchetypeCurrentSources().at(i); - // **TODO** adaptor - genInitNeuronVarCode(os, modelMerged, backend, popSubs, cs->getCurrentSourceModel()->getVars(), cs->getVarInitialisers(), - "CS" + std::to_string(i), "numNeurons", i, model.getBatchSize(), - [i, this](const std::string &v, const std::string &p) { return isCurrentSourceVarInitParamHeterogeneous(i, v, p); }, - [i, this](const std::string &v, const std::string &p) { return isCurrentSourceVarInitDerivedParamHeterogeneous(i, v, p); }); + + + for (const auto &sg : m_OutSynWUMPreVars) { + sg.generate(backend, os, *this, modelMerged, popSubs); } -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::isInSynWUMVarInitParamHeterogeneous(size_t childIndex, const std::string &varName, const std::string ¶mName) const -{ - return (isInSynWUMVarInitParamReferenced(childIndex, varName, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedInSynWithPostVars, - [&varName](const SynapseGroupInternal *s) { return s->getWUPostVarInitialisers().at(varName).getParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::isInSynWUMVarInitDerivedParamHeterogeneous(size_t childIndex, const std::string &varName, const std::string ¶mName) const -{ - return (isInSynWUMVarInitParamReferenced(childIndex, varName, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedInSynWithPostVars, - [&varName](const SynapseGroupInternal *s) { return s->getWUPostVarInitialisers().at(varName).getDerivedParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::isOutSynWUMVarInitParamHeterogeneous(size_t childIndex, const std::string &varName, const std::string ¶mName) const -{ - return (isOutSynWUMVarInitParamReferenced(childIndex, varName, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedOutSynWithPreVars, - [&varName](const SynapseGroupInternal *s) { return s->getWUPreVarInitialisers().at(varName).getParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::isOutSynWUMVarInitDerivedParamHeterogeneous(size_t childIndex, const std::string &varName, const std::string ¶mName) const -{ - return (isOutSynWUMVarInitParamReferenced(childIndex, varName, paramName) && - isChildParamValueHeterogeneous(childIndex, paramName, m_SortedOutSynWithPreVars, - [&varName](const SynapseGroupInternal *s) { return s->getWUPreVarInitialisers().at(varName).getDerivedParams(); })); -} -//---------------------------------------------------------------------------- -void NeuronInitGroupMerged::generateWUVar(const BackendBase &backend, - const std::string &fieldPrefixStem, - const std::vector> &sortedSyn, - Models::Base::VarVec(WeightUpdateModels::Base::*getVars)(void) const, - const std::unordered_map &(SynapseGroupInternal::*getVarInitialiserFn)(void) const, - bool(NeuronInitGroupMerged::*isParamHeterogeneousFn)(size_t, const std::string&, const std::string&) const, - bool(NeuronInitGroupMerged::*isDerivedParamHeterogeneousFn)(size_t, const std::string&, const std::string&) const, - const std::string&(SynapseGroupInternal::*getFusedVarSuffix)(void) const) -{ - using namespace Type; - - // Loop through synapse groups - const auto &archetypeSyns = sortedSyn.front(); - for(size_t i = 0; i < archetypeSyns.size(); i++) { - const auto *sg = archetypeSyns.at(i); - - // Loop through variables - const auto vars = (sg->getWUModel()->*getVars)(); - const auto &varInit = (sg->*getVarInitialiserFn)(); - for(const auto &var : vars) { - // Add pointers to state variable - if(!varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(var.type.resolve(getTypeContext()).createPointer(), var.name + fieldPrefixStem + std::to_string(i), - [i, var, &backend, &sortedSyn, getFusedVarSuffix](const auto&, size_t groupIndex) - { - const std::string &varMergeSuffix = (sortedSyn.at(groupIndex).at(i)->*getFusedVarSuffix)(); - return backend.getDeviceVarPrefix() + var.name + varMergeSuffix; - }); - } - // Also add any heterogeneous, derived or extra global parameters required for initializers - const auto *varInitSnippet = varInit.at(var.name).getSnippet(); - addHeterogeneousChildVarInitParams(varInitSnippet->getParamNames(), sortedSyn, i, var.name, fieldPrefixStem, - isParamHeterogeneousFn, getVarInitialiserFn); - addHeterogeneousChildVarInitDerivedParams(varInitSnippet->getDerivedParams(), sortedSyn, i, var.name, fieldPrefixStem, - isDerivedParamHeterogeneousFn, getVarInitialiserFn); - addChildEGPs(varInitSnippet->getExtraGlobalParams(), i, backend.getDeviceVarPrefix(), var.name + fieldPrefixStem, - [var, &sortedSyn](size_t groupIndex, size_t childIndex) - { - return var.name + sortedSyn.at(groupIndex).at(childIndex)->getName(); - }); - } + // Generate var update for incoming synaptic populations with postsynaptic code + for (const auto &sg : m_OutSynWUMPreVars) { + sg.generate(backend, os, *this, modelMerged, popSubs); } } -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::isInSynWUMVarInitParamReferenced(size_t childIndex, const std::string &varName, const std::string ¶mName) const -{ - const auto *varInitSnippet = getSortedArchetypeInSynWithPostVars().at(childIndex)->getWUPostVarInitialisers().at(varName).getSnippet(); - return isParamReferenced({varInitSnippet->getCode()}, paramName); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::isOutSynWUMVarInitParamReferenced(size_t childIndex, const std::string &varName, const std::string ¶mName) const -{ - const auto *varInitSnippet = getSortedArchetypeOutSynWithPreVars().at(childIndex)->getWUPreVarInitialisers().at(varName).getSnippet(); - return isParamReferenced({varInitSnippet->getCode()}, paramName); -} //-------------------------------------------------------------------------- void NeuronInitGroupMerged::genInitSpikeCount(CodeStream &os, const BackendBase &backend, const Substitutions &popSubs, bool spikeEvent, unsigned int batchSize) const diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 7f9e03f7b9..6148610332 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -636,6 +636,26 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC &NeuronGroupInternal::getFusedOutSynWithPreCode, &SynapseGroupInternal::getWUPreHashDigest); + if(backend.isPopulationRNGRequired() && getArchetype().isSimRNGRequired()) { + addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); + } + + // Add variables and extra global parameters + addVars(getArchetype().getNeuronModel()->getVars(), backend.getDeviceVarPrefix()); + addEGPs(getArchetype().getNeuronModel()->getExtraGlobalParams(), backend.getDeviceVarPrefix()); + + // Add heterogeneous neuron model parameters + addHeterogeneousParams( + getArchetype().getNeuronModel()->getParamNames(), "", + [](const NeuronGroupInternal &ng) { return ng.getParams(); }, + &NeuronGroupMergedBase::isParamHeterogeneous); + + // Add heterogeneous neuron model derived parameters + addHeterogeneousDerivedParams( + getArchetype().getNeuronModel()->getDerivedParams(), "", + [](const NeuronGroupInternal &ng) { return ng.getDerivedParams(); }, + &NeuronGroupMergedBase::isDerivedParamHeterogeneous); + // Loop through neuron groups std::vector> eventThresholdSGs; for(const auto &g : getGroups()) { From be529fd7bd970e2aaf6f6ed7ae8a626fcbb04582 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 2 Jun 2023 09:39:14 +0100 Subject: [PATCH 190/725] added accessors --- .../genn/code_generator/initGroupMerged.h | 17 +++++--- .../code_generator/neuronUpdateGroupMerged.h | 17 +++++--- include/genn/genn/neuronGroup.h | 4 +- .../genn/code_generator/initGroupMerged.cc | 28 ++++++------- .../code_generator/neuronUpdateGroupMerged.cc | 40 +++++++++---------- src/genn/genn/neuronGroup.cc | 6 +-- 6 files changed, 61 insertions(+), 51 deletions(-) diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index bf3a582702..50273a1fff 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -180,6 +180,12 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase void generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + const std::vector &getMergedCurrentSourceGroups() const { return m_MergedCurrentSourceGroups; } + const std::vector &getMergedInSynPSMGroups() const { return m_MergedInSynPSMGroups; } + const std::vector &getMergedOutSynPreOutputGroups() const { return m_MergedOutSynPreOutputGroups; } + const std::vector &getMergedInSynWUMPostVarGroups() const { return m_MergedInSynWUMPostVarGroups; } + const std::vector &getMergedOutSynWUMPreVarGroups() const { return m_MergedOutSynWUMPreVarGroups; } + //---------------------------------------------------------------------------- // Static constants //---------------------------------------------------------------------------- @@ -201,12 +207,11 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - std::vector m_CurrentSources; - std::vector m_InSynPSMs; - std::vector m_OutSynPreOutput; - std::vector m_SortedInSynWithPostCode; - std::vector m_InSynWUMPostVars; - std::vector m_OutSynWUMPreVars; + std::vector m_MergedCurrentSourceGroups; + std::vector m_MergedInSynPSMGroups; + std::vector m_MergedOutSynPreOutputGroups; + std::vector m_MergedInSynWUMPostVarGroups; + std::vector m_MergedOutSynWUMPreVarGroups; }; diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index ea9b9a782c..d6071574d5 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -194,6 +194,12 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase std::string getReadVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; std::string getWriteVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; + const std::vector &getMergedCurrentSourceGroups() const { return m_MergedCurrentSourceGroups; } + const std::vector &getMergedInSynPSMGroups() const { return m_MergedInSynPSMGroups; } + const std::vector &getMergedOutSynPreOutputGroups() const { return m_MergedOutSynPreOutputGroups; } + const std::vector &getMergedInSynWUMPostCodeGroups() const { return m_MergedInSynWUMPostCodeGroups; } + const std::vector &getMergedOutSynWUMPreCodeGroups() const { return m_MergedOutSynWUMPreCodeGroups; } + //---------------------------------------------------------------------------- // Static constants //---------------------------------------------------------------------------- @@ -208,11 +214,10 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - std::vector m_CurrentSources; - std::vector m_InSynPSMs; - std::vector m_OutSynPreOutput; - std::vector m_SortedInSynWithPostCode; - std::vector m_InSynWUMPostCode; - std::vector m_OutSynWUMPreCode; + std::vector m_MergedCurrentSourceGroups; + std::vector m_MergedInSynPSMGroups; + std::vector m_MergedOutSynPreOutputGroups; + std::vector m_MergedInSynWUMPostCodeGroups; + std::vector m_MergedOutSynWUMPreCodeGroups; }; } // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/neuronGroup.h b/include/genn/genn/neuronGroup.h index 30a17fd695..8425f65077 100644 --- a/include/genn/genn/neuronGroup.h +++ b/include/genn/genn/neuronGroup.h @@ -229,7 +229,7 @@ class GENN_EXPORT NeuronGroup const std::vector &getFusedPreOutputOutSyn() const { return m_FusedPreOutputOutSyn; } //! Gets pointers to all current sources which provide input to this neuron group - const std::vector &getCurrentSources() const { return m_CurrentSources; } + const std::vector &getCurrentSources() const { return m_MergedCurrentSourceGroups; } const std::unordered_map &getDerivedParams() const{ return m_DerivedParams; } @@ -290,7 +290,7 @@ class GENN_EXPORT NeuronGroup std::vector m_FusedPreOutputOutSyn; std::set m_SpikeEventCondition; unsigned int m_NumDelaySlots; - std::vector m_CurrentSources; + std::vector m_MergedCurrentSourceGroups; //! Vector specifying which variables require queues std::vector m_VarQueueRequired; diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 350a491f94..1b1b00c476 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -559,30 +559,30 @@ NeuronInitGroupMerged::NeuronInitGroupMerged(size_t index, const Type::TypeConte : NeuronGroupMergedBase(index, typeContext, backend, groups) { // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group - orderNeuronGroupChildren(m_InSynPSMs, typeContext, backend, + orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, backend, &NeuronGroupInternal::getFusedPSMInSyn, &SynapseGroupInternal::getPSInitHashDigest ); // Build vector of vectors containing each child group's merged out syns with pre output, ordered to match those of the archetype group - orderNeuronGroupChildren(m_OutSynPreOutput, typeContext, backend, + orderNeuronGroupChildren(m_MergedOutSynPreOutputGroups, typeContext, backend, &NeuronGroupInternal::getFusedPreOutputOutSyn, &SynapseGroupInternal::getPreOutputInitHashDigest ); // Build vector of vectors containing each child group's current sources, ordered to match those of the archetype group - orderNeuronGroupChildren(m_CurrentSources, typeContext, backend, + orderNeuronGroupChildren(m_MergedCurrentSourceGroups, typeContext, backend, &NeuronGroupInternal::getCurrentSources, &CurrentSourceInternal::getInitHashDigest ); // Build vector of vectors containing each child group's incoming synapse groups // with postsynaptic weight update model variable, ordered to match those of the archetype group - orderNeuronGroupChildren(m_InSynWUMPostVars, typeContext, backend, + orderNeuronGroupChildren(m_MergedInSynWUMPostVarGroups, typeContext, backend, &NeuronGroupInternal::getFusedInSynWithPostVars, &SynapseGroupInternal::getWUPostInitHashDigest); // Build vector of vectors containing each child group's outgoing synapse groups // with presynaptic weight update model variables, ordered to match those of the archetype group - orderNeuronGroupChildren(m_OutSynWUMPreVars, typeContext, backend, + orderNeuronGroupChildren(m_MergedOutSynWUMPreVarGroups, typeContext, backend, &NeuronGroupInternal::getFusedOutSynWithPreVars, &SynapseGroupInternal::getWUPreInitHashDigest); @@ -631,16 +631,16 @@ boost::uuids::detail::sha1::digest_type NeuronInitGroupMerged::getHashDigest() c updateVarInitDerivedParamHash(&NeuronInitGroupMerged::isVarInitParamReferenced, hash); // Update hash with child groups - for (const auto &cs : m_CurrentSources) { + for (const auto &cs : getMergedCurrentSourceGroups()) { cs.updateHash(hash); } - for(const auto &sg : m_InSynPSMs) { + for(const auto &sg : getMergedInSynPSMGroups()) { sg.updateHash(hash); } - for (const auto &sg : m_OutSynWUMPreVars) { + for (const auto &sg : getMergedOutSynWUMPreVarGroups()) { sg.updateHash(hash); } - for (const auto &sg : m_OutSynWUMPreVars) { + for (const auto &sg : getMergedOutSynWUMPreVarGroups()) { sg.updateHash(hash); } @@ -696,26 +696,26 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); // Loop through all of neuron group's current sources - for (const auto &cs : m_CurrentSources) { + for (const auto &cs : getMergedCurrentSourceGroups()) { cs.generate(backend, os, *this, modelMerged, popSubs); } - for(const auto &sg : m_InSynPSMs) { + for(const auto &sg : getMergedInSynPSMGroups()) { sg.generate(backend, os, *this, modelMerged, popSubs); } // Loop through outgoing synapse groups with presynaptic output - for (const auto &sg : m_OutSynPreOutput) { + for (const auto &sg : getMergedOutSynPreOutputGroups()) { sg.generate(backend, os, *this, modelMerged, popSubs); } - for (const auto &sg : m_OutSynWUMPreVars) { + for (const auto &sg : getMergedOutSynWUMPreVarGroups()) { sg.generate(backend, os, *this, modelMerged, popSubs); } // Generate var update for incoming synaptic populations with postsynaptic code - for (const auto &sg : m_OutSynWUMPreVars) { + for (const auto &sg : getMergedOutSynWUMPreVarGroups()) { sg.generate(backend, os, *this, modelMerged, popSubs); } } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 6148610332..7b38e09cf7 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -609,30 +609,30 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC using namespace Type; // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group - orderNeuronGroupChildren(m_InSynPSMs, typeContext, backend, + orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, backend, &NeuronGroupInternal::getFusedPSMInSyn, &SynapseGroupInternal::getPSHashDigest); // Build vector of vectors containing each child group's merged out syns with pre output, ordered to match those of the archetype group - orderNeuronGroupChildren(m_OutSynPreOutput, typeContext, backend, + orderNeuronGroupChildren(m_MergedOutSynPreOutputGroups, typeContext, backend, &NeuronGroupInternal::getFusedPreOutputOutSyn, &SynapseGroupInternal::getPreOutputHashDigest); // Build vector of vectors containing each child group's current sources, ordered to match those of the archetype group - orderNeuronGroupChildren(m_CurrentSources, typeContext, backend, + orderNeuronGroupChildren(m_MergedCurrentSourceGroups, typeContext, backend, &NeuronGroupInternal::getCurrentSources, &CurrentSourceInternal::getHashDigest); // Build vector of vectors containing each child group's incoming synapse groups // with postsynaptic updates, ordered to match those of the archetype group - orderNeuronGroupChildren(m_InSynWUMPostCode, typeContext, backend, + orderNeuronGroupChildren(m_MergedInSynWUMPostCodeGroups, typeContext, backend, &NeuronGroupInternal::getFusedInSynWithPostCode, &SynapseGroupInternal::getWUPostHashDigest); // Build vector of vectors containing each child group's outgoing synapse groups // with presynaptic synaptic updates, ordered to match those of the archetype group - orderNeuronGroupChildren(m_OutSynWUMPreCode, typeContext, backend, + orderNeuronGroupChildren(m_MergedOutSynWUMPreCodeGroups, typeContext, backend, &NeuronGroupInternal::getFusedOutSynWithPreCode, &SynapseGroupInternal::getWUPreHashDigest); @@ -745,16 +745,16 @@ boost::uuids::detail::sha1::digest_type NeuronUpdateGroupMerged::getHashDigest() updateHash([](const NeuronGroupInternal &g) { return g.getDerivedParams(); }, hash); // Update hash with child groups - for (const auto &cs : m_CurrentSources) { + for (const auto &cs : getMergedCurrentSourceGroups()) { cs.updateHash(hash); } - for(const auto &sg : m_InSynPSMs) { + for(const auto &sg : getMergedInSynPSMGroups()) { sg.updateHash(hash); } - for (const auto &sg : m_OutSynWUMPreCode) { + for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { sg.updateHash(hash); } - for (const auto &sg : m_OutSynWUMPreCode) { + for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { sg.updateHash(hash); } @@ -840,19 +840,19 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C } // Loop through incoming synapse groups - for(const auto &sg : m_InSynPSMs) { + for(const auto &sg : getMergedInSynPSMGroups()) { CodeStream::Scope b(os); sg.generate(backend, os, *this, modelMerged, popSubs); } // Loop through outgoing synapse groups with presynaptic output - for (const auto &sg : m_OutSynPreOutput) { + for (const auto &sg : getMergedOutSynPreOutputGroups()) { CodeStream::Scope b(os); sg.generate(backend, os, *this, modelMerged, popSubs); } // Loop through all of neuron group's current sources - for (const auto &cs : m_CurrentSources) { + for (const auto &cs : getMergedCurrentSourceGroups()) { CodeStream::Scope b(os); cs.generate(backend, os, *this, modelMerged, popSubs); } @@ -896,13 +896,13 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C os << sCode << std::endl; // Generate var update for outgoing synaptic populations with presynaptic update code - for (const auto &sg : m_OutSynWUMPreCode) { + for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { CodeStream::Scope b(os); sg.generate(backend, os, *this, modelMerged, popSubs, true); } // Generate var update for incoming synaptic populations with postsynaptic code - for (const auto &sg : m_OutSynWUMPreCode) { + for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { CodeStream::Scope b(os); sg.generate(backend, os, *this, modelMerged, popSubs, true); } @@ -1012,7 +1012,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C // Are there any outgoing synapse groups with presynaptic code // which have axonal delay and no presynaptic dynamics - const bool preVars = std::any_of(m_OutSynWUMPreCode.cbegin(), m_OutSynWUMPreCode.cend(), + const bool preVars = std::any_of(getMergedOutSynWUMPreCodeGroups().cbegin(), getMergedOutSynWUMPreCodeGroups().cend(), [](const OutSynWUMPreCode &sg) { return ((sg.getArchetype().getDelaySteps() != NO_DELAY) @@ -1021,7 +1021,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C // Are there any incoming synapse groups with postsynaptic code // which have back-propagation delay and no postsynaptic dynamics - const bool postVars = std::any_of(m_InSynWUMPostCode.cbegin(), m_InSynWUMPostCode.cend(), + const bool postVars = std::any_of(getMergedInSynWUMPostCodeGroups().cbegin(), getMergedInSynWUMPostCodeGroups().cend(), [](const auto &sg) { return ((sg.getArchetype().getBackPropDelaySteps() != NO_DELAY) @@ -1044,12 +1044,12 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C } // Loop through outgoing synapse groups with some sort of presynaptic code - for (const auto &sg : m_OutSynWUMPreCode) { + for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { sg.genCopyDelayedVars(os, *this, modelMerged, popSubs); } // Loop through outgoing synapse groups with some sort of postsynaptic code - for (const auto &sg : m_OutSynWUMPreCode) { + for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { sg.genCopyDelayedVars(os, *this, modelMerged, popSubs); } } @@ -1072,13 +1072,13 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { // Generate var update for outgoing synaptic populations with presynaptic update code - for (const auto &sg : m_OutSynWUMPreCode) { + for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { CodeStream::Scope b(os); sg.generate(backend, os, *this, modelMerged, popSubs, false); } // Generate var update for incoming synaptic populations with postsynaptic code - for (const auto &sg : m_OutSynWUMPreCode) { + for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { CodeStream::Scope b(os); sg.generate(backend, os, *this, modelMerged, popSubs, false); } diff --git a/src/genn/genn/neuronGroup.cc b/src/genn/genn/neuronGroup.cc index 6b9cf60cee..51b22baa6e 100644 --- a/src/genn/genn/neuronGroup.cc +++ b/src/genn/genn/neuronGroup.cc @@ -232,7 +232,7 @@ bool NeuronGroup::isSimRNGRequired() const } // Return true if any current sources require an RNG for simulation - if(std::any_of(m_CurrentSources.cbegin(), m_CurrentSources.cend(), + if(std::any_of(m_MergedCurrentSourceGroups.cbegin(), m_MergedCurrentSourceGroups.cend(), [](const CurrentSourceInternal *cs){ return cs->isSimRNGRequired(); })) { return true; @@ -256,7 +256,7 @@ bool NeuronGroup::isInitRNGRequired() const } // Return true if any current sources require an RNG for initialisation - if(std::any_of(m_CurrentSources.cbegin(), m_CurrentSources.cend(), + if(std::any_of(m_MergedCurrentSourceGroups.cbegin(), m_MergedCurrentSourceGroups.cend(), [](const CurrentSourceInternal *cs){ return cs->isInitRNGRequired(); })) { return true; @@ -300,7 +300,7 @@ bool NeuronGroup::isRecordingEnabled() const //---------------------------------------------------------------------------- void NeuronGroup::injectCurrent(CurrentSourceInternal *src) { - m_CurrentSources.push_back(src); + m_MergedCurrentSourceGroups.push_back(src); } //---------------------------------------------------------------------------- NeuronGroup::NeuronGroup(const std::string &name, int numNeurons, const NeuronModels::Base *neuronModel, From f414e142b5943785d58772ab6fc3e72a6df7f42f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 2 Jun 2023 10:06:49 +0100 Subject: [PATCH 191/725] ``isVarInitParamHeterogeneous`` and ``isVarInitDerivedParamHeterogeneous`` need to be public for unit testing --- .../genn/code_generator/initGroupMerged.h | 50 +++++++++---------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 50273a1fff..ca48a3f5b1 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -30,18 +30,18 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; + //! Should the var init parameter be implemented heterogeneously? + bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + + //! Should the var init derived parameter be implemented heterogeneously? + bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + private: //---------------------------------------------------------------------------- // Private methods //---------------------------------------------------------------------------- //! Is the var init parameter referenced? bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; - - //! Should the var init parameter be implemented heterogeneously? - bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - //! Should the var init derived parameter be implemented heterogeneously? - bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; }; //---------------------------------------------------------------------------- @@ -63,18 +63,18 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; + //! Should the var init parameter be implemented heterogeneously? + bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + + //! Should the var init derived parameter be implemented heterogeneously? + bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + private: //---------------------------------------------------------------------------- // Private methods //---------------------------------------------------------------------------- //! Is the var init parameter referenced? bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; - - //! Should the var init parameter be implemented heterogeneously? - bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - //! Should the var init derived parameter be implemented heterogeneously? - bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; }; //---------------------------------------------------------------------------- @@ -113,18 +113,18 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; + //! Should the var init parameter be implemented heterogeneously? + bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + + //! Should the var init derived parameter be implemented heterogeneously? + bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + private: //---------------------------------------------------------------------------- // Private methods //---------------------------------------------------------------------------- //! Is the var init parameter referenced? bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; - - //! Should the var init parameter be implemented heterogeneously? - bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - //! Should the var init derived parameter be implemented heterogeneously? - bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; }; //---------------------------------------------------------------------------- @@ -146,18 +146,18 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; - private: - //---------------------------------------------------------------------------- - // Private methods - //---------------------------------------------------------------------------- - //! Is the var init parameter referenced? - bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; - //! Should the var init parameter be implemented heterogeneously? bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; //! Should the var init derived parameter be implemented heterogeneously? bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + + private: + //---------------------------------------------------------------------------- + // Private methods + //---------------------------------------------------------------------------- + //! Is the var init parameter referenced? + bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; }; NeuronInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, From 91e0197e7389cc9ee9a240f3645a78c1fd53224f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 2 Jun 2023 10:07:05 +0100 Subject: [PATCH 192/725] updated unit tests --- tests/unit/neuronGroup.cc | 44 +++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tests/unit/neuronGroup.cc b/tests/unit/neuronGroup.cc index 147865d458..22aece424d 100644 --- a/tests/unit/neuronGroup.cc +++ b/tests/unit/neuronGroup.cc @@ -746,19 +746,19 @@ TEST(NeuronGroup, CompareCurrentSources) ASSERT_TRUE(dcDCMergedGroup.getGroups().size() == 1); // Find which child in the DC + poisson merged group is the poisson current source - const size_t poissonIndex = (dcPoissonMergedGroup.getSortedArchetypeCurrentSources().at(0)->getCurrentSourceModel() == CurrentSourceModels::PoissonExp::getInstance()) ? 0 : 1; + const size_t poissonIndex = (dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(0).getArchetype().getCurrentSourceModel() == CurrentSourceModels::PoissonExp::getInstance()) ? 0 : 1; // Check that only the ExpDecay and Init derived parameters of the poisson exp current sources are heterogeneous // **NOTE** tauSyn is not heterogeneous because it's not referenced directly - ASSERT_FALSE(dcDCMergedGroup.isCurrentSourceParamHeterogeneous(0, "amp")); - ASSERT_FALSE(dcDCMergedGroup.isCurrentSourceParamHeterogeneous(1, "amp")); - ASSERT_FALSE(dcPoissonMergedGroup.isCurrentSourceParamHeterogeneous(poissonIndex, "weight")); - ASSERT_FALSE(dcPoissonMergedGroup.isCurrentSourceParamHeterogeneous(poissonIndex, "tauSyn")); - ASSERT_FALSE(dcPoissonMergedGroup.isCurrentSourceParamHeterogeneous(poissonIndex, "rate")); - ASSERT_FALSE(dcPoissonMergedGroup.isCurrentSourceParamHeterogeneous(1 - poissonIndex, "amp")); - ASSERT_TRUE(dcPoissonMergedGroup.isCurrentSourceDerivedParamHeterogeneous(poissonIndex, "ExpDecay")); - ASSERT_TRUE(dcPoissonMergedGroup.isCurrentSourceDerivedParamHeterogeneous(poissonIndex, "Init")); - ASSERT_FALSE(dcPoissonMergedGroup.isCurrentSourceDerivedParamHeterogeneous(poissonIndex, "ExpMinusLambda")); + ASSERT_FALSE(dcDCMergedGroup.getMergedCurrentSourceGroups().at(0).isParamHeterogeneous("amp")); + ASSERT_FALSE(dcDCMergedGroup.getMergedCurrentSourceGroups().at(1).isParamHeterogeneous("amp")); + ASSERT_FALSE(dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(poissonIndex).isParamHeterogeneous("weight")); + ASSERT_FALSE(dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(poissonIndex).isParamHeterogeneous("tauSyn")); + ASSERT_FALSE(dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(poissonIndex).isParamHeterogeneous("rate")); + ASSERT_FALSE(dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(1 - poissonIndex).isParamHeterogeneous("amp")); + ASSERT_TRUE(dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(poissonIndex).isDerivedParamHeterogeneous("ExpDecay")); + ASSERT_TRUE(dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(poissonIndex).isDerivedParamHeterogeneous("Init")); + ASSERT_FALSE(dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(poissonIndex).isDerivedParamHeterogeneous("ExpMinusLambda")); } TEST(NeuronGroup, ComparePostsynapticModels) @@ -865,17 +865,17 @@ TEST(NeuronGroup, ComparePostsynapticModels) [](const CodeGenerator::NeuronInitGroupMerged &ng) { return (ng.getGroups().size() == 4); }); // Find which child in the DC + gaussian merged group is the gaussian current source - ASSERT_TRUE(deltaAlphaMergedUpdateGroup->getSortedArchetypeMergedInSyns().size() == 2); - ASSERT_TRUE(deltaAlphaMergedInitGroup->getSortedArchetypeMergedInSyns().size() == 2); - const size_t alphaUpdateIndex = (deltaAlphaMergedUpdateGroup->getSortedArchetypeMergedInSyns().at(0)->getPSModel() == AlphaCurr::getInstance()) ? 0 : 1; - const size_t alphaInitIndex = (deltaAlphaMergedInitGroup->getSortedArchetypeMergedInSyns().at(0)->getPSModel() == AlphaCurr::getInstance()) ? 0 : 1; + ASSERT_TRUE(deltaAlphaMergedUpdateGroup->getMergedInSynPSMGroups().size() == 2); + ASSERT_TRUE(deltaAlphaMergedInitGroup->getMergedInSynPSMGroups().size() == 2); + const size_t alphaUpdateIndex = (deltaAlphaMergedUpdateGroup->getMergedInSynPSMGroups().at(0).getArchetype().getPSModel() == AlphaCurr::getInstance()) ? 0 : 1; + const size_t alphaInitIndex = (deltaAlphaMergedInitGroup->getMergedInSynPSMGroups().at(0).getArchetype().getPSModel() == AlphaCurr::getInstance()) ? 0 : 1; // Check that parameter and both derived parameters are heterogeneous // **NOTE** tau is NOT heterogeneous because it's unused - ASSERT_FALSE(deltaAlphaMergedUpdateGroup->isPSMParamHeterogeneous(alphaUpdateIndex, "tau")); - ASSERT_TRUE(deltaAlphaMergedUpdateGroup->isPSMDerivedParamHeterogeneous(alphaUpdateIndex, "expDecay")); - ASSERT_TRUE(deltaAlphaMergedUpdateGroup->isPSMDerivedParamHeterogeneous(alphaUpdateIndex, "init")); - ASSERT_TRUE(deltaAlphaMergedInitGroup->isPSMVarInitParamHeterogeneous(alphaInitIndex, "x", "constant")); + ASSERT_FALSE(deltaAlphaMergedUpdateGroup->getMergedInSynPSMGroups().at(alphaUpdateIndex).isParamHeterogeneous("tau")); + ASSERT_TRUE(deltaAlphaMergedUpdateGroup->getMergedInSynPSMGroups().at(alphaUpdateIndex).isDerivedParamHeterogeneous("expDecay")); + ASSERT_TRUE(deltaAlphaMergedUpdateGroup->getMergedInSynPSMGroups().at(alphaUpdateIndex).isDerivedParamHeterogeneous("init")); + ASSERT_TRUE(deltaAlphaMergedInitGroup->getMergedInSynPSMGroups().at(alphaInitIndex).isVarInitParamHeterogeneous("x", "constant")); } @@ -1056,8 +1056,8 @@ TEST(NeuronGroup, CompareWUPreUpdate) [](const CodeGenerator::NeuronInitGroupMerged &ng) { return (ng.getGroups().size() == 4); }); // Check that parameter is heterogeneous - ASSERT_TRUE(wumPreMergedUpdateGroup->isOutSynWUMParamHeterogeneous(0, "p")); - ASSERT_TRUE(wumPreMergedInitGroup->isOutSynWUMVarInitParamHeterogeneous(0, "s", "constant")); + ASSERT_TRUE(wumPreMergedUpdateGroup->getMergedOutSynWUMPreCodeGroups().at(0).isParamHeterogeneous("p")); + ASSERT_TRUE(wumPreMergedInitGroup->getMergedOutSynWUMPreVarGroups().at(0).isVarInitParamHeterogeneous("s", "constant")); } TEST(NeuronGroup, CompareWUPostUpdate) @@ -1157,6 +1157,6 @@ TEST(NeuronGroup, CompareWUPostUpdate) [](const CodeGenerator::NeuronInitGroupMerged &ng) { return (ng.getGroups().size() == 4); }); // Check that parameter is heterogeneous - ASSERT_TRUE(wumPostMergedUpdateGroup->isInSynWUMParamHeterogeneous(0, "p")); - ASSERT_TRUE(wumPostMergedInitGroup->isInSynWUMVarInitParamHeterogeneous(0, "s", "constant")); + ASSERT_TRUE(wumPostMergedUpdateGroup->getMergedInSynWUMPostCodeGroups().at(0).isParamHeterogeneous("p")); + ASSERT_TRUE(wumPostMergedInitGroup->getMergedInSynWUMPostVarGroups().at(0).isVarInitParamHeterogeneous("s", "constant")); } From 825d58d032208cbe04f63736737145bd0a177433 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 2 Jun 2023 10:07:36 +0100 Subject: [PATCH 193/725] fixed typos - thanks tests --- src/genn/genn/code_generator/initGroupMerged.cc | 15 ++++----------- .../code_generator/neuronUpdateGroupMerged.cc | 10 +++++----- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 1b1b00c476..216ccefdbd 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -637,7 +637,7 @@ boost::uuids::detail::sha1::digest_type NeuronInitGroupMerged::getHashDigest() c for(const auto &sg : getMergedInSynPSMGroups()) { sg.updateHash(hash); } - for (const auto &sg : getMergedOutSynWUMPreVarGroups()) { + for (const auto &sg : getMergedInSynWUMPostVarGroups()) { sg.updateHash(hash); } for (const auto &sg : getMergedOutSynWUMPreVarGroups()) { @@ -695,27 +695,20 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); - // Loop through all of neuron group's current sources + // Generate initialisation code for child groups for (const auto &cs : getMergedCurrentSourceGroups()) { cs.generate(backend, os, *this, modelMerged, popSubs); } - for(const auto &sg : getMergedInSynPSMGroups()) { sg.generate(backend, os, *this, modelMerged, popSubs); } - - // Loop through outgoing synapse groups with presynaptic output for (const auto &sg : getMergedOutSynPreOutputGroups()) { sg.generate(backend, os, *this, modelMerged, popSubs); - } - - + } for (const auto &sg : getMergedOutSynWUMPreVarGroups()) { sg.generate(backend, os, *this, modelMerged, popSubs); } - - // Generate var update for incoming synaptic populations with postsynaptic code - for (const auto &sg : getMergedOutSynWUMPreVarGroups()) { + for (const auto &sg : getMergedInSynWUMPostVarGroups()) { sg.generate(backend, os, *this, modelMerged, popSubs); } } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 7b38e09cf7..43f46776a7 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -751,7 +751,7 @@ boost::uuids::detail::sha1::digest_type NeuronUpdateGroupMerged::getHashDigest() for(const auto &sg : getMergedInSynPSMGroups()) { sg.updateHash(hash); } - for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { + for (const auto &sg : getMergedInSynWUMPostCodeGroups()) { sg.updateHash(hash); } for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { @@ -902,7 +902,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C } // Generate var update for incoming synaptic populations with postsynaptic code - for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { + for (const auto &sg : getMergedInSynWUMPostCodeGroups()) { CodeStream::Scope b(os); sg.generate(backend, os, *this, modelMerged, popSubs, true); } @@ -1048,8 +1048,8 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C sg.genCopyDelayedVars(os, *this, modelMerged, popSubs); } - // Loop through outgoing synapse groups with some sort of postsynaptic code - for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { + // Loop through incoming synapse groups with some sort of presynaptic code + for (const auto &sg : getMergedInSynWUMPostCodeGroups()) { sg.genCopyDelayedVars(os, *this, modelMerged, popSubs); } } @@ -1078,7 +1078,7 @@ void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, Co } // Generate var update for incoming synaptic populations with postsynaptic code - for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { + for (const auto &sg : getMergedInSynWUMPostCodeGroups()) { CodeStream::Scope b(os); sg.generate(backend, os, *this, modelMerged, popSubs, false); } From 094eed39c43779818510c338559f732a441dce3b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 2 Jun 2023 11:29:15 +0100 Subject: [PATCH 194/725] * Refined new APIs * Update ``NeuronUpdateGroupMerged`` to build type environment --- .../genn/genn/code_generator/groupMerged.h | 14 +- .../groupMergedTypeEnvironment.h | 68 ++-- .../code_generator/neuronUpdateGroupMerged.h | 68 +++- .../code_generator/customUpdateGroupMerged.cc | 24 +- .../code_generator/neuronUpdateGroupMerged.cc | 322 ++++++++---------- 5 files changed, 260 insertions(+), 236 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 2bac5509ac..5ced7aac1e 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -19,12 +19,20 @@ #include "code_generator/backendBase.h" #include "code_generator/codeGenUtils.h" +// GeNN transpiler includes +#include "transpiler/typeChecker.h" + // Forward declarations namespace GeNN::CodeGenerator { class CodeStream; } +namespace Transpiler::TypeChecker +{ +class EnvironmentBase; +} + //------------------------------------------------------------------------ // GeNN::CodeGenerator::GroupMergedFieldType //------------------------------------------------------------------------ @@ -632,8 +640,8 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged - void orderNeuronGroupChildren(std::vector &childGroups, const Type::TypeContext &typeContext, const BackendBase &backend, - G getVectorFunc, H getHashDigestFunc) const + void orderNeuronGroupChildren(std::vector &childGroups, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, G getVectorFunc, H getHashDigestFunc) const { const std::vector &archetypeChildren = (getArchetype().*getVectorFunc)(); @@ -674,7 +682,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged &(GroupInternal::*)(void) const; + + template + using GetVarReferencesFn = const std::unordered_map &(GroupInternal::*)(void) const; public: GroupMergedTypeEnvironment(G &groupMerged, EnvironmentBase *enclosing = nullptr) @@ -81,22 +90,25 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa } } - void definePointerField(const Type::ResolvedType &type, const std::string &name,const std::string &prefix, VarAccessMode access) + void definePointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix, VarAccessMode access, + const std::string &fieldSuffix = "", GetVarSuffixFn getVarSuffixFn = &GroupInternal::getName) { const auto qualifiedType = (access & VarAccessModeAttribute::READ_ONLY) ? type.addQualifier(Type::Qualifier::CONSTANT) : type; defineField(qualifiedType, name, - type.createPointer(), name, [prefix](const auto &g, size_t) { return prefix + g.getName(); }); + type.createPointer(), name + fieldSuffix, + [prefix, getVarSuffixFn](const auto &g, size_t) { return prefix + std::invoke(getVarSuffixFn, g); }); } - void definePointerField(const Type::UnresolvedType &type, const std::string &name, const std::string &prefix, VarAccessMode access) + void definePointerField(const Type::UnresolvedType &type, const std::string &name, const std::string &prefix, VarAccessMode access, + const std::string &fieldSuffix = "", GetVarSuffixFn getVarSuffixFn = &GroupInternal::getName) { - definePointerField(type.resolve(m_GroupMerged.getTypeContext()), name, prefix, access); + definePointerField(type.resolve(m_GroupMerged.getTypeContext()), name, prefix, access, fieldSuffix, getVarSuffixFn); } - void defineScalarField(const std::string &name, typename G::GetFieldDoubleValueFunc getFieldValue) + void defineScalarField(const std::string &name, const std::string &fieldSuffix, typename G::GetFieldDoubleValueFunc getFieldValue) { defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), name, - m_GroupMerged.getScalarType(), name, + m_GroupMerged.getScalarType(), name + fieldSuffix, [getFieldValue, this](const auto &g, size_t i) { return (Utils::writePreciseString(getFieldValue(g, i), m_GroupMerged.getScalarType().getNumeric().maxDigits10) @@ -104,55 +116,56 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa }); } - template - void defineHeterogeneousParams(const Snippet::Base::StringVec ¶mNames, const std::string &suffix, - P getParamValues, H isHeterogeneous) + void defineHeterogeneousParams(const Snippet::Base::StringVec ¶mNames, const std::string &fieldSuffix, + GetParamValuesFn getParamValues, IsHeterogeneousFn isHeterogeneous) { // Loop through params for(const auto &p : paramNames) { if (std::invoke(isHeterogeneous, m_GroupMerged, p)) { - defineScalarField(p + suffix, + defineScalarField(p, fieldSuffix, [p, getParamValues](const auto &g, size_t) { - return getParamValues(g).at(p); + return std::invoke(getParamValues, g).at(p); }); } // Otherwise, just add a const-qualified scalar to the type environment else { - defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), p + suffix); + defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), p); } } } - template - void defineHeterogeneousDerivedParams(const Snippet::Base::DerivedParamVec &derivedParams, const std::string &suffix, - D getDerivedParamValues, H isHeterogeneous) + void defineHeterogeneousDerivedParams(const Snippet::Base::DerivedParamVec &derivedParams, const std::string &fieldSuffix, + GetParamValuesFn getDerivedParamValues, IsHeterogeneousFn isHeterogeneous) { // Loop through derived params for(const auto &d : derivedParams) { if (std::invoke(isHeterogeneous, m_GroupMerged, d.name)) { - defineScalarField(d.name + suffix, + defineScalarField(d.name, fieldSuffix, [d, getDerivedParamValues](const auto &g, size_t) { - return getDerivedParamValues(g).at(d.name); + return std::invoke(getDerivedParamValues, g).at(d.name); }); } else { - defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), d.name + suffix); + defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), d.name); } } } - void defineVars(const Models::Base::VarVec &vars, const std::string &arrayPrefix) + void defineVars(const Models::Base::VarVec &vars, const std::string &arrayPrefix, + const std::string &fieldSuffix = "", GetVarSuffixFn getVarSuffixFn = &GroupInternal::getName) { // Loop through variables for(const auto &v : vars) { - definePointerField(v.type, v.name, arrayPrefix, getVarAccessMode(v.access)); + definePointerField(v.type, v.name, arrayPrefix, getVarAccessMode(v.access), + fieldSuffix, getVarSuffixFn); } } template - void defineVarReferences(const Models::Base::VarRefVec &varReferences, const std::string &arrayPrefix, V getVarRefFn) + void defineVarReferences(const Models::Base::VarRefVec &varReferences, const std::string &arrayPrefix, + const std::string &fieldSuffix = "", GetVarReferencesFn getVarRefFn = &GroupInternal::getVarReferences) { // Loop through variables for(const auto &v : varReferences) { @@ -160,24 +173,25 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa const auto resolvedType = v.type.resolve(m_GroupMerged.getTypeContext()); const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addQualifier(Type::Qualifier::CONSTANT) : resolvedType; defineField(qualifiedType, v.name, - resolvedType.createPointer(), v.name, + resolvedType.createPointer(), v.name + fieldSuffix, [arrayPrefix, getVarRefFn, v](const auto &g, size_t) { - const auto varRef = getVarRefFn(g).at(v.name); + const auto varRef = std::invoke(getVarRefFn, g).at(v.name); return arrayPrefix + varRef.getVar().name + varRef.getTargetName(); }); } } - void defineEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "") + void defineEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "", + const std::string &fieldSuffix = "", GetVarSuffixFn getVarSuffixFn = &GroupInternal::getName) { for(const auto &e : egps) { const auto pointerType = e.type.resolve(m_GroupMerged.getTypeContext()).createPointer(); defineField(pointerType, e.name, - pointerType, e.name + varName, - [arrayPrefix, e, varName](const auto &g, size_t) + pointerType, e.name + varName + fieldSuffix, + [arrayPrefix, e, varName, getVarSuffixFn](const auto &g, size_t) { - return arrayPrefix + e.name + varName + g.getName(); + return arrayPrefix + e.name + varName + std::invoke(getVarSuffixFn, g); }, GroupMergedFieldType::DYNAMIC); } diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index d6071574d5..e387d8c674 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -3,6 +3,10 @@ // GeNN code generator includes #include "code_generator/groupMerged.h" +// GeNN transpiler includes +#include "transpiler/statement.h" +#include "transpiler/typeChecker.h" + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged //---------------------------------------------------------------------------- @@ -18,8 +22,8 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase class CurrentSource : public GroupMerged { public: - CurrentSource(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + CurrentSource(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- // Public API @@ -40,8 +44,17 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Private API //---------------------------------------------------------------------------- - //! Is the current source parameter referenced? + //! Is the parameter referenced? **YUCK** only used for hashing bool isParamReferenced(const std::string ¶mName) const; + + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + //! List of statements parsed and type-checked in constructor; and used to generate code + Transpiler::Statement::StatementList m_UpdateStatements; + + //! Resolved types used to generate code + Transpiler::TypeChecker::ResolvedTypeMap m_ResolvedTypes; }; //---------------------------------------------------------------------------- @@ -51,8 +64,8 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase class InSynPSM : public GroupMerged { public: - InSynPSM(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + InSynPSM(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- // Public API @@ -73,8 +86,17 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Private API //---------------------------------------------------------------------------- - //! Is the current source parameter referenced? + //! Is the parameter referenced? **YUCK** only used for hashing bool isParamReferenced(const std::string ¶mName) const; + + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + //! List of statements parsed and type-checked in constructor; and used to generate code + Transpiler::Statement::StatementList m_UpdateStatements; + + //! Resolved types used to generate code + Transpiler::TypeChecker::ResolvedTypeMap m_ResolvedTypes; }; //---------------------------------------------------------------------------- @@ -84,8 +106,8 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase class OutSynPreOutput : public GroupMerged { public: - OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- // Public API @@ -101,8 +123,8 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase class InSynWUMPostCode : public GroupMerged { public: - InSynWUMPostCode(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + InSynWUMPostCode(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- // Public API @@ -126,8 +148,17 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Private API //---------------------------------------------------------------------------- - //! Is the current source parameter referenced? + //! Is the parameter referenced? **YUCK** only used for hashing bool isParamReferenced(const std::string ¶mName) const; + + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + //! List of statements parsed and type-checked in constructor; and used to generate code + Transpiler::Statement::StatementList m_UpdateStatements; + + //! Resolved types used to generate code + Transpiler::TypeChecker::ResolvedTypeMap m_ResolvedTypes; }; //---------------------------------------------------------------------------- @@ -137,8 +168,8 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase class OutSynWUMPreCode : public GroupMerged { public: - OutSynWUMPreCode(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + OutSynWUMPreCode(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- // Public API @@ -162,8 +193,17 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Private API //---------------------------------------------------------------------------- - //! Is the current source parameter referenced? + //! Is the parameter referenced? **YUCK** only used for hashing bool isParamReferenced(const std::string ¶mName) const; + + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + //! List of statements parsed and type-checked in constructor; and used to generate code + Transpiler::Statement::StatementList m_UpdateStatements; + + //! Resolved types used to generate code + Transpiler::TypeChecker::ResolvedTypeMap m_ResolvedTypes; }; NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index fb46fa7fdd..140de1de16 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -50,23 +50,22 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::TypeC // Add heterogeneous custom update model parameters const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); - typeEnvironment.defineHeterogeneousParams( + typeEnvironment.defineHeterogeneousParams( cm->getParamNames(), "", - [](const auto &cg) { return cg.getParams(); }, + &CustomUpdateInternal::getParams, &CustomUpdateGroupMerged::isParamHeterogeneous); - // Add heterogeneous weight update model derived parameters - typeEnvironment.defineHeterogeneousDerivedParams( + // Add heterogeneous custom update model derived parameters + typeEnvironment.defineHeterogeneousDerivedParams( cm->getDerivedParams(), "", - [](const auto &cg) { return cg.getDerivedParams(); }, + &CustomUpdateInternal::getDerivedParams, &CustomUpdateGroupMerged::isDerivedParamHeterogeneous); // Add variables to struct typeEnvironment.defineVars(cm->getVars(), backend.getDeviceVarPrefix()); // Add variable references to struct - typeEnvironment.defineVarReferences(cm->getVarRefs(), backend.getDeviceVarPrefix(), - [](const auto &cg) { return cg.getVarReferences(); }); + typeEnvironment.defineVarReferences(cm->getVarRefs(), backend.getDeviceVarPrefix()); // Add EGPs to struct typeEnvironment.defineEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); @@ -325,15 +324,15 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // Add heterogeneous custom update model parameters const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); - typeEnvironment.defineHeterogeneousParams( + typeEnvironment.defineHeterogeneousParams( cm->getParamNames(), "", - [](const auto &cg) { return cg.getParams(); }, + &CustomUpdateWUInternal::getParams, &CustomUpdateWUGroupMergedBase::isParamHeterogeneous); // Add heterogeneous weight update model derived parameters - typeEnvironment.defineHeterogeneousDerivedParams( + typeEnvironment.defineHeterogeneousDerivedParams( cm->getDerivedParams(), "", - [](const auto &cg) { return cg.getDerivedParams(); }, + &CustomUpdateWUInternal::getDerivedParams, &CustomUpdateWUGroupMergedBase::isDerivedParamHeterogeneous); // Add variables to struct @@ -341,8 +340,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // Add variable references to struct const auto varRefs = cm->getVarRefs(); - typeEnvironment.defineVarReferences(varRefs, backend.getDeviceVarPrefix(), - [](const auto &cg) { return cg.getVarReferences(); }); + typeEnvironment.defineVarReferences(varRefs, backend.getDeviceVarPrefix()); // Loop through variables for(const auto &v : varRefs) { diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 43f46776a7..2232d9a16e 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -1,10 +1,20 @@ #include "code_generator/neuronUpdateGroupMerged.h" // GeNN code generator includes +#include "code_generator/groupMergedTypeEnvironment.h" #include "code_generator/modelSpecMerged.h" +// GeNN transpiler includes +#include "transpiler/errorHandler.h" +#include "transpiler/parser.h" +#include "transpiler/prettyPrinter.h" +#include "transpiler/scanner.h" +#include "transpiler/standardLibrary.h" +#include "transpiler/typeChecker.h" + using namespace GeNN; using namespace GeNN::CodeGenerator; +using namespace GeNN::Transpiler; //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::CurrentSource @@ -13,34 +23,31 @@ using namespace GeNN::CodeGenerator; // * field suffix (string) and value suffix (function to get suffix from group) common to everything in group - GroupMerged fields? // * without nasty combined groups, getParams and getDerivedParams functions can use pointers to members // * pre and post neuron stuff in synapse update group merged can also be child classes -NeuronUpdateGroupMerged::CurrentSource::CurrentSource(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) +NeuronUpdateGroupMerged::CurrentSource::CurrentSource(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, typeContext, groups) { const std::string suffix = "CS" + std::to_string(getIndex()); + // Create type environment + GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); + + // Add heterogeneous parameters + const auto *cm = getArchetype().getCurrentSourceModel(); + typeEnvironment.defineHeterogeneousParams(cm->getParamNames(), suffix, + &CurrentSourceInternal::getParams, + &CurrentSource::isParamHeterogeneous); + + // Add heterogeneous derived parameters + typeEnvironment.defineHeterogeneousDerivedParams(cm->getDerivedParams(), suffix, + &CurrentSourceInternal::getDerivedParams, + &CurrentSource::isDerivedParamHeterogeneous); + // Add variables - for(const auto &var : getArchetype().getCurrentSourceModel()->getVars()) { - addPointerField(var.type, var.name + suffix, - backend.getDeviceVarPrefix() + var.name); - } - - // Add parameters and derived parameters - addHeterogeneousParams( - getArchetype().getCurrentSourceModel()->getParamNames(), suffix, - [](const auto &cs) { return cs.getParams(); }, - &CurrentSource::isParamHeterogeneous); - addHeterogeneousDerivedParams( - getArchetype().getCurrentSourceModel()->getDerivedParams(), suffix, - [](const auto &cs) { return cs.getDerivedParams(); }, - &CurrentSource::isDerivedParamHeterogeneous); + typeEnvironment.defineVars(cm->getVars(), backend.getDeviceVarPrefix(), suffix); // Add EGPs - for(const auto &egp : getArchetype().getCurrentSourceModel()->getExtraGlobalParams()) { - addPointerField(egp.type, egp.name + suffix, - backend.getDeviceVarPrefix() + egp.name, - GroupMergedFieldType::DYNAMIC); - } + typeEnvironment.defineEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", suffix); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, @@ -94,32 +101,25 @@ void NeuronUpdateGroupMerged::CurrentSource::updateHash(boost::uuids::detail::sh //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::CurrentSource::isParamHeterogeneous(const std::string ¶mName) const { - return (isParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const CurrentSourceInternal &cs) { return cs.getParams(); })); + return isParamValueHeterogeneous(paramName, [](const CurrentSourceInternal &cs) { return cs.getParams(); }); } //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::CurrentSource::isDerivedParamHeterogeneous( const std::string ¶mName) const { - return (isParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const CurrentSourceInternal &cs) { return cs.getDerivedParams(); })); - + return isParamValueHeterogeneous(paramName, [](const CurrentSourceInternal &cs) { return cs.getDerivedParams(); }); } -//---------------------------------------------------------------------------- -bool NeuronUpdateGroupMerged::CurrentSource::isParamReferenced(const std::string ¶mName) const -{ - return GroupMerged::isParamReferenced({getArchetype().getCurrentSourceModel()->getInjectionCode()}, - paramName); -} - //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynPSM //---------------------------------------------------------------------------- -NeuronUpdateGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) +NeuronUpdateGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, typeContext, groups) { const std::string suffix = "InSyn" + std::to_string(getIndex()); + // Create type environment + GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); + // Add pointer to insyn addField(getScalarType().createPointer(), "inSyn" + suffix, [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "inSyn" + g.getFusedPSVarSuffix(); }); @@ -133,31 +133,24 @@ NeuronUpdateGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContex [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); } - // Add pointers to state variable - // **FUSE** - for(const auto &var : getArchetype().getPSModel()->getVars()) { - addField(var.type.resolve(getTypeContext()).createPointer(), var.name + suffix, - [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getFusedPSVarSuffix(); }); - } + // Add heterogeneous parameters + const auto *psm = getArchetype().getPSModel(); + typeEnvironment.defineHeterogeneousParams(psm->getParamNames(), suffix, + &SynapseGroupInternal::getPSParams, + &InSynPSM::isParamHeterogeneous); - // Add any heterogeneous postsynaptic model parameters - addHeterogeneousParams( - getArchetype().getPSModel()->getParamNames(), suffix, - [](const auto &sg) { return sg.getPSParams(); }, - &InSynPSM::isParamHeterogeneous); + // Add heterogeneous derived parameters + typeEnvironment.defineHeterogeneousDerivedParams(psm->getDerivedParams(), suffix, + &SynapseGroupInternal::getPSDerivedParams, + &InSynPSM::isDerivedParamHeterogeneous); - // Add any heterogeneous postsynaptic mode derived parameters - addHeterogeneousDerivedParams( - getArchetype().getPSModel()->getDerivedParams(), suffix, - [](const auto &sg) { return sg.getPSDerivedParams(); }, - &InSynPSM::isDerivedParamHeterogeneous); + // Add variables + typeEnvironment.defineVars(psm->getVars(), backend.getDeviceVarPrefix(), + suffix, &SynapseGroupInternal::getFusedPSVarSuffix); // Add EGPs - for(const auto &egp : getArchetype().getPSModel()->getExtraGlobalParams()) { - addField(egp.type.resolve(getTypeContext()).createPointer(), egp.name + suffix, - [&backend, egp](const auto &g, size_t) { return backend.getDeviceVarPrefix() + egp.name + g.getFusedPSVarSuffix(); }, - GroupMergedFieldType::DYNAMIC); - } + typeEnvironment.defineEGPs(psm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", + suffix, &SynapseGroupInternal::getFusedPSVarSuffix); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, @@ -262,29 +255,19 @@ void NeuronUpdateGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &h //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::InSynPSM::isParamHeterogeneous(const std::string ¶mName) const { - return (isParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getPSParams(); })); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getPSParams(); }); } //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::InSynPSM::isDerivedParamHeterogeneous( const std::string ¶mName) const { - return (isParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getPSDerivedParams(); })); - -} -//---------------------------------------------------------------------------- -bool NeuronUpdateGroupMerged::InSynPSM::isParamReferenced(const std::string ¶mName) const -{ - return GroupMerged::isParamReferenced( - {getArchetype().getPSModel()->getApplyInputCode(), getArchetype().getPSModel()->getDecayCode()}, - paramName); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getPSDerivedParams(); }); } //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynPreOutput //---------------------------------------------------------------------------- -NeuronUpdateGroupMerged::OutSynPreOutput::OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) +NeuronUpdateGroupMerged::OutSynPreOutput::OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase&, + const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, typeContext, groups) { const std::string suffix = "OutSyn" + std::to_string(getIndex()); @@ -298,46 +281,45 @@ void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backe { const std::string suffix = "OutSyn" + std::to_string(getIndex()); - os << getArchetype().getPreTargetVar() << "+= "; + os << getArchetype().getPreTargetVar() << " += "; os << "group->revInSyn" << suffix << "["; os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, popSubs["id"]); os << "];" << std::endl; os << "group->revInSyn" << suffix << "["; os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, popSubs["id"]); - os << "]= " << modelMerged.scalarExpr(0.0) << ";" << std::endl; + os << "] = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; } //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynWUMPostCode //---------------------------------------------------------------------------- -NeuronUpdateGroupMerged::InSynWUMPostCode::InSynWUMPostCode(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) +NeuronUpdateGroupMerged::InSynWUMPostCode::InSynWUMPostCode(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, typeContext, groups) { const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); - // Add postsynaptic variables - for(const auto &var : getArchetype().getWUModel()->getPostVars()) { - addField(var.type.resolve(getTypeContext()).createPointer(), var.name + suffix, - [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getFusedWUPostVarSuffix(); }); - } - - // Add parameters and derived parameters - addHeterogeneousParams( - getArchetype().getWUModel()->getParamNames(), suffix, - [](const auto &sg) { return sg.getWUParams(); }, - &InSynWUMPostCode::isParamHeterogeneous); - addHeterogeneousDerivedParams( - getArchetype().getWUModel()->getDerivedParams(), suffix, - [](const auto &sg) { return sg.getWUDerivedParams(); }, - &InSynWUMPostCode::isDerivedParamHeterogeneous); + // Create type environment + GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); + + // Add heterogeneous parameters + const auto *wum = getArchetype().getWUModel(); + typeEnvironment.defineHeterogeneousParams(wum->getParamNames(), suffix, + &SynapseGroupInternal::getWUParams, + &InSynWUMPostCode::isParamHeterogeneous); + + // Add heterogeneous derived parameters + typeEnvironment.defineHeterogeneousDerivedParams(wum->getDerivedParams(), suffix, + &SynapseGroupInternal::getWUDerivedParams, + &InSynWUMPostCode::isDerivedParamHeterogeneous); + + // Add variables + typeEnvironment.defineVars(wum->getPostVars(), backend.getDeviceVarPrefix(), + suffix, &SynapseGroupInternal::getFusedWUPostVarSuffix); // Add EGPs - for(const auto &egp : getArchetype().getWUModel()->getExtraGlobalParams()) { - addField(egp.type.resolve(getTypeContext()).createPointer(), egp.name + suffix, - [&backend, egp](const auto &g, size_t) { return backend.getDeviceVarPrefix() + egp.name + g.getFusedWUPostVarSuffix(); }, - GroupMergedFieldType::DYNAMIC); - } + typeEnvironment.defineEGPs(wum->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", + suffix, &SynapseGroupInternal::getFusedWUPostVarSuffix); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, @@ -434,55 +416,44 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::updateHash(boost::uuids::detail: //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::InSynWUMPostCode::isParamHeterogeneous(const std::string ¶mName) const { - return (isParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUParams(); })); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUParams(); }); } //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::InSynWUMPostCode::isDerivedParamHeterogeneous( const std::string ¶mName) const { - return (isParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUDerivedParams(); })); - -} -//---------------------------------------------------------------------------- -bool NeuronUpdateGroupMerged::InSynWUMPostCode::isParamReferenced(const std::string ¶mName) const -{ - return GroupMerged::isParamReferenced( - {getArchetype().getWUModel()->getPostDynamicsCode(), getArchetype().getWUModel()->getPostSpikeCode()}, - paramName); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUDerivedParams(); }); } //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynWUMPreCode //---------------------------------------------------------------------------- -NeuronUpdateGroupMerged::OutSynWUMPreCode::OutSynWUMPreCode(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) +NeuronUpdateGroupMerged::OutSynWUMPreCode::OutSynWUMPreCode(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, typeContext, groups) { const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); - // Add presynaptic variables - for(const auto &var : getArchetype().getWUModel()->getPreVars()) { - addField(var.type.resolve(getTypeContext()).createPointer(), var.name + suffix, - [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getFusedWUPreVarSuffix(); }); - } - - // Add parameters and derived parameters - addHeterogeneousParams( - getArchetype().getWUModel()->getParamNames(), suffix, - [](const auto &sg) { return sg.getWUParams(); }, - &OutSynWUMPreCode::isParamHeterogeneous); - addHeterogeneousDerivedParams( - getArchetype().getWUModel()->getDerivedParams(), suffix, - [](const auto &sg) { return sg.getWUDerivedParams(); }, - &OutSynWUMPreCode::isDerivedParamHeterogeneous); + // Create type environment + GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); + + // Add heterogeneous parameters + const auto *wum = getArchetype().getWUModel(); + typeEnvironment.defineHeterogeneousParams(wum->getParamNames(), suffix, + &SynapseGroupInternal::getWUParams, + &OutSynWUMPreCode::isParamHeterogeneous); + + // Add heterogeneous derived parameters + typeEnvironment.defineHeterogeneousDerivedParams(wum->getDerivedParams(), suffix, + &SynapseGroupInternal::getWUDerivedParams, + &OutSynWUMPreCode::isDerivedParamHeterogeneous); + + // Add variables + typeEnvironment.defineVars(wum->getPreVars(), backend.getDeviceVarPrefix(), + suffix, &SynapseGroupInternal::getFusedWUPreVarSuffix); // Add EGPs - for(const auto &egp : getArchetype().getWUModel()->getExtraGlobalParams()) { - addField(egp.type.resolve(getTypeContext()).createPointer(), egp.name + suffix, - [&backend, egp](const auto &g, size_t) { return backend.getDeviceVarPrefix() + egp.name + g.getFusedWUPreVarSuffix(); }, - GroupMergedFieldType::DYNAMIC); - } + typeEnvironment.defineEGPs(wum->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", + suffix, &SynapseGroupInternal::getFusedWUPreVarSuffix); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, @@ -579,23 +550,14 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::updateHash(boost::uuids::detail: //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::OutSynWUMPreCode::isParamHeterogeneous(const std::string ¶mName) const { - return (isParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUParams(); })); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUParams(); }); } //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::OutSynWUMPreCode::isDerivedParamHeterogeneous( const std::string ¶mName) const { - return (isParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUDerivedParams(); })); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUDerivedParams(); }); } -//---------------------------------------------------------------------------- -bool NeuronUpdateGroupMerged::OutSynWUMPreCode::isParamReferenced(const std::string ¶mName) const -{ - return GroupMerged::isParamReferenced( - {getArchetype().getWUModel()->getPreDynamicsCode(), getArchetype().getWUModel()->getPreSpikeCode()}, - paramName); -} //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged @@ -608,53 +570,33 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC { using namespace Type; - // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, backend, - &NeuronGroupInternal::getFusedPSMInSyn, - &SynapseGroupInternal::getPSHashDigest); - - // Build vector of vectors containing each child group's merged out syns with pre output, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedOutSynPreOutputGroups, typeContext, backend, - &NeuronGroupInternal::getFusedPreOutputOutSyn, - &SynapseGroupInternal::getPreOutputHashDigest); - - // Build vector of vectors containing each child group's current sources, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedCurrentSourceGroups, typeContext, backend, - &NeuronGroupInternal::getCurrentSources, - &CurrentSourceInternal::getHashDigest); - - - // Build vector of vectors containing each child group's incoming synapse groups - // with postsynaptic updates, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedInSynWUMPostCodeGroups, typeContext, backend, - &NeuronGroupInternal::getFusedInSynWithPostCode, - &SynapseGroupInternal::getWUPostHashDigest); - - // Build vector of vectors containing each child group's outgoing synapse groups - // with presynaptic synaptic updates, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedOutSynWUMPreCodeGroups, typeContext, backend, - &NeuronGroupInternal::getFusedOutSynWithPreCode, - &SynapseGroupInternal::getWUPreHashDigest); + // Create type environment + StandardLibrary::FunctionTypes stdLibraryEnv; + GroupMergedTypeEnvironment typeEnvironment(*this, &stdLibraryEnv); + // Add RNG if(backend.isPopulationRNGRequired() && getArchetype().isSimRNGRequired()) { + // **TODO** inject RNG types into environment + addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); } - // Add variables and extra global parameters - addVars(getArchetype().getNeuronModel()->getVars(), backend.getDeviceVarPrefix()); - addEGPs(getArchetype().getNeuronModel()->getExtraGlobalParams(), backend.getDeviceVarPrefix()); - // Add heterogeneous neuron model parameters - addHeterogeneousParams( - getArchetype().getNeuronModel()->getParamNames(), "", - [](const NeuronGroupInternal &ng) { return ng.getParams(); }, - &NeuronGroupMergedBase::isParamHeterogeneous); + const auto *nm = getArchetype().getNeuronModel(); + typeEnvironment.defineHeterogeneousParams(nm->getParamNames(), "", + &NeuronGroupInternal::getParams, + &NeuronUpdateGroupMerged::isParamHeterogeneous); - // Add heterogeneous neuron model derived parameters - addHeterogeneousDerivedParams( - getArchetype().getNeuronModel()->getDerivedParams(), "", - [](const NeuronGroupInternal &ng) { return ng.getDerivedParams(); }, - &NeuronGroupMergedBase::isDerivedParamHeterogeneous); + // Add heterogeneous weight update model derived parameters + typeEnvironment.defineHeterogeneousDerivedParams(nm->getDerivedParams(), "", + &NeuronGroupInternal::getDerivedParams, + &NeuronUpdateGroupMerged::isDerivedParamHeterogeneous); + + // Add variables + typeEnvironment.defineVars(nm->getVars(), backend.getDeviceVarPrefix()); + + // Add EGPs + typeEnvironment.defineEGPs(nm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); // Loop through neuron groups std::vector> eventThresholdSGs; @@ -728,6 +670,28 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC GroupMergedFieldType::DYNAMIC); } + // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group + orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, typeEnvironment, backend, + &NeuronGroupInternal::getFusedPSMInSyn, &SynapseGroupInternal::getPSHashDigest); + + // Build vector of vectors containing each child group's merged out syns with pre output, ordered to match those of the archetype group + orderNeuronGroupChildren(m_MergedOutSynPreOutputGroups, typeContext, typeEnvironment, backend, + &NeuronGroupInternal::getFusedPreOutputOutSyn, &SynapseGroupInternal::getPreOutputHashDigest); + + // Build vector of vectors containing each child group's current sources, ordered to match those of the archetype group + orderNeuronGroupChildren(m_MergedCurrentSourceGroups, typeContext, typeEnvironment, backend, + &NeuronGroupInternal::getCurrentSources, &CurrentSourceInternal::getHashDigest); + + + // Build vector of vectors containing each child group's incoming synapse groups + // with postsynaptic updates, ordered to match those of the archetype group + orderNeuronGroupChildren(m_MergedInSynWUMPostCodeGroups, typeContext, typeEnvironment, backend, + &NeuronGroupInternal::getFusedInSynWithPostCode, &SynapseGroupInternal::getWUPostHashDigest); + + // Build vector of vectors containing each child group's outgoing synapse groups + // with presynaptic synaptic updates, ordered to match those of the archetype group + orderNeuronGroupChildren(m_MergedOutSynWUMPreCodeGroups, typeContext, typeEnvironment, backend, + &NeuronGroupInternal::getFusedOutSynWithPreCode, &SynapseGroupInternal::getWUPreHashDigest); } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type NeuronUpdateGroupMerged::getHashDigest() const From a634c65026691bd8438c0529fbb0ac3b8f03ea6a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 2 Jun 2023 15:05:51 +0100 Subject: [PATCH 195/725] moved some header includes around and started hooking up init group merged --- .../code_generator/customUpdateGroupMerged.h | 4 - .../genn/genn/code_generator/groupMerged.h | 1 + .../genn/code_generator/initGroupMerged.h | 22 ++--- .../code_generator/neuronUpdateGroupMerged.h | 4 - .../genn/code_generator/initGroupMerged.cc | 89 +++++++++++-------- .../code_generator/neuronUpdateGroupMerged.cc | 1 + 6 files changed, 64 insertions(+), 57 deletions(-) diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index 31df29b5a6..c96fcd3825 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -5,10 +5,6 @@ #include "code_generator/environment.h" #include "code_generator/groupMerged.h" -// GeNN transpiler includes -#include "transpiler/statement.h" -#include "transpiler/typeChecker.h" - //---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomUpdateGroupMerged //---------------------------------------------------------------------------- diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 5ced7aac1e..3d3066ad8d 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -20,6 +20,7 @@ #include "code_generator/codeGenUtils.h" // GeNN transpiler includes +#include "transpiler/statement.h" #include "transpiler/typeChecker.h" // Forward declarations diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index ca48a3f5b1..981898fb30 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -18,8 +18,8 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase class CurrentSource : public GroupMerged { public: - CurrentSource(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + CurrentSource(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- // Public API @@ -51,8 +51,8 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase class InSynPSM : public GroupMerged { public: - InSynPSM(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + InSynPSM(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- // Public API @@ -84,8 +84,8 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase class OutSynPreOutput : public GroupMerged { public: - OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- // Public API @@ -101,8 +101,8 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase class InSynWUMPostVars : public GroupMerged { public: - InSynWUMPostVars(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + InSynWUMPostVars(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- // Public API @@ -134,8 +134,8 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase class OutSynWUMPreVars: public GroupMerged { public: - OutSynWUMPreVars(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + OutSynWUMPreVars(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups); //---------------------------------------------------------------------------- // Public API @@ -222,7 +222,7 @@ class GENN_EXPORT SynapseInitGroupMerged : public SynapseGroupMergedBase { public: SynapseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) + const std::vector> &groups) : SynapseGroupMergedBase(index, typeContext, backend, SynapseGroupMergedBase::Role::Init, "", groups) {} diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index e387d8c674..0a076b7b6c 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -3,10 +3,6 @@ // GeNN code generator includes #include "code_generator/groupMerged.h" -// GeNN transpiler includes -#include "transpiler/statement.h" -#include "transpiler/typeChecker.h" - //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 216ccefdbd..62c0192774 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -1,10 +1,20 @@ #include "code_generator/initGroupMerged.h" // GeNN code generator includes +#include "code_generator/groupMergedTypeEnvironment.h" #include "code_generator/modelSpecMerged.h" +// GeNN transpiler includes +#include "transpiler/errorHandler.h" +#include "transpiler/parser.h" +#include "transpiler/prettyPrinter.h" +#include "transpiler/scanner.h" +#include "transpiler/standardLibrary.h" +#include "transpiler/typeChecker.h" + using namespace GeNN; using namespace GeNN::CodeGenerator; +using namespace GeNN::Transpiler; //-------------------------------------------------------------------------- // Anonymous namespace @@ -183,8 +193,8 @@ void genInitWUVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::CurrentSource //---------------------------------------------------------------------------- -NeuronInitGroupMerged::CurrentSource::CurrentSource(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) +NeuronInitGroupMerged::CurrentSource::CurrentSource(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, typeContext, groups) { const std::string suffix = "CS" + std::to_string(getIndex()); @@ -257,8 +267,8 @@ bool NeuronInitGroupMerged::CurrentSource::isVarInitParamReferenced(const std::s //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::InSynPSM //---------------------------------------------------------------------------- -NeuronInitGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) +NeuronInitGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, typeContext, groups) { const std::string suffix = "InSyn" + std::to_string(getIndex()); @@ -375,8 +385,8 @@ bool NeuronInitGroupMerged::InSynPSM::isVarInitParamReferenced(const std::string //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynPreOutput //---------------------------------------------------------------------------- -NeuronInitGroupMerged::OutSynPreOutput::OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) +NeuronInitGroupMerged::OutSynPreOutput::OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &, + const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, typeContext, groups) { const std::string suffix = "OutSyn" + std::to_string(getIndex()); @@ -402,8 +412,8 @@ void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::InSynWUMPostVars //---------------------------------------------------------------------------- -NeuronInitGroupMerged::InSynWUMPostVars::InSynWUMPostVars(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) +NeuronInitGroupMerged::InSynWUMPostVars::InSynWUMPostVars(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, typeContext, groups) { const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); @@ -477,8 +487,8 @@ bool NeuronInitGroupMerged::InSynWUMPostVars::isVarInitParamReferenced(const std //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynWUMPreVars //---------------------------------------------------------------------------- -NeuronInitGroupMerged::OutSynWUMPreVars::OutSynWUMPreVars(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) +NeuronInitGroupMerged::OutSynWUMPreVars::OutSynWUMPreVars(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, + const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, typeContext, groups) { const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); @@ -558,38 +568,14 @@ NeuronInitGroupMerged::NeuronInitGroupMerged(size_t index, const Type::TypeConte const std::vector> &groups) : NeuronGroupMergedBase(index, typeContext, backend, groups) { - // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, backend, - &NeuronGroupInternal::getFusedPSMInSyn, - &SynapseGroupInternal::getPSInitHashDigest ); - - // Build vector of vectors containing each child group's merged out syns with pre output, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedOutSynPreOutputGroups, typeContext, backend, - &NeuronGroupInternal::getFusedPreOutputOutSyn, - &SynapseGroupInternal::getPreOutputInitHashDigest ); - - // Build vector of vectors containing each child group's current sources, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedCurrentSourceGroups, typeContext, backend, - &NeuronGroupInternal::getCurrentSources, - &CurrentSourceInternal::getInitHashDigest ); - - - // Build vector of vectors containing each child group's incoming synapse groups - // with postsynaptic weight update model variable, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedInSynWUMPostVarGroups, typeContext, backend, - &NeuronGroupInternal::getFusedInSynWithPostVars, - &SynapseGroupInternal::getWUPostInitHashDigest); - - // Build vector of vectors containing each child group's outgoing synapse groups - // with presynaptic weight update model variables, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedOutSynWUMPreVarGroups, typeContext, backend, - &NeuronGroupInternal::getFusedOutSynWithPreVars, - &SynapseGroupInternal::getWUPreInitHashDigest); - + // Create type environment + StandardLibrary::FunctionTypes stdLibraryEnv; + GroupMergedTypeEnvironment typeEnvironment(*this, &stdLibraryEnv); if(backend.isPopulationRNGRequired() && getArchetype().isSimRNGRequired() && backend.isPopulationRNGInitialisedOnDevice()) { + // **TODO** inject RNG types into environment addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); } @@ -614,6 +600,33 @@ NeuronInitGroupMerged::NeuronInitGroupMerged(size_t index, const Type::TypeConte addHeterogeneousVarInitDerivedParams( &NeuronGroupMergedBase::isVarInitDerivedParamHeterogeneous); + + // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group + orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, typeEnvironment, backend, + &NeuronGroupInternal::getFusedPSMInSyn, + &SynapseGroupInternal::getPSInitHashDigest ); + + // Build vector of vectors containing each child group's merged out syns with pre output, ordered to match those of the archetype group + orderNeuronGroupChildren(m_MergedOutSynPreOutputGroups, typeContext, typeEnvironment, backend, + &NeuronGroupInternal::getFusedPreOutputOutSyn, + &SynapseGroupInternal::getPreOutputInitHashDigest ); + + // Build vector of vectors containing each child group's current sources, ordered to match those of the archetype group + orderNeuronGroupChildren(m_MergedCurrentSourceGroups, typeContext, typeEnvironment, backend, + &NeuronGroupInternal::getCurrentSources, + &CurrentSourceInternal::getInitHashDigest ); + + // Build vector of vectors containing each child group's incoming synapse groups + // with postsynaptic weight update model variable, ordered to match those of the archetype group + orderNeuronGroupChildren(m_MergedInSynWUMPostVarGroups, typeContext, typeEnvironment, backend, + &NeuronGroupInternal::getFusedInSynWithPostVars, + &SynapseGroupInternal::getWUPostInitHashDigest); + + // Build vector of vectors containing each child group's outgoing synapse groups + // with presynaptic weight update model variables, ordered to match those of the archetype group + orderNeuronGroupChildren(m_MergedOutSynWUMPreVarGroups, typeContext, typeEnvironment, backend, + &NeuronGroupInternal::getFusedOutSynWithPreVars, + &SynapseGroupInternal::getWUPreInitHashDigest); } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type NeuronInitGroupMerged::getHashDigest() const diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 2232d9a16e..d484b288f7 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -108,6 +108,7 @@ bool NeuronUpdateGroupMerged::CurrentSource::isDerivedParamHeterogeneous( const { return isParamValueHeterogeneous(paramName, [](const CurrentSourceInternal &cs) { return cs.getDerivedParams(); }); } + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynPSM //---------------------------------------------------------------------------- From 227271a98e8502372c0794d49759ad5a8414eb81 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 2 Jun 2023 18:13:03 +0100 Subject: [PATCH 196/725] started hooking up code generation --- .../genn/genn/code_generator/codeGenUtils.h | 11 + .../genn/genn/code_generator/environment.h | 23 +- .../groupMergedTypeEnvironment.h | 45 ++ .../genn/code_generator/initGroupMerged.h | 24 ++ .../code_generator/neuronUpdateGroupMerged.h | 68 +-- src/genn/genn/code_generator/codeGenUtils.cc | 24 ++ .../code_generator/customUpdateGroupMerged.cc | 22 +- src/genn/genn/code_generator/groupMerged.cc | 1 + .../genn/code_generator/initGroupMerged.cc | 46 +- .../code_generator/neuronUpdateGroupMerged.cc | 398 ++++++++---------- 10 files changed, 380 insertions(+), 282 deletions(-) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 8363bd2a92..5fc2ae9f97 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -21,6 +21,10 @@ #include "substitutions.h" #include "teeStream.h" +// GeNN transpiler includes +#include "transpiler/statement.h" +#include "transpiler/typeChecker.h" + //-------------------------------------------------------------------------- // GeNN::CodeGenerator //-------------------------------------------------------------------------- @@ -93,6 +97,13 @@ GENN_EXPORT std::string disambiguateNamespaceFunction(const std::string supportC //-------------------------------------------------------------------------- GENN_EXPORT std::string upgradeCodeString(const std::string &codeString); +//-------------------------------------------------------------------------- +/*! \brief This function uses the transpiler to scan, parse and type check a code string + */ + //-------------------------------------------------------------------------- +GENN_EXPORT std::tuple scanParseAndTypeCheck( + const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, Transpiler::ErrorHandlerBase &errorHandler); + //------------------------------------------------------------------------- /*! \brief Function for performing the code and value substitutions necessary to insert neuron related variables, parameters, and extraGlobal parameters into synaptic code. diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index b60d7ef220..2c782db246 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -173,11 +173,23 @@ class EnvironmentLocalVarCache : public EnvironmentExternal using InitialiserType = typename std::remove_reference_t>::mapped_type; //! Function used to provide index strings based on initialiser and access type - using GetIndexFn = std::function; + using GetIndexFn = std::function; public: + EnvironmentLocalVarCache(const G &group, const Type::TypeContext &context, EnvironmentExternal &enclosing, + GetIndexFn getReadIndex, GetIndexFn getWriteIndex, const std::string &localPrefix = "l") + : EnvironmentExternal(static_cast(enclosing)), m_Group(group), m_Context(context), m_Contents(m_ContentsStream), + m_LocalPrefix(localPrefix), m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex) + { + // Add name of each definition to map, initially with value set to value + const auto defs = A(m_Group).getDefs(); + std::transform(defs.cbegin(), defs.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), + [](const auto &v){ return std::make_pair(v.name, false); }); + } + EnvironmentLocalVarCache(const G &group, const Type::TypeContext &context, EnvironmentExternal &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") - : EnvironmentExternal(static_cast(enclosing)), m_Group(group), m_Context(context), m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetIndex(getIndex) + : EnvironmentExternal(static_cast(enclosing)), m_Group(group), m_Context(context), + m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetReadIndex(getIndex), m_GetWriteIndex(getIndex) { // Add name of each definition to map, initially with value set to value const auto defs = A(m_Group).getDefs(); @@ -209,7 +221,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal // **NOTE** by not initialising these variables for reductions, // compilers SHOULD emit a warning if user code doesn't set it to something if(!(v.access & VarAccessModeAttribute::REDUCE)) { - getContextStream() << " = group->" << v.name << "[" << m_GetIndex(initialisers.at(v.name), v.access) << "]"; + getContextStream() << " = group->" << v.name << "[" << m_GetReadIndex(v.name, initialisers.at(v.name), v.access) << "]"; } getContextStream() << ";" << std::endl; } @@ -221,7 +233,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal for(const auto &v : referencedVars) { // If variables are read-write if(v.access & VarAccessMode::READ_WRITE) { - getContextStream() << "group->" << v.name << "[" << m_GetIndex(initialisers.at(v.name), v.access) << "]"; + getContextStream() << "group->" << v.name << "[" << m_GetWriteIndex(v.name, initialisers.at(v.name), v.access) << "]"; getContextStream() << " = " << m_LocalPrefix << v.name << ";" << std::endl; } } @@ -261,7 +273,8 @@ class EnvironmentLocalVarCache : public EnvironmentExternal std::ostringstream m_ContentsStream; CodeStream m_Contents; std::string m_LocalPrefix; - GetIndexFn m_GetIndex; + GetIndexFn m_GetReadIndex; + GetIndexFn m_GetWriteIndex; std::unordered_map m_VariablesReferenced; }; } // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index 812b429f9b..72da9c1a33 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -24,6 +24,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa using TypeCheckError = Transpiler::TypeChecker::TypeCheckError; using IsHeterogeneousFn = bool (G::*)(const std::string&) const; + using IsVarInitHeterogeneousFn = bool (G::*)(const std::string&, const std::string&) const; using GroupInternal = typename G::GroupInternal; using GetVarSuffixFn = const std::string &(GroupInternal::*)(void) const; @@ -153,6 +154,50 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa } } + template + void defineHeterogeneousVarInitParams(IsVarInitHeterogeneousFn isHeterogeneous, const std::string &fieldSuffix = "") + { + // Loop through weight update model variables + const A archetypeAdaptor(m_GroupMerged.getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + // Loop through parameters + for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getParams()) { + if(std::invoke(isHeterogeneous, m_GroupMerged, v.name, p.first)) { + defineScalarField(p.first, v.name + fieldSuffix, + [p, v](const auto &g, size_t) + { + return A(g).getInitialisers().at(v.name).getParams().at(p.first); + }); + } + else { + defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), p.first + v.name); + } + } + } + } + + template + void defineHeterogeneousVarInitDerivedParams(IsVarInitHeterogeneousFn isHeterogeneous, const std::string &fieldSuffix = "") + { + // Loop through weight update model variables + const A archetypeAdaptor(m_GroupMerged.getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + // Loop through parameters + for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getDerivedParams()) { + if(std::invoke(isHeterogeneous, m_GroupMerged, v.name, p.first)) { + defineScalarField(p.first, v.name + fieldSuffix, + [p, v](const auto &g, size_t) + { + return A(g).getInitialisers().at(v.name).getDerivedParams().at(p.first); + }); + } + else { + defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), p.first + v.name); + } + } + } + } + void defineVars(const Models::Base::VarVec &vars, const std::string &arrayPrefix, const std::string &fieldSuffix = "", GetVarSuffixFn getVarSuffixFn = &GroupInternal::getName) { diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 981898fb30..90da248205 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -42,6 +42,12 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- //! Is the var init parameter referenced? bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; + + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + //! Parsed statements and resolved types for initialising each variable + std::unordered_map> m_VarInitASTs; }; //---------------------------------------------------------------------------- @@ -75,6 +81,12 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- //! Is the var init parameter referenced? bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; + + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + //! Parsed statements and resolved types for initialising each variable + std::unordered_map> m_VarInitASTs; }; //---------------------------------------------------------------------------- @@ -125,6 +137,12 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- //! Is the var init parameter referenced? bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; + + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + //! Parsed statements and resolved types for initialising each variable + std::unordered_map> m_VarInitASTs; }; //---------------------------------------------------------------------------- @@ -158,6 +176,12 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- //! Is the var init parameter referenced? bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; + + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + //! Parsed statements and resolved types for initialising each variable + std::unordered_map> m_VarInitASTs; }; NeuronInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index 0a076b7b6c..1cbe238caa 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -24,8 +24,8 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) const; //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -47,10 +47,10 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase // Members //---------------------------------------------------------------------------- //! List of statements parsed and type-checked in constructor; and used to generate code - Transpiler::Statement::StatementList m_UpdateStatements; + Transpiler::Statement::StatementList m_InjectionStatements; //! Resolved types used to generate code - Transpiler::TypeChecker::ResolvedTypeMap m_ResolvedTypes; + Transpiler::TypeChecker::ResolvedTypeMap m_InjectionResolvedTypes; }; //---------------------------------------------------------------------------- @@ -66,8 +66,8 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) const; //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -88,11 +88,17 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Members //---------------------------------------------------------------------------- - //! List of statements parsed and type-checked in constructor; and used to generate code - Transpiler::Statement::StatementList m_UpdateStatements; + //! List of statements parsed and type-checked in constructor; and used to generate decay code + Transpiler::Statement::StatementList m_DecayStatements; + + //! List of statements parsed and type-checked in constructor; and used to generate apply inputcode + Transpiler::Statement::StatementList m_ApplyInputStatements; - //! Resolved types used to generate code - Transpiler::TypeChecker::ResolvedTypeMap m_ResolvedTypes; + //! Resolved types used to generate decay code + Transpiler::TypeChecker::ResolvedTypeMap m_DecayResolvedTypes; + + //! Resolved types used to generate apply input code + Transpiler::TypeChecker::ResolvedTypeMap m_ApplyInputResolvedTypes; }; //---------------------------------------------------------------------------- @@ -108,8 +114,8 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generate(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged) const; }; //---------------------------------------------------------------------------- @@ -125,8 +131,8 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike) const; + void generate(const BackendBase &backend, EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) const; void genCopyDelayedVars(CodeStream &os, const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; @@ -150,11 +156,17 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Members //---------------------------------------------------------------------------- - //! List of statements parsed and type-checked in constructor; and used to generate code - Transpiler::Statement::StatementList m_UpdateStatements; + //! List of statements parsed and type-checked in constructor; and used to generate dynamics code + Transpiler::Statement::StatementList m_DynamicsStatements; - //! Resolved types used to generate code - Transpiler::TypeChecker::ResolvedTypeMap m_ResolvedTypes; + //! List of statements parsed and type-checked in constructor; and used to generate spike code + Transpiler::Statement::StatementList m_SpikeStatements; + + //! Resolved types used to generate dynamics code + Transpiler::TypeChecker::ResolvedTypeMap m_DynamicsResolvedTypes; + + //! Resolved types used to generate spike code + Transpiler::TypeChecker::ResolvedTypeMap m_SpikeResolvedTypes; }; //---------------------------------------------------------------------------- @@ -195,11 +207,17 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Members //---------------------------------------------------------------------------- - //! List of statements parsed and type-checked in constructor; and used to generate code - Transpiler::Statement::StatementList m_UpdateStatements; + //! List of statements parsed and type-checked in constructor; and used to generate dynamics code + Transpiler::Statement::StatementList m_DynamicsStatements; - //! Resolved types used to generate code - Transpiler::TypeChecker::ResolvedTypeMap m_ResolvedTypes; + //! List of statements parsed and type-checked in constructor; and used to generate spike code + Transpiler::Statement::StatementList m_SpikeStatements; + + //! Resolved types used to generate dynamics code + Transpiler::TypeChecker::ResolvedTypeMap m_DynamicsResolvedTypes; + + //! Resolved types used to generate spike code + Transpiler::TypeChecker::ResolvedTypeMap m_SpikeResolvedTypes; }; NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, @@ -220,9 +238,9 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase runnerVarDecl, runnerMergedStructAlloc, name); } - void generateNeuronUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs, - BackendBase::GroupHandler genEmitTrueSpike, - BackendBase::GroupHandler genEmitSpikeLikeEvent) const; + void generateNeuronUpdate(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged, + BackendBase::GroupHandlerEnv genEmitTrueSpike, + BackendBase::GroupHandlerEnv genEmitSpikeLikeEvent) const; void generateWUVarUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index 327e5e7725..de686efbaa 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -28,6 +28,10 @@ #include "code_generator/groupMerged.h" #include "code_generator/substitutions.h" +// GeNN transpiler includes +#include "transpiler/parser.h" +#include "transpiler/scanner.h" + //-------------------------------------------------------------------------- // Anonymous namespace //-------------------------------------------------------------------------- @@ -480,5 +484,25 @@ std::string upgradeCodeString(const std::string &codeString) return std::regex_replace(codeString, variable, "$1"); } +//---------------------------------------------------------------------------- +std::tuple scanParseAndTypeCheck( + const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, Transpiler::ErrorHandlerBase &errorHandler) +{ + using namespace Transpiler; + + // Upgrade code string + const std::string upgradedCode = upgradeCodeString(code); + // Scan code string to convert to tokens + const auto tokens = Scanner::scanSource(upgradedCode, typeContext, errorHandler); + + // Parse tokens as block item list (function body) + auto updateStatements = Parser::parseBlockItemList(tokens, typeContext, errorHandler); + + // Resolve types + auto resolvedTypes= TypeChecker::typeCheck(updateStatements, environment, errorHandler); + + // Move into tuple and eturn + return std::make_tuple(std::move(updateStatements), std::move(resolvedTypes)); +} } // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 140de1de16..a04bad3e84 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -72,10 +72,8 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::TypeC // Scan, parse and type-check update code ErrorHandler errorHandler; - const std::string code = upgradeCodeString(cm->getUpdateCode()); - const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); - m_UpdateStatements = Parser::parseBlockItemList(tokens, typeContext, errorHandler); - m_ResolvedTypes = TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, errorHandler); + std::tie(m_UpdateStatements, m_ResolvedTypes) = scanParseAndTypeCheck(cm->getUpdateCode(), typeContext, + typeEnvironment, errorHandler); } //---------------------------------------------------------------------------- bool CustomUpdateGroupMerged::isParamHeterogeneous(const std::string ¶mName) const @@ -108,7 +106,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() //---------------------------------------------------------------------------- void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, EnvironmentExternal &env) const { - // Add parameters, derived parameters and EGPs to environment + // Add parameters, derived parameters and EGPs to environment EnvironmentSubstitute envSubs(env); const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); envSubs.addParamValueSubstitution(cm->getParamNames(), getArchetype().getParams(), @@ -120,7 +118,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varSubs( getArchetype(), getTypeContext(), envSubs, - [this](const Models::VarInit&, VarAccess a) + [this](const std::string&, const Models::VarInit&, VarAccess a) { return getVarIndex(getVarAccessDuplication(a), "id"); }); @@ -128,7 +126,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E // Create an environment which caches variable references in local variables if they are accessed EnvironmentLocalVarCache varRefSubs( getArchetype(), getTypeContext(), varSubs, - [this](const Models::VarReference &v, VarAccessMode) + [this](const std::string&, const Models::VarReference &v, VarAccessMode) { return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, getVarAccessDuplication(v.getVar().access), @@ -229,7 +227,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varSubs( getArchetype(), getTypeContext(), envSubs, - [this](const Models::VarInit&, VarAccess a) + [this](const std::string&, const Models::VarInit&, VarAccess a) { return getVarIndex(getVarAccessDuplication(a), "id_syn"); }); @@ -237,7 +235,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back // Create an environment which caches variable references in local variables if they are accessed EnvironmentLocalVarCache varRefSubs( getArchetype(), getTypeContext(), varSubs, - [this](const Models::WUVarReference &v, VarAccessMode) + [this](const std::string&, const Models::WUVarReference &v, VarAccessMode) { return getVarRefIndex(getVarAccessDuplication(v.getVar().access), "id_syn"); @@ -360,10 +358,8 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // Scan, parse and type-check update code ErrorHandler errorHandler; - const std::string code = upgradeCodeString(cm->getUpdateCode()); - const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); - m_UpdateStatements = Parser::parseBlockItemList(tokens, typeContext, errorHandler); - m_ResolvedTypes = TypeChecker::typeCheck(m_UpdateStatements, typeEnvironment, errorHandler); + std::tie(m_UpdateStatements, m_ResolvedTypes) = scanParseAndTypeCheck(cm->getUpdateCode(), typeContext, + typeEnvironment, errorHandler); } // ---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 67bcd89536..0454d950a8 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -147,6 +147,7 @@ NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeConte addPointerField(Uint32, "spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); } + // **TODO** add to type environment for update if(getArchetype().isSpikeTimeRequired()) { addPointerField(getTimeType(), "sT", backend.getDeviceVarPrefix() + "sT"); } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 62c0192774..6e33528fc7 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -290,26 +290,34 @@ NeuronInitGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContext // **TODO** adaptor const auto &varInit = getArchetype().getPSVarInitialisers(); for(const auto &var : getArchetype().getPSModel()->getVars()) { - // Add pointers to state variable - if(!varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(var.type.resolve(getTypeContext()).createPointer(), var.name + suffix, - [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getFusedPSVarSuffix(); }); - } - - // Add heterogeneous var init parameters - addHeterogeneousVarInitParams( - &InSynPSM::isVarInitParamHeterogeneous, suffix); - addHeterogeneousVarInitDerivedParams( - &InSynPSM::isVarInitDerivedParamHeterogeneous, suffix); + // If there is any initialisation code + const auto *snippet = varInit.at(var.name).getSnippet(); + if (!snippet->getCode().empty()) { + // Create type environment for this variable's initialisation + GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); + + // Add pointers to state variable itself + typeEnvironment.definePointerField(var.type, var.name, backend.getDeviceVarPrefix(), + getVarAccessMode(var.access), suffix, &SynapseGroupInternal::getFusedPSVarSuffix); + + // Add heterogeneous var init parameters + typeEnvironment.defineHeterogeneousVarInitParams(&InSynPSM::isVarInitParamHeterogeneous, suffix); + typeEnvironment.defineHeterogeneousVarInitDerivedParams(&InSynPSM::isVarInitDerivedParamHeterogeneous, suffix); + + // Add EGPs + typeEnvironment.defineEGPs(snippet->getExtraGlobalParams(), backend.getDeviceVarPrefix(), + var.name, suffix, &SynapseGroupInternal::getFusedPSVarSuffix); + + // Scan, parse and type-check update code + ErrorHandler errorHandler; + const std::string code = upgradeCodeString(snippet->getCode()); + const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); + + auto initStatements = Parser::parseBlockItemList(tokens, typeContext, errorHandler); + auto initTypes = TypeChecker::typeCheck(initStatements, typeEnvironment, errorHandler); - // Add extra global parameters - for(const auto &e : varInit.at(var.name).getSnippet()->getExtraGlobalParams()) { - addField(e.type.resolve(getTypeContext()).createPointer(), e.name + var.name + suffix, - [&backend, e, suffix, var](const auto &g, size_t) - { - return backend.getDeviceVarPrefix() + e.name + var.name + g.getFusedPSVarSuffix(); - }, - GroupMergedFieldType::DYNAMIC); + // Add to map of per-variable initialistion AST + m_VarInitASTs.emplace(var.name, std::make_tuple(std::move(initStatements), std::move(initTypes))); } } } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index d484b288f7..729089a771 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -48,47 +48,40 @@ NeuronUpdateGroupMerged::CurrentSource::CurrentSource(size_t index, const Type:: // Add EGPs typeEnvironment.defineEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", suffix); + + // Scan, parse and type-check injection code + ErrorHandler errorHandler; + std::tie(m_InjectionStatements, m_InjectionResolvedTypes) = scanParseAndTypeCheck(cm->getInjectionCode(), typeContext, + typeEnvironment, errorHandler); } //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) const { - os << "// current source " << getIndex() << std::endl; - - // Read current source variables into registers const std::string suffix = "CS" + std::to_string(getIndex()); - for(const auto &v : getArchetype().getCurrentSourceModel()->getVars()) { - if(v.access & VarAccessMode::READ_ONLY) { - os << "const "; - } - os << v.type.resolve(getTypeContext()).getName() << " lcs" << v.name << " = " << "group->" << v.name << suffix << "["; - os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), popSubs["id"]) << "];" << std::endl; - } + const auto *cm = getArchetype().getCurrentSourceModel(); - Substitutions currSourceSubs(&popSubs); - currSourceSubs.addFuncSubstitution("injectCurrent", 1, "Isyn += $(0)"); - currSourceSubs.addVarNameSubstitution(getArchetype().getCurrentSourceModel()->getVars(), "", "lcs"); - currSourceSubs.addParamValueSubstitution(getArchetype().getCurrentSourceModel()->getParamNames(), getArchetype().getParams(), - [this](const std::string &p) { return isParamHeterogeneous(p); }, - "", "group->", suffix); - currSourceSubs.addVarValueSubstitution(getArchetype().getCurrentSourceModel()->getDerivedParams(), getArchetype().getDerivedParams(), - [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }, - "", "group->", suffix); - currSourceSubs.addVarNameSubstitution(getArchetype().getCurrentSourceModel()->getExtraGlobalParams(), "", "group->", suffix); - - std::string iCode = getArchetype().getCurrentSourceModel()->getInjectionCode(); - currSourceSubs.applyCheckUnreplaced(iCode, "injectionCode : merged" + getIndex()); - //iCode = ensureFtype(iCode, model.getPrecision()); - os << iCode << std::endl; - - // Write read/write variables back to global memory - for(const auto &v : getArchetype().getCurrentSourceModel()->getVars()) { - if(v.access & VarAccessMode::READ_WRITE) { - os << "group->" << v.name << suffix << "["; - os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), currSourceSubs["id"]); - os << "] = lcs" << v.name << ";" << std::endl; - } - } + // Create new substitution environment and add parameters, derived parameters and extra global parameters + EnvironmentSubstitute envSubs(env); + envSubs.getStream() << "// current source " << getIndex() << std::endl; + envSubs.addParamValueSubstitution(cm->getParamNames(), getArchetype().getParams(), + [this](const std::string &p) { return isParamHeterogeneous(p); }); + envSubs.addVarValueSubstitution(cm->getDerivedParams(), getArchetype().getDerivedParams(), + [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); + envSubs.addVarNameSubstitution(cm->getExtraGlobalParams()); + + // Create an environment which caches variables in local variables if they are accessed + EnvironmentLocalVarCache varSubs( + getArchetype(), getTypeContext(), envSubs, + [&modelMerged, &ng](const std::string&, const Models::VarInit&, VarAccess a) + { + return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), "id"); + }); + + //currSourceSubs.addFuncSubstitution("injectCurrent", 1, "Isyn += $(0)"); + + // Pretty print previously parsed update statements + PrettyPrinter::print(m_InjectionStatements, varSubs, getTypeContext(), m_InjectionResolvedTypes); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::CurrentSource::updateHash(boost::uuids::detail::sha1 &hash) const @@ -152,98 +145,73 @@ NeuronUpdateGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContex // Add EGPs typeEnvironment.defineEGPs(psm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", suffix, &SynapseGroupInternal::getFusedPSVarSuffix); + + // Scan, parse and type-check decay and apply input code + ErrorHandler errorHandler; + std::tie(m_DecayStatements, m_DecayResolvedTypes) = scanParseAndTypeCheck(psm->getDecayCode(), typeContext, + typeEnvironment, errorHandler); + std::tie(m_ApplyInputStatements, m_ApplyInputResolvedTypes) = scanParseAndTypeCheck(psm->getApplyInputCode(), typeContext, + typeEnvironment, errorHandler); } //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) const { const std::string suffix = "InSyn" + std::to_string(getIndex()); const auto *psm = getArchetype().getPSModel(); - os << "// pull inSyn values in a coalesced access" << std::endl; - os << "scalar linSyn = group->inSynInSyn" << getIndex() << "["; - os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, popSubs["id"]); - os << "];" << std::endl; + // Create new substitution environment + EnvironmentSubstitute envSubs(env); + + envSubs.getStream() << "// current source " << getIndex() << std::endl; + envSubs.getStream() << "scalar linSyn = group->inSynInSyn" << getIndex() << "["; + envSubs.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); + envSubs.getStream() << "];" << std::endl; // If dendritic delay is required if (getArchetype().isDendriticDelayRequired()) { // Get reference to dendritic delay buffer input for this timestep - os << backend.getPointerPrefix() << "scalar *denDelayFront = "; - os << "&group->denDelay" << suffix << "[(*group->denDelayPtr" << suffix << " * group->numNeurons) + "; - os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, popSubs["id"]); - os << "];" << std::endl; + envSubs.getStream() << backend.getPointerPrefix() << "scalar *denDelayFront = "; + envSubs.getStream() << "&group->denDelay" << suffix << "[(*group->denDelayPtr" << suffix << " * group->numNeurons) + "; + envSubs.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); + envSubs.getStream() << "];" << std::endl; // Add delayed input from buffer into inSyn - os << "linSyn += *denDelayFront;" << std::endl; + envSubs.getStream() << "linSyn += *denDelayFront;" << std::endl; // Zero delay buffer slot - os << "*denDelayFront = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; + envSubs.getStream() << "*denDelayFront = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; } - // Pull postsynaptic model variables in a coalesced access - for (const auto &v : psm->getVars()) { - if(v.access & VarAccessMode::READ_ONLY) { - os << "const "; - } - os << v.type.resolve(getTypeContext()).getName() << " lps" << v.name << " = group->" << v.name << suffix << "["; - os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), popSubs["id"]); - os << "];" << std::endl; - } - - Substitutions inSynSubs(&popSubs); - inSynSubs.addVarSubstitution("inSyn", "linSyn"); + // Add parameters, derived parameters and extra global parameters to environment + envSubs.addParamValueSubstitution(psm->getParamNames(), getArchetype().getPSParams(), + [this](const std::string &p) { return isParamHeterogeneous(p); }); + envSubs.addVarValueSubstitution(psm->getDerivedParams(), getArchetype().getPSDerivedParams(), + [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); + envSubs.addVarNameSubstitution(psm->getExtraGlobalParams()); + + // **TODO** naming convention + envSubs.addSubstitution("inSyn", "linSyn"); // Allow synapse group's PS output var to override what Isyn points to - inSynSubs.addVarSubstitution("Isyn", getArchetype().getPSTargetVar(), true); - inSynSubs.addVarNameSubstitution(psm->getVars(), "", "lps"); - - inSynSubs.addParamValueSubstitution(psm->getParamNames(), getArchetype().getPSParams(), - [this](const std::string &p) { return isParamHeterogeneous(p); }, - "", "group->", suffix); - inSynSubs.addVarValueSubstitution(psm->getDerivedParams(), getArchetype().getPSDerivedParams(), - [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }, - "", "group->", suffix); - inSynSubs.addVarNameSubstitution(psm->getExtraGlobalParams(), "", "group->", suffix); - - // Apply substitutions to current converter code - std::string psCode = psm->getApplyInputCode(); - inSynSubs.applyCheckUnreplaced(psCode, "postSyntoCurrent : merged " + getIndex()); - //psCode = ensureFtype(psCode, model.getPrecision()); - - // Apply substitutions to decay code - std::string pdCode = psm->getDecayCode(); - inSynSubs.applyCheckUnreplaced(pdCode, "decayCode : merged " + getIndex()); - //pdCode = ensureFtype(pdCode, model.getPrecision()); - - if (!psm->getSupportCode().empty() && backend.supportsNamespace()) { - os << "using namespace " << modelMerged.getPostsynapticDynamicsSupportCodeNamespace(psm->getSupportCode()) << ";" << std::endl; - } - - if (!psm->getSupportCode().empty() && !backend.supportsNamespace()) { - psCode = disambiguateNamespaceFunction(psm->getSupportCode(), psCode, modelMerged.getPostsynapticDynamicsSupportCodeNamespace(psm->getSupportCode())); - pdCode = disambiguateNamespaceFunction(psm->getSupportCode(), pdCode, modelMerged.getPostsynapticDynamicsSupportCodeNamespace(psm->getSupportCode())); - } + envSubs.addSubstitution("Isyn", getArchetype().getPSTargetVar()); - os << psCode << std::endl; - os << pdCode << std::endl; + // Create an environment which caches variables in local variables if they are accessed + EnvironmentLocalVarCache varSubs( + getArchetype(), getTypeContext(), envSubs, + [&modelMerged, &ng](const std::string&, const Models::VarInit&, VarAccess a) + { + return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), "id"); + }); - if (!psm->getSupportCode().empty()) { - os << CodeStream::CB(29) << " // namespace bracket closed" << std::endl; - } + // Pretty print previously parsed update statements + PrettyPrinter::print(m_ApplyInputStatements, varSubs, getTypeContext(), m_ApplyInputResolvedTypes); + PrettyPrinter::print(m_DecayStatements, varSubs, getTypeContext(), m_DecayResolvedTypes); // Write back linSyn - os << "group->inSyn" << suffix << "["; - os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, inSynSubs["id"]); - os << "] = linSyn;" << std::endl; - - // Copy any non-readonly postsynaptic model variables back to global state variables dd_V etc - for (const auto &v : psm->getVars()) { - if(v.access & VarAccessMode::READ_WRITE) { - os << "group->" << v.name << suffix << "["; - os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), inSynSubs["id"]); - os << "]" << " = lps" << v.name << ";" << std::endl; - } - } + varSubs.getStream() << "group->inSyn" << suffix << "["; + varSubs.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); + varSubs.getStream() << "] = linSyn;" << std::endl; } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &hash) const @@ -277,18 +245,18 @@ NeuronUpdateGroupMerged::OutSynPreOutput::OutSynPreOutput(size_t index, const Ty [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "revInSyn" + g.getFusedPreOutputSuffix(); }); } //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void NeuronUpdateGroupMerged::OutSynPreOutput::generate(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged) const { const std::string suffix = "OutSyn" + std::to_string(getIndex()); - os << getArchetype().getPreTargetVar() << " += "; - os << "group->revInSyn" << suffix << "["; - os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, popSubs["id"]); - os << "];" << std::endl; - os << "group->revInSyn" << suffix << "["; - os << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, popSubs["id"]); - os << "] = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; + env.getStream() << getArchetype().getPreTargetVar() << " += "; + env.getStream() << "group->revInSyn" << suffix << "["; + env.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); + env.getStream() << "];" << std::endl; + env.getStream() << "group->revInSyn" << suffix << "["; + env.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); + env.getStream() << "] = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; } //---------------------------------------------------------------------------- @@ -321,42 +289,50 @@ NeuronUpdateGroupMerged::InSynWUMPostCode::InSynWUMPostCode(size_t index, const // Add EGPs typeEnvironment.defineEGPs(wum->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", suffix, &SynapseGroupInternal::getFusedWUPostVarSuffix); + + // Scan, parse and type-check dynamics and spike code + ErrorHandler errorHandler; + std::tie(m_DynamicsStatements, m_DynamicsResolvedTypes) = scanParseAndTypeCheck(wum->getPostDynamicsCode(), typeContext, + typeEnvironment, errorHandler); + std::tie(m_SpikeStatements, m_SpikeResolvedTypes) = scanParseAndTypeCheck(wum->getPostSpikeCode(), typeContext, + typeEnvironment, errorHandler); } //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike) const +void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &backend, EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) const { const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); - - const unsigned int batchSize = modelMerged.getModel().getBatchSize(); + const auto *wum = getArchetype().getWUModel(); - // If this code string isn't empty - std::string code = dynamicsNotSpike ? getArchetype().getWUModel()->getPostDynamicsCode() : getArchetype().getWUModel()->getPostSpikeCode(); - if(!code.empty()) { - Substitutions subs(&popSubs); + const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - // Fetch postsynaptic variables from global memory - os << "// perform WUM update required for merged" << getIndex() << std::endl; - const auto vars = getArchetype().getWUModel()->getPostVars(); + // If there are any statements to executre here + const auto &statements = dynamicsNotSpike ? m_DynamicsStatements : m_SpikeStatements; + const auto &resolvedTypes = dynamicsNotSpike ? m_DynamicsResolvedTypes : m_SpikeResolvedTypes; + if(!statements.empty()) { + // Create new substitution environment and add parameters, derived parameters and extra global parameters + EnvironmentSubstitute envSubs(env); + envSubs.getStream() << "// postsynaptic weight update " << getIndex() << std::endl; + envSubs.addParamValueSubstitution(wum->getParamNames(), getArchetype().getWUParams(), + [this](const std::string &p) { return isParamHeterogeneous(p); }); + envSubs.addVarValueSubstitution(wum->getDerivedParams(), getArchetype().getWUDerivedParams(), + [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); + envSubs.addVarNameSubstitution(wum->getExtraGlobalParams()); + + // Create an environment which caches variables in local variables if they are accessed const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); - for(const auto &v : vars) { - if(v.access & VarAccessMode::READ_ONLY) { - os << "const "; - } - os << v.type.resolve(getTypeContext()).getName() << " l" << v.name << " = group->" << v.name << suffix << "["; - os << ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), subs["id"]) << "];" << std::endl; - } - - subs.addParamValueSubstitution(getArchetype().getWUModel()->getParamNames(), getArchetype().getWUParams(), - [this](const std::string &p) { return isParamHeterogeneous(p); }, - "", "group->", suffix); - subs.addVarValueSubstitution(getArchetype().getWUModel()->getDerivedParams(), getArchetype().getWUDerivedParams(), - [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }, - "", "group->", suffix); - subs.addVarNameSubstitution(getArchetype().getWUModel()->getExtraGlobalParams(), "", "group->", suffix); - subs.addVarNameSubstitution(vars, "", "l"); - - neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_post", "", "", "", dynamicsNotSpike, + EnvironmentLocalVarCache varSubs( + getArchetype(), getTypeContext(), envSubs, + [batchSize, delayed, &ng](const std::string&, const Models::VarInit&, VarAccess a) + { + return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), "id"); + }, + [batchSize, delayed, &ng](const std::string&, const Models::VarInit&, VarAccess a) + { + return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), "id"); + }); + + /*neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_post", "", "", "", dynamicsNotSpike, [&ng](const std::string &p) { return ng.isParamHeterogeneous(p); }, [&ng](const std::string &p) { return ng.isDerivedParamHeterogeneous(p); }, [&subs, &ng, batchSize](bool delay, VarAccessDuplication varDuplication) @@ -366,23 +342,10 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back [&subs, &ng, batchSize](bool delay, VarAccessDuplication varDuplication) { return ng.getReadVarIndex(delay, batchSize, varDuplication, subs["id"]); - }); + });*/ - // Perform standard substitutions - subs.applyCheckUnreplaced(code, "spikeCode : merged" + getIndex()); - //code = ensureFtype(code, precision); - os << code; - - // Write back postsynaptic variables into global memory - for(const auto &v : vars) { - // If state variables is read/write - meaning that it may have been updated - or it is delayed - - // meaning that it needs to be copied into next delay slot whatever - copy neuron state variables - // back to global state variables dd_V etc - if((v.access & VarAccessMode::READ_WRITE) || delayed) { - os << "group->" << v.name << suffix << "["; - os << ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), subs["id"]) << "] = l" << v.name << ";" << std::endl; - } - } + // Pretty print previously parsed statements + PrettyPrinter::print(statements, varSubs, getTypeContext(), resolvedTypes); } } //---------------------------------------------------------------------------- @@ -455,6 +418,13 @@ NeuronUpdateGroupMerged::OutSynWUMPreCode::OutSynWUMPreCode(size_t index, const // Add EGPs typeEnvironment.defineEGPs(wum->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", suffix, &SynapseGroupInternal::getFusedWUPreVarSuffix); + + // Scan, parse and type-check dynamics and spike code + ErrorHandler errorHandler; + std::tie(m_DynamicsStatements, m_DynamicsResolvedTypes) = scanParseAndTypeCheck(wum->getPreDynamicsCode(), typeContext, + typeEnvironment, errorHandler); + std::tie(m_SpikeStatements, m_SpikeResolvedTypes) = scanParseAndTypeCheck(wum->getPreSpikeCode(), typeContext, + typeEnvironment, errorHandler); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, @@ -726,42 +696,29 @@ boost::uuids::detail::sha1::digest_type NeuronUpdateGroupMerged::getHashDigest() return hash.get_digest(); } //-------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs, - BackendBase::GroupHandler genEmitTrueSpike, - BackendBase::GroupHandler genEmitSpikeLikeEvent) const +void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged, + BackendBase::GroupHandlerEnv genEmitTrueSpike, + BackendBase::GroupHandlerEnv genEmitSpikeLikeEvent) const { const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); const NeuronModels::Base *nm = getArchetype().getNeuronModel(); + - // Generate code to copy neuron state into local variable - for(const auto &v : nm->getVars()) { - if(v.access & VarAccessMode::READ_ONLY) { - os << "const "; - } - os << v.type.resolve(getTypeContext()).getName() << " l" << v.name << " = group->" << v.name << "["; - const bool delayed = (getArchetype().isVarQueueRequired(v.name) && getArchetype().isDelayRequired()); - os << getReadVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), popSubs["id"]) << "];" << std::endl; - } - - // Also read spike and spike-like-event times into local variables if required - if(getArchetype().isSpikeTimeRequired()) { - os << "const timepoint lsT = group->sT["; - os << getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; - } - if(getArchetype().isPrevSpikeTimeRequired()) { - os << "const timepoint lprevST = group->prevST["; - os << getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; - } - if(getArchetype().isSpikeEventTimeRequired()) { - os << "const timepoint lseT = group->seT["; - os << getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; - } - if(getArchetype().isPrevSpikeEventTimeRequired()) { - os << "const timepoint lprevSET = group->prevSET["; - os << getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "];" << std::endl; - } - os << std::endl; + // Create an environment which caches variables in local variables if they are accessed + // **NOTE** we do this right at the top so that local copies can be used by child groups + EnvironmentLocalVarCache neuronVarEnv( + getArchetype(), getTypeContext(), env, + [batchSize, this](const std::string &varName, const Models::VarInit&, VarAccess a) + { + const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); + return getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), "id") ; + }, + [batchSize, this](const std::string &varName, const Models::VarInit&, VarAccess a) + { + const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); + return getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), "id") ; + }); // If neuron model sim code references ISyn (could still be the case if there are no incoming synapses) // OR any incoming synapse groups have post synaptic models which reference $(Isyn), declare it @@ -779,18 +736,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C Substitutions neuronSubs(&popSubs); neuronSubs.addVarSubstitution("Isyn", "Isyn"); - if(getArchetype().isSpikeTimeRequired()) { - neuronSubs.addVarSubstitution("sT", "lsT"); - } - if(getArchetype().isPrevSpikeTimeRequired()) { - neuronSubs.addVarSubstitution("prev_sT", "lprevST"); - } - if(getArchetype().isSpikeEventTimeRequired()) { - neuronSubs.addVarSubstitution("seT", "lseT"); - } - if(getArchetype().isPrevSpikeEventTimeRequired()) { - neuronSubs.addVarSubstitution("prev_seT", "lprevSET"); - } + neuronSubs.addVarNameSubstitution(nm->getAdditionalInputVars()); addNeuronModelSubstitutions(neuronSubs); @@ -822,9 +768,36 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C cs.generate(backend, os, *this, modelMerged, popSubs); } - if (!nm->getSupportCode().empty() && backend.supportsNamespace()) { - os << "using namespace " << modelMerged.getNeuronUpdateSupportCodeNamespace(nm->getSupportCode()) << ";" << std::endl; + // Read spike and spike-like-event times into local variables if required + EnvironmentSubstitute neuronEnv(neuronVarEnv); + if(getArchetype().isSpikeTimeRequired()) { + const size_t initialiser = neuronEnv.addInitialiser( + "const timepoint lsT = group->sT[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "id") + "];"); + neuronEnv.addSubstitution("sT", "lsT", {initialiser}); + } + if(getArchetype().isPrevSpikeTimeRequired()) { + const size_t initialiser = neuronEnv.addInitialiser( + "const timepoint lprevST = group->prevST[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "id") + "];"); + neuronEnv.addSubstitution("prev_sT", "lprevST", {initialiser}); + } + if(getArchetype().isSpikeEventTimeRequired()) { + const size_t initialiser = neuronEnv.addInitialiser( + "const timepoint lseT = group->seT[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "id") + "];"); + neuronEnv.addSubstitution("seT", "lseT", {initialiser}); + } + if(getArchetype().isPrevSpikeEventTimeRequired()) { + const size_t initialiser = neuronEnv.addInitialiser( + "const timepoint lprevSET = group->prevSET[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "id") << "];"); + neuronEnv.addSubstitution("prev_seT", "lprevSET", {initialiser}); } + neuronEnv.getStream() << std::endl; + + // Add neuron parameters, derived parameters and extra global parameters to neuron environment + neuronEnv.addParamValueSubstitution(nm->getParamNames(), getArchetype().getParams(), + [this](const std::string &p) { return isParamHeterogeneous(p); }); + neuronEnv.addVarValueSubstitution(nm->getDerivedParams(), getArchetype().getDerivedParams(), + [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); + neuronEnv.addVarNameSubstitution(nm->getExtraGlobalParams()); // If a threshold condition is provided std::string thCode = nm->getThresholdConditionCode(); @@ -854,9 +827,6 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C neuronSubs.applyCheckUnreplaced(sCode, "simCode : merged" + std::to_string(getIndex())); //sCode = ensureFtype(sCode, model.getPrecision()); - if (!nm->getSupportCode().empty() && !backend.supportsNamespace()) { - sCode = disambiguateNamespaceFunction(nm->getSupportCode(), sCode, modelMerged.getNeuronUpdateSupportCodeNamespace(nm->getSupportCode())); - } os << sCode << std::endl; @@ -1020,18 +990,6 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, C } } } - - // Loop through neuron state variables - for(const auto &v : nm->getVars()) { - // If state variables is read/writes - meaning that it may have been updated - or it is delayed - - // meaning that it needs to be copied into next delay slot whatever - copy neuron state variables - // back to global state variables dd_V etc - const bool delayed = (getArchetype().isVarQueueRequired(v.name) && getArchetype().isDelayRequired()); - if((v.access & VarAccessMode::READ_WRITE) || delayed) { - os << "group->" << v.name << "["; - os << getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), popSubs["id"]) << "] = l" << v.name << ";" << std::endl; - } - } } //-------------------------------------------------------------------------- void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const From c8213bfe0654c50d08d921afa4cfc6e36b4e674e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 5 Jun 2023 17:13:04 +0100 Subject: [PATCH 197/725] exposed functionality to type check and pretty print expressions as well as statement lists --- .../genn/genn/code_generator/codeGenUtils.h | 12 ++++++++-- include/genn/genn/transpiler/prettyPrinter.h | 11 +++++++++ include/genn/genn/transpiler/typeChecker.h | 4 ++-- src/genn/genn/code_generator/codeGenUtils.cc | 23 ++++++++++++++++++- src/genn/genn/transpiler/prettyPrinter.cc | 14 +++++++++++ src/genn/genn/transpiler/typeChecker.cc | 4 ++-- 6 files changed, 61 insertions(+), 7 deletions(-) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 5fc2ae9f97..4f693e2419 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -98,12 +98,20 @@ GENN_EXPORT std::string disambiguateNamespaceFunction(const std::string supportC GENN_EXPORT std::string upgradeCodeString(const std::string &codeString); //-------------------------------------------------------------------------- -/*! \brief This function uses the transpiler to scan, parse and type check a code string +/*! \brief This function uses the transpiler to scan, parse and type check statements contained in a code string */ //-------------------------------------------------------------------------- -GENN_EXPORT std::tuple scanParseAndTypeCheck( +GENN_EXPORT std::tuple scanParseAndTypeCheckStatements( const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, Transpiler::ErrorHandlerBase &errorHandler); +//-------------------------------------------------------------------------- +/*! \brief This function uses the transpiler to scan, parse and type check expression contained in a code string + */ + //-------------------------------------------------------------------------- +GENN_EXPORT std::tuple scanParseAndTypeCheckExpression( + const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, Transpiler::ErrorHandlerBase &errorHandler); + + //------------------------------------------------------------------------- /*! \brief Function for performing the code and value substitutions necessary to insert neuron related variables, parameters, and extraGlobal parameters into synaptic code. diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h index 34fbda05f5..03dae95df1 100644 --- a/include/genn/genn/transpiler/prettyPrinter.h +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -35,6 +35,15 @@ class EnvironmentBase //! Get stream to write code within this environment to virtual CodeGenerator::CodeStream &getStream() = 0; + + //------------------------------------------------------------------------ + // Operators + //------------------------------------------------------------------------ + std::string operator[] (const std::string &name) + { + return getName(name); + } + }; //--------------------------------------------------------------------------- @@ -42,4 +51,6 @@ class EnvironmentBase //--------------------------------------------------------------------------- void print(const Statement::StatementList &statements, EnvironmentBase &environment, const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes); +void print(const Expression::ExpressionPtr &expression, EnvironmentBase &environment, + const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes); } diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index d0cf01b92a..c073641d35 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -58,6 +58,6 @@ class EnvironmentBase ResolvedTypeMap typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, ErrorHandlerBase &errorHandler); -Type::ResolvedType typeCheck(const Expression::Base *expression, EnvironmentBase &environment, - ErrorHandlerBase &errorHandler); +ResolvedTypeMap typeCheck(const Expression::Base *expression, EnvironmentBase &environment, + ErrorHandlerBase &errorHandler); } // namespace MiniParse::GeNN::Transpiler diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index de686efbaa..bcd7636c90 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -485,7 +485,7 @@ std::string upgradeCodeString(const std::string &codeString) return std::regex_replace(codeString, variable, "$1"); } //---------------------------------------------------------------------------- -std::tuple scanParseAndTypeCheck( +std::tuple scanParseAndTypeCheckStatements( const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, Transpiler::ErrorHandlerBase &errorHandler) { using namespace Transpiler; @@ -505,4 +505,25 @@ std::tuple scanParseAndTypeCheckExpression( + const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, Transpiler::ErrorHandlerBase &errorHandler) +{ + using namespace Transpiler; + + // Upgrade code string + const std::string upgradedCode = upgradeCodeString(code); + + // Scan code string to convert to tokens + const auto tokens = Scanner::scanSource(upgradedCode, typeContext, errorHandler); + + // Parse tokens as expression + auto expression = Parser::parseExpression(tokens, typeContext, errorHandler); + + // Resolve types + auto resolvedTypes= TypeChecker::typeCheck(expression.get(), environment, errorHandler); + + // Move into tuple and eturn + return std::make_tuple(std::move(expression), std::move(resolvedTypes)); +} } // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index bae2653029..2f0a4fa597 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -132,6 +132,13 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } } + Visitor(const Expression::ExpressionPtr &expression, EnvironmentInternal &environment, + const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes) + : m_Environment(environment), m_Context(context), m_ResolvedTypes(resolvedTypes) + { + expression.get()->accept(*this); + } + private: //--------------------------------------------------------------------------- // Expression::Visitor virtuals @@ -468,3 +475,10 @@ void GeNN::Transpiler::PrettyPrinter::print(const Statement::StatementList &stat EnvironmentInternal internalEnvironment(environment); Visitor visitor(statements, internalEnvironment, context, resolvedTypes); } +//--------------------------------------------------------------------------- +void GeNN::Transpiler::PrettyPrinter::print(const Expression::ExpressionPtr &expression, EnvironmentBase &environment, + const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes) +{ + EnvironmentInternal internalEnvironment(environment); + Visitor visitor(expression, internalEnvironment, context, resolvedTypes); +} \ No newline at end of file diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index b295737654..e83ca1f5e8 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -864,11 +864,11 @@ ResolvedTypeMap GeNN::Transpiler::TypeChecker::typeCheck(const Statement::Statem return expressionTypes; } //--------------------------------------------------------------------------- -Type::ResolvedType GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, EnvironmentBase &environment, +ResolvedTypeMap GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, EnvironmentBase &environment, ErrorHandlerBase &errorHandler) { ResolvedTypeMap expressionTypes; EnvironmentInternal internalEnvironment(environment); Visitor visitor(expression, internalEnvironment, expressionTypes, errorHandler); - return expressionTypes.at(expression); + return expressionTypes; } From 11f32ffdb96401679567375bb5f72e996d136505 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 5 Jun 2023 17:14:45 +0100 Subject: [PATCH 198/725] approximately completed neuron update group --- .../code_generator/neuronUpdateGroupMerged.h | 37 +- .../code_generator/neuronUpdateGroupMerged.cc | 391 ++++++++---------- 2 files changed, 200 insertions(+), 228 deletions(-) diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index 1cbe238caa..5451aa3f5c 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -134,8 +134,8 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase void generate(const BackendBase &backend, EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) const; - void genCopyDelayedVars(CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void genCopyDelayedVars(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged) const; //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -182,11 +182,11 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike) const; + void generate(const BackendBase &backend, EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) const; - void genCopyDelayedVars(CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void genCopyDelayedVars(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged) const; //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -242,7 +242,7 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase BackendBase::GroupHandlerEnv genEmitTrueSpike, BackendBase::GroupHandlerEnv genEmitSpikeLikeEvent) const; - void generateWUVarUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateWUVarUpdate(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged) const; std::string getVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; std::string getReadVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; @@ -260,11 +260,6 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase static const std::string name; private: - //------------------------------------------------------------------------ - // Private methods - //------------------------------------------------------------------------ - void addNeuronModelSubstitutions(Substitutions &substitution, const std::string &sourceSuffix = "", const std::string &destSuffix = "") const; - //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ @@ -273,5 +268,23 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase std::vector m_MergedOutSynPreOutputGroups; std::vector m_MergedInSynWUMPostCodeGroups; std::vector m_MergedOutSynWUMPreCodeGroups; + + //! List of statements parsed and type-checked in constructor; and used to generate sim code + Transpiler::Statement::StatementList m_SimStatements; + + //! Expression parsed and type-checked in constructor; and used to generate threshold condition code + Transpiler::Expression::ExpressionPtr m_ThresholdConditionExpression; + + //! List of statements parsed and type-checked in constructor; and used to generate reset code + Transpiler::Statement::StatementList m_ResetStatements; + + //! Resolved types used to generate sim code + Transpiler::TypeChecker::ResolvedTypeMap m_SimResolvedTypes; + + //! Resolved types used to generate threshold condition code + Transpiler::TypeChecker::ResolvedTypeMap m_ThresholdConditionResolvedTypes; + + //! Resolved types used to generate threshold condition code + Transpiler::TypeChecker::ResolvedTypeMap m_ResetResolvedTypes; }; } // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 729089a771..0bcc9c0164 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -51,7 +51,7 @@ NeuronUpdateGroupMerged::CurrentSource::CurrentSource(size_t index, const Type:: // Scan, parse and type-check injection code ErrorHandler errorHandler; - std::tie(m_InjectionStatements, m_InjectionResolvedTypes) = scanParseAndTypeCheck(cm->getInjectionCode(), typeContext, + std::tie(m_InjectionStatements, m_InjectionResolvedTypes) = scanParseAndTypeCheckStatements(cm->getInjectionCode(), typeContext, typeEnvironment, errorHandler); } //---------------------------------------------------------------------------- @@ -73,9 +73,9 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varSubs( getArchetype(), getTypeContext(), envSubs, - [&modelMerged, &ng](const std::string&, const Models::VarInit&, VarAccess a) + [&envSubs, &modelMerged, &ng](const std::string&, const Models::VarInit&, VarAccess a) { - return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), "id"); + return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), envSubs["id"]); }); //currSourceSubs.addFuncSubstitution("injectCurrent", 1, "Isyn += $(0)"); @@ -148,9 +148,9 @@ NeuronUpdateGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContex // Scan, parse and type-check decay and apply input code ErrorHandler errorHandler; - std::tie(m_DecayStatements, m_DecayResolvedTypes) = scanParseAndTypeCheck(psm->getDecayCode(), typeContext, + std::tie(m_DecayStatements, m_DecayResolvedTypes) = scanParseAndTypeCheckStatements(psm->getDecayCode(), typeContext, typeEnvironment, errorHandler); - std::tie(m_ApplyInputStatements, m_ApplyInputResolvedTypes) = scanParseAndTypeCheck(psm->getApplyInputCode(), typeContext, + std::tie(m_ApplyInputStatements, m_ApplyInputResolvedTypes) = scanParseAndTypeCheckStatements(psm->getApplyInputCode(), typeContext, typeEnvironment, errorHandler); } //---------------------------------------------------------------------------- @@ -165,7 +165,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env envSubs.getStream() << "// current source " << getIndex() << std::endl; envSubs.getStream() << "scalar linSyn = group->inSynInSyn" << getIndex() << "["; - envSubs.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); + envSubs.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, envSubs["id"]); envSubs.getStream() << "];" << std::endl; // If dendritic delay is required @@ -173,7 +173,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Get reference to dendritic delay buffer input for this timestep envSubs.getStream() << backend.getPointerPrefix() << "scalar *denDelayFront = "; envSubs.getStream() << "&group->denDelay" << suffix << "[(*group->denDelayPtr" << suffix << " * group->numNeurons) + "; - envSubs.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); + envSubs.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, envSubs["id"]); envSubs.getStream() << "];" << std::endl; // Add delayed input from buffer into inSyn @@ -199,9 +199,9 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varSubs( getArchetype(), getTypeContext(), envSubs, - [&modelMerged, &ng](const std::string&, const Models::VarInit&, VarAccess a) + [&envSubs, &modelMerged, &ng](const std::string&, const Models::VarInit&, VarAccess a) { - return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), "id"); + return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), envSubs["id"]); }); // Pretty print previously parsed update statements @@ -210,7 +210,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Write back linSyn varSubs.getStream() << "group->inSyn" << suffix << "["; - varSubs.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); + varSubs.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, envSubs["id"]); varSubs.getStream() << "] = linSyn;" << std::endl; } //---------------------------------------------------------------------------- @@ -252,10 +252,10 @@ void NeuronUpdateGroupMerged::OutSynPreOutput::generate(EnvironmentExternal &env env.getStream() << getArchetype().getPreTargetVar() << " += "; env.getStream() << "group->revInSyn" << suffix << "["; - env.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); + env.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, env["id"]); env.getStream() << "];" << std::endl; env.getStream() << "group->revInSyn" << suffix << "["; - env.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); + env.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, env["id"]); env.getStream() << "] = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; } @@ -292,9 +292,9 @@ NeuronUpdateGroupMerged::InSynWUMPostCode::InSynWUMPostCode(size_t index, const // Scan, parse and type-check dynamics and spike code ErrorHandler errorHandler; - std::tie(m_DynamicsStatements, m_DynamicsResolvedTypes) = scanParseAndTypeCheck(wum->getPostDynamicsCode(), typeContext, + std::tie(m_DynamicsStatements, m_DynamicsResolvedTypes) = scanParseAndTypeCheckStatements(wum->getPostDynamicsCode(), typeContext, typeEnvironment, errorHandler); - std::tie(m_SpikeStatements, m_SpikeResolvedTypes) = scanParseAndTypeCheck(wum->getPostSpikeCode(), typeContext, + std::tie(m_SpikeStatements, m_SpikeResolvedTypes) = scanParseAndTypeCheckStatements(wum->getPostSpikeCode(), typeContext, typeEnvironment, errorHandler); } //---------------------------------------------------------------------------- @@ -306,7 +306,7 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - // If there are any statements to executre here + // If there are any statements to execute here const auto &statements = dynamicsNotSpike ? m_DynamicsStatements : m_SpikeStatements; const auto &resolvedTypes = dynamicsNotSpike ? m_DynamicsResolvedTypes : m_SpikeResolvedTypes; if(!statements.empty()) { @@ -323,13 +323,13 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varSubs( getArchetype(), getTypeContext(), envSubs, - [batchSize, delayed, &ng](const std::string&, const Models::VarInit&, VarAccess a) + [batchSize, delayed, &envSubs, &ng](const std::string&, const Models::VarInit&, VarAccess a) { - return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), "id"); + return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), envSubs["id"]); }, - [batchSize, delayed, &ng](const std::string&, const Models::VarInit&, VarAccess a) + [batchSize, delayed, &envSubs, &ng](const std::string&, const Models::VarInit&, VarAccess a) { - return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), "id"); + return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), envSubs["id"]); }); /*neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_post", "", "", "", dynamicsNotSpike, @@ -349,8 +349,8 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back } } //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged) const { // If this group has a delay and no postsynaptic dynamics (which will already perform this copying) const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); @@ -358,13 +358,13 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(CodeStream &o // Loop through variables and copy between read and write delay slots for(const auto &v : getArchetype().getWUModel()->getPostVars()) { if(v.access & VarAccessMode::READ_WRITE) { - os << "group->" << v.name << suffix << "["; - os << ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), popSubs["id"]); - os << "] = "; + env.getStream() << "group->" << v.name << suffix << "["; + env.getStream() << ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), env["id"]); + env.getStream() << "] = "; - os << "group->" << v.name << suffix << "["; - os << ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), popSubs["id"]); - os << "];" << std::endl; + env.getStream() << "group->" << v.name << suffix << "["; + env.getStream() << ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), env["id"]); + env.getStream() << "];" << std::endl; } } } @@ -421,46 +421,48 @@ NeuronUpdateGroupMerged::OutSynWUMPreCode::OutSynWUMPreCode(size_t index, const // Scan, parse and type-check dynamics and spike code ErrorHandler errorHandler; - std::tie(m_DynamicsStatements, m_DynamicsResolvedTypes) = scanParseAndTypeCheck(wum->getPreDynamicsCode(), typeContext, + std::tie(m_DynamicsStatements, m_DynamicsResolvedTypes) = scanParseAndTypeCheckStatements(wum->getPreDynamicsCode(), typeContext, typeEnvironment, errorHandler); - std::tie(m_SpikeStatements, m_SpikeResolvedTypes) = scanParseAndTypeCheck(wum->getPreSpikeCode(), typeContext, + std::tie(m_SpikeStatements, m_SpikeResolvedTypes) = scanParseAndTypeCheckStatements(wum->getPreSpikeCode(), typeContext, typeEnvironment, errorHandler); } //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &backend, CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs, bool dynamicsNotSpike) const +void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &backend, EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) const { const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); - + const auto *wum = getArchetype().getWUModel(); const unsigned int batchSize = modelMerged.getModel().getBatchSize(); + + // If there are any statements to executre here + const auto &statements = dynamicsNotSpike ? m_DynamicsStatements : m_SpikeStatements; + const auto &resolvedTypes = dynamicsNotSpike ? m_DynamicsResolvedTypes : m_SpikeResolvedTypes; + + // If there are any statements to execute here + if(!statements.empty()) { + // Create new substitution environment and add parameters, derived parameters and extra global parameters + EnvironmentSubstitute envSubs(env); + envSubs.getStream() << "// presynaptic weight update " << getIndex() << std::endl; + envSubs.addParamValueSubstitution(wum->getParamNames(), getArchetype().getWUParams(), + [this](const std::string &p) { return isParamHeterogeneous(p); }); + envSubs.addVarValueSubstitution(wum->getDerivedParams(), getArchetype().getWUDerivedParams(), + [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); + envSubs.addVarNameSubstitution(wum->getExtraGlobalParams()); - // If this code string isn't empty - std::string code = dynamicsNotSpike ? getArchetype().getWUModel()->getPreDynamicsCode() : getArchetype().getWUModel()->getPreSpikeCode(); - if(!code.empty()) { - Substitutions subs(&popSubs); - - // Fetch presynaptic variables from global memory - os << "// perform WUM update required for merged" << getIndex() << std::endl; - const auto vars = getArchetype().getWUModel()->getPreVars(); + // Create an environment which caches variables in local variables if they are accessed const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); - for(const auto &v : vars) { - if(v.access & VarAccessMode::READ_ONLY) { - os << "const "; - } - os << v.type.resolve(getTypeContext()).getName() << " l" << v.name << " = group->" << v.name << suffix << "["; - os << ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), subs["id"]) << "];" << std::endl; - } - - subs.addParamValueSubstitution(getArchetype().getWUModel()->getParamNames(), getArchetype().getWUParams(), - [this](const std::string &p) { return isParamHeterogeneous(p); }, - "", "group->", suffix); - subs.addVarValueSubstitution(getArchetype().getWUModel()->getDerivedParams(), getArchetype().getWUDerivedParams(), - [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }, - "", "group->", suffix); - subs.addVarNameSubstitution(getArchetype().getWUModel()->getExtraGlobalParams(), "", "group->", suffix); - subs.addVarNameSubstitution(vars, "", "l"); - - neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_pre", "", "", "", dynamicsNotSpike, + EnvironmentLocalVarCache varSubs( + getArchetype(), getTypeContext(), envSubs, + [batchSize, delayed, &envSubs, &ng](const std::string&, const Models::VarInit&, VarAccess a) + { + return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), envSubs["id"]); + }, + [batchSize, delayed, &envSubs, &ng](const std::string&, const Models::VarInit&, VarAccess a) + { + return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), envSubs["id"]); + }); + + /*neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_pre", "", "", "", dynamicsNotSpike, [&ng](const std::string &p) { return ng.isParamHeterogeneous(p); }, [&ng](const std::string &p) { return ng.isDerivedParamHeterogeneous(p); }, [&subs, &ng, batchSize](bool delay, VarAccessDuplication varDuplication) @@ -470,28 +472,15 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back [&subs, &ng, batchSize](bool delay, VarAccessDuplication varDuplication) { return ng.getReadVarIndex(delay, batchSize, varDuplication, subs["id"]); - }); - - // Perform standard substitutions - subs.applyCheckUnreplaced(code, "spikeCode : merged" + getIndex()); - //code = ensureFtype(code, precision); - os << code; - - // Write back presynaptic variables into global memory - for(const auto &v : vars) { - // If state variables is read/write - meaning that it may have been updated - or it is delayed - - // meaning that it needs to be copied into next delay slot whatever - copy neuron state variables - // back to global state variables dd_V etc - if((v.access & VarAccessMode::READ_WRITE) || delayed) { - os << "group->" << v.name << suffix << "["; - os << ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(v.access), subs["id"]) << "] = l" << v.name << ";" << std::endl; - } - } + });*/ + + // Pretty print previously parsed statements + PrettyPrinter::print(statements, varSubs, getTypeContext(), resolvedTypes); } } //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(CodeStream &os, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged) const { // If this group has a delay and no presynaptic dynamics (which will already perform this copying) const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); @@ -499,13 +488,13 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(CodeStream &o // Loop through variables and copy between read and write delay slots for(const auto &v : getArchetype().getWUModel()->getPreVars()) { if(v.access & VarAccessMode::READ_WRITE) { - os << "group->" << v.name << suffix << "["; - os << ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), popSubs["id"]); - os << "] = "; + env.getStream() << "group->" << v.name << suffix << "["; + env.getStream() << ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), env["id"]); + env.getStream() << "] = "; - os << "group->" << v.name << suffix << "["; - os << ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), popSubs["id"]); - os << "];" << std::endl; + env.getStream() << "group->" << v.name << suffix << "["; + env.getStream() << ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), env["id"]); + env.getStream() << "];" << std::endl; } } } @@ -641,6 +630,15 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC GroupMergedFieldType::DYNAMIC); } + // Parse code + ErrorHandler errorHandler; + std::tie(m_SimStatements, m_SimResolvedTypes) = scanParseAndTypeCheckStatements( + nm->getSimCode(), typeContext, typeEnvironment, errorHandler); + std::tie(m_ThresholdConditionExpression, m_ThresholdConditionResolvedTypes) = scanParseAndTypeCheckExpression( + nm->getThresholdConditionCode(), typeContext, typeEnvironment, errorHandler); + std::tie(m_ResetStatements, m_ResetResolvedTypes) = scanParseAndTypeCheckStatements( + nm->getResetCode(), typeContext, typeEnvironment, errorHandler); + // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, typeEnvironment, backend, &NeuronGroupInternal::getFusedPSMInSyn, &SynapseGroupInternal::getPSHashDigest); @@ -705,114 +703,92 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E const NeuronModels::Base *nm = getArchetype().getNeuronModel(); + EnvironmentSubstitute neuronEnv(env); + neuronEnv.addSubstitution("Isyn", "Isyn", + {neuronEnv.addInitialiser("scalar Isyn = 0;")}); + + // **NOTE** arbitrary code in param value to be deprecated + for (const auto &v : nm->getAdditionalInputVars()) { + const std::string typeName = v.type.resolve(getTypeContext()).getName(); + neuronEnv.addSubstitution(v.name, v.value, + {neuronEnv.addInitialiser(typeName + " " + v.name + " = " + v.value + ";")}); + } + + neuronEnv.addParamValueSubstitution(nm->getParamNames(), getArchetype().getParams(), + [this](const std::string &p) { return isParamHeterogeneous(p); }); + neuronEnv.addVarValueSubstitution(nm->getDerivedParams(), getArchetype().getDerivedParams(), + [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); + neuronEnv.addVarNameSubstitution(nm->getExtraGlobalParams()); + + if(getArchetype().isSpikeTimeRequired()) { + const size_t initialiser = neuronEnv.addInitialiser( + "const timepoint lsT = group->sT[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, neuronEnv["id"]) + "];"); + neuronEnv.addSubstitution("sT", "lsT", {initialiser}); + } + if(getArchetype().isPrevSpikeTimeRequired()) { + const size_t initialiser = neuronEnv.addInitialiser( + "const timepoint lprevST = group->prevST[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, neuronEnv["id"]) + "];"); + neuronEnv.addSubstitution("prev_sT", "lprevST", {initialiser}); + } + if(getArchetype().isSpikeEventTimeRequired()) { + const size_t initialiser = neuronEnv.addInitialiser( + "const timepoint lseT = group->seT[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, neuronEnv["id"]) + "];"); + neuronEnv.addSubstitution("seT", "lseT", {initialiser}); + } + if(getArchetype().isPrevSpikeEventTimeRequired()) { + const size_t initialiser = neuronEnv.addInitialiser( + "const timepoint lprevSET = group->prevSET[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, neuronEnv["id"]) + "];"); + neuronEnv.addSubstitution("prev_seT", "lprevSET", {initialiser}); + } + // Create an environment which caches variables in local variables if they are accessed // **NOTE** we do this right at the top so that local copies can be used by child groups EnvironmentLocalVarCache neuronVarEnv( - getArchetype(), getTypeContext(), env, - [batchSize, this](const std::string &varName, const Models::VarInit&, VarAccess a) + getArchetype(), getTypeContext(), neuronEnv, + [batchSize, &neuronEnv, this](const std::string &varName, const Models::VarInit&, VarAccess a) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), "id") ; + return getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), neuronEnv["id"]) ; }, - [batchSize, this](const std::string &varName, const Models::VarInit&, VarAccess a) + [batchSize, &neuronEnv, this](const std::string &varName, const Models::VarInit&, VarAccess a) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), "id") ; + return getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), neuronEnv["id"]) ; }); - // If neuron model sim code references ISyn (could still be the case if there are no incoming synapses) - // OR any incoming synapse groups have post synaptic models which reference $(Isyn), declare it - if (nm->getSimCode().find("$(Isyn)") != std::string::npos || - std::any_of(getArchetype().getFusedPSMInSyn().cbegin(), getArchetype().getFusedPSMInSyn().cend(), - [](const SynapseGroupInternal *sg) - { - return (sg->getPSModel()->getApplyInputCode().find("$(Isyn)") != std::string::npos - || sg->getPSModel()->getDecayCode().find("$(Isyn)") != std::string::npos); - })) - { - os << "scalar Isyn = 0;" << std::endl; - } - - Substitutions neuronSubs(&popSubs); - neuronSubs.addVarSubstitution("Isyn", "Isyn"); - - - neuronSubs.addVarNameSubstitution(nm->getAdditionalInputVars()); - addNeuronModelSubstitutions(neuronSubs); - - // Initialise any additional input variables supported by neuron model - for (const auto &a : nm->getAdditionalInputVars()) { - // Apply substitutions to value - std::string value = a.value; - neuronSubs.applyCheckUnreplaced(value, "neuron additional input var : merged" + std::to_string(getIndex())); - //value = ensureFtype(value, modelMerged.getModel().getPrecision()); - - os << a.type.resolve(getTypeContext()).getName() << " " << a.name << " = " << value << ";" << std::endl; - } // Loop through incoming synapse groups for(const auto &sg : getMergedInSynPSMGroups()) { - CodeStream::Scope b(os); - sg.generate(backend, os, *this, modelMerged, popSubs); + CodeStream::Scope b(env.getStream()); + sg.generate(backend, env, *this, modelMerged); } // Loop through outgoing synapse groups with presynaptic output for (const auto &sg : getMergedOutSynPreOutputGroups()) { - CodeStream::Scope b(os); - sg.generate(backend, os, *this, modelMerged, popSubs); + CodeStream::Scope b(env.getStream()); + sg.generate(env, *this, modelMerged); } // Loop through all of neuron group's current sources for (const auto &cs : getMergedCurrentSourceGroups()) { - CodeStream::Scope b(os); - cs.generate(backend, os, *this, modelMerged, popSubs); + CodeStream::Scope b(env.getStream()); + cs.generate(backend, env, *this, modelMerged); } - // Read spike and spike-like-event times into local variables if required - EnvironmentSubstitute neuronEnv(neuronVarEnv); - if(getArchetype().isSpikeTimeRequired()) { - const size_t initialiser = neuronEnv.addInitialiser( - "const timepoint lsT = group->sT[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "id") + "];"); - neuronEnv.addSubstitution("sT", "lsT", {initialiser}); - } - if(getArchetype().isPrevSpikeTimeRequired()) { - const size_t initialiser = neuronEnv.addInitialiser( - "const timepoint lprevST = group->prevST[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "id") + "];"); - neuronEnv.addSubstitution("prev_sT", "lprevST", {initialiser}); - } - if(getArchetype().isSpikeEventTimeRequired()) { - const size_t initialiser = neuronEnv.addInitialiser( - "const timepoint lseT = group->seT[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "id") + "];"); - neuronEnv.addSubstitution("seT", "lseT", {initialiser}); - } - if(getArchetype().isPrevSpikeEventTimeRequired()) { - const size_t initialiser = neuronEnv.addInitialiser( - "const timepoint lprevSET = group->prevSET[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "id") << "];"); - neuronEnv.addSubstitution("prev_seT", "lprevSET", {initialiser}); - } - neuronEnv.getStream() << std::endl; - // Add neuron parameters, derived parameters and extra global parameters to neuron environment - neuronEnv.addParamValueSubstitution(nm->getParamNames(), getArchetype().getParams(), - [this](const std::string &p) { return isParamHeterogeneous(p); }); - neuronEnv.addVarValueSubstitution(nm->getDerivedParams(), getArchetype().getDerivedParams(), - [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); - neuronEnv.addVarNameSubstitution(nm->getExtraGlobalParams()); // If a threshold condition is provided - std::string thCode = nm->getThresholdConditionCode(); - if (!thCode.empty()) { - os << "// test whether spike condition was fulfilled previously" << std::endl; - - neuronSubs.applyCheckUnreplaced(thCode, "thresholdConditionCode : merged" + std::to_string(getIndex())); - //thCode= ensureFtype(thCode, model.getPrecision()); - - if (!nm->getSupportCode().empty() && !backend.supportsNamespace()) { - thCode = disambiguateNamespaceFunction(nm->getSupportCode(), thCode, modelMerged.getNeuronUpdateSupportCodeNamespace(nm->getSupportCode())); - } + if (m_ThresholdConditionExpression) { + neuronVarEnv.getStream() << "// test whether spike condition was fulfilled previously" << std::endl; + + //if (!nm->getSupportCode().empty() && !backend.supportsNamespace()) { + // thCode = disambiguateNamespaceFunction(nm->getSupportCode(), thCode, modelMerged.getNeuronUpdateSupportCodeNamespace(nm->getSupportCode())); + //} if (nm->isAutoRefractoryRequired()) { - os << "const bool oldSpike = (" << thCode << ");" << std::endl; + neuronVarEnv.getStream() << "const bool oldSpike = ("; + PrettyPrinter::print(m_ThresholdConditionExpression, neuronVarEnv, getTypeContext(), m_ThresholdConditionResolvedTypes); + neuronVarEnv.getStream() << ");" << std::endl; } } // Otherwise, if any outgoing synapse groups have spike-processing code @@ -822,30 +798,25 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E LOGW_CODE_GEN << "No thresholdConditionCode for neuron type " << typeid(*nm).name() << " used for population \"" << getName() << "\" was provided. There will be no spikes detected in this population!"; }*/ - os << "// calculate membrane potential" << std::endl; - std::string sCode = nm->getSimCode(); - neuronSubs.applyCheckUnreplaced(sCode, "simCode : merged" + std::to_string(getIndex())); - //sCode = ensureFtype(sCode, model.getPrecision()); - - - os << sCode << std::endl; + neuronVarEnv.getStream() << "// calculate membrane potential" << std::endl; + PrettyPrinter::print(m_SimStatements, neuronVarEnv, getTypeContext(), m_SimResolvedTypes); // Generate var update for outgoing synaptic populations with presynaptic update code for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { - CodeStream::Scope b(os); - sg.generate(backend, os, *this, modelMerged, popSubs, true); + CodeStream::Scope b(neuronVarEnv.getStream()); + sg.generate(backend, neuronVarEnv, *this, modelMerged, true); } // Generate var update for incoming synaptic populations with postsynaptic code for (const auto &sg : getMergedInSynWUMPostCodeGroups()) { - CodeStream::Scope b(os); - sg.generate(backend, os, *this, modelMerged, popSubs, true); + CodeStream::Scope b(neuronVarEnv.getStream()); + sg.generate(backend, neuronVarEnv, *this, modelMerged, true); } // look for spike type events first. - if (getArchetype().isSpikeEventRequired()) { + /*if (getArchetype().isSpikeEventRequired()) { // Create local variable - os << "bool spikeLikeEvent = false;" << std::endl; + neuronVarEnv.getStream() << "bool spikeLikeEvent = false;" << std::endl; // Loop through outgoing synapse populations that will contribute to event condition code size_t i = 0; @@ -912,29 +883,26 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E os << "group->prevSET[" << getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "] = lprevSET;" << std::endl; } } - } + }*/ // test for true spikes if condition is provided - if (!thCode.empty()) { - os << "// test for and register a true spike" << std::endl; + if (m_ThresholdConditionExpression) { + neuronVarEnv.getStream() << "// test for and register a true spike" << std::endl; + neuronVarEnv.getStream() << "if (("; + PrettyPrinter::print(m_ThresholdConditionExpression, neuronVarEnv, getTypeContext(), m_ThresholdConditionResolvedTypes); + neuronVarEnv.getStream() << ")"; if (nm->isAutoRefractoryRequired()) { - os << "if ((" << thCode << ") && !(oldSpike))"; - } - else { - os << "if (" << thCode << ")"; + neuronVarEnv.getStream() << " && !oldSpike"; } + neuronVarEnv.getStream() << ")"; { - CodeStream::Scope b(os); - genEmitTrueSpike(os, *this, popSubs); + CodeStream::Scope b(neuronVarEnv.getStream()); + genEmitTrueSpike(neuronVarEnv, *this); // add after-spike reset if provided - if (!nm->getResetCode().empty()) { - std::string rCode = nm->getResetCode(); - neuronSubs.applyCheckUnreplaced(rCode, "resetCode : merged" + std::to_string(getIndex())); - //rCode = ensureFtype(rCode, model.getPrecision()); - - os << "// spike reset code" << std::endl; - os << rCode << std::endl; + if (!m_ResetStatements.empty()) { + neuronVarEnv.getStream() << "// spike reset code" << std::endl; + PrettyPrinter::print(m_ResetStatements, neuronVarEnv, getTypeContext(), m_ResetResolvedTypes); } } @@ -965,45 +933,49 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // If spike times, presynaptic variables or postsynaptic variables are required, add if clause if(getArchetype().isSpikeTimeRequired() || getArchetype().isPrevSpikeTimeRequired() || preVars || postVars) { - os << "else"; - CodeStream::Scope b(os); + neuronVarEnv.getStream() << "else"; + CodeStream::Scope b(neuronVarEnv.getStream()); // If spike times are required, copy times from register if(getArchetype().isSpikeTimeRequired()) { - os << "group->sT[" << getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "] = lsT;" << std::endl; + neuronVarEnv.getStream() << "group->sT["; + neuronVarEnv.getStream() << getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, neuronVarEnv["id"]); + neuronVarEnv.getStream() << "] = lsT;" << std::endl; } // If previous spike times are required, copy times from register if(getArchetype().isPrevSpikeTimeRequired()) { - os << "group->prevST[" << getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) << "] = lprevST;" << std::endl; + neuronVarEnv.getStream() << "group->prevST["; + neuronVarEnv.getStream() << getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, neuronVarEnv["id"]); + neuronVarEnv.getStream() << "] = lprevST;" << std::endl; } // Loop through outgoing synapse groups with some sort of presynaptic code for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { - sg.genCopyDelayedVars(os, *this, modelMerged, popSubs); + sg.genCopyDelayedVars(neuronVarEnv, *this, modelMerged); } // Loop through incoming synapse groups with some sort of presynaptic code for (const auto &sg : getMergedInSynWUMPostCodeGroups()) { - sg.genCopyDelayedVars(os, *this, modelMerged, popSubs); + sg.genCopyDelayedVars(neuronVarEnv, *this, modelMerged); } } } } } //-------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged) const { // Generate var update for outgoing synaptic populations with presynaptic update code for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { - CodeStream::Scope b(os); - sg.generate(backend, os, *this, modelMerged, popSubs, false); + CodeStream::Scope b(env.getStream()); + sg.generate(backend, env, *this, modelMerged, false); } // Generate var update for incoming synaptic populations with postsynaptic code for (const auto &sg : getMergedInSynWUMPostCodeGroups()) { - CodeStream::Scope b(os); - sg.generate(backend, os, *this, modelMerged, popSubs, false); + CodeStream::Scope b(env.getStream()); + sg.generate(backend, env, *this, modelMerged, false); } } //-------------------------------------------------------------------------- @@ -1056,16 +1028,3 @@ std::string NeuronUpdateGroupMerged::getWriteVarIndex(bool delay, unsigned int b return getVarIndex(batchSize, varDuplication, index); } } -//---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::addNeuronModelSubstitutions(Substitutions &substitution, const std::string &sourceSuffix, const std::string &destSuffix) const -{ - const NeuronModels::Base *nm = getArchetype().getNeuronModel(); - substitution.addVarNameSubstitution(nm->getVars(), sourceSuffix, "l", destSuffix); - substitution.addParamValueSubstitution(nm->getParamNames(), getArchetype().getParams(), - [this](const std::string &p) { return isParamHeterogeneous(p); }, - sourceSuffix, "group->"); - substitution.addVarValueSubstitution(nm->getDerivedParams(), getArchetype().getDerivedParams(), - [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }, - sourceSuffix, "group->"); - substitution.addVarNameSubstitution(nm->getExtraGlobalParams(), sourceSuffix, "group->"); -} From a9e2fa265cc1e827a51db6be37652c3ce95be42a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 5 Jun 2023 17:15:15 +0100 Subject: [PATCH 199/725] pull ids out of environment in custom update --- .../code_generator/customUpdateGroupMerged.cc | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index a04bad3e84..122619cb3d 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -72,7 +72,7 @@ CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::TypeC // Scan, parse and type-check update code ErrorHandler errorHandler; - std::tie(m_UpdateStatements, m_ResolvedTypes) = scanParseAndTypeCheck(cm->getUpdateCode(), typeContext, + std::tie(m_UpdateStatements, m_ResolvedTypes) = scanParseAndTypeCheckStatements(cm->getUpdateCode(), typeContext, typeEnvironment, errorHandler); } //---------------------------------------------------------------------------- @@ -118,19 +118,19 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varSubs( getArchetype(), getTypeContext(), envSubs, - [this](const std::string&, const Models::VarInit&, VarAccess a) + [this, &envSubs](const std::string&, const Models::VarInit&, VarAccess a) { - return getVarIndex(getVarAccessDuplication(a), "id"); + return getVarIndex(getVarAccessDuplication(a), envSubs["id"]); }); // Create an environment which caches variable references in local variables if they are accessed EnvironmentLocalVarCache varRefSubs( getArchetype(), getTypeContext(), varSubs, - [this](const std::string&, const Models::VarReference &v, VarAccessMode) + [this, &envSubs](const std::string&, const Models::VarReference &v, VarAccessMode) { return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, getVarAccessDuplication(v.getVar().access), - "id"); + envSubs["id"]); }); // Pretty print previously parsed update statements @@ -227,18 +227,18 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varSubs( getArchetype(), getTypeContext(), envSubs, - [this](const std::string&, const Models::VarInit&, VarAccess a) + [&envSubs, this](const std::string&, const Models::VarInit&, VarAccess a) { - return getVarIndex(getVarAccessDuplication(a), "id_syn"); + return getVarIndex(getVarAccessDuplication(a), envSubs["id_syn"]); }); // Create an environment which caches variable references in local variables if they are accessed EnvironmentLocalVarCache varRefSubs( getArchetype(), getTypeContext(), varSubs, - [this](const std::string&, const Models::WUVarReference &v, VarAccessMode) + [&envSubs, this](const std::string&, const Models::WUVarReference &v, VarAccessMode) { return getVarRefIndex(getVarAccessDuplication(v.getVar().access), - "id_syn"); + envSubs["id_syn"]); }); // Pretty print previously parsed update statements @@ -358,7 +358,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // Scan, parse and type-check update code ErrorHandler errorHandler; - std::tie(m_UpdateStatements, m_ResolvedTypes) = scanParseAndTypeCheck(cm->getUpdateCode(), typeContext, + std::tie(m_UpdateStatements, m_ResolvedTypes) = scanParseAndTypeCheckStatements(cm->getUpdateCode(), typeContext, typeEnvironment, errorHandler); } From 36be6e7c6b5ce8bb85bb44c7c0e351da75471bcb Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 5 Jun 2023 17:47:30 +0100 Subject: [PATCH 200/725] pull out spike time variables 'officially' --- src/genn/genn/code_generator/neuronUpdateGroupMerged.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 0bcc9c0164..fa40439e6e 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -940,14 +940,14 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E if(getArchetype().isSpikeTimeRequired()) { neuronVarEnv.getStream() << "group->sT["; neuronVarEnv.getStream() << getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, neuronVarEnv["id"]); - neuronVarEnv.getStream() << "] = lsT;" << std::endl; + neuronVarEnv.getStream() << "] = " << neuronVarEnv["sT"] << ";" << std::endl; } // If previous spike times are required, copy times from register if(getArchetype().isPrevSpikeTimeRequired()) { neuronVarEnv.getStream() << "group->prevST["; neuronVarEnv.getStream() << getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, neuronVarEnv["id"]); - neuronVarEnv.getStream() << "] = lprevST;" << std::endl; + neuronVarEnv.getStream() << "] = " << neuronVarEnv["prev_sT"] << ";" << std::endl; } // Loop through outgoing synapse groups with some sort of presynaptic code From cf7adcf741d8628162391a66f0b65ba2ad9f63b5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 8 Jun 2023 13:03:56 +0100 Subject: [PATCH 201/725] hooking up and tidying initialisation --- .../backends/single_threaded_cpu/backend.h | 5 +- .../genn/genn/code_generator/backendBase.h | 7 +- .../genn/genn/code_generator/backendSIMT.h | 5 +- .../genn/genn/code_generator/environment.h | 28 +- .../groupMergedTypeEnvironment.h | 2 +- .../genn/code_generator/initGroupMerged.h | 44 +- include/genn/genn/currentSourceInternal.h | 2 + include/genn/genn/customUpdate.h | 2 + include/genn/genn/neuronGroupInternal.h | 4 + include/genn/genn/synapseGroupInternal.h | 12 +- .../backends/single_threaded_cpu/backend.cc | 20 +- .../genn/code_generator/initGroupMerged.cc | 466 +++++++----------- .../code_generator/neuronUpdateGroupMerged.cc | 38 +- 13 files changed, 282 insertions(+), 353 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 155a4a8db9..0de2c91fa2 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -121,9 +121,8 @@ class BACKEND_EXPORT Backend : public BackendBase //! When generating merged structures what type to use for simulation RNGs virtual std::optional getMergedGroupSimRNGType() const final; - virtual void genPopVariableInit(CodeStream &os,const Substitutions &kernelSubs, Handler handler) const final; - virtual void genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, - const Substitutions &kernelSubs, Handler handler) const final; + virtual void genPopVariableInit(EnvironmentExternal &env, HandlerEnv handler) const final; + virtual void genVariableInit(EnvironmentExternal &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const final; virtual void genSparseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const final; virtual void genDenseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const final; virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const final; diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index ddc1e4e250..aa0d8b2612 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -186,6 +186,8 @@ class GENN_EXPORT BackendBase typedef std::function HostHandler; typedef std::function Handler; + + typedef std::function HandlerEnv; template using GroupHandler = std::function ; @@ -312,9 +314,8 @@ class GENN_EXPORT BackendBase //! When generating merged structures what type to use for simulation RNGs virtual std::optional getMergedGroupSimRNGType() const = 0; - virtual void genPopVariableInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0; - virtual void genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, - const Substitutions &kernelSubs, Handler handler) const = 0; + virtual void genPopVariableInit(EnvironmentExternal &env, HandlerEnv handler) const = 0; + virtual void genVariableInit(EnvironmentExternal &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const = 0; virtual void genSparseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0; virtual void genDenseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0; virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const = 0; diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index c1061187d4..b12d52bf65 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -122,9 +122,8 @@ class GENN_EXPORT BackendSIMT : public BackendBase //! This function returns the device prefix so it can be used in otherwise platform-independent code. virtual std::string getDeviceVarPrefix() const final { return getPreferences().automaticCopy ? "" : "d_"; } - virtual void genPopVariableInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const final; - virtual void genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, - const Substitutions &kernelSubs, Handler handler) const final; + virtual void genPopVariableInit(EnvironmentExternal &env, HandlerEnv handler) const final; + virtual void genVariableInit(EnvironmentExternal &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const final; virtual void genSparseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const final { genSynapseVariableRowInit(os, kernelSubs, handler); diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 2c782db246..382f7bd8f0 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -104,16 +104,16 @@ class EnvironmentSubstitute : public EnvironmentExternal size_t addInitialiser(const std::string &initialiser); template - void addVarNameSubstitution(const std::vector &variables) + void addVarNameSubstitution(const std::vector &variables, const std::string &fieldSuffix = "") { for(const auto &v : variables) { - addSubstitution(v.name, "group->" + v.name); + addSubstitution(v.name, "group->" + v.name + fieldSuffix); } } template void addParamValueSubstitution(const std::vector ¶mNames, const std::unordered_map &values, - G isHeterogeneousFn) + const std::string &fieldSuffix, G isHeterogeneousFn) { if(paramNames.size() != values.size()) { throw std::runtime_error("Number of parameters does not match number of values"); @@ -121,7 +121,7 @@ class EnvironmentSubstitute : public EnvironmentExternal for(const auto &p : paramNames) { if(isHeterogeneousFn(p)) { - addSubstitution(p, "group->" + p); + addSubstitution(p, "group->" + p + fieldSuffix); } else { // **TODO** scalar suffix @@ -132,7 +132,7 @@ class EnvironmentSubstitute : public EnvironmentExternal template void addVarValueSubstitution(const std::vector &variables, const std::unordered_map &values, - G isHeterogeneousFn) + const std::string &fieldSuffix, G isHeterogeneousFn) { if(variables.size() != values.size()) { throw std::runtime_error("Number of variables does not match number of values"); @@ -140,7 +140,7 @@ class EnvironmentSubstitute : public EnvironmentExternal for(const auto &v : variables) { if(isHeterogeneousFn(v.name)) { - addSubstitution(v.name, "group->" + v.name); + addSubstitution(v.name, "group->" + v.name + fieldSuffix); } else { addSubstitution(v.name, Utils::writePreciseString(values.at(v.name))); @@ -156,7 +156,6 @@ class EnvironmentSubstitute : public EnvironmentExternal CodeStream m_Contents; std::unordered_map>> m_VarSubstitutions; std::vector> m_Initialisers; - }; //---------------------------------------------------------------------------- @@ -177,9 +176,10 @@ class EnvironmentLocalVarCache : public EnvironmentExternal public: EnvironmentLocalVarCache(const G &group, const Type::TypeContext &context, EnvironmentExternal &enclosing, - GetIndexFn getReadIndex, GetIndexFn getWriteIndex, const std::string &localPrefix = "l") + const std::string &fieldSuffix, const std::string & localPrefix, + GetIndexFn getReadIndex, GetIndexFn getWriteIndex) : EnvironmentExternal(static_cast(enclosing)), m_Group(group), m_Context(context), m_Contents(m_ContentsStream), - m_LocalPrefix(localPrefix), m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex) + m_FieldSuffix(fieldSuffix), m_LocalPrefix(localPrefix), m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex) { // Add name of each definition to map, initially with value set to value const auto defs = A(m_Group).getDefs(); @@ -187,9 +187,10 @@ class EnvironmentLocalVarCache : public EnvironmentExternal [](const auto &v){ return std::make_pair(v.name, false); }); } - EnvironmentLocalVarCache(const G &group, const Type::TypeContext &context, EnvironmentExternal &enclosing, GetIndexFn getIndex, const std::string &localPrefix = "l") + EnvironmentLocalVarCache(const G &group, const Type::TypeContext &context, EnvironmentExternal &enclosing, + const std::string &fieldSuffix, const std::string & localPrefix, GetIndexFn getIndex) : EnvironmentExternal(static_cast(enclosing)), m_Group(group), m_Context(context), - m_Contents(m_ContentsStream), m_LocalPrefix(localPrefix), m_GetReadIndex(getIndex), m_GetWriteIndex(getIndex) + m_Contents(m_ContentsStream), m_FieldSuffix(fieldSuffix), m_LocalPrefix(localPrefix), m_GetReadIndex(getIndex), m_GetWriteIndex(getIndex) { // Add name of each definition to map, initially with value set to value const auto defs = A(m_Group).getDefs(); @@ -221,7 +222,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal // **NOTE** by not initialising these variables for reductions, // compilers SHOULD emit a warning if user code doesn't set it to something if(!(v.access & VarAccessModeAttribute::REDUCE)) { - getContextStream() << " = group->" << v.name << "[" << m_GetReadIndex(v.name, initialisers.at(v.name), v.access) << "]"; + getContextStream() << " = group->" << v.name << m_FieldSuffix << "[" << m_GetReadIndex(v.name, initialisers.at(v.name), v.access) << "]"; } getContextStream() << ";" << std::endl; } @@ -233,7 +234,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal for(const auto &v : referencedVars) { // If variables are read-write if(v.access & VarAccessMode::READ_WRITE) { - getContextStream() << "group->" << v.name << "[" << m_GetWriteIndex(v.name, initialisers.at(v.name), v.access) << "]"; + getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << m_GetWriteIndex(v.name, initialisers.at(v.name), v.access) << "]"; getContextStream() << " = " << m_LocalPrefix << v.name << ";" << std::endl; } } @@ -272,6 +273,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternal const Type::TypeContext &m_Context; std::ostringstream m_ContentsStream; CodeStream m_Contents; + std::string m_FieldSuffix; std::string m_LocalPrefix; GetIndexFn m_GetReadIndex; GetIndexFn m_GetWriteIndex; diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h index 72da9c1a33..15b779d150 100644 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h @@ -97,7 +97,7 @@ class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBa const auto qualifiedType = (access & VarAccessModeAttribute::READ_ONLY) ? type.addQualifier(Type::Qualifier::CONSTANT) : type; defineField(qualifiedType, name, type.createPointer(), name + fieldSuffix, - [prefix, getVarSuffixFn](const auto &g, size_t) { return prefix + std::invoke(getVarSuffixFn, g); }); + [name, prefix, getVarSuffixFn](const auto &g, size_t) { return prefix + name + std::invoke(getVarSuffixFn, g); }); } void definePointerField(const Type::UnresolvedType &type, const std::string &name, const std::string &prefix, VarAccessMode access, diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 90da248205..e38950bea7 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -11,6 +11,8 @@ namespace GeNN::CodeGenerator class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase { public: + using VarInitAST = std::unordered_map>; + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::CurrentSource //---------------------------------------------------------------------------- @@ -24,8 +26,8 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const; //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -47,7 +49,7 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase // Members //---------------------------------------------------------------------------- //! Parsed statements and resolved types for initialising each variable - std::unordered_map> m_VarInitASTs; + VarInitAST m_VarInitASTs; }; //---------------------------------------------------------------------------- @@ -63,8 +65,8 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const; //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -86,7 +88,7 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase // Members //---------------------------------------------------------------------------- //! Parsed statements and resolved types for initialising each variable - std::unordered_map> m_VarInitASTs; + VarInitAST m_VarInitASTs; }; //---------------------------------------------------------------------------- @@ -102,8 +104,8 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const; }; //---------------------------------------------------------------------------- @@ -119,8 +121,8 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const; //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -142,7 +144,7 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase // Members //---------------------------------------------------------------------------- //! Parsed statements and resolved types for initialising each variable - std::unordered_map> m_VarInitASTs; + VarInitAST m_VarInitASTs; }; //---------------------------------------------------------------------------- @@ -158,8 +160,8 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const; //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -181,7 +183,7 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase // Members //---------------------------------------------------------------------------- //! Parsed statements and resolved types for initialising each variable - std::unordered_map> m_VarInitASTs; + VarInitAST m_VarInitASTs; }; NeuronInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, @@ -202,7 +204,7 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase runnerVarDecl, runnerMergedStructAlloc, name); } - void generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateInit(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged) const; const std::vector &getMergedCurrentSourceGroups() const { return m_MergedCurrentSourceGroups; } const std::vector &getMergedInSynPSMGroups() const { return m_MergedInSynPSMGroups; } @@ -219,14 +221,11 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //------------------------------------------------------------------------ // Private methods //------------------------------------------------------------------------ - void genInitSpikeCount(CodeStream &os, const BackendBase &backend, const Substitutions &popSubs, - bool spikeEvent, unsigned int batchSize) const; + void genInitSpikeCount(const BackendBase &backend, EnvironmentExternal &env, bool spikeEvent, unsigned int batchSize) const; - void genInitSpikes(CodeStream &os, const BackendBase &backend, const Substitutions &popSubs, - bool spikeEvent, unsigned int batchSize) const; + void genInitSpikes(const BackendBase &backend, EnvironmentExternal &env, bool spikeEvent, unsigned int batchSize) const; - void genInitSpikeTime(CodeStream &os, const BackendBase &backend, const Substitutions &popSubs, - const std::string &varName, unsigned int batchSize) const; + void genInitSpikeTime(const BackendBase &backend, EnvironmentExternal &env, const std::string &varName, unsigned int batchSize) const; //------------------------------------------------------------------------ // Members @@ -236,6 +235,9 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase std::vector m_MergedOutSynPreOutputGroups; std::vector m_MergedInSynWUMPostVarGroups; std::vector m_MergedOutSynWUMPreVarGroups; + + //! Parsed statements and resolved types for initialising each variable + VarInitAST m_VarInitASTs; }; diff --git a/include/genn/genn/currentSourceInternal.h b/include/genn/genn/currentSourceInternal.h index 06973292b2..ba6a7628a2 100644 --- a/include/genn/genn/currentSourceInternal.h +++ b/include/genn/genn/currentSourceInternal.h @@ -49,6 +49,8 @@ class CurrentSourceVarAdapter const std::unordered_map &getInitialisers() const{ return m_CS.getVarInitialisers(); } + const std::string &getNameSuffix() const{ return m_CS.getName(); } + private: //---------------------------------------------------------------------------- // Members diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 21c258387e..37d5364bb1 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -185,6 +185,8 @@ class CustomUpdateVarAdapter const std::unordered_map &getInitialisers() const{ return m_CU.getVarInitialisers(); } + const std::string &getNameSuffix() const{ return m_CU.getName(); } + private: //---------------------------------------------------------------------------- // Members diff --git a/include/genn/genn/neuronGroupInternal.h b/include/genn/genn/neuronGroupInternal.h index 86ec247516..22f8ace253 100644 --- a/include/genn/genn/neuronGroupInternal.h +++ b/include/genn/genn/neuronGroupInternal.h @@ -66,6 +66,10 @@ class NeuronVarAdapter const std::unordered_map &getInitialisers() const{ return m_NG.getVarInitialisers(); } + bool isVarDelayed(const std::string &varName) const{ return m_NG.isVarQueueRequired(varName); } + + const std::string &getNameSuffix() const{ return m_NG.getName(); } + private: //---------------------------------------------------------------------------- // Members diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h index ab5529346c..e1d911d411 100644 --- a/include/genn/genn/synapseGroupInternal.h +++ b/include/genn/genn/synapseGroupInternal.h @@ -95,7 +95,9 @@ class SynapsePSMVarAdapter const std::unordered_map &getInitialisers() const{ return m_SG.getPSVarInitialisers(); } - const std::string &getFusedSuffix() const{ return m_SG.getFusedPSVarSuffix(); } + const std::string &getNameSuffix() const{ return m_SG.getFusedPSVarSuffix(); } + + bool isVarDelayed(const std::string &) const { return false; } private: //---------------------------------------------------------------------------- @@ -147,7 +149,9 @@ class SynapseWUPreVarAdapter const std::unordered_map &getInitialisers() const{ return m_SG.getWUPreVarInitialisers(); } - const std::string &getFusedSuffix() const{ return m_SG.getFusedWUPreVarSuffix(); } + const std::string &getNameSuffix() const{ return m_SG.getFusedWUPreVarSuffix(); } + + bool isVarDelayed(const std::string&) const{ return (m_SG.getDelaySteps() != 0); } private: //---------------------------------------------------------------------------- @@ -174,7 +178,9 @@ class SynapseWUPostVarAdapter const std::unordered_map &getInitialisers() const{ return m_SG.getWUPostVarInitialisers(); } - const std::string &getFusedSuffix() const{ return m_SG.getFusedWUPostVarSuffix(); } + const std::string &getNameSuffix() const{ return m_SG.getFusedWUPostVarSuffix(); } + + bool isVarDelayed(const std::string&) const{ return (m_SG.getBackPropDelaySteps() != 0); } private: //---------------------------------------------------------------------------- diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 55d2f12761..7c657721a0 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1386,23 +1386,21 @@ std::optional Backend::getMergedGroupSimRNGType() const return std::nullopt; } //-------------------------------------------------------------------------- -void Backend::genPopVariableInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const +void Backend::genPopVariableInit(EnvironmentExternal &env, HandlerEnv handler) const { - Substitutions varSubs(&kernelSubs); - handler(os, varSubs); + handler(env); } //-------------------------------------------------------------------------- -void Backend::genVariableInit(CodeStream &os, const std::string &count, const std::string &indexVarName, - const Substitutions &kernelSubs, Handler handler) const +void Backend::genVariableInit(EnvironmentExternal &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const { - // **TODO** loops like this should be generated like CUDA threads - os << "for (unsigned i = 0; i < (" << count << "); i++)"; + // **TODO** loops like this should be generated like CUDA threads + env.getStream() << "for (unsigned i = 0; i < (" << count << "); i++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); - Substitutions varSubs(&kernelSubs); - varSubs.addVarSubstitution(indexVarName, "i"); - handler(os, varSubs); + EnvironmentSubstitute varSubs(env); + varSubs.addSubstitution(indexVarName, "i"); + handler(varSubs); } } //-------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 6e33528fc7..972d4f0de9 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -60,72 +60,124 @@ void genScalarFill(CodeStream &os, const std::string &fieldName, const std::stri } } } + +template +NeuronInitGroupMerged::VarInitAST addInitNeuronVarFields(const BackendBase &backend, TypeChecker::EnvironmentBase &enclosingEnv, + const G &groupMerged, const std::string &fieldSuffix) +{ + // Loop through variables + NeuronInitGroupMerged::VarInitAST varInitAST; + A archetypeAdaptor(groupMerged.getArchetype()); + for (const auto &var : archetypeAdaptor.getDefs()) { + // If there is any initialisation code + const auto *snippet = archetypeAdaptor.getInitialisers().at(var.name).getSnippet(); + if (!snippet->getCode().empty()) { + // Create type environment for this variable's initialisation + GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); + + // Add pointers to state variable itself + //typeEnvironment.definePointerField(var.type, var.name, backend.getDeviceVarPrefix(), + // getVarAccessMode(var.access), suffix, &SynapseGroupInternal::getFusedPSVarSuffix);*/ + const auto varResolvedType = var.type.resolve(groupMerged.getTypeContext()); + const auto varQualifiedType = (var.access & VarAccessModeAttribute::READ_ONLY) ? varResolvedType.addQualifier(Type::Qualifier::CONSTANT) : varResolvedType; + defineField(varQualifiedType, var.name, + varResolvedType.createPointer(), var.name + fieldSuffix, + [&backend, var](const auto &g, size_t) + { + return backend.getDeviceVarPrefix() + var.name + A(g).getNameSuffix(); + }); + + + // Add heterogeneous var init parameters + typeEnvironment.defineHeterogeneousVarInitParams(&G::isVarInitParamHeterogeneous, suffix); + typeEnvironment.defineHeterogeneousVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, suffix); + + // Add EGPs + typeEnvironment.defineEGPs(snippet->getExtraGlobalParams(), backend.getDeviceVarPrefix(), + var.name, suffix); + + // Scan, parse and type-check update code + ErrorHandler errorHandler; + const std::string code = upgradeCodeString(snippet->getCode()); + const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); + + auto initStatements = Parser::parseBlockItemList(tokens, typeContext, errorHandler); + auto initTypes = TypeChecker::typeCheck(initStatements, typeEnvironment, errorHandler); + + // Add to map of per-variable initialistion AST + varInitAST.emplace(var.name, std::make_tuple(std::move(initStatements), std::move(initTypes))); + } + } + + return varInitAST; +} //------------------------------------------------------------------------ -template -void genInitNeuronVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const BackendBase &backend, const Substitutions &popSubs, - const Models::Base::VarVec &vars, const std::unordered_map &varInitialisers, - const std::string &fieldSuffix, const std::string &countMember, - size_t numDelaySlots, const size_t groupIndex, unsigned int batchSize, - Q isVarQueueRequired, P isParamHeterogeneousFn, D isDerivedParamHeterogeneousFn) +template +void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternal &env, + const G &groupMerged, const NeuronInitGroupMerged::VarInitAST &varInitAST, const std::string &fieldSuffix, + const std::string &countMember, size_t numDelaySlots, unsigned int batchSizet) { + A adaptor(groupMerged.getArchetype()); const std::string count = "group->" + countMember; - for (const auto &var : vars) { - const auto &varInit = varInitialisers.at(var.name); + for (const auto &var : adaptor.getDefs()) { + const auto &varInit = adaptor.getInitialisers().at(var.name); + const auto &varAST = varInitAST.at(var.name); - // If this variable has any initialisation code - if (!varInit.getSnippet()->getCode().empty()) { - CodeStream::Scope b(os); + // If there are any initialisation statements for this variable + if (!std::get<0>(varAST).empty()) { + CodeStream::Scope b(env.getStream()); - Substitutions varSubs(&popSubs); + EnvironmentSubstitute varEnv(&env); // Substitute in parameters and derived parameters for initialising variables - varSubs.addParamValueSubstitution(varInit.getSnippet()->getParamNames(), varInit.getParams(), - [&var, isParamHeterogeneousFn](const std::string &p) { return isParamHeterogeneousFn(var.name, p); }, - "", "group->", var.name + fieldSuffix); - varSubs.addVarValueSubstitution(varInit.getSnippet()->getDerivedParams(), varInit.getDerivedParams(), - [&var, isDerivedParamHeterogeneousFn](const std::string &p) { return isDerivedParamHeterogeneousFn(var.name, p); }, - "", "group->", var.name + fieldSuffix); - varSubs.addVarNameSubstitution(varInit.getSnippet()->getExtraGlobalParams(), - "", "group->", var.name + fieldSuffix); + varEnv.addParamValueSubstitution(varInit.getSnippet()->getParamNames(), varInit.getParams(), var.name + fieldSuffix, + [&groupMerged, &var, isParamHeterogeneousFn](const std::string &p) + { + return groupMerged.isParamHeterogeneous(groupMerged, var.name, p); + }); + varEnv.addVarValueSubstitution(varInit.getSnippet()->getDerivedParams(), varInit.getDerivedParams(), var.name + fieldSuffix, + [&groupMerged, &var, isDerivedParamHeterogeneousFn](const std::string &p) + { + return groupMerged.isDerivedParamHeterogeneous(groupMerged, var.name, p); + }); + varEnv.addVarNameSubstitution(varInit.getSnippet()->getExtraGlobalParams(), var.name + fieldSuffix); // If variable is shared between neurons if (getVarAccessDuplication(var.access) == VarAccessDuplication::SHARED_NEURON) { backend.genPopVariableInit( - os, varSubs, - [&var, &varInit, &fieldSuffix, &modelMerged, batchSize, groupIndex, numDelaySlots, isVarQueueRequired] - (CodeStream &os, Substitutions &varInitSubs) + varEnvs, + [&adaptor, &groupMerged, &var, &varAST, &fieldSuffix, batchSize, numDelaySlots] + (EnvironmentExternal &varInitEnv) { // Generate initial value into temporary variable - os << var.type.resolve(modelMerged.getTypeContext()).getName() << " initVal;" << std::endl; - varInitSubs.addVarSubstitution("value", "initVal"); - std::string code = varInit.getSnippet()->getCode(); - varInitSubs.applyCheckUnreplaced(code, "initVar : " + var.name + "merged" + std::to_string(groupIndex)); - //code = ensureFtype(code, scalarType); - os << code << std::endl; - + varInitEnv.getStream() << var.type.resolve(groupMerged.getTypeContext()).getName() << " initVal;" << std::endl; + varInitEnv.addVarSubstitution("value", "initVal"); + + // Pretty print variable initialisation code + PrettyPrinter::print(std::get<0>(varAST), varInitEnv, groupMerged.getTypeContext(), std::get<1>(varAST)); + // Fill value across all delay slots and batches - genScalarFill(os, var.name + fieldSuffix, "initVal", getVarAccessDuplication(var.access), - batchSize, isVarQueueRequired(var.name), numDelaySlots); + genScalarFill(varInitEnv.getStream(), var.name + fieldSuffix, "initVal", getVarAccessDuplication(var.access), + batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } // Otherwise else { backend.genVariableInit( - os, count, "id", varSubs, - [&var, &varInit, &modelMerged, &fieldSuffix, batchSize, groupIndex, count, numDelaySlots, isVarQueueRequired] - (CodeStream &os, Substitutions &varInitSubs) + varEnvs, count, "id", + [&adaptor, &groupMerged, &var, &varAST, &fieldSuffix, batchSize, count, numDelaySlots] + (EnvironmentExternal &varInitEnv) { // Generate initial value into temporary variable - os << var.type.resolve(modelMerged.getTypeContext()).getName() << " initVal;" << std::endl; - varInitSubs.addVarSubstitution("value", "initVal"); - std::string code = varInit.getSnippet()->getCode(); - varInitSubs.applyCheckUnreplaced(code, "initVar : " + var.name + "merged" + std::to_string(groupIndex)); - //code = ensureFtype(code, ftype); - os << code << std::endl; + varInitEnv.getStream() << var.type.resolve(groupMerged.getTypeContext()).getName() << " initVal;" << std::endl; + varInitEnv.addVarSubstitution("value", "initVal"); + + // Pretty print variable initialisation code + PrettyPrinter::print(std::get<0>(varAST), varInitEnv, groupMerged.getTypeContext(), std::get<1>(varAST)); // Fill value across all delay slots and batches - genVariableFill(os, var.name + fieldSuffix, "initVal", varInitSubs["id"], count, - getVarAccessDuplication(var.access), batchSize, isVarQueueRequired(var.name), numDelaySlots); + genVariableFill(varInitEnv.getStream(), var.name + fieldSuffix, "initVal", varInitSubs["id"], count, + getVarAccessDuplication(var.access), batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } } @@ -133,18 +185,6 @@ void genInitNeuronVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, co } } //------------------------------------------------------------------------ -template -void genInitNeuronVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const BackendBase &backend, const Substitutions &popSubs, - const Models::Base::VarVec &vars, const std::unordered_map &varInitialisers, - const std::string &fieldSuffix, const std::string &countMember, const size_t groupIndex, - unsigned int batchSize, P isParamHeterogeneousFn, D isDerivedParamHeterogeneousFn) -{ - genInitNeuronVarCode(os, modelMerged, backend, popSubs, vars, varInitialisers, fieldSuffix, countMember, 0, groupIndex, batchSize, - [](const std::string&){ return false; }, - isParamHeterogeneousFn, - isDerivedParamHeterogeneousFn); -} -//------------------------------------------------------------------------ // Initialise one row of weight update model variables template void genInitWUVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const Substitutions &popSubs, @@ -227,15 +267,11 @@ NeuronInitGroupMerged::CurrentSource::CurrentSource(size_t index, const Type::Ty } } //---------------------------------------------------------------------------- -void NeuronInitGroupMerged::CurrentSource::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void NeuronInitGroupMerged::CurrentSource::generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const { - const std::string suffix = "CS" + std::to_string(getIndex()); - - genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getCurrentSourceModel()->getVars(), getArchetype().getVarInitialisers(), - suffix, "numNeurons", getIndex(), modelMerged.getModel().getBatchSize(), - [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, - [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); + genInitNeuronVarCode(backend, env, *this, m_VarInitASTs, "CS" + std::to_string(getIndex()), + "numNeurons", 0, modelMerged.getModel().getBatchSize()); } //---------------------------------------------------------------------------- void NeuronInitGroupMerged::CurrentSource::updateHash(boost::uuids::detail::sha1 &hash) const @@ -246,16 +282,14 @@ void NeuronInitGroupMerged::CurrentSource::updateHash(boost::uuids::detail::sha1 //---------------------------------------------------------------------------- bool NeuronInitGroupMerged::CurrentSource::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { - return (isVarInitParamReferenced(varName, paramName) && - isParamValueHeterogeneous(paramName, - [varName](const auto &cs){ return cs.getVarInitialisers().at(varName).getParams(); })); + return isParamValueHeterogeneous(paramName, + [varName](const auto &cs){ return cs.getVarInitialisers().at(varName).getParams(); }); } //---------------------------------------------------------------------------- bool NeuronInitGroupMerged::CurrentSource::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { - return (isVarInitParamReferenced(varName, paramName) && - isParamValueHeterogeneous(paramName, - [varName](const auto &cs){ return cs.getVarInitialisers().at(varName).getDerivedParams(); })); + return isParamValueHeterogeneous(paramName, + [varName](const auto &cs){ return cs.getVarInitialisers().at(varName).getDerivedParams(); }); } //---------------------------------------------------------------------------- bool NeuronInitGroupMerged::CurrentSource::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const @@ -286,53 +320,21 @@ NeuronInitGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContext [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); } - // Loop through variables - // **TODO** adaptor - const auto &varInit = getArchetype().getPSVarInitialisers(); - for(const auto &var : getArchetype().getPSModel()->getVars()) { - // If there is any initialisation code - const auto *snippet = varInit.at(var.name).getSnippet(); - if (!snippet->getCode().empty()) { - // Create type environment for this variable's initialisation - GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); - - // Add pointers to state variable itself - typeEnvironment.definePointerField(var.type, var.name, backend.getDeviceVarPrefix(), - getVarAccessMode(var.access), suffix, &SynapseGroupInternal::getFusedPSVarSuffix); - - // Add heterogeneous var init parameters - typeEnvironment.defineHeterogeneousVarInitParams(&InSynPSM::isVarInitParamHeterogeneous, suffix); - typeEnvironment.defineHeterogeneousVarInitDerivedParams(&InSynPSM::isVarInitDerivedParamHeterogeneous, suffix); - - // Add EGPs - typeEnvironment.defineEGPs(snippet->getExtraGlobalParams(), backend.getDeviceVarPrefix(), - var.name, suffix, &SynapseGroupInternal::getFusedPSVarSuffix); - - // Scan, parse and type-check update code - ErrorHandler errorHandler; - const std::string code = upgradeCodeString(snippet->getCode()); - const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); - - auto initStatements = Parser::parseBlockItemList(tokens, typeContext, errorHandler); - auto initTypes = TypeChecker::typeCheck(initStatements, typeEnvironment, errorHandler); - - // Add to map of per-variable initialistion AST - m_VarInitASTs.emplace(var.name, std::make_tuple(std::move(initStatements), std::move(initTypes))); - } - } + // Add fields required to initialise PSM variables and get AST + m_VarInitASTs = addInitNeuronVarFields(backend, enclosingEnv, *this, suffix); } //---------------------------------------------------------------------------- -void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const { const std::string suffix = "InSyn" + std::to_string(getIndex()); // Zero InSyn - backend.genVariableInit(os, "group->numNeurons", "id", popSubs, - [&modelMerged, &suffix] (CodeStream &os, Substitutions &varSubs) + backend.genVariableInit(env, "group->numNeurons", "id", + [&modelMerged, &suffix] (EnvironmentExternal &varEnv) { - genVariableFill(os, "inSyn" + suffix, modelMerged.scalarExpr(0.0), - varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, + genVariableFill(varEnv.getStream(), "inSyn" + suffix, modelMerged.scalarExpr(0.0), + varEnv["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, modelMerged.getModel().getBatchSize()); }); @@ -340,28 +342,25 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, CodeS // If dendritic delays are required if(getArchetype().isDendriticDelayRequired()) { // Zero dendritic delay buffer - backend.genVariableInit(os, "group->numNeurons", "id", popSubs, - [&modelMerged, &suffix, this](CodeStream &os, Substitutions &varSubs) + backend.genVariableInit(env, "group->numNeurons", "id", + [&modelMerged, &suffix, this](EnvironmentExternal &varEnv) { - genVariableFill(os, "denDelay" + suffix, modelMerged.scalarExpr(0.0), - varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, + genVariableFill(varEnv.getStream(), "denDelay" + suffix, modelMerged.scalarExpr(0.0), + varEnv["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, modelMerged.getModel().getBatchSize(), true, getArchetype().getMaxDendriticDelayTimesteps()); }); // Zero dendritic delay pointer - backend.genPopVariableInit(os, popSubs, - [&suffix](CodeStream &os, Substitutions &) + backend.genPopVariableInit(env, + [&suffix](EnvironmentExternal &varEnv) { - os << "*group->denDelayPtr" << suffix << " = 0;" << std::endl; + varEnv.getStream() << "*group->denDelayPtr" << suffix << " = 0;" << std::endl; }); } - // **TODO** adaptor - genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getPSModel()->getVars(), getArchetype().getPSVarInitialisers(), - suffix, "numNeurons", getIndex(), modelMerged.getModel().getBatchSize(), - [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, - [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); + genInitNeuronVarCode(backend, env, *this, m_VarInitASTs, suffix, + "numNeurons", 1, modelMerged.getModel().getBatchSize()); } //---------------------------------------------------------------------------- void NeuronInitGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &hash) const @@ -372,16 +371,14 @@ void NeuronInitGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &has //---------------------------------------------------------------------------- bool NeuronInitGroupMerged::InSynPSM::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { - return (isVarInitParamReferenced(varName, paramName) && - isParamValueHeterogeneous(paramName, - [varName](const auto &sg){ return sg.getPSVarInitialisers().at(varName).getParams(); })); + return isParamValueHeterogeneous(paramName, + [varName](const auto &sg){ return sg.getPSVarInitialisers().at(varName).getParams(); }); } //---------------------------------------------------------------------------- bool NeuronInitGroupMerged::InSynPSM::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { - return (isVarInitParamReferenced(varName, paramName) && - isParamValueHeterogeneous(paramName, - [varName](const auto &sg){ return sg.getPSVarInitialisers().at(varName).getDerivedParams(); })); + return isParamValueHeterogeneous(paramName, + [varName](const auto &sg){ return sg.getPSVarInitialisers().at(varName).getDerivedParams(); }); } //---------------------------------------------------------------------------- bool NeuronInitGroupMerged::InSynPSM::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const @@ -403,18 +400,18 @@ NeuronInitGroupMerged::OutSynPreOutput::OutSynPreOutput(size_t index, const Type [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "revInSyn" + g.getFusedPreOutputSuffix(); }); } //---------------------------------------------------------------------------- -void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const { const std::string suffix = "OutSyn" + std::to_string(getIndex()); - backend.genVariableInit(os, "group->numNeurons", "id", popSubs, - [&modelMerged, suffix] (CodeStream &os, Substitutions &varSubs) - { - genVariableFill(os, "revInSyn" + suffix, modelMerged.scalarExpr(0.0), - varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, - modelMerged.getModel().getBatchSize()); - }); + backend.genVariableInit(env, "group->numNeurons", "id", + [&modelMerged, suffix] (EnvironmentExternal &varEnv) + { + genVariableFill(varEnv.getStream(), "revInSyn" + suffix, modelMerged.scalarExpr(0.0), + varEnv["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, + modelMerged.getModel().getBatchSize()); + }); } //---------------------------------------------------------------------------- @@ -424,46 +421,16 @@ NeuronInitGroupMerged::InSynWUMPostVars::InSynWUMPostVars(size_t index, const Ty const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, typeContext, groups) { - const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); - - // Loop through variables - // **TODO** adaptor - const auto &varInit = getArchetype().getWUPostVarInitialisers(); - for(const auto &var : getArchetype().getWUModel()->getPostVars()) { - // Add pointers to state variable - if(!varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(var.type.resolve(getTypeContext()).createPointer(), var.name + suffix, - [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getFusedWUPostVarSuffix(); }); - } - - // Add heterogeneous var init parameters - addHeterogeneousVarInitParams( - &InSynWUMPostVars::isVarInitParamHeterogeneous, suffix); - addHeterogeneousVarInitDerivedParams( - &InSynWUMPostVars::isVarInitDerivedParamHeterogeneous, suffix); - - // Add extra global parameters - for(const auto &e : varInit.at(var.name).getSnippet()->getExtraGlobalParams()) { - addField(e.type.resolve(getTypeContext()).createPointer(), e.name + var.name + suffix, - [&backend, e, suffix, var](const auto &g, size_t) - { - return backend.getDeviceVarPrefix() + e.name + var.name + g.getFusedWUPostVarSuffix(); - }, - GroupMergedFieldType::DYNAMIC); - } - } + // Add fields required to initialise PSM variables and get AST + m_VarInitASTs = addInitNeuronVarFields(backend, enclosingEnv, *this, + "InSynWUMPost" + std::to_string(getIndex())); } //---------------------------------------------------------------------------- -void NeuronInitGroupMerged::InSynWUMPostVars::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void NeuronInitGroupMerged::InSynWUMPostVars::generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const { - const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); - - genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getWUModel()->getPostVars(), getArchetype().getWUPostVarInitialisers(), - suffix, "numNeurons", getArchetype().getTrgNeuronGroup()->getNumDelaySlots(), getIndex(), modelMerged.getModel().getBatchSize(), - [this](const std::string&){ return (getArchetype().getBackPropDelaySteps() != NO_DELAY); }, - [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, - [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); + genInitNeuronVarCode(backend, env, *this, m_VarInitASTs, "InSynWUMPost" + std::to_string(getIndex()), + "numNeurons", 1, modelMerged.getModel().getBatchSize()); } //---------------------------------------------------------------------------- void NeuronInitGroupMerged::InSynWUMPostVars::updateHash(boost::uuids::detail::sha1 &hash) const @@ -474,16 +441,14 @@ void NeuronInitGroupMerged::InSynWUMPostVars::updateHash(boost::uuids::detail::s //---------------------------------------------------------------------------- bool NeuronInitGroupMerged::InSynWUMPostVars::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { - return (isVarInitParamReferenced(varName, paramName) && - isParamValueHeterogeneous(paramName, - [varName](const auto &sg){ return sg.getWUPostVarInitialisers().at(varName).getParams(); })); + return isParamValueHeterogeneous(paramName, + [varName](const auto &sg){ return sg.getWUPostVarInitialisers().at(varName).getParams(); }); } //---------------------------------------------------------------------------- bool NeuronInitGroupMerged::InSynWUMPostVars::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { - return (isVarInitParamReferenced(varName, paramName) && - isParamValueHeterogeneous(paramName, - [varName](const auto &sg){ return sg.getWUPostVarInitialisers().at(varName).getDerivedParams(); })); + return isParamValueHeterogeneous(paramName, + [varName](const auto &sg){ return sg.getWUPostVarInitialisers().at(varName).getDerivedParams(); }); } //---------------------------------------------------------------------------- bool NeuronInitGroupMerged::InSynWUMPostVars::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const @@ -499,46 +464,16 @@ NeuronInitGroupMerged::OutSynWUMPreVars::OutSynWUMPreVars(size_t index, const Ty const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, typeContext, groups) { - const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); - - // Loop through variables - // **TODO** adaptor - const auto &varInit = getArchetype().getWUPreVarInitialisers(); - for(const auto &var : getArchetype().getWUModel()->getPreVars()) { - // Add pointers to state variable - if(!varInit.at(var.name).getSnippet()->getCode().empty()) { - addField(var.type.resolve(getTypeContext()).createPointer(), var.name + suffix, - [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getFusedWUPreVarSuffix(); }); - } - - // Add heterogeneous var init parameters - addHeterogeneousVarInitParams( - &OutSynWUMPreVars::isVarInitParamHeterogeneous, suffix); - addHeterogeneousVarInitDerivedParams( - &OutSynWUMPreVars::isVarInitDerivedParamHeterogeneous, suffix); - - // Add extra global parameters - for(const auto &e : varInit.at(var.name).getSnippet()->getExtraGlobalParams()) { - addField(e.type.resolve(getTypeContext()).createPointer(), e.name + var.name + suffix, - [&backend, e, suffix, var](const auto &g, size_t) - { - return backend.getDeviceVarPrefix() + e.name + var.name + g.getFusedWUPreVarSuffix(); - }, - GroupMergedFieldType::DYNAMIC); - } - } + // Add fields required to initialise PSM variables and get AST + m_VarInitASTs = addInitNeuronVarFields(backend, enclosingEnv, *this, + "OutSynWUMPre" + std::to_string(getIndex())); } //---------------------------------------------------------------------------- -void NeuronInitGroupMerged::OutSynWUMPreVars::generate(const BackendBase &backend, CodeStream &os, const NeuronInitGroupMerged &ng, - const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void NeuronInitGroupMerged::OutSynWUMPreVars::generate(const BackendBase &backend, EnvironmentExternal &env, + const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const { - const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); - - genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getWUModel()->getPreVars(), getArchetype().getWUPreVarInitialisers(), - suffix, "numNeurons", getArchetype().getSrcNeuronGroup()->getNumDelaySlots(), getIndex(), modelMerged.getModel().getBatchSize(), - [this](const std::string&){ return (getArchetype().getDelaySteps() != NO_DELAY); }, - [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, - [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); + genInitNeuronVarCode(backend, env, *this, m_VarInitASTs, "OutSynWUMPre" + std::to_string(getIndex()), + "numNeurons", 1, modelMerged.getModel().getBatchSize()); } //---------------------------------------------------------------------------- void NeuronInitGroupMerged::OutSynWUMPreVars::updateHash(boost::uuids::detail::sha1 &hash) const @@ -549,16 +484,14 @@ void NeuronInitGroupMerged::OutSynWUMPreVars::updateHash(boost::uuids::detail::s //---------------------------------------------------------------------------- bool NeuronInitGroupMerged::OutSynWUMPreVars::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { - return (isVarInitParamReferenced(varName, paramName) && - isParamValueHeterogeneous(paramName, - [varName](const auto &sg){ return sg.getWUPreVarInitialisers().at(varName).getParams(); })); + return isParamValueHeterogeneous(paramName, + [varName](const auto &sg){ return sg.getWUPreVarInitialisers().at(varName).getParams(); }); } //---------------------------------------------------------------------------- bool NeuronInitGroupMerged::OutSynWUMPreVars::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { - return (isVarInitParamReferenced(varName, paramName) && - isParamValueHeterogeneous(paramName, - [varName](const auto &sg){ return sg.getWUPreVarInitialisers().at(varName).getDerivedParams(); })); + return isParamValueHeterogeneous(paramName, + [varName](const auto &sg){ return sg.getWUPreVarInitialisers().at(varName).getDerivedParams(); }); } //---------------------------------------------------------------------------- bool NeuronInitGroupMerged::OutSynWUMPreVars::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const @@ -587,27 +520,8 @@ NeuronInitGroupMerged::NeuronInitGroupMerged(size_t index, const Type::TypeConte addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); } - // Loop through variables - const NeuronModels::Base *nm = getArchetype().getNeuronModel(); - const auto vars = nm->getVars(); - const auto &varInit = getArchetype().getVarInitialisers(); - for(const auto &var : vars) { - // If we're not initialising or if there is initialization code for this variable - if(!varInit.at(var.name).getSnippet()->getCode().empty()) { - addPointerField(var.type, var.name, - backend.getDeviceVarPrefix() + var.name); - } - - // Add any var init EGPs to structure - addEGPs(varInit.at(var.name).getSnippet()->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name); - } - - // Add heterogeneous var init parameters - addHeterogeneousVarInitParams( - &NeuronGroupMergedBase::isVarInitParamHeterogeneous); - - addHeterogeneousVarInitDerivedParams( - &NeuronGroupMergedBase::isVarInitDerivedParamHeterogeneous); + // Add fields required to initialise PSM variables and get AST + m_VarInitASTs = addInitNeuronVarFields(backend, typeEnvironment, *this, ""); // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, typeEnvironment, backend, @@ -668,41 +582,41 @@ boost::uuids::detail::sha1::digest_type NeuronInitGroupMerged::getHashDigest() c return hash.get_digest(); } //---------------------------------------------------------------------------- -void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void NeuronInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged) const { const auto &model = modelMerged.getModel(); // Initialise spike counts - genInitSpikeCount(os, backend, popSubs, false, model.getBatchSize()); - genInitSpikeCount(os, backend, popSubs, true, model.getBatchSize()); + genInitSpikeCount(backend, env, false, model.getBatchSize()); + genInitSpikeCount(backend, env, true, model.getBatchSize()); // Initialise spikes - genInitSpikes(os, backend, popSubs, false, model.getBatchSize()); - genInitSpikes(os, backend, popSubs, true, model.getBatchSize()); + genInitSpikes(backend, env, false, model.getBatchSize()); + genInitSpikes(backend, env, true, model.getBatchSize()); // Initialize spike times if(getArchetype().isSpikeTimeRequired()) { - genInitSpikeTime(os, backend, popSubs, "sT", model.getBatchSize()); + genInitSpikeTime(backend, env, "sT", model.getBatchSize()); } // Initialize previous spike times if(getArchetype().isPrevSpikeTimeRequired()) { - genInitSpikeTime(os, backend, popSubs, "prevST", model.getBatchSize()); + genInitSpikeTime( backend, env, "prevST", model.getBatchSize()); } // Initialize spike-like-event times if(getArchetype().isSpikeEventTimeRequired()) { - genInitSpikeTime(os, backend, popSubs, "seT", model.getBatchSize()); + genInitSpikeTime(backend, env, "seT", model.getBatchSize()); } // Initialize previous spike-like-event times if(getArchetype().isPrevSpikeEventTimeRequired()) { - genInitSpikeTime(os, backend, popSubs, "prevSET", model.getBatchSize()); + genInitSpikeTime(backend, env, "prevSET", model.getBatchSize()); } // If neuron group requires delays, zero spike queue pointer if(getArchetype().isDelayRequired()) { - backend.genPopVariableInit(os, popSubs, + backend.genPopVariableInit(env, [](CodeStream &os, Substitutions &) { os << "*group->spkQuePtr = 0;" << std::endl; @@ -710,39 +624,36 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, CodeStream } // Initialise neuron variables - genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getNeuronModel()->getVars(), getArchetype().getVarInitialisers(), - "", "numNeurons", getArchetype().getNumDelaySlots(), getIndex(), model.getBatchSize(), - [this](const std::string &v){ return getArchetype().isVarQueueRequired(v); }, - [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, - [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); + genInitNeuronVarCode(backend, env, *this, m_VarInitASTs, "", + "numNeurons", 1, modelMerged.getModel().getBatchSize()); // Generate initialisation code for child groups for (const auto &cs : getMergedCurrentSourceGroups()) { - cs.generate(backend, os, *this, modelMerged, popSubs); + cs.generate(backend, env, *this, modelMerged); } for(const auto &sg : getMergedInSynPSMGroups()) { - sg.generate(backend, os, *this, modelMerged, popSubs); + sg.generate(backend, env, *this, modelMerged); } for (const auto &sg : getMergedOutSynPreOutputGroups()) { - sg.generate(backend, os, *this, modelMerged, popSubs); + sg.generate(backend, env, *this, modelMerged); } for (const auto &sg : getMergedOutSynWUMPreVarGroups()) { - sg.generate(backend, os, *this, modelMerged, popSubs); + sg.generate(backend, env, *this, modelMerged); } for (const auto &sg : getMergedInSynWUMPostVarGroups()) { - sg.generate(backend, os, *this, modelMerged, popSubs); + sg.generate(backend, env, *this, modelMerged); } } //-------------------------------------------------------------------------- -void NeuronInitGroupMerged::genInitSpikeCount(CodeStream &os, const BackendBase &backend, const Substitutions &popSubs, +void NeuronInitGroupMerged::genInitSpikeCount(const BackendBase &backend, EnvironmentExternal &env, bool spikeEvent, unsigned int batchSize) const { // Is initialisation required at all const bool initRequired = spikeEvent ? getArchetype().isSpikeEventRequired() : true; if(initRequired) { // Generate variable initialisation code - backend.genPopVariableInit(os, popSubs, - [batchSize, spikeEvent, this] (CodeStream &os, Substitutions &) + backend.genPopVariableInit(env, + [batchSize, spikeEvent, this] (EnvironmentExternal &spikeCountEnv) { // Get variable name const char *spikeCntName = spikeEvent ? "spkCntEvnt" : "spkCnt"; @@ -753,21 +664,21 @@ void NeuronInitGroupMerged::genInitSpikeCount(CodeStream &os, const BackendBase (getArchetype().isTrueSpikeRequired() && getArchetype().isDelayRequired()); // Zero across all delay slots and batches - genScalarFill(os, spikeCntName, "0", VarAccessDuplication::DUPLICATE, batchSize, delayRequired, getArchetype().getNumDelaySlots()); + genScalarFill(spikeCountEnv.getStream(), spikeCntName, "0", VarAccessDuplication::DUPLICATE, batchSize, delayRequired, getArchetype().getNumDelaySlots()); }); } } //-------------------------------------------------------------------------- -void NeuronInitGroupMerged::genInitSpikes(CodeStream &os, const BackendBase &backend, const Substitutions &popSubs, +void NeuronInitGroupMerged::genInitSpikes(const BackendBase &backend, EnvironmentExternal &env, bool spikeEvent, unsigned int batchSize) const { // Is initialisation required at all const bool initRequired = spikeEvent ? getArchetype().isSpikeEventRequired() : true; if(initRequired) { // Generate variable initialisation code - backend.genVariableInit(os, "group->numNeurons", "id", popSubs, - [batchSize, spikeEvent, this] (CodeStream &os, Substitutions &varSubs) + backend.genVariableInit(env, "group->numNeurons", "id", + [batchSize, spikeEvent, this] (EnvironmentExternal &varEnv) { // Get variable name const char *spikeName = spikeEvent ? "spkEvnt" : "spk"; @@ -778,20 +689,20 @@ void NeuronInitGroupMerged::genInitSpikes(CodeStream &os, const BackendBase &bac (getArchetype().isTrueSpikeRequired() && getArchetype().isDelayRequired()); // Zero across all delay slots and batches - genVariableFill(os, spikeName, "0", varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, - batchSize, delayRequired, getArchetype().getNumDelaySlots()); + genVariableFill(varEnv.getStream(), spikeName, "0", varEnv["id"], "group->numNeurons", + VarAccessDuplication::DUPLICATE, batchSize, delayRequired, getArchetype().getNumDelaySlots()); }); } } //------------------------------------------------------------------------ -void NeuronInitGroupMerged::genInitSpikeTime(CodeStream &os, const BackendBase &backend, const Substitutions &popSubs, +void NeuronInitGroupMerged::genInitSpikeTime(const BackendBase &backend, EnvironmentExternal &env, const std::string &varName, unsigned int batchSize) const { // Generate variable initialisation code - backend.genVariableInit(os, "group->numNeurons", "id", popSubs, - [batchSize, varName, this] (CodeStream &os, Substitutions &varSubs) + backend.genVariableInit(env, "group->numNeurons", "id", + [batchSize, varName, this] (EnvironmentExternal &varEnv) { - genVariableFill(os, varName, "-TIME_MAX", varSubs["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, + genVariableFill(varEnv.getStream(), varName, "-TIME_MAX", varEnv["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, batchSize, getArchetype().isDelayRequired(), getArchetype().getNumDelaySlots()); }); @@ -1139,6 +1050,9 @@ boost::uuids::detail::sha1::digest_type CustomUpdateInitGroupMerged::getHashDige void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { // Initialise custom update variables + genInitNeuronVarCode(backend, env, *this, m_VarInitASTs, ""), + "size", 1, getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1); + genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getCustomUpdateModel()->getVars(), getArchetype().getVarInitialisers(), "", "size", getIndex(), getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index fa40439e6e..56c79cc2cd 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -64,15 +64,15 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend // Create new substitution environment and add parameters, derived parameters and extra global parameters EnvironmentSubstitute envSubs(env); envSubs.getStream() << "// current source " << getIndex() << std::endl; - envSubs.addParamValueSubstitution(cm->getParamNames(), getArchetype().getParams(), + envSubs.addParamValueSubstitution(cm->getParamNames(), getArchetype().getParams(), suffix, [this](const std::string &p) { return isParamHeterogeneous(p); }); - envSubs.addVarValueSubstitution(cm->getDerivedParams(), getArchetype().getDerivedParams(), + envSubs.addVarValueSubstitution(cm->getDerivedParams(), getArchetype().getDerivedParams(), suffix, [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); - envSubs.addVarNameSubstitution(cm->getExtraGlobalParams()); + envSubs.addVarNameSubstitution(cm->getExtraGlobalParams(), suffix); // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varSubs( - getArchetype(), getTypeContext(), envSubs, + getArchetype(), getTypeContext(), envSubs, "l", suffix, [&envSubs, &modelMerged, &ng](const std::string&, const Models::VarInit&, VarAccess a) { return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), envSubs["id"]); @@ -184,11 +184,11 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env } // Add parameters, derived parameters and extra global parameters to environment - envSubs.addParamValueSubstitution(psm->getParamNames(), getArchetype().getPSParams(), + envSubs.addParamValueSubstitution(psm->getParamNames(), getArchetype().getPSParams(), suffix, [this](const std::string &p) { return isParamHeterogeneous(p); }); - envSubs.addVarValueSubstitution(psm->getDerivedParams(), getArchetype().getPSDerivedParams(), + envSubs.addVarValueSubstitution(psm->getDerivedParams(), getArchetype().getPSDerivedParams(), suffix, [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); - envSubs.addVarNameSubstitution(psm->getExtraGlobalParams()); + envSubs.addVarNameSubstitution(psm->getExtraGlobalParams(), suffix); // **TODO** naming convention envSubs.addSubstitution("inSyn", "linSyn"); @@ -198,7 +198,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varSubs( - getArchetype(), getTypeContext(), envSubs, + getArchetype(), getTypeContext(), envSubs, "l", suffix, [&envSubs, &modelMerged, &ng](const std::string&, const Models::VarInit&, VarAccess a) { return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), envSubs["id"]); @@ -313,16 +313,16 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back // Create new substitution environment and add parameters, derived parameters and extra global parameters EnvironmentSubstitute envSubs(env); envSubs.getStream() << "// postsynaptic weight update " << getIndex() << std::endl; - envSubs.addParamValueSubstitution(wum->getParamNames(), getArchetype().getWUParams(), + envSubs.addParamValueSubstitution(wum->getParamNames(), getArchetype().getWUParams(), suffix, [this](const std::string &p) { return isParamHeterogeneous(p); }); - envSubs.addVarValueSubstitution(wum->getDerivedParams(), getArchetype().getWUDerivedParams(), + envSubs.addVarValueSubstitution(wum->getDerivedParams(), getArchetype().getWUDerivedParams(), suffix, [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); - envSubs.addVarNameSubstitution(wum->getExtraGlobalParams()); + envSubs.addVarNameSubstitution(wum->getExtraGlobalParams(), suffix); // Create an environment which caches variables in local variables if they are accessed const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varSubs( - getArchetype(), getTypeContext(), envSubs, + getArchetype(), getTypeContext(), envSubs, "l", suffix, [batchSize, delayed, &envSubs, &ng](const std::string&, const Models::VarInit&, VarAccess a) { return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), envSubs["id"]); @@ -443,16 +443,16 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back // Create new substitution environment and add parameters, derived parameters and extra global parameters EnvironmentSubstitute envSubs(env); envSubs.getStream() << "// presynaptic weight update " << getIndex() << std::endl; - envSubs.addParamValueSubstitution(wum->getParamNames(), getArchetype().getWUParams(), + envSubs.addParamValueSubstitution(wum->getParamNames(), getArchetype().getWUParams(), suffix, [this](const std::string &p) { return isParamHeterogeneous(p); }); - envSubs.addVarValueSubstitution(wum->getDerivedParams(), getArchetype().getWUDerivedParams(), + envSubs.addVarValueSubstitution(wum->getDerivedParams(), getArchetype().getWUDerivedParams(), suffix, [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); - envSubs.addVarNameSubstitution(wum->getExtraGlobalParams()); + envSubs.addVarNameSubstitution(wum->getExtraGlobalParams(), suffix); // Create an environment which caches variables in local variables if they are accessed const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varSubs( - getArchetype(), getTypeContext(), envSubs, + getArchetype(), getTypeContext(), envSubs, "l", suffix, [batchSize, delayed, &envSubs, &ng](const std::string&, const Models::VarInit&, VarAccess a) { return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), envSubs["id"]); @@ -714,9 +714,9 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E {neuronEnv.addInitialiser(typeName + " " + v.name + " = " + v.value + ";")}); } - neuronEnv.addParamValueSubstitution(nm->getParamNames(), getArchetype().getParams(), + neuronEnv.addParamValueSubstitution(nm->getParamNames(), getArchetype().getParams(), "", [this](const std::string &p) { return isParamHeterogeneous(p); }); - neuronEnv.addVarValueSubstitution(nm->getDerivedParams(), getArchetype().getDerivedParams(), + neuronEnv.addVarValueSubstitution(nm->getDerivedParams(), getArchetype().getDerivedParams(), "", [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); neuronEnv.addVarNameSubstitution(nm->getExtraGlobalParams()); @@ -744,7 +744,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Create an environment which caches variables in local variables if they are accessed // **NOTE** we do this right at the top so that local copies can be used by child groups EnvironmentLocalVarCache neuronVarEnv( - getArchetype(), getTypeContext(), neuronEnv, + getArchetype(), getTypeContext(), neuronEnv, "l", "", [batchSize, &neuronEnv, this](const std::string &varName, const Models::VarInit&, VarAccess a) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); From bdd86b97b9c9c96fa2443f685b781aceb08292f8 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 9 Jun 2023 09:08:31 +0100 Subject: [PATCH 202/725] environment-based single threaded CPU neuron updates --- .../backends/single_threaded_cpu/backend.h | 2 +- .../backends/single_threaded_cpu/backend.cc | 121 +++++++++--------- 2 files changed, 62 insertions(+), 61 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 0de2c91fa2..3d4ec44d70 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -188,7 +188,7 @@ class BACKEND_EXPORT Backend : public BackendBase //-------------------------------------------------------------------------- void genPresynapticUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, const Substitutions &popSubs, bool trueSpike) const; - void genEmitSpike(CodeStream &os, const NeuronUpdateGroupMerged &ng, const Substitutions &subs, bool trueSpike, bool recordingEnabled) const; + void genEmitSpike(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng,bool trueSpike, bool recordingEnabled) const; template void genMergedStructArrayPush(CodeStream &os, const std::vector &groups) const diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 7c657721a0..cc18d1941f 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -153,60 +153,61 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision().getName())); - funcSubs.addVarSubstitution("t", "t"); - funcSubs.addVarSubstitution("batch", "0"); - - Timer t(os, "neuronUpdate", model.isTimingEnabled()); + StandardLibrary::FunctionEnvironment stdEnv(os); + EnvironmentSubstitute funcEnv(stdEnv); + funcEnv.addSubstitution("t", "t"); + funcEnv.addSubstitution("batch", "0"); + + Timer t(funcEnv.getStream(), "neuronUpdate", model.isTimingEnabled()); // Loop through merged previous spike time update groups for(const auto &n : modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups()) { - CodeStream::Scope b(os); - os << "// merged neuron prev spike update group " << n.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged neuron prev spike update group " << n.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Get reference to group - os << "const auto *group = &mergedNeuronPrevSpikeTimeUpdateGroup" << n.getIndex() << "[g]; " << std::endl; + funcEnv.getStream() << "const auto *group = &mergedNeuronPrevSpikeTimeUpdateGroup" << n.getIndex() << "[g]; " << std::endl; if(n.getArchetype().isDelayRequired()) { // Calculate delay slot corresponding to last timestep - os << "const unsigned int lastTimestepDelaySlot = (*group->spkQuePtr + " << (n.getArchetype().getNumDelaySlots() - 1) << ") % " << n.getArchetype().getNumDelaySlots() << ";" << std::endl; - os << "const unsigned int lastTimestepDelayOffset = lastTimestepDelaySlot * group->numNeurons;" << std::endl; + funcEnv.getStream() << "const unsigned int lastTimestepDelaySlot = (*group->spkQuePtr + " << (n.getArchetype().getNumDelaySlots() - 1) << ") % " << n.getArchetype().getNumDelaySlots() << ";" << std::endl; + funcEnv.getStream() << "const unsigned int lastTimestepDelayOffset = lastTimestepDelaySlot * group->numNeurons;" << std::endl; if(n.getArchetype().isPrevSpikeTimeRequired()) { // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - os << "for(unsigned int i = 0; i < group->spkCnt[lastTimestepDelaySlot]; i++)"; + funcEnv.getStream() << "for(unsigned int i = 0; i < group->spkCnt[lastTimestepDelaySlot]; i++)"; { - CodeStream::Scope b(os); - os << "group->prevST[lastTimestepDelayOffset + group->spk[lastTimestepDelayOffset + i]] = t - DT;" << std::endl; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "group->prevST[lastTimestepDelayOffset + group->spk[lastTimestepDelayOffset + i]] = t - DT;" << std::endl; } } if(n.getArchetype().isPrevSpikeEventTimeRequired()) { // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - os << "for(unsigned int i = 0; i < group->spkCntEvnt[lastTimestepDelaySlot]; i++)"; + funcEnv.getStream() << "for(unsigned int i = 0; i < group->spkCntEvnt[lastTimestepDelaySlot]; i++)"; { - CodeStream::Scope b(os); - os << "group->prevSET[lastTimestepDelayOffset + group->spkEvnt[lastTimestepDelayOffset + i]] = t - DT;" << std::endl; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "group->prevSET[lastTimestepDelayOffset + group->spkEvnt[lastTimestepDelayOffset + i]] = t - DT;" << std::endl; } } } else { if(n.getArchetype().isPrevSpikeTimeRequired()) { // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - os << "for(unsigned int i = 0; i < group->spkCnt[0]; i++)"; + funcEnv.getStream() << "for(unsigned int i = 0; i < group->spkCnt[0]; i++)"; { - CodeStream::Scope b(os); - os << "group->prevST[group->spk[i]] = t - DT;" << std::endl; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "group->prevST[group->spk[i]] = t - DT;" << std::endl; } } if(n.getArchetype().isPrevSpikeEventTimeRequired()) { // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - os << "for(unsigned int i = 0; i < group->spkCntEvnt[0]; i++)"; + funcEnv.getStream() << "for(unsigned int i = 0; i < group->spkCntEvnt[0]; i++)"; { - CodeStream::Scope b(os); - os << "group->prevSET[group->spkEvnt[i]] = t - DT;" << std::endl; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "group->prevSET[group->spkEvnt[i]] = t - DT;" << std::endl; } } } @@ -215,77 +216,77 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // Loop through merged neuron spike queue update groups for(const auto &n : modelMerged.getMergedNeuronSpikeQueueUpdateGroups()) { - CodeStream::Scope b(os); - os << "// merged neuron spike queue update group " << n.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged neuron spike queue update group " << n.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Get reference to group - os << "const auto *group = &mergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << "[g]; " << std::endl; + funcEnv.getStream() << "const auto *group = &mergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << "[g]; " << std::endl; // Generate spike count reset - n.genMergedGroupSpikeCountReset(os, 1); + n.genMergedGroupSpikeCountReset(funcEnv.getStream(), 1); } } // Loop through merged neuron update groups for(const auto &n : modelMerged.getMergedNeuronUpdateGroups()) { - CodeStream::Scope b(os); - os << "// merged neuron update group " << n.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged neuron update group " << n.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Get reference to group - os << "const auto *group = &mergedNeuronUpdateGroup" << n.getIndex() << "[g]; " << std::endl; + funcEnv.getStream() << "const auto *group = &mergedNeuronUpdateGroup" << n.getIndex() << "[g]; " << std::endl; // If spike or spike-like event recording is in use if(n.getArchetype().isSpikeRecordingEnabled() || n.getArchetype().isSpikeEventRecordingEnabled()) { // Calculate number of words which will be used to record this population's spikes - os << "const unsigned int numRecordingWords = (group->numNeurons + 31) / 32;" << std::endl; + funcEnv.getStream() << "const unsigned int numRecordingWords = (group->numNeurons + 31) / 32;" << std::endl; // Zero spike recording buffer if(n.getArchetype().isSpikeRecordingEnabled()) { - os << "std::fill_n(&group->recordSpk[recordingTimestep * numRecordingWords], numRecordingWords, 0);" << std::endl; + funcEnv.getStream() << "std::fill_n(&group->recordSpk[recordingTimestep * numRecordingWords], numRecordingWords, 0);" << std::endl; } // Zero spike-like-event recording buffer if(n.getArchetype().isSpikeEventRecordingEnabled()) { - os << "std::fill_n(&group->recordSpkEvent[recordingTimestep * numRecordingWords], numRecordingWords, 0);" << std::endl; + funcEnv.getStream() << "std::fill_n(&group->recordSpkEvent[recordingTimestep * numRecordingWords], numRecordingWords, 0);" << std::endl; } } - genNeuronIndexCalculation(os, n, 1); - os << std::endl; + genNeuronIndexCalculation(funcEnv.getStream(), n, 1); + funcEnv.getStream() << std::endl; - os << "for(unsigned int i = 0; i < group->numNeurons; i++)"; + funcEnv.getStream() << "for(unsigned int i = 0; i < group->numNeurons; i++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); - Substitutions popSubs(&funcSubs); - popSubs.addVarSubstitution("id", "i"); + EnvironmentSubstitute popEnv(funcEnv); + popEnv.addSubstitution("id", "i"); // If this neuron group requires a simulation RNG, substitute in global RNG if(n.getArchetype().isSimRNGRequired()) { - popSubs.addVarSubstitution("rng", "hostRNG"); + popEnv.addSubstitution("rng", "hostRNG"); } - n.generateNeuronUpdate(*this, os, modelMerged, popSubs, + n.generateNeuronUpdate(*this, popEnv, modelMerged, // Emit true spikes - [&modelMerged, this](CodeStream &os, const NeuronUpdateGroupMerged &ng, Substitutions &subs) + [&modelMerged, this](EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng) { // Insert code to update WU vars - ng.generateWUVarUpdate(*this, os, modelMerged, subs); + ng.generateWUVarUpdate(*this, env, modelMerged); // Insert code to emit true spikes - genEmitSpike(os, ng, subs, true, ng.getArchetype().isSpikeRecordingEnabled()); + genEmitSpike(env, ng, true, ng.getArchetype().isSpikeRecordingEnabled()); }, // Emit spike-like events - [this](CodeStream &os, const NeuronUpdateGroupMerged &ng, Substitutions &subs) + [this](EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng) { // Insert code to emit spike-like events - genEmitSpike(os, ng, subs, false, ng.getArchetype().isSpikeEventRecordingEnabled()); + genEmitSpike(env, ng, false, ng.getArchetype().isSpikeEventRecordingEnabled()); }); } } @@ -1890,36 +1891,36 @@ void Backend::genPresynapticUpdate(CodeStream &os, const ModelSpecMerged &modelM } } //-------------------------------------------------------------------------- -void Backend::genEmitSpike(CodeStream &os, const NeuronUpdateGroupMerged &ng, const Substitutions &subs, bool trueSpike, bool recordingEnabled) const +void Backend::genEmitSpike(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, bool trueSpike, bool recordingEnabled) const { // Determine if delay is required and thus, at what offset we should write into the spike queue const bool spikeDelayRequired = trueSpike ? (ng.getArchetype().isDelayRequired() && ng.getArchetype().isTrueSpikeRequired()) : ng.getArchetype().isDelayRequired(); const std::string spikeQueueOffset = spikeDelayRequired ? "writeDelayOffset + " : ""; const std::string suffix = trueSpike ? "" : "Evnt"; - os << "group->spk" << suffix << "[" << spikeQueueOffset << "group->spkCnt" << suffix; + env.getStream() << "group->spk" << suffix << "[" << spikeQueueOffset << "group->spkCnt" << suffix; if(spikeDelayRequired) { // WITH DELAY - os << "[*group->spkQuePtr]++]"; + env.getStream() << "[*group->spkQuePtr]++]"; } else { // NO DELAY - os << "[0]++]"; + env.getStream() << "[0]++]"; } - os << " = " << subs["id"] << ";" << std::endl; + env.getStream() << " = " << env["id"] << ";" << std::endl; // Reset spike and spike-like-event times const std::string queueOffset = ng.getArchetype().isDelayRequired() ? "writeDelayOffset + " : ""; if(trueSpike && ng.getArchetype().isSpikeTimeRequired()) { - os << "group->sT[" << queueOffset << subs["id"] << "] = " << subs["t"] << ";" << std::endl; + env.getStream() << "group->sT[" << queueOffset << env["id"] << "] = " << env["t"] << ";" << std::endl; } else if(!trueSpike && ng.getArchetype().isSpikeEventTimeRequired()) { - os << "group->seT[" << queueOffset << subs["id"] << "] = " << subs["t"] << ";" << std::endl; + env.getStream() << "group->seT[" << queueOffset << env["id"] << "] = " << env["t"] << ";" << std::endl; } // If recording is enabled if(recordingEnabled) { const std::string recordSuffix = trueSpike ? "" : "Event"; - os << "group->recordSpk" << recordSuffix << "[(recordingTimestep * numRecordingWords) + (" << subs["id"] << " / 32)]"; - os << " |= (1 << (" << subs["id"] << " % 32));" << std::endl; + env.getStream() << "group->recordSpk" << recordSuffix << "[(recordingTimestep * numRecordingWords) + (" << env["id"] << " / 32)]"; + env.getStream() << " |= (1 << (" << env["id"] << " % 32));" << std::endl; } } //-------------------------------------------------------------------------- From 835402d584e8e1ed32f04dcf1c55a6f3d476299d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 9 Jun 2023 09:33:20 +0100 Subject: [PATCH 203/725] moved custom update init to same structure --- include/genn/genn/code_generator/initGroupMerged.h | 9 ++++++++- src/genn/genn/code_generator/initGroupMerged.cc | 7 +------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index e38950bea7..6fcd522f69 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -508,12 +508,19 @@ class GENN_EXPORT CustomUpdateInitGroupMerged : public CustomUpdateInitGroupMerg runnerVarDecl, runnerMergedStructAlloc, name); } - void generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateInit(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged) const; //---------------------------------------------------------------------------- // Static constants //---------------------------------------------------------------------------- static const std::string name; + +private: + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + //! Parsed statements and resolved types for initialising each variable + NeuronInitGroupMerged::VarInitAST m_VarInitASTs; }; diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 972d4f0de9..9240257650 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -1047,16 +1047,11 @@ boost::uuids::detail::sha1::digest_type CustomUpdateInitGroupMerged::getHashDige return hash.get_digest(); } // ---------------------------------------------------------------------------- -void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged) const { // Initialise custom update variables genInitNeuronVarCode(backend, env, *this, m_VarInitASTs, ""), "size", 1, getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1); - - genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getCustomUpdateModel()->getVars(), getArchetype().getVarInitialisers(), - "", "size", getIndex(), getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, - [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, - [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); } // ---------------------------------------------------------------------------- From 6356fe64499e79f499e69b7633af485661db2d8f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 12 Jun 2023 10:01:19 +0100 Subject: [PATCH 204/725] renamed ``ModelSpecMerged::createMergedGroupsHash`` to ``ModelSpecMerged::createMergedGroups`` --- .../genn/code_generator/modelSpecMerged.h | 6 +-- .../genn/code_generator/modelSpecMerged.cc | 50 +++++++++---------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 846ce07bd9..13ee936733 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -291,7 +291,7 @@ class GENN_EXPORT ModelSpecMerged } template - void createMergedGroupsHash(const BackendBase &backend, + void createMergedGroups(const BackendBase &backend, const std::vector> &unmergedGroups, std::vector &mergedGroups, D getHashDigest, bool host = false) { @@ -337,7 +337,7 @@ class GENN_EXPORT ModelSpecMerged } template - void createMergedGroupsHash(const BackendBase &backend, + void createMergedGroups(const BackendBase &backend, const std::map &groups, std::vector &mergedGroups, F filter, U updateHash, bool host = false) { @@ -350,7 +350,7 @@ class GENN_EXPORT ModelSpecMerged } // Merge filtered vector - createMergedGroupsHash(backend, unmergedGroups, mergedGroups, updateHash, host); + createMergedGroups(backend, unmergedGroups, mergedGroups, updateHash, host); } //-------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index ee7688a855..76f3cc9389 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -35,32 +35,32 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode"), m_TypeContext{{"scalar", model.getPrecision()}, {"timepoint", model.getTimePrecision()}} { LOGD_CODE_GEN << "Merging neuron update groups:"; - createMergedGroupsHash(backend, model.getNeuronGroups(), m_MergedNeuronUpdateGroups, + createMergedGroups(backend, model.getNeuronGroups(), m_MergedNeuronUpdateGroups, [](const NeuronGroupInternal &){ return true; }, &NeuronGroupInternal::getHashDigest); LOGD_CODE_GEN << "Merging presynaptic update groups:"; - createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedPresynapticUpdateGroups, + createMergedGroups(backend, model.getSynapseGroups(), m_MergedPresynapticUpdateGroups, [](const SynapseGroupInternal &sg) { return (sg.isSpikeEventRequired() || sg.isTrueSpikeRequired()); }, &SynapseGroupInternal::getWUHashDigest); LOGD_CODE_GEN << "Merging postsynaptic update groups:"; - createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedPostsynapticUpdateGroups, + createMergedGroups(backend, model.getSynapseGroups(), m_MergedPostsynapticUpdateGroups, [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getLearnPostCode().empty(); }, &SynapseGroupInternal::getWUHashDigest); LOGD_CODE_GEN << "Merging synapse dynamics update groups:"; - createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedSynapseDynamicsGroups, + createMergedGroups(backend, model.getSynapseGroups(), m_MergedSynapseDynamicsGroups, [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getSynapseDynamicsCode().empty(); }, &SynapseGroupInternal::getWUHashDigest); LOGD_CODE_GEN << "Merging neuron initialization groups:"; - createMergedGroupsHash(backend, model.getNeuronGroups(), m_MergedNeuronInitGroups, + createMergedGroups(backend, model.getNeuronGroups(), m_MergedNeuronInitGroups, [](const NeuronGroupInternal &){ return true; }, &NeuronGroupInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging synapse initialization groups:"; - createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedSynapseInitGroups, + createMergedGroups(backend, model.getSynapseGroups(), m_MergedSynapseInitGroups, [](const SynapseGroupInternal &sg) { return (((sg.getMatrixType() & SynapseMatrixConnectivity::DENSE) @@ -70,12 +70,12 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa &SynapseGroupInternal::getWUInitHashDigest); LOGD_CODE_GEN << "Merging custom update initialization groups:"; - createMergedGroupsHash(backend, model.getCustomUpdates(), m_MergedCustomUpdateInitGroups, + createMergedGroups(backend, model.getCustomUpdates(), m_MergedCustomUpdateInitGroups, [](const CustomUpdateInternal &cg) { return cg.isVarInitRequired(); }, &CustomUpdateInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging custom weight update initialization groups:"; - createMergedGroupsHash(backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateInitGroups, + createMergedGroups(backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateInitGroups, [](const CustomUpdateWUInternal &cg) { return (((cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::DENSE) @@ -85,12 +85,12 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa &CustomUpdateWUInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging synapse connectivity initialisation groups:"; - createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedSynapseConnectivityInitGroups, + createMergedGroups(backend, model.getSynapseGroups(), m_MergedSynapseConnectivityInitGroups, [](const SynapseGroupInternal &sg){ return sg.isSparseConnectivityInitRequired(); }, &SynapseGroupInternal::getConnectivityInitHashDigest); LOGD_CODE_GEN << "Merging synapse sparse initialization groups:"; - createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedSynapseSparseInitGroups, + createMergedGroups(backend, model.getSynapseGroups(), m_MergedSynapseSparseInitGroups, [&backend](const SynapseGroupInternal &sg) { return ((sg.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && @@ -100,7 +100,7 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa &SynapseGroupInternal::getWUInitHashDigest); LOGD_CODE_GEN << "Merging custom sparse weight update initialization groups:"; - createMergedGroupsHash(backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateSparseInitGroups, + createMergedGroups(backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateSparseInitGroups, [](const CustomUpdateWUInternal &cg) { return (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) && cg.isVarInitRequired(); @@ -108,7 +108,7 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa &CustomUpdateWUInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging custom connectivity update presynaptic initialisation groups:"; - createMergedGroupsHash(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePreInitGroups, + createMergedGroups(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePreInitGroups, [&backend](const CustomConnectivityUpdateInternal &cg) { return (cg.isPreVarInitRequired() || (backend.isPopulationRNGInitialisedOnDevice() && cg.isRowSimRNGRequired())); @@ -116,22 +116,22 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa &CustomConnectivityUpdateInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging custom connectivity update postsynaptic initialisation groups:"; - createMergedGroupsHash(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePostInitGroups, + createMergedGroups(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePostInitGroups, [](const CustomConnectivityUpdateInternal &cg) { return cg.isPostVarInitRequired(); }, &CustomConnectivityUpdateInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging custom connectivity update sparse initialisation groups:"; - createMergedGroupsHash(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateSparseInitGroups, + createMergedGroups(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateSparseInitGroups, [](const CustomConnectivityUpdateInternal &cg) { return cg.isVarInitRequired(); }, &CustomConnectivityUpdateInternal::getInitHashDigest); LOGD_CODE_GEN << "Merging neuron groups which require their spike queues updating:"; - createMergedGroupsHash(backend, model.getNeuronGroups(), m_MergedNeuronSpikeQueueUpdateGroups, + createMergedGroups(backend, model.getNeuronGroups(), m_MergedNeuronSpikeQueueUpdateGroups, [](const NeuronGroupInternal &){ return true; }, &NeuronGroupInternal::getSpikeQueueUpdateHashDigest); LOGD_CODE_GEN << "Merging neuron groups which require their previous spike times updating:"; - createMergedGroupsHash(backend, model.getNeuronGroups(), m_MergedNeuronPrevSpikeTimeUpdateGroups, + createMergedGroups(backend, model.getNeuronGroups(), m_MergedNeuronPrevSpikeTimeUpdateGroups, [](const NeuronGroupInternal &ng){ return (ng.isPrevSpikeTimeRequired() || ng.isPrevSpikeEventTimeRequired()); }, &NeuronGroupInternal::getPrevSpikeTimeUpdateHashDigest); @@ -145,11 +145,11 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa } } LOGD_CODE_GEN << "Merging synapse groups which require their dendritic delay updating:"; - createMergedGroupsHash(backend, synapseGroupsWithDendriticDelay, m_MergedSynapseDendriticDelayUpdateGroups, + createMergedGroups(backend, synapseGroupsWithDendriticDelay, m_MergedSynapseDendriticDelayUpdateGroups, &SynapseGroupInternal::getDendriticDelayUpdateHashDigest); LOGD_CODE_GEN << "Merging synapse groups which require host code to initialise their synaptic connectivity:"; - createMergedGroupsHash(backend, model.getSynapseGroups(), m_MergedSynapseConnectivityHostInitGroups, + createMergedGroups(backend, model.getSynapseGroups(), m_MergedSynapseConnectivityHostInitGroups, [](const SynapseGroupInternal &sg) { return !sg.getConnectivityInitialiser().getSnippet()->getHostInitCode().empty(); @@ -157,39 +157,39 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa &SynapseGroupInternal::getConnectivityHostInitHashDigest, true); LOGD_CODE_GEN << "Merging custom update groups:"; - createMergedGroupsHash(backend, model.getCustomUpdates(), m_MergedCustomUpdateGroups, + createMergedGroups(backend, model.getCustomUpdates(), m_MergedCustomUpdateGroups, [](const CustomUpdateInternal &) { return true; }, &CustomUpdateInternal::getHashDigest); LOGD_CODE_GEN << "Merging custom weight update groups:"; - createMergedGroupsHash(backend, model.getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, + createMergedGroups(backend, model.getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, [](const CustomUpdateWUInternal &cg) { return !cg.isTransposeOperation(); }, &CustomUpdateWUInternal::getHashDigest); LOGD_CODE_GEN << "Merging custom weight transpose update groups:"; - createMergedGroupsHash(backend, model.getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, + createMergedGroups(backend, model.getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, [](const CustomUpdateWUInternal &cg) { return cg.isTransposeOperation(); }, &CustomUpdateWUInternal::getHashDigest); if(backend.isHostReductionRequired()) { LOGD_CODE_GEN << "Merging custom weight update groups:"; - createMergedGroupsHash(backend, model.getCustomUpdates(), m_MergedCustomUpdateHostReductionGroups, + createMergedGroups(backend, model.getCustomUpdates(), m_MergedCustomUpdateHostReductionGroups, [](const CustomUpdateInternal &cg) { return cg.isBatchReduction(); }, &CustomUpdateInternal::getHashDigest, true); LOGD_CODE_GEN << "Merging custom weight transpose update groups:"; - createMergedGroupsHash(backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateHostReductionGroups, + createMergedGroups(backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateHostReductionGroups, [](const CustomUpdateWUInternal &cg) { return cg.isBatchReduction(); }, &CustomUpdateWUInternal::getHashDigest, true); } LOGD_CODE_GEN << "Merging custom connectivity update groups:"; - createMergedGroupsHash(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateGroups, + createMergedGroups(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateGroups, [](const CustomConnectivityUpdateInternal &cg) { return !cg.getCustomConnectivityUpdateModel()->getRowUpdateCode().empty(); }, &CustomConnectivityUpdateInternal::getHashDigest); LOGD_CODE_GEN << "Merging custom connectivity host update groups:"; - createMergedGroupsHash(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityHostUpdateGroups, + createMergedGroups(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityHostUpdateGroups, [](const CustomConnectivityUpdateInternal &cg) { return !cg.getCustomConnectivityUpdateModel()->getHostUpdateCode().empty(); }, &CustomConnectivityUpdateInternal::getHashDigest, true); From 232bd8e79f18357914b413f714bf0f0b55bd7ae0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 12 Jun 2023 10:55:22 +0100 Subject: [PATCH 205/725] made references to ModelSpecMerged non-const --- .../genn/backends/single_threaded_cpu/backend.h | 8 ++++---- include/genn/genn/code_generator/backendBase.h | 8 ++++---- .../backends/single_threaded_cpu/backend.cc | 6 +++--- src/genn/genn/code_generator/generateModules.cc | 17 +++++++++-------- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 3d4ec44d70..658e26abc4 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -42,13 +42,13 @@ class BACKEND_EXPORT Backend : public BackendBase //-------------------------------------------------------------------------- // CodeGenerator::BackendBase virtuals //-------------------------------------------------------------------------- - virtual void genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; + virtual void genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; - virtual void genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; + virtual void genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; - virtual void genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; + virtual void genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; - virtual void genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; + virtual void genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; virtual size_t getSynapticMatrixRowStride(const SynapseGroupInternal &sg) const final; diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index aa0d8b2612..b71a112a4b 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -208,25 +208,25 @@ class GENN_EXPORT BackendBase /*! \param os CodeStream to write function to \param modelMerged merged model to generate code for \param preambleHandler callback to write functions for pushing extra-global parameters*/ - virtual void genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; + virtual void genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; //! Generate platform-specific function to update the state of all synapses /*! \param os CodeStream to write function to \param modelMerged merged model to generate code for \param preambleHandler callback to write functions for pushing extra-global parameters*/ - virtual void genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; + virtual void genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; //! Generate platform-specific functions to perform custom updates /*! \param os CodeStream to write function to \param modelMerged merged model to generate code for \param preambleHandler callback to write functions for pushing extra-global parameters*/ - virtual void genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; + virtual void genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; //! Generate platform-specific function to initialise model /*! \param os CodeStream to write function to \param modelMerged merged model to generate code for \param preambleHandler callback to write functions for pushing extra-global parameters*/ - virtual void genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; + virtual void genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; //! Gets the stride used to access synaptic matrix rows, taking into account sparse data structure, padding etc virtual size_t getSynapticMatrixRowStride(const SynapseGroupInternal &sg) const = 0; diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index cc18d1941f..72e123fb10 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -294,7 +294,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged } } //-------------------------------------------------------------------------- -void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); if(model.getBatchSize() != 1) { @@ -493,7 +493,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge } } //-------------------------------------------------------------------------- -void Backend::genCustomUpdate(CodeStream &os_, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genCustomUpdate(CodeStream &os_, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); @@ -779,7 +779,7 @@ void Backend::genCustomUpdate(CodeStream &os_, const ModelSpecMerged &modelMerge } } //-------------------------------------------------------------------------- -void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); if(model.getBatchSize() != 1) { diff --git a/src/genn/genn/code_generator/generateModules.cc b/src/genn/genn/code_generator/generateModules.cc index 4923b6ef35..88d0bb370c 100644 --- a/src/genn/genn/code_generator/generateModules.cc +++ b/src/genn/genn/code_generator/generateModules.cc @@ -104,16 +104,17 @@ std::pair, MemAlloc> generateAll(const ModelSpecInterna ModelSpecMerged modelMerged(model, backend); // If force rebuild flag is set or model should be rebuilt - const auto hashDigest = modelMerged.getHashDigest(backend); + //const auto hashDigest = modelMerged.getHashDigest(backend); MemAlloc mem = MemAlloc::zero(); - if(forceRebuild || shouldRebuildModel(outputPath, hashDigest, mem)) { + if(true/*forceRebuild || shouldRebuildModel(outputPath, hashDigest, mem)*/) { // Generate modules - mem = generateRunner(outputPath, modelMerged, backend); + // **NOTE** these are ordered in terms of memory-space priority generateSynapseUpdate(outputPath, modelMerged, backend); generateNeuronUpdate(outputPath, modelMerged, backend); generateCustomUpdate(outputPath, modelMerged, backend); generateInit(outputPath, modelMerged, backend); - + mem = generateRunner(outputPath, modelMerged, backend); + // Generate support code module if the backend supports namespaces if(backend.supportsNamespace()) { generateSupportCode(outputPath, modelMerged); @@ -182,7 +183,7 @@ std::pair, MemAlloc> generateAll(const ModelSpecInterna return std::make_pair(modules, mem); } //-------------------------------------------------------------------------- -void generateNeuronUpdate(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, +void generateNeuronUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, const std::string &suffix) { // Create output stream to write to file and wrap in CodeStream @@ -206,7 +207,7 @@ void generateNeuronUpdate(const filesystem::path &outputPath, const ModelSpecMer }); } //-------------------------------------------------------------------------- -void generateCustomUpdate(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, +void generateCustomUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, const std::string &suffix) { // Create output stream to write to file and wrap in CodeStream @@ -229,7 +230,7 @@ void generateCustomUpdate(const filesystem::path &outputPath, const ModelSpecMer }); } //-------------------------------------------------------------------------- -void generateSynapseUpdate(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, +void generateSynapseUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, const std::string &suffix) { // Create output stream to write to file and wrap in CodeStream @@ -254,7 +255,7 @@ void generateSynapseUpdate(const filesystem::path &outputPath, const ModelSpecMe }); } //-------------------------------------------------------------------------- -void generateInit(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, +void generateInit(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, const std::string &suffix) { // Create output stream to write to file and wrap in CodeStream From a2a72494da156d2b0bd377da9376987262b3b6ed Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 12 Jun 2023 10:55:32 +0100 Subject: [PATCH 206/725] exposed group creation with callback --- .../genn/code_generator/modelSpecMerged.h | 251 +++++++++++++++++- .../genn/code_generator/modelSpecMerged.cc | 161 +---------- 2 files changed, 244 insertions(+), 168 deletions(-) diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 13ee936733..380dc61143 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -77,6 +77,9 @@ class GENN_EXPORT ModelSpecMerged typedef std::unordered_multimap MergedEGPDestinations; typedef std::map MergedEGPMap; + template + using GenerateMergedGroupFn = std::function; + //-------------------------------------------------------------------------- // Public API //-------------------------------------------------------------------------- @@ -161,6 +164,237 @@ class GENN_EXPORT ModelSpecMerged //! Get merged custom connectivity update groups where host processing needs to be performed const std::vector &getMergedCustomConnectivityHostUpdateGroups() const { return m_MergedCustomConnectivityHostUpdateGroups; } + template + void genMergedNeuronUpdateGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronUpdateGroups, + [](const NeuronGroupInternal &){ return true; }, + &NeuronGroupInternal::getHashDigest, generateGroup); + } + + template + void genMergedPresynapticUpdateGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedPresynapticUpdateGroups, + [](const SynapseGroupInternal &sg) { return (sg.isSpikeEventRequired() || sg.isTrueSpikeRequired()); }, + &SynapseGroupInternal::getWUHashDigest, generateGroup); + } + + template + void genMergedPostsynapticUpdateGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedPostsynapticUpdateGroups, + [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getLearnPostCode().empty(); }, + &SynapseGroupInternal::getWUHashDigest, generateGroup); + } + + template + void genMergedSynapseDynamicsGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseDynamicsGroups, + [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getSynapseDynamicsCode().empty(); }, + &SynapseGroupInternal::getWUHashDigest, generateGroup); + } + + template + void genMergedCustomUpdateGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateGroups, + [](const CustomUpdateInternal &) { return true; }, + &CustomUpdateInternal::getHashDigest, generateGroup); + } + + template + void genMergedCustomUpdateWUGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, + [](const CustomUpdateWUInternal &cg) { return !cg.isTransposeOperation(); }, + &CustomUpdateWUInternal::getHashDigest, generateGroup); + } + + template + void genMergedCustomUpdateTransposeWUGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, + [](const CustomUpdateWUInternal &cg) { return cg.isTransposeOperation(); }, + &CustomUpdateWUInternal::getHashDigest, generateGroup); + } + + template + void genMergedCustomUpdateHostReductionGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateHostReductionGroups, + [](const CustomUpdateInternal &cg) { return cg.isBatchReduction(); }, + &CustomUpdateInternal::getHashDigest, generateGroup, true); + } + + template + void genMergedCustomWUUpdateHostReductionGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateHostReductionGroups, + [](const CustomUpdateWUInternal &cg) { return cg.isBatchReduction(); }, + &CustomUpdateWUInternal::getHashDigest, generateGroup, true); + } + + template + void genMergedCustomConnectivityUpdateGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateGroups, + [](const CustomConnectivityUpdateInternal &cg) { return !cg.getCustomConnectivityUpdateModel()->getRowUpdateCode().empty(); }, + &CustomConnectivityUpdateInternal::getHashDigest, genereateGroup); + } + + template + void genMergedCustomConnectivityHostUpdateGroups(BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityHostUpdateGroups, + [](const CustomConnectivityUpdateInternal &cg) { return !cg.getCustomConnectivityUpdateModel()->getHostUpdateCode().empty(); }, + &CustomConnectivityUpdateInternal::getHashDigest, generateGroup, true); + } + + template + void genMergedNeuronSpikeQueueUpdateGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronSpikeQueueUpdateGroups, + [](const NeuronGroupInternal &){ return true; }, + &NeuronGroupInternal::getSpikeQueueUpdateHashDigest, generateGroup); + } + + template + void genMergedNeuronPrevSpikeTimeUpdateGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronPrevSpikeTimeUpdateGroups, + [](const NeuronGroupInternal &ng){ return (ng.isPrevSpikeTimeRequired() || ng.isPrevSpikeEventTimeRequired()); }, + &NeuronGroupInternal::getPrevSpikeTimeUpdateHashDigest, generateGroup); + } + + template + void genMergedSynapseDendriticDelayUpdateGroups(const BackendBase &backend, G generateGroup) + { + std::vector> synapseGroupsWithDendriticDelay; + for(const auto &n : getModel().getNeuronGroups()) { + for(const auto *sg : n.second.getFusedPSMInSyn()) { + if(sg->isDendriticDelayRequired()) { + synapseGroupsWithDendriticDelay.push_back(std::cref(*sg)); + } + } + } + createMergedGroups(backend, synapseGroupsWithDendriticDelay, m_MergedSynapseDendriticDelayUpdateGroups, + &SynapseGroupInternal::getDendriticDelayUpdateHashDigest, generateGroup); + } + + template + void genMergedNeuronInitGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronInitGroups, + [](const NeuronGroupInternal &){ return true; }, + &NeuronGroupInternal::getInitHashDigest, generateGroup); + } + + template + void genMergedCustomUpdateInitGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateInitGroups, + [](const CustomUpdateInternal &cg) { return cg.isVarInitRequired(); }, + &CustomUpdateInternal::getInitHashDigest, generateGroup); + } + + template + void genMergedCustomWUUpdateInitGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateInitGroups, + [](const CustomUpdateWUInternal &cg) + { + return (((cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::DENSE) + || (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL)) + && cg.isVarInitRequired()); + }, + &CustomUpdateWUInternal::getInitHashDigest, generateGroup); + } + + template + void genMergedSynapseInitGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseInitGroups, + [](const SynapseGroupInternal &sg) + { + return (((sg.getMatrixType() & SynapseMatrixConnectivity::DENSE) + || (sg.getMatrixType() & SynapseMatrixWeight::KERNEL)) + && sg.isWUVarInitRequired()); + }, + &SynapseGroupInternal::getWUInitHashDigest, generateGroup); + } + + template + void genMergedSynapseConnectivityInitGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseConnectivityInitGroups, + [](const SynapseGroupInternal &sg){ return sg.isSparseConnectivityInitRequired(); }, + &SynapseGroupInternal::getConnectivityInitHashDigest, generateGroup); + } + + template + void genMergedSynapseSparseInitGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseSparseInitGroups, + [&backend](const SynapseGroupInternal &sg) + { + return ((sg.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && + (sg.isWUVarInitRequired() + || (backend.isPostsynapticRemapRequired() && !sg.getWUModel()->getLearnPostCode().empty()))); + }, + &SynapseGroupInternal::getWUInitHashDigest, generateGroup); + } + + template + void genMergedCustomWUUpdateSparseInitGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateSparseInitGroups, + [](const CustomUpdateWUInternal &cg) + { + return (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) && cg.isVarInitRequired(); + }, + &CustomUpdateWUInternal::getInitHashDigest, generateGroup); + } + + template + void genMergedCustomConnectivityUpdatePreInitGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePreInitGroups, + [&backend](const CustomConnectivityUpdateInternal &cg) + { + return (cg.isPreVarInitRequired() || (backend.isPopulationRNGInitialisedOnDevice() && cg.isRowSimRNGRequired())); + }, + &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); + } + + template + void genMergedCustomConnectivityUpdatePostInitGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePostInitGroups, + [](const CustomConnectivityUpdateInternal &cg) { return cg.isPostVarInitRequired(); }, + &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); + } + + template + void genMergedCustomConnectivityUpdateSparseInitGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateSparseInitGroups, + [](const CustomConnectivityUpdateInternal &cg) { return cg.isVarInitRequired(); }, + &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); + } + + template + void genMergedSynapseConnectivityHostInitGroups(const BackendBase &backend, G generateGroup) + { + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseConnectivityHostInitGroups, + [](const SynapseGroupInternal &sg) + { + return !sg.getConnectivityInitialiser().getSnippet()->getHostInitCode().empty(); + }, + &SynapseGroupInternal::getConnectivityHostInitHashDigest, generateGroup, true); + } + void genMergedNeuronUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedNeuronUpdateGroups); } void genMergedPresynapticUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedPresynapticUpdateGroups); } void genMergedPostsynapticUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedPostsynapticUpdateGroups); } @@ -290,10 +524,10 @@ class GENN_EXPORT ModelSpecMerged } } - template + template void createMergedGroups(const BackendBase &backend, - const std::vector> &unmergedGroups, - std::vector &mergedGroups, D getHashDigest, bool host = false) + const std::vector> &unmergedGroups, + std::vector &mergedGroups, D getHashDigest, G generateGroup, bool host = false) { // Create a hash map to group together groups with the same SHA1 digest std::unordered_map + template void createMergedGroups(const BackendBase &backend, - const std::map &groups, std::vector &mergedGroups, - F filter, U updateHash, bool host = false) + const std::map &groups, std::vector &mergedGroups, + F filter, U updateHash, G generateGroup, bool host = false) { // Build temporary vector of references to groups that pass filter std::vector> unmergedGroups; @@ -350,7 +585,7 @@ class GENN_EXPORT ModelSpecMerged } // Merge filtered vector - createMergedGroups(backend, unmergedGroups, mergedGroups, updateHash, host); + createMergedGroups(backend, unmergedGroups, mergedGroups, updateHash, generateGroup, host); } //-------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index 76f3cc9389..6d1e34bf0d 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -34,165 +34,7 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa m_PresynapticUpdateSupportCode("PresynapticUpdateSupportCode"), m_PostsynapticUpdateSupportCode("PostsynapticUpdateSupportCode"), m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode"), m_TypeContext{{"scalar", model.getPrecision()}, {"timepoint", model.getTimePrecision()}} { - LOGD_CODE_GEN << "Merging neuron update groups:"; - createMergedGroups(backend, model.getNeuronGroups(), m_MergedNeuronUpdateGroups, - [](const NeuronGroupInternal &){ return true; }, - &NeuronGroupInternal::getHashDigest); - - LOGD_CODE_GEN << "Merging presynaptic update groups:"; - createMergedGroups(backend, model.getSynapseGroups(), m_MergedPresynapticUpdateGroups, - [](const SynapseGroupInternal &sg) { return (sg.isSpikeEventRequired() || sg.isTrueSpikeRequired()); }, - &SynapseGroupInternal::getWUHashDigest); - - LOGD_CODE_GEN << "Merging postsynaptic update groups:"; - createMergedGroups(backend, model.getSynapseGroups(), m_MergedPostsynapticUpdateGroups, - [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getLearnPostCode().empty(); }, - &SynapseGroupInternal::getWUHashDigest); - - LOGD_CODE_GEN << "Merging synapse dynamics update groups:"; - createMergedGroups(backend, model.getSynapseGroups(), m_MergedSynapseDynamicsGroups, - [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getSynapseDynamicsCode().empty(); }, - &SynapseGroupInternal::getWUHashDigest); - - LOGD_CODE_GEN << "Merging neuron initialization groups:"; - createMergedGroups(backend, model.getNeuronGroups(), m_MergedNeuronInitGroups, - [](const NeuronGroupInternal &){ return true; }, - &NeuronGroupInternal::getInitHashDigest); - - LOGD_CODE_GEN << "Merging synapse initialization groups:"; - createMergedGroups(backend, model.getSynapseGroups(), m_MergedSynapseInitGroups, - [](const SynapseGroupInternal &sg) - { - return (((sg.getMatrixType() & SynapseMatrixConnectivity::DENSE) - || (sg.getMatrixType() & SynapseMatrixWeight::KERNEL)) - && sg.isWUVarInitRequired()); - }, - &SynapseGroupInternal::getWUInitHashDigest); - - LOGD_CODE_GEN << "Merging custom update initialization groups:"; - createMergedGroups(backend, model.getCustomUpdates(), m_MergedCustomUpdateInitGroups, - [](const CustomUpdateInternal &cg) { return cg.isVarInitRequired(); }, - &CustomUpdateInternal::getInitHashDigest); - - LOGD_CODE_GEN << "Merging custom weight update initialization groups:"; - createMergedGroups(backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateInitGroups, - [](const CustomUpdateWUInternal &cg) - { - return (((cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::DENSE) - || (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL)) - && cg.isVarInitRequired()); - }, - &CustomUpdateWUInternal::getInitHashDigest); - - LOGD_CODE_GEN << "Merging synapse connectivity initialisation groups:"; - createMergedGroups(backend, model.getSynapseGroups(), m_MergedSynapseConnectivityInitGroups, - [](const SynapseGroupInternal &sg){ return sg.isSparseConnectivityInitRequired(); }, - &SynapseGroupInternal::getConnectivityInitHashDigest); - - LOGD_CODE_GEN << "Merging synapse sparse initialization groups:"; - createMergedGroups(backend, model.getSynapseGroups(), m_MergedSynapseSparseInitGroups, - [&backend](const SynapseGroupInternal &sg) - { - return ((sg.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && - (sg.isWUVarInitRequired() - || (backend.isPostsynapticRemapRequired() && !sg.getWUModel()->getLearnPostCode().empty()))); - }, - &SynapseGroupInternal::getWUInitHashDigest); - - LOGD_CODE_GEN << "Merging custom sparse weight update initialization groups:"; - createMergedGroups(backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateSparseInitGroups, - [](const CustomUpdateWUInternal &cg) - { - return (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) && cg.isVarInitRequired(); - }, - &CustomUpdateWUInternal::getInitHashDigest); - - LOGD_CODE_GEN << "Merging custom connectivity update presynaptic initialisation groups:"; - createMergedGroups(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePreInitGroups, - [&backend](const CustomConnectivityUpdateInternal &cg) - { - return (cg.isPreVarInitRequired() || (backend.isPopulationRNGInitialisedOnDevice() && cg.isRowSimRNGRequired())); - }, - &CustomConnectivityUpdateInternal::getInitHashDigest); - - LOGD_CODE_GEN << "Merging custom connectivity update postsynaptic initialisation groups:"; - createMergedGroups(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePostInitGroups, - [](const CustomConnectivityUpdateInternal &cg) { return cg.isPostVarInitRequired(); }, - &CustomConnectivityUpdateInternal::getInitHashDigest); - - LOGD_CODE_GEN << "Merging custom connectivity update sparse initialisation groups:"; - createMergedGroups(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateSparseInitGroups, - [](const CustomConnectivityUpdateInternal &cg) { return cg.isVarInitRequired(); }, - &CustomConnectivityUpdateInternal::getInitHashDigest); - - LOGD_CODE_GEN << "Merging neuron groups which require their spike queues updating:"; - createMergedGroups(backend, model.getNeuronGroups(), m_MergedNeuronSpikeQueueUpdateGroups, - [](const NeuronGroupInternal &){ return true; }, - &NeuronGroupInternal::getSpikeQueueUpdateHashDigest); - - LOGD_CODE_GEN << "Merging neuron groups which require their previous spike times updating:"; - createMergedGroups(backend, model.getNeuronGroups(), m_MergedNeuronPrevSpikeTimeUpdateGroups, - [](const NeuronGroupInternal &ng){ return (ng.isPrevSpikeTimeRequired() || ng.isPrevSpikeEventTimeRequired()); }, - &NeuronGroupInternal::getPrevSpikeTimeUpdateHashDigest); - - // Build vector of merged synapse groups which require dendritic delay - std::vector> synapseGroupsWithDendriticDelay; - for(const auto &n : model.getNeuronGroups()) { - for(const auto *sg : n.second.getFusedPSMInSyn()) { - if(sg->isDendriticDelayRequired()) { - synapseGroupsWithDendriticDelay.push_back(std::cref(*sg)); - } - } - } - LOGD_CODE_GEN << "Merging synapse groups which require their dendritic delay updating:"; - createMergedGroups(backend, synapseGroupsWithDendriticDelay, m_MergedSynapseDendriticDelayUpdateGroups, - &SynapseGroupInternal::getDendriticDelayUpdateHashDigest); - - LOGD_CODE_GEN << "Merging synapse groups which require host code to initialise their synaptic connectivity:"; - createMergedGroups(backend, model.getSynapseGroups(), m_MergedSynapseConnectivityHostInitGroups, - [](const SynapseGroupInternal &sg) - { - return !sg.getConnectivityInitialiser().getSnippet()->getHostInitCode().empty(); - }, - &SynapseGroupInternal::getConnectivityHostInitHashDigest, true); - - LOGD_CODE_GEN << "Merging custom update groups:"; - createMergedGroups(backend, model.getCustomUpdates(), m_MergedCustomUpdateGroups, - [](const CustomUpdateInternal &) { return true; }, - &CustomUpdateInternal::getHashDigest); - - LOGD_CODE_GEN << "Merging custom weight update groups:"; - createMergedGroups(backend, model.getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, - [](const CustomUpdateWUInternal &cg) { return !cg.isTransposeOperation(); }, - &CustomUpdateWUInternal::getHashDigest); - - LOGD_CODE_GEN << "Merging custom weight transpose update groups:"; - createMergedGroups(backend, model.getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, - [](const CustomUpdateWUInternal &cg) { return cg.isTransposeOperation(); }, - &CustomUpdateWUInternal::getHashDigest); - - if(backend.isHostReductionRequired()) { - LOGD_CODE_GEN << "Merging custom weight update groups:"; - createMergedGroups(backend, model.getCustomUpdates(), m_MergedCustomUpdateHostReductionGroups, - [](const CustomUpdateInternal &cg) { return cg.isBatchReduction(); }, - &CustomUpdateInternal::getHashDigest, true); - - LOGD_CODE_GEN << "Merging custom weight transpose update groups:"; - createMergedGroups(backend, model.getCustomWUUpdates(), m_MergedCustomWUUpdateHostReductionGroups, - [](const CustomUpdateWUInternal &cg) { return cg.isBatchReduction(); }, - &CustomUpdateWUInternal::getHashDigest, true); - } - - LOGD_CODE_GEN << "Merging custom connectivity update groups:"; - createMergedGroups(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateGroups, - [](const CustomConnectivityUpdateInternal &cg) { return !cg.getCustomConnectivityUpdateModel()->getRowUpdateCode().empty(); }, - &CustomConnectivityUpdateInternal::getHashDigest); - - LOGD_CODE_GEN << "Merging custom connectivity host update groups:"; - createMergedGroups(backend, model.getCustomConnectivityUpdates(), m_MergedCustomConnectivityHostUpdateGroups, - [](const CustomConnectivityUpdateInternal &cg) { return !cg.getCustomConnectivityUpdateModel()->getHostUpdateCode().empty(); }, - &CustomConnectivityUpdateInternal::getHashDigest, true); - + // Get memory spaces available to this backend // **NOTE** Memory spaces are given out on a first-come, first-serve basis so subsequent groups are in preferential order auto memorySpaces = backend.getMergedGroupMemorySpaces(*this); @@ -253,7 +95,6 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa assignGroups(backend, m_MergedCustomConnectivityUpdatePreInitGroups, memorySpaces); assignGroups(backend, m_MergedCustomConnectivityUpdatePostInitGroups, memorySpaces); assignGroups(backend, m_MergedCustomConnectivityUpdateSparseInitGroups, memorySpaces); - } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type ModelSpecMerged::getHashDigest(const BackendBase &backend) const From 713dcd8ca3f51c32116ae00e48d642c1fb2c9cb0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 12 Jun 2023 11:04:34 +0100 Subject: [PATCH 207/725] restructured single-threaded CPU version of genSynapseUpdate --- .../backends/single_threaded_cpu/backend.cc | 287 +++++++++--------- 1 file changed, 150 insertions(+), 137 deletions(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 72e123fb10..ffd1665ff0 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -297,200 +297,213 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); - if(model.getBatchSize() != 1) { + if (model.getBatchSize() != 1) { throw std::runtime_error("The single-threaded CPU backend only supports simulations with a batch size of 1"); } - - // Generate struct definitions - // **YUCK** dendritic delay update structs not actually required - modelMerged.genMergedSynapseDendriticDelayUpdateStructs(os, *this); - modelMerged.genMergedPresynapticUpdateGroupStructs(os, *this); - modelMerged.genMergedPostsynapticUpdateGroupStructs(os, *this); - modelMerged.genMergedSynapseDynamicsGroupStructs(os, *this); - - // Generate arrays of merged structs and functions to set them - // **YUCK** dendritic delay update structs not actually required - genMergedStructArrayPush(os, modelMerged.getMergedSynapseDendriticDelayUpdateGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedPresynapticUpdateGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedPostsynapticUpdateGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedSynapseDynamicsGroups()); - - // Generate preamble - preambleHandler(os); - os << "void updateSynapses(timepoint t)"; + // Generate stream with synapse update code + std::ostringstream synapseUpdateStream; + CodeStream synapseUpdate(synapseUpdateStream); + synapseUpdate << "void updateSynapses(timepoint t)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(synapseUpdate); Substitutions funcSubs(getFunctionTemplates(model.getPrecision().getName())); funcSubs.addVarSubstitution("t", "t"); funcSubs.addVarSubstitution("batch", "0"); // Synapse dynamics { - // Loop through merged synapse dynamics groups - Timer t(os, "synapseDynamics", model.isTimingEnabled()); - for(const auto &s : modelMerged.getMergedSynapseDynamicsGroups()) { - CodeStream::Scope b(os); - os << "// merged synapse dynamics group " << s.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; + Timer t(synapseUpdate, "synapseDynamics", model.isTimingEnabled()); + modelMerged.genMergedSynapseDynamicsGroups( + *this, + [this, &funcSubs, &synapseUpdate](SynapseDynamicsGroupMerged &s) { - CodeStream::Scope b(os); + CodeStream::Scope b(synapseUpdate); + synapseUpdate << "// merged synapse dynamics group " << s.getIndex() << std::endl; + synapseUpdate << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(synapseUpdate); - // Get reference to group - os << "const auto *group = &mergedSynapseDynamicsGroup" << s.getIndex() << "[g]; " << std::endl; + // Get reference to group + synapseUpdate << "const auto *group = &mergedSynapseDynamicsGroup" << s.getIndex() << "[g]; " << std::endl; - genSynapseIndexCalculation(os, s, 1); + genSynapseIndexCalculation(synapseUpdate, s, 1); - // Loop through presynaptic neurons - os << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; - { - // If this synapse group has sparse connectivity, loop through length of this row - CodeStream::Scope b(os); - if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << "for(unsigned int s = 0; s < group->rowLength[i]; s++)"; - } - // Otherwise, if it's dense, loop through each postsynaptic neuron - else if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::DENSE) { - os << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; - } - else { - throw std::runtime_error("Only DENSE and SPARSE format connectivity can be used for synapse dynamics"); - } + // Loop through presynaptic neurons + synapseUpdate << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; { - CodeStream::Scope b(os); - - Substitutions synSubs(&funcSubs); + // If this synapse group has sparse connectivity, loop through length of this row + CodeStream::Scope b(synapseUpdate); if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - // Calculate index of synapse and use it to look up postsynaptic index - os << "const unsigned int n = (i * group->rowStride) + s;" << std::endl; - os << "const unsigned int j = group->ind[n];" << std::endl; - - synSubs.addVarSubstitution("id_syn", "n"); + synapseUpdate << "for(unsigned int s = 0; s < group->rowLength[i]; s++)"; + } + // Otherwise, if it's dense, loop through each postsynaptic neuron + else if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::DENSE) { + synapseUpdate << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; } else { - synSubs.addVarSubstitution("id_syn", "(i * group->numTrgNeurons) + j"); + throw std::runtime_error("Only DENSE and SPARSE format connectivity can be used for synapse dynamics"); } + { + CodeStream::Scope b(synapseUpdate); - // Add pre and postsynaptic indices to substitutions - synSubs.addVarSubstitution("id_pre", "i"); - synSubs.addVarSubstitution("id_post", "j"); + Substitutions synSubs(&funcSubs); + if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { + // Calculate index of synapse and use it to look up postsynaptic index + synapseUpdate << "const unsigned int n = (i * group->rowStride) + s;" << std::endl; + synapseUpdate << "const unsigned int j = group->ind[n];" << std::endl; - // Add correct functions for apply synaptic input - if(s.getArchetype().isDendriticDelayRequired()) { - synSubs.addFuncSubstitution("addToInSynDelay", 2, "group->denDelay[" + s.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)"); - } - else { - synSubs.addFuncSubstitution("addToInSyn", 1, "group->inSyn[" + s.getPostISynIndex(1, "j") + "] += $(0)"); - } + synSubs.addVarSubstitution("id_syn", "n"); + } + else { + synSubs.addVarSubstitution("id_syn", "(i * group->numTrgNeurons) + j"); + } + + // Add pre and postsynaptic indices to substitutions + synSubs.addVarSubstitution("id_pre", "i"); + synSubs.addVarSubstitution("id_post", "j"); - if(s.getArchetype().isPresynapticOutputRequired()) { - synSubs.addFuncSubstitution("addToPre", 1, "group->revInSyn[" + s.getPreISynIndex(1, synSubs["id_pre"]) + "] += $(0)"); + // Add correct functions for apply synaptic input + if(s.getArchetype().isDendriticDelayRequired()) { + synSubs.addFuncSubstitution("addToInSynDelay", 2, "group->denDelay[" + s.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)"); + } + else { + synSubs.addFuncSubstitution("addToInSyn", 1, "group->inSyn[" + s.getPostISynIndex(1, "j") + "] += $(0)"); + } + + if(s.getArchetype().isPresynapticOutputRequired()) { + synSubs.addFuncSubstitution("addToPre", 1, "group->revInSyn[" + s.getPreISynIndex(1, synSubs["id_pre"]) + "] += $(0)"); + } + // Call synapse dynamics handler + s.generateSynapseUpdate(*this, synapseUpdate, modelMerged, synSubs); } - // Call synapse dynamics handler - s.generateSynapseUpdate(*this, os, modelMerged, synSubs); } } - } - } + }); } // Presynaptic update { - Timer t(os, "presynapticUpdate", model.isTimingEnabled()); - for(const auto &s : modelMerged.getMergedPresynapticUpdateGroups()) { - CodeStream::Scope b(os); - os << "// merged presynaptic update group " << s.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; + Timer t(synapseUpdate, "presynapticUpdate", model.isTimingEnabled()); + modelMerged.genMergedPresynapticUpdateGroups( + *this, + [this, &funcSubs, &synapseUpdate](SynapseDynamicsGroupMerged &s) { - CodeStream::Scope b(os); + CodeStream::Scope b(synapseUpdate); + synapseUpdate << "// merged presynaptic update group " << s.getIndex() << std::endl; + synapseUpdate << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(synapseUpdate); - // Get reference to group - os << "const auto *group = &mergedPresynapticUpdateGroup" << s.getIndex() << "[g]; " << std::endl; + // Get reference to group + synapseUpdate << "const auto *group = &mergedPresynapticUpdateGroup" << s.getIndex() << "[g]; " << std::endl; - genSynapseIndexCalculation(os, s, 1); + genSynapseIndexCalculation(synapseUpdate, s, 1); - // generate the code for processing spike-like events - if (s.getArchetype().isSpikeEventRequired()) { - genPresynapticUpdate(os, modelMerged, s, funcSubs, false); - } + // generate the code for processing spike-like events + if (s.getArchetype().isSpikeEventRequired()) { + genPresynapticUpdate(synapseUpdate, modelMerged, s, funcSubs, false); + } - // generate the code for processing true spike events - if (s.getArchetype().isTrueSpikeRequired()) { - genPresynapticUpdate(os, modelMerged, s, funcSubs, true); + // generate the code for processing true spike events + if (s.getArchetype().isTrueSpikeRequired()) { + genPresynapticUpdate(synapseUpdate, modelMerged, s, funcSubs, true); + } + synapseUpdate << std::endl; } - os << std::endl; - } - } + }); } // Postsynaptic update { - Timer t(os, "postsynapticUpdate", model.isTimingEnabled()); - for(const auto &s : modelMerged.getMergedPostsynapticUpdateGroups()) { - CodeStream::Scope b(os); - os << "// merged postsynaptic update group " << s.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; + Timer t(synapseUpdate, "postsynapticUpdate", model.isTimingEnabled()); + modelMerged.genMergedPostsynapticUpdateGroups( + *this, + [this, &funcSubs, &synapseUpdate](SynapseDynamicsGroupMerged &s) { - CodeStream::Scope b(os); - - // Get reference to group - os << "const auto *group = &mergedPostsynapticUpdateGroup" << s.getIndex() << "[g]; " << std::endl; - - genSynapseIndexCalculation(os, s, 1); - - // Get number of postsynaptic spikes - if (s.getArchetype().getTrgNeuronGroup()->isDelayRequired() && s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired()) { - os << "const unsigned int numSpikes = group->trgSpkCnt[postDelaySlot];" << std::endl; - } - else { - os << "const unsigned int numSpikes = group->trgSpkCnt[0];" << std::endl; - } - - // Loop through postsynaptic spikes - os << "for (unsigned int j = 0; j < numSpikes; j++)"; + CodeStream::Scope b(synapseUpdate); + synapseUpdate << "// merged postsynaptic update group " << s.getIndex() << std::endl; + synapseUpdate << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(synapseUpdate); - const std::string offsetTrueSpkPost = (s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired() && s.getArchetype().getTrgNeuronGroup()->isDelayRequired()) ? "postDelayOffset + " : ""; - os << "const unsigned int spike = group->trgSpk[" << offsetTrueSpkPost << "j];" << std::endl; + // Get reference to group + synapseUpdate << "const auto *group = &mergedPostsynapticUpdateGroup" << s.getIndex() << "[g]; " << std::endl; + + genSynapseIndexCalculation(synapseUpdate, s, 1); - // Loop through column of presynaptic neurons - if (s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << "const unsigned int npre = group->colLength[spike];" << std::endl; - os << "for (unsigned int i = 0; i < npre; i++)"; + // Get number of postsynaptic spikes + if (s.getArchetype().getTrgNeuronGroup()->isDelayRequired() && s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired()) { + synapseUpdate << "const unsigned int numSpikes = group->trgSpkCnt[postDelaySlot];" << std::endl; } else { - os << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)"; + synapseUpdate << "const unsigned int numSpikes = group->trgSpkCnt[0];" << std::endl; } + + // Loop through postsynaptic spikes + synapseUpdate << "for (unsigned int j = 0; j < numSpikes; j++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(synapseUpdate); - Substitutions synSubs(&funcSubs); - if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << "const unsigned int colMajorIndex = (spike * group->colStride) + i;" << std::endl; - os << "const unsigned int rowMajorIndex = group->remap[colMajorIndex];" << std::endl; + const std::string offsetTrueSpkPost = (s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired() && s.getArchetype().getTrgNeuronGroup()->isDelayRequired()) ? "postDelayOffset + " : ""; + synapseUpdate << "const unsigned int spike = group->trgSpk[" << offsetTrueSpkPost << "j];" << std::endl; - // **TODO** fast divide optimisations - synSubs.addVarSubstitution("id_pre", "(rowMajorIndex / group->rowStride)"); - synSubs.addVarSubstitution("id_syn", "rowMajorIndex"); + // Loop through column of presynaptic neurons + if (s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { + synapseUpdate << "const unsigned int npre = group->colLength[spike];" << std::endl; + synapseUpdate << "for (unsigned int i = 0; i < npre; i++)"; } else { - synSubs.addVarSubstitution("id_pre", "i"); - synSubs.addVarSubstitution("id_syn", "((group->numTrgNeurons * i) + spike)"); - } - synSubs.addVarSubstitution("id_post", "spike"); - if (s.getArchetype().isPresynapticOutputRequired()) { - synSubs.addFuncSubstitution("addToPre", 1, "group->revInSyn[" + s.getPreISynIndex(1, synSubs["id_pre"]) + "] += $(0)"); + synapseUpdate << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)"; } + { + CodeStream::Scope b(synapseUpdate); + + Substitutions synSubs(&funcSubs); + if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { + synapseUpdate << "const unsigned int colMajorIndex = (spike * group->colStride) + i;" << std::endl; + synapseUpdate << "const unsigned int rowMajorIndex = group->remap[colMajorIndex];" << std::endl; + + // **TODO** fast divide optimisations + synSubs.addVarSubstitution("id_pre", "(rowMajorIndex / group->rowStride)"); + synSubs.addVarSubstitution("id_syn", "rowMajorIndex"); + } + else { + synSubs.addVarSubstitution("id_pre", "i"); + synSubs.addVarSubstitution("id_syn", "((group->numTrgNeurons * i) + spike)"); + } + synSubs.addVarSubstitution("id_post", "spike"); + if (s.getArchetype().isPresynapticOutputRequired()) { + synSubs.addFuncSubstitution("addToPre", 1, "group->revInSyn[" + s.getPreISynIndex(1, synSubs["id_pre"]) + "] += $(0)"); + } - s.generateSynapseUpdate(*this, os, modelMerged, synSubs); + s.generateSynapseUpdate(*this, synapseUpdate, modelMerged, synSubs); + } } + synapseUpdate << std::endl; } - os << std::endl; - } - } + }); } } + + // Generate struct definitions + // **YUCK** dendritic delay update structs not actually required + modelMerged.genMergedSynapseDendriticDelayUpdateStructs(os, *this); + modelMerged.genMergedPresynapticUpdateGroupStructs(os, *this); + modelMerged.genMergedPostsynapticUpdateGroupStructs(os, *this); + modelMerged.genMergedSynapseDynamicsGroupStructs(os, *this); + + // Generate arrays of merged structs and functions to set them + // **YUCK** dendritic delay update structs not actually required + genMergedStructArrayPush(os, modelMerged.getMergedSynapseDendriticDelayUpdateGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedPresynapticUpdateGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedPostsynapticUpdateGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedSynapseDynamicsGroups()); + + // Generate preamble + preambleHandler(os); + + os << synapseUpdateStream.str(); } //-------------------------------------------------------------------------- void Backend::genCustomUpdate(CodeStream &os_, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const From 5ff0ceaadd9968efc6457e0c4a299d86e19cbd8b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 12 Jun 2023 11:14:58 +0100 Subject: [PATCH 208/725] * made EnvironmentExternal multiply inherit type and pretty printing environments * merged together type and pretty printing standard library environments and moved into code generator --- .../genn/genn/code_generator/environment.h | 9 +++- .../genn/code_generator/standardLibrary.h | 30 ++++++++++++ .../genn/genn/transpiler/standardLibrary.h | 47 ------------------- src/genn/genn/code_generator/environment.cc | 6 +++ .../standardLibrary.cc | 25 +++------- src/genn/genn/genn.vcxproj | 4 +- 6 files changed, 52 insertions(+), 69 deletions(-) create mode 100644 include/genn/genn/code_generator/standardLibrary.h delete mode 100644 include/genn/genn/transpiler/standardLibrary.h rename src/genn/genn/{transpiler => code_generator}/standardLibrary.cc (88%) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 382f7bd8f0..e26d682617 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -14,13 +14,14 @@ // GeNN transpiler includes #include "transpiler/prettyPrinter.h" +#include "transpiler/typeChecker.h" //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentExternal //---------------------------------------------------------------------------- namespace GeNN::CodeGenerator { -class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase +class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase, public Transpiler::TypeChecker::EnvironmentBase { protected: using EnvironmentBase = Transpiler::PrettyPrinter::EnvironmentBase; @@ -42,6 +43,12 @@ class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase //------------------------------------------------------------------------ virtual std::string define(const std::string &name); + //------------------------------------------------------------------------ + // TypeChecker::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual void define(const Transpiler::Token &name, const Type::ResolvedType &type, + Transpiler::ErrorHandlerBase &errorHandler); + protected: //------------------------------------------------------------------------ // Protected API diff --git a/include/genn/genn/code_generator/standardLibrary.h b/include/genn/genn/code_generator/standardLibrary.h new file mode 100644 index 0000000000..978d2f1e07 --- /dev/null +++ b/include/genn/genn/code_generator/standardLibrary.h @@ -0,0 +1,30 @@ +#pragma once + +// Standard C++ includes +#include +#include + +// Code generator includes +#include "code_generator/codeStream.h" +#include "code_generator/environment.h" + +//--------------------------------------------------------------------------- +// GeNN::CodeGenerator::StandardLibrary::FunctionTypes +//--------------------------------------------------------------------------- +namespace GeNN::CodeGenerator::StandardLibrary +{ +class Environment : public EnvironmentExternal +{ +public: + //------------------------------------------------------------------------ + // TypeChecker::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::vector getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) final; + + //------------------------------------------------------------------------ + // PrettyPrinter::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::string getName(const std::string &name, std::optional type = std::nullopt) final; + virtual CodeGenerator::CodeStream &getStream() final; +}; +} // namespace GeNN::CodeGenerator::StandardLibrary diff --git a/include/genn/genn/transpiler/standardLibrary.h b/include/genn/genn/transpiler/standardLibrary.h deleted file mode 100644 index 7a4ceaa228..0000000000 --- a/include/genn/genn/transpiler/standardLibrary.h +++ /dev/null @@ -1,47 +0,0 @@ -#pragma once - -// Standard C++ includes -#include -#include - -// Code generator includes -#include "code_generator/codeStream.h" -#include "code_generator/environment.h" - -// Transpiler includes -#include "transpiler/typeChecker.h" - -//--------------------------------------------------------------------------- -// GeNN::Transpiler::StandardLibrary::FunctionTypes -//--------------------------------------------------------------------------- -namespace GeNN::Transpiler::StandardLibrary -{ -class FunctionTypes : public TypeChecker::EnvironmentBase -{ -public: - FunctionTypes(); - - //------------------------------------------------------------------------ - // EnvironmentBase virtuals - //------------------------------------------------------------------------ - virtual void define(const Token &name, const Type::ResolvedType &type, ErrorHandlerBase &errorHandler) final; - virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final; -}; - -//--------------------------------------------------------------------------- -// GeNN::Transpiler::StandardLibrary::FunctionEnvironment -//--------------------------------------------------------------------------- -class FunctionEnvironment : public CodeGenerator::EnvironmentExternal -{ -public: - FunctionEnvironment(CodeGenerator::CodeStream &os) - : CodeGenerator::EnvironmentExternal(os) - {} - - //------------------------------------------------------------------------ - // PrettyPrinter::EnvironmentBase virtuals - //------------------------------------------------------------------------ - virtual std::string getName(const std::string &name, std::optional type = std::nullopt) final; - virtual CodeGenerator::CodeStream &getStream() final; -}; -} // namespace GeNN::Transpiler::StandardLibrary diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index c1a4fd2b5b..551094c93a 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -19,6 +19,12 @@ std::string EnvironmentExternal::define(const std::string&) throw std::runtime_error("Cannot declare variable in external environment"); } //---------------------------------------------------------------------------- +void EnvironmentExternal::define(const Transpiler::Token&, const Type::ResolvedType&, + Transpiler::ErrorHandlerBase&) +{ + throw std::runtime_error("Cannot declare variable in external environment"); +} +//---------------------------------------------------------------------------- CodeStream &EnvironmentExternal::getContextStream() const { return std::visit( diff --git a/src/genn/genn/transpiler/standardLibrary.cc b/src/genn/genn/code_generator/standardLibrary.cc similarity index 88% rename from src/genn/genn/transpiler/standardLibrary.cc rename to src/genn/genn/code_generator/standardLibrary.cc index 6818c985b9..1fb50de5ef 100644 --- a/src/genn/genn/transpiler/standardLibrary.cc +++ b/src/genn/genn/code_generator/standardLibrary.cc @@ -1,4 +1,4 @@ -#include "transpiler/standardLibrary.h" +#include "code_generator/standardLibrary.h" // Standard C++ library #include @@ -12,7 +12,7 @@ #include "transpiler/typeChecker.h" using namespace GeNN::CodeGenerator; -using namespace GeNN::Transpiler::StandardLibrary; +using namespace GeNN::CodeGenerator::StandardLibrary; using namespace GeNN::Transpiler::TypeChecker; namespace Type = GeNN::Type; @@ -130,19 +130,9 @@ const auto libraryTypes = initLibraryTypes( //min, max, printf //--------------------------------------------------------------------------- -// GeNN::Transpiler::StandardLibrary::FunctionTypes +// GeNN::Transpiler::StandardLibrary::Environment //--------------------------------------------------------------------------- -FunctionTypes::FunctionTypes() -{ -} -//------------------------------------------------------------------------ -void FunctionTypes::define(const Token &name, const Type::ResolvedType&, ErrorHandlerBase &errorHandler) -{ - errorHandler.error(name, "Cannot declare variable in external environment"); - throw TypeCheckError(); -} -//--------------------------------------------------------------------------- -std::vector FunctionTypes::getTypes(const Token &name, ErrorHandlerBase &errorHandler) +std::vector Environment::getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) { const auto [typeBegin, typeEnd] = libraryTypes.equal_range(name.lexeme); if (typeBegin == typeEnd) { @@ -157,11 +147,8 @@ std::vector FunctionTypes::getTypes(const Token &name, Error return types; } } - -//--------------------------------------------------------------------------- -// GeNN::Transpiler::StandardLibrary::FunctionEnvironment //--------------------------------------------------------------------------- -std::string FunctionEnvironment::getName(const std::string &name, std::optional type) +std::string Environment::getName(const std::string &name, std::optional type) { const auto [libTypeBegin, libTypeEnd] = libraryTypes.equal_range(name); if (libTypeBegin == libTypeEnd) { @@ -178,7 +165,7 @@ std::string FunctionEnvironment::getName(const std::string &name, std::optional< } } //--------------------------------------------------------------------------- -CodeStream &FunctionEnvironment::getStream() +CodeStream &Environment::getStream() { return getContextStream(); } \ No newline at end of file diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index 5815b91f05..6db5c9083f 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -36,6 +36,7 @@ + @@ -60,7 +61,6 @@ - @@ -90,6 +90,7 @@ + @@ -125,7 +126,6 @@ - From 538a433d3524bc033125c73964ae4e9500a80d27 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 12 Jun 2023 11:47:33 +0100 Subject: [PATCH 209/725] comments --- include/genn/genn/transpiler/prettyPrinter.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h index 03dae95df1..25cc2a3979 100644 --- a/include/genn/genn/transpiler/prettyPrinter.h +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -27,10 +27,10 @@ class EnvironmentBase //------------------------------------------------------------------------ // Declared virtuals //------------------------------------------------------------------------ - //! Define named variable and return the name as it should be used in code + //! Define identifier and return the name as it should be used in code virtual std::string define(const std::string &name) = 0; - //! Get the name to use in code for the variable named by token + //! Get the name to use in code for the named identifier virtual std::string getName(const std::string &name, std::optional type = std::nullopt) = 0; //! Get stream to write code within this environment to From 7bfd6514f7b730c0744fac32c560913a9e07b507 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 12 Jun 2023 11:47:52 +0100 Subject: [PATCH 210/725] bring in constructor for StandardLibrary::Environment --- include/genn/genn/code_generator/standardLibrary.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/genn/genn/code_generator/standardLibrary.h b/include/genn/genn/code_generator/standardLibrary.h index 978d2f1e07..24a0b7dd6e 100644 --- a/include/genn/genn/code_generator/standardLibrary.h +++ b/include/genn/genn/code_generator/standardLibrary.h @@ -16,6 +16,8 @@ namespace GeNN::CodeGenerator::StandardLibrary class Environment : public EnvironmentExternal { public: + using EnvironmentExternal::EnvironmentExternal; + //------------------------------------------------------------------------ // TypeChecker::EnvironmentBase virtuals //------------------------------------------------------------------------ From cd3deb0f2233aa0a5743d2f6e59b71facfdee374 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 12 Jun 2023 11:48:05 +0100 Subject: [PATCH 211/725] slightly empowered EnvironmentExternal --- .../genn/genn/code_generator/environment.h | 34 +++++++--- src/genn/genn/code_generator/environment.cc | 62 +++++++++++++++++-- 2 files changed, 82 insertions(+), 14 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index e26d682617..d2a7e1027d 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -8,6 +8,7 @@ // GeNN includes #include "gennUtils.h" #include "varAccess.h" +#include "type.h" // GeNN code generator includes #include "code_generator/codeStream.h" @@ -16,6 +17,11 @@ #include "transpiler/prettyPrinter.h" #include "transpiler/typeChecker.h" +namespace GeNN::Transpiler +{ +class ErrorHandlerBase; +struct Token; +} //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentExternal //---------------------------------------------------------------------------- @@ -23,15 +29,13 @@ namespace GeNN::CodeGenerator { class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase, public Transpiler::TypeChecker::EnvironmentBase { -protected: - using EnvironmentBase = Transpiler::PrettyPrinter::EnvironmentBase; public: - EnvironmentExternal(EnvironmentBase &enclosing) + explicit EnvironmentExternal(EnvironmentExternal &enclosing) : m_Context(enclosing) { } - EnvironmentExternal(CodeStream &os) + explicit EnvironmentExternal(CodeStream &os) : m_Context(os) { } @@ -41,13 +45,22 @@ class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase, p //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual std::string define(const std::string &name); + virtual std::string define(const std::string &name) override; + + virtual std::string getName(const std::string &name, std::optional type = std::nullopt) override; //------------------------------------------------------------------------ // TypeChecker::EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual void define(const Transpiler::Token &name, const Type::ResolvedType &type, - Transpiler::ErrorHandlerBase &errorHandler); + virtual void define(const Transpiler::Token &name, const GeNN::Type::ResolvedType &type, + Transpiler::ErrorHandlerBase &errorHandler) override; + + virtual std::vector getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) override; + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + void add(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &value); protected: //------------------------------------------------------------------------ @@ -63,7 +76,8 @@ class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase, p //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - std::variant, std::reference_wrapper> m_Context; + std::variant, std::reference_wrapper> m_Context; + std::unordered_map> m_Environment; }; //---------------------------------------------------------------------------- @@ -74,12 +88,12 @@ class EnvironmentSubstitute : public EnvironmentExternal { public: EnvironmentSubstitute(EnvironmentSubstitute &enclosing) - : EnvironmentExternal(static_cast(enclosing)), m_Contents(m_ContentsStream) + : EnvironmentExternal(static_cast(enclosing)), m_Contents(m_ContentsStream) { } EnvironmentSubstitute(EnvironmentExternal &enclosing) - : EnvironmentExternal(static_cast(enclosing)), m_Contents(m_ContentsStream) + : EnvironmentExternal(static_cast(enclosing)), m_Contents(m_ContentsStream) { } diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index 551094c93a..89c7d3f0d4 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -9,7 +9,12 @@ // GeNN includes #include "gennUtils.h" +// Transpiler includes +#include "transpiler/errorHandler.h" + +using namespace GeNN; using namespace GeNN::CodeGenerator; +using namespace GeNN::Transpiler; //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentExternal @@ -19,17 +24,66 @@ std::string EnvironmentExternal::define(const std::string&) throw std::runtime_error("Cannot declare variable in external environment"); } //---------------------------------------------------------------------------- -void EnvironmentExternal::define(const Transpiler::Token&, const Type::ResolvedType&, - Transpiler::ErrorHandlerBase&) +std::string EnvironmentExternal::getName(const std::string &name, std::optional type) +{ + // If name isn't found in environment + auto env = m_Environment.find(name); + if (env == m_Environment.end()) { + // If there's a parent environment in context, lookup there + if (std::holds_alternative>(m_Context)) { + return std::get>(m_Context).get().getName(name, type); + } + // Otherwise, give error + // **NOTE** this should never throw as type checking should happen first + else { + throw std::runtime_error("Undefined identifier '" + name + "'"); + } + } + // Otherwise, return it's value + else { + return env->second.second; + } +} +//---------------------------------------------------------------------------- +void EnvironmentExternal::define(const Token&, const Type::ResolvedType&, ErrorHandlerBase&) { throw std::runtime_error("Cannot declare variable in external environment"); } //---------------------------------------------------------------------------- +std::vector EnvironmentExternal::getTypes(const Token &name, ErrorHandlerBase &errorHandler) +{ + // If name isn't found in environment + auto env = m_Environment.find(name.lexeme); + if (env == m_Environment.end()) { + // If there's a parent environment in context, lookup there + if (std::holds_alternative>(m_Context)) { + return std::get>(m_Context).get().getTypes(name, errorHandler); + } + // Otherwise, give error + // **NOTE** this should never throw as type checking should happen first + else { + errorHandler.error(name, "Undefined identifier"); + throw TypeChecker::TypeCheckError(); + } + } + // Otherwise, return it's type + else { + return {env->second.first}; + } +} +//---------------------------------------------------------------------------- +void EnvironmentExternal::add(const Type::ResolvedType &type, const std::string &name, const std::string &value) +{ + if(!m_Environment.try_emplace(name, type, value).second) { + throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); + } +} +//---------------------------------------------------------------------------- CodeStream &EnvironmentExternal::getContextStream() const { return std::visit( Utils::Overload{ - [](std::reference_wrapper enclosing)->CodeStream& { return enclosing.get().getStream(); }, + [](std::reference_wrapper enclosing)->CodeStream& { return enclosing.get().getStream(); }, [](std::reference_wrapper os)->CodeStream& { return os.get(); }}, getContext()); } @@ -38,7 +92,7 @@ std::string EnvironmentExternal::getContextName(const std::string &name, std::op { return std::visit( Utils::Overload{ - [&name, type](std::reference_wrapper enclosing)->std::string { return enclosing.get().getName(name, type); }, + [&name, type](std::reference_wrapper enclosing)->std::string { return enclosing.get().getName(name, type); }, [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Identifier '" + name + "' undefined"); }}, getContext()); } From 055d449e5a7c28622ac9b80411f5e7ab61c4067a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 12 Jun 2023 12:07:31 +0100 Subject: [PATCH 212/725] add little helper to add const to a type --- include/genn/genn/type.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index c256684223..8e61596c4d 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -227,6 +227,7 @@ struct ResolvedType const Numeric &getNumeric() const{ return *getValue().numeric; } const ResolvedType addQualifier(Qualifier qualifier) const{ return ResolvedType(*this, qualifiers | qualifier); } + const ResolvedType addConst() const{ return addQualifier(Qualifier::CONSTANT); } bool hasQualifier(Qualifier qualifier) const{ return (qualifiers & qualifier); } std::string getName() const; From 36f3b5af9fd9c43b0370f924ac9dbfc2fdef9ed5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 12 Jun 2023 12:08:10 +0100 Subject: [PATCH 213/725] add a virtual ``EnvironmentExternalBase`` class --- .../genn/genn/code_generator/environment.h | 73 +++++++----- src/genn/genn/code_generator/environment.cc | 106 +++++++++--------- 2 files changed, 101 insertions(+), 78 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index d2a7e1027d..c155e384d5 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -23,60 +23,84 @@ class ErrorHandlerBase; struct Token; } //---------------------------------------------------------------------------- -// GeNN::CodeGenerator::EnvironmentExternal +// GeNN::CodeGenerator::EnvironmentExternalBase //---------------------------------------------------------------------------- +//! Base class for external environments i.e. those defines OUTSIDE of transpiled code by code generator namespace GeNN::CodeGenerator { -class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase, public Transpiler::TypeChecker::EnvironmentBase +class EnvironmentExternalBase : public Transpiler::PrettyPrinter::EnvironmentBase, public Transpiler::TypeChecker::EnvironmentBase { public: - explicit EnvironmentExternal(EnvironmentExternal &enclosing) + explicit EnvironmentExternalBase(EnvironmentExternalBase &enclosing) : m_Context(enclosing) { } - explicit EnvironmentExternal(CodeStream &os) + explicit EnvironmentExternalBase(CodeStream &os) : m_Context(os) { } - EnvironmentExternal(const EnvironmentExternal&) = delete; + EnvironmentExternalBase(const EnvironmentExternalBase&) = delete; //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ virtual std::string define(const std::string &name) override; - - virtual std::string getName(const std::string &name, std::optional type = std::nullopt) override; - + //------------------------------------------------------------------------ // TypeChecker::EnvironmentBase virtuals //------------------------------------------------------------------------ virtual void define(const Transpiler::Token &name, const GeNN::Type::ResolvedType &type, Transpiler::ErrorHandlerBase &errorHandler) override; - virtual std::vector getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) override; +protected: + //------------------------------------------------------------------------ + // Protected API + //------------------------------------------------------------------------ + //! Get stream exposed by context + CodeStream &getContextStream() const; + //! Get name from context if it provides this functionality + std::string getContextName(const std::string &name, std::optional type) const; + + //! Get vector of types from context if it provides this functionality + std::vector getContextTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) const; + +private: //------------------------------------------------------------------------ - // Public API + // Members //------------------------------------------------------------------------ - void add(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &value); + std::variant, std::reference_wrapper> m_Context; +}; -protected: +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentExternal +//---------------------------------------------------------------------------- +//! Minimal environment, not tied to any sort of group - just lets you define things +class EnvironmentExternal : public EnvironmentExternalBase +{ +public: //------------------------------------------------------------------------ - // Protected API + // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ - auto &getContext() const{ return m_Context; } + virtual std::string getName(const std::string &name, std::optional type = std::nullopt) override; + virtual CodeStream &getStream() override { return getContextStream(); } - CodeStream &getContextStream() const; + //------------------------------------------------------------------------ + // TypeChecker::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::vector getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) override; - std::string getContextName(const std::string &name, std::optional type) const; + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + void add(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &value); private: //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - std::variant, std::reference_wrapper> m_Context; std::unordered_map> m_Environment; }; @@ -84,21 +108,16 @@ class EnvironmentExternal : public Transpiler::PrettyPrinter::EnvironmentBase, p // GeNN::CodeGenerator::EnvironmentSubstitute //---------------------------------------------------------------------------- //! Standard pretty printing environment simply allowing substitutions to be implemented -class EnvironmentSubstitute : public EnvironmentExternal +class EnvironmentSubstitute : public EnvironmentExternalBase { public: - EnvironmentSubstitute(EnvironmentSubstitute &enclosing) - : EnvironmentExternal(static_cast(enclosing)), m_Contents(m_ContentsStream) - { - } - - EnvironmentSubstitute(EnvironmentExternal &enclosing) - : EnvironmentExternal(static_cast(enclosing)), m_Contents(m_ContentsStream) + EnvironmentSubstitute(EnvironmentExternalBase &enclosing) + : EnvironmentExternalBase(static_cast(enclosing)), m_Contents(m_ContentsStream) { } EnvironmentSubstitute(CodeStream &os) - : EnvironmentExternal(os), m_Contents(m_ContentsStream) + : EnvironmentExternalBase(os), m_Contents(m_ContentsStream) { } @@ -184,7 +203,7 @@ class EnvironmentSubstitute : public EnvironmentExternal //---------------------------------------------------------------------------- //! Pretty printing environment which caches used variables in local variables template -class EnvironmentLocalVarCache : public EnvironmentExternal +class EnvironmentLocalVarCache : public EnvironmentExternalBase { //! Type of a single definition using DefType = typename std::invoke_result_t::value_type; diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index 89c7d3f0d4..4b17e89bb1 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -17,60 +17,81 @@ using namespace GeNN::CodeGenerator; using namespace GeNN::Transpiler; //---------------------------------------------------------------------------- -// GeNN::CodeGenerator::EnvironmentExternal +// GeNN::CodeGenerator::EnvironmentExternalBase //---------------------------------------------------------------------------- -std::string EnvironmentExternal::define(const std::string&) +std::string EnvironmentExternalBase::define(const std::string&) { throw std::runtime_error("Cannot declare variable in external environment"); } + //---------------------------------------------------------------------------- -std::string EnvironmentExternal::getName(const std::string &name, std::optional type) +void EnvironmentExternalBase::define(const Token&, const Type::ResolvedType&, ErrorHandlerBase&) { - // If name isn't found in environment - auto env = m_Environment.find(name); - if (env == m_Environment.end()) { - // If there's a parent environment in context, lookup there - if (std::holds_alternative>(m_Context)) { - return std::get>(m_Context).get().getName(name, type); - } - // Otherwise, give error - // **NOTE** this should never throw as type checking should happen first - else { - throw std::runtime_error("Undefined identifier '" + name + "'"); - } - } - // Otherwise, return it's value - else { - return env->second.second; - } + throw std::runtime_error("Cannot declare variable in external environment"); } //---------------------------------------------------------------------------- -void EnvironmentExternal::define(const Token&, const Type::ResolvedType&, ErrorHandlerBase&) +CodeStream &EnvironmentExternalBase::getContextStream() const { - throw std::runtime_error("Cannot declare variable in external environment"); + return std::visit( + Utils::Overload{ + [](std::reference_wrapper enclosing)->CodeStream& { return enclosing.get().getStream(); }, + [](std::reference_wrapper os)->CodeStream& { return os.get(); }}, + m_Context); } -//---------------------------------------------------------------------------- +//---------------------------------------------------------------------------- +std::string EnvironmentExternalBase::getContextName(const std::string &name, std::optional type) const +{ + return std::visit( + Utils::Overload{ + [&name, type](std::reference_wrapper enclosing)->std::string { return enclosing.get().getName(name, type); }, + [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Identifier '" + name + "' undefined"); }}, + m_Context); +} +//---------------------------------------------------------------------------- +std::vector EnvironmentExternalBase::getContextTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) const +{ + return std::visit( + Utils::Overload{ + [&errorHandler, &name](std::reference_wrapper enclosing)->std::vector + { + return enclosing.get().getTypes(name, errorHandler); + }, + [&errorHandler, &name](std::reference_wrapper)->std::vector + { + errorHandler.error(name, "Undefined identifier"); + throw TypeChecker::TypeCheckError(); + }}, + m_Context); +} + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentExternal +//---------------------------------------------------------------------------- std::vector EnvironmentExternal::getTypes(const Token &name, ErrorHandlerBase &errorHandler) { // If name isn't found in environment auto env = m_Environment.find(name.lexeme); if (env == m_Environment.end()) { - // If there's a parent environment in context, lookup there - if (std::holds_alternative>(m_Context)) { - return std::get>(m_Context).get().getTypes(name, errorHandler); - } - // Otherwise, give error - // **NOTE** this should never throw as type checking should happen first - else { - errorHandler.error(name, "Undefined identifier"); - throw TypeChecker::TypeCheckError(); - } + return getContextTypes(name, errorHandler); } // Otherwise, return it's type else { return {env->second.first}; } } +//---------------------------------------------------------------------------- +std::string EnvironmentExternal::getName(const std::string &name, std::optional type) +{ + // If name isn't found in environment + auto env = m_Environment.find(name); + if (env == m_Environment.end()) { + return getContextName(name, type); + } + // Otherwise, return it's value + else { + return env->second.second; + } +} //---------------------------------------------------------------------------- void EnvironmentExternal::add(const Type::ResolvedType &type, const std::string &name, const std::string &value) { @@ -78,24 +99,7 @@ void EnvironmentExternal::add(const Type::ResolvedType &type, const std::string throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } -//---------------------------------------------------------------------------- -CodeStream &EnvironmentExternal::getContextStream() const -{ - return std::visit( - Utils::Overload{ - [](std::reference_wrapper enclosing)->CodeStream& { return enclosing.get().getStream(); }, - [](std::reference_wrapper os)->CodeStream& { return os.get(); }}, - getContext()); -} -//---------------------------------------------------------------------------- -std::string EnvironmentExternal::getContextName(const std::string &name, std::optional type) const -{ - return std::visit( - Utils::Overload{ - [&name, type](std::reference_wrapper enclosing)->std::string { return enclosing.get().getName(name, type); }, - [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Identifier '" + name + "' undefined"); }}, - getContext()); -} + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentSubstitute From 91de413b28f6ba813c4794166e5dc51c9776bd17 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 12 Jun 2023 16:00:33 +0100 Subject: [PATCH 214/725] lots of hacking - good progress being made --- .../genn/genn/code_generator/backendBase.h | 2 +- .../genn/genn/code_generator/environment.h | 281 +++++++++++++++++- .../genn/code_generator/standardLibrary.h | 4 +- .../code_generator/synapseUpdateGroupMerged.h | 2 +- include/genn/genn/synapseGroup.h | 6 - include/genn/genn/transpiler/prettyPrinter.h | 9 - include/genn/genn/type.h | 15 +- .../backends/single_threaded_cpu/backend.cc | 111 ++++--- src/genn/genn/code_generator/backendBase.cc | 20 +- src/genn/genn/code_generator/environment.cc | 44 ++- .../synapseUpdateGroupMerged.cc | 39 ++- src/genn/genn/synapseGroup.cc | 45 --- 12 files changed, 428 insertions(+), 150 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index b71a112a4b..c1f5ad029a 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -489,7 +489,7 @@ class GENN_EXPORT BackendBase void genNeuronIndexCalculation(CodeStream &os, const NeuronUpdateGroupMerged &ng, unsigned int batchSize) const; - void genSynapseIndexCalculation(CodeStream &os, const SynapseGroupMergedBase &sg, unsigned int batchSize) const; + void genSynapseIndexCalculation(EnvironmentExternal &env, const SynapseGroupMergedBase &sg, unsigned int batchSize) const; void genCustomUpdateIndexCalculation(CodeStream &os, const CustomUpdateGroupMerged &cu) const; diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index c155e384d5..0c3b10260f 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -12,6 +12,7 @@ // GeNN code generator includes #include "code_generator/codeStream.h" +#include "code_generator/groupMerged.h" // GeNN transpiler includes #include "transpiler/prettyPrinter.h" @@ -54,6 +55,14 @@ class EnvironmentExternalBase : public Transpiler::PrettyPrinter::EnvironmentBas virtual void define(const Transpiler::Token &name, const GeNN::Type::ResolvedType &type, Transpiler::ErrorHandlerBase &errorHandler) override; + //------------------------------------------------------------------------ + // Operators + //------------------------------------------------------------------------ + std::string operator[] (const std::string &name) + { + return getName(name); + } + protected: //------------------------------------------------------------------------ // Protected API @@ -77,33 +86,293 @@ class EnvironmentExternalBase : public Transpiler::PrettyPrinter::EnvironmentBas //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentExternal //---------------------------------------------------------------------------- -//! Minimal environment, not tied to any sort of group - just lets you define things +//! Minimal external environment, not tied to any sort of group - just lets you define things class EnvironmentExternal : public EnvironmentExternalBase { public: + using EnvironmentExternalBase::EnvironmentExternalBase; + EnvironmentExternal(const EnvironmentExternal&) = delete; + ~EnvironmentExternal(); + //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual std::string getName(const std::string &name, std::optional type = std::nullopt) override; - virtual CodeStream &getStream() override { return getContextStream(); } + virtual std::string getName(const std::string &name, std::optional type = std::nullopt) final; + virtual CodeStream &getStream() final { return m_Contents;; } //------------------------------------------------------------------------ // TypeChecker::EnvironmentBase virtuals //------------------------------------------------------------------------ - virtual std::vector getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) override; + virtual std::vector getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) final; //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - void add(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &value); + //! Map a type (for type-checking) and a value (for pretty-printing) to an identifier + void add(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &value, + const std::vector &initialisers = {}, const std::vector &dependents = {}); + size_t addInitialiser(const std::string &initialiser); private: //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - std::unordered_map> m_Environment; + std::ostringstream m_ContentsStream; + CodeStream m_Contents; + + std::unordered_map, std::vector>> m_Environment; + std::vector> m_Initialisers; }; +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentGroupMergedField +//---------------------------------------------------------------------------- +//! External environment, for substituting +template +class EnvironmentGroupMergedField : public EnvironmentExternalBase +{ + using GroupInternal = typename G::GroupInternal; + using IsHeterogeneousFn = bool (G::*)(const std::string&) const; + using IsVarInitHeterogeneousFn = bool (G::*)(const std::string&, const std::string&) const; + + using GroupInternal = typename G::GroupInternal; + using GetVarSuffixFn = const std::string &(GroupInternal::*)(void) const; + using GetParamValuesFn = const std::unordered_map &(GroupInternal::*)(void) const; + + template + using GetVarReferencesFn = const std::unordered_map &(GroupInternal::*)(void) const; + +public: + EnvironmentGroupMergedField(G &group, EnvironmentExternalBase &enclosing) + : EnvironmentExternalBase(enclosing), m_Group(group) + { + } + EnvironmentGroupMergedField(G &group, CodeStream &os) + : EnvironmentExternalBase(os), m_Group(group) + { + } + + //------------------------------------------------------------------------ + // PrettyPrinter::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::string getName(const std::string &name, std::optional type = std::nullopt) final + { + // If name isn't found in environment + auto env = m_Environment.find(name); + if (env == m_Environment.end()) { + return getContextName(name, type); + } + // Otherwise, visit field in environment + else { + return "group->" + std::get<1>(env->second.second); + } + } + virtual CodeStream &getStream() final { return getContextStream(); } + + //------------------------------------------------------------------------ + // TypeChecker::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::vector getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) final + { + // If name isn't found in environment + auto env = m_Environment.find(name); + if (env == m_Environment.end()) { + return getContextType(name, type); + } + // Otherwise, return type + else { + // If field hasn't already been added + if (!std::get<1>(env->second)) { + // Call function to add field to underlying merged group + const auto &field = std::get<2>(env->second); + m_GroupMerged.addField(std::get<0>(field), std::get<1>(field), + std::get<2>(field), std::get<3>(field)); + + // Set flag so field doesn't get re-added + std::get<1>(env->second) = true; + } + // Return type + return {std::get<0>(env->second)}; + } + } + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + //! Map a type (for type-checking) and a group merged field to back it to an identifier + void add(const GeNN::Type::ResolvedType &type, const std::string &name, + const GeNN::Type::ResolvedType &fieldType, const std::string &fieldName, typename G::GetFieldValueFunc getFieldValue, + GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD) + { + if(!m_Environment.try_emplace(name, std::piecewise_construct, + std::forward_as_tuple(type), + std::forward_as_tuple(false), + std::forward_as_tuple(std::in_place, fieldType, fieldName, getFieldValue, mergedFieldType)).second) + { + throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); + } + } + + void addScalar(const std::string &name, const std::string &fieldSuffix, typename G::GetFieldDoubleValueFunc getFieldValue) + { + add(m_Group.getScalarType().addConst(), name, + m_Group.getScalarType(), name + fieldSuffix, + [getFieldValue, this](const auto &g, size_t i) + { + return getScalarString(getFieldValue(g, i); + }); + } + + void addParams(const Snippet::Base::StringVec ¶mNames, const std::string &fieldSuffix, + GetParamValuesFn getParamValues, IsHeterogeneousFn isHeterogeneous) + { + // Loop through params + for(const auto &p : paramNames) { + if (std::invoke(isHeterogeneous, m_Group, p)) { + addScalar(p, fieldSuffix, + [p, getParamValues](const auto &g, size_t) + { + return std::invoke(getParamValues, g).at(p); + }); + } + // Otherwise, just add a const-qualified scalar to the type environment + else { + add(m_Group.getScalarType().addConst(), p, getScalarString(std::invoke(getParamValues, m_Group.getArchetype()).at(p))); + } + } + } + + void addDerivedParams(const Snippet::Base::DerivedParamVec &derivedParams, const std::string &fieldSuffix, + GetParamValuesFn getDerivedParamValues, IsHeterogeneousFn isHeterogeneous) + { + // Loop through derived params + for(const auto &d : derivedParams) { + if (std::invoke(isHeterogeneous, m_Group, d.name)) { + addScalar(d.name, fieldSuffix, + [d, getDerivedParamValues](const auto &g, size_t) + { + return std::invoke(getDerivedParamValues, g).at(d.name); + }); + } + else { + add(m_Group.getScalarType().addConst(), d.name, getScalarString(std::invoke(getDerivedParamValues, m_Group).at(d.name)); + } + } + } + + template + void addVarInitParams(IsVarInitHeterogeneousFn isHeterogeneous, const std::string &fieldSuffix = "") + { + // Loop through weight update model variables + const A archetypeAdaptor(m_Group.getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + // Loop through parameters + for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getParams()) { + if(std::invoke(isHeterogeneous, m_Group, v.name, p.first)) { + defineScalarField(p.first, v.name + fieldSuffix, + [p, v](const auto &g, size_t) + { + return A(g).getInitialisers().at(v.name).getParams().at(p.first); + }); + } + else { + defineField(m_Group.getScalarType().addConst(), p.first); + } + } + } + } + + template + void addVarInitDerivedParams(IsVarInitHeterogeneousFn isHeterogeneous, const std::string &fieldSuffix = "") + { + // Loop through weight update model variables + const A archetypeAdaptor(m_Group.getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + // Loop through parameters + for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getDerivedParams()) { + if(std::invoke(isHeterogeneous, m_Group, v.name, p.first)) { + defineScalarField(p.first, v.name + fieldSuffix, + [p, v](const auto &g, size_t) + { + return A(g).getInitialisers().at(v.name).getDerivedParams().at(p.first); + }); + } + else { + defineField(m_Group.getScalarType().addConst(), p.first); + } + } + } + } + + template + void addVars(const std::string &arrayPrefix, const std::string &fieldSuffix = "") + { + // Loop through variables + const A archetypeAdaptor(m_Group.getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + const auto resolvedType = v.type.resolve(m_Group.getTypeContext()) + const auto qualifiedType = (getVarAccessMode(v.access) & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; + add(qualifiedType, v.name, + resolvedType.createPointer(), v.name + fieldSuffix, + [arrayPrefix, v](const auto &g, size_t) + { + return prefix + v.name + A(g).getNameSuffix(); + }); + } + } + + template + void addVarRefs(const std::string &arrayPrefix, const std::string &fieldSuffix = "") + { + // Loop through variable references + const A archetypeAdaptor(m_Group.getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + // If variable access is read-only, qualify type with const + const auto resolvedType = v.type.resolve(m_Group.getTypeContext()); + const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; + defineField(qualifiedType, v.name, + resolvedType.createPointer(), v.name + fieldSuffix, + [arrayPrefix, v](const auto &g, size_t) + { + const auto varRef = A(g).getInitialisers().at(v.name); + return arrayPrefix + varRef.getVar().name + varRef.getTargetName(); + }); + } + } + + template + void addEGPs(const std::string &arrayPrefix, const std::string &varName = "", const std::string &fieldSuffix = "") + { + // Loop through EGPs + const A archetypeAdaptor(m_Group.getArchetype()); + for(const auto &e : archetypeAdaptor.getDefs()) { + const auto pointerType = e.type.resolve(m_Group.getTypeContext()).createPointer(); + defineField(pointerType, e.name, + pointerType, e.name + varName + fieldSuffix, + [arrayPrefix, e, varName](const auto &g, size_t) + { + return arrayPrefix + e.name + varName + g.getName(); + }, + GroupMergedFieldType::DYNAMIC); + } + } + +private: + std::string getScalarString(double scalar) const + { + return (Utils::writePreciseString(scalar, m_GroupMerged.getScalarType().getNumeric().maxDigits10) + + m_GroupMerged.getScalarType().getNumeric().literalSuffix)); + } + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + std::reference_wrapper m_Group; + + //! Environment mapping names to types to fields to pull values from + std::unordered_map>> m_Environment; +}; + + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentSubstitute //---------------------------------------------------------------------------- diff --git a/include/genn/genn/code_generator/standardLibrary.h b/include/genn/genn/code_generator/standardLibrary.h index 24a0b7dd6e..2bc6c6af8f 100644 --- a/include/genn/genn/code_generator/standardLibrary.h +++ b/include/genn/genn/code_generator/standardLibrary.h @@ -13,10 +13,10 @@ //--------------------------------------------------------------------------- namespace GeNN::CodeGenerator::StandardLibrary { -class Environment : public EnvironmentExternal +class Environment : public EnvironmentExternalBase { public: - using EnvironmentExternal::EnvironmentExternal; + using EnvironmentExternalBase::EnvironmentExternalBase; //------------------------------------------------------------------------ // TypeChecker::EnvironmentBase virtuals diff --git a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h index 1f887668e7..40ae0b56f3 100644 --- a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h @@ -103,7 +103,7 @@ class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase runnerVarDecl, runnerMergedStructAlloc, name); } - void generateSynapseUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const; //---------------------------------------------------------------------------- // Static constants diff --git a/include/genn/genn/synapseGroup.h b/include/genn/genn/synapseGroup.h index 88e5d13941..b6c4e3d21e 100644 --- a/include/genn/genn/synapseGroup.h +++ b/include/genn/genn/synapseGroup.h @@ -201,12 +201,6 @@ class GENN_EXPORT SynapseGroup /*! This is only used by extra global parameters which are pointers*/ VarLocation getSparseConnectivityExtraGlobalParamLocation(const std::string ¶mName) const; - //! Does this synapse group require dendritic delay? - bool isDendriticDelayRequired() const; - - //! Does this synapse group define presynaptic output? - bool isPresynapticOutputRequired() const; - //! Does this synapse group require an RNG to generate procedural connectivity? bool isProceduralConnectivityRNGRequired() const; diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h index 25cc2a3979..456868471b 100644 --- a/include/genn/genn/transpiler/prettyPrinter.h +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -35,15 +35,6 @@ class EnvironmentBase //! Get stream to write code within this environment to virtual CodeGenerator::CodeStream &getStream() = 0; - - //------------------------------------------------------------------------ - // Operators - //------------------------------------------------------------------------ - std::string operator[] (const std::string &name) - { - return getName(name); - } - }; //--------------------------------------------------------------------------- diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 8e61596c4d..cb14ea5250 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -205,13 +205,15 @@ struct ResolvedType {} ResolvedType(const ResolvedType &other, Qualifier qualifiers) : qualifiers(qualifiers), detail(other.detail) {} + explicit ResolvedType(Qualifier qualifiers = Qualifier{0}) : qualifiers(qualifiers), detail(std::monostate{}) + {} //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ Qualifier qualifiers; - std::variant detail; + std::variant detail; //------------------------------------------------------------------------ // Public API @@ -220,6 +222,7 @@ struct ResolvedType bool isPointer() const{ return std::holds_alternative(detail); } bool isFunction() const{ return std::holds_alternative(detail); } bool isNumeric() const{ return isValue() && getValue().numeric; } + bool isVoid() const{ return std::holds_alternative(detail); } const Value &getValue() const{ return std::get(detail); } const Pointer &getPointer() const{ return std::get(detail); } @@ -339,6 +342,16 @@ inline static const ResolvedType Uint32 = CREATE_NUMERIC(uint32_t, 30, "u"); inline static const ResolvedType Float = CREATE_NUMERIC(float, 50, "f"); inline static const ResolvedType Double = CREATE_NUMERIC(double, 60, ""); +// Void +inline static const ResolvedType Void = ResolvedType(); + +//---------------------------------------------------------------------------- +// Standard function types +//---------------------------------------------------------------------------- +inline static const ResolvedType AddToPre = ResolvedType::createFunction(Void, {Uint32}); +inline static const ResolvedType AddToPost = ResolvedType::createFunction(Void, {Uint32}); +inline static const ResolvedType AddToPostDenDelay = ResolvedType::createFunction(Void, {Uint32, Uint32}); + //! Parse a numeric type GENN_EXPORT ResolvedType parseNumeric(const std::string &typeString, const TypeContext &context); diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index ffd1665ff0..d3dfa2d25e 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -8,10 +8,9 @@ #include "code_generator/codeStream.h" #include "code_generator/environment.h" #include "code_generator/modelSpecMerged.h" +#include "code_generator/standardLibrary.h" #include "code_generator/substitutions.h" -#include "transpiler/standardLibrary.h" - using namespace GeNN::CodeGenerator; using namespace GeNN::Transpiler; @@ -125,7 +124,7 @@ void genKernelIteration(EnvironmentExternal &env, const G &g, size_t numKernelDi //-------------------------------------------------------------------------- namespace GeNN::CodeGenerator::SingleThreadedCPU { -void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); if(model.getBatchSize() != 1) { @@ -153,7 +152,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged { CodeStream::Scope b(os); - StandardLibrary::FunctionEnvironment stdEnv(os); + StandardLibrary::Environment stdEnv(os); EnvironmentSubstitute funcEnv(stdEnv); funcEnv.addSubstitution("t", "t"); funcEnv.addSubstitution("batch", "0"); @@ -304,78 +303,102 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Generate stream with synapse update code std::ostringstream synapseUpdateStream; CodeStream synapseUpdate(synapseUpdateStream); - synapseUpdate << "void updateSynapses(timepoint t)"; + + // Begin environment with standard library + StandardLibrary::Environment synapseUpdateEnv(synapseUpdate); + + synapseUpdateEnv.getStream() << "void updateSynapses(timepoint t)"; { - CodeStream::Scope b(synapseUpdate); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision().getName())); - funcSubs.addVarSubstitution("t", "t"); - funcSubs.addVarSubstitution("batch", "0"); + CodeStream::Scope b(synapseUpdateEnv.getStream()); + + EnvironmentExternal funcEnv(synapseUpdateEnv); + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); + funcEnv.add(Type::Uint32.addConst(), "batch", "0"); // Synapse dynamics { Timer t(synapseUpdate, "synapseDynamics", model.isTimingEnabled()); modelMerged.genMergedSynapseDynamicsGroups( *this, - [this, &funcSubs, &synapseUpdate](SynapseDynamicsGroupMerged &s) + [this, &funcEnv, &modelMerged, &synapseUpdate](SynapseDynamicsGroupMerged &s) { - CodeStream::Scope b(synapseUpdate); - synapseUpdate << "// merged synapse dynamics group " << s.getIndex() << std::endl; - synapseUpdate << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged synapse dynamics group " << s.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; { - CodeStream::Scope b(synapseUpdate); + CodeStream::Scope b(funcEnv.getStream()); // Get reference to group - synapseUpdate << "const auto *group = &mergedSynapseDynamicsGroup" << s.getIndex() << "[g]; " << std::endl; - - genSynapseIndexCalculation(synapseUpdate, s, 1); + funcEnv.getStream() << "const auto *group = &mergedSynapseDynamicsGroup" << s.getIndex() << "[g]; " << std::endl; + + // Create matching environment + EnvironmentGroupMergedField groupEnv(s, funcEnv); + + // Add basic fields **TODO** move to group merged + groupEnv.add(Type::Uint32.addConst(), "num_pre", + Type::Uint32, "numSrcNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + groupEnv.add(Type::Uint32.addConst(), "num_post", + Type::Uint32, "numTrgNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + groupEnv.add(Type::Uint32.addConst(), "_row_stride", + Type::Uint32, "rowStride", + [this](const SynapseGroupInternal &sg, size_t) { return std::to_string(getSynapticMatrixRowStride(sg)); }); + + // _row_length + // _ind + genSynapseIndexCalculation(funcEnv, s, 1); // Loop through presynaptic neurons - synapseUpdate << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; + groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)"; { // If this synapse group has sparse connectivity, loop through length of this row - CodeStream::Scope b(synapseUpdate); + CodeStream::Scope b(groupEnv.getStream()); if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - synapseUpdate << "for(unsigned int s = 0; s < group->rowLength[i]; s++)"; + groupEnv.getStream() << "for(unsigned int s = 0; s < " << groupEnv["_row_length"] << "[i]; s++)"; } // Otherwise, if it's dense, loop through each postsynaptic neuron else if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::DENSE) { - synapseUpdate << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; + groupEnv.getStream() << "for (unsigned int j = 0; j < " << groupEnv["num_post"] << "; j++)"; } else { throw std::runtime_error("Only DENSE and SPARSE format connectivity can be used for synapse dynamics"); } { - CodeStream::Scope b(synapseUpdate); + EnvironmentExternal synEnv(groupEnv); + CodeStream::Scope b(synEnv.getStream()); - Substitutions synSubs(&funcSubs); - if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - // Calculate index of synapse and use it to look up postsynaptic index - synapseUpdate << "const unsigned int n = (i * group->rowStride) + s;" << std::endl; - synapseUpdate << "const unsigned int j = group->ind[n];" << std::endl; + // Add presynaptic index to substitutions + synEnv.add(Type::Uint32.addConst(), "id_pre", "i"); - synSubs.addVarSubstitution("id_syn", "n"); + if (s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { + // Add initialiser strings to calculate synaptic and presynaptic index + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["_row_stride"] + ") + s;"); + const size_t idPostInit = synEnv.addInitialiser("const unsigned int idPost = " + synEnv["_ind"] + "[idSyn];"); + + // **TODO** id_syn can be 64-bit + synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}, {"_row_stride"}); + synEnv.add(Type::Uint32.addConst(), "id_post", "idPost", {idPostInit, idSynInit}, {"_ind"}); } else { - synSubs.addVarSubstitution("id_syn", "(i * group->numTrgNeurons) + j"); - } + // Add postsynaptic index to substitutions + synEnv.add(Type::Uint32.addConst(), "id_post", "j"); - // Add pre and postsynaptic indices to substitutions - synSubs.addVarSubstitution("id_pre", "i"); - synSubs.addVarSubstitution("id_post", "j"); + // Add initialiser to calculate synaptic index + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["num_post"] + ") + j;"); - // Add correct functions for apply synaptic input - if(s.getArchetype().isDendriticDelayRequired()) { - synSubs.addFuncSubstitution("addToInSynDelay", 2, "group->denDelay[" + s.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)"); - } - else { - synSubs.addFuncSubstitution("addToInSyn", 1, "group->inSyn[" + s.getPostISynIndex(1, "j") + "] += $(0)"); - } + // **TODO** id_syn can be 64-bit + synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}); - if(s.getArchetype().isPresynapticOutputRequired()) { - synSubs.addFuncSubstitution("addToPre", 1, "group->revInSyn[" + s.getPreISynIndex(1, synSubs["id_pre"]) + "] += $(0)"); } + + // Add correct functions for apply synaptic input + synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", synEnv["_den_delay"] + "[" + s.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)"); + synEnv.add(Type::AddToPost, "addToPost", synEnv["_out_post"] + "[" + s.getPostISynIndex(1, "j") + "] += $(0)"); + synEnv.add(Type::AddToPre, "addToPre", synEnv["_out_pre"] + "[" + s.getPreISynIndex(1, synEnv["id_pre"]) + "] += $(0)"); + // Call synapse dynamics handler - s.generateSynapseUpdate(*this, synapseUpdate, modelMerged, synSubs); + s.generateSynapseUpdate(*this, synEnv, modelMerged); } } } @@ -543,7 +566,7 @@ void Backend::genCustomUpdate(CodeStream &os_, ModelSpecMerged &modelMerged, Hos { CodeStream::Scope b(os_); - StandardLibrary::FunctionEnvironment stdEnv(os_); + StandardLibrary::Environment stdEnv(os_); EnvironmentSubstitute funcEnv(stdEnv); funcEnv.addSubstitution("t", "t"); funcEnv.addSubstitution("batch", "0"); diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 3d4561c862..4c250ed47c 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -66,24 +66,28 @@ void BackendBase::genNeuronIndexCalculation(CodeStream &os, const NeuronUpdateGr } } //----------------------------------------------------------------------- -void BackendBase::genSynapseIndexCalculation(CodeStream &os, const SynapseGroupMergedBase &sg, unsigned int batchSize) const +void BackendBase::genSynapseIndexCalculation(EnvironmentExternal &env, const SynapseGroupMergedBase &sg, unsigned int batchSize) const { // If batching is enabled if(batchSize > 1) { // Calculate batch offsets into pre and postsynaptic populations - os << "const unsigned int preBatchOffset = group->numSrcNeurons * batch;" << std::endl; - os << "const unsigned int postBatchOffset = group->numTrgNeurons * batch;" << std::endl; - + env.add(Type::Uint32.addConst(), "_pre_batch_offset", "preBatchOffset", + {env.addInitialiser("const unsigned int preBatchOffset = " + env["num_pre"] + " * " + env["batch"] + ";")}); + env.add(Type::Uint32.addConst(), "_post_batch_offset", "postBatchOffset", + {env.addInitialiser("const unsigned int preBatchOffset = " + env["num_post"] + " * " + env["batch"] + ";")}); + // Calculate batch offsets into synapse arrays, using 64-bit arithmetic if necessary if(areSixtyFourBitSynapseIndicesRequired(sg)) { - os << "const uint64_t synBatchOffset = (uint64_t)preBatchOffset * (uint64_t)group->rowStride;" << std::endl; + assert(false); + //os << "const uint64_t synBatchOffset = (uint64_t)preBatchOffset * (uint64_t)group->rowStride;" << std::endl; } else { - os << "const unsigned int synBatchOffset = preBatchOffset * group->rowStride;" << std::endl; + env.add(Type::Uint32.addConst(), "_syn_batch_offset", "synBatchOffset", + {env.addInitialiser("const unsigned int synBatchOffset = " + env["_pre_batch_offset"] + " * " + env["_row_stride"] + ";")}); } // If synapse group has kernel weights - const auto &kernelSize = sg.getArchetype().getKernelSize(); + /*const auto &kernelSize = sg.getArchetype().getKernelSize(); if((sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL) && !kernelSize.empty()) { // Loop through kernel dimensions and multiply together os << "const unsigned int kernBatchOffset = "; @@ -151,7 +155,7 @@ void BackendBase::genSynapseIndexCalculation(CodeStream &os, const SynapseGroupM } } - } + }*/ } //----------------------------------------------------------------------- void BackendBase::genCustomUpdateIndexCalculation(CodeStream &os, const CustomUpdateGroupMerged &cu) const diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index 4b17e89bb1..40c107fb7b 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -67,6 +67,20 @@ std::vector EnvironmentExternalBase::getContextTypes(const T //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentExternal //---------------------------------------------------------------------------- +EnvironmentExternal::~EnvironmentExternal() +{ + // Loop through initialiser + for(const auto &i : m_Initialisers) { + // If variable requiring initialiser has been referenced, write out initialiser + if (i.first) { + getContextStream() << i.second << std::endl; + } + } + + // Write contents to context stream + getContextStream() << m_ContentsStream.str(); +} +//---------------------------------------------------------------------------- std::vector EnvironmentExternal::getTypes(const Token &name, ErrorHandlerBase &errorHandler) { // If name isn't found in environment @@ -74,9 +88,21 @@ std::vector EnvironmentExternal::getTypes(const Token &name, if (env == m_Environment.end()) { return getContextTypes(name, errorHandler); } - // Otherwise, return it's type + // Otherwise else { - return {env->second.first}; + // If this identifier relies on any initialiser statements, mark these initialisers as required + for(size_t i : std::get<2>(env->second)) { + m_Initialisers.at(i).first = true; + } + + // If this identifier relies on any others, get their types + // **YUCK** + for(const std::string &id : std::get<3>(env->second)) { + getTypes(Token{Token::Type::IDENTIFIER, id, 0}, errorHandler); + } + + // Return type of variables + return {std::get<0>(env->second)}; } } //---------------------------------------------------------------------------- @@ -89,17 +115,23 @@ std::string EnvironmentExternal::getName(const std::string &name, std::optional< } // Otherwise, return it's value else { - return env->second.second; + return std::get<1>(env->second); } } //---------------------------------------------------------------------------- -void EnvironmentExternal::add(const Type::ResolvedType &type, const std::string &name, const std::string &value) +void EnvironmentExternal::add(const Type::ResolvedType &type, const std::string &name, const std::string &value, + const std::vector &initialisers, const std::vector &dependents) { - if(!m_Environment.try_emplace(name, type, value).second) { + if(!m_Environment.try_emplace(name, type, value, initialisers, dependents).second) { throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } - +//------------------------------------------------------------------------ +size_t EnvironmentExternal::addInitialiser(const std::string &initialiser) +{ + m_Initialisers.emplace_back(false, initialiser); + return (m_Initialisers.size() - 1); +} //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentSubstitute diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 39e4aa3eae..2d57740503 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -11,26 +11,23 @@ using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- namespace { -void applySynapseSubstitutions(CodeStream &os, std::string code, const std::string &errorContext, - const SynapseGroupMergedBase &sg, const Substitutions &baseSubs, - const ModelSpecMerged &modelMerged, const bool backendSupportsNamespace) +template +void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBase &env, std::string code, const std::string &errorContext, + const G &sg, const ModelSpecMerged &modelMerged, bool backendSupportsNamespace) { const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); const auto *wu = sg.getArchetype().getWUModel(); - Substitutions synapseSubs(&baseSubs); + EnvironmentGroupMergedField synEnv(sg, env); // Substitute parameter and derived parameter names - synapseSubs.addParamValueSubstitution(wu->getParamNames(), sg.getArchetype().getWUParams(), - [&sg](const std::string &p) { return sg.isWUParamHeterogeneous(p); }, - "", "group->"); - synapseSubs.addVarValueSubstitution(wu->getDerivedParams(), sg.getArchetype().getWUDerivedParams(), - [&sg](const std::string &p) { return sg.isWUDerivedParamHeterogeneous(p); }, - "", "group->"); - synapseSubs.addVarNameSubstitution(wu->getExtraGlobalParams(), "", "group->"); + synEnv.addParams(wu->getParamNames(), "", &SynapseGroupInternal::getWUParams, &G::isWUParamHeterogeneous); + synEnv.addDerivedParams(wu->getDerivedParams(), "", &SynapseGroupInternal::getWUDerivedParams, &G::isWUDerivedParamHeterogeneous); + synEnv.addEGPs(backend.getDeviceVarPrefix()); - // Substitute names of pre and postsynaptic weight update variables + // Substitute names of pre and postsynaptic weight update variable + synEnv.addVars(backend.getDeviceVarPrefix()); synapseSubs.addVarNameSubstitution(wu->getPreVars(), "", "group->", [&sg, &synapseSubs, batchSize](VarAccess a, const std::string&) { @@ -220,14 +217,14 @@ void PresynapticUpdateGroupMerged::generateSpikeEventThreshold(const BackendBase //---------------------------------------------------------------------------- void PresynapticUpdateGroupMerged::generateSpikeEventUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { - applySynapseSubstitutions(os, getArchetype().getWUModel()->getEventCode(), "eventCode", - *this, popSubs, modelMerged, backend.supportsNamespace()); + applySynapseSubstitutions(backend, os, getArchetype().getWUModel()->getEventCode(), "eventCode", + *this, popSubs, modelMerged); } //---------------------------------------------------------------------------- void PresynapticUpdateGroupMerged::generateSpikeUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { - applySynapseSubstitutions(os, getArchetype().getWUModel()->getSimCode(), "simCode", - *this, popSubs, modelMerged, backend.supportsNamespace()); + applySynapseSubstitutions(backend, os, getArchetype().getWUModel()->getSimCode(), "simCode", + *this, popSubs, modelMerged); } //---------------------------------------------------------------------------- void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendBase&, CodeStream &os, Substitutions &popSubs) const @@ -304,15 +301,15 @@ void PostsynapticUpdateGroupMerged::generateSynapseUpdate(const BackendBase &bac //---------------------------------------------------------------------------- const std::string SynapseDynamicsGroupMerged::name = "SynapseDynamics"; //---------------------------------------------------------------------------- -void SynapseDynamicsGroupMerged::generateSynapseUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void SynapseDynamicsGroupMerged::generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const { const auto *wum = getArchetype().getWUModel(); - if (!wum->getSynapseDynamicsSuppportCode().empty() && backend.supportsNamespace()) { + /*if (!wum->getSynapseDynamicsSuppportCode().empty() && backend.supportsNamespace()) { os << "using namespace " << modelMerged.getSynapseDynamicsSupportCodeNamespace(wum->getSynapseDynamicsSuppportCode()) << ";" << std::endl; - } + }*/ - applySynapseSubstitutions(os, wum->getSynapseDynamicsCode(), "synapseDynamics", - *this, popSubs, modelMerged, backend.supportsNamespace()); + applySynapseSubstitutions(backend, env, wum->getSynapseDynamicsCode(), "synapseDynamics", + *this, modelMerged, backend.supportsNamespace()); } diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 960c1a6ea9..93815f9f03 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -305,51 +305,6 @@ VarLocation SynapseGroup::getSparseConnectivityExtraGlobalParamLocation(const st return m_ConnectivityExtraGlobalParamLocation[m_SparseConnectivityInitialiser.getSnippet()->getExtraGlobalParamIndex(paramName)]; } //---------------------------------------------------------------------------- -bool SynapseGroup::isDendriticDelayRequired() const -{ - // If addToInSynDelay function is used in sim code, return true - if(getWUModel()->getSimCode().find("$(addToInSynDelay") != std::string::npos) { - return true; - } - - // If addToInSynDelay function is used in event code, return true - if(getWUModel()->getEventCode().find("$(addToInSynDelay") != std::string::npos) { - return true; - } - - // If addToInSynDelay function is used in synapse dynamics, return true - if(getWUModel()->getSynapseDynamicsCode().find("$(addToInSynDelay") != std::string::npos) { - return true; - } - - return false; -} -//---------------------------------------------------------------------------- -bool SynapseGroup::isPresynapticOutputRequired() const -{ - // If addToPre function is used in sim_code, return true - if(getWUModel()->getSimCode().find("$(addToPre") != std::string::npos) { - return true; - } - - // If addToPre function is used in learn_post_code, return true - if(getWUModel()->getLearnPostCode().find("$(addToPre") != std::string::npos) { - return true; - } - - // If addToPre function is used in event_code, return true - if(getWUModel()->getEventCode().find("$(addToPre") != std::string::npos) { - return true; - } - - // If addToPre function is used in synapse_dynamics, return true - if(getWUModel()->getSynapseDynamicsCode().find("$(addToPre") != std::string::npos) { - return true; - } - - return false; -} -//---------------------------------------------------------------------------- bool SynapseGroup::isProceduralConnectivityRNGRequired() const { if(m_MatrixType & SynapseMatrixConnectivity::PROCEDURAL) { From e608bce0b96be404099325bd4f01be33e3341b9f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 14 Jun 2023 12:06:53 +0100 Subject: [PATCH 215/725] additional environment functionality: * Added new intermediate class to provide policy-based implementation of basic functionality * Fixed lots of typos --- .../genn/genn/code_generator/environment.h | 346 ++++++++++++------ src/genn/genn/code_generator/environment.cc | 69 ---- 2 files changed, 242 insertions(+), 173 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 0c3b10260f..dd7061c1d4 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -18,11 +18,13 @@ #include "transpiler/prettyPrinter.h" #include "transpiler/typeChecker.h" +// Forward declarations namespace GeNN::Transpiler { class ErrorHandlerBase; struct Token; } + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentExternalBase //---------------------------------------------------------------------------- @@ -84,72 +86,94 @@ class EnvironmentExternalBase : public Transpiler::PrettyPrinter::EnvironmentBas }; //---------------------------------------------------------------------------- -// GeNN::CodeGenerator::EnvironmentExternal +// GeNN::CodeGenerator::EnvironmentSubstitutionPolicy //---------------------------------------------------------------------------- -//! Minimal external environment, not tied to any sort of group - just lets you define things -class EnvironmentExternal : public EnvironmentExternalBase +struct EnvironmentSubstitutionPolicy { -public: - using EnvironmentExternalBase::EnvironmentExternalBase; - EnvironmentExternal(const EnvironmentExternal&) = delete; - ~EnvironmentExternal(); - - //------------------------------------------------------------------------ - // PrettyPrinter::EnvironmentBase virtuals - //------------------------------------------------------------------------ - virtual std::string getName(const std::string &name, std::optional type = std::nullopt) final; - virtual CodeStream &getStream() final { return m_Contents;; } - - //------------------------------------------------------------------------ - // TypeChecker::EnvironmentBase virtuals - //------------------------------------------------------------------------ - virtual std::vector getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) final; - - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - //! Map a type (for type-checking) and a value (for pretty-printing) to an identifier - void add(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &value, - const std::vector &initialisers = {}, const std::vector &dependents = {}); + using Payload = std::string; - size_t addInitialiser(const std::string &initialiser); -private: - //------------------------------------------------------------------------ - // Members - //------------------------------------------------------------------------ - std::ostringstream m_ContentsStream; - CodeStream m_Contents; + std::string getName(const std::string &payload) + { + return payload; + } - std::unordered_map, std::vector>> m_Environment; - std::vector> m_Initialisers; + void setRequired(std::string&) + { + } }; //---------------------------------------------------------------------------- -// GeNN::CodeGenerator::EnvironmentGroupMergedField +// GeNN::CodeGenerator::EnvironmentFieldPolicy //---------------------------------------------------------------------------- -//! External environment, for substituting template -class EnvironmentGroupMergedField : public EnvironmentExternalBase +struct EnvironmentFieldPolicy { - using GroupInternal = typename G::GroupInternal; - using IsHeterogeneousFn = bool (G::*)(const std::string&) const; - using IsVarInitHeterogeneousFn = bool (G::*)(const std::string&, const std::string&) const; + using Payload = std::tuple>; - using GroupInternal = typename G::GroupInternal; - using GetVarSuffixFn = const std::string &(GroupInternal::*)(void) const; - using GetParamValuesFn = const std::unordered_map &(GroupInternal::*)(void) const; + EnvironmentFieldPolicy(G &group) : m_Group(group) + { + } - template - using GetVarReferencesFn = const std::unordered_map &(GroupInternal::*)(void) const; + std::string getName(const Payload &payload) + { + // If a field is specified + if(std::get<2>(payload)) { + return "group->" + std::get<1>(std::get<2>(payload).get()) + std::get<1>(payload); + } + // Otherwise, use value directly + else { + assert(!std::get<1>(payload).empty()); + return std::get<1>(payload); + } + } -public: - EnvironmentGroupMergedField(G &group, EnvironmentExternalBase &enclosing) - : EnvironmentExternalBase(enclosing), m_Group(group) + void setRequired(Payload &payload) { + // If a field is specified but it hasn't already been added + if (std::get<2>(payload) && !std::get<0>(payload)) { + // Call function to add field to underlying merged group + const auto &field = std::get<2>(payload).get(); + m_Group.addField(std::get<0>(field), std::get<1>(field), + std::get<2>(field), std::get<3>(field)); + + // Set flag so field doesn't get re-added + std::get<0>(payload) = true; + } } - EnvironmentGroupMergedField(G &group, CodeStream &os) - : EnvironmentExternalBase(os), m_Group(group) + +private: + std::reference_wrapper m_Group; +}; + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentExternalDynamicBase +//---------------------------------------------------------------------------- +template +class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, private P +{ +public: + template + EnvironmentExternalDynamicBase(EnvironmentExternalBase &enclosing, PolicyArgs&&... policyArgs) + : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...) + {} + + template + EnvironmentExternalDynamicBase(CodeStream &os, PolicyArgs&&... policyArgs) + : EnvironmentExternalBase(os), P(std::forward(policyArgs)...) + {} + + ~EnvironmentExternalDynamicBase() { + // Loop through initialiser + for(const auto &i : m_Initialisers) { + // If variable requiring initialiser has been referenced, write out initialiser + if (i.first) { + getContextStream() << i.second << std::endl; + } + } + + // Write contents to context stream + getContextStream() << m_ContentsStream.str(); } //------------------------------------------------------------------------ @@ -162,12 +186,13 @@ class EnvironmentGroupMergedField : public EnvironmentExternalBase if (env == m_Environment.end()) { return getContextName(name, type); } - // Otherwise, visit field in environment + // Otherwise, get name from payload else { - return "group->" + std::get<1>(env->second.second); + return getName(std::get<3>(env->second)); } } - virtual CodeStream &getStream() final { return getContextStream(); } + + virtual CodeStream &getStream() final { return m_Contents; } //------------------------------------------------------------------------ // TypeChecker::EnvironmentBase virtuals @@ -175,42 +200,122 @@ class EnvironmentGroupMergedField : public EnvironmentExternalBase virtual std::vector getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) final { // If name isn't found in environment - auto env = m_Environment.find(name); + auto env = m_Environment.find(name.lexeme); if (env == m_Environment.end()) { - return getContextType(name, type); + return getContextTypes(name, errorHandler); } - // Otherwise, return type + // Otherwise else { - // If field hasn't already been added - if (!std::get<1>(env->second)) { - // Call function to add field to underlying merged group - const auto &field = std::get<2>(env->second); - m_GroupMerged.addField(std::get<0>(field), std::get<1>(field), - std::get<2>(field), std::get<3>(field)); - - // Set flag so field doesn't get re-added - std::get<1>(env->second) = true; + // If this identifier relies on any initialiser statements, mark these initialisers as required + for(size_t i : std::get<1>(env->second)) { + m_Initialisers.at(i).first = true; } - // Return type + + // If this identifier relies on any others, get their types + // **YUCK** + for(const std::string &id : std::get<2>(env->second)) { + getTypes(Token{Token::Type::IDENTIFIER, id, 0}, errorHandler); + } + + // Perform any type-specific logic to mark this identifier as required + setRequired(std::get<3>(env->second)); + + // Return type of variables return {std::get<0>(env->second)}; } } + + size_t addInitialiser(const std::string &initialiser) + { + m_Initialisers.emplace_back(false, initialiser); + return (m_Initialisers.size() - 1); + } + +protected: + //------------------------------------------------------------------------ + // Protected API + //------------------------------------------------------------------------ + //! Map an identifier to a type (for type-checking), lists of initialisers and dependencies and a payload + void addInternal(const GeNN::Type::ResolvedType &type, const std::string &name, const typename P::Payload &payload, + const std::vector &initialisers = {}, const std::vector &dependents = {}) + { + if(!m_Environment.try_emplace(name, type, initialisers, dependents, payload).second) { + throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); + } + } + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + std::ostringstream m_ContentsStream; + CodeStream m_Contents; + + std::unordered_map, std::vector, typename P::Payload>> m_Environment; + std::vector> m_Initialisers; +}; + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentExternal +//---------------------------------------------------------------------------- +//! Minimal external environment, not tied to any sort of group - just lets you define things +class EnvironmentExternal : public EnvironmentExternalDynamicBase +{ +public: + using EnvironmentExternalDynamicBase::EnvironmentExternalDynamicBase; + EnvironmentExternal(const EnvironmentExternal&) = delete; + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + //! Map a type (for type-checking) and a value (for pretty-printing) to an identifier + void add(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &value, + const std::vector &initialisers = {}, const std::vector &dependents = {}) + { + addInternal(type, name, value, initialisers, dependents); + } +}; + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentGroupMergedField +//---------------------------------------------------------------------------- +//! External environment, for substituting +template +class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase> +{ + using GroupInternal = typename G::GroupInternal; + using IsHeterogeneousFn = bool (G::*)(const std::string&) const; + using IsVarInitHeterogeneousFn = bool (G::*)(const std::string&, const std::string&) const; + + using GetVarSuffixFn = const std::string &(GroupInternal::*)(void) const; + using GetParamValuesFn = const std::unordered_map &(GroupInternal::*)(void) const; + + template + using GetVarReferencesFn = const std::unordered_map &(GroupInternal::*)(void) const; + +public: + using EnvironmentExternalDynamicBase::EnvironmentExternalDynamicBase; + //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ + //! Map a type and a value to an identifier + void add(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &value, + const std::vector &initialisers = {}, const std::vector &dependents = {}) + { + addInternal(type, name, std::make_tuple(false, value, std::nullopt), + initialisers, dependents); + } + //! Map a type (for type-checking) and a group merged field to back it to an identifier void add(const GeNN::Type::ResolvedType &type, const std::string &name, const GeNN::Type::ResolvedType &fieldType, const std::string &fieldName, typename G::GetFieldValueFunc getFieldValue, - GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD) + const std::string &indexSuffix = "", GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, + const std::vector &initialisers = {}, const std::vector &dependents = {}) { - if(!m_Environment.try_emplace(name, std::piecewise_construct, - std::forward_as_tuple(type), - std::forward_as_tuple(false), - std::forward_as_tuple(std::in_place, fieldType, fieldName, getFieldValue, mergedFieldType)).second) - { - throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); - } + addInternal(type, name, std::make_tuple(false, indexSuffix, std::forward_as_tuple(std::in_place, fieldType, fieldName, getFieldValue, mergedFieldType)), + initialisers, dependents); } void addScalar(const std::string &name, const std::string &fieldSuffix, typename G::GetFieldDoubleValueFunc getFieldValue) @@ -219,7 +324,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalBase m_Group.getScalarType(), name + fieldSuffix, [getFieldValue, this](const auto &g, size_t i) { - return getScalarString(getFieldValue(g, i); + return getScalarString(getFieldValue(g, i)); }); } @@ -228,6 +333,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalBase { // Loop through params for(const auto &p : paramNames) { + // If parameter is heterogeneous, add scalar field if (std::invoke(isHeterogeneous, m_Group, p)) { addScalar(p, fieldSuffix, [p, getParamValues](const auto &g, size_t) @@ -237,7 +343,8 @@ class EnvironmentGroupMergedField : public EnvironmentExternalBase } // Otherwise, just add a const-qualified scalar to the type environment else { - add(m_Group.getScalarType().addConst(), p, getScalarString(std::invoke(getParamValues, m_Group.getArchetype()).at(p))); + add(m_Group.getScalarType().addConst(), p, + getScalarString(std::invoke(getParamValues, m_Group.getArchetype()).at(p))); } } } @@ -247,6 +354,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalBase { // Loop through derived params for(const auto &d : derivedParams) { + // If derived parameter is heterogeneous, add scalar field if (std::invoke(isHeterogeneous, m_Group, d.name)) { addScalar(d.name, fieldSuffix, [d, getDerivedParamValues](const auto &g, size_t) @@ -254,8 +362,10 @@ class EnvironmentGroupMergedField : public EnvironmentExternalBase return std::invoke(getDerivedParamValues, g).at(d.name); }); } + // Otherwise, just add a const-qualified scalar to the type environment with archetype value else { - add(m_Group.getScalarType().addConst(), d.name, getScalarString(std::invoke(getDerivedParamValues, m_Group).at(d.name)); + add(m_Group.getScalarType().addConst(), d.name, + getScalarString(std::invoke(getDerivedParamValues, m_Group).at(d.name))); } } } @@ -268,15 +378,17 @@ class EnvironmentGroupMergedField : public EnvironmentExternalBase for(const auto &v : archetypeAdaptor.getDefs()) { // Loop through parameters for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getParams()) { + // If parameter is heterogeneous, add scalar field if(std::invoke(isHeterogeneous, m_Group, v.name, p.first)) { - defineScalarField(p.first, v.name + fieldSuffix, - [p, v](const auto &g, size_t) - { - return A(g).getInitialisers().at(v.name).getParams().at(p.first); - }); + addScalar(p.first, v.name + fieldSuffix, + [p, v](const auto &g, size_t) + { + return A(g).getInitialisers().at(v.name).getParams().at(p.first); + }); } + // Otherwise, just add a const-qualified scalar to the type environment with archetype value else { - defineField(m_Group.getScalarType().addConst(), p.first); + add(m_Group.getScalarType().addConst(), p.first, getScalarString(p.second)); } } } @@ -290,22 +402,25 @@ class EnvironmentGroupMergedField : public EnvironmentExternalBase for(const auto &v : archetypeAdaptor.getDefs()) { // Loop through parameters for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getDerivedParams()) { + // If derived parameter is heterogeneous, add scalar field if(std::invoke(isHeterogeneous, m_Group, v.name, p.first)) { - defineScalarField(p.first, v.name + fieldSuffix, - [p, v](const auto &g, size_t) - { - return A(g).getInitialisers().at(v.name).getDerivedParams().at(p.first); - }); + addScalar(p.first, v.name + fieldSuffix, + [p, v](const auto &g, size_t) + { + return A(g).getInitialisers().at(v.name).getDerivedParams().at(p.first); + }); } + // Otherwise, just add a const-qualified scalar to the type environment with archetype value else { - defineField(m_Group.getScalarType().addConst(), p.first); + add(m_Group.getScalarType().addConst(), p.first, getScalarString(p.second)); } } } } - template - void addVars(const std::string &arrayPrefix, const std::string &fieldSuffix = "") + template + void addVars(const std::string &arrayPrefix, I getIndexFn, const std::string &fieldSuffix = "", + const std::vector &dependents = {}) { // Loop through variables const A archetypeAdaptor(m_Group.getArchetype()); @@ -314,15 +429,25 @@ class EnvironmentGroupMergedField : public EnvironmentExternalBase const auto qualifiedType = (getVarAccessMode(v.access) & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; add(qualifiedType, v.name, resolvedType.createPointer(), v.name + fieldSuffix, - [arrayPrefix, v](const auto &g, size_t) + [arrayPrefix, getIndexFn, v](const auto &g, size_t) { return prefix + v.name + A(g).getNameSuffix(); - }); + }, + getIndexFn(v.access, v.name), GroupMergedFieldType::STANDARD, {}, dependents); } } template - void addVarRefs(const std::string &arrayPrefix, const std::string &fieldSuffix = "") + void addVars(const std::string &arrayPrefix, const std::string &index, const std::string &fieldSuffix = "", + const std::vector &dependents = {}) + { + addVars(arrayPrefix, [&index](VarAccess a, const std::string &) { return index; }, + fieldSuffix, dependents); + } + + template + void addVarRefs(const std::string &arrayPrefix, I getIndexFn, const std::string &fieldSuffix = "", + const std::vector &dependents = {}) { // Loop through variable references const A archetypeAdaptor(m_Group.getArchetype()); @@ -330,15 +455,24 @@ class EnvironmentGroupMergedField : public EnvironmentExternalBase // If variable access is read-only, qualify type with const const auto resolvedType = v.type.resolve(m_Group.getTypeContext()); const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; - defineField(qualifiedType, v.name, - resolvedType.createPointer(), v.name + fieldSuffix, - [arrayPrefix, v](const auto &g, size_t) - { - const auto varRef = A(g).getInitialisers().at(v.name); - return arrayPrefix + varRef.getVar().name + varRef.getTargetName(); - }); + add(qualifiedType, v.name, + resolvedType.createPointer(), v.name + fieldSuffix, + [arrayPrefix, v](const auto &g, size_t) + { + const auto varRef = A(g).getInitialisers().at(v.name); + return arrayPrefix + varRef.getVar().name + varRef.getTargetName(); + }, + getIndexFn(v.access, v.name), GroupMergedFieldType::STANDARD, {}, dependents); } } + + template + void addVarRefs(const std::string &arrayPrefix, const std::string &index, const std::string &fieldSuffix = "", + const std::vector &dependents = {}) + { + addVarRefs(arrayPrefix, [&index](VarAccess a, const std::string &) { return index; }, + fieldSuffix, dependents); + } template void addEGPs(const std::string &arrayPrefix, const std::string &varName = "", const std::string &fieldSuffix = "") @@ -353,23 +487,27 @@ class EnvironmentGroupMergedField : public EnvironmentExternalBase { return arrayPrefix + e.name + varName + g.getName(); }, - GroupMergedFieldType::DYNAMIC); + "", GroupMergedFieldType::DYNAMIC); } } private: + //------------------------------------------------------------------------ + // Private API + //------------------------------------------------------------------------ std::string getScalarString(double scalar) const { return (Utils::writePreciseString(scalar, m_GroupMerged.getScalarType().getNumeric().maxDigits10) - + m_GroupMerged.getScalarType().getNumeric().literalSuffix)); + + m_GroupMerged.getScalarType().getNumeric().literalSuffix); } + //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ std::reference_wrapper m_Group; //! Environment mapping names to types to fields to pull values from - std::unordered_map>> m_Environment; + std::unordered_map>> m_Environment; }; diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index 40c107fb7b..967dcc3be0 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -64,75 +64,6 @@ std::vector EnvironmentExternalBase::getContextTypes(const T m_Context); } -//---------------------------------------------------------------------------- -// GeNN::CodeGenerator::EnvironmentExternal -//---------------------------------------------------------------------------- -EnvironmentExternal::~EnvironmentExternal() -{ - // Loop through initialiser - for(const auto &i : m_Initialisers) { - // If variable requiring initialiser has been referenced, write out initialiser - if (i.first) { - getContextStream() << i.second << std::endl; - } - } - - // Write contents to context stream - getContextStream() << m_ContentsStream.str(); -} -//---------------------------------------------------------------------------- -std::vector EnvironmentExternal::getTypes(const Token &name, ErrorHandlerBase &errorHandler) -{ - // If name isn't found in environment - auto env = m_Environment.find(name.lexeme); - if (env == m_Environment.end()) { - return getContextTypes(name, errorHandler); - } - // Otherwise - else { - // If this identifier relies on any initialiser statements, mark these initialisers as required - for(size_t i : std::get<2>(env->second)) { - m_Initialisers.at(i).first = true; - } - - // If this identifier relies on any others, get their types - // **YUCK** - for(const std::string &id : std::get<3>(env->second)) { - getTypes(Token{Token::Type::IDENTIFIER, id, 0}, errorHandler); - } - - // Return type of variables - return {std::get<0>(env->second)}; - } -} -//---------------------------------------------------------------------------- -std::string EnvironmentExternal::getName(const std::string &name, std::optional type) -{ - // If name isn't found in environment - auto env = m_Environment.find(name); - if (env == m_Environment.end()) { - return getContextName(name, type); - } - // Otherwise, return it's value - else { - return std::get<1>(env->second); - } -} -//---------------------------------------------------------------------------- -void EnvironmentExternal::add(const Type::ResolvedType &type, const std::string &name, const std::string &value, - const std::vector &initialisers, const std::vector &dependents) -{ - if(!m_Environment.try_emplace(name, type, value, initialisers, dependents).second) { - throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); - } -} -//------------------------------------------------------------------------ -size_t EnvironmentExternal::addInitialiser(const std::string &initialiser) -{ - m_Initialisers.emplace_back(false, initialiser); - return (m_Initialisers.size() - 1); -} - //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentSubstitute //---------------------------------------------------------------------------- From 5e7e3e742a0534ac34a38f9a9090b30f0610afd1 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 14 Jun 2023 12:12:19 +0100 Subject: [PATCH 216/725] added context to error handler --- include/genn/genn/transpiler/errorHandler.h | 3 ++- src/genn/genn/transpiler/errorHandler.cc | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/include/genn/genn/transpiler/errorHandler.h b/include/genn/genn/transpiler/errorHandler.h index 279dbfcab1..e752a627fd 100644 --- a/include/genn/genn/transpiler/errorHandler.h +++ b/include/genn/genn/transpiler/errorHandler.h @@ -27,7 +27,7 @@ class ErrorHandlerBase class ErrorHandler : public ErrorHandlerBase { public: - ErrorHandler() : m_Error(false) + explicit ErrorHandler(const std::string &context) : m_Context(context), m_Error(false) { } @@ -51,6 +51,7 @@ class ErrorHandler : public ErrorHandlerBase //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ + std::string m_Context; bool m_Error; }; diff --git a/src/genn/genn/transpiler/errorHandler.cc b/src/genn/genn/transpiler/errorHandler.cc index 11f38cc315..82b23bad1e 100644 --- a/src/genn/genn/transpiler/errorHandler.cc +++ b/src/genn/genn/transpiler/errorHandler.cc @@ -25,7 +25,7 @@ void ErrorHandler::error(const Token &token, std::string_view message) //---------------------------------------------------------------------------- void ErrorHandler::report(size_t line, std::string_view where, std::string_view message) { - LOGE_TRANSPILER << "[line " << line << "] Error" << where << ": " << message; + LOGE_TRANSPILER << "[" << m_Context << ", line " << line << "] Error" << where << ": " << message; m_Error = true; } From 9a45a0ed85a4a6041ffe3dd1381d5a75f5972a69 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 14 Jun 2023 12:12:48 +0100 Subject: [PATCH 217/725] add single-pass transpiler helper to code gen utils --- .../genn/genn/code_generator/codeGenUtils.h | 18 +++++++++++++++ src/genn/genn/code_generator/codeGenUtils.cc | 23 +++++++++++++++++-- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 4f693e2419..701b0789db 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -25,6 +25,12 @@ #include "transpiler/statement.h" #include "transpiler/typeChecker.h" +// Forward declarations +namespace GeNN::CodeGenerator +{ +class EnvironmentExternalBase; +} + //-------------------------------------------------------------------------- // GeNN::CodeGenerator //-------------------------------------------------------------------------- @@ -112,6 +118,18 @@ GENN_EXPORT std::tuple scanParseAndTypeCheckExpression( +std::tuple scanParseAndTypeCheckExpression( const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, Transpiler::ErrorHandlerBase &errorHandler) { using namespace Transpiler; @@ -526,4 +527,22 @@ GENN_EXPORT std::tuple(expressionTypes), env, typeContext, std::get<1>(expressionTypes)); +} + //-------------------------------------------------------------------------- +void prettyPrintStatements(const std::string &code, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, Transpiler::ErrorHandlerBase &errorHandler) +{ + // Scan, parse and type check statements + auto statementTypes = scanParseAndTypeCheckStatements(code, typeContext, env, errorHandler); + + // Pretty print + Transpiler::PrettyPrinter::print(std::get<0>(statementTypes), env, typeContext, std::get<1>(statementTypes)); +} } // namespace GeNN::CodeGenerator From 1d1cd491809f3edf50f3e7d56878815483fc1906 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 14 Jun 2023 12:22:13 +0100 Subject: [PATCH 218/725] start integrating across synapse group merged classes --- .../backends/single_threaded_cpu/backend.h | 2 +- .../genn/genn/code_generator/backendBase.h | 4 +- .../code_generator/synapseUpdateGroupMerged.h | 12 +- .../backends/single_threaded_cpu/backend.cc | 23 +- src/genn/genn/code_generator/backendBase.cc | 2 +- .../synapseUpdateGroupMerged.cc | 201 ++++++++---------- 6 files changed, 111 insertions(+), 133 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 658e26abc4..0d3f1d2b2a 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -186,7 +186,7 @@ class BACKEND_EXPORT Backend : public BackendBase //-------------------------------------------------------------------------- // Private methods //-------------------------------------------------------------------------- - void genPresynapticUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, const Substitutions &popSubs, bool trueSpike) const; + void genPresynapticUpdate(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, bool trueSpike) const; void genEmitSpike(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng,bool trueSpike, bool recordingEnabled) const; diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index c1f5ad029a..5ba9dd5ecd 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -35,7 +35,7 @@ class SynapseGroupInternal; namespace CodeGenerator { -class EnvironmentExternal; +class EnvironmentExternalBase; class ModelSpecMerged; class NeuronUpdateGroupMerged; class Substitutions; @@ -489,7 +489,7 @@ class GENN_EXPORT BackendBase void genNeuronIndexCalculation(CodeStream &os, const NeuronUpdateGroupMerged &ng, unsigned int batchSize) const; - void genSynapseIndexCalculation(EnvironmentExternal &env, const SynapseGroupMergedBase &sg, unsigned int batchSize) const; + void genSynapseIndexCalculation(EnvironmentExternalBase &env, const SynapseGroupMergedBase &sg, unsigned int batchSize) const; void genCustomUpdateIndexCalculation(CodeStream &os, const CustomUpdateGroupMerged &cu) const; diff --git a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h index 40ae0b56f3..a9ad211483 100644 --- a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h @@ -31,11 +31,11 @@ class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase runnerVarDecl, runnerMergedStructAlloc, name); } - void generateSpikeEventThreshold(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; - void generateSpikeEventUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; - void generateSpikeUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; - void generateProceduralConnectivity(const BackendBase &backend, CodeStream &os, Substitutions &popSubs) const; - void generateToeplitzConnectivity(const BackendBase &backend, CodeStream &os, Substitutions &popSubs) const; + void generateSpikeEventThreshold(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const; + void generateSpikeEventUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const; + void generateSpikeUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const; + void generateProceduralConnectivity(const BackendBase &backend, EnvironmentExternalBase &env) const; + void generateToeplitzConnectivity(const BackendBase &backend, EnvironmentExternalBase &env) const; //---------------------------------------------------------------------------- // Static constants @@ -69,7 +69,7 @@ class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase runnerVarDecl, runnerMergedStructAlloc, name); } - void generateSynapseUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const; //---------------------------------------------------------------------------- // Static constants diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index d3dfa2d25e..76be6c5a43 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -332,7 +332,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos funcEnv.getStream() << "const auto *group = &mergedSynapseDynamicsGroup" << s.getIndex() << "[g]; " << std::endl; // Create matching environment - EnvironmentGroupMergedField groupEnv(s, funcEnv); + EnvironmentGroupMergedField groupEnv(funcEnv, s); // Add basic fields **TODO** move to group merged groupEnv.add(Type::Uint32.addConst(), "num_pre", @@ -347,7 +347,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // _row_length // _ind - genSynapseIndexCalculation(funcEnv, s, 1); + genSynapseIndexCalculation(groupEnv, s, 1); // Loop through presynaptic neurons groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)"; @@ -410,18 +410,21 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos Timer t(synapseUpdate, "presynapticUpdate", model.isTimingEnabled()); modelMerged.genMergedPresynapticUpdateGroups( *this, - [this, &funcSubs, &synapseUpdate](SynapseDynamicsGroupMerged &s) + [this, &funcEnv, &synapseUpdate](PresynapticUpdateGroupMerged &s) { - CodeStream::Scope b(synapseUpdate); - synapseUpdate << "// merged presynaptic update group " << s.getIndex() << std::endl; - synapseUpdate << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged presynaptic update group " << s.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; { - CodeStream::Scope b(synapseUpdate); + CodeStream::Scope b(funcEnv.getStream()); // Get reference to group - synapseUpdate << "const auto *group = &mergedPresynapticUpdateGroup" << s.getIndex() << "[g]; " << std::endl; + funcEnv.getStream() << "const auto *group = &mergedPresynapticUpdateGroup" << s.getIndex() << "[g]; " << std::endl; + + // Create matching environment + EnvironmentGroupMergedField groupEnv(funcEnv, s); - genSynapseIndexCalculation(synapseUpdate, s, 1); + genSynapseIndexCalculation(groupEnv, s, 1); // generate the code for processing spike-like events if (s.getArchetype().isSpikeEventRequired()) { @@ -1664,7 +1667,7 @@ boost::uuids::detail::sha1::digest_type Backend::getHashDigest() const return hash.get_digest(); } //-------------------------------------------------------------------------- -void Backend::genPresynapticUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, const Substitutions &popSubs, bool trueSpike) const +void Backend::genPresynapticUpdate(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, bool trueSpike) const { // Get suffix based on type of events const std::string eventSuffix = trueSpike ? "" : "Evnt"; diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 4c250ed47c..8b1267f59a 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -66,7 +66,7 @@ void BackendBase::genNeuronIndexCalculation(CodeStream &os, const NeuronUpdateGr } } //----------------------------------------------------------------------- -void BackendBase::genSynapseIndexCalculation(EnvironmentExternal &env, const SynapseGroupMergedBase &sg, unsigned int batchSize) const +void BackendBase::genSynapseIndexCalculation(EnvironmentExternalBase &env, const SynapseGroupMergedBase &sg, unsigned int batchSize) const { // If batching is enabled if(batchSize > 1) { diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 2d57740503..b8b9c0de71 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -3,6 +3,9 @@ // GeNN code generator includes #include "code_generator/modelSpecMerged.h" +// GeNN transpiler includes +#include "transpiler/errorHandler.h" + using namespace GeNN; using namespace GeNN::CodeGenerator; @@ -27,40 +30,41 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa synEnv.addEGPs(backend.getDeviceVarPrefix()); // Substitute names of pre and postsynaptic weight update variable - synEnv.addVars(backend.getDeviceVarPrefix()); - synapseSubs.addVarNameSubstitution(wu->getPreVars(), "", "group->", - [&sg, &synapseSubs, batchSize](VarAccess a, const std::string&) - { - return "[" + sg.getPreWUVarIndex(batchSize, getVarAccessDuplication(a), synapseSubs["id_pre"]) + "]"; - }); - - synapseSubs.addVarNameSubstitution(wu->getPostVars(), "", "group->", - [&sg, &synapseSubs, batchSize](VarAccess a, const std::string&) - { - return "[" + sg.getPostWUVarIndex(batchSize, getVarAccessDuplication(a), synapseSubs["id_post"]) + "]"; - }); - - // If this synapse group has a kernel and weights are either procedural and kernel - if (!sg.getArchetype().getKernelSize().empty() && ( - (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::PROCEDURAL) - || (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL))) - { - // Generate kernel index - os << "const unsigned int kernelInd = "; - sg.genKernelIndex(os, synapseSubs); - os << ";" << std::endl; + synEnv.addVars(backend.getDeviceVarPrefix(), + [&sg, &synEnv, batchSize](VarAccess a, const std::string&) + { + return "[" + sg.getPreWUVarIndex(batchSize, getVarAccessDuplication(a), synEnv["id_pre"]) + "]"; + }, + {"id_pre"}); + synEnv.addVars(backend.getDeviceVarPrefix(), + [&sg, &synEnv, batchSize](VarAccess a, const std::string&) + { + return "[" + sg.getPostWUVarIndex(batchSize, getVarAccessDuplication(a), synEnv["id_post"]) + "]"; + }, + {"id_post"}); + + + // If this synapse group has a kernel + if (!sg.getArchetype().getKernelSize().empty()) { + // Generate kernel index calculation + std::ostringstream kernelIndexStream; + kernelIndexStream << "const unsigned int kernelInd = "; + sg.genKernelIndex(kernelIndexStream, synEnv); + kernelIndexStream << ";" << std::endl; // Add substitution - synapseSubs.addVarSubstitution("id_kernel", "kernelInd"); + env.add(Type::Uint32, "id_kernel", "kernelInd", + {synEnv.addInitialiser(kernelIndexStream.str())}); } // If weights are individual, substitute variables for values stored in global memory if (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::INDIVIDUAL) { - synapseSubs.addVarNameSubstitution(wu->getVars(), "", "group->", - [&sg, &synapseSubs, batchSize](VarAccess a, const std::string&) - { - return "[" + sg.getSynVarIndex(batchSize, getVarAccessDuplication(a), synapseSubs["id_syn"]) + "]"; - }); + synEnv.addVars(backend.getDeviceVarPrefix(), + [&sg, &synEnv, batchSize](VarAccess a, const std::string&) + { + return "[" + sg.getSynVarIndex(batchSize, getVarAccessDuplication(a), synEnv["id_syn"]) + "]"; + }, + {"id_syn"}); } // Otherwise, if weights are procedual else if (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::PROCEDURAL) { @@ -98,18 +102,18 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa } // Substitute variables for newly-declared local variables - synapseSubs.addVarNameSubstitution(vars, "", "l"); + synEnv.add(vars, "", "l"); } - // Otherwise, if weights are kernels + // Otherwise, if weights are kernels, use kernel index to index into variables else if(sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL) { assert(!sg.getArchetype().getKernelSize().empty()); - // Use kernel index to index into variables - synapseSubs.addVarNameSubstitution(wu->getVars(), "", "group->", - [&sg, &synapseSubs, batchSize](VarAccess a, const std::string&) - { - return "[" + sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), synapseSubs["id_kernel"]) + "]"; - }); + synEnv.addVars(backend.getDeviceVarPrefix(), + [&sg, &synEnv, batchSize](VarAccess a, const std::string&) + { + return "[" + sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), synapseSubs["id_kernel"]) + "]"; + }, + {"id_kernel"}); } // Otherwise, substitute variables for constant values else { @@ -119,7 +123,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa } // Make presynaptic neuron substitutions - const std::string axonalDelayOffset = Utils::writePreciseString(model.getDT() * (double)(sg.getArchetype().getDelaySteps() + 1u)) + " + "; + /*const std::string axonalDelayOffset = Utils::writePreciseString(model.getDT() * (double)(sg.getArchetype().getDelaySteps() + 1u)) + " + "; neuronSubstitutionsInSynapticCode(synapseSubs, sg.getArchetype().getSrcNeuronGroup(), axonalDelayOffset, "_pre", "Pre", "", "", false, [&sg](const std::string &p) { return sg.isSrcNeuronParamHeterogeneous(p); }, @@ -147,10 +151,10 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa [&synapseSubs, &sg, batchSize](bool delay, VarAccessDuplication varDuplication) { return sg.getPostPrevSpikeTimeIndex(delay, batchSize, varDuplication, synapseSubs["id_post"]); - }); + });*/ // If the backend does not support namespaces then we substitute all support code functions with namepsace as prefix - if (!backendSupportsNamespace) { + /*if (!backendSupportsNamespace) { if (!wu->getSimSupportCode().empty()) { code = disambiguateNamespaceFunction(wu->getSimSupportCode(), code, modelMerged.getPresynapticUpdateSupportCodeNamespace(wu->getSimSupportCode())); } @@ -160,12 +164,11 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa if (!wu->getSynapseDynamicsSuppportCode().empty()) { code = disambiguateNamespaceFunction(wu->getSynapseDynamicsSuppportCode(), code, modelMerged.getSynapseDynamicsSupportCodeNamespace(wu->getSynapseDynamicsSuppportCode())); } - } + }*/ - synapseSubs.apply(code); - //synapseSubs.applyCheckUnreplaced(code, errorContext + " : " + sg.getName()); - //code = ensureFtype(code, model.getPrecision()); - os << code; + // Pretty print code back to environment + Transpiler::ErrorHandler errorHandler(errorContext + std::to_string(sg.getIndex())); + prettyPrintStatements(code, sg.getTypeContext(), synEnv, errorHandler); } } // Anonymous namespace @@ -174,21 +177,18 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa //---------------------------------------------------------------------------- const std::string PresynapticUpdateGroupMerged::name = "PresynapticUpdate"; //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateSpikeEventThreshold(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void PresynapticUpdateGroupMerged::generateSpikeEventThreshold(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const { - Substitutions synapseSubs(&popSubs); + EnvironmentGroupMergedField synEnv(env, *this); - // Make weight update model substitutions - synapseSubs.addParamValueSubstitution(getArchetype().getWUModel()->getParamNames(), getArchetype().getWUParams(), - [this](const std::string &p) { return isWUParamHeterogeneous(p); }, - "", "group->"); - synapseSubs.addVarValueSubstitution(getArchetype().getWUModel()->getDerivedParams(), getArchetype().getWUDerivedParams(), - [this](const std::string &p) { return isWUDerivedParamHeterogeneous(p); }, - "", "group->"); - synapseSubs.addVarNameSubstitution(getArchetype().getWUModel()->getExtraGlobalParams(), "", "group->"); + // Substitute parameter and derived parameter names + const auto *wum = getArchetype().getWUModel(); + synEnv.addParams(wum->getParamNames(), "", &SynapseGroupInternal::getWUParams, &PresynapticUpdateGroupMerged::isWUParamHeterogeneous); + synEnv.addDerivedParams(wum->getDerivedParams(), "", &SynapseGroupInternal::getWUDerivedParams, &PresynapticUpdateGroupMerged::isWUDerivedParamHeterogeneous); + synEnv.addEGPs(backend.getDeviceVarPrefix()); // Substitute in presynaptic neuron properties - const unsigned int batchSize = modelMerged.getModel().getBatchSize(); + /*const unsigned int batchSize = modelMerged.getModel().getBatchSize(); neuronSubstitutionsInSynapticCode(synapseSubs, getArchetype().getSrcNeuronGroup(), "", "_pre", "Pre", "", "", false, [this](const std::string &p) { return isSrcNeuronParamHeterogeneous(p); }, [this](const std::string &p) { return isSrcNeuronDerivedParamHeterogeneous(p); }, @@ -199,40 +199,35 @@ void PresynapticUpdateGroupMerged::generateSpikeEventThreshold(const BackendBase [batchSize, &synapseSubs, this](bool delay, VarAccessDuplication varDuplication) { return getPrePrevSpikeTimeIndex(delay, batchSize, varDuplication, synapseSubs["id_pre"]); - }); - - const auto* wum = getArchetype().getWUModel(); + });*/ - // Get event threshold condition code - std::string code = wum->getEventThresholdConditionCode(); - synapseSubs.applyCheckUnreplaced(code, "eventThresholdConditionCode"); - //code = ensureFtype(code, modelMerged.getModel().getPrecision()); - - if (!backend.supportsNamespace() && !wum->getSimSupportCode().empty()) { - code = disambiguateNamespaceFunction(wum->getSimSupportCode(), code, modelMerged.getPresynapticUpdateSupportCodeNamespace(wum->getSimSupportCode())); - } - - os << code; + // Pretty print code back to environment + Transpiler::ErrorHandler errorHandler("eventThresholdConditionCode" + std::to_string(getIndex())); + prettyPrintStatements(wum->getEventThresholdConditionCode(), getTypeContext(), synEnv, errorHandler); } //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateSpikeEventUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void PresynapticUpdateGroupMerged::generateSpikeEventUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const { - applySynapseSubstitutions(backend, os, getArchetype().getWUModel()->getEventCode(), "eventCode", - *this, popSubs, modelMerged); + applySynapseSubstitutions(backend, env, getArchetype().getWUModel()->getEventCode(), "eventCode", + *this, modelMerged, backend.supportsNamespace()); } //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateSpikeUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void PresynapticUpdateGroupMerged::generateSpikeUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const { - applySynapseSubstitutions(backend, os, getArchetype().getWUModel()->getSimCode(), "simCode", - *this, popSubs, modelMerged); + applySynapseSubstitutions(backend, env, getArchetype().getWUModel()->getSimCode(), "simCode", + *this, modelMerged, backend.supportsNamespace()); } //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendBase&, CodeStream &os, Substitutions &popSubs) const +void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendBase&, EnvironmentExternalBase &env) const { const auto &connectInit = getArchetype().getConnectivityInitialiser(); + EnvironmentGroupMergedField synEnv(env, *this); + // Add substitutions - popSubs.addFuncSubstitution("endRow", 0, "break"); + //synEnv.addParams() + //synEnv.addParams(wu->getParamNames(), "", &SynapseGroupInternal::getWUParams, &G::isWUParamHeterogeneous); + //synEnv.addDerivedParams(wu->getDerivedParams(), "", &SynapseGroupInternal::getWUDerivedParams, &G::isWUDerivedParamHeterogeneous); popSubs.addParamValueSubstitution(connectInit.getSnippet()->getParamNames(), connectInit.getParams(), [this](const std::string &p) { return isSparseConnectivityInitParamHeterogeneous(p); }, "", "group->"); @@ -240,60 +235,40 @@ void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendB [this](const std::string &p) { return isSparseConnectivityInitDerivedParamHeterogeneous(p); }, "", "group->"); popSubs.addVarNameSubstitution(connectInit.getSnippet()->getExtraGlobalParams(), "", "group->"); - popSubs.addVarNameSubstitution(connectInit.getSnippet()->getRowBuildStateVars()); - - // Initialise row building state variables for procedural connectivity - for(const auto &a : connectInit.getSnippet()->getRowBuildStateVars()) { - // Apply substitutions to value - std::string value = a.value; - popSubs.applyCheckUnreplaced(value, "proceduralSparseConnectivity row build state var : merged" + std::to_string(getIndex())); - //value = ensureFtype(value, modelMerged.getModel().getPrecision()); - os << a.type.resolve(getTypeContext()).getName() << " " << a.name << " = " << value << ";" << std::endl; - } - - // Loop through synapses in row - os << "while(true)"; - { - CodeStream::Scope b(os); - // Apply substitutions to row building code - std::string pCode = connectInit.getSnippet()->getRowBuildCode(); + + // Apply substitutions to row building code + std::string pCode = connectInit.getSnippet()->getRowBuildCode(); - popSubs.applyCheckUnreplaced(pCode, "proceduralSparseConnectivity : merged " + std::to_string(getIndex())); - //pCode = ensureFtype(pCode, modelMerged.getModel().getPrecision()); - - // Write out code - os << pCode << std::endl; - } -} -//---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateToeplitzConnectivity(const BackendBase&, CodeStream &os, Substitutions &popSubs) const -{ - const auto &connectInit = getArchetype().getToeplitzConnectivityInitialiser(); - - // Apply substitutions to diagonal building code - std::string pCode = connectInit.getSnippet()->getDiagonalBuildCode(); - popSubs.applyCheckUnreplaced(pCode, "toeplitzSparseConnectivity : merged " + std::to_string(getIndex())); + popSubs.applyCheckUnreplaced(pCode, "proceduralSparseConnectivity : merged " + std::to_string(getIndex())); //pCode = ensureFtype(pCode, modelMerged.getModel().getPrecision()); // Write out code os << pCode << std::endl; } +//---------------------------------------------------------------------------- +void PresynapticUpdateGroupMerged::generateToeplitzConnectivity(const BackendBase&, EnvironmentExternalBase &env) const +{ + // Pretty print code back to environment + Transpiler::ErrorHandler errorHandler("toeplitzSparseConnectivity" + std::to_string(getIndex())); + prettyPrintStatements(getArchetype().getToeplitzConnectivityInitialiser().getSnippet()->getDiagonalBuildCode(), + getTypeContext(), env, errorHandler); +} //---------------------------------------------------------------------------- // GeNN::CodeGenerator::PostsynapticUpdateGroupMerged //---------------------------------------------------------------------------- const std::string PostsynapticUpdateGroupMerged::name = "PostsynapticUpdate"; //---------------------------------------------------------------------------- -void PostsynapticUpdateGroupMerged::generateSynapseUpdate(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void PostsynapticUpdateGroupMerged::generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const { const auto *wum = getArchetype().getWUModel(); - if (!wum->getLearnPostSupportCode().empty() && backend.supportsNamespace()) { + /*if (!wum->getLearnPostSupportCode().empty() && backend.supportsNamespace()) { os << "using namespace " << modelMerged.getPostsynapticUpdateSupportCodeNamespace(wum->getLearnPostSupportCode()) << ";" << std::endl; - } + }*/ - applySynapseSubstitutions(os, wum->getLearnPostCode(), "learnPostCode", - *this, popSubs, modelMerged, backend.supportsNamespace()); + applySynapseSubstitutions(backend, env, wum->getLearnPostCode(), "synapselearnPostCodeDynamics", + *this, modelMerged, backend.supportsNamespace()); } //---------------------------------------------------------------------------- From 3fa9502d67d042da0e93db116b9172fc6ba8257f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 14 Jun 2023 18:42:34 +0100 Subject: [PATCH 219/725] hackathon continues --- .../backends/single_threaded_cpu/backend.h | 12 +- .../genn/genn/code_generator/backendBase.h | 144 +++++- .../genn/genn/code_generator/environment.h | 35 +- .../genn/code_generator/standardLibrary.h | 22 +- .../backends/single_threaded_cpu/backend.cc | 458 ++++++++++-------- src/genn/genn/code_generator/backendBase.cc | 125 ----- src/genn/genn/code_generator/environment.cc | 41 ++ .../genn/code_generator/standardLibrary.cc | 50 +- 8 files changed, 467 insertions(+), 420 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 0d3f1d2b2a..f9b89a33bc 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -121,12 +121,12 @@ class BACKEND_EXPORT Backend : public BackendBase //! When generating merged structures what type to use for simulation RNGs virtual std::optional getMergedGroupSimRNGType() const final; - virtual void genPopVariableInit(EnvironmentExternal &env, HandlerEnv handler) const final; - virtual void genVariableInit(EnvironmentExternal &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const final; - virtual void genSparseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const final; - virtual void genDenseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const final; - virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const final; - virtual void genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged &cu, const Substitutions &kernelSubs, Handler handler) const final; + virtual void genPopVariableInit(EnvironmentExternalBase &env, HandlerEnv handler) const final; + virtual void genVariableInit(EnvironmentExternalBase &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const final; + virtual void genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const final; + virtual void genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const final; + virtual void genKernelSynapseVariableInit(EnvironmentExternalBase &env, const SynapseInitGroupMerged &sg, HandlerEnv handler) const final; + virtual void genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, const CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const final; virtual void genGlobalDeviceRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, MemAlloc &memAlloc) const final; diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 5ba9dd5ecd..e72fa66499 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -35,6 +35,8 @@ class SynapseGroupInternal; namespace CodeGenerator { +template +class EnvironmentGroupMergedField; class EnvironmentExternalBase; class ModelSpecMerged; class NeuronUpdateGroupMerged; @@ -187,13 +189,13 @@ class GENN_EXPORT BackendBase typedef std::function Handler; - typedef std::function HandlerEnv; + typedef std::function HandlerEnv; template using GroupHandler = std::function ; template - using GroupHandlerEnv = std::function ; + using GroupHandlerEnv = std::function ; //! Vector of prefixes required to allocate in memory space and size of memory space typedef std::vector> MemorySpaces; @@ -314,12 +316,12 @@ class GENN_EXPORT BackendBase //! When generating merged structures what type to use for simulation RNGs virtual std::optional getMergedGroupSimRNGType() const = 0; - virtual void genPopVariableInit(EnvironmentExternal &env, HandlerEnv handler) const = 0; - virtual void genVariableInit(EnvironmentExternal &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const = 0; - virtual void genSparseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0; - virtual void genDenseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const = 0; - virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const = 0; - virtual void genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged &cu, const Substitutions &kernelSubs, Handler handler) const = 0; + virtual void genPopVariableInit(EnvironmentExternalBase &env, HandlerEnv handler) const = 0; + virtual void genVariableInit(EnvironmentExternalBase &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const = 0; + virtual void genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, Handler handler) const = 0; + virtual void genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, Handler handler) const = 0; + virtual void genKernelSynapseVariableInit(EnvironmentExternalBase &env, const SynapseInitGroupMerged &sg, HandlerEnv handler) const = 0; + virtual void genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, const CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const = 0; //! Generate a single RNG instance /*! On single-threaded platforms this can be a standard RNG like M.T. but, on parallel platforms, it is likely to be a counter-based RNG */ @@ -487,10 +489,132 @@ class GENN_EXPORT BackendBase m_PointerBytes = pointerBytes; } - void genNeuronIndexCalculation(CodeStream &os, const NeuronUpdateGroupMerged &ng, unsigned int batchSize) const; + template + void genNeuronIndexCalculation(EnvironmentGroupMergedField &env, const NeuronUpdateGroupMerged &ng, unsigned int batchSize) const + { + // If batching is enabled, calculate batch offset + if(batchSize > 1) { + os << "const unsigned int batchOffset = group->numNeurons * batch;" << std::endl; + } + + // If axonal delays are required + if(ng.getArchetype().isDelayRequired()) { + // We should READ from delay slot before spkQuePtr + os << "const unsigned int readDelaySlot = (*group->spkQuePtr + " << (ng.getArchetype().getNumDelaySlots() - 1) << ") % " << ng.getArchetype().getNumDelaySlots() << ";" << std::endl; + os << "const unsigned int readDelayOffset = readDelaySlot * group->numNeurons;" << std::endl; + + // And we should WRITE to delay slot pointed to be spkQuePtr + os << "const unsigned int writeDelaySlot = *group->spkQuePtr;" << std::endl; + os << "const unsigned int writeDelayOffset = writeDelaySlot * group->numNeurons;" << std::endl; + + // If batching is also enabled + if(batchSize > 1) { + // Calculate batched delay slots + os << "const unsigned int readBatchDelaySlot = (batch * " << ng.getArchetype().getNumDelaySlots() << ") + readDelaySlot;" << std::endl; + os << "const unsigned int writeBatchDelaySlot = (batch * " << ng.getArchetype().getNumDelaySlots() << ") + writeDelaySlot;" << std::endl; + + // Calculate current batch offset + os << "const unsigned int batchDelayOffset = batchOffset * " << ng.getArchetype().getNumDelaySlots() << ";" << std::endl; + + // Calculate further offsets to include delay and batch + os << "const unsigned int readBatchDelayOffset = readDelayOffset + batchDelayOffset;" << std::endl; + os << "const unsigned int writeBatchDelayOffset = writeDelayOffset + batchDelayOffset;" << std::endl; + } + } + } + + template + void genSynapseIndexCalculation(EnvironmentGroupMergedField &env, const SynapseGroupMergedBase &sg, unsigned int batchSize) const + { + // If batching is enabled + if(batchSize > 1) { + // Calculate batch offsets into pre and postsynaptic populations + env.add(Type::Uint32.addConst(), "_pre_batch_offset", "preBatchOffset", + {env.addInitialiser("const unsigned int preBatchOffset = " + env["num_pre"] + " * " + env["batch"] + ";")}); + env.add(Type::Uint32.addConst(), "_post_batch_offset", "postBatchOffset", + {env.addInitialiser("const unsigned int preBatchOffset = " + env["num_post"] + " * " + env["batch"] + ";")}); + + // Calculate batch offsets into synapse arrays, using 64-bit arithmetic if necessary + if(areSixtyFourBitSynapseIndicesRequired(sg)) { + assert(false); + //os << "const uint64_t synBatchOffset = (uint64_t)preBatchOffset * (uint64_t)group->rowStride;" << std::endl; + } + else { + env.add(Type::Uint32.addConst(), "_syn_batch_offset", "synBatchOffset", + {env.addInitialiser("const unsigned int synBatchOffset = " + env["_pre_batch_offset"] + " * " + env["_row_stride"] + ";")}); + } + + // If synapse group has kernel weights + /*const auto &kernelSize = sg.getArchetype().getKernelSize(); + if((sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL) && !kernelSize.empty()) { + // Loop through kernel dimensions and multiply together + os << "const unsigned int kernBatchOffset = "; + for(size_t i = 0; i < kernelSize.size(); i++) { + os << sg.getKernelSize(i) << " * "; + } + + // And finally by batch + os << "batch;" << std::endl; + }*/ + } + + // If presynaptic neuron group has variable queues, calculate offset to read from its variables with axonal delay + /*if(sg.getArchetype().getSrcNeuronGroup()->isDelayRequired()) { + const unsigned int numDelaySteps = sg.getArchetype().getDelaySteps(); + const unsigned int numSrcDelaySlots = sg.getArchetype().getSrcNeuronGroup()->getNumDelaySlots(); + + os << "const unsigned int preDelaySlot = "; + if(numDelaySteps == 0) { + os << "*group->srcSpkQuePtr;" << std::endl; + } + else { + os << "(*group->srcSpkQuePtr + " << (numSrcDelaySlots - numDelaySteps) << ") % " << numSrcDelaySlots << ";" << std::endl; + } + os << "const unsigned int preDelayOffset = preDelaySlot * group->numSrcNeurons;" << std::endl; - void genSynapseIndexCalculation(EnvironmentExternalBase &env, const SynapseGroupMergedBase &sg, unsigned int batchSize) const; + if(batchSize > 1) { + os << "const unsigned int preBatchDelaySlot = preDelaySlot + (batch * " << numSrcDelaySlots << ");" << std::endl; + os << "const unsigned int preBatchDelayOffset = preDelayOffset + (preBatchOffset * " << numSrcDelaySlots << ");" << std::endl; + } + + if(sg.getArchetype().getWUModel()->isPrevPreSpikeTimeRequired() || sg.getArchetype().getWUModel()->isPrevPreSpikeEventTimeRequired()) { + os << "const unsigned int prePrevSpikeTimeDelayOffset = " << "((*group->srcSpkQuePtr + " << (numSrcDelaySlots - numDelaySteps - 1) << ") % " << numSrcDelaySlots << ")" << " * group->numSrcNeurons;" << std::endl; + + if(batchSize > 1) { + os << "const unsigned int prePrevSpikeTimeBatchDelayOffset = prePrevSpikeTimeDelayOffset + (preBatchOffset * " << numSrcDelaySlots << ");" << std::endl; + } + } + } + + // If postsynaptic neuron group has variable queues, calculate offset to read from its variables at current time + if(sg.getArchetype().getTrgNeuronGroup()->isDelayRequired()) { + const unsigned int numBackPropDelaySteps = sg.getArchetype().getBackPropDelaySteps(); + const unsigned int numTrgDelaySlots = sg.getArchetype().getTrgNeuronGroup()->getNumDelaySlots(); + + os << "const unsigned int postDelaySlot = "; + if(numBackPropDelaySteps == 0) { + os << "*group->trgSpkQuePtr;" << std::endl; + } + else { + os << "(*group->trgSpkQuePtr + " << (numTrgDelaySlots - numBackPropDelaySteps) << ") % " << numTrgDelaySlots << ";" << std::endl; + } + os << "const unsigned int postDelayOffset = postDelaySlot * group->numTrgNeurons;" << std::endl; + + if(batchSize > 1) { + os << "const unsigned int postBatchDelaySlot = postDelaySlot + (batch * " << numTrgDelaySlots << ");" << std::endl; + os << "const unsigned int postBatchDelayOffset = postDelayOffset + (postBatchOffset * " << numTrgDelaySlots << ");" << std::endl; + } + if(sg.getArchetype().getWUModel()->isPrevPostSpikeTimeRequired()) { + os << "const unsigned int postPrevSpikeTimeDelayOffset = " << "((*group->trgSpkQuePtr + " << (numTrgDelaySlots - numBackPropDelaySteps - 1) << ") % " << numTrgDelaySlots << ")" << " * group->numTrgNeurons;" << std::endl; + + if(batchSize > 1) { + os << "const unsigned int postPrevSpikeTimeBatchDelayOffset = postPrevSpikeTimeDelayOffset + (postBatchOffset * " << numTrgDelaySlots << ");" << std::endl; + } + + } + }*/ + } void genCustomUpdateIndexCalculation(CodeStream &os, const CustomUpdateGroupMerged &cu) const; void genCustomConnectivityUpdateIndexCalculation(CodeStream &os, const CustomConnectivityUpdateGroupMerged &cu) const; diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index dd7061c1d4..ef7cd88dc1 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -16,6 +16,7 @@ // GeNN transpiler includes #include "transpiler/prettyPrinter.h" +#include "transpiler/token.h" #include "transpiler/typeChecker.h" // Forward declarations @@ -85,6 +86,37 @@ class EnvironmentExternalBase : public Transpiler::PrettyPrinter::EnvironmentBas std::variant, std::reference_wrapper> m_Context; }; +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentLibrary +//---------------------------------------------------------------------------- +class EnvironmentLibrary : public EnvironmentExternalBase +{ +public: + using Library = std::unordered_multimap>; + + EnvironmentLibrary(EnvironmentExternalBase &enclosing, const Library &library) + : EnvironmentExternalBase(enclosing), m_Library(library) + {} + + EnvironmentLibrary(CodeStream &os, const Library &library) + : EnvironmentExternalBase(os), m_Library(library) + {} + + //------------------------------------------------------------------------ + // TypeChecker::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::vector getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) final; + + //------------------------------------------------------------------------ + // PrettyPrinter::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::string getName(const std::string &name, std::optional type = std::nullopt) final; + virtual CodeGenerator::CodeStream &getStream() final; + +private: + std::reference_wrapper m_Library; +}; + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentSubstitutionPolicy //---------------------------------------------------------------------------- @@ -214,7 +246,8 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, private P // If this identifier relies on any others, get their types // **YUCK** for(const std::string &id : std::get<2>(env->second)) { - getTypes(Token{Token::Type::IDENTIFIER, id, 0}, errorHandler); + getTypes(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, id, 0}, + errorHandler); } // Perform any type-specific logic to mark this identifier as required diff --git a/include/genn/genn/code_generator/standardLibrary.h b/include/genn/genn/code_generator/standardLibrary.h index 2bc6c6af8f..cf51594eaa 100644 --- a/include/genn/genn/code_generator/standardLibrary.h +++ b/include/genn/genn/code_generator/standardLibrary.h @@ -1,11 +1,6 @@ #pragma once -// Standard C++ includes -#include -#include - // Code generator includes -#include "code_generator/codeStream.h" #include "code_generator/environment.h" //--------------------------------------------------------------------------- @@ -13,20 +8,5 @@ //--------------------------------------------------------------------------- namespace GeNN::CodeGenerator::StandardLibrary { -class Environment : public EnvironmentExternalBase -{ -public: - using EnvironmentExternalBase::EnvironmentExternalBase; - - //------------------------------------------------------------------------ - // TypeChecker::EnvironmentBase virtuals - //------------------------------------------------------------------------ - virtual std::vector getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) final; - - //------------------------------------------------------------------------ - // PrettyPrinter::EnvironmentBase virtuals - //------------------------------------------------------------------------ - virtual std::string getName(const std::string &name, std::optional type = std::nullopt) final; - virtual CodeGenerator::CodeStream &getStream() final; -}; +const EnvironmentLibrary::Library &getFunctions(); } // namespace GeNN::CodeGenerator::StandardLibrary diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 76be6c5a43..442cac63ef 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -11,6 +11,7 @@ #include "code_generator/standardLibrary.h" #include "code_generator/substitutions.h" +using namespace GeNN; using namespace GeNN::CodeGenerator; using namespace GeNN::Transpiler; @@ -19,22 +20,22 @@ using namespace GeNN::Transpiler; //-------------------------------------------------------------------------- namespace { -const std::vector cpuSinglePrecisionFunctions = { - {"gennrand_uniform", 0, "standardUniformDistribution($(rng))"}, - {"gennrand_normal", 0, "standardNormalDistribution($(rng))"}, - {"gennrand_exponential", 0, "standardExponentialDistribution($(rng))"}, - {"gennrand_log_normal", 2, "std::lognormal_distribution($(0), $(1))($(rng))"}, - {"gennrand_gamma", 1, "std::gamma_distribution($(0), 1.0f)($(rng))"}, - {"gennrand_binomial", 2, "std::binomial_distribution($(0), $(1))($(rng))"} +const EnvironmentLibrary::Library cpuSinglePrecisionFunctions = { + {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Float, {}), "standardUniformDistribution(hostRNG)"}}, + {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "standardNormalDistribution(hostRNG)"}}, + {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Float, {}), "standardExponentialDistribution(hostRNG)"}}, + {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "std::lognormal_distribution($(0), $(1))(hostRNG)"}}, + {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Float, {}), "std::gamma_distribution($(0), 1.0f)(hostRNG)"}}, + {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Float, {}), "std::binomial_distribution($(0), $(1))(hostRNG)"}}, }; -//-------------------------------------------------------------------------- -const std::vector cpuDoublePrecisionFunctions = { - {"gennrand_uniform", 0, "standardUniformDistribution($(rng))"}, - {"gennrand_normal", 0, "standardNormalDistribution($(rng))"}, - {"gennrand_exponential", 0, "standardExponentialDistribution($(rng))"}, - {"gennrand_log_normal", 2, "std::lognormal_distribution($(0), $(1))($(rng))"}, - {"gennrand_gamma", 1, "std::gamma_distribution($(0), 1.0)($(rng))"}, - {"gennrand_binomial", 2, "std::binomial_distribution($(0), $(1))($(rng))"} + +const EnvironmentLibrary::Library cpuDoublePrecisionFunctions = { + {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Float, {}), "standardUniformDistribution(hostRNG)"}}, + {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "standardNormalDistribution(hostRNG)"}}, + {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Float, {}), "standardExponentialDistribution(hostRNG)"}}, + {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "std::lognormal_distribution($(0), $(1))(hostRNG)"}}, + {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Float, {}), "std::gamma_distribution($(0), 1.0)(hostRNG)"}}, + {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Float, {}), "std::binomial_distribution($(0), $(1))(hostRNG)"}}, }; //-------------------------------------------------------------------------- @@ -131,166 +132,186 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host throw std::runtime_error("The single-threaded CPU backend only supports simulations with a batch size of 1"); } - // Generate struct definitions - modelMerged.genMergedNeuronUpdateGroupStructs(os, *this); - modelMerged.genMergedNeuronSpikeQueueUpdateStructs(os, *this); - modelMerged.genMergedNeuronPrevSpikeTimeUpdateStructs(os, *this); - - // Generate arrays of merged structs and functions to set them - genMergedStructArrayPush(os, modelMerged.getMergedNeuronUpdateGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedNeuronSpikeQueueUpdateGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups()); + + // Generate stream with neuron update code + std::ostringstream neuronUpdateStream; + CodeStream neuronUpdate(neuronUpdateStream); - // Generate preamble - preambleHandler(os); + // Begin environment with standard library + EnvironmentLibrary neuronUpdateEnv(neuronUpdate, StandardLibrary::getFunctions()); - os << "void updateNeurons(timepoint t"; + neuronUpdateEnv.getStream() << "void updateNeurons(timepoint t"; if(model.isRecordingInUse()) { - os << ", unsigned int recordingTimestep"; + neuronUpdateEnv.getStream() << ", unsigned int recordingTimestep"; } - os << ")"; + neuronUpdateEnv.getStream() << ")"; { - CodeStream::Scope b(os); + CodeStream::Scope b(neuronUpdateEnv.getStream()); - StandardLibrary::Environment stdEnv(os); - EnvironmentSubstitute funcEnv(stdEnv); - funcEnv.addSubstitution("t", "t"); - funcEnv.addSubstitution("batch", "0"); + EnvironmentExternal funcEnv(neuronUpdateEnv); + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); + funcEnv.add(Type::Uint32.addConst(), "batch", "0"); Timer t(funcEnv.getStream(), "neuronUpdate", model.isTimingEnabled()); - - // Loop through merged previous spike time update groups - for(const auto &n : modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups()) { - CodeStream::Scope b(funcEnv.getStream()); - funcEnv.getStream() << "// merged neuron prev spike update group " << n.getIndex() << std::endl; - funcEnv.getStream() << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; + modelMerged.genMergedNeuronPrevSpikeTimeUpdateGroups( + *this, + [this, &funcEnv, &modelMerged](auto &n) { CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged neuron prev spike update group " << n.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(funcEnv.getStream()); - // Get reference to group - funcEnv.getStream() << "const auto *group = &mergedNeuronPrevSpikeTimeUpdateGroup" << n.getIndex() << "[g]; " << std::endl; - - if(n.getArchetype().isDelayRequired()) { - // Calculate delay slot corresponding to last timestep - funcEnv.getStream() << "const unsigned int lastTimestepDelaySlot = (*group->spkQuePtr + " << (n.getArchetype().getNumDelaySlots() - 1) << ") % " << n.getArchetype().getNumDelaySlots() << ";" << std::endl; - funcEnv.getStream() << "const unsigned int lastTimestepDelayOffset = lastTimestepDelaySlot * group->numNeurons;" << std::endl; - - if(n.getArchetype().isPrevSpikeTimeRequired()) { - // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - funcEnv.getStream() << "for(unsigned int i = 0; i < group->spkCnt[lastTimestepDelaySlot]; i++)"; - { - CodeStream::Scope b(funcEnv.getStream()); - funcEnv.getStream() << "group->prevST[lastTimestepDelayOffset + group->spk[lastTimestepDelayOffset + i]] = t - DT;" << std::endl; + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedNeuronPrevSpikeTimeUpdateGroup" << n.getIndex() << "[g]; " << std::endl; + + // Create matching environment + EnvironmentGroupMergedField groupEnv(funcEnv, n); + + if(n.getArchetype().isDelayRequired()) { + // Calculate delay slot corresponding to last timestep + groupEnv.getStream() << "const unsigned int lastTimestepDelaySlot = (*group->spkQuePtr + " << (n.getArchetype().getNumDelaySlots() - 1) << ") % " << n.getArchetype().getNumDelaySlots() << ";" << std::endl; + groupEnv.getStream() << "const unsigned int lastTimestepDelayOffset = lastTimestepDelaySlot * group->numNeurons;" << std::endl; + + if(n.getArchetype().isPrevSpikeTimeRequired()) { + // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep + groupEnv.getStream() << "for(unsigned int i = 0; i < group->spkCnt[lastTimestepDelaySlot]; i++)"; + { + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "group->prevST[lastTimestepDelayOffset + group->spk[lastTimestepDelayOffset + i]] = t - DT;" << std::endl; + } } - } - if(n.getArchetype().isPrevSpikeEventTimeRequired()) { - // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - funcEnv.getStream() << "for(unsigned int i = 0; i < group->spkCntEvnt[lastTimestepDelaySlot]; i++)"; - { - CodeStream::Scope b(funcEnv.getStream()); - funcEnv.getStream() << "group->prevSET[lastTimestepDelayOffset + group->spkEvnt[lastTimestepDelayOffset + i]] = t - DT;" << std::endl; + if(n.getArchetype().isPrevSpikeEventTimeRequired()) { + // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep + groupEnv.getStream() << "for(unsigned int i = 0; i < group->spkCntEvnt[lastTimestepDelaySlot]; i++)"; + { + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "group->prevSET[lastTimestepDelayOffset + group->spkEvnt[lastTimestepDelayOffset + i]] = t - DT;" << std::endl; + } } } - } - else { - if(n.getArchetype().isPrevSpikeTimeRequired()) { - // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - funcEnv.getStream() << "for(unsigned int i = 0; i < group->spkCnt[0]; i++)"; - { - CodeStream::Scope b(funcEnv.getStream()); - funcEnv.getStream() << "group->prevST[group->spk[i]] = t - DT;" << std::endl; + else { + if(n.getArchetype().isPrevSpikeTimeRequired()) { + // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep + groupEnv.getStream() << "for(unsigned int i = 0; i < group->spkCnt[0]; i++)"; + { + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "group->prevST[group->spk[i]] = t - DT;" << std::endl; + } } - } - if(n.getArchetype().isPrevSpikeEventTimeRequired()) { - // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - funcEnv.getStream() << "for(unsigned int i = 0; i < group->spkCntEvnt[0]; i++)"; - { - CodeStream::Scope b(funcEnv.getStream()); - funcEnv.getStream() << "group->prevSET[group->spkEvnt[i]] = t - DT;" << std::endl; + if(n.getArchetype().isPrevSpikeEventTimeRequired()) { + // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep + groupEnv.getStream() << "for(unsigned int i = 0; i < group->spkCntEvnt[0]; i++)"; + { + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "group->prevSET[group->spkEvnt[i]] = t - DT;" << std::endl; + } } } } - } - } + }); // Loop through merged neuron spike queue update groups - for(const auto &n : modelMerged.getMergedNeuronSpikeQueueUpdateGroups()) { - CodeStream::Scope b(funcEnv.getStream()); - funcEnv.getStream() << "// merged neuron spike queue update group " << n.getIndex() << std::endl; - funcEnv.getStream() << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; + modelMerged.genMergedNeuronSpikeQueueUpdateGroups( + *this, + [this, &funcEnv, &modelMerged](auto &n) { CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged neuron spike queue update group " << n.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(funcEnv.getStream()); - // Get reference to group - funcEnv.getStream() << "const auto *group = &mergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << "[g]; " << std::endl; + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << "[g]; " << std::endl; + EnvironmentGroupMergedField groupEnv(funcEnv, n); + + // Generate spike count reset + n.genMergedGroupSpikeCountReset(groupEnv.getStream(), 1); + } + }); - // Generate spike count reset - n.genMergedGroupSpikeCountReset(funcEnv.getStream(), 1); - } - - } // Loop through merged neuron update groups - for(const auto &n : modelMerged.getMergedNeuronUpdateGroups()) { - CodeStream::Scope b(funcEnv.getStream()); - funcEnv.getStream() << "// merged neuron update group " << n.getIndex() << std::endl; - funcEnv.getStream() << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; + modelMerged.genMergedNeuronUpdateGroups( + *this, + [this, &funcEnv, &modelMerged](auto &n) { CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged neuron update group " << n.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(funcEnv.getStream()); - // Get reference to group - funcEnv.getStream() << "const auto *group = &mergedNeuronUpdateGroup" << n.getIndex() << "[g]; " << std::endl; + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedNeuronUpdateGroup" << n.getIndex() << "[g]; " << std::endl; + EnvironmentGroupMergedField groupEnv(funcEnv, n); - // If spike or spike-like event recording is in use - if(n.getArchetype().isSpikeRecordingEnabled() || n.getArchetype().isSpikeEventRecordingEnabled()) { - // Calculate number of words which will be used to record this population's spikes - funcEnv.getStream() << "const unsigned int numRecordingWords = (group->numNeurons + 31) / 32;" << std::endl; + // If spike or spike-like event recording is in use + if(n.getArchetype().isSpikeRecordingEnabled() || n.getArchetype().isSpikeEventRecordingEnabled()) { + // Calculate number of words which will be used to record this population's spikes + groupEnv.getStream() << "const unsigned int numRecordingWords = (group->numNeurons + 31) / 32;" << std::endl; - // Zero spike recording buffer - if(n.getArchetype().isSpikeRecordingEnabled()) { - funcEnv.getStream() << "std::fill_n(&group->recordSpk[recordingTimestep * numRecordingWords], numRecordingWords, 0);" << std::endl; - } + // Zero spike recording buffer + if(n.getArchetype().isSpikeRecordingEnabled()) { + groupEnv.getStream() << "std::fill_n(&group->recordSpk[recordingTimestep * numRecordingWords], numRecordingWords, 0);" << std::endl; + } - // Zero spike-like-event recording buffer - if(n.getArchetype().isSpikeEventRecordingEnabled()) { - funcEnv.getStream() << "std::fill_n(&group->recordSpkEvent[recordingTimestep * numRecordingWords], numRecordingWords, 0);" << std::endl; + // Zero spike-like-event recording buffer + if(n.getArchetype().isSpikeEventRecordingEnabled()) { + groupEnv.getStream() << "std::fill_n(&group->recordSpkEvent[recordingTimestep * numRecordingWords], numRecordingWords, 0);" << std::endl; + } } - } - genNeuronIndexCalculation(funcEnv.getStream(), n, 1); - funcEnv.getStream() << std::endl; + genNeuronIndexCalculation(groupEnv, n, 1); + groupEnv.getStream() << std::endl; - funcEnv.getStream() << "for(unsigned int i = 0; i < group->numNeurons; i++)"; - { - CodeStream::Scope b(funcEnv.getStream()); + groupEnv.getStream() << "for(unsigned int i = 0; i < group->numNeurons; i++)"; + { + CodeStream::Scope b(groupEnv.getStream()); - EnvironmentSubstitute popEnv(funcEnv); - popEnv.addSubstitution("id", "i"); + groupEnv.add(Type::Uint32, "id", "i"); - // If this neuron group requires a simulation RNG, substitute in global RNG - if(n.getArchetype().isSimRNGRequired()) { - popEnv.addSubstitution("rng", "hostRNG"); - } + // Add RNG libray + EnvironmentLibrary rngEnv(groupEnv, (modelMerged.getModel().getPrecision() == Type::Float) ? cpuSinglePrecisionFunctions : cpuDoublePrecisionFunctions; - n.generateNeuronUpdate(*this, popEnv, modelMerged, - // Emit true spikes - [&modelMerged, this](EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng) - { - // Insert code to update WU vars - ng.generateWUVarUpdate(*this, env, modelMerged); - - // Insert code to emit true spikes - genEmitSpike(env, ng, true, ng.getArchetype().isSpikeRecordingEnabled()); - }, - // Emit spike-like events - [this](EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng) - { - // Insert code to emit spike-like events - genEmitSpike(env, ng, false, ng.getArchetype().isSpikeEventRecordingEnabled()); - }); + // Generate neuron update + n.generateNeuronUpdate(*this, rngEnv, modelMerged, + // Emit true spikes + [&modelMerged, this](EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng) + { + // Insert code to update WU vars + ng.generateWUVarUpdate(*this, env, modelMerged); + + // Insert code to emit true spikes + genEmitSpike(env, ng, true, ng.getArchetype().isSpikeRecordingEnabled()); + }, + // Emit spike-like events + [this](EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng) + { + // Insert code to emit spike-like events + genEmitSpike(env, ng, false, ng.getArchetype().isSpikeEventRecordingEnabled()); + }); + } } - } - } + }); } + + // Generate struct definitions + modelMerged.genMergedNeuronUpdateGroupStructs(os, *this); + modelMerged.genMergedNeuronSpikeQueueUpdateStructs(os, *this); + modelMerged.genMergedNeuronPrevSpikeTimeUpdateStructs(os, *this); + + // Generate arrays of merged structs and functions to set them + genMergedStructArrayPush(os, modelMerged.getMergedNeuronUpdateGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedNeuronSpikeQueueUpdateGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups()); + + // Generate preamble + preambleHandler(os); + + os << neuronUpdateStream.str(); + } //-------------------------------------------------------------------------- void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const @@ -305,7 +326,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos CodeStream synapseUpdate(synapseUpdateStream); // Begin environment with standard library - StandardLibrary::Environment synapseUpdateEnv(synapseUpdate); + EnvironmentLibrary synapseUpdateEnv(synapseUpdate, StandardLibrary::getFunctions()); synapseUpdateEnv.getStream() << "void updateSynapses(timepoint t)"; { @@ -317,10 +338,10 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Synapse dynamics { - Timer t(synapseUpdate, "synapseDynamics", model.isTimingEnabled()); + Timer t(funcEnv.getStream(), "synapseDynamics", model.isTimingEnabled()); modelMerged.genMergedSynapseDynamicsGroups( *this, - [this, &funcEnv, &modelMerged, &synapseUpdate](SynapseDynamicsGroupMerged &s) + [this, &funcEnv, &modelMerged](auto &s) { CodeStream::Scope b(funcEnv.getStream()); funcEnv.getStream() << "// merged synapse dynamics group " << s.getIndex() << std::endl; @@ -365,40 +386,39 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos throw std::runtime_error("Only DENSE and SPARSE format connectivity can be used for synapse dynamics"); } { - EnvironmentExternal synEnv(groupEnv); - CodeStream::Scope b(synEnv.getStream()); + CodeStream::Scope b(groupEnv.getStream()); // Add presynaptic index to substitutions - synEnv.add(Type::Uint32.addConst(), "id_pre", "i"); + groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); if (s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Add initialiser strings to calculate synaptic and presynaptic index - const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["_row_stride"] + ") + s;"); - const size_t idPostInit = synEnv.addInitialiser("const unsigned int idPost = " + synEnv["_ind"] + "[idSyn];"); + const size_t idSynInit = groupEnv.addInitialiser("const unsigned int idSyn = (i * " + groupEnv["_row_stride"] + ") + s;"); + const size_t idPostInit = groupEnv.addInitialiser("const unsigned int idPost = " + groupEnv["_ind"] + "[idSyn];"); // **TODO** id_syn can be 64-bit - synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}, {"_row_stride"}); - synEnv.add(Type::Uint32.addConst(), "id_post", "idPost", {idPostInit, idSynInit}, {"_ind"}); + groupEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}, {"_row_stride"}); + groupEnv.add(Type::Uint32.addConst(), "id_post", "idPost", {idPostInit, idSynInit}, {"_ind"}); } else { // Add postsynaptic index to substitutions - synEnv.add(Type::Uint32.addConst(), "id_post", "j"); + groupEnv.add(Type::Uint32.addConst(), "id_post", "j"); // Add initialiser to calculate synaptic index - const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["num_post"] + ") + j;"); + const size_t idSynInit = groupEnv.addInitialiser("const unsigned int idSyn = (i * " + groupEnv["num_post"] + ") + j;"); // **TODO** id_syn can be 64-bit - synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}); + groupEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}); } // Add correct functions for apply synaptic input - synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", synEnv["_den_delay"] + "[" + s.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)"); - synEnv.add(Type::AddToPost, "addToPost", synEnv["_out_post"] + "[" + s.getPostISynIndex(1, "j") + "] += $(0)"); - synEnv.add(Type::AddToPre, "addToPre", synEnv["_out_pre"] + "[" + s.getPreISynIndex(1, synEnv["id_pre"]) + "] += $(0)"); + groupEnv.add(Type::AddToPostDenDelay, "addToPostDelay", groupEnv["_den_delay"] + "[" + s.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)"); + groupEnv.add(Type::AddToPost, "addToPost", groupEnv["_out_post"] + "[" + s.getPostISynIndex(1, "j") + "] += $(0)"); + groupEnv.add(Type::AddToPre, "addToPre", groupEnv["_out_pre"] + "[" + s.getPreISynIndex(1, groupEnv["id_pre"]) + "] += $(0)"); // Call synapse dynamics handler - s.generateSynapseUpdate(*this, synEnv, modelMerged); + s.generateSynapseUpdate(*this, groupEnv, modelMerged); } } } @@ -407,10 +427,10 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Presynaptic update { - Timer t(synapseUpdate, "presynapticUpdate", model.isTimingEnabled()); + Timer t(funcEnv.getStream(), "presynapticUpdate", model.isTimingEnabled()); modelMerged.genMergedPresynapticUpdateGroups( *this, - [this, &funcEnv, &synapseUpdate](PresynapticUpdateGroupMerged &s) + [this, &funcEnv](auto &s) { CodeStream::Scope b(funcEnv.getStream()); funcEnv.getStream() << "// merged presynaptic update group " << s.getIndex() << std::endl; @@ -428,85 +448,94 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // generate the code for processing spike-like events if (s.getArchetype().isSpikeEventRequired()) { - genPresynapticUpdate(synapseUpdate, modelMerged, s, funcSubs, false); + genPresynapticUpdate(groupEnv, modelMerged, s, false); } // generate the code for processing true spike events if (s.getArchetype().isTrueSpikeRequired()) { - genPresynapticUpdate(synapseUpdate, modelMerged, s, funcSubs, true); + genPresynapticUpdate(groupEnv, modelMerged, s, true); } - synapseUpdate << std::endl; + funcEnv.getStream() << std::endl; } }); } // Postsynaptic update { - Timer t(synapseUpdate, "postsynapticUpdate", model.isTimingEnabled()); + Timer t(funcEnv.getStream(), "postsynapticUpdate", model.isTimingEnabled()); modelMerged.genMergedPostsynapticUpdateGroups( *this, - [this, &funcSubs, &synapseUpdate](SynapseDynamicsGroupMerged &s) + [this, &funcEnv](auto &s) { - CodeStream::Scope b(synapseUpdate); - synapseUpdate << "// merged postsynaptic update group " << s.getIndex() << std::endl; - synapseUpdate << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged postsynaptic update group " << s.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; { - CodeStream::Scope b(synapseUpdate); + CodeStream::Scope b(funcEnv.getStream()); // Get reference to group - synapseUpdate << "const auto *group = &mergedPostsynapticUpdateGroup" << s.getIndex() << "[g]; " << std::endl; + funcEnv.getStream() << "const auto *group = &mergedPostsynapticUpdateGroup" << s.getIndex() << "[g]; " << std::endl; + + // Create matching environment + EnvironmentGroupMergedField groupEnv(funcEnv, s); - genSynapseIndexCalculation(synapseUpdate, s, 1); + genSynapseIndexCalculation(groupEnv, s, 1); // Get number of postsynaptic spikes if (s.getArchetype().getTrgNeuronGroup()->isDelayRequired() && s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired()) { - synapseUpdate << "const unsigned int numSpikes = group->trgSpkCnt[postDelaySlot];" << std::endl; + groupEnv.getStream() << "const unsigned int numSpikes = group->trgSpkCnt[postDelaySlot];" << std::endl; } else { - synapseUpdate << "const unsigned int numSpikes = group->trgSpkCnt[0];" << std::endl; + groupEnv.getStream() << "const unsigned int numSpikes = group->trgSpkCnt[0];" << std::endl; } // Loop through postsynaptic spikes - synapseUpdate << "for (unsigned int j = 0; j < numSpikes; j++)"; + groupEnv.getStream() << "for (unsigned int j = 0; j < numSpikes; j++)"; { - CodeStream::Scope b(synapseUpdate); + CodeStream::Scope b(groupEnv.getStream()); + // **TODO** prod types const std::string offsetTrueSpkPost = (s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired() && s.getArchetype().getTrgNeuronGroup()->isDelayRequired()) ? "postDelayOffset + " : ""; - synapseUpdate << "const unsigned int spike = group->trgSpk[" << offsetTrueSpkPost << "j];" << std::endl; + groupEnv.getStream() << "const unsigned int spike = group->trgSpk[" << offsetTrueSpkPost << "j];" << std::endl; // Loop through column of presynaptic neurons if (s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - synapseUpdate << "const unsigned int npre = group->colLength[spike];" << std::endl; - synapseUpdate << "for (unsigned int i = 0; i < npre; i++)"; + groupEnv.getStream() << "const unsigned int npre = group->colLength[spike];" << std::endl; + groupEnv.getStream() << "for (unsigned int i = 0; i < npre; i++)"; } else { - synapseUpdate << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)"; + groupEnv.getStream() << "for (unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)"; } { - CodeStream::Scope b(synapseUpdate); + CodeStream::Scope b(groupEnv.getStream()); - Substitutions synSubs(&funcSubs); if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - synapseUpdate << "const unsigned int colMajorIndex = (spike * group->colStride) + i;" << std::endl; - synapseUpdate << "const unsigned int rowMajorIndex = group->remap[colMajorIndex];" << std::endl; - + // Add initialisers to calculate column and row-major indices // **TODO** fast divide optimisations - synSubs.addVarSubstitution("id_pre", "(rowMajorIndex / group->rowStride)"); - synSubs.addVarSubstitution("id_syn", "rowMajorIndex"); + const size_t colMajorIdxInit = groupEnv.addInitialiser("const unsigned int colMajorIndex = (spike * " + groupEnv["_col_stride"] + ") + i;"); + const size_t rowMajorIdxInit = groupEnv.addInitialiser("const unsigned int rowMajorIndex = " + groupEnv["_remap"] + "[colMajorIndex];"); + const size_t idPreInit = groupEnv.addInitialiser("const unsigned int idPre = rowMajorIndex / " + groupEnv["_row_stride"] + ";"); + + // Add presynaptic and synapse index to environment + groupEnv.add("id_pre", "idPre", {colMajorIdxInit, rowMajorIdxInit, idPreInit}, {"_col_stride", "_row_stride", "_remap"}); + groupEnv.add("id_syn", "rowMajorIndex", {colMajorIdxInit, rowMajorIdxInit}, {"_col_stride", "_remap"}); } else { - synSubs.addVarSubstitution("id_pre", "i"); - synSubs.addVarSubstitution("id_syn", "((group->numTrgNeurons * i) + spike)"); - } - synSubs.addVarSubstitution("id_post", "spike"); - if (s.getArchetype().isPresynapticOutputRequired()) { - synSubs.addFuncSubstitution("addToPre", 1, "group->revInSyn[" + s.getPreISynIndex(1, synSubs["id_pre"]) + "] += $(0)"); + // Add initialiser to calculate synaptic index + const size_t idSynInit = groupEnv.addInitialiser("const unsigned int idSyn = (i * " + groupEnv["num_post"] + ") + spike;"); + + // Add presynaptic and synapse index to environment + groupEnv.add(Type::Uint32, "id_pre", "i"); + groupEnv.add(Type::Uint32, "id_syn", "idSyn", {idSynInit}, {"num_post"}); } + + groupEnv.add(Type::Uint32, "id_post", "spike"); + groupEnv.add(Type::AddToPre, "addToPre", groupEnv["_out_pre"] + "[" + s.getPreISynIndex(1, groupEnv["id_pre"]) + "] += $(0)"); - s.generateSynapseUpdate(*this, synapseUpdate, modelMerged, synSubs); + s.generateSynapseUpdate(*this, groupEnv, modelMerged); } } - synapseUpdate << std::endl; + groupEnv.getStream() << std::endl; } }); } @@ -1426,57 +1455,64 @@ std::optional Backend::getMergedGroupSimRNGType() const return std::nullopt; } //-------------------------------------------------------------------------- -void Backend::genPopVariableInit(EnvironmentExternal &env, HandlerEnv handler) const +void Backend::genPopVariableInit(EnvironmentExternalBase &env, HandlerEnv handler) const { handler(env); } //-------------------------------------------------------------------------- -void Backend::genVariableInit(EnvironmentExternal &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const +void Backend::genVariableInit(EnvironmentExternalBase &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const { // **TODO** loops like this should be generated like CUDA threads - env.getStream() << "for (unsigned i = 0; i < (" << count << "); i++)"; + env.getStream() << "for (unsigned int i = 0; i < (" << count << "); i++)"; { CodeStream::Scope b(env.getStream()); - EnvironmentSubstitute varSubs(env); - varSubs.addSubstitution(indexVarName, "i"); - handler(varSubs); + EnvironmentExternal varEnv(env); + varEnv.add(Type::Uint32, indexVarName, "i"); + handler(varEnv); } } //-------------------------------------------------------------------------- -void Backend::genSparseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const +void Backend::genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const { - os << "for (unsigned j = 0; j < group->rowLength[" << kernelSubs["id_pre"] << "]; j++)"; + env.getStream() << "for (unsigned int j = 0; j < group->rowLength[" << env["id_pre"] << "]; j++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); - Substitutions varSubs(&kernelSubs); - varSubs.addVarSubstitution("id_syn", "(" + kernelSubs["id_pre"] + " * group->rowStride) + j"); - varSubs.addVarSubstitution("id_post", "group->ind[(" + kernelSubs["id_pre"] + " * group->rowStride) + j]"); - handler(os, varSubs); + EnvironmentExternal varEnv(env); + // **TODO** 64-bit + varEnv.add(Type::Uint32, "id_syn", "idSyn", + {varEnv.addInitialiser("const unsigned int idSyn = (" + varEnv["id_pre"] + " * " + varEnv["_row_stride"] + ") + j;")}, + {"id_pre", "_rowStride"}); + varEnv.add(Type::Uint32, "id_post", "idPost", + {varEnv.addInitialiser("const unsigned int idPost = (" + varEnv["_ind"] + "[(" + varEnv["id_pre"] + " * " + varEnv["_row_stride"] + ") + j]"); + handler(varEnv); } } //-------------------------------------------------------------------------- -void Backend::genDenseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const +void Backend::genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const { - os << "for (unsigned j = 0; j < group->numTrgNeurons; j++)"; + env.getStream() << "for (unsigned int j = 0; j < " << env["num_post"] << "; j++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); - Substitutions varSubs(&kernelSubs); - varSubs.addVarSubstitution("id_syn", "(" + kernelSubs["id_pre"] + " * group->rowStride) + j"); - varSubs.addVarSubstitution("id_post", "j"); - handler(os, varSubs); + EnvironmentExternal varEnv(env); + // **TODO** 64-bit + varEnv.add(Type::Uint32, "id_syn", "idSyn", + {varEnv.addInitialiser("const unsigned int idSyn = (" + varEnv["id_pre"] + " * " + varEnv["_row_stride"] + ") + j;")}, + {"id_pre", "_rowStride"}); + varEnv.add(Type::Uint32, "id_post", "j"); + handler(varEnv); } } //-------------------------------------------------------------------------- -void Backend::genKernelSynapseVariableInit(CodeStream &os, const SynapseInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const +void Backend::genKernelSynapseVariableInit(EnvironmentExternalBase &env, const SynapseInitGroupMerged &sg, HandlerEnv handler) const { assert(false); //genKernelIteration(os, sg, sg.getArchetype().getKernelSize().size(), kernelSubs, handler); } //-------------------------------------------------------------------------- -void Backend::genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged &cu, const Substitutions &kernelSubs, Handler handler) const +void Backend::genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, const CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const { assert(false); //genKernelIteration(os, cu, cu.getArchetype().getSynapseGroup()->getKernelSize().size(), kernelSubs, handler); diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 8b1267f59a..57c257e270 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -33,131 +33,6 @@ bool BackendBase::areSixtyFourBitSynapseIndicesRequired(const SynapseGroupMerged return ((maxSynapses & 0xFFFFFFFF00000000ULL) != 0); } //----------------------------------------------------------------------- -void BackendBase::genNeuronIndexCalculation(CodeStream &os, const NeuronUpdateGroupMerged &ng, unsigned int batchSize) const -{ - // If batching is enabled, calculate batch offset - if(batchSize > 1) { - os << "const unsigned int batchOffset = group->numNeurons * batch;" << std::endl; - } - - // If axonal delays are required - if(ng.getArchetype().isDelayRequired()) { - // We should READ from delay slot before spkQuePtr - os << "const unsigned int readDelaySlot = (*group->spkQuePtr + " << (ng.getArchetype().getNumDelaySlots() - 1) << ") % " << ng.getArchetype().getNumDelaySlots() << ";" << std::endl; - os << "const unsigned int readDelayOffset = readDelaySlot * group->numNeurons;" << std::endl; - - // And we should WRITE to delay slot pointed to be spkQuePtr - os << "const unsigned int writeDelaySlot = *group->spkQuePtr;" << std::endl; - os << "const unsigned int writeDelayOffset = writeDelaySlot * group->numNeurons;" << std::endl; - - // If batching is also enabled - if(batchSize > 1) { - // Calculate batched delay slots - os << "const unsigned int readBatchDelaySlot = (batch * " << ng.getArchetype().getNumDelaySlots() << ") + readDelaySlot;" << std::endl; - os << "const unsigned int writeBatchDelaySlot = (batch * " << ng.getArchetype().getNumDelaySlots() << ") + writeDelaySlot;" << std::endl; - - // Calculate current batch offset - os << "const unsigned int batchDelayOffset = batchOffset * " << ng.getArchetype().getNumDelaySlots() << ";" << std::endl; - - // Calculate further offsets to include delay and batch - os << "const unsigned int readBatchDelayOffset = readDelayOffset + batchDelayOffset;" << std::endl; - os << "const unsigned int writeBatchDelayOffset = writeDelayOffset + batchDelayOffset;" << std::endl; - } - } -} -//----------------------------------------------------------------------- -void BackendBase::genSynapseIndexCalculation(EnvironmentExternalBase &env, const SynapseGroupMergedBase &sg, unsigned int batchSize) const -{ - // If batching is enabled - if(batchSize > 1) { - // Calculate batch offsets into pre and postsynaptic populations - env.add(Type::Uint32.addConst(), "_pre_batch_offset", "preBatchOffset", - {env.addInitialiser("const unsigned int preBatchOffset = " + env["num_pre"] + " * " + env["batch"] + ";")}); - env.add(Type::Uint32.addConst(), "_post_batch_offset", "postBatchOffset", - {env.addInitialiser("const unsigned int preBatchOffset = " + env["num_post"] + " * " + env["batch"] + ";")}); - - // Calculate batch offsets into synapse arrays, using 64-bit arithmetic if necessary - if(areSixtyFourBitSynapseIndicesRequired(sg)) { - assert(false); - //os << "const uint64_t synBatchOffset = (uint64_t)preBatchOffset * (uint64_t)group->rowStride;" << std::endl; - } - else { - env.add(Type::Uint32.addConst(), "_syn_batch_offset", "synBatchOffset", - {env.addInitialiser("const unsigned int synBatchOffset = " + env["_pre_batch_offset"] + " * " + env["_row_stride"] + ";")}); - } - - // If synapse group has kernel weights - /*const auto &kernelSize = sg.getArchetype().getKernelSize(); - if((sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL) && !kernelSize.empty()) { - // Loop through kernel dimensions and multiply together - os << "const unsigned int kernBatchOffset = "; - for(size_t i = 0; i < kernelSize.size(); i++) { - os << sg.getKernelSize(i) << " * "; - } - - // And finally by batch - os << "batch;" << std::endl; - } - } - - // If presynaptic neuron group has variable queues, calculate offset to read from its variables with axonal delay - if(sg.getArchetype().getSrcNeuronGroup()->isDelayRequired()) { - const unsigned int numDelaySteps = sg.getArchetype().getDelaySteps(); - const unsigned int numSrcDelaySlots = sg.getArchetype().getSrcNeuronGroup()->getNumDelaySlots(); - - os << "const unsigned int preDelaySlot = "; - if(numDelaySteps == 0) { - os << "*group->srcSpkQuePtr;" << std::endl; - } - else { - os << "(*group->srcSpkQuePtr + " << (numSrcDelaySlots - numDelaySteps) << ") % " << numSrcDelaySlots << ";" << std::endl; - } - os << "const unsigned int preDelayOffset = preDelaySlot * group->numSrcNeurons;" << std::endl; - - if(batchSize > 1) { - os << "const unsigned int preBatchDelaySlot = preDelaySlot + (batch * " << numSrcDelaySlots << ");" << std::endl; - os << "const unsigned int preBatchDelayOffset = preDelayOffset + (preBatchOffset * " << numSrcDelaySlots << ");" << std::endl; - } - - if(sg.getArchetype().getWUModel()->isPrevPreSpikeTimeRequired() || sg.getArchetype().getWUModel()->isPrevPreSpikeEventTimeRequired()) { - os << "const unsigned int prePrevSpikeTimeDelayOffset = " << "((*group->srcSpkQuePtr + " << (numSrcDelaySlots - numDelaySteps - 1) << ") % " << numSrcDelaySlots << ")" << " * group->numSrcNeurons;" << std::endl; - - if(batchSize > 1) { - os << "const unsigned int prePrevSpikeTimeBatchDelayOffset = prePrevSpikeTimeDelayOffset + (preBatchOffset * " << numSrcDelaySlots << ");" << std::endl; - } - } - } - - // If postsynaptic neuron group has variable queues, calculate offset to read from its variables at current time - if(sg.getArchetype().getTrgNeuronGroup()->isDelayRequired()) { - const unsigned int numBackPropDelaySteps = sg.getArchetype().getBackPropDelaySteps(); - const unsigned int numTrgDelaySlots = sg.getArchetype().getTrgNeuronGroup()->getNumDelaySlots(); - - os << "const unsigned int postDelaySlot = "; - if(numBackPropDelaySteps == 0) { - os << "*group->trgSpkQuePtr;" << std::endl; - } - else { - os << "(*group->trgSpkQuePtr + " << (numTrgDelaySlots - numBackPropDelaySteps) << ") % " << numTrgDelaySlots << ";" << std::endl; - } - os << "const unsigned int postDelayOffset = postDelaySlot * group->numTrgNeurons;" << std::endl; - - if(batchSize > 1) { - os << "const unsigned int postBatchDelaySlot = postDelaySlot + (batch * " << numTrgDelaySlots << ");" << std::endl; - os << "const unsigned int postBatchDelayOffset = postDelayOffset + (postBatchOffset * " << numTrgDelaySlots << ");" << std::endl; - } - - if(sg.getArchetype().getWUModel()->isPrevPostSpikeTimeRequired()) { - os << "const unsigned int postPrevSpikeTimeDelayOffset = " << "((*group->trgSpkQuePtr + " << (numTrgDelaySlots - numBackPropDelaySteps - 1) << ") % " << numTrgDelaySlots << ")" << " * group->numTrgNeurons;" << std::endl; - - if(batchSize > 1) { - os << "const unsigned int postPrevSpikeTimeBatchDelayOffset = postPrevSpikeTimeDelayOffset + (postBatchOffset * " << numTrgDelaySlots << ");" << std::endl; - } - - } - }*/ -} -//----------------------------------------------------------------------- void BackendBase::genCustomUpdateIndexCalculation(CodeStream &os, const CustomUpdateGroupMerged &cu) const { // If batching is enabled, calculate batch offset diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index 967dcc3be0..0b33316570 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -64,6 +64,47 @@ std::vector EnvironmentExternalBase::getContextTypes(const T m_Context); } +//--------------------------------------------------------------------------- +// GeNN::CodeGenerator::EnvironmentLibrary +//--------------------------------------------------------------------------- +std::vector EnvironmentLibrary::getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) +{ + const auto [typeBegin, typeEnd] = m_Library.get().equal_range(name.lexeme); + if (typeBegin == typeEnd) { + errorHandler.error(name, "Undefined identifier"); + throw TypeChecker::TypeCheckError(); + } + else { + std::vector types; + types.reserve(std::distance(typeBegin, typeEnd)); + std::transform(typeBegin, typeEnd, std::back_inserter(types), + [](const auto &t) { return t.second.first; }); + return types; + } +} +//--------------------------------------------------------------------------- +std::string EnvironmentLibrary::getName(const std::string &name, std::optional type) +{ + const auto [libTypeBegin, libTypeEnd] = m_Library.get().equal_range(name); + if (libTypeBegin == libTypeEnd) { + return getContextName(name, type); + } + else { + if (!type) { + throw std::runtime_error("Ambiguous reference to '" + name + "' but no type provided to disambiguate"); + } + const auto libType = std::find_if(libTypeBegin, libTypeEnd, + [type](const auto &t){ return t.second.first == type; }); + assert(libType != libTypeEnd); + return libType->second.second; + } +} +//--------------------------------------------------------------------------- +CodeStream &EnvironmentLibrary::getStream() +{ + return getContextStream(); +} + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentSubstitute //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/standardLibrary.cc b/src/genn/genn/code_generator/standardLibrary.cc index 1fb50de5ef..a09db367b7 100644 --- a/src/genn/genn/code_generator/standardLibrary.cc +++ b/src/genn/genn/code_generator/standardLibrary.cc @@ -7,13 +7,6 @@ // GeNN includes #include "type.h" -// Transpiler includes -#include "transpiler/errorHandler.h" -#include "transpiler/typeChecker.h" - -using namespace GeNN::CodeGenerator; -using namespace GeNN::CodeGenerator::StandardLibrary; -using namespace GeNN::Transpiler::TypeChecker; namespace Type = GeNN::Type; //--------------------------------------------------------------------------- @@ -39,7 +32,7 @@ namespace template auto initLibraryTypes(Args&&... args) { - std::unordered_multimap> map; + GeNN::CodeGenerator::EnvironmentLibrary::Library map; (map.emplace(std::forward(args)), ...); return map; } @@ -129,43 +122,8 @@ const auto libraryTypes = initLibraryTypes( */ //min, max, printf -//--------------------------------------------------------------------------- -// GeNN::Transpiler::StandardLibrary::Environment -//--------------------------------------------------------------------------- -std::vector Environment::getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) -{ - const auto [typeBegin, typeEnd] = libraryTypes.equal_range(name.lexeme); - if (typeBegin == typeEnd) { - errorHandler.error(name, "Undefined identifier"); - throw TypeCheckError(); - } - else { - std::vector types; - types.reserve(std::distance(typeBegin, typeEnd)); - std::transform(typeBegin, typeEnd, std::back_inserter(types), - [](const auto &t) { return t.second.first; }); - return types; - } -} -//--------------------------------------------------------------------------- -std::string Environment::getName(const std::string &name, std::optional type) + +const GeNN::CodeGenerator::EnvironmentLibrary::Library &GeNN::CodeGenerator::StandardLibrary::getFunctions() { - const auto [libTypeBegin, libTypeEnd] = libraryTypes.equal_range(name); - if (libTypeBegin == libTypeEnd) { - return getContextName(name, type); - } - else { - if (!type) { - throw std::runtime_error("Ambiguous reference to '" + name + "' but no type provided to disambiguate"); - } - const auto libType = std::find_if(libTypeBegin, libTypeEnd, - [type](const auto &t){ return t.second.first == type; }); - assert(libType != libTypeEnd); - return libType->second.second; - } + return libraryTypes; } -//--------------------------------------------------------------------------- -CodeStream &Environment::getStream() -{ - return getContextStream(); -} \ No newline at end of file From 1e40569f1e22ace61167675fe1774db35824e32d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 15 Jun 2023 10:04:58 +0100 Subject: [PATCH 220/725] started stripping out SynapeGroupMergedBase --- .../genn/genn/code_generator/groupMerged.h | 27 -- src/genn/genn/code_generator/groupMerged.cc | 350 ------------------ 2 files changed, 377 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 3d3066ad8d..c5a362baaa 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -802,38 +802,16 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged> &groups); - //---------------------------------------------------------------------------- // Protected methods //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type getHashDigest(Role role) const; - const std::string &getArchetypeCode() const { return m_ArchetypeCode; } private: //------------------------------------------------------------------------ // Private methods //------------------------------------------------------------------------ - void addPSPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix); - void addPreOutputPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix); - void addSrcPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix); - void addTrgPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix); - std::string getVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index, const std::string &prefix) const; @@ -862,10 +840,5 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged> &groups) -: GroupMerged(index, typeContext, groups), m_ArchetypeCode(archetypeCode) -{ - using namespace Type; - - const bool updateRole = ((role == Role::PresynapticUpdate) - || (role == Role::PostsynapticUpdate) - || (role == Role::SynapseDynamics)); - const WeightUpdateModels::Base *wum = getArchetype().getWUModel(); - - // If role isn't an init role or weights aren't kernel - if(role != Role::Init || !(getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL)) { - addField(Uint32, "rowStride", - [&backend](const SynapseGroupInternal &sg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(sg)); }); - addField(Uint32, "numSrcNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); - addField(Uint32, "numTrgNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); - } - - if(role == Role::PostsynapticUpdate || role == Role::SparseInit) { - addField(Uint32, "colStride", - [](const auto &sg, size_t) { return std::to_string(sg.getMaxSourceConnections()); }); - } - - // If this role is one where postsynaptic input can be provided - if(role == Role::PresynapticUpdate || role == Role::SynapseDynamics) { - if(getArchetype().isDendriticDelayRequired()) { - addPSPointerField(getScalarType(), "denDelay", backend.getDeviceVarPrefix() + "denDelay"); - addPSPointerField(Uint32, "denDelayPtr", backend.getScalarAddressPrefix() + "denDelayPtr"); - } - else { - addPSPointerField(getScalarType(), "inSyn", backend.getDeviceVarPrefix() + "inSyn"); - } - } - - if(role == Role::PresynapticUpdate) { - if(getArchetype().isTrueSpikeRequired()) { - addSrcPointerField(Uint32, "srcSpkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); - addSrcPointerField(Uint32, "srcSpk", backend.getDeviceVarPrefix() + "glbSpk"); - } - - if(getArchetype().isSpikeEventRequired()) { - addSrcPointerField(Uint32, "srcSpkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); - addSrcPointerField(Uint32, "srcSpkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); - } - } - else if(role == Role::PostsynapticUpdate) { - addTrgPointerField(Uint32, "trgSpkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); - addTrgPointerField(Uint32, "trgSpk", backend.getDeviceVarPrefix() + "glbSpk"); - } - - // If this structure is used for updating rather than initializing - if(updateRole) { - // for all types of roles - if (getArchetype().isPresynapticOutputRequired()) { - addPreOutputPointerField(getScalarType(), "revInSyn", backend.getDeviceVarPrefix() + "revInSyn"); - } - - // If presynaptic population has delay buffers - if(getArchetype().getSrcNeuronGroup()->isDelayRequired()) { - addSrcPointerField(Uint32, "srcSpkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); - } - - // If postsynaptic population has delay buffers - if(getArchetype().getTrgNeuronGroup()->isDelayRequired()) { - addTrgPointerField(Uint32, "trgSpkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); - } - - // Add heterogeneous presynaptic neuron model parameters - addHeterogeneousParams( - getArchetype().getSrcNeuronGroup()->getNeuronModel()->getParamNames(), "Pre", - [](const auto &sg) { return sg.getSrcNeuronGroup()->getParams(); }, - &SynapseGroupMergedBase::isSrcNeuronParamHeterogeneous); - - // Add heterogeneous presynaptic neuron model derived parameters - addHeterogeneousDerivedParams( - getArchetype().getSrcNeuronGroup()->getNeuronModel()->getDerivedParams(), "Pre", - [](const auto &sg) { return sg.getSrcNeuronGroup()->getDerivedParams(); }, - &SynapseGroupMergedBase::isSrcNeuronDerivedParamHeterogeneous); - - // Add heterogeneous postsynaptic neuron model parameters - addHeterogeneousParams( - getArchetype().getTrgNeuronGroup()->getNeuronModel()->getParamNames(), "Post", - [](const auto &sg) { return sg.getTrgNeuronGroup()->getParams(); }, - &SynapseGroupMergedBase::isTrgNeuronParamHeterogeneous); - - // Add heterogeneous postsynaptic neuron model derived parameters - addHeterogeneousDerivedParams( - getArchetype().getTrgNeuronGroup()->getNeuronModel()->getDerivedParams(), "Post", - [](const auto &sg) { return sg.getTrgNeuronGroup()->getDerivedParams(); }, - &SynapseGroupMergedBase::isTrgNeuronDerivedParamHeterogeneous); - - // Get correct code string - const std::string code = getArchetypeCode(); - - // Loop through variables in presynaptic neuron model - const auto preVars = getArchetype().getSrcNeuronGroup()->getNeuronModel()->getVars(); - for(const auto &v : preVars) { - // If variable is referenced in code string, add source pointer - if(code.find("$(" + v.name + "_pre)") != std::string::npos) { - addSrcPointerField(v.type.resolve(getTypeContext()), v.name + "Pre", - backend.getDeviceVarPrefix() + v.name); - } - } - - // Loop through variables in postsynaptic neuron model - const auto postVars = getArchetype().getTrgNeuronGroup()->getNeuronModel()->getVars(); - for(const auto &v : postVars) { - // If variable is referenced in code string, add target pointer - if(code.find("$(" + v.name + "_post)") != std::string::npos) { - addTrgPointerField(v.type.resolve(getTypeContext()), v.name + "Post", - backend.getDeviceVarPrefix() + v.name); - } - } - - // Loop through extra global parameters in presynaptic neuron model - const auto preEGPs = getArchetype().getSrcNeuronGroup()->getNeuronModel()->getExtraGlobalParams(); - for(const auto &e : preEGPs) { - if(code.find("$(" + e.name + "_pre)") != std::string::npos) { - const std::string prefix = backend.getDeviceVarPrefix(); - addField(e.type.resolve(getTypeContext()).createPointer(), e.name + "Pre", - [e, prefix](const auto &sg, size_t) { return prefix + e.name + sg.getSrcNeuronGroup()->getName(); }, - GroupMergedFieldType::DYNAMIC); - } - } - - // Loop through extra global parameters in postsynaptic neuron model - const auto postEGPs = getArchetype().getTrgNeuronGroup()->getNeuronModel()->getExtraGlobalParams(); - for(const auto &e : postEGPs) { - if(code.find("$(" + e.name + "_post)") != std::string::npos) { - const std::string prefix = backend.getDeviceVarPrefix(); - addField(e.type.resolve(getTypeContext()).createPointer(), e.name + "Post", - [e, prefix](const auto &sg, size_t) { return prefix + e.name + sg.getTrgNeuronGroup()->getName(); }, - GroupMergedFieldType::DYNAMIC); - } - } - - // Add spike times if required - if(wum->isPreSpikeTimeRequired()) { - addSrcPointerField(getTimeType(), "sTPre", backend.getDeviceVarPrefix() + "sT"); - } - if(wum->isPostSpikeTimeRequired()) { - addTrgPointerField(getTimeType(), "sTPost", backend.getDeviceVarPrefix() + "sT"); - } - if(wum->isPreSpikeEventTimeRequired()) { - addSrcPointerField(getTimeType(), "seTPre", backend.getDeviceVarPrefix() + "seT"); - } - if(wum->isPrevPreSpikeTimeRequired()) { - addSrcPointerField(getTimeType(), "prevSTPre", backend.getDeviceVarPrefix() + "prevST"); - } - if(wum->isPrevPostSpikeTimeRequired()) { - addTrgPointerField(getTimeType(), "prevSTPost", backend.getDeviceVarPrefix() + "prevST"); - } - if(wum->isPrevPreSpikeEventTimeRequired()) { - addSrcPointerField(getTimeType(), "prevSETPre", backend.getDeviceVarPrefix() + "prevSET"); - } - // Add heterogeneous weight update model parameters - addHeterogeneousParams( - wum->getParamNames(), "", - [](const auto &sg) { return sg.getWUParams(); }, - &SynapseGroupMergedBase::isWUParamHeterogeneous); - - // Add heterogeneous weight update model derived parameters - addHeterogeneousDerivedParams( - wum->getDerivedParams(), "", - [](const auto &sg) { return sg.getWUDerivedParams(); }, - &SynapseGroupMergedBase::isWUDerivedParamHeterogeneous); - - // Add presynaptic variables to struct - for(const auto &v : wum->getPreVars()) { - const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(v.type.resolve(getTypeContext()).createPointer(), v.name, - [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPreVarSuffix(); }); - } - - // Add presynaptic variables to struct - for(const auto &v : wum->getPostVars()) { - const std::string prefix = backend.getDeviceVarPrefix() + v.name; - addField(v.type.resolve(getTypeContext()).createPointer(), v.name, - [prefix](const auto &g, size_t) { return prefix + g.getFusedWUPostVarSuffix(); }); - } - - // Add EGPs to struct - addEGPs(wum->getExtraGlobalParams(), backend.getDeviceVarPrefix()); - } - - // Add pointers to connectivity data - if(getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addPointerField(Uint32, "rowLength", backend.getDeviceVarPrefix() + "rowLength"); - addPointerField(getArchetype().getSparseIndType(), "ind", backend.getDeviceVarPrefix() + "ind"); - - // Add additional structure for postsynaptic access - if(backend.isPostsynapticRemapRequired() && !wum->getLearnPostCode().empty() - && (role == Role::PostsynapticUpdate || role == Role::SparseInit)) - { - addPointerField(Uint32, "colLength", backend.getDeviceVarPrefix() + "colLength"); - addPointerField(Uint32, "remap", backend.getDeviceVarPrefix() + "remap"); - } - } - else if(getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { - addPointerField(Uint32, "gp", backend.getDeviceVarPrefix() + "gp"); - } - - // If we're updating a group with procedural connectivity or initialising connectivity - if((getArchetype().getMatrixType() & SynapseMatrixConnectivity::PROCEDURAL) || (role == Role::ConnectivityInit)) { - // Add heterogeneous sparse connectivity initialiser model parameters - addHeterogeneousParams( - getArchetype().getConnectivityInitialiser().getSnippet()->getParamNames(), "", - [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getParams(); }, - &SynapseGroupMergedBase::isSparseConnectivityInitParamHeterogeneous); - - - // Add heterogeneous sparse connectivity initialiser derived parameters - addHeterogeneousDerivedParams( - getArchetype().getConnectivityInitialiser().getSnippet()->getDerivedParams(), "", - [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); }, - &SynapseGroupMergedBase::isSparseConnectivityInitDerivedParamHeterogeneous); - - addEGPs(getArchetype().getConnectivityInitialiser().getSnippet()->getExtraGlobalParams(), - backend.getDeviceVarPrefix()); - } - - // If we're updating a group with Toeplitz connectivity - if(updateRole && (getArchetype().getMatrixType() & SynapseMatrixConnectivity::TOEPLITZ)) { - // Add heterogeneous toeplitz connectivity initialiser model parameters - addHeterogeneousParams( - getArchetype().getToeplitzConnectivityInitialiser().getSnippet()->getParamNames(), "", - [](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getParams(); }, - &SynapseGroupMergedBase::isToeplitzConnectivityInitParamHeterogeneous); - - - // Add heterogeneous toeplitz initialiser derived parameters - addHeterogeneousDerivedParams( - getArchetype().getToeplitzConnectivityInitialiser().getSnippet()->getDerivedParams(), "", - [](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getDerivedParams(); }, - &SynapseGroupMergedBase::isToeplitzConnectivityInitDerivedParamHeterogeneous); - - addEGPs(getArchetype().getToeplitzConnectivityInitialiser().getSnippet()->getExtraGlobalParams(), - backend.getDeviceVarPrefix()); - } - - // If WU variables are global - const auto &varInit = getArchetype().getWUVarInitialisers(); - if(getArchetype().getMatrixType() & SynapseMatrixWeight::GLOBAL) { - // If this is an update role - // **NOTE **global variable values aren't useful during initialization - if(updateRole) { - for(const auto &var : wum->getVars()) { - // If variable should be implemented heterogeneously, add scalar field - if(isWUGlobalVarHeterogeneous(var.name)) { - addScalarField(var.name, - [var](const SynapseGroupInternal &sg, size_t) - { - return sg.getWUConstInitVals().at(var.name); - }); - } - } - } - } - // Otherwise (weights are individual or procedural) - else { - const bool connectInitRole = (role == Role::ConnectivityInit); - const bool varInitRole = (role == Role::Init || role == Role::SparseInit); - const bool proceduralWeights = (getArchetype().getMatrixType() & SynapseMatrixWeight::PROCEDURAL); - const bool kernelWeights = (getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL); - const bool individualWeights = (getArchetype().getMatrixType() & SynapseMatrixWeight::INDIVIDUAL); - - // If synapse group has a kernel and has kernel weights or initialising individual weights - if(!getArchetype().getKernelSize().empty() && ((proceduralWeights && updateRole) || kernelWeights || (connectInitRole && individualWeights))) { - // Loop through kernel size dimensions - for(size_t d = 0; d < getArchetype().getKernelSize().size(); d++) { - // If this dimension has a heterogeneous size, add it to struct - if(isKernelSizeHeterogeneous(d)) { - addField(Uint32, "kernelSize" + std::to_string(d), - [d](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getKernelSize().at(d)); }); - } - } - } - - // If weights are procedural, we're initializing individual variables or we're initialising variables in a kernel - // **NOTE** some of these won't actually be required - could do this per-variable in loop over vars - if((proceduralWeights && updateRole) || (connectInitRole && !getArchetype().getKernelSize().empty()) - || (varInitRole && (individualWeights || kernelWeights))) - { - // Add heterogeneous variable initialization parameters and derived parameters - addHeterogeneousVarInitParams( - &SynapseGroupMergedBase::isWUVarInitParamHeterogeneous); - - addHeterogeneousVarInitDerivedParams( - &SynapseGroupMergedBase::isWUVarInitDerivedParamHeterogeneous); - } - - // Loop through variables - for(const auto &var : wum->getVars()) { - // Variable initialisation is required if we're performing connectivity init and var init snippet requires a kernel or - // We're performing some other sort of initialisation, the snippet DOESN'T require a kernel but has SOME code - const auto *snippet = varInit.at(var.name).getSnippet(); - const bool varInitRequired = ((connectInitRole && snippet->requiresKernel()) - || (varInitRole && individualWeights && !snippet->requiresKernel() && !snippet->getCode().empty()) - || (varInitRole && kernelWeights && !snippet->getCode().empty())); - - // If we're performing an update with individual weights; or this variable should be initialised - if((updateRole && individualWeights) || (kernelWeights && updateRole) || varInitRequired) { - addPointerField(var.type, var.name, - backend.getDeviceVarPrefix() + var.name); - } - - // If we're performing a procedural update or this variable should be initialised, add any var init EGPs to structure - if((proceduralWeights && updateRole) || varInitRequired) { - const auto egps = snippet->getExtraGlobalParams(); - for(const auto &e : egps) { - const std::string prefix = backend.getDeviceVarPrefix(); - addField(e.type.resolve(getTypeContext()).createPointer(), e.name + var.name, - [e, prefix, var](const SynapseGroupInternal &sg, size_t) - { - return prefix + e.name + var.name + sg.getName(); - }, - GroupMergedFieldType::DYNAMIC); - } - } - } - } -} -//---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type SynapseGroupMergedBase::getHashDigest(Role role) const { const bool updateRole = ((role == Role::PresynapticUpdate) @@ -769,30 +443,6 @@ boost::uuids::detail::sha1::digest_type SynapseGroupMergedBase::getHashDigest(Ro return hash.get_digest(); } //---------------------------------------------------------------------------- -void SynapseGroupMergedBase::addPSPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix) -{ - assert(type.isValue()); - addField(type.createPointer(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPSVarSuffix(); }); -} -//---------------------------------------------------------------------------- -void SynapseGroupMergedBase::addPreOutputPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix) -{ - assert(type.isValue()); - addField(type.createPointer(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getFusedPreOutputSuffix(); }); -} -//---------------------------------------------------------------------------- -void SynapseGroupMergedBase::addSrcPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix) -{ - assert(type.isValue()); - addField(type.createPointer(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getSrcNeuronGroup()->getName(); }); -} -//---------------------------------------------------------------------------- -void SynapseGroupMergedBase::addTrgPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix) -{ - assert(type.isValue()); - addField(type.createPointer(), name, [prefix](const SynapseGroupInternal &sg, size_t) { return prefix + sg.getTrgNeuronGroup()->getName(); }); -} -//---------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index, const std::string &prefix) const { From 3c7958aa1cae3fa9b9f404c86218ae5120dbe9d7 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 15 Jun 2023 10:06:29 +0100 Subject: [PATCH 221/725] expand genSynapseIndexCalculation to include all generic fields --- .../genn/genn/code_generator/backendBase.h | 116 +++++++++++++++--- .../backends/single_threaded_cpu/backend.cc | 85 ++++++------- 2 files changed, 135 insertions(+), 66 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index e72fa66499..9ac3536d7a 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -526,13 +526,74 @@ class GENN_EXPORT BackendBase template void genSynapseIndexCalculation(EnvironmentGroupMergedField &env, const SynapseGroupMergedBase &sg, unsigned int batchSize) const { + // Synapse group fields + groupEnv.add(Type::Uint32.addConst(), "num_pre", + Type::Uint32, "numSrcNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + groupEnv.add(Type::Uint32.addConst(), "num_post", + Type::Uint32, "numTrgNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + groupEnv.add(Type::Uint32.addConst(), "_row_stride", + Type::Uint32, "rowStride", + [this](const SynapseGroupInternal &sg, size_t) { return std::to_string(getSynapticMatrixRowStride(sg)); }); + groupEnv.add(Type::Uint32.addConst(), "_col_stride", + Type::Uint32, "colStride", + [this](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getMaxSourceConnections()); }); + + // Postsynaptic model fields + groupEnv.add(modelMerged.getModel().getPrecision(), "_out_post", + modelMerged.getModel().getPrecision().createPointer(), "outPost", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); + groupEnv.add(modelMerged.getModel().getPrecision(), "_den_delay", + modelMerged.getModel().getPrecision().createPointer(), "denDelay", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); + groupEnv.add(Type::Uint32, "_den_delay_ptr", + Type::Uint32.createPointer(), "denDelayPtr", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); + + // Presynaptic output fields + groupEnv.add(modelMerged.getModel().getPrecision(), "_out_pre", + modelMerged.getModel().getPrecision().createPointer(), "outPre", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); + + + // Source neuron fields + groupEnv.add(Type::Uint32, "_src_spk_que_ptr", + Type::Uint32.createPointer(), "srcSpkQuePtr", + [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32, "_src_spk_cnt", + Type::Uint32.createPointer(), "srcSpkCnt", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkCnt" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32, "_src_spk", + Type::Uint32.createPointer(), "srcSpk", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpk" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32, "_src_spk_evnt_cnt", + Type::Uint32.createPointer(), "srcSpkCntEvnt", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32, "_src_spk_evnt", + Type::Uint32.createPointer(), "srcSpkEvnt", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkEvnt" + g.getSrcNeuronGroup()->getName(); }); + + // Target neuron fields + groupEnv.add(Type::Uint32, "_trg_spk_que_ptr", + Type::Uint32.createPointer(), "trgSpkQuePtr", + [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + g.getTrgNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32, "_trg_spk_cnt", + Type::Uint32.createPointer(), "trgSpkCnt", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkCnt" + g.getTrgNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32, "_trg_spk", + Type::Uint32.createPointer(), "trgSpk", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpk" + g.getTrgNeuronGroup()->getName(); }); + // If batching is enabled if(batchSize > 1) { // Calculate batch offsets into pre and postsynaptic populations env.add(Type::Uint32.addConst(), "_pre_batch_offset", "preBatchOffset", - {env.addInitialiser("const unsigned int preBatchOffset = " + env["num_pre"] + " * " + env["batch"] + ";")}); + {env.addInitialiser("const unsigned int preBatchOffset = " + env["num_pre"] + " * batch;")}, + {"num_pre"}); env.add(Type::Uint32.addConst(), "_post_batch_offset", "postBatchOffset", - {env.addInitialiser("const unsigned int preBatchOffset = " + env["num_post"] + " * " + env["batch"] + ";")}); + {env.addInitialiser("const unsigned int preBatchOffset = " + env["num_post"] + " * batch;")}, + {"num_post"}); // Calculate batch offsets into synapse arrays, using 64-bit arithmetic if necessary if(areSixtyFourBitSynapseIndicesRequired(sg)) { @@ -541,40 +602,59 @@ class GENN_EXPORT BackendBase } else { env.add(Type::Uint32.addConst(), "_syn_batch_offset", "synBatchOffset", - {env.addInitialiser("const unsigned int synBatchOffset = " + env["_pre_batch_offset"] + " * " + env["_row_stride"] + ";")}); + {env.addInitialiser("const unsigned int synBatchOffset = " + env["_pre_batch_offset"] + " * " + env["_row_stride"] + ";")}, + {"_pre_batch_offset", "_row_stride"}); } - // If synapse group has kernel weights - /*const auto &kernelSize = sg.getArchetype().getKernelSize(); - if((sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL) && !kernelSize.empty()) { + // If synapse group has kernel + const auto &kernelSize = sg.getArchetype().getKernelSize(); + if(!kernelSize.empty()) { // Loop through kernel dimensions and multiply together - os << "const unsigned int kernBatchOffset = "; + // **TODO** extract list of kernel size variables referenced + std::ostringstream kernBatchOffsetInit; + kernBatchOffsetInit << "const unsigned int kernBatchOffset = "; for(size_t i = 0; i < kernelSize.size(); i++) { - os << sg.getKernelSize(i) << " * "; + kernBatchOffsetInit << sg.getKernelSize(i) << " * "; } // And finally by batch - os << "batch;" << std::endl; - }*/ + kernBatchOffsetInit << "batch;" << std::endl; + + env.add(Type::Uint32.addConst(), "_kern_batch_offset", "kernBatchOffset", + {env.addInitialiser(kernBatchOffsetInit.str())}); + } } // If presynaptic neuron group has variable queues, calculate offset to read from its variables with axonal delay - /*if(sg.getArchetype().getSrcNeuronGroup()->isDelayRequired()) { + if(sg.getArchetype().getSrcNeuronGroup()->isDelayRequired()) { const unsigned int numDelaySteps = sg.getArchetype().getDelaySteps(); const unsigned int numSrcDelaySlots = sg.getArchetype().getSrcNeuronGroup()->getNumDelaySlots(); - os << "const unsigned int preDelaySlot = "; + std::ostringstream preDelaySlotInit; + preDelaySlotInit << "const unsigned int preDelaySlot = "; if(numDelaySteps == 0) { - os << "*group->srcSpkQuePtr;" << std::endl; + preDelaySlotInit << "*" << env["_src_spk_que_ptr"] << ";" << std::endl; } else { - os << "(*group->srcSpkQuePtr + " << (numSrcDelaySlots - numDelaySteps) << ") % " << numSrcDelaySlots << ";" << std::endl; + preDelaySlotInit << "(*" << env["_src_spk_que_ptr"] << " + " << (numSrcDelaySlots - numDelaySteps) << ") % " << numSrcDelaySlots << ";" << std::endl; } - os << "const unsigned int preDelayOffset = preDelaySlot * group->numSrcNeurons;" << std::endl; + env.add(Type::Uint32, "_pre_delay_slot", "preDelaySlot", + {env.addInitialiser(preDelaySlotInit.str())}, {"_src_spk_que_ptr"}); + + env.add(Type::Uint32, "_pre_delay_offset", "preDelayOffset", + {env.addInitialiser("const unsigned int preDelayOffset = preDelaySlot * " + env["num_pre"] + ";")}, + {"num_pre", "_pre_delay_slot"}); if(batchSize > 1) { - os << "const unsigned int preBatchDelaySlot = preDelaySlot + (batch * " << numSrcDelaySlots << ");" << std::endl; - os << "const unsigned int preBatchDelayOffset = preDelayOffset + (preBatchOffset * " << numSrcDelaySlots << ");" << std::endl; + env.add(Type::Uint32, "_pre_batch_delay_slot", "preBatchDelaySlot", + {env.addInitialiser("const unsigned int preBatchDelaySlot = preDelaySlot + (batch * " + std::to_string(numSrcDelaySlots) + ");")}, + {"_pre_delay_slot"}); + + os << << std::endl; + + env.add(Type::Uint32, "_pre_batch_delay_offset", "preBatchDelayOffset", + {env.addInitialiser("const unsigned int preBatchDelayOffset = preDelayOffset + (preBatchOffset * " + std::to_string(numSrcDelaySlots) + ");")}, + {"_pre_delay_offset", "_pre_batch_offset"}); } if(sg.getArchetype().getWUModel()->isPrevPreSpikeTimeRequired() || sg.getArchetype().getWUModel()->isPrevPreSpikeEventTimeRequired()) { @@ -613,7 +693,7 @@ class GENN_EXPORT BackendBase } } - }*/ + } } void genCustomUpdateIndexCalculation(CodeStream &os, const CustomUpdateGroupMerged &cu) const; diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 442cac63ef..5cedb4a78b 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -127,13 +127,11 @@ namespace GeNN::CodeGenerator::SingleThreadedCPU { void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { - const ModelSpecInternal &model = modelMerged.getModel(); - if(model.getBatchSize() != 1) { + if(modelMerged.getModel().getBatchSize() != 1) { throw std::runtime_error("The single-threaded CPU backend only supports simulations with a batch size of 1"); } - - // Generate stream with neuron update code + // Generate stream with neuron update code std::ostringstream neuronUpdateStream; CodeStream neuronUpdate(neuronUpdateStream); @@ -141,7 +139,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host EnvironmentLibrary neuronUpdateEnv(neuronUpdate, StandardLibrary::getFunctions()); neuronUpdateEnv.getStream() << "void updateNeurons(timepoint t"; - if(model.isRecordingInUse()) { + if(modelMerged.getModel().isRecordingInUse()) { neuronUpdateEnv.getStream() << ", unsigned int recordingTimestep"; } neuronUpdateEnv.getStream() << ")"; @@ -152,7 +150,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); funcEnv.add(Type::Uint32.addConst(), "batch", "0"); - Timer t(funcEnv.getStream(), "neuronUpdate", model.isTimingEnabled()); + Timer t(funcEnv.getStream(), "neuronUpdate", modelMerged.getModel().isTimingEnabled()); modelMerged.genMergedNeuronPrevSpikeTimeUpdateGroups( *this, [this, &funcEnv, &modelMerged](auto &n) @@ -316,8 +314,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host //-------------------------------------------------------------------------- void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { - const ModelSpecInternal &model = modelMerged.getModel(); - if (model.getBatchSize() != 1) { + if (modelMerged.getModel().getBatchSize() != 1) { throw std::runtime_error("The single-threaded CPU backend only supports simulations with a batch size of 1"); } @@ -338,7 +335,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Synapse dynamics { - Timer t(funcEnv.getStream(), "synapseDynamics", model.isTimingEnabled()); + Timer t(funcEnv.getStream(), "synapseDynamics", modelMerged.getModel().isTimingEnabled()); modelMerged.genMergedSynapseDynamicsGroups( *this, [this, &funcEnv, &modelMerged](auto &s) @@ -355,19 +352,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, s); - // Add basic fields **TODO** move to group merged - groupEnv.add(Type::Uint32.addConst(), "num_pre", - Type::Uint32, "numSrcNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); - groupEnv.add(Type::Uint32.addConst(), "num_post", - Type::Uint32, "numTrgNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); - groupEnv.add(Type::Uint32.addConst(), "_row_stride", - Type::Uint32, "rowStride", - [this](const SynapseGroupInternal &sg, size_t) { return std::to_string(getSynapticMatrixRowStride(sg)); }); - - // _row_length - // _ind + // **TODO** rename as it does more! genSynapseIndexCalculation(groupEnv, s, 1); // Loop through presynaptic neurons @@ -387,38 +372,41 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos } { CodeStream::Scope b(groupEnv.getStream()); + EnvironmentGroupMergedField synEnv(groupEnv, s); // Add presynaptic index to substitutions - groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); + synEnv.add(Type::Uint32.addConst(), "id_pre", "i"); if (s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Add initialiser strings to calculate synaptic and presynaptic index - const size_t idSynInit = groupEnv.addInitialiser("const unsigned int idSyn = (i * " + groupEnv["_row_stride"] + ") + s;"); - const size_t idPostInit = groupEnv.addInitialiser("const unsigned int idPost = " + groupEnv["_ind"] + "[idSyn];"); + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["_row_stride"] + ") + s;"); + const size_t idPostInit = synEnv.addInitialiser("const unsigned int idPost = " + synEnv["_ind"] + "[idSyn];"); // **TODO** id_syn can be 64-bit - groupEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}, {"_row_stride"}); - groupEnv.add(Type::Uint32.addConst(), "id_post", "idPost", {idPostInit, idSynInit}, {"_ind"}); + synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}, {"_row_stride"}); + synEnv.add(Type::Uint32.addConst(), "id_post", "idPost", {idPostInit, idSynInit}, {"_ind"}); } else { // Add postsynaptic index to substitutions - groupEnv.add(Type::Uint32.addConst(), "id_post", "j"); + synEnv.add(Type::Uint32.addConst(), "id_post", "j"); // Add initialiser to calculate synaptic index - const size_t idSynInit = groupEnv.addInitialiser("const unsigned int idSyn = (i * " + groupEnv["num_post"] + ") + j;"); + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["num_post"] + ") + j;"); // **TODO** id_syn can be 64-bit - groupEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}); - + synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}); } // Add correct functions for apply synaptic input - groupEnv.add(Type::AddToPostDenDelay, "addToPostDelay", groupEnv["_den_delay"] + "[" + s.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)"); - groupEnv.add(Type::AddToPost, "addToPost", groupEnv["_out_post"] + "[" + s.getPostISynIndex(1, "j") + "] += $(0)"); - groupEnv.add(Type::AddToPre, "addToPre", groupEnv["_out_pre"] + "[" + s.getPreISynIndex(1, groupEnv["id_pre"]) + "] += $(0)"); + synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", synEnv["_den_delay"] + "[" + s.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)", + {}, {"_den_delay"}); + synEnv.add(Type::AddToPost, "addToPost", synEnv["_out_post"] + "[" + s.getPostISynIndex(1, "j") + "] += $(0)", + {}, {"_out_post"}); + synEnv.add(Type::AddToPre, "addToPre", synEnv["_out_pre"] + "[" + s.getPreISynIndex(1, synEnv["id_pre"]) + "] += $(0)", + {}, {"id_pre"})); // Call synapse dynamics handler - s.generateSynapseUpdate(*this, groupEnv, modelMerged); + s.generateSynapseUpdate(*this, synEnv, modelMerged); } } } @@ -427,7 +415,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Presynaptic update { - Timer t(funcEnv.getStream(), "presynapticUpdate", model.isTimingEnabled()); + Timer t(funcEnv.getStream(), "presynapticUpdate", modelMerged.getModel().isTimingEnabled()); modelMerged.genMergedPresynapticUpdateGroups( *this, [this, &funcEnv](auto &s) @@ -462,7 +450,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Postsynaptic update { - Timer t(funcEnv.getStream(), "postsynapticUpdate", model.isTimingEnabled()); + Timer t(funcEnv.getStream(), "postsynapticUpdate", modelMerged.getModel().isTimingEnabled()); modelMerged.genMergedPostsynapticUpdateGroups( *this, [this, &funcEnv](auto &s) @@ -508,31 +496,32 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos } { CodeStream::Scope b(groupEnv.getStream()); + EnvironmentGroupMergedField synEnv(groupEnv, s); if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Add initialisers to calculate column and row-major indices // **TODO** fast divide optimisations - const size_t colMajorIdxInit = groupEnv.addInitialiser("const unsigned int colMajorIndex = (spike * " + groupEnv["_col_stride"] + ") + i;"); - const size_t rowMajorIdxInit = groupEnv.addInitialiser("const unsigned int rowMajorIndex = " + groupEnv["_remap"] + "[colMajorIndex];"); - const size_t idPreInit = groupEnv.addInitialiser("const unsigned int idPre = rowMajorIndex / " + groupEnv["_row_stride"] + ";"); + const size_t colMajorIdxInit = synEnv.addInitialiser("const unsigned int colMajorIndex = (spike * " + synEnv["_col_stride"] + ") + i;"); + const size_t rowMajorIdxInit = synEnv.addInitialiser("const unsigned int rowMajorIndex = " + synEnv["_remap"] + "[colMajorIndex];"); + const size_t idPreInit = synEnv.addInitialiser("const unsigned int idPre = rowMajorIndex / " + synEnv["_row_stride"] + ";"); // Add presynaptic and synapse index to environment - groupEnv.add("id_pre", "idPre", {colMajorIdxInit, rowMajorIdxInit, idPreInit}, {"_col_stride", "_row_stride", "_remap"}); - groupEnv.add("id_syn", "rowMajorIndex", {colMajorIdxInit, rowMajorIdxInit}, {"_col_stride", "_remap"}); + synEnv.add("id_pre", "idPre", {colMajorIdxInit, rowMajorIdxInit, idPreInit}, {"_col_stride", "_row_stride", "_remap"}); + synEnv.add("id_syn", "rowMajorIndex", {colMajorIdxInit, rowMajorIdxInit}, {"_col_stride", "_remap"}); } else { // Add initialiser to calculate synaptic index - const size_t idSynInit = groupEnv.addInitialiser("const unsigned int idSyn = (i * " + groupEnv["num_post"] + ") + spike;"); + const size_t idSynInit = groupEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["num_post"] + ") + spike;"); // Add presynaptic and synapse index to environment - groupEnv.add(Type::Uint32, "id_pre", "i"); - groupEnv.add(Type::Uint32, "id_syn", "idSyn", {idSynInit}, {"num_post"}); + synEnv.add(Type::Uint32, "id_pre", "i"); + synEnv.add(Type::Uint32, "id_syn", "idSyn", {idSynInit}, {"num_post"}); } - groupEnv.add(Type::Uint32, "id_post", "spike"); - groupEnv.add(Type::AddToPre, "addToPre", groupEnv["_out_pre"] + "[" + s.getPreISynIndex(1, groupEnv["id_pre"]) + "] += $(0)"); + synEnv.add(Type::Uint32, "id_post", "spike"); + synEnv.add(Type::AddToPre, "addToPre", synEnv["_out_pre"] + "[" + s.getPreISynIndex(1, synEnv["id_pre"]) + "] += $(0)"); - s.generateSynapseUpdate(*this, groupEnv, modelMerged); + s.generateSynapseUpdate(*this, synEnv, modelMerged); } } groupEnv.getStream() << std::endl; From ee5a94b9df3a8a8eb2d1a21e4d254566b48bab69 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 15 Jun 2023 17:34:06 +0100 Subject: [PATCH 222/725] closing in on neuron and synapse update groups --- .../backends/single_threaded_cpu/backend.h | 4 +- .../genn/genn/code_generator/backendBase.h | 166 +++++++++------ .../genn/genn/code_generator/environment.h | 149 ++++---------- .../genn/genn/code_generator/groupMerged.h | 25 ++- .../code_generator/neuronUpdateGroupMerged.h | 90 +------- .../backends/single_threaded_cpu/backend.cc | 192 +++++++++--------- src/genn/genn/code_generator/groupMerged.cc | 19 +- .../code_generator/neuronUpdateGroupMerged.cc | 163 +++++---------- .../synapseUpdateGroupMerged.cc | 6 +- 9 files changed, 338 insertions(+), 476 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index f9b89a33bc..7caf718a06 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -186,9 +186,9 @@ class BACKEND_EXPORT Backend : public BackendBase //-------------------------------------------------------------------------- // Private methods //-------------------------------------------------------------------------- - void genPresynapticUpdate(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, bool trueSpike) const; + void genPresynapticUpdate(EnvironmentExternalBase &env, const PresynapticUpdateGroupMerged &sg, const ModelSpecMerged &modelMerged, bool trueSpike) const; - void genEmitSpike(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng,bool trueSpike, bool recordingEnabled) const; + void genEmitSpike(EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, bool trueSpike, bool recordingEnabled) const; template void genMergedStructArrayPush(CodeStream &os, const std::vector &groups) const diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 9ac3536d7a..8416eea854 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -490,41 +490,87 @@ class GENN_EXPORT BackendBase } template - void genNeuronIndexCalculation(EnvironmentGroupMergedField &env, const NeuronUpdateGroupMerged &ng, unsigned int batchSize) const + void genNeuronIndexCalculation(EnvironmentGroupMergedField &env, unsigned int batchSize) const { + env.add(Type::Uint32.addConst(), "num_neurons", + Type::Uint32, "numNeurons", + [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); + env.add(Type::Uint32.createPointer(), "_spk_cnt", "spkCnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getName(); }); + env.add(Type::Uint32.createPointer(), "_spk", "spk", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getName(); }); + env.add(Type::Uint32.createPointer(), "_spk_cnt_evnt", "spkCntEvnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getName(); }); + env.add(Type::Uint32.createPointer(), "_spk_evnt", "spkEvnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkEvnt" + g.getName(); }); + env.add(Type::Uint32.createPointer(), "_spk_que_ptr", "spkQuePtr", + [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getName(); }); + + env.add(env.getGroup().getTimeType().createPointer(), "_spk_time", "sT", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "sT" + g.getName(); }); + env.add(env.getGroup().getTimeType().createPointer(), "_spk_evnt_time", "seT", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "seT" + g.getName(); }); + env.add(env.getGroup().getTimeType().createPointer(), "_prev_spk_time", "prevST", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "prevST" + g.getName(); }); + env.add(env.getGroup().getTimeType().createPointer(), "_prev_spk_evnt_time", "prevSET", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "prevSET" + g.getName(); }); + + // If batching is enabled, calculate batch offset if(batchSize > 1) { - os << "const unsigned int batchOffset = group->numNeurons * batch;" << std::endl; + env.add(Type::Uint32.addConst(), "_batchOffset", "batchOffset", + {env.addInitialiser("const unsigned int batchOffset = " + env["num_neurons"] + " * batch;")}, + {"num_neurons"}); } // If axonal delays are required - if(ng.getArchetype().isDelayRequired()) { + if(env.getGroup().getArchetype().isDelayRequired()) { // We should READ from delay slot before spkQuePtr - os << "const unsigned int readDelaySlot = (*group->spkQuePtr + " << (ng.getArchetype().getNumDelaySlots() - 1) << ") % " << ng.getArchetype().getNumDelaySlots() << ";" << std::endl; - os << "const unsigned int readDelayOffset = readDelaySlot * group->numNeurons;" << std::endl; + const unsigned int numDelaySlots = env.getGroup().getArchetype().getNumDelaySlots(); + const std::string numDelaySlotsStr = std::to_string(numDelaySlots); + env.add(Type::Uint32.addConst(), "_read_delay_slot", "readDelaySlot", + {env.addInitialiser("const unsigned int readDelaySlot = (*" + env["_spk_que_ptr"] + " + " + std::to_string(numDelaySlots - 1) + ") % " + numDelaySlotsStr+ ";")}, + {"_spk_que_ptr"}); + env.add(Type::Uint32.addConst(), "_read_delay_offset", "readDelayOffset", + {env.addInitialiser("const unsigned int readDelayOffset = readDelaySlot * " + env["num_neurons"] + ";")}, + {"num_neurons", "_read_delay_slot"}); // And we should WRITE to delay slot pointed to be spkQuePtr - os << "const unsigned int writeDelaySlot = *group->spkQuePtr;" << std::endl; - os << "const unsigned int writeDelayOffset = writeDelaySlot * group->numNeurons;" << std::endl; + env.add(Type::Uint32.addConst(), "_write_delay_slot", "writeDelaySlot", + {env.addInitialiser("const unsigned int writeDelaySlot = *" + env["_spk_que_ptr"] + ";")}, + {"_spk_que_ptr"}); + env.add(Type::Uint32.addConst(), "_write_delay_offset", "writeDelayOffset", + {env.addInitialiser("const unsigned int writeDelayOffset = writeDelaySlot * " + env["num_neurons"] + ";")}, + {"num_neurons", "_write_delay_slot"}); // If batching is also enabled if(batchSize > 1) { // Calculate batched delay slots - os << "const unsigned int readBatchDelaySlot = (batch * " << ng.getArchetype().getNumDelaySlots() << ") + readDelaySlot;" << std::endl; - os << "const unsigned int writeBatchDelaySlot = (batch * " << ng.getArchetype().getNumDelaySlots() << ") + writeDelaySlot;" << std::endl; + env.add(Type::Uint32.addConst(), "_read_batch_delay_slot", "readBatchDelaySlot", + {env.addInitialiser("const unsigned int readBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + readDelaySlot;")}, + {"_read_delay_slot"}); + env.add(Type::Uint32.addConst(), "_write_batch_delay_slot", "writeBatchDelaySlot", + {env.addInitialiser("const unsigned int writeBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + writeDelaySlot;")}, + {"_write_delay_slot"}); // Calculate current batch offset - os << "const unsigned int batchDelayOffset = batchOffset * " << ng.getArchetype().getNumDelaySlots() << ";" << std::endl; + env.add(Type::Uint32.addConst(), "_batch_delay_offset", "batchDelayOffset", + {env.addInitialiser("const unsigned int batchDelayOffset = batchOffset * " + numDelaySlotsStr + ";")}, + {"_batch_offset"}); // Calculate further offsets to include delay and batch - os << "const unsigned int readBatchDelayOffset = readDelayOffset + batchDelayOffset;" << std::endl; - os << "const unsigned int writeBatchDelayOffset = writeDelayOffset + batchDelayOffset;" << std::endl; + env.add(Type::Uint32.addConst(), "_read_batch_delay_offset", "readBatchDelayOffset", + {env.addInitialiser("const unsigned int readBatchDelayOffset = readDelayOffset + batchDelayOffset;")}, + {"_read_delay_offset", "_batchDelayOffset"}); + env.add(Type::Uint32.addConst(), "_write_batch_delay_offset", "writeBatchDelayOffset", + {env.addInitialiser("const unsigned int writeBatchDelayOffset = writeDelayOffset + batchDelayOffset;")}, + {"_write_delay_offset", "_batchDelayOffset"}); } } } template - void genSynapseIndexCalculation(EnvironmentGroupMergedField &env, const SynapseGroupMergedBase &sg, unsigned int batchSize) const + void genSynapseIndexCalculation(EnvironmentGroupMergedField &env, unsigned int batchSize) const { // Synapse group fields groupEnv.add(Type::Uint32.addConst(), "num_pre", @@ -533,57 +579,43 @@ class GENN_EXPORT BackendBase groupEnv.add(Type::Uint32.addConst(), "num_post", Type::Uint32, "numTrgNeurons", [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); - groupEnv.add(Type::Uint32.addConst(), "_row_stride", - Type::Uint32, "rowStride", - [this](const SynapseGroupInternal &sg, size_t) { return std::to_string(getSynapticMatrixRowStride(sg)); }); - groupEnv.add(Type::Uint32.addConst(), "_col_stride", - Type::Uint32, "colStride", - [this](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getMaxSourceConnections()); }); + groupEnv.add(Type::Uint32, "_row_stride", "rowStride", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(getSynapticMatrixRowStride(sg)); }); + groupEnv.add(Type::Uint32, "_col_stride", "colStride", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getMaxSourceConnections()); }); // Postsynaptic model fields - groupEnv.add(modelMerged.getModel().getPrecision(), "_out_post", - modelMerged.getModel().getPrecision().createPointer(), "outPost", - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); - groupEnv.add(modelMerged.getModel().getPrecision(), "_den_delay", - modelMerged.getModel().getPrecision().createPointer(), "denDelay", - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); - groupEnv.add(Type::Uint32, "_den_delay_ptr", - Type::Uint32.createPointer(), "denDelayPtr", - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); + groupEnv.add(modelMerged.getModel().getPrecision().createPointer(), "_out_post", "outPost", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); + groupEnv.add(modelMerged.getModel().getPrecision().createPointer(), "_den_delay", "denDelay", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); + groupEnv.add(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); // Presynaptic output fields - groupEnv.add(modelMerged.getModel().getPrecision(), "_out_pre", - modelMerged.getModel().getPrecision().createPointer(), "outPre", - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); + groupEnv.add(modelMerged.getModel().getPrecision().createPointer(), "_out_pre", "outPre", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); // Source neuron fields - groupEnv.add(Type::Uint32, "_src_spk_que_ptr", - Type::Uint32.createPointer(), "srcSpkQuePtr", - [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + g.getSrcNeuronGroup()->getName(); }); - groupEnv.add(Type::Uint32, "_src_spk_cnt", - Type::Uint32.createPointer(), "srcSpkCnt", - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkCnt" + g.getSrcNeuronGroup()->getName(); }); - groupEnv.add(Type::Uint32, "_src_spk", - Type::Uint32.createPointer(), "srcSpk", - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpk" + g.getSrcNeuronGroup()->getName(); }); - groupEnv.add(Type::Uint32, "_src_spk_evnt_cnt", - Type::Uint32.createPointer(), "srcSpkCntEvnt", - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getSrcNeuronGroup()->getName(); }); - groupEnv.add(Type::Uint32, "_src_spk_evnt", - Type::Uint32.createPointer(), "srcSpkEvnt", - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkEvnt" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32.createPointer(), "_src_spk_que_ptr", "srcSpkQuePtr", + [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32.createPointer(), "_src_spk_cnt", "srcSpkCnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32.createPointer(), "_src_spk", "srcSpk", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32.createPointer(), "_src_spk_evnt_cnt", "srcSpkCntEvnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32.createPointer(), "_src_spk_evnt", "srcSpkEvnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkEvnt" + g.getSrcNeuronGroup()->getName(); }); // Target neuron fields - groupEnv.add(Type::Uint32, "_trg_spk_que_ptr", - Type::Uint32.createPointer(), "trgSpkQuePtr", - [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + g.getTrgNeuronGroup()->getName(); }); - groupEnv.add(Type::Uint32, "_trg_spk_cnt", - Type::Uint32.createPointer(), "trgSpkCnt", - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkCnt" + g.getTrgNeuronGroup()->getName(); }); - groupEnv.add(Type::Uint32, "_trg_spk", - Type::Uint32.createPointer(), "trgSpk", - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpk" + g.getTrgNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32.createPointer(), "_trg_spk_que_ptr", "trgSpkQuePtr", + [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getTrgNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32.createPointer(), "_trg_spk_cnt", "trgSpkCnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getTrgNeuronGroup()->getName(); }); + groupEnv.add(Type::Uint32.createPointer(), "_trg_spk", "trgSpk", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getTrgNeuronGroup()->getName(); }); // If batching is enabled if(batchSize > 1) { @@ -596,7 +628,7 @@ class GENN_EXPORT BackendBase {"num_post"}); // Calculate batch offsets into synapse arrays, using 64-bit arithmetic if necessary - if(areSixtyFourBitSynapseIndicesRequired(sg)) { + if(areSixtyFourBitSynapseIndicesRequired(env.getGroup())) { assert(false); //os << "const uint64_t synBatchOffset = (uint64_t)preBatchOffset * (uint64_t)group->rowStride;" << std::endl; } @@ -607,14 +639,14 @@ class GENN_EXPORT BackendBase } // If synapse group has kernel - const auto &kernelSize = sg.getArchetype().getKernelSize(); + const auto &kernelSize = env.getGroup().getArchetype().getKernelSize(); if(!kernelSize.empty()) { // Loop through kernel dimensions and multiply together // **TODO** extract list of kernel size variables referenced std::ostringstream kernBatchOffsetInit; kernBatchOffsetInit << "const unsigned int kernBatchOffset = "; for(size_t i = 0; i < kernelSize.size(); i++) { - kernBatchOffsetInit << sg.getKernelSize(i) << " * "; + kernBatchOffsetInit << env.getGroup().getKernelSize(i) << " * "; } // And finally by batch @@ -626,9 +658,9 @@ class GENN_EXPORT BackendBase } // If presynaptic neuron group has variable queues, calculate offset to read from its variables with axonal delay - if(sg.getArchetype().getSrcNeuronGroup()->isDelayRequired()) { - const unsigned int numDelaySteps = sg.getArchetype().getDelaySteps(); - const unsigned int numSrcDelaySlots = sg.getArchetype().getSrcNeuronGroup()->getNumDelaySlots(); + if(env.getGroup().getArchetype().getSrcNeuronGroup()->isDelayRequired()) { + const unsigned int numDelaySteps = env.getGroup().getArchetype().getDelaySteps(); + const unsigned int numSrcDelaySlots = env.getGroup().getArchetype().getSrcNeuronGroup()->getNumDelaySlots(); std::ostringstream preDelaySlotInit; preDelaySlotInit << "const unsigned int preDelaySlot = "; @@ -657,7 +689,9 @@ class GENN_EXPORT BackendBase {"_pre_delay_offset", "_pre_batch_offset"}); } - if(sg.getArchetype().getWUModel()->isPrevPreSpikeTimeRequired() || sg.getArchetype().getWUModel()->isPrevPreSpikeEventTimeRequired()) { + if(env.getGroup().getArchetype().getWUModel()->isPrevPreSpikeTimeRequired() + || env.getGroup().getArchetype().getWUModel()->isPrevPreSpikeEventTimeRequired()) + { os << "const unsigned int prePrevSpikeTimeDelayOffset = " << "((*group->srcSpkQuePtr + " << (numSrcDelaySlots - numDelaySteps - 1) << ") % " << numSrcDelaySlots << ")" << " * group->numSrcNeurons;" << std::endl; if(batchSize > 1) { @@ -667,9 +701,9 @@ class GENN_EXPORT BackendBase } // If postsynaptic neuron group has variable queues, calculate offset to read from its variables at current time - if(sg.getArchetype().getTrgNeuronGroup()->isDelayRequired()) { - const unsigned int numBackPropDelaySteps = sg.getArchetype().getBackPropDelaySteps(); - const unsigned int numTrgDelaySlots = sg.getArchetype().getTrgNeuronGroup()->getNumDelaySlots(); + if(env.getGroup().getArchetype().getTrgNeuronGroup()->isDelayRequired()) { + const unsigned int numBackPropDelaySteps = env.getGroup().getArchetype().getBackPropDelaySteps(); + const unsigned int numTrgDelaySlots = env.getGroup().getArchetype().getTrgNeuronGroup()->getNumDelaySlots(); os << "const unsigned int postDelaySlot = "; if(numBackPropDelaySteps == 0) { @@ -685,7 +719,7 @@ class GENN_EXPORT BackendBase os << "const unsigned int postBatchDelayOffset = postDelayOffset + (postBatchOffset * " << numTrgDelaySlots << ");" << std::endl; } - if(sg.getArchetype().getWUModel()->isPrevPostSpikeTimeRequired()) { + if(env.getGroup().getArchetype().getWUModel()->isPrevPostSpikeTimeRequired()) { os << "const unsigned int postPrevSpikeTimeDelayOffset = " << "((*group->trgSpkQuePtr + " << (numTrgDelaySlots - numBackPropDelaySteps - 1) << ") % " << numTrgDelaySlots << ")" << " * group->numTrgNeurons;" << std::endl; if(batchSize > 1) { diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index ef7cd88dc1..98ad8cb3da 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -351,6 +351,14 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &initialisers = {}, const std::vector &dependents = {}) + { + add(type, name, type, fieldName, getFieldValue, indexSuffix, mergedFieldType, initialisers, dependents); + } + void addScalar(const std::string &name, const std::string &fieldSuffix, typename G::GetFieldDoubleValueFunc getFieldValue) { add(m_Group.getScalarType().addConst(), name, @@ -397,8 +405,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase>> m_Environment; }; - -//---------------------------------------------------------------------------- -// GeNN::CodeGenerator::EnvironmentSubstitute -//---------------------------------------------------------------------------- -//! Standard pretty printing environment simply allowing substitutions to be implemented -class EnvironmentSubstitute : public EnvironmentExternalBase -{ -public: - EnvironmentSubstitute(EnvironmentExternalBase &enclosing) - : EnvironmentExternalBase(static_cast(enclosing)), m_Contents(m_ContentsStream) - { - } - - EnvironmentSubstitute(CodeStream &os) - : EnvironmentExternalBase(os), m_Contents(m_ContentsStream) - { - } - - EnvironmentSubstitute(const EnvironmentSubstitute&) = delete; - - ~EnvironmentSubstitute(); - - //------------------------------------------------------------------------ - // PrettyPrinter::EnvironmentBase virtuals - //------------------------------------------------------------------------ - virtual std::string getName(const std::string &name, std::optional type = std::nullopt) final; - - virtual CodeStream &getStream() final - { - return m_Contents; - } - - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - void addSubstitution(const std::string &source, const std::string &destination, - std::vector initialisers = {}); - - size_t addInitialiser(const std::string &initialiser); - - template - void addVarNameSubstitution(const std::vector &variables, const std::string &fieldSuffix = "") - { - for(const auto &v : variables) { - addSubstitution(v.name, "group->" + v.name + fieldSuffix); - } - } - - template - void addParamValueSubstitution(const std::vector ¶mNames, const std::unordered_map &values, - const std::string &fieldSuffix, G isHeterogeneousFn) - { - if(paramNames.size() != values.size()) { - throw std::runtime_error("Number of parameters does not match number of values"); - } - - for(const auto &p : paramNames) { - if(isHeterogeneousFn(p)) { - addSubstitution(p, "group->" + p + fieldSuffix); - } - else { - // **TODO** scalar suffix - addSubstitution(p, Utils::writePreciseString(values.at(p))); - } - } - } - - template - void addVarValueSubstitution(const std::vector &variables, const std::unordered_map &values, - const std::string &fieldSuffix, G isHeterogeneousFn) - { - if(variables.size() != values.size()) { - throw std::runtime_error("Number of variables does not match number of values"); - } - - for(const auto &v : variables) { - if(isHeterogeneousFn(v.name)) { - addSubstitution(v.name, "group->" + v.name + fieldSuffix); - } - else { - addSubstitution(v.name, Utils::writePreciseString(values.at(v.name))); - } - } - } - -private: - //------------------------------------------------------------------------ - // Members - //------------------------------------------------------------------------ - std::ostringstream m_ContentsStream; - CodeStream m_Contents; - std::unordered_map>> m_VarSubstitutions; - std::vector> m_Initialisers; -}; - //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentLocalVarCache //---------------------------------------------------------------------------- @@ -655,10 +567,10 @@ class EnvironmentLocalVarCache : public EnvironmentExternalBase using GetIndexFn = std::function; public: - EnvironmentLocalVarCache(const G &group, const Type::TypeContext &context, EnvironmentExternal &enclosing, - const std::string &fieldSuffix, const std::string & localPrefix, + EnvironmentLocalVarCache(const G &group, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, + const std::string &fieldSuffix, const std::string &localPrefix, GetIndexFn getReadIndex, GetIndexFn getWriteIndex) - : EnvironmentExternal(static_cast(enclosing)), m_Group(group), m_Context(context), m_Contents(m_ContentsStream), + : EnvironmentExternalBase(enclosing), m_Group(group), m_Context(context), m_Contents(m_ContentsStream), m_FieldSuffix(fieldSuffix), m_LocalPrefix(localPrefix), m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex) { // Add name of each definition to map, initially with value set to value @@ -667,15 +579,10 @@ class EnvironmentLocalVarCache : public EnvironmentExternalBase [](const auto &v){ return std::make_pair(v.name, false); }); } - EnvironmentLocalVarCache(const G &group, const Type::TypeContext &context, EnvironmentExternal &enclosing, - const std::string &fieldSuffix, const std::string & localPrefix, GetIndexFn getIndex) - : EnvironmentExternal(static_cast(enclosing)), m_Group(group), m_Context(context), - m_Contents(m_ContentsStream), m_FieldSuffix(fieldSuffix), m_LocalPrefix(localPrefix), m_GetReadIndex(getIndex), m_GetWriteIndex(getIndex) + EnvironmentLocalVarCache(const G &group, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, + const std::string &fieldSuffix, const std::string &localPrefix, GetIndexFn getIndex) + : EnvironmentLocalVarCache(group, context, enclosing, fieldSuffix, getIndex, getIndex) { - // Add name of each definition to map, initially with value set to value - const auto defs = A(m_Group).getDefs(); - std::transform(defs.cbegin(), defs.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), - [](const auto &v){ return std::make_pair(v.name, false); }); } EnvironmentLocalVarCache(const EnvironmentLocalVarCache&) = delete; @@ -720,6 +627,32 @@ class EnvironmentLocalVarCache : public EnvironmentExternalBase } } + //------------------------------------------------------------------------ + // TypeChecker::EnvironmentBase virtuals + //------------------------------------------------------------------------ + virtual std::vector getTypes(const Transpiler::Token &name, Transpiler::ErrorHandlerBase &errorHandler) final + { + // If name isn't found in environment + auto var = m_VariablesReferenced.find(name.lexeme); + if (var == m_VariablesReferenced.end()) { + return getContextTypes(name, errorHandler); + } + // Otherwise + else { + // Set flag to indicate that variable has been referenced + var->second = true; + + // Find corresponsing variable definition + const auto varDefs = A(m_Group).getDefs(); + auto varDef = std::find_if(varDefs.cbegin(), varDefs.cend(), + [](const auto &v){ return v.name == name.lexeme; }); + assert(varDef != varDefs.cend()); + + // Return it's resolved type + return {varDef->type.resolve(m_Context)}; + } + } + //------------------------------------------------------------------------ // PrettyPrinter::EnvironmentBase virtuals //------------------------------------------------------------------------ diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index c5a362baaa..e9f98bec5d 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -577,7 +577,7 @@ class GENN_EXPORT NeuronSpikeQueueUpdateGroupMerged : public GroupMerged> &groups) + : GroupMerged(index, typeContext, groups), m_ArchetypeCode(archetypeCode) + {} + //---------------------------------------------------------------------------- // Protected methods //---------------------------------------------------------------------------- @@ -840,5 +858,10 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged genEmitTrueSpike, BackendBase::GroupHandlerEnv genEmitSpikeLikeEvent) const; - void generateWUVarUpdate(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged) const; + void generateWUVarUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const; std::string getVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; std::string getReadVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; @@ -268,23 +214,5 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase std::vector m_MergedOutSynPreOutputGroups; std::vector m_MergedInSynWUMPostCodeGroups; std::vector m_MergedOutSynWUMPreCodeGroups; - - //! List of statements parsed and type-checked in constructor; and used to generate sim code - Transpiler::Statement::StatementList m_SimStatements; - - //! Expression parsed and type-checked in constructor; and used to generate threshold condition code - Transpiler::Expression::ExpressionPtr m_ThresholdConditionExpression; - - //! List of statements parsed and type-checked in constructor; and used to generate reset code - Transpiler::Statement::StatementList m_ResetStatements; - - //! Resolved types used to generate sim code - Transpiler::TypeChecker::ResolvedTypeMap m_SimResolvedTypes; - - //! Resolved types used to generate threshold condition code - Transpiler::TypeChecker::ResolvedTypeMap m_ThresholdConditionResolvedTypes; - - //! Resolved types used to generate threshold condition code - Transpiler::TypeChecker::ResolvedTypeMap m_ResetResolvedTypes; }; } // namespace GeNN::CodeGenerator diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 5cedb4a78b..aedc56ae14 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -70,11 +70,6 @@ class Timer const bool m_TimingEnabled; }; -//----------------------------------------------------------------------- -const std::vector &getFunctionTemplates(const std::string &precision) -{ - return (precision == "double") ? cpuDoublePrecisionFunctions : cpuSinglePrecisionFunctions; -} //----------------------------------------------------------------------- template void genKernelIteration(EnvironmentExternal &env, const G &g, size_t numKernelDims, std::function/*BackendBase::Handler*/ handler) @@ -166,44 +161,41 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, n); - - if(n.getArchetype().isDelayRequired()) { - // Calculate delay slot corresponding to last timestep - groupEnv.getStream() << "const unsigned int lastTimestepDelaySlot = (*group->spkQuePtr + " << (n.getArchetype().getNumDelaySlots() - 1) << ") % " << n.getArchetype().getNumDelaySlots() << ";" << std::endl; - groupEnv.getStream() << "const unsigned int lastTimestepDelayOffset = lastTimestepDelaySlot * group->numNeurons;" << std::endl; + genNeuronIndexCalculation(groupEnv, 1); + if(n.getArchetype().isDelayRequired()) { if(n.getArchetype().isPrevSpikeTimeRequired()) { // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - groupEnv.getStream() << "for(unsigned int i = 0; i < group->spkCnt[lastTimestepDelaySlot]; i++)"; + groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["_spk_cnt"] << "[" << groupEnv["_read_delay_slot"] << "]; i++)"; { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.getStream() << "group->prevST[lastTimestepDelayOffset + group->spk[lastTimestepDelayOffset + i]] = t - DT;" << std::endl; + groupEnv.getStream() << groupEnv["_prev_spk_time"] << "[" << groupEnv["_read_delay_offset"] << " + " << groupEnv["_spk"] << "[" << groupEnv["_read_delay_offset"] << " + i]] = t - DT;" << std::endl; } } if(n.getArchetype().isPrevSpikeEventTimeRequired()) { // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - groupEnv.getStream() << "for(unsigned int i = 0; i < group->spkCntEvnt[lastTimestepDelaySlot]; i++)"; + groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["_spk_cnt_envt"] << "[" << groupEnv["_read_delay_slot"] << "]; i++)"; { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.getStream() << "group->prevSET[lastTimestepDelayOffset + group->spkEvnt[lastTimestepDelayOffset + i]] = t - DT;" << std::endl; + groupEnv.getStream() << groupEnv["_prev_spk_evnt_time"] << "[" << groupEnv["_read_delay_offset"] << " + " << groupEnv["_spk_evnt"] << "[" << groupEnv["_read_delay_offset"] << " + i]] = t - DT;" << std::endl; } } } else { if(n.getArchetype().isPrevSpikeTimeRequired()) { // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - groupEnv.getStream() << "for(unsigned int i = 0; i < group->spkCnt[0]; i++)"; + groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["_spk_cnt"] << "[0]; i++)"; { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.getStream() << "group->prevST[group->spk[i]] = t - DT;" << std::endl; + groupEnv.getStream() << groupEnv["_prev_spk_time"] << "[" << groupEnv["_spk"] << "[i]] = t - DT;" << std::endl; } } if(n.getArchetype().isPrevSpikeEventTimeRequired()) { // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - groupEnv.getStream() << "for(unsigned int i = 0; i < group->spkCntEvnt[0]; i++)"; + groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["_spk_cnt_evnt"] << "[0]; i++)"; { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.getStream() << "group->prevSET[group->spkEvnt[i]] = t - DT;" << std::endl; + groupEnv.getStream() << groupEnv["_prev_spk_evnt_time"] << "[" << groupEnv["_spk_evnt"] << "[i]] = t - DT;" << std::endl; } } } @@ -224,9 +216,10 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Get reference to group funcEnv.getStream() << "const auto *group = &mergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, n); + genNeuronIndexCalculation(groupEnv, 1); // Generate spike count reset - n.genMergedGroupSpikeCountReset(groupEnv.getStream(), 1); + n.genMergedGroupSpikeCountReset(groupEnv, 1); } }); @@ -248,7 +241,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // If spike or spike-like event recording is in use if(n.getArchetype().isSpikeRecordingEnabled() || n.getArchetype().isSpikeEventRecordingEnabled()) { // Calculate number of words which will be used to record this population's spikes - groupEnv.getStream() << "const unsigned int numRecordingWords = (group->numNeurons + 31) / 32;" << std::endl; + groupEnv.getStream() << "const unsigned int numRecordingWords = (" << groupEnv["num_neurons"] << " + 31) / 32;" << std::endl; // Zero spike recording buffer if(n.getArchetype().isSpikeRecordingEnabled()) { @@ -261,10 +254,10 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host } } - genNeuronIndexCalculation(groupEnv, n, 1); + genNeuronIndexCalculation(groupEnv, 1); groupEnv.getStream() << std::endl; - groupEnv.getStream() << "for(unsigned int i = 0; i < group->numNeurons; i++)"; + groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["num_neurons"] << "; i++)"; { CodeStream::Scope b(groupEnv.getStream()); @@ -276,7 +269,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Generate neuron update n.generateNeuronUpdate(*this, rngEnv, modelMerged, // Emit true spikes - [&modelMerged, this](EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng) + [&modelMerged, this](EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng) { // Insert code to update WU vars ng.generateWUVarUpdate(*this, env, modelMerged); @@ -285,7 +278,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host genEmitSpike(env, ng, true, ng.getArchetype().isSpikeRecordingEnabled()); }, // Emit spike-like events - [this](EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng) + [this](EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng) { // Insert code to emit spike-like events genEmitSpike(env, ng, false, ng.getArchetype().isSpikeEventRecordingEnabled()); @@ -353,7 +346,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos EnvironmentGroupMergedField groupEnv(funcEnv, s); // **TODO** rename as it does more! - genSynapseIndexCalculation(groupEnv, s, 1); + genSynapseIndexCalculation(groupEnv, 1); // Loop through presynaptic neurons groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)"; @@ -432,7 +425,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, s); - genSynapseIndexCalculation(groupEnv, s, 1); + genSynapseIndexCalculation(groupEnv, 1); // generate the code for processing spike-like events if (s.getArchetype().isSpikeEventRequired()) { @@ -467,7 +460,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, s); - genSynapseIndexCalculation(groupEnv, s, 1); + genSynapseIndexCalculation(groupEnv, 1); // Get number of postsynaptic spikes if (s.getArchetype().getTrgNeuronGroup()->isDelayRequired() && s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired()) { @@ -1692,7 +1685,7 @@ boost::uuids::detail::sha1::digest_type Backend::getHashDigest() const return hash.get_digest(); } //-------------------------------------------------------------------------- -void Backend::genPresynapticUpdate(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, bool trueSpike) const +void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, const PresynapticUpdateGroupMerged &sg, const ModelSpecMerged &modelMerged, bool trueSpike) const { // Get suffix based on type of events const std::string eventSuffix = trueSpike ? "" : "Evnt"; @@ -1702,9 +1695,9 @@ void Backend::genPresynapticUpdate(EnvironmentExternal &env, const ModelSpecMerg const auto &connectInit = sg.getArchetype().getToeplitzConnectivityInitialiser(); // Loop through Toeplitz matrix diagonals - os << "for(unsigned int j = 0; j < group->rowStride; j++)"; + env.getStream() << "for(unsigned int j = 0; j < group->rowStride; j++)"; { - CodeStream::Scope b(os); + /*CodeStream::Scope b(env.getStream()); // Create substitution stack for generating procedural connectivity code Substitutions connSubs(&popSubs); @@ -1804,72 +1797,70 @@ void Backend::genPresynapticUpdate(EnvironmentExternal &env, const ModelSpecMerg if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { os << CodeStream::CB(130); // end if (eCode) } - } + }*/ } } else { // Detect spike events or spikes and do the update - os << "// process presynaptic events: " << (trueSpike ? "True Spikes" : "Spike type events") << std::endl; + env.getStream() << "// process presynaptic events: " << (trueSpike ? "True Spikes" : "Spike type events") << std::endl; if(sg.getArchetype().getSrcNeuronGroup()->isDelayRequired()) { - os << "for (unsigned int i = 0; i < group->srcSpkCnt" << eventSuffix << "[preDelaySlot]; i++)"; + env.getStream() << "for (unsigned int i = 0; i < group->srcSpkCnt" << eventSuffix << "[preDelaySlot]; i++)"; } else { - os << "for (unsigned int i = 0; i < group->srcSpkCnt" << eventSuffix << "[0]; i++)"; + env.getStream() << "for (unsigned int i = 0; i < group->srcSpkCnt" << eventSuffix << "[0]; i++)"; } { - CodeStream::Scope b(os); - if(!wu->getSimSupportCode().empty()) { + CodeStream::Scope b(env.getStream()); + /*if(!wu->getSimSupportCode().empty()) { os << "using namespace " << modelMerged.getPresynapticUpdateSupportCodeNamespace(wu->getSimSupportCode()) << ";" << std::endl; - } + }*/ + EnvironmentGroupMergedField groupEnv(env, sg); + const std::string queueOffset = sg.getArchetype().getSrcNeuronGroup()->isDelayRequired() ? "preDelayOffset + " : ""; - os << "const unsigned int ipre = group->srcSpk" << eventSuffix << "[" << queueOffset << "i];" << std::endl; + groupEnv.add(Type::Uint32, "id_pre", "idPre", + {groupEnv.addInitialiser("const unsigned int ipre = group->srcSpk" + eventSuffix + "[" + queueOffset + "i];")}); // If this is a spike-like event, insert threshold check for this presynaptic neuron if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { - os << "if("; - - Substitutions threshSubs(&popSubs); - threshSubs.addVarSubstitution("id_pre", "ipre"); + groupEnv.getStream() << "if("; // Generate weight update threshold condition - sg.generateSpikeEventThreshold(*this, os, modelMerged, threshSubs); - - os << ")"; - os << CodeStream::OB(10); - } - - Substitutions synSubs(&popSubs); - synSubs.addVarSubstitution("id_pre", "ipre"); - synSubs.addVarSubstitution("id_post", "ipost"); - synSubs.addVarSubstitution("id_syn", "synAddress"); + sg.generateSpikeEventThreshold(*this, groupEnv, modelMerged); - if(sg.getArchetype().isDendriticDelayRequired()) { - synSubs.addFuncSubstitution("addToInSynDelay", 2, "group->denDelay[" + sg.getPostDenDelayIndex(1, "ipost", "$(1)") + "] += $(0)"); - } - else { - synSubs.addFuncSubstitution("addToInSyn", 1, "group->inSyn[" + sg.getPostISynIndex(1, "ipost") + "] += $(0)"); + groupEnv.getStream() << ")"; + groupEnv.getStream() << CodeStream::OB(10); } - if(sg.getArchetype().isPresynapticOutputRequired()) { - synSubs.addFuncSubstitution("addToPre", 1, "group->revInSyn[" + sg.getPreISynIndex(1, synSubs["id_pre"]) + "] += $(0)"); - } + // Add correct functions for apply synaptic input + groupEnv.add(Type::AddToPostDenDelay, "addToPostDelay", env["_den_delay"] + "[" + sg.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)", + {}, {"_den_delay"}); + groupEnv.add(Type::AddToPost, "addToPost", env["_out_post"] + "[" + sg.getPostISynIndex(1, "j") + "] += $(0)", + {}, {"_out_post"}); + groupEnv.add(Type::AddToPre, "addToPre", env["_out_pre"] + "[" + sg.getPreISynIndex(1, env["id_pre"]) + "] += $(0)", + {}, {"id_pre"}); + // If connectivity is sparse if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << "const unsigned int npost = group->rowLength[ipre];" << std::endl; - os << "for (unsigned int j = 0; j < npost; j++)"; + groupEnv.getStream() << "const unsigned int npost = group->rowLength[ipre];" << std::endl; + groupEnv.getStream() << "for (unsigned int j = 0; j < npost; j++)"; { - CodeStream::Scope b(os); - - // **TODO** seperate stride from max connection - os << "const unsigned int synAddress = (ipre * group->rowStride) + j;" << std::endl; - os << "const unsigned int ipost = group->ind[synAddress];" << std::endl; - + CodeStream::Scope b(groupEnv.getStream()); + EnvironmentGroupMergedField synEnv(groupEnv, sg); + + // **TODO** 64-bit id_syn + synEnv.add(Type::Uint32, "id_syn", "idSyn", + {synEnv.addInitialiser("const unsigned int idSyn = (ipre * " + env["_row_stride"] + ") + j;")}, + {"_row_stride"}); + synEnv.add(Type::Uint32, "id_post", "idPost", + {synEnv.addInitialiser("const unsigned int idPost = " + env["_ind"] + "[idSyn];")}, + {"_ind", "id_syn"}); + if(trueSpike) { - sg.generateSpikeUpdate(*this, os, modelMerged, synSubs); + sg.generateSpikeUpdate(*this, synEnv, modelMerged); } else { - sg.generateSpikeEventUpdate(*this, os, modelMerged, synSubs); + sg.generateSpikeEventUpdate(*this, synEnv, modelMerged); } } } @@ -1878,84 +1869,93 @@ void Backend::genPresynapticUpdate(EnvironmentExternal &env, const ModelSpecMerg } else if(getPreferences().enableBitmaskOptimisations && (sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK)) { // Determine the number of words in each row - os << "const unsigned int rowWords = ((group->numTrgNeurons + 32 - 1) / 32);" << std::endl; - os << "for(unsigned int w = 0; w < rowWords; w++)"; + groupEnv.getStream() << "const unsigned int rowWords = ((" << env["_num_post"] << " + 32 - 1) / 32);" << std::endl; + groupEnv.getStream() << "for(unsigned int w = 0; w < rowWords; w++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(groupEnv.getStream()); // Read row word - os << "uint32_t connectivityWord = group->gp[(ipre * rowWords) + w];" << std::endl; + groupEnv.getStream() << "uint32_t connectivityWord = group->gp[(ipre * rowWords) + w];" << std::endl; // Set ipost to first synapse in connectivity word - os << "unsigned int ipost = w * 32;" << std::endl; + groupEnv.getStream() << "unsigned int ipost = w * 32;" << std::endl; + groupEnv.add(Type::Uint32, "id_post", "ipost"); // While there any bits left - os << "while(connectivityWord != 0)"; + groupEnv.getStream() << "while(connectivityWord != 0)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(groupEnv.getStream()); // Cound leading zeros (as bits are indexed backwards this is index of next synapse) - os << "const int numLZ = gennCLZ(connectivityWord);" << std::endl; + groupEnv.getStream() << "const int numLZ = gennCLZ(connectivityWord);" << std::endl; // Shift off zeros and the one just discovered // **NOTE** << 32 appears to result in undefined behaviour - os << "connectivityWord = (numLZ == 31) ? 0 : (connectivityWord << (numLZ + 1));" << std::endl; + groupEnv.getStream() << "connectivityWord = (numLZ == 31) ? 0 : (connectivityWord << (numLZ + 1));" << std::endl; // Add to ipost - os << "ipost += numLZ;" << std::endl; + groupEnv.getStream() << "ipost += numLZ;" << std::endl; // If we aren't in padding region // **TODO** don't bother checking if there is no padding - os << "if(ipost < group->numTrgNeurons)"; + groupEnv.getStream() << "if(ipost < group->numTrgNeurons)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); if(trueSpike) { - sg.generateSpikeUpdate(*this, os, modelMerged, synSubs); + sg.generateSpikeUpdate(*this, groupEnv, modelMerged); } else { - sg.generateSpikeEventUpdate(*this, os, modelMerged, synSubs); + sg.generateSpikeEventUpdate(*this, groupEnv, modelMerged); } } // Increment ipost to take into account fact the next CLZ will go from bit AFTER synapse - os << "ipost++;" << std::endl; + groupEnv.getStream() << "ipost++;" << std::endl; } } } // Otherwise (DENSE or BITMASK) else { - os << "for (unsigned int ipost = 0; ipost < group->numTrgNeurons; ipost++)"; + groupEnv.getStream() << "for (unsigned int ipost = 0; ipost < group->numTrgNeurons; ipost++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(groupEnv.getStream()); + EnvironmentGroupMergedField synEnv(groupEnv, sg); + synEnv.add(Type::Uint32, "id_post", "ipost"); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { - os << "const uint64_t gid = (ipre * (uint64_t)group->numTrgNeurons + ipost);" << std::endl; - os << "if (B(group->gp[gid / 32], gid & 31))" << CodeStream::OB(20); - } + // **TODO** 64-bit index + synEnv.getStream() << "const uint64_t gid = (ipre * group->numTrgNeurons + ipost);" << std::endl; - os << "const unsigned int synAddress = (ipre * group->numTrgNeurons) + ipost;" << std::endl; + synEnv.getStream() << "if (B(group->gp[gid / 32], gid & 31))" << CodeStream::OB(20); + } + else { + synEnv.add(Type::Uint32, "id_syn", "idSyn", + {synEnv.addInitialiser("const unsigned int idSyn = (ipre * " + synEnv["num_post"] + ") + ipost;")}, + {"num_post"}); + } + if(trueSpike) { - sg.generateSpikeUpdate(*this, os, modelMerged, synSubs); + sg.generateSpikeUpdate(*this, synEnv, modelMerged); } else { - sg.generateSpikeEventUpdate(*this, os, modelMerged, synSubs); + sg.generateSpikeEventUpdate(*this, synEnv, modelMerged); } if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { - os << CodeStream::CB(20); + synEnv.getStream() << CodeStream::CB(20); } } } // If this is a spike-like event, close braces around threshold check if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { - os << CodeStream::CB(10); + groupEnv.getStream() << CodeStream::CB(10); } } } } //-------------------------------------------------------------------------- -void Backend::genEmitSpike(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, bool trueSpike, bool recordingEnabled) const +void Backend::genEmitSpike(EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, bool trueSpike, bool recordingEnabled) const { // Determine if delay is required and thus, at what offset we should write into the spike queue const bool spikeDelayRequired = trueSpike ? (ng.getArchetype().isDelayRequired() && ng.getArchetype().isTrueSpikeRequired()) : ng.getArchetype().isDelayRequired(); diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 1184370f0f..986306d3b4 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -10,6 +10,7 @@ #include "code_generator/backendBase.h" #include "code_generator/codeGenUtils.h" #include "code_generator/codeStream.h" +#include "code_generator/environment.h" using namespace GeNN; using namespace GeNN::CodeGenerator; @@ -36,30 +37,30 @@ NeuronSpikeQueueUpdateGroupMerged::NeuronSpikeQueueUpdateGroupMerged(size_t inde } } //---------------------------------------------------------------------------- -void NeuronSpikeQueueUpdateGroupMerged::genMergedGroupSpikeCountReset(CodeStream &os, unsigned int batchSize) const +void NeuronSpikeQueueUpdateGroupMerged::genMergedGroupSpikeCountReset(EnvironmentExternalBase &env, unsigned int batchSize) const { if(getArchetype().isSpikeEventRequired()) { if(getArchetype().isDelayRequired()) { - os << "group->spkCntEvnt[*group->spkQuePtr"; + env.getStream() << env["_spk_cnt_evnt"] << "[*" << env["_spk_que_ptr"]; if(batchSize > 1) { - os << " + (batch * " << getArchetype().getNumDelaySlots() << ")"; + env.getStream() << " + (batch * " << getArchetype().getNumDelaySlots() << ")"; } - os << "] = 0; " << std::endl; + env.getStream() << "] = 0; " << std::endl; } else { - os << "group->spkCntEvnt[" << ((batchSize > 1) ? "batch" : "0") << "] = 0;" << std::endl; + env.getStream() << env["_spk_cnt_evnt"] << "[" << ((batchSize > 1) ? "batch" : "0") << "] = 0;" << std::endl; } } if(getArchetype().isTrueSpikeRequired() && getArchetype().isDelayRequired()) { - os << "group->spkCnt[*group->spkQuePtr"; + env.getStream() << env["_spk_cnt"] << "[*" << env["_spk_que_ptr"]; if(batchSize > 1) { - os << " + (batch * " << getArchetype().getNumDelaySlots() << ")"; + env.getStream() << " + (batch * " << getArchetype().getNumDelaySlots() << ")"; } - os << "] = 0; " << std::endl; + env.getStream() << "] = 0; " << std::endl; } else { - os << "group->spkCnt[" << ((batchSize > 1) ? "batch" : "0") << "] = 0;" << std::endl; + env.getStream() << env["_spk_cnt"] << "[" << ((batchSize > 1) ? "batch" : "0") << "] = 0;" << std::endl; } } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 56c79cc2cd..d5d9d9fe5a 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -1,6 +1,7 @@ #include "code_generator/neuronUpdateGroupMerged.h" // GeNN code generator includes +#include "code_generator/standardLibrary.h" #include "code_generator/groupMergedTypeEnvironment.h" #include "code_generator/modelSpecMerged.h" @@ -9,7 +10,6 @@ #include "transpiler/parser.h" #include "transpiler/prettyPrinter.h" #include "transpiler/scanner.h" -#include "transpiler/standardLibrary.h" #include "transpiler/typeChecker.h" using namespace GeNN; @@ -27,7 +27,7 @@ NeuronUpdateGroupMerged::CurrentSource::CurrentSource(size_t index, const Type:: const BackendBase &backend, const std::vector> &groups) : GroupMerged(index, typeContext, groups) { - const std::string suffix = "CS" + std::to_string(getIndex()); + /*const std::string suffix = "CS" + std::to_string(getIndex()); // Create type environment GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); @@ -47,15 +47,11 @@ NeuronUpdateGroupMerged::CurrentSource::CurrentSource(size_t index, const Type:: typeEnvironment.defineVars(cm->getVars(), backend.getDeviceVarPrefix(), suffix); // Add EGPs - typeEnvironment.defineEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", suffix); + typeEnvironment.defineEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", suffix);*/ - // Scan, parse and type-check injection code - ErrorHandler errorHandler; - std::tie(m_InjectionStatements, m_InjectionResolvedTypes) = scanParseAndTypeCheckStatements(cm->getInjectionCode(), typeContext, - typeEnvironment, errorHandler); } //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend, EnvironmentExternal &env, +void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend, EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) const { const std::string suffix = "CS" + std::to_string(getIndex()); @@ -112,7 +108,7 @@ NeuronUpdateGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContex const std::string suffix = "InSyn" + std::to_string(getIndex()); // Create type environment - GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); + /*GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); // Add pointer to insyn addField(getScalarType().createPointer(), "inSyn" + suffix, @@ -151,10 +147,10 @@ NeuronUpdateGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContex std::tie(m_DecayStatements, m_DecayResolvedTypes) = scanParseAndTypeCheckStatements(psm->getDecayCode(), typeContext, typeEnvironment, errorHandler); std::tie(m_ApplyInputStatements, m_ApplyInputResolvedTypes) = scanParseAndTypeCheckStatements(psm->getApplyInputCode(), typeContext, - typeEnvironment, errorHandler); + typeEnvironment, errorHandler);*/ } //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, EnvironmentExternal &env, +void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) const { const std::string suffix = "InSyn" + std::to_string(getIndex()); @@ -528,38 +524,8 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC const std::vector> &groups) : NeuronGroupMergedBase(index, typeContext, backend, groups) { - using namespace Type; - - // Create type environment - StandardLibrary::FunctionTypes stdLibraryEnv; - GroupMergedTypeEnvironment typeEnvironment(*this, &stdLibraryEnv); - - // Add RNG - if(backend.isPopulationRNGRequired() && getArchetype().isSimRNGRequired()) { - // **TODO** inject RNG types into environment - - addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); - } - - // Add heterogeneous neuron model parameters - const auto *nm = getArchetype().getNeuronModel(); - typeEnvironment.defineHeterogeneousParams(nm->getParamNames(), "", - &NeuronGroupInternal::getParams, - &NeuronUpdateGroupMerged::isParamHeterogeneous); - - // Add heterogeneous weight update model derived parameters - typeEnvironment.defineHeterogeneousDerivedParams(nm->getDerivedParams(), "", - &NeuronGroupInternal::getDerivedParams, - &NeuronUpdateGroupMerged::isDerivedParamHeterogeneous); - - // Add variables - typeEnvironment.defineVars(nm->getVars(), backend.getDeviceVarPrefix()); - - // Add EGPs - typeEnvironment.defineEGPs(nm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); - // Loop through neuron groups - std::vector> eventThresholdSGs; + /*std::vector> eventThresholdSGs; for(const auto &g : getGroups()) { // Reserve vector for this group's children eventThresholdSGs.emplace_back(); @@ -608,36 +574,7 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC } i++; } - } - - if(getArchetype().isSpikeRecordingEnabled()) { - // Add field for spike recording - addField(Uint32.createPointer(), "recordSpk", - [&backend](const auto &ng, size_t) - { - return backend.getDeviceVarPrefix() + "recordSpk" + ng.getName(); - }, - GroupMergedFieldType::DYNAMIC); - } - - if(getArchetype().isSpikeEventRecordingEnabled()) { - // Add field for spike event recording - addField(Uint32.createPointer(), "recordSpkEvent", - [&backend](const auto &ng, size_t) - { - return backend.getDeviceVarPrefix() + "recordSpkEvent" + ng.getName(); - }, - GroupMergedFieldType::DYNAMIC); - } - - // Parse code - ErrorHandler errorHandler; - std::tie(m_SimStatements, m_SimResolvedTypes) = scanParseAndTypeCheckStatements( - nm->getSimCode(), typeContext, typeEnvironment, errorHandler); - std::tie(m_ThresholdConditionExpression, m_ThresholdConditionResolvedTypes) = scanParseAndTypeCheckExpression( - nm->getThresholdConditionCode(), typeContext, typeEnvironment, errorHandler); - std::tie(m_ResetStatements, m_ResetResolvedTypes) = scanParseAndTypeCheckStatements( - nm->getResetCode(), typeContext, typeEnvironment, errorHandler); + }*/ // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, typeEnvironment, backend, @@ -694,7 +631,7 @@ boost::uuids::detail::sha1::digest_type NeuronUpdateGroupMerged::getHashDigest() return hash.get_digest(); } //-------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged, +void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, BackendBase::GroupHandlerEnv genEmitTrueSpike, BackendBase::GroupHandlerEnv genEmitSpikeLikeEvent) const { @@ -702,44 +639,50 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E const unsigned int batchSize = model.getBatchSize(); const NeuronModels::Base *nm = getArchetype().getNeuronModel(); - - EnvironmentSubstitute neuronEnv(env); - neuronEnv.addSubstitution("Isyn", "Isyn", - {neuronEnv.addInitialiser("scalar Isyn = 0;")}); + EnvironmentGroupMergedField neuronEnv(env, *this); + + // Add field for spike recording + neuronEnv.add(Type::Uint32.createPointer(), "_record_spk", "recordSpk", + [&backend](const auto &ng, size_t) + { + return backend.getDeviceVarPrefix() + "recordSpk" + ng.getName(); + }, + "", GroupMergedFieldType::DYNAMIC); + + // Add field for spike event recording + neuronEnv.add(Type::Uint32.createPointer(), "_record_spk_event", "recordSpkEvent", + [&backend](const auto &ng, size_t) + { + return backend.getDeviceVarPrefix() + "recordSpkEvent" + ng.getName(); + }, + "", GroupMergedFieldType::DYNAMIC); + + // Add default input variable + neuronEnv.add(modelMerged.getModel().getPrecision(), "Isyn", "Isyn", + {neuronEnv.addInitialiser("scalar Isyn = 0;")}); // **NOTE** arbitrary code in param value to be deprecated for (const auto &v : nm->getAdditionalInputVars()) { - const std::string typeName = v.type.resolve(getTypeContext()).getName(); - neuronEnv.addSubstitution(v.name, v.value, - {neuronEnv.addInitialiser(typeName + " " + v.name + " = " + v.value + ";")}); + const auto resolvedType = v.type.resolve(getTypeContext()); + neuronEnv.add(resolvedType, v.name, v.name, + {neuronEnv.addInitialiser(resolvedType.getName() + " " + v.name + " = " + v.value + ";")}); } - neuronEnv.addParamValueSubstitution(nm->getParamNames(), getArchetype().getParams(), "", - [this](const std::string &p) { return isParamHeterogeneous(p); }); - neuronEnv.addVarValueSubstitution(nm->getDerivedParams(), getArchetype().getDerivedParams(), "", - [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); - neuronEnv.addVarNameSubstitution(nm->getExtraGlobalParams()); + // Substitute parameter and derived parameter names + neuronEnv.addParams(nm->getParamNames(), "", &NeuronGroupInternal::getParams, &NeuronUpdateGroupMerged::isParamHeterogeneous); + neuronEnv.addDerivedParams(nm->getDerivedParams(), "", &NeuronGroupInternal::getDerivedParams, &NeuronUpdateGroupMerged::isDerivedParamHeterogeneous); + neuronEnv.addEGPs(backend.getDeviceVarPrefix()); - if(getArchetype().isSpikeTimeRequired()) { - const size_t initialiser = neuronEnv.addInitialiser( - "const timepoint lsT = group->sT[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, neuronEnv["id"]) + "];"); - neuronEnv.addSubstitution("sT", "lsT", {initialiser}); - } - if(getArchetype().isPrevSpikeTimeRequired()) { - const size_t initialiser = neuronEnv.addInitialiser( - "const timepoint lprevST = group->prevST[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, neuronEnv["id"]) + "];"); - neuronEnv.addSubstitution("prev_sT", "lprevST", {initialiser}); - } - if(getArchetype().isSpikeEventTimeRequired()) { - const size_t initialiser = neuronEnv.addInitialiser( - "const timepoint lseT = group->seT[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, neuronEnv["id"]) + "];"); - neuronEnv.addSubstitution("seT", "lseT", {initialiser}); - } - if(getArchetype().isPrevSpikeEventTimeRequired()) { - const size_t initialiser = neuronEnv.addInitialiser( - "const timepoint lprevSET = group->prevSET[" + getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, neuronEnv["id"]) + "];"); - neuronEnv.addSubstitution("prev_seT", "lprevSET", {initialiser}); - } + // Substitute spike times + const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, neuronEnv["id"]); + neuronEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "sT", "lsT", + {neuronEnv.addInitialiser("const timepoint lsT = " + neuronEnv["_spk_time"] + "[" + spikeTimeReadIndex + "];")}); + neuronEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "prev_sT", "lprevST", + {neuronEnv.addInitialiser("const timepoint lprevST = " + neuronEnv["_prev_spk_time"] + "[" + spikeTimeReadIndex + "];")}); + neuronEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "seT", "lseT", + {neuronEnv.addInitialiser("const timepoint lseT = " + neuronEnv["_spk_evnt_time"] + "[" + spikeTimeReadIndex+ "];")}); + neuronEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "prev_seT", "lprevSET", + {neuronEnv.addInitialiser("const timepoint lprevSET = " + neuronEnv["_prev_spk_evnt_time"] + "[" + spikeTimeReadIndex + "];")}); // Create an environment which caches variables in local variables if they are accessed // **NOTE** we do this right at the top so that local copies can be used by child groups @@ -760,25 +703,25 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Loop through incoming synapse groups for(const auto &sg : getMergedInSynPSMGroups()) { CodeStream::Scope b(env.getStream()); - sg.generate(backend, env, *this, modelMerged); + sg.generate(backend, neuronVarEnv, *this, modelMerged); } // Loop through outgoing synapse groups with presynaptic output for (const auto &sg : getMergedOutSynPreOutputGroups()) { CodeStream::Scope b(env.getStream()); - sg.generate(env, *this, modelMerged); + sg.generate(neuronVarEnv, *this, modelMerged); } // Loop through all of neuron group's current sources for (const auto &cs : getMergedCurrentSourceGroups()) { CodeStream::Scope b(env.getStream()); - cs.generate(backend, env, *this, modelMerged); + cs.generate(backend, neuronVarEnv, *this, modelMerged); } // If a threshold condition is provided - if (m_ThresholdConditionExpression) { + if (!nm->getThresholdConditionCode().empty()) { neuronVarEnv.getStream() << "// test whether spike condition was fulfilled previously" << std::endl; //if (!nm->getSupportCode().empty() && !backend.supportsNamespace()) { @@ -964,7 +907,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E } } //-------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged) const +void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const { // Generate var update for outgoing synaptic populations with presynaptic update code for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index b8b9c0de71..b4d719d5bc 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -53,8 +53,8 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa kernelIndexStream << ";" << std::endl; // Add substitution - env.add(Type::Uint32, "id_kernel", "kernelInd", - {synEnv.addInitialiser(kernelIndexStream.str())}); + synEnv.add(Type::Uint32, "id_kernel", "kernelInd", + {synEnv.addInitialiser(kernelIndexStream.str())}); } // If weights are individual, substitute variables for values stored in global memory @@ -109,7 +109,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa assert(!sg.getArchetype().getKernelSize().empty()); synEnv.addVars(backend.getDeviceVarPrefix(), - [&sg, &synEnv, batchSize](VarAccess a, const std::string&) + [&synEnv, batchSize](VarAccess a, const std::string&) { return "[" + sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), synapseSubs["id_kernel"]) + "]"; }, From adf3d96789725cc44158414cb847c682583bbc79 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 15 Jun 2023 18:12:10 +0100 Subject: [PATCH 223/725] WIP upgrading of ``EnvironmentFieldPolicy`` to handle case where fields live in a different sort of group that group --- .../genn/genn/code_generator/backendBase.h | 2 +- .../genn/genn/code_generator/environment.h | 32 +++++++++++++------ .../code_generator/neuronUpdateGroupMerged.cc | 1 + 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 8416eea854..8be1874822 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -35,7 +35,7 @@ class SynapseGroupInternal; namespace CodeGenerator { -template +template class EnvironmentGroupMergedField; class EnvironmentExternalBase; class ModelSpecMerged; diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 98ad8cb3da..270a179e36 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -137,12 +137,19 @@ struct EnvironmentSubstitutionPolicy //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentFieldPolicy //---------------------------------------------------------------------------- -template +template struct EnvironmentFieldPolicy { - using Payload = std::tuple>; + using Payload = std::tuple>; + using GetGroupsFn = const std::vector> &(F::*)() const; - EnvironmentFieldPolicy(G &group) : m_Group(group) + EnvironmentFieldPolicy(G &group, F &fieldGroup, GetGroupsFn getGroups) + : m_Group(group), m_FieldGroup(fieldGroup), m_GetGroups(getGroups) + { + } + + // **TODO** only enable if G == F + EnvironmentFieldPolicy(G &group) : EnvironmentFieldPolicy(group, group, &G::getGroups) { } @@ -165,16 +172,22 @@ struct EnvironmentFieldPolicy if (std::get<2>(payload) && !std::get<0>(payload)) { // Call function to add field to underlying merged group const auto &field = std::get<2>(payload).get(); - m_Group.addField(std::get<0>(field), std::get<1>(field), - std::get<2>(field), std::get<3>(field)); + m_FieldGroup.addField(std::get<0>(field), std::get<1>(field), + [this](const F &, size_t i) + { + const auto &childGroups = std::invoke(m_FieldGroup, m_GetGroups).at(m_Group.getIndex()); + return .get(); + }, + std::get<3>(field)); // Set flag so field doesn't get re-added std::get<0>(payload) = true; } } -private: + std::reference_wrapper m_FieldGroup; std::reference_wrapper m_Group; + GetGroupsFn m_GetGroups; }; //---------------------------------------------------------------------------- @@ -314,8 +327,8 @@ class EnvironmentExternal : public EnvironmentExternalDynamicBase -class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase> +template +class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase> { using GroupInternal = typename G::GroupInternal; using IsHeterogeneousFn = bool (G::*)(const std::string&) const; @@ -405,7 +418,8 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase synEnv(env, *this, ng, &NeuronUpdateGroupMerged::getMergedCurrentSourceGroups); // Create new substitution environment and add parameters, derived parameters and extra global parameters EnvironmentSubstitute envSubs(env); envSubs.getStream() << "// current source " << getIndex() << std::endl; From 1978bbe6ad9facfae236825c973f1d8af6709b41 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 16 Jun 2023 11:07:10 +0100 Subject: [PATCH 224/725] fixed up enough typos that SynapseUpdateGroupMerged compiles! --- .../genn/genn/code_generator/codeGenUtils.h | 15 ++- .../code_generator/customUpdateGroupMerged.h | 4 +- .../genn/genn/code_generator/environment.h | 112 +++++++++--------- .../genn/genn/code_generator/groupMerged.h | 6 +- .../genn/code_generator/initGroupMerged.h | 4 +- .../code_generator/synapseUpdateGroupMerged.h | 14 +-- include/genn/genn/synapseGroupInternal.h | 1 + .../synapseUpdateGroupMerged.cc | 49 ++++---- 8 files changed, 105 insertions(+), 100 deletions(-) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 701b0789db..2dba711f78 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -211,23 +211,26 @@ std::string getKernelSize(const G *group, size_t dimensionIndex, K getKernelSize } template -void genKernelIndex(const G *group, std::ostream &os, const CodeGenerator::Substitutions &subs, - K getKernelSizeFn) +std::string getKernelIndex(const G *group, EnvironmentExternalBase &env, + K getKernelSizeFn) { // Loop through kernel dimensions to calculate array index const auto &kernelSize = getKernelSizeFn(group->getArchetype()); + std::ostringstream kernelIndex; for (size_t i = 0; i < kernelSize.size(); i++) { - os << "(" << subs["id_kernel_" + std::to_string(i)]; + kernelIndex << "(" << env["id_kernel_" + std::to_string(i)]; // Loop through remainining dimensions of kernel and multiply for (size_t j = i + 1; j < kernelSize.size(); j++) { - os << " * " << getKernelSize(group, j, getKernelSizeFn); + kernelIndex << " * " << getKernelSize(group, j, getKernelSizeFn); } - os << ")"; + kernelIndex << ")"; // If this isn't the last dimension, add + if (i != (kernelSize.size() - 1)) { - os << " + "; + kernelIndex << " + "; } } + + return kernelIndex.str(); } } // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index c96fcd3825..eca80747e2 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -88,9 +88,9 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged -struct EnvironmentFieldPolicy +class EnvironmentFieldPolicy { +protected: using Payload = std::tuple>; - using GetGroupsFn = const std::vector> &(F::*)() const; - - EnvironmentFieldPolicy(G &group, F &fieldGroup, GetGroupsFn getGroups) - : m_Group(group), m_FieldGroup(fieldGroup), m_GetGroups(getGroups) + + EnvironmentFieldPolicy(G &group, F &fieldGroup) + : m_Group(group), m_FieldGroup(fieldGroup) { } // **TODO** only enable if G == F - EnvironmentFieldPolicy(G &group) : EnvironmentFieldPolicy(group, group, &G::getGroups) + EnvironmentFieldPolicy(G &group) : EnvironmentFieldPolicy(group, group) { } - std::string getName(const Payload &payload) + std::string getNameInternal(const Payload &payload) { // If a field is specified if(std::get<2>(payload)) { - return "group->" + std::get<1>(std::get<2>(payload).get()) + std::get<1>(payload); + return "group->" + std::get<1>(std::get<2>(payload).value()) + std::get<1>(payload); } // Otherwise, use value directly else { @@ -170,31 +171,34 @@ struct EnvironmentFieldPolicy { // If a field is specified but it hasn't already been added if (std::get<2>(payload) && !std::get<0>(payload)) { - // Call function to add field to underlying merged group - const auto &field = std::get<2>(payload).get(); - m_FieldGroup.addField(std::get<0>(field), std::get<1>(field), - [this](const F &, size_t i) - { - const auto &childGroups = std::invoke(m_FieldGroup, m_GetGroups).at(m_Group.getIndex()); - return .get(); - }, - std::get<3>(field)); + // Extract field from payload + const auto &field = std::get<2>(payload).value(); + + // Add to field group using lambda function to potentially map from group to field + m_FieldGroup.get().addField(std::get<0>(field), std::get<1>(field), + [this, &field](const typename F::GroupInternal &, size_t i) + { + return std::get<2>(field)(getGroup().getGroups().at(i), i); + }, + std::get<3>(field)); // Set flag so field doesn't get re-added std::get<0>(payload) = true; } } + const G &getGroup() const{ return m_Group; } + +private: std::reference_wrapper m_FieldGroup; std::reference_wrapper m_Group; - GetGroupsFn m_GetGroups; }; //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentExternalDynamicBase //---------------------------------------------------------------------------- template -class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, private P +class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, protected P { public: template @@ -233,7 +237,7 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, private P } // Otherwise, get name from payload else { - return getName(std::get<3>(env->second)); + return getNameInternal(std::get<3>(env->second)); } } @@ -360,7 +364,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &initialisers = {}, const std::vector &dependents = {}) { - addInternal(type, name, std::make_tuple(false, indexSuffix, std::forward_as_tuple(std::in_place, fieldType, fieldName, getFieldValue, mergedFieldType)), + addInternal(type, name, std::make_tuple(false, indexSuffix, std::make_optional(std::make_tuple(fieldType, fieldName, getFieldValue, mergedFieldType))), initialisers, dependents); } @@ -374,8 +378,8 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &dependents = {}) { // Loop through variables - const A archetypeAdaptor(m_Group.getArchetype()); + const A archetypeAdaptor(getGroup().getArchetype()); for(const auto &v : archetypeAdaptor.getDefs()) { - const auto resolvedType = v.type.resolve(m_Group.getTypeContext()) + const auto resolvedType = v.type.resolve(getGroup().getTypeContext()); const auto qualifiedType = (getVarAccessMode(v.access) & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; add(qualifiedType, v.name, resolvedType.createPointer(), v.name + fieldSuffix, [arrayPrefix, getIndexFn, v](const auto &g, size_t) { - return prefix + v.name + A(g).getNameSuffix(); + return arrayPrefix + v.name + A(g).getNameSuffix(); }, getIndexFn(v.access, v.name), GroupMergedFieldType::STANDARD, {}, dependents); } @@ -504,10 +508,10 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &dependents = {}) { // Loop through variable references - const A archetypeAdaptor(m_Group.getArchetype()); + const A archetypeAdaptor(getGroup().getArchetype()); for(const auto &v : archetypeAdaptor.getDefs()) { // If variable access is read-only, qualify type with const - const auto resolvedType = v.type.resolve(m_Group.getTypeContext()); + const auto resolvedType = v.type.resolve(getGroup().getTypeContext()); const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; add(qualifiedType, v.name, resolvedType.createPointer(), v.name + fieldSuffix, @@ -532,16 +536,16 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase m_Group; - //! Environment mapping names to types to fields to pull values from std::unordered_map>> m_Environment; }; diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index e9f98bec5d..e33773e8de 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -29,7 +29,7 @@ namespace GeNN::CodeGenerator class CodeStream; } -namespace Transpiler::TypeChecker +namespace GeNN::Transpiler::TypeChecker { class EnvironmentBase; } @@ -752,9 +752,9 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged &getInitialisers() const{ return m_SG.getWUVarInitialisers(); } + const std::string &getNameSuffix() const{ return m_SG.getName(); } private: //---------------------------------------------------------------------------- // Members diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index b4d719d5bc..ca8485f7ef 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -16,13 +16,13 @@ namespace { template void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBase &env, std::string code, const std::string &errorContext, - const G &sg, const ModelSpecMerged &modelMerged, bool backendSupportsNamespace) + G &sg, const ModelSpecMerged &modelMerged, bool backendSupportsNamespace) { const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); const auto *wu = sg.getArchetype().getWUModel(); - EnvironmentGroupMergedField synEnv(sg, env); + EnvironmentGroupMergedField synEnv(env, sg); // Substitute parameter and derived parameter names synEnv.addParams(wu->getParamNames(), "", &SynapseGroupInternal::getWUParams, &G::isWUParamHeterogeneous); @@ -46,15 +46,10 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa // If this synapse group has a kernel if (!sg.getArchetype().getKernelSize().empty()) { - // Generate kernel index calculation - std::ostringstream kernelIndexStream; - kernelIndexStream << "const unsigned int kernelInd = "; - sg.genKernelIndex(kernelIndexStream, synEnv); - kernelIndexStream << ";" << std::endl; - // Add substitution + // **TODO** dependencies on kernel fields synEnv.add(Type::Uint32, "id_kernel", "kernelInd", - {synEnv.addInitialiser(kernelIndexStream.str())}); + {synEnv.addInitialiser("const unsigned int kernelInd = " + sg.getKernelIndex(synEnv) + ";")}); } // If weights are individual, substitute variables for values stored in global memory @@ -68,12 +63,14 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa } // Otherwise, if weights are procedual else if (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::PROCEDURAL) { - const auto vars = wu->getVars(); + assert(false); + /*const auto vars = wu->getVars(); for(const auto &var : vars) { const auto &varInit = sg.getArchetype().getWUVarInitialisers().at(var.name); - + // If this variable has any initialisation code if(!varInit.getSnippet()->getCode().empty()) { + // Configure variable substitutions CodeGenerator::Substitutions varSubs(&synapseSubs); varSubs.addVarSubstitution("value", "l" + var.name); @@ -102,24 +99,25 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa } // Substitute variables for newly-declared local variables - synEnv.add(vars, "", "l"); + synEnv.add(vars, "", "l");*/ } // Otherwise, if weights are kernels, use kernel index to index into variables else if(sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL) { assert(!sg.getArchetype().getKernelSize().empty()); synEnv.addVars(backend.getDeviceVarPrefix(), - [&synEnv, batchSize](VarAccess a, const std::string&) + [&sg, &synEnv, batchSize](VarAccess a, const std::string&) { - return "[" + sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), synapseSubs["id_kernel"]) + "]"; + return "[" + sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), synEnv["id_kernel"]) + "]"; }, {"id_kernel"}); } // Otherwise, substitute variables for constant values else { - synapseSubs.addVarValueSubstitution(wu->getVars(), sg.getArchetype().getWUConstInitVals(), + assert(false); + /*synapseSubs.addVarValueSubstitution(wu->getVars(), sg.getArchetype().getWUConstInitVals(), [&sg](const std::string &v) { return sg.isWUGlobalVarHeterogeneous(v); }, - "", "group->"); + "", "group->");*/ } // Make presynaptic neuron substitutions @@ -177,7 +175,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa //---------------------------------------------------------------------------- const std::string PresynapticUpdateGroupMerged::name = "PresynapticUpdate"; //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateSpikeEventThreshold(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const +void PresynapticUpdateGroupMerged::generateSpikeEventThreshold(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { EnvironmentGroupMergedField synEnv(env, *this); @@ -206,29 +204,30 @@ void PresynapticUpdateGroupMerged::generateSpikeEventThreshold(const BackendBase prettyPrintStatements(wum->getEventThresholdConditionCode(), getTypeContext(), synEnv, errorHandler); } //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateSpikeEventUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const +void PresynapticUpdateGroupMerged::generateSpikeEventUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { applySynapseSubstitutions(backend, env, getArchetype().getWUModel()->getEventCode(), "eventCode", *this, modelMerged, backend.supportsNamespace()); } //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateSpikeUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const +void PresynapticUpdateGroupMerged::generateSpikeUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { applySynapseSubstitutions(backend, env, getArchetype().getWUModel()->getSimCode(), "simCode", *this, modelMerged, backend.supportsNamespace()); } //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendBase&, EnvironmentExternalBase &env) const +void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendBase&, EnvironmentExternalBase &env) { const auto &connectInit = getArchetype().getConnectivityInitialiser(); EnvironmentGroupMergedField synEnv(env, *this); + assert(false); // Add substitutions //synEnv.addParams() //synEnv.addParams(wu->getParamNames(), "", &SynapseGroupInternal::getWUParams, &G::isWUParamHeterogeneous); //synEnv.addDerivedParams(wu->getDerivedParams(), "", &SynapseGroupInternal::getWUDerivedParams, &G::isWUDerivedParamHeterogeneous); - popSubs.addParamValueSubstitution(connectInit.getSnippet()->getParamNames(), connectInit.getParams(), + /*popSubs.addParamValueSubstitution(connectInit.getSnippet()->getParamNames(), connectInit.getParams(), [this](const std::string &p) { return isSparseConnectivityInitParamHeterogeneous(p); }, "", "group->"); popSubs.addVarValueSubstitution(connectInit.getSnippet()->getDerivedParams(), connectInit.getDerivedParams(), @@ -244,10 +243,10 @@ void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendB //pCode = ensureFtype(pCode, modelMerged.getModel().getPrecision()); // Write out code - os << pCode << std::endl; + os << pCode << std::endl;*/ } //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateToeplitzConnectivity(const BackendBase&, EnvironmentExternalBase &env) const +void PresynapticUpdateGroupMerged::generateToeplitzConnectivity(const BackendBase&, EnvironmentExternalBase &env) { // Pretty print code back to environment Transpiler::ErrorHandler errorHandler("toeplitzSparseConnectivity" + std::to_string(getIndex())); @@ -260,7 +259,7 @@ void PresynapticUpdateGroupMerged::generateToeplitzConnectivity(const BackendBas //---------------------------------------------------------------------------- const std::string PostsynapticUpdateGroupMerged::name = "PostsynapticUpdate"; //---------------------------------------------------------------------------- -void PostsynapticUpdateGroupMerged::generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const +void PostsynapticUpdateGroupMerged::generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { const auto *wum = getArchetype().getWUModel(); /*if (!wum->getLearnPostSupportCode().empty() && backend.supportsNamespace()) { @@ -276,7 +275,7 @@ void PostsynapticUpdateGroupMerged::generateSynapseUpdate(const BackendBase &bac //---------------------------------------------------------------------------- const std::string SynapseDynamicsGroupMerged::name = "SynapseDynamics"; //---------------------------------------------------------------------------- -void SynapseDynamicsGroupMerged::generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const +void SynapseDynamicsGroupMerged::generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { const auto *wum = getArchetype().getWUModel(); /*if (!wum->getSynapseDynamicsSuppportCode().empty() && backend.supportsNamespace()) { From b268db15f1fa0770544dd7429b0ee0cc23c38d8a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 16 Jun 2023 11:16:52 +0100 Subject: [PATCH 225/725] ChildGroupMerged - very minimal version with no fields --- .../genn/genn/code_generator/groupMerged.h | 202 ++++++++++-------- 1 file changed, 116 insertions(+), 86 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index e33773e8de..d2e5785c7b 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -59,24 +59,135 @@ inline bool operator & (GroupMergedFieldType typeA, GroupMergedFieldType typeB) return (static_cast(typeA) & static_cast(typeB)) != 0; } +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::ChildGroupMerged +//---------------------------------------------------------------------------- +template +class ChildGroupMerged +{ +public: + //------------------------------------------------------------------------ + // Typedefines + //------------------------------------------------------------------------ + typedef G GroupInternal; + + ChildGroupMerged(size_t index, const std::vector> groups) + : m_Index(index), m_Groups(std::move(groups)) + {} + + ChildGroupMerged(const ChildGroupMerged&) = delete; + ChildGroupMerged(ChildGroupMerged&&) = default; + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + size_t getIndex() const { return m_Index; } + + //! Get 'archetype' neuron group - it's properties represent those of all other merged neuron groups + const GroupInternal &getArchetype() const { return m_Groups.front().get(); } + + //! Gets access to underlying vector of neuron groups which have been merged + const std::vector> &getGroups() const{ return m_Groups; } + +protected: + //------------------------------------------------------------------------ + // Protected API + //------------------------------------------------------------------------ + //! Helper to update hash with the hash of calling getHashableFn on each group + template + void updateHash(H getHashableFn, boost::uuids::detail::sha1 &hash) const + { + for(const auto &g : getGroups()) { + Utils::updateHash(getHashableFn(g.get()), hash); + } + } + + template + void updateParamHash(R isParamReferencedFn, V getValueFn, boost::uuids::detail::sha1 &hash) const + { + // Loop through parameters + const auto &archetypeParams = getValueFn(getArchetype()); + for(const auto &p : archetypeParams) { + // If any of the code strings reference the parameter + if((static_cast(this)->*isParamReferencedFn)(p.first)) { + // Loop through groups + for(const auto &g : getGroups()) { + // Update hash with parameter value + Utils::updateHash(getValueFn(g.get()).at(p.first), hash); + } + } + } + } + + template + void updateVarInitParamHash(R isParamReferencedFn, boost::uuids::detail::sha1 &hash) const + { + // Loop through variables + const auto &archetypeVarInitialisers = A(getArchetype()).getInitialisers(); + for(const auto &varInit : archetypeVarInitialisers) { + // Loop through parameters + for(const auto &p : varInit.second.getParams()) { + // If any of the code strings reference the parameter + if((static_cast(this)->*isParamReferencedFn)(varInit.first, p.first)) { + // Loop through groups + for(const auto &g : getGroups()) { + const auto &values = A(g.get()).getInitialisers().at(varInit.first).getParams(); + + // Update hash with parameter value + Utils::updateHash(values.at(p.first), hash); + } + } + } + } + } + + template + void updateVarInitDerivedParamHash(R isDerivedParamReferencedFn, boost::uuids::detail::sha1 &hash) const + { + // Loop through variables + const auto &archetypeVarInitialisers = A(getArchetype()).getInitialisers(); + for(const auto &varInit : archetypeVarInitialisers) { + // Loop through parameters + for(const auto &d : varInit.second.getDerivedParams()) { + // If any of the code strings reference the parameter + if((static_cast(this)->*isDerivedParamReferencedFn)(varInit.first, d.first)) { + // Loop through groups + for(const auto &g : getGroups()) { + const auto &values = A(g.get()).getInitialisers().at(varInit.first).getDerivedParams(); + + // Update hash with parameter value + Utils::updateHash(values.at(d.first), hash); + } + } + } + } + } + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + size_t m_Index; + std::vector> m_Groups; +}; + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::GroupMerged //---------------------------------------------------------------------------- //! Very thin wrapper around a number of groups which have been merged together template -class GroupMerged +class GroupMerged : public ChildGroupMerged { public: //------------------------------------------------------------------------ // Typedefines //------------------------------------------------------------------------ - typedef G GroupInternal; typedef std::function GetFieldValueFunc; typedef std::function GetFieldDoubleValueFunc; typedef std::tuple Field; GroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> groups) - : m_Index(index), m_TypeContext(typeContext), m_Groups(std::move(groups)) + : ChildGroupMerged(index, std::move(groups)), m_TypeContext(typeContext) {} GroupMerged(const GroupMerged&) = delete; @@ -85,20 +196,12 @@ class GroupMerged //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - size_t getIndex() const { return m_Index; } - - //! Get 'archetype' neuron group - it's properties represent those of all other merged neuron groups - const GroupInternal &getArchetype() const { return m_Groups.front().get(); } - //! Get type context used to resolve any types involved in this group const Type::TypeContext &getTypeContext() const{ return m_TypeContext; } //! Get name of memory space assigned to group const std::string &getMemorySpace() const { return m_MemorySpace; } - //! Gets access to underlying vector of neuron groups which have been merged - const std::vector> &getGroups() const{ return m_Groups; } - //! Get group fields const std::vector &getFields() const{ return m_Fields; } @@ -397,76 +500,6 @@ class GroupMerged } } - //! Helper to update hash with the hash of calling getHashableFn on each group - template - void updateHash(H getHashableFn, boost::uuids::detail::sha1 &hash) const - { - for(const auto &g : getGroups()) { - Utils::updateHash(getHashableFn(g.get()), hash); - } - } - - template - void updateParamHash(R isParamReferencedFn, V getValueFn, boost::uuids::detail::sha1 &hash) const - { - // Loop through parameters - const auto &archetypeParams = getValueFn(getArchetype()); - for(const auto &p : archetypeParams) { - // If any of the code strings reference the parameter - if((static_cast(this)->*isParamReferencedFn)(p.first)) { - // Loop through groups - for(const auto &g : getGroups()) { - // Update hash with parameter value - Utils::updateHash(getValueFn(g.get()).at(p.first), hash); - } - } - } - } - - template - void updateVarInitParamHash(R isParamReferencedFn, boost::uuids::detail::sha1 &hash) const - { - // Loop through variables - const auto &archetypeVarInitialisers = A(getArchetype()).getInitialisers(); - for(const auto &varInit : archetypeVarInitialisers) { - // Loop through parameters - for(const auto &p : varInit.second.getParams()) { - // If any of the code strings reference the parameter - if((static_cast(this)->*isParamReferencedFn)(varInit.first, p.first)) { - // Loop through groups - for(const auto &g : getGroups()) { - const auto &values = A(g.get()).getInitialisers().at(varInit.first).getParams(); - - // Update hash with parameter value - Utils::updateHash(values.at(p.first), hash); - } - } - } - } - } - - template - void updateVarInitDerivedParamHash(R isDerivedParamReferencedFn, boost::uuids::detail::sha1 &hash) const - { - // Loop through variables - const auto &archetypeVarInitialisers = A(getArchetype()).getInitialisers(); - for(const auto &varInit : archetypeVarInitialisers) { - // Loop through parameters - for(const auto &d : varInit.second.getDerivedParams()) { - // If any of the code strings reference the parameter - if((static_cast(this)->*isDerivedParamReferencedFn)(varInit.first, d.first)) { - // Loop through groups - for(const auto &g : getGroups()) { - const auto &values = A(g.get()).getInitialisers().at(varInit.first).getDerivedParams(); - - // Update hash with parameter value - Utils::updateHash(values.at(d.first), hash); - } - } - } - } - } - void generateRunnerBase(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, @@ -549,11 +582,9 @@ class GroupMerged //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - size_t m_Index; const Type::TypeContext &m_TypeContext; std::string m_MemorySpace; std::vector m_Fields; - std::vector> m_Groups; }; //---------------------------------------------------------------------------- @@ -641,8 +672,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged - void orderNeuronGroupChildren(std::vector &childGroups, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, G getVectorFunc, H getHashDigestFunc) const + void orderNeuronGroupChildren(std::vector &childGroups, G getVectorFunc, H getHashDigestFunc) const { const std::vector &archetypeChildren = (getArchetype().*getVectorFunc)(); @@ -683,7 +713,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged Date: Fri, 16 Jun 2023 11:38:40 +0100 Subject: [PATCH 226/725] restored over-zealously deleted SynapseGroup methods --- include/genn/genn/synapseGroup.h | 6 +++++ src/genn/genn/synapseGroup.cc | 45 ++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/include/genn/genn/synapseGroup.h b/include/genn/genn/synapseGroup.h index b6c4e3d21e..58dd7a39f4 100644 --- a/include/genn/genn/synapseGroup.h +++ b/include/genn/genn/synapseGroup.h @@ -201,6 +201,12 @@ class GENN_EXPORT SynapseGroup /*! This is only used by extra global parameters which are pointers*/ VarLocation getSparseConnectivityExtraGlobalParamLocation(const std::string ¶mName) const; + //! Does this synapse group require dendritic delay? + bool isDendriticDelayRequired() const; + + //! Does this synapse group define presynaptic output? + bool isPresynapticOutputRequired() const; + //! Does this synapse group require an RNG to generate procedural connectivity? bool isProceduralConnectivityRNGRequired() const; diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 93815f9f03..960c1a6ea9 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -305,6 +305,51 @@ VarLocation SynapseGroup::getSparseConnectivityExtraGlobalParamLocation(const st return m_ConnectivityExtraGlobalParamLocation[m_SparseConnectivityInitialiser.getSnippet()->getExtraGlobalParamIndex(paramName)]; } //---------------------------------------------------------------------------- +bool SynapseGroup::isDendriticDelayRequired() const +{ + // If addToInSynDelay function is used in sim code, return true + if(getWUModel()->getSimCode().find("$(addToInSynDelay") != std::string::npos) { + return true; + } + + // If addToInSynDelay function is used in event code, return true + if(getWUModel()->getEventCode().find("$(addToInSynDelay") != std::string::npos) { + return true; + } + + // If addToInSynDelay function is used in synapse dynamics, return true + if(getWUModel()->getSynapseDynamicsCode().find("$(addToInSynDelay") != std::string::npos) { + return true; + } + + return false; +} +//---------------------------------------------------------------------------- +bool SynapseGroup::isPresynapticOutputRequired() const +{ + // If addToPre function is used in sim_code, return true + if(getWUModel()->getSimCode().find("$(addToPre") != std::string::npos) { + return true; + } + + // If addToPre function is used in learn_post_code, return true + if(getWUModel()->getLearnPostCode().find("$(addToPre") != std::string::npos) { + return true; + } + + // If addToPre function is used in event_code, return true + if(getWUModel()->getEventCode().find("$(addToPre") != std::string::npos) { + return true; + } + + // If addToPre function is used in synapse_dynamics, return true + if(getWUModel()->getSynapseDynamicsCode().find("$(addToPre") != std::string::npos) { + return true; + } + + return false; +} +//---------------------------------------------------------------------------- bool SynapseGroup::isProceduralConnectivityRNGRequired() const { if(m_MatrixType & SynapseMatrixConnectivity::PROCEDURAL) { From 3db5a5572fea3c0b32cb1951ad7af9ced83a76e0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 16 Jun 2023 12:39:40 +0100 Subject: [PATCH 227/725] EnvironmentLocalVarCache broken but, otherwise, NeuronUpdateGroupMerged almost compiling --- .../genn/genn/code_generator/environment.h | 72 +-- .../genn/genn/code_generator/groupMerged.h | 90 ++-- .../code_generator/neuronUpdateGroupMerged.h | 45 +- include/genn/genn/synapseGroupInternal.h | 23 + .../code_generator/neuronUpdateGroupMerged.cc | 491 +++++++----------- 5 files changed, 299 insertions(+), 422 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index b55979d066..c21c674189 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -359,31 +359,31 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &initialisers = {}, const std::vector &dependents = {}) + void addField(const GeNN::Type::ResolvedType &type, const std::string &name, + const GeNN::Type::ResolvedType &fieldType, const std::string &fieldName, typename G::GetFieldValueFunc getFieldValue, + const std::string &indexSuffix = "", GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, + const std::vector &initialisers = {}, const std::vector &dependents = {}) { addInternal(type, name, std::make_tuple(false, indexSuffix, std::make_optional(std::make_tuple(fieldType, fieldName, getFieldValue, mergedFieldType))), initialisers, dependents); } //! Map a type (for type-checking) and a group merged field to back it to an identifier - void add(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &fieldName, - typename G::GetFieldValueFunc getFieldValue, const std::string &indexSuffix = "", GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, - const std::vector &initialisers = {}, const std::vector &dependents = {}) + void addField(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &fieldName, + typename G::GetFieldValueFunc getFieldValue, const std::string &indexSuffix = "", GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, + const std::vector &initialisers = {}, const std::vector &dependents = {}) { - add(type, name, type, fieldName, getFieldValue, indexSuffix, mergedFieldType, initialisers, dependents); + addField(type, name, type, fieldName, getFieldValue, indexSuffix, mergedFieldType, initialisers, dependents); } void addScalar(const std::string &name, const std::string &fieldSuffix, typename G::GetFieldDoubleValueFunc getFieldValue) { - add(getGroup().getScalarType().addConst(), name, - getGroup().getScalarType(), name + fieldSuffix, - [getFieldValue, this](const auto &g, size_t i) - { - return getScalarString(getFieldValue(g, i)); - }); + addField(getGroup().getScalarType().addConst(), name, + getGroup().getScalarType(), name + fieldSuffix, + [getFieldValue, this](const auto &g, size_t i) + { + return getScalarString(getFieldValue(g, i)); + }); } void addParams(const Snippet::Base::StringVec ¶mNames, const std::string &fieldSuffix, @@ -485,13 +485,13 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase GetFieldValueFunc; + typedef std::function GetFieldDoubleValueFunc; + typedef std::tuple Field; ChildGroupMerged(size_t index, const std::vector> groups) : m_Index(index), m_Groups(std::move(groups)) @@ -93,6 +96,46 @@ class ChildGroupMerged //------------------------------------------------------------------------ // Protected API //------------------------------------------------------------------------ + //! Helper to test whether parameter is referenced in vector of codestrings + bool isParamReferenced(const std::vector &codeStrings, const std::string ¶mName) const + { + return std::any_of(codeStrings.begin(), codeStrings.end(), + [¶mName](const std::string &c) + { + return (c.find("$(" + paramName + ")") != std::string::npos); + }); + } + + //! Helper to test whether parameter values are heterogeneous within merged group + template + bool isParamValueHeterogeneous(const std::string &name, P getParamValuesFn) const + { + // Get value of parameter in archetype group + const double archetypeValue = getParamValuesFn(getArchetype()).at(name); + + // Return true if any parameter values differ from the archetype value + return std::any_of(getGroups().cbegin(), getGroups().cend(), + [&name, archetypeValue, getParamValuesFn](const GroupInternal &g) + { + return (getParamValuesFn(g).at(name) != archetypeValue); + }); + } + + //! Helper to test whether parameter values are heterogeneous within merged group + template + bool isParamValueHeterogeneous(size_t index, P getParamValuesFn) const + { + // Get value of parameter in archetype group + const double archetypeValue = getParamValuesFn(getArchetype()).at(index); + + // Return true if any parameter values differ from the archetype value + return std::any_of(getGroups().cbegin(), getGroups().cend(), + [archetypeValue, index, getParamValuesFn](const GroupInternal &g) + { + return (getParamValuesFn(g).at(index) != archetypeValue); + }); + } + //! Helper to update hash with the hash of calling getHashableFn on each group template void updateHash(H getHashableFn, boost::uuids::detail::sha1 &hash) const @@ -179,13 +222,6 @@ template class GroupMerged : public ChildGroupMerged { public: - //------------------------------------------------------------------------ - // Typedefines - //------------------------------------------------------------------------ - typedef std::function GetFieldValueFunc; - typedef std::function GetFieldDoubleValueFunc; - typedef std::tuple Field; - GroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> groups) : ChildGroupMerged(index, std::move(groups)), m_TypeContext(typeContext) {} @@ -321,46 +357,6 @@ class GroupMerged : public ChildGroupMerged const Type::ResolvedType &getScalarType() const{ return m_TypeContext.at("scalar"); } const Type::ResolvedType &getTimeType() const{ return m_TypeContext.at("timepoint"); } - //! Helper to test whether parameter is referenced in vector of codestrings - bool isParamReferenced(const std::vector &codeStrings, const std::string ¶mName) const - { - return std::any_of(codeStrings.begin(), codeStrings.end(), - [¶mName](const std::string &c) - { - return (c.find("$(" + paramName + ")") != std::string::npos); - }); - } - - //! Helper to test whether parameter values are heterogeneous within merged group - template - bool isParamValueHeterogeneous(const std::string &name, P getParamValuesFn) const - { - // Get value of parameter in archetype group - const double archetypeValue = getParamValuesFn(getArchetype()).at(name); - - // Return true if any parameter values differ from the archetype value - return std::any_of(getGroups().cbegin(), getGroups().cend(), - [&name, archetypeValue, getParamValuesFn](const GroupInternal &g) - { - return (getParamValuesFn(g).at(name) != archetypeValue); - }); - } - - //! Helper to test whether parameter values are heterogeneous within merged group - template - bool isParamValueHeterogeneous(size_t index, P getParamValuesFn) const - { - // Get value of parameter in archetype group - const double archetypeValue = getParamValuesFn(getArchetype()).at(index); - - // Return true if any parameter values differ from the archetype value - return std::any_of(getGroups().cbegin(), getGroups().cend(), - [archetypeValue, index, getParamValuesFn](const GroupInternal &g) - { - return (getParamValuesFn(g).at(index) != archetypeValue); - }); - } - void addField(const Type::ResolvedType &type, const std::string &name, GetFieldValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) { // Add field to data structure diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index 5df32f81e3..d282302109 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -15,17 +15,16 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase // GeNN::CodeGenerator::NeuronUpdateGroupMerged::CurrentSource //---------------------------------------------------------------------------- //! Child group merged for current sources attached to this neuron update group - class CurrentSource : public GroupMerged + class CurrentSource : public ChildGroupMerged { public: - CurrentSource(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups); + using ChildGroupMerged::ChildGroupMerged; //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- void generate(const BackendBase &backend, EnvironmentExternalBase &env, - const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) const; + const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged); //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -48,17 +47,16 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynPSM //---------------------------------------------------------------------------- //! Child group merged for incoming synapse groups - class InSynPSM : public GroupMerged + class InSynPSM : public ChildGroupMerged { public: - InSynPSM(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups); + using ChildGroupMerged::ChildGroupMerged; //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- void generate(const BackendBase &backend, EnvironmentExternalBase &env, - const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) const; + const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged); //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -81,37 +79,35 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynPreOutput //---------------------------------------------------------------------------- //! Child group merged for outgoing synapse groups with $(addToPre) logic - class OutSynPreOutput : public GroupMerged + class OutSynPreOutput : public ChildGroupMerged { public: - OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups); + using ChildGroupMerged::ChildGroupMerged; //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged) const; + void generate(const BackendBase &backend, EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged); }; //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynWUMPostCode //---------------------------------------------------------------------------- //! Child group merged for incoming synapse groups with postsynaptic update/spike code - class InSynWUMPostCode : public GroupMerged + class InSynWUMPostCode : public ChildGroupMerged { public: - InSynWUMPostCode(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups); + using ChildGroupMerged::ChildGroupMerged; //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- void generate(const BackendBase &backend, EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) const; + const ModelSpecMerged &modelMerged, bool dynamicsNotSpike); void genCopyDelayedVars(EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged) const; + const ModelSpecMerged &modelMerged); //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -134,20 +130,19 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynWUMPreCode //---------------------------------------------------------------------------- //! Child group merged for outgoing synapse groups with presynaptic update/spike code - class OutSynWUMPreCode : public GroupMerged + class OutSynWUMPreCode : public ChildGroupMerged { public: - OutSynWUMPreCode(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups); + using ChildGroupMerged::ChildGroupMerged; //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- void generate(const BackendBase &backend, EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) const; + const ModelSpecMerged &modelMerged, bool dynamicsNotSpike); void genCopyDelayedVars(EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged) const; + const ModelSpecMerged &modelMerged); //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -186,9 +181,9 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase void generateNeuronUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, BackendBase::GroupHandlerEnv genEmitTrueSpike, - BackendBase::GroupHandlerEnv genEmitSpikeLikeEvent) const; + BackendBase::GroupHandlerEnv genEmitSpikeLikeEvent); - void generateWUVarUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const; + void generateWUVarUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); std::string getVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; std::string getReadVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h index 9814da9d92..4d0e2b410b 100644 --- a/include/genn/genn/synapseGroupInternal.h +++ b/include/genn/genn/synapseGroupInternal.h @@ -106,6 +106,29 @@ class SynapsePSMVarAdapter const SynapseGroupInternal &m_SG; }; +//---------------------------------------------------------------------------- +// SynapsePSMEGPAdapter +//---------------------------------------------------------------------------- +class SynapsePSMEGPAdapter +{ +public: + SynapsePSMEGPAdapter(const SynapseGroupInternal &sg) : m_SG(sg) + {} + + //---------------------------------------------------------------------------- + // Public methods + //---------------------------------------------------------------------------- + VarLocation getLoc(const std::string &varName) const{ return m_SG.getPSExtraGlobalParamLocation(varName); } + + Snippet::Base::EGPVec getDefs() const{ return m_SG.getPSModel()->getExtraGlobalParams(); } + +private: + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + const SynapseGroupInternal &m_SG; +}; + //---------------------------------------------------------------------------- // SynapseWUVarAdapter //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 1882fb3e22..14122c36c2 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -19,66 +19,37 @@ using namespace GeNN::Transpiler; //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::CurrentSource //---------------------------------------------------------------------------- -// **TODO** -// * field suffix (string) and value suffix (function to get suffix from group) common to everything in group - GroupMerged fields? -// * without nasty combined groups, getParams and getDerivedParams functions can use pointers to members -// * pre and post neuron stuff in synapse update group merged can also be child classes -NeuronUpdateGroupMerged::CurrentSource::CurrentSource(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, typeContext, groups) +void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend, EnvironmentExternalBase &env, + const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) { - /*const std::string suffix = "CS" + std::to_string(getIndex()); - - // Create type environment - GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); - - // Add heterogeneous parameters + const std::string fieldSuffix = "CS" + std::to_string(getIndex()); const auto *cm = getArchetype().getCurrentSourceModel(); - typeEnvironment.defineHeterogeneousParams(cm->getParamNames(), suffix, - &CurrentSourceInternal::getParams, - &CurrentSource::isParamHeterogeneous); - - // Add heterogeneous derived parameters - typeEnvironment.defineHeterogeneousDerivedParams(cm->getDerivedParams(), suffix, - &CurrentSourceInternal::getDerivedParams, - &CurrentSource::isDerivedParamHeterogeneous); - - // Add variables - typeEnvironment.defineVars(cm->getVars(), backend.getDeviceVarPrefix(), suffix); - // Add EGPs - typeEnvironment.defineEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", suffix);*/ + // Create new environment to add current source fields to neuron update group + EnvironmentGroupMergedField csEnv(env, *this, ng); + + csEnv.getStream() << "// current source " << getIndex() << std::endl; -} -//---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend, EnvironmentExternalBase &env, - const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) const -{ - const std::string suffix = "CS" + std::to_string(getIndex()); - const auto *cm = getArchetype().getCurrentSourceModel(); + // Substitute parameter and derived parameter names + csEnv.addParams(cm->getParamNames(), fieldSuffix, &CurrentSourceInternal::getParams, &CurrentSource::isParamHeterogeneous); + csEnv.addDerivedParams(cm->getDerivedParams(), fieldSuffix, &CurrentSourceInternal::getDerivedParams, &CurrentSource::isDerivedParamHeterogeneous); + csEnv.addEGPs(backend.getDeviceVarPrefix(), "", fieldSuffix); - EnvironmentGroupMergedField synEnv(env, *this, ng, &NeuronUpdateGroupMerged::getMergedCurrentSourceGroups); - // Create new substitution environment and add parameters, derived parameters and extra global parameters - EnvironmentSubstitute envSubs(env); - envSubs.getStream() << "// current source " << getIndex() << std::endl; - envSubs.addParamValueSubstitution(cm->getParamNames(), getArchetype().getParams(), suffix, - [this](const std::string &p) { return isParamHeterogeneous(p); }); - envSubs.addVarValueSubstitution(cm->getDerivedParams(), getArchetype().getDerivedParams(), suffix, - [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); - envSubs.addVarNameSubstitution(cm->getExtraGlobalParams(), suffix); + // Define inject current function + csEnv.add(Type::ResolvedType::createFunction(Type::Void, {modelMerged.getModel().getPrecision()}), "injectCurrent", csEnv["Isyn"] + " += $(0)", + {}, {"Isyn"}); // Create an environment which caches variables in local variables if they are accessed - EnvironmentLocalVarCache varSubs( - getArchetype(), getTypeContext(), envSubs, "l", suffix, + EnvironmentLocalVarCache varEnv( + getArchetype(), modelMerged.getTypeContext(), envSubs, "l", fieldSuffix, [&envSubs, &modelMerged, &ng](const std::string&, const Models::VarInit&, VarAccess a) { return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), envSubs["id"]); }); - //currSourceSubs.addFuncSubstitution("injectCurrent", 1, "Isyn += $(0)"); - - // Pretty print previously parsed update statements - PrettyPrinter::print(m_InjectionStatements, varSubs, getTypeContext(), m_InjectionResolvedTypes); + // Pretty print code back to environment + Transpiler::ErrorHandler errorHandler("Current source injection" + std::to_string(getIndex())); + prettyPrintStatements(cm->getInjectionCode(), modelMerged.getTypeContext(), varEnv, errorHandler); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::CurrentSource::updateHash(boost::uuids::detail::sha1 &hash) const @@ -102,113 +73,77 @@ bool NeuronUpdateGroupMerged::CurrentSource::isDerivedParamHeterogeneous( const //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynPSM //---------------------------------------------------------------------------- -NeuronUpdateGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, typeContext, groups) -{ - const std::string suffix = "InSyn" + std::to_string(getIndex()); - - // Create type environment - /*GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); - - // Add pointer to insyn - addField(getScalarType().createPointer(), "inSyn" + suffix, - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "inSyn" + g.getFusedPSVarSuffix(); }); - - // Add pointer to dendritic delay buffer if required - if(getArchetype().isDendriticDelayRequired()) { - addField(getScalarType().createPointer(), "denDelay" + suffix, - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); - - addField(Type::Uint32.createPointer(), "denDelayPtr" + suffix, - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); - } - - // Add heterogeneous parameters - const auto *psm = getArchetype().getPSModel(); - typeEnvironment.defineHeterogeneousParams(psm->getParamNames(), suffix, - &SynapseGroupInternal::getPSParams, - &InSynPSM::isParamHeterogeneous); - - // Add heterogeneous derived parameters - typeEnvironment.defineHeterogeneousDerivedParams(psm->getDerivedParams(), suffix, - &SynapseGroupInternal::getPSDerivedParams, - &InSynPSM::isDerivedParamHeterogeneous); - - // Add variables - typeEnvironment.defineVars(psm->getVars(), backend.getDeviceVarPrefix(), - suffix, &SynapseGroupInternal::getFusedPSVarSuffix); - - // Add EGPs - typeEnvironment.defineEGPs(psm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", - suffix, &SynapseGroupInternal::getFusedPSVarSuffix); - - // Scan, parse and type-check decay and apply input code - ErrorHandler errorHandler; - std::tie(m_DecayStatements, m_DecayResolvedTypes) = scanParseAndTypeCheckStatements(psm->getDecayCode(), typeContext, - typeEnvironment, errorHandler); - std::tie(m_ApplyInputStatements, m_ApplyInputResolvedTypes) = scanParseAndTypeCheckStatements(psm->getApplyInputCode(), typeContext, - typeEnvironment, errorHandler);*/ -} -//---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, EnvironmentExternalBase &env, - const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) const + const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) { - const std::string suffix = "InSyn" + std::to_string(getIndex()); + const std::string fieldSuffix = "InSyn" + std::to_string(getIndex()); const auto *psm = getArchetype().getPSModel(); - // Create new substitution environment - EnvironmentSubstitute envSubs(env); + // Create new environment to add PSM fields to neuron update group + EnvironmentGroupMergedField psmEnv(env, *this, ng); + + // Add inSyn + const auto scalarType = modelMerged.getModel().getPrecision(); + psmEnv.addField(scalarType.createPointer(), "_out_post", "outPost" + fieldSuffix, + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); - envSubs.getStream() << "// current source " << getIndex() << std::endl; - envSubs.getStream() << "scalar linSyn = group->inSynInSyn" << getIndex() << "["; - envSubs.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, envSubs["id"]); - envSubs.getStream() << "];" << std::endl; + // Read into local variable + psmEnv.getStream() << "// postsynaptic model " << getIndex() << std::endl; + psmEnv.getStream() << "scalar linSyn = " << psmEnv["_out_post"] << "["; + psmEnv.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, psmEnv["id"]); + psmEnv.getStream() << "];" << std::endl; // If dendritic delay is required if (getArchetype().isDendriticDelayRequired()) { + // Add dendritic delay buffer and pointer into it + psmEnv.addField(scalarType.createPointer(), "_den_delay", "denDelay" + fieldSuffix, + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix();}); + psmEnv.addField(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr" + fieldSuffix, + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix();}); + // Get reference to dendritic delay buffer input for this timestep - envSubs.getStream() << backend.getPointerPrefix() << "scalar *denDelayFront = "; - envSubs.getStream() << "&group->denDelay" << suffix << "[(*group->denDelayPtr" << suffix << " * group->numNeurons) + "; - envSubs.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, envSubs["id"]); - envSubs.getStream() << "];" << std::endl; + psmEnv.getStream() << backend.getPointerPrefix() << "scalar *denDelayFront = "; + psmEnv.getStream() << "&" << psmEnv["_den_delay"] << "[(*" << psmEnv["_den_delay_ptr"] << " * " << psmEnv["num_neurons"] << ") + "; + psmEnv.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, psmEnv["id"]); + psmEnv.getStream() << "];" << std::endl; // Add delayed input from buffer into inSyn - envSubs.getStream() << "linSyn += *denDelayFront;" << std::endl; + psmEnv.getStream() << "linSyn += *denDelayFront;" << std::endl; // Zero delay buffer slot - envSubs.getStream() << "*denDelayFront = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; + psmEnv.getStream() << "*denDelayFront = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; } // Add parameters, derived parameters and extra global parameters to environment - envSubs.addParamValueSubstitution(psm->getParamNames(), getArchetype().getPSParams(), suffix, - [this](const std::string &p) { return isParamHeterogeneous(p); }); - envSubs.addVarValueSubstitution(psm->getDerivedParams(), getArchetype().getPSDerivedParams(), suffix, - [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); - envSubs.addVarNameSubstitution(psm->getExtraGlobalParams(), suffix); + psmEnv.addParams(psm->getParamNames(), fieldSuffix, &SynapseGroupInternal::getPSParams, &InSynPSM::isParamHeterogeneous); + psmEnv.addDerivedParams(psm->getDerivedParams(), fieldSuffix, &SynapseGroupInternal::getPSDerivedParams, &InSynPSM::isDerivedParamHeterogeneous); + psmEnv.addEGPs(backend.getDeviceVarPrefix(), "", fieldSuffix); // **TODO** naming convention - envSubs.addSubstitution("inSyn", "linSyn"); + psmEnv.add(modelMerged.getModel().getPrecision().addConst(), "inSyn", "linSyn"); // Allow synapse group's PS output var to override what Isyn points to - envSubs.addSubstitution("Isyn", getArchetype().getPSTargetVar()); + psmEnv.add(modelMerged.getModel().getPrecision().addConst(), "Isyn", getArchetype().getPSTargetVar()); // Create an environment which caches variables in local variables if they are accessed - EnvironmentLocalVarCache varSubs( - getArchetype(), getTypeContext(), envSubs, "l", suffix, - [&envSubs, &modelMerged, &ng](const std::string&, const Models::VarInit&, VarAccess a) + EnvironmentLocalVarCache varEnv( + getArchetype(), modelMerged.getTypeContext(), psmEnv, "l", fieldSuffix, + [&psmEnv, &modelMerged, &ng](const std::string&, const Models::VarInit&, VarAccess a) { - return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), envSubs["id"]); + return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), psmEnv["id"]); }); - // Pretty print previously parsed update statements - PrettyPrinter::print(m_ApplyInputStatements, varSubs, getTypeContext(), m_ApplyInputResolvedTypes); - PrettyPrinter::print(m_DecayStatements, varSubs, getTypeContext(), m_DecayResolvedTypes); + // Pretty print code back to environment + Transpiler::ErrorHandler applyInputErrorHandler("Postsynaptic model apply input" + std::to_string(getIndex())); + prettyPrintStatements(psm->getApplyInputCode(), modelMerged.getTypeContext(), varEnv, applyInputErrorHandler); + + Transpiler::ErrorHandler decayErrorHandler("Postsynaptic model decay" + std::to_string(getIndex())); + prettyPrintStatements(psm->getDecayCode(), modelMerged.getTypeContext(), varEnv, decayErrorHandler); // Write back linSyn - varSubs.getStream() << "group->inSyn" << suffix << "["; - varSubs.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, envSubs["id"]); - varSubs.getStream() << "] = linSyn;" << std::endl; + varEnv.getStream() << psmEnv["_out_post"] << "["; + varEnv.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, psmEnv["id"]); + varEnv.getStream() << "] = linSyn;" << std::endl; } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &hash) const @@ -232,104 +167,67 @@ bool NeuronUpdateGroupMerged::InSynPSM::isDerivedParamHeterogeneous( const std:: //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynPreOutput //---------------------------------------------------------------------------- -NeuronUpdateGroupMerged::OutSynPreOutput::OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase&, - const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, typeContext, groups) +void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged) { - const std::string suffix = "OutSyn" + std::to_string(getIndex()); - - addField(getScalarType().createPointer(), "revInSyn" + suffix, - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "revInSyn" + g.getFusedPreOutputSuffix(); }); -} -//---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::OutSynPreOutput::generate(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged) const -{ - const std::string suffix = "OutSyn" + std::to_string(getIndex()); - - env.getStream() << getArchetype().getPreTargetVar() << " += "; - env.getStream() << "group->revInSyn" << suffix << "["; - env.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, env["id"]); - env.getStream() << "];" << std::endl; - env.getStream() << "group->revInSyn" << suffix << "["; - env.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, env["id"]); - env.getStream() << "] = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; + const std::string fieldSuffix = "OutSyn" + std::to_string(getIndex()); + + // Create new environment to add out syn fields to neuron update group + EnvironmentGroupMergedField outSynEnv(env, *this, ng); + + outSynEnv.addField(modelMerged.getModel().getPrecision().createPointer(), "_out_pre", "outPre" + fieldSuffix, + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); + + // Add reverse insyn variable to + outSynEnv.getStream() << getArchetype().getPreTargetVar() << " += "; + outSynEnv.getStream() << outSynEnv["_out_pre"] << "["; + outSynEnv.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, env["id"]); + outSynEnv.getStream() << "];" << std::endl; + + // Zero it again + outSynEnv.getStream() << outSynEnv["_out_pre"] << "["; + outSynEnv.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, env["id"]); + outSynEnv.getStream() << "] = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; } //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynWUMPostCode //---------------------------------------------------------------------------- -NeuronUpdateGroupMerged::InSynWUMPostCode::InSynWUMPostCode(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, typeContext, groups) -{ - const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); - - // Create type environment - GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); - - // Add heterogeneous parameters - const auto *wum = getArchetype().getWUModel(); - typeEnvironment.defineHeterogeneousParams(wum->getParamNames(), suffix, - &SynapseGroupInternal::getWUParams, - &InSynWUMPostCode::isParamHeterogeneous); - - // Add heterogeneous derived parameters - typeEnvironment.defineHeterogeneousDerivedParams(wum->getDerivedParams(), suffix, - &SynapseGroupInternal::getWUDerivedParams, - &InSynWUMPostCode::isDerivedParamHeterogeneous); - - // Add variables - typeEnvironment.defineVars(wum->getPostVars(), backend.getDeviceVarPrefix(), - suffix, &SynapseGroupInternal::getFusedWUPostVarSuffix); - - // Add EGPs - typeEnvironment.defineEGPs(wum->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", - suffix, &SynapseGroupInternal::getFusedWUPostVarSuffix); - - // Scan, parse and type-check dynamics and spike code - ErrorHandler errorHandler; - std::tie(m_DynamicsStatements, m_DynamicsResolvedTypes) = scanParseAndTypeCheckStatements(wum->getPostDynamicsCode(), typeContext, - typeEnvironment, errorHandler); - std::tie(m_SpikeStatements, m_SpikeResolvedTypes) = scanParseAndTypeCheckStatements(wum->getPostSpikeCode(), typeContext, - typeEnvironment, errorHandler); -} -//---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &backend, EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) const +void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &backend, EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) { - const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); + const std::string fieldSuffix = "InSynWUMPost" + std::to_string(getIndex()); const auto *wum = getArchetype().getWUModel(); const unsigned int batchSize = modelMerged.getModel().getBatchSize(); // If there are any statements to execute here - const auto &statements = dynamicsNotSpike ? m_DynamicsStatements : m_SpikeStatements; - const auto &resolvedTypes = dynamicsNotSpike ? m_DynamicsResolvedTypes : m_SpikeResolvedTypes; - if(!statements.empty()) { - // Create new substitution environment and add parameters, derived parameters and extra global parameters - EnvironmentSubstitute envSubs(env); - envSubs.getStream() << "// postsynaptic weight update " << getIndex() << std::endl; - envSubs.addParamValueSubstitution(wum->getParamNames(), getArchetype().getWUParams(), suffix, - [this](const std::string &p) { return isParamHeterogeneous(p); }); - envSubs.addVarValueSubstitution(wum->getDerivedParams(), getArchetype().getWUDerivedParams(), suffix, - [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); - envSubs.addVarNameSubstitution(wum->getExtraGlobalParams(), suffix); + const std::string code = dynamicsNotSpike ? wum->getPostDynamicsCode() : wum->getPostSpikeCode(); + if(!code.empty()) { + // Create new environment to add out syn fields to neuron update group + EnvironmentGroupMergedField synEnv(env, *this, ng); + + synEnv.getStream() << "// postsynaptic weight update " << getIndex() << std::endl; + + // Add parameters, derived parameters and extra global parameters to environment + synEnv.addParams(wum->getParamNames(), fieldSuffix, &SynapseGroupInternal::getWUParams, &InSynWUMPostCode::isParamHeterogeneous); + synEnv.addDerivedParams(wum->getDerivedParams(), fieldSuffix, &SynapseGroupInternal::getWUDerivedParams, &InSynWUMPostCode::isDerivedParamHeterogeneous); + synEnv.addEGPs(backend.getDeviceVarPrefix(), "", fieldSuffix); // Create an environment which caches variables in local variables if they are accessed const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); - EnvironmentLocalVarCache varSubs( - getArchetype(), getTypeContext(), envSubs, "l", suffix, - [batchSize, delayed, &envSubs, &ng](const std::string&, const Models::VarInit&, VarAccess a) + EnvironmentLocalVarCache varEnv( + getArchetype(), modelMerged.getTypeContext(), synEnv, "l", fieldSuffix, + [batchSize, delayed, &synEnv, &ng](const std::string&, const Models::VarInit&, VarAccess a) { - return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), envSubs["id"]); + return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), synEnv["id"]); }, - [batchSize, delayed, &envSubs, &ng](const std::string&, const Models::VarInit&, VarAccess a) + [batchSize, delayed, &synEnv, &ng](const std::string&, const Models::VarInit&, VarAccess a) { - return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), envSubs["id"]); + return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), synEnv["id"]); }); - /*neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_post", "", "", "", dynamicsNotSpike, + /*neuronSubstitutionsInSynapticCode(varEnv, &ng.getArchetype(), "", "_post", "", "", "", dynamicsNotSpike, [&ng](const std::string &p) { return ng.isParamHeterogeneous(p); }, [&ng](const std::string &p) { return ng.isDerivedParamHeterogeneous(p); }, [&subs, &ng, batchSize](bool delay, VarAccessDuplication varDuplication) @@ -341,13 +239,13 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back return ng.getReadVarIndex(delay, batchSize, varDuplication, subs["id"]); });*/ - // Pretty print previously parsed statements - PrettyPrinter::print(statements, varSubs, getTypeContext(), resolvedTypes); + Transpiler::ErrorHandler errorHandler("Postsynaptic weight update model " + std::to_string(getIndex())); + prettyPrintStatements(code, modelMerged.getTypeContext(), varEnv, errorHandler); } } //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged) const +void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged) { // If this group has a delay and no postsynaptic dynamics (which will already perform this copying) const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); @@ -355,11 +253,11 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots for(const auto &v : getArchetype().getWUModel()->getPostVars()) { if(v.access & VarAccessMode::READ_WRITE) { - env.getStream() << "group->" << v.name << suffix << "["; + env.getStream() << env[v.name] << "["; env.getStream() << ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), env["id"]); env.getStream() << "] = "; - env.getStream() << "group->" << v.name << suffix << "["; + env.getStream() << env[v.name] << "["; env.getStream() << ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), env["id"]); env.getStream() << "];" << std::endl; } @@ -388,68 +286,30 @@ bool NeuronUpdateGroupMerged::InSynWUMPostCode::isDerivedParamHeterogeneous( con //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynWUMPreCode //---------------------------------------------------------------------------- -NeuronUpdateGroupMerged::OutSynWUMPreCode::OutSynWUMPreCode(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, typeContext, groups) +void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &backend, EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) { - const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); - - // Create type environment - GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); - - // Add heterogeneous parameters - const auto *wum = getArchetype().getWUModel(); - typeEnvironment.defineHeterogeneousParams(wum->getParamNames(), suffix, - &SynapseGroupInternal::getWUParams, - &OutSynWUMPreCode::isParamHeterogeneous); - - // Add heterogeneous derived parameters - typeEnvironment.defineHeterogeneousDerivedParams(wum->getDerivedParams(), suffix, - &SynapseGroupInternal::getWUDerivedParams, - &OutSynWUMPreCode::isDerivedParamHeterogeneous); - - // Add variables - typeEnvironment.defineVars(wum->getPreVars(), backend.getDeviceVarPrefix(), - suffix, &SynapseGroupInternal::getFusedWUPreVarSuffix); - - // Add EGPs - typeEnvironment.defineEGPs(wum->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", - suffix, &SynapseGroupInternal::getFusedWUPreVarSuffix); - - // Scan, parse and type-check dynamics and spike code - ErrorHandler errorHandler; - std::tie(m_DynamicsStatements, m_DynamicsResolvedTypes) = scanParseAndTypeCheckStatements(wum->getPreDynamicsCode(), typeContext, - typeEnvironment, errorHandler); - std::tie(m_SpikeStatements, m_SpikeResolvedTypes) = scanParseAndTypeCheckStatements(wum->getPreSpikeCode(), typeContext, - typeEnvironment, errorHandler); -} -//---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &backend, EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) const -{ - const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); + const std::string fieldSuffix = "OutSynWUMPre" + std::to_string(getIndex()); const auto *wum = getArchetype().getWUModel(); const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - - // If there are any statements to executre here - const auto &statements = dynamicsNotSpike ? m_DynamicsStatements : m_SpikeStatements; - const auto &resolvedTypes = dynamicsNotSpike ? m_DynamicsResolvedTypes : m_SpikeResolvedTypes; // If there are any statements to execute here - if(!statements.empty()) { - // Create new substitution environment and add parameters, derived parameters and extra global parameters - EnvironmentSubstitute envSubs(env); - envSubs.getStream() << "// presynaptic weight update " << getIndex() << std::endl; - envSubs.addParamValueSubstitution(wum->getParamNames(), getArchetype().getWUParams(), suffix, - [this](const std::string &p) { return isParamHeterogeneous(p); }); - envSubs.addVarValueSubstitution(wum->getDerivedParams(), getArchetype().getWUDerivedParams(), suffix, - [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); - envSubs.addVarNameSubstitution(wum->getExtraGlobalParams(), suffix); + const std::string code = dynamicsNotSpike ? wum->getPreDynamicsCode() : wum->getPreSpikeCode(); + if(!code.empty()) { + // Create new environment to add out syn fields to neuron update group + EnvironmentGroupMergedField synEnv(env, *this, ng); + + synEnv.getStream() << "// postsynaptic weight update " << getIndex() << std::endl; + + // Add parameters, derived parameters and extra global parameters to environment + synEnv.addParams(wum->getParamNames(), fieldSuffix, &SynapseGroupInternal::getWUParams, &OutSynWUMPreCode::isParamHeterogeneous); + synEnv.addDerivedParams(wum->getDerivedParams(), fieldSuffix, &SynapseGroupInternal::getWUDerivedParams, &OutSynWUMPreCode::isDerivedParamHeterogeneous); + synEnv.addEGPs(backend.getDeviceVarPrefix(), "", fieldSuffix); // Create an environment which caches variables in local variables if they are accessed const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); - EnvironmentLocalVarCache varSubs( - getArchetype(), getTypeContext(), envSubs, "l", suffix, + EnvironmentLocalVarCache varEnv( + getArchetype(), modelMerged.getTypeContext(), synEnv, "l", fieldSuffix, [batchSize, delayed, &envSubs, &ng](const std::string&, const Models::VarInit&, VarAccess a) { return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), envSubs["id"]); @@ -471,13 +331,13 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back return ng.getReadVarIndex(delay, batchSize, varDuplication, subs["id"]); });*/ - // Pretty print previously parsed statements - PrettyPrinter::print(statements, varSubs, getTypeContext(), resolvedTypes); + Transpiler::ErrorHandler errorHandler("Presynaptic weight update model " + std::to_string(getIndex())); + prettyPrintStatements(code, modelMerged.getTypeContext(), varEnv, errorHandler); } } //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentExternal &env, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged) const +void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, + const ModelSpecMerged &modelMerged) { // If this group has a delay and no presynaptic dynamics (which will already perform this copying) const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); @@ -485,11 +345,11 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots for(const auto &v : getArchetype().getWUModel()->getPreVars()) { if(v.access & VarAccessMode::READ_WRITE) { - env.getStream() << "group->" << v.name << suffix << "["; + env.getStream() << env[v.name] << "["; env.getStream() << ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), env["id"]); env.getStream() << "] = "; - env.getStream() << "group->" << v.name << suffix << "["; + env.getStream() << env[v.name] << "["; env.getStream() << ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), env["id"]); env.getStream() << "];" << std::endl; } @@ -578,27 +438,21 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC }*/ // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, typeEnvironment, backend, - &NeuronGroupInternal::getFusedPSMInSyn, &SynapseGroupInternal::getPSHashDigest); + orderNeuronGroupChildren(m_MergedInSynPSMGroups, &NeuronGroupInternal::getFusedPSMInSyn, &SynapseGroupInternal::getPSHashDigest); // Build vector of vectors containing each child group's merged out syns with pre output, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedOutSynPreOutputGroups, typeContext, typeEnvironment, backend, - &NeuronGroupInternal::getFusedPreOutputOutSyn, &SynapseGroupInternal::getPreOutputHashDigest); + orderNeuronGroupChildren(m_MergedOutSynPreOutputGroups, &NeuronGroupInternal::getFusedPreOutputOutSyn, &SynapseGroupInternal::getPreOutputHashDigest); // Build vector of vectors containing each child group's current sources, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedCurrentSourceGroups, typeContext, typeEnvironment, backend, - &NeuronGroupInternal::getCurrentSources, &CurrentSourceInternal::getHashDigest); - + orderNeuronGroupChildren(m_MergedCurrentSourceGroups, &NeuronGroupInternal::getCurrentSources, &CurrentSourceInternal::getHashDigest); // Build vector of vectors containing each child group's incoming synapse groups // with postsynaptic updates, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedInSynWUMPostCodeGroups, typeContext, typeEnvironment, backend, - &NeuronGroupInternal::getFusedInSynWithPostCode, &SynapseGroupInternal::getWUPostHashDigest); + orderNeuronGroupChildren(m_MergedInSynWUMPostCodeGroups, &NeuronGroupInternal::getFusedInSynWithPostCode, &SynapseGroupInternal::getWUPostHashDigest); // Build vector of vectors containing each child group's outgoing synapse groups // with presynaptic synaptic updates, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedOutSynWUMPreCodeGroups, typeContext, typeEnvironment, backend, - &NeuronGroupInternal::getFusedOutSynWithPreCode, &SynapseGroupInternal::getWUPreHashDigest); + orderNeuronGroupChildren(m_MergedOutSynWUMPreCodeGroups, &NeuronGroupInternal::getFusedOutSynWithPreCode, &SynapseGroupInternal::getWUPreHashDigest); } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type NeuronUpdateGroupMerged::getHashDigest() const @@ -634,7 +488,7 @@ boost::uuids::detail::sha1::digest_type NeuronUpdateGroupMerged::getHashDigest() //-------------------------------------------------------------------------- void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, BackendBase::GroupHandlerEnv genEmitTrueSpike, - BackendBase::GroupHandlerEnv genEmitSpikeLikeEvent) const + BackendBase::GroupHandlerEnv genEmitSpikeLikeEvent) { const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); @@ -643,20 +497,20 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E EnvironmentGroupMergedField neuronEnv(env, *this); // Add field for spike recording - neuronEnv.add(Type::Uint32.createPointer(), "_record_spk", "recordSpk", - [&backend](const auto &ng, size_t) - { - return backend.getDeviceVarPrefix() + "recordSpk" + ng.getName(); - }, - "", GroupMergedFieldType::DYNAMIC); + neuronEnv.addField(Type::Uint32.createPointer(), "_record_spk", "recordSpk", + [&backend](const auto &ng, size_t) + { + return backend.getDeviceVarPrefix() + "recordSpk" + ng.getName(); + }, + "", GroupMergedFieldType::DYNAMIC); // Add field for spike event recording - neuronEnv.add(Type::Uint32.createPointer(), "_record_spk_event", "recordSpkEvent", - [&backend](const auto &ng, size_t) - { - return backend.getDeviceVarPrefix() + "recordSpkEvent" + ng.getName(); - }, - "", GroupMergedFieldType::DYNAMIC); + neuronEnv.addField(Type::Uint32.createPointer(), "_record_spk_event", "recordSpkEvent", + [&backend](const auto &ng, size_t) + { + return backend.getDeviceVarPrefix() + "recordSpkEvent" + ng.getName(); + }, + "", GroupMergedFieldType::DYNAMIC); // Add default input variable neuronEnv.add(modelMerged.getModel().getPrecision(), "Isyn", "Isyn", @@ -702,25 +556,24 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Loop through incoming synapse groups - for(const auto &sg : getMergedInSynPSMGroups()) { + for(auto &sg : m_MergedInSynPSMGroups) { CodeStream::Scope b(env.getStream()); sg.generate(backend, neuronVarEnv, *this, modelMerged); } // Loop through outgoing synapse groups with presynaptic output - for (const auto &sg : getMergedOutSynPreOutputGroups()) { + for (auto &sg : m_MergedOutSynPreOutputGroups) { CodeStream::Scope b(env.getStream()); - sg.generate(neuronVarEnv, *this, modelMerged); + sg.generate(backend, neuronVarEnv, *this, modelMerged); } // Loop through all of neuron group's current sources - for (const auto &cs : getMergedCurrentSourceGroups()) { + for (auto &cs : m_MergedCurrentSourceGroups) { CodeStream::Scope b(env.getStream()); cs.generate(backend, neuronVarEnv, *this, modelMerged); } - // If a threshold condition is provided if (!nm->getThresholdConditionCode().empty()) { neuronVarEnv.getStream() << "// test whether spike condition was fulfilled previously" << std::endl; @@ -731,7 +584,10 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E if (nm->isAutoRefractoryRequired()) { neuronVarEnv.getStream() << "const bool oldSpike = ("; - PrettyPrinter::print(m_ThresholdConditionExpression, neuronVarEnv, getTypeContext(), m_ThresholdConditionResolvedTypes); + + Transpiler::ErrorHandler errorHandler("Neuron threshold condition " + std::to_string(getIndex())); + prettyPrintExpression(nm->getThresholdConditionCode(), modelMerged.getTypeContext(), neuronVarEnv, errorHandler); + neuronVarEnv.getStream() << ");" << std::endl; } } @@ -743,16 +599,18 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E }*/ neuronVarEnv.getStream() << "// calculate membrane potential" << std::endl; - PrettyPrinter::print(m_SimStatements, neuronVarEnv, getTypeContext(), m_SimResolvedTypes); + + Transpiler::ErrorHandler errorHandler("Neuron sim code " + std::to_string(getIndex())); + prettyPrintExpression(nm->getSimCode(), modelMerged.getTypeContext(), neuronVarEnv, errorHandler); // Generate var update for outgoing synaptic populations with presynaptic update code - for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { + for (auto &sg : m_MergedOutSynWUMPreCodeGroups) { CodeStream::Scope b(neuronVarEnv.getStream()); sg.generate(backend, neuronVarEnv, *this, modelMerged, true); } // Generate var update for incoming synaptic populations with postsynaptic code - for (const auto &sg : getMergedInSynWUMPostCodeGroups()) { + for (auto &sg : m_MergedInSynWUMPostCodeGroups) { CodeStream::Scope b(neuronVarEnv.getStream()); sg.generate(backend, neuronVarEnv, *this, modelMerged, true); } @@ -830,10 +688,13 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E }*/ // test for true spikes if condition is provided - if (m_ThresholdConditionExpression) { + if (!nm->getThresholdConditionCode().empty()) { neuronVarEnv.getStream() << "// test for and register a true spike" << std::endl; neuronVarEnv.getStream() << "if (("; - PrettyPrinter::print(m_ThresholdConditionExpression, neuronVarEnv, getTypeContext(), m_ThresholdConditionResolvedTypes); + + Transpiler::ErrorHandler errorHandler("Neuron threshold condition " + std::to_string(getIndex())); + prettyPrintExpression(nm->getThresholdConditionCode(), modelMerged.getTypeContext(), neuronVarEnv, errorHandler); + neuronVarEnv.getStream() << ")"; if (nm->isAutoRefractoryRequired()) { neuronVarEnv.getStream() << " && !oldSpike"; @@ -844,9 +705,11 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E genEmitTrueSpike(neuronVarEnv, *this); // add after-spike reset if provided - if (!m_ResetStatements.empty()) { + if (!nm->getResetCode().empty()) { neuronVarEnv.getStream() << "// spike reset code" << std::endl; - PrettyPrinter::print(m_ResetStatements, neuronVarEnv, getTypeContext(), m_ResetResolvedTypes); + + Transpiler::ErrorHandler errorHandler("Neuron reset code " + std::to_string(getIndex())); + prettyPrintExpression(nm->getResetCode(), modelMerged.getTypeContext(), neuronVarEnv, errorHandler); } } @@ -895,12 +758,12 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E } // Loop through outgoing synapse groups with some sort of presynaptic code - for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { + for (auto &sg : m_MergedOutSynWUMPreCodeGroups) { sg.genCopyDelayedVars(neuronVarEnv, *this, modelMerged); } // Loop through incoming synapse groups with some sort of presynaptic code - for (const auto &sg : getMergedInSynWUMPostCodeGroups()) { + for (auto &sg : m_MergedInSynWUMPostCodeGroups) { sg.genCopyDelayedVars(neuronVarEnv, *this, modelMerged); } } @@ -908,16 +771,16 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E } } //-------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const +void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { // Generate var update for outgoing synaptic populations with presynaptic update code - for (const auto &sg : getMergedOutSynWUMPreCodeGroups()) { + for (auto &sg : m_MergedOutSynWUMPreCodeGroups) { CodeStream::Scope b(env.getStream()); sg.generate(backend, env, *this, modelMerged, false); } // Generate var update for incoming synaptic populations with postsynaptic code - for (const auto &sg : getMergedInSynWUMPostCodeGroups()) { + for (auto &sg : m_MergedInSynWUMPostCodeGroups) { CodeStream::Scope b(env.getStream()); sg.generate(backend, env, *this, modelMerged, false); } From 4643194eb333947a9c843f7878d419dc9a377f22 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 16 Jun 2023 14:19:48 +0100 Subject: [PATCH 228/725] NeuronUpdateGroupMergd compiling --- .../genn/genn/code_generator/environment.h | 98 +++++++++++-------- .../genn/genn/code_generator/groupMerged.h | 24 ++--- .../code_generator/neuronUpdateGroupMerged.h | 12 +-- .../code_generator/neuronUpdateGroupMerged.cc | 88 ++++++++--------- 4 files changed, 117 insertions(+), 105 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index c21c674189..8dae65609a 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -142,7 +142,7 @@ template class EnvironmentFieldPolicy { protected: - using Payload = std::tuple>; + using Payload = std::tuple>; EnvironmentFieldPolicy(G &group, F &fieldGroup) : m_Group(group), m_FieldGroup(fieldGroup) @@ -570,34 +570,40 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase +template class EnvironmentLocalVarCache : public EnvironmentExternalBase { - //! Type of a single definition - using DefType = typename std::invoke_result_t::value_type; - - //! Type of a single initialiser - using InitialiserType = typename std::remove_reference_t>::mapped_type; - - //! Function used to provide index strings based on initialiser and access type - using GetIndexFn = std::function; + //! Function used to provide index strings based on var name and + using GetIndexFn = std::function; public: - EnvironmentLocalVarCache(const G &group, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, - const std::string &fieldSuffix, const std::string &localPrefix, + EnvironmentLocalVarCache(G &group, F &fieldGroup, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, + const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, GetIndexFn getReadIndex, GetIndexFn getWriteIndex) - : EnvironmentExternalBase(enclosing), m_Group(group), m_Context(context), m_Contents(m_ContentsStream), - m_FieldSuffix(fieldSuffix), m_LocalPrefix(localPrefix), m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex) + : EnvironmentExternalBase(enclosing), m_Group(group), m_FieldGroup(fieldGroup), m_Context(context), m_Contents(m_ContentsStream), + m_ArrayPrefix(arrayPrefix), m_FieldSuffix(fieldSuffix), m_LocalPrefix(localPrefix), m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex) { - // Add name of each definition to map, initially with value set to value - const auto defs = A(m_Group).getDefs(); + // Copy variables into variables referenced, alongside boolean + const auto defs = A(m_Group.get().getArchetype()).getDefs(); std::transform(defs.cbegin(), defs.cend(), std::inserter(m_VariablesReferenced, m_VariablesReferenced.end()), - [](const auto &v){ return std::make_pair(v.name, false); }); + [](const auto &v){ return std::make_pair(v.name, std::make_pair(false, v)); }); + } + + EnvironmentLocalVarCache(G &group, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, + const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, + GetIndexFn getReadIndex, GetIndexFn getWriteIndex) + : EnvironmentLocalVarCache(group, group, context, enclosing, arrayPrefix, fieldSuffix, localPrefix, getReadIndex, getWriteIndex) + {} + + EnvironmentLocalVarCache(G &group, F &fieldGroup, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, + const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, GetIndexFn getIndex) + : EnvironmentLocalVarCache(group, fieldGroup, context, enclosing, arrayPrefix, fieldSuffix, localPrefix, getIndex, getIndex) + { } - EnvironmentLocalVarCache(const G &group, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, - const std::string &fieldSuffix, const std::string &localPrefix, GetIndexFn getIndex) - : EnvironmentLocalVarCache(group, context, enclosing, fieldSuffix, getIndex, getIndex) + EnvironmentLocalVarCache(G &group, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, + const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, GetIndexFn getIndex) + : EnvironmentLocalVarCache(group, group, context, enclosing, arrayPrefix, fieldSuffix, localPrefix, getIndex, getIndex) { } @@ -605,27 +611,35 @@ class EnvironmentLocalVarCache : public EnvironmentExternalBase ~EnvironmentLocalVarCache() { - A adapter(m_Group); + A archetypeAdapter(m_Group.get().getArchetype()); - // Copy definitions which have been referenced into new vector - const auto defs = adapter.getDefs(); - std::remove_const_t referencedVars; - std::copy_if(defs.cbegin(), defs.cend(), std::back_inserter(referencedVars), - [this](const auto &v){ return m_VariablesReferenced.at(v.name); }); + // Copy definitions of variables which have been referenced into new vector + const auto varDefs = archetypeAdapter.getDefs(); + Models::Base::VarVec referencedVars; + std::copy_if(varDefs.cbegin(), varDefs.cend(), std::back_inserter(referencedVars), + [this](const auto &v){ return m_VariablesReferenced.at(v.name).first; }); // Loop through referenced variables - const auto &initialisers = adapter.getInitialisers(); for(const auto &v : referencedVars) { + const auto resolvedType = v.type.resolve(m_Context.get()); + + // Add field to underlying field group + m_FieldGroup.get().addField(resolvedType.createPointer(), v.name + m_FieldSuffix, + [this, v](const typename F::GroupInternal &, size_t i) + { + return m_ArrayPrefix + v.name + A(m_Group.get().getGroups().at(i)).getNameSuffix(); + }); + if(v.access & VarAccessMode::READ_ONLY) { getContextStream() << "const "; } - getContextStream() << v.type.resolve(m_Context).getName() << " " << m_LocalPrefix << v.name; + getContextStream() << resolvedType.getName() << " " << m_LocalPrefix << v.name; // If this isn't a reduction, read value from memory // **NOTE** by not initialising these variables for reductions, // compilers SHOULD emit a warning if user code doesn't set it to something if(!(v.access & VarAccessModeAttribute::REDUCE)) { - getContextStream() << " = group->" << v.name << m_FieldSuffix << "[" << m_GetReadIndex(v.name, initialisers.at(v.name), v.access) << "]"; + getContextStream() << " = group->" << v.name << m_FieldSuffix << "[" << m_GetReadIndex(v.name, v.access) << "]"; } getContextStream() << ";" << std::endl; } @@ -637,7 +651,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternalBase for(const auto &v : referencedVars) { // If variables are read-write if(v.access & VarAccessMode::READ_WRITE) { - getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << m_GetWriteIndex(v.name, initialisers.at(v.name), v.access) << "]"; + getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << m_GetWriteIndex(v.name, v.access) << "]"; getContextStream() << " = " << m_LocalPrefix << v.name << ";" << std::endl; } } @@ -656,16 +670,12 @@ class EnvironmentLocalVarCache : public EnvironmentExternalBase // Otherwise else { // Set flag to indicate that variable has been referenced - var->second = true; + var->second.first = true; - // Find corresponsing variable definition - const auto varDefs = A(m_Group).getDefs(); - auto varDef = std::find_if(varDefs.cbegin(), varDefs.cend(), - [](const auto &v){ return v.name == name.lexeme; }); - assert(varDef != varDefs.cend()); - - // Return it's resolved type - return {varDef->type.resolve(m_Context)}; + // Resolve type, add qualifier if required and return + const auto resolvedType = var->second.second.type.resolve(m_Context.get()); + const auto qualifiedType = (var->second.second.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; + return {qualifiedType}; } } @@ -682,7 +692,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternalBase // Otherwise else { // Set flag to indicate that variable has been referenced - var->second = true; + var->second.first = true; // Add local prefix to variable name return m_LocalPrefix + name; @@ -698,14 +708,16 @@ class EnvironmentLocalVarCache : public EnvironmentExternalBase //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - const G &m_Group; - const Type::TypeContext &m_Context; + std::reference_wrapper m_Group; + std::reference_wrapper m_FieldGroup; + std::reference_wrapper m_Context; std::ostringstream m_ContentsStream; CodeStream m_Contents; + std::string m_ArrayPrefix; std::string m_FieldSuffix; std::string m_LocalPrefix; GetIndexFn m_GetReadIndex; GetIndexFn m_GetWriteIndex; - std::unordered_map m_VariablesReferenced; + std::unordered_map> m_VariablesReferenced; }; } // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 7967f48b8f..957326449a 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -74,8 +74,8 @@ class ChildGroupMerged typedef std::function GetFieldDoubleValueFunc; typedef std::tuple Field; - ChildGroupMerged(size_t index, const std::vector> groups) - : m_Index(index), m_Groups(std::move(groups)) + ChildGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> groups) + : m_Index(index), m_TypeContext(typeContext), m_Groups(std::move(groups)) {} ChildGroupMerged(const ChildGroupMerged&) = delete; @@ -84,6 +84,9 @@ class ChildGroupMerged //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ + //! Get type context used to resolve any types involved in this group + const Type::TypeContext &getTypeContext() const{ return m_TypeContext; } + size_t getIndex() const { return m_Index; } //! Get 'archetype' neuron group - it's properties represent those of all other merged neuron groups @@ -92,6 +95,9 @@ class ChildGroupMerged //! Gets access to underlying vector of neuron groups which have been merged const std::vector> &getGroups() const{ return m_Groups; } + const Type::ResolvedType &getScalarType() const{ return m_TypeContext.at("scalar"); } + const Type::ResolvedType &getTimeType() const{ return m_TypeContext.at("timepoint"); } + protected: //------------------------------------------------------------------------ // Protected API @@ -211,6 +217,7 @@ class ChildGroupMerged // Members //------------------------------------------------------------------------ size_t m_Index; + const Type::TypeContext &m_TypeContext; std::vector> m_Groups; }; @@ -223,7 +230,7 @@ class GroupMerged : public ChildGroupMerged { public: GroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> groups) - : ChildGroupMerged(index, std::move(groups)), m_TypeContext(typeContext) + : ChildGroupMerged(index, typeContext, std::move(groups)) {} GroupMerged(const GroupMerged&) = delete; @@ -232,9 +239,6 @@ class GroupMerged : public ChildGroupMerged //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - //! Get type context used to resolve any types involved in this group - const Type::TypeContext &getTypeContext() const{ return m_TypeContext; } - //! Get name of memory space assigned to group const std::string &getMemorySpace() const { return m_MemorySpace; } @@ -354,9 +358,6 @@ class GroupMerged : public ChildGroupMerged //------------------------------------------------------------------------ // Protected methods //------------------------------------------------------------------------ - const Type::ResolvedType &getScalarType() const{ return m_TypeContext.at("scalar"); } - const Type::ResolvedType &getTimeType() const{ return m_TypeContext.at("timepoint"); } - void addField(const Type::ResolvedType &type, const std::string &name, GetFieldValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) { // Add field to data structure @@ -578,7 +579,6 @@ class GroupMerged : public ChildGroupMerged //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - const Type::TypeContext &m_TypeContext; std::string m_MemorySpace; std::vector m_Fields; }; @@ -668,7 +668,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged - void orderNeuronGroupChildren(std::vector &childGroups, G getVectorFunc, H getHashDigestFunc) const + void orderNeuronGroupChildren(std::vector &childGroups, const Type::TypeContext &typeContext, G getVectorFunc, H getHashDigestFunc) const { const std::vector &archetypeChildren = (getArchetype().*getVectorFunc)(); @@ -709,7 +709,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged varEnv( - getArchetype(), modelMerged.getTypeContext(), envSubs, "l", fieldSuffix, - [&envSubs, &modelMerged, &ng](const std::string&, const Models::VarInit&, VarAccess a) + EnvironmentLocalVarCache varEnv( + *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, + [&csEnv, &modelMerged, &ng](const std::string&, VarAccess a) { - return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), envSubs["id"]); + return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), csEnv["id"]); }); // Pretty print code back to environment Transpiler::ErrorHandler errorHandler("Current source injection" + std::to_string(getIndex())); - prettyPrintStatements(cm->getInjectionCode(), modelMerged.getTypeContext(), varEnv, errorHandler); + prettyPrintStatements(cm->getInjectionCode(), getTypeContext(), varEnv, errorHandler); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::CurrentSource::updateHash(boost::uuids::detail::sha1 &hash) const @@ -74,7 +74,7 @@ bool NeuronUpdateGroupMerged::CurrentSource::isDerivedParamHeterogeneous( const // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynPSM //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, EnvironmentExternalBase &env, - const NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) + NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) { const std::string fieldSuffix = "InSyn" + std::to_string(getIndex()); const auto *psm = getArchetype().getPSModel(); @@ -126,19 +126,19 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env psmEnv.add(modelMerged.getModel().getPrecision().addConst(), "Isyn", getArchetype().getPSTargetVar()); // Create an environment which caches variables in local variables if they are accessed - EnvironmentLocalVarCache varEnv( - getArchetype(), modelMerged.getTypeContext(), psmEnv, "l", fieldSuffix, - [&psmEnv, &modelMerged, &ng](const std::string&, const Models::VarInit&, VarAccess a) + EnvironmentLocalVarCache varEnv( + *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, + [&psmEnv, &modelMerged, &ng](const std::string&, VarAccess a) { return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), psmEnv["id"]); }); // Pretty print code back to environment Transpiler::ErrorHandler applyInputErrorHandler("Postsynaptic model apply input" + std::to_string(getIndex())); - prettyPrintStatements(psm->getApplyInputCode(), modelMerged.getTypeContext(), varEnv, applyInputErrorHandler); + prettyPrintStatements(psm->getApplyInputCode(), getTypeContext(), varEnv, applyInputErrorHandler); Transpiler::ErrorHandler decayErrorHandler("Postsynaptic model decay" + std::to_string(getIndex())); - prettyPrintStatements(psm->getDecayCode(), modelMerged.getTypeContext(), varEnv, decayErrorHandler); + prettyPrintStatements(psm->getDecayCode(), getTypeContext(), varEnv, decayErrorHandler); // Write back linSyn varEnv.getStream() << psmEnv["_out_post"] << "["; @@ -167,7 +167,7 @@ bool NeuronUpdateGroupMerged::InSynPSM::isDerivedParamHeterogeneous( const std:: //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynPreOutput //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, +void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) { const std::string fieldSuffix = "OutSyn" + std::to_string(getIndex()); @@ -193,7 +193,7 @@ void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backe //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynWUMPostCode //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &backend, EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, +void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &backend, EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) { const std::string fieldSuffix = "InSynWUMPost" + std::to_string(getIndex()); @@ -216,13 +216,13 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back // Create an environment which caches variables in local variables if they are accessed const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); - EnvironmentLocalVarCache varEnv( - getArchetype(), modelMerged.getTypeContext(), synEnv, "l", fieldSuffix, - [batchSize, delayed, &synEnv, &ng](const std::string&, const Models::VarInit&, VarAccess a) + EnvironmentLocalVarCache varEnv( + *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, + [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess a) { return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), synEnv["id"]); }, - [batchSize, delayed, &synEnv, &ng](const std::string&, const Models::VarInit&, VarAccess a) + [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess a) { return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), synEnv["id"]); }); @@ -240,7 +240,7 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back });*/ Transpiler::ErrorHandler errorHandler("Postsynaptic weight update model " + std::to_string(getIndex())); - prettyPrintStatements(code, modelMerged.getTypeContext(), varEnv, errorHandler); + prettyPrintStatements(code, getTypeContext(), varEnv, errorHandler); } } //---------------------------------------------------------------------------- @@ -286,7 +286,7 @@ bool NeuronUpdateGroupMerged::InSynWUMPostCode::isDerivedParamHeterogeneous( con //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynWUMPreCode //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &backend, EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, +void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &backend, EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) { const std::string fieldSuffix = "OutSynWUMPre" + std::to_string(getIndex()); @@ -308,15 +308,15 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back // Create an environment which caches variables in local variables if they are accessed const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); - EnvironmentLocalVarCache varEnv( - getArchetype(), modelMerged.getTypeContext(), synEnv, "l", fieldSuffix, - [batchSize, delayed, &envSubs, &ng](const std::string&, const Models::VarInit&, VarAccess a) + EnvironmentLocalVarCache varEnv( + *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, + [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess a) { - return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), envSubs["id"]); + return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), synEnv["id"]); }, - [batchSize, delayed, &envSubs, &ng](const std::string&, const Models::VarInit&, VarAccess a) + [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess a) { - return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), envSubs["id"]); + return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), synEnv["id"]); }); /*neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_pre", "", "", "", dynamicsNotSpike, @@ -332,7 +332,7 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back });*/ Transpiler::ErrorHandler errorHandler("Presynaptic weight update model " + std::to_string(getIndex())); - prettyPrintStatements(code, modelMerged.getTypeContext(), varEnv, errorHandler); + prettyPrintStatements(code, getTypeContext(), varEnv, errorHandler); } } //---------------------------------------------------------------------------- @@ -438,21 +438,21 @@ NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeC }*/ // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedInSynPSMGroups, &NeuronGroupInternal::getFusedPSMInSyn, &SynapseGroupInternal::getPSHashDigest); + orderNeuronGroupChildren(m_MergedInSynPSMGroups, getTypeContext(), &NeuronGroupInternal::getFusedPSMInSyn, &SynapseGroupInternal::getPSHashDigest); // Build vector of vectors containing each child group's merged out syns with pre output, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedOutSynPreOutputGroups, &NeuronGroupInternal::getFusedPreOutputOutSyn, &SynapseGroupInternal::getPreOutputHashDigest); + orderNeuronGroupChildren(m_MergedOutSynPreOutputGroups, getTypeContext(), &NeuronGroupInternal::getFusedPreOutputOutSyn, &SynapseGroupInternal::getPreOutputHashDigest); // Build vector of vectors containing each child group's current sources, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedCurrentSourceGroups, &NeuronGroupInternal::getCurrentSources, &CurrentSourceInternal::getHashDigest); + orderNeuronGroupChildren(m_MergedCurrentSourceGroups, getTypeContext(), &NeuronGroupInternal::getCurrentSources, &CurrentSourceInternal::getHashDigest); // Build vector of vectors containing each child group's incoming synapse groups // with postsynaptic updates, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedInSynWUMPostCodeGroups, &NeuronGroupInternal::getFusedInSynWithPostCode, &SynapseGroupInternal::getWUPostHashDigest); + orderNeuronGroupChildren(m_MergedInSynWUMPostCodeGroups, getTypeContext(), &NeuronGroupInternal::getFusedInSynWithPostCode, &SynapseGroupInternal::getWUPostHashDigest); // Build vector of vectors containing each child group's outgoing synapse groups // with presynaptic synaptic updates, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedOutSynWUMPreCodeGroups, &NeuronGroupInternal::getFusedOutSynWithPreCode, &SynapseGroupInternal::getWUPreHashDigest); + orderNeuronGroupChildren(m_MergedOutSynWUMPreCodeGroups, getTypeContext(), &NeuronGroupInternal::getFusedOutSynWithPreCode, &SynapseGroupInternal::getWUPreHashDigest); } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type NeuronUpdateGroupMerged::getHashDigest() const @@ -530,25 +530,25 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Substitute spike times const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, neuronEnv["id"]); - neuronEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "sT", "lsT", + neuronEnv.add(getTimeType().addConst(), "sT", "lsT", {neuronEnv.addInitialiser("const timepoint lsT = " + neuronEnv["_spk_time"] + "[" + spikeTimeReadIndex + "];")}); - neuronEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "prev_sT", "lprevST", + neuronEnv.add(getTimeType().addConst(), "prev_sT", "lprevST", {neuronEnv.addInitialiser("const timepoint lprevST = " + neuronEnv["_prev_spk_time"] + "[" + spikeTimeReadIndex + "];")}); - neuronEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "seT", "lseT", + neuronEnv.add(getTimeType().addConst(), "seT", "lseT", {neuronEnv.addInitialiser("const timepoint lseT = " + neuronEnv["_spk_evnt_time"] + "[" + spikeTimeReadIndex+ "];")}); - neuronEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "prev_seT", "lprevSET", + neuronEnv.add(getTimeType().addConst(), "prev_seT", "lprevSET", {neuronEnv.addInitialiser("const timepoint lprevSET = " + neuronEnv["_prev_spk_evnt_time"] + "[" + spikeTimeReadIndex + "];")}); // Create an environment which caches variables in local variables if they are accessed // **NOTE** we do this right at the top so that local copies can be used by child groups - EnvironmentLocalVarCache neuronVarEnv( - getArchetype(), getTypeContext(), neuronEnv, "l", "", - [batchSize, &neuronEnv, this](const std::string &varName, const Models::VarInit&, VarAccess a) + EnvironmentLocalVarCache neuronVarEnv( + *this, getTypeContext(), neuronEnv, backend.getDeviceVarPrefix(), "l", "", + [batchSize, &neuronEnv, this](const std::string &varName, VarAccess a) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); return getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), neuronEnv["id"]) ; }, - [batchSize, &neuronEnv, this](const std::string &varName, const Models::VarInit&, VarAccess a) + [batchSize, &neuronEnv, this](const std::string &varName, VarAccess a) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); return getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), neuronEnv["id"]) ; @@ -586,7 +586,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E neuronVarEnv.getStream() << "const bool oldSpike = ("; Transpiler::ErrorHandler errorHandler("Neuron threshold condition " + std::to_string(getIndex())); - prettyPrintExpression(nm->getThresholdConditionCode(), modelMerged.getTypeContext(), neuronVarEnv, errorHandler); + prettyPrintExpression(nm->getThresholdConditionCode(), getTypeContext(), neuronVarEnv, errorHandler); neuronVarEnv.getStream() << ");" << std::endl; } @@ -601,7 +601,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E neuronVarEnv.getStream() << "// calculate membrane potential" << std::endl; Transpiler::ErrorHandler errorHandler("Neuron sim code " + std::to_string(getIndex())); - prettyPrintExpression(nm->getSimCode(), modelMerged.getTypeContext(), neuronVarEnv, errorHandler); + prettyPrintExpression(nm->getSimCode(), getTypeContext(), neuronVarEnv, errorHandler); // Generate var update for outgoing synaptic populations with presynaptic update code for (auto &sg : m_MergedOutSynWUMPreCodeGroups) { @@ -693,7 +693,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E neuronVarEnv.getStream() << "if (("; Transpiler::ErrorHandler errorHandler("Neuron threshold condition " + std::to_string(getIndex())); - prettyPrintExpression(nm->getThresholdConditionCode(), modelMerged.getTypeContext(), neuronVarEnv, errorHandler); + prettyPrintExpression(nm->getThresholdConditionCode(), getTypeContext(), neuronVarEnv, errorHandler); neuronVarEnv.getStream() << ")"; if (nm->isAutoRefractoryRequired()) { @@ -709,7 +709,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E neuronVarEnv.getStream() << "// spike reset code" << std::endl; Transpiler::ErrorHandler errorHandler("Neuron reset code " + std::to_string(getIndex())); - prettyPrintExpression(nm->getResetCode(), modelMerged.getTypeContext(), neuronVarEnv, errorHandler); + prettyPrintExpression(nm->getResetCode(), getTypeContext(), neuronVarEnv, errorHandler); } } From d8b4148476d858aa6e46e757c4ffe8b9855b6830 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 16 Jun 2023 14:32:00 +0100 Subject: [PATCH 229/725] extra global parameter functionality in environment was over-engineered --- .../genn/genn/code_generator/environment.h | 33 +++++++++---------- .../code_generator/neuronUpdateGroupMerged.cc | 10 +++--- .../synapseUpdateGroupMerged.cc | 4 +-- 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 8dae65609a..c149b86bcb 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -428,6 +428,22 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase void addVarInitParams(IsVarInitHeterogeneousFn isHeterogeneous, const std::string &fieldSuffix = "") { @@ -531,23 +547,6 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase(arrayPrefix, [&index](VarAccess a, const std::string &) { return index; }, fieldSuffix, dependents); } - - template - void addEGPs(const std::string &arrayPrefix, const std::string &varName = "", const std::string &fieldSuffix = "") - { - // Loop through EGPs - const A archetypeAdaptor(getGroup().getArchetype()); - for(const auto &e : archetypeAdaptor.getDefs()) { - const auto pointerType = e.type.resolve(getGroup().getTypeContext()).createPointer(); - addField(pointerType, e.name, - pointerType, e.name + varName + fieldSuffix, - [arrayPrefix, e, varName](const auto &g, size_t) - { - return arrayPrefix + e.name + varName + g.getName(); - }, - "", GroupMergedFieldType::DYNAMIC); - } - } private: //------------------------------------------------------------------------ diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 8e2d3547f6..6939ed9719 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -33,7 +33,7 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend // Substitute parameter and derived parameter names csEnv.addParams(cm->getParamNames(), fieldSuffix, &CurrentSourceInternal::getParams, &CurrentSource::isParamHeterogeneous); csEnv.addDerivedParams(cm->getDerivedParams(), fieldSuffix, &CurrentSourceInternal::getDerivedParams, &CurrentSource::isDerivedParamHeterogeneous); - csEnv.addEGPs(backend.getDeviceVarPrefix(), "", fieldSuffix); + csEnv.addExtraGlobalParams(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", fieldSuffix); // Define inject current function csEnv.add(Type::ResolvedType::createFunction(Type::Void, {modelMerged.getModel().getPrecision()}), "injectCurrent", csEnv["Isyn"] + " += $(0)", @@ -117,7 +117,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Add parameters, derived parameters and extra global parameters to environment psmEnv.addParams(psm->getParamNames(), fieldSuffix, &SynapseGroupInternal::getPSParams, &InSynPSM::isParamHeterogeneous); psmEnv.addDerivedParams(psm->getDerivedParams(), fieldSuffix, &SynapseGroupInternal::getPSDerivedParams, &InSynPSM::isDerivedParamHeterogeneous); - psmEnv.addEGPs(backend.getDeviceVarPrefix(), "", fieldSuffix); + psmEnv.addExtraGlobalParams(psm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", fieldSuffix); // **TODO** naming convention psmEnv.add(modelMerged.getModel().getPrecision().addConst(), "inSyn", "linSyn"); @@ -212,7 +212,7 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back // Add parameters, derived parameters and extra global parameters to environment synEnv.addParams(wum->getParamNames(), fieldSuffix, &SynapseGroupInternal::getWUParams, &InSynWUMPostCode::isParamHeterogeneous); synEnv.addDerivedParams(wum->getDerivedParams(), fieldSuffix, &SynapseGroupInternal::getWUDerivedParams, &InSynWUMPostCode::isDerivedParamHeterogeneous); - synEnv.addEGPs(backend.getDeviceVarPrefix(), "", fieldSuffix); + synEnv.addExtraGlobalParams(wum->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", fieldSuffix); // Create an environment which caches variables in local variables if they are accessed const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); @@ -304,7 +304,7 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back // Add parameters, derived parameters and extra global parameters to environment synEnv.addParams(wum->getParamNames(), fieldSuffix, &SynapseGroupInternal::getWUParams, &OutSynWUMPreCode::isParamHeterogeneous); synEnv.addDerivedParams(wum->getDerivedParams(), fieldSuffix, &SynapseGroupInternal::getWUDerivedParams, &OutSynWUMPreCode::isDerivedParamHeterogeneous); - synEnv.addEGPs(backend.getDeviceVarPrefix(), "", fieldSuffix); + synEnv.addExtraGlobalParams(wum->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", fieldSuffix); // Create an environment which caches variables in local variables if they are accessed const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); @@ -526,7 +526,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Substitute parameter and derived parameter names neuronEnv.addParams(nm->getParamNames(), "", &NeuronGroupInternal::getParams, &NeuronUpdateGroupMerged::isParamHeterogeneous); neuronEnv.addDerivedParams(nm->getDerivedParams(), "", &NeuronGroupInternal::getDerivedParams, &NeuronUpdateGroupMerged::isDerivedParamHeterogeneous); - neuronEnv.addEGPs(backend.getDeviceVarPrefix()); + neuronEnv.addExtraGlobalParams(nm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); // Substitute spike times const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, neuronEnv["id"]); diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index ca8485f7ef..2011814fc7 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -27,7 +27,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa // Substitute parameter and derived parameter names synEnv.addParams(wu->getParamNames(), "", &SynapseGroupInternal::getWUParams, &G::isWUParamHeterogeneous); synEnv.addDerivedParams(wu->getDerivedParams(), "", &SynapseGroupInternal::getWUDerivedParams, &G::isWUDerivedParamHeterogeneous); - synEnv.addEGPs(backend.getDeviceVarPrefix()); + synEnv.addExtraGlobalParams(wu->getExtraGlobalParams(), backend.getDeviceVarPrefix()); // Substitute names of pre and postsynaptic weight update variable synEnv.addVars(backend.getDeviceVarPrefix(), @@ -183,7 +183,7 @@ void PresynapticUpdateGroupMerged::generateSpikeEventThreshold(const BackendBase const auto *wum = getArchetype().getWUModel(); synEnv.addParams(wum->getParamNames(), "", &SynapseGroupInternal::getWUParams, &PresynapticUpdateGroupMerged::isWUParamHeterogeneous); synEnv.addDerivedParams(wum->getDerivedParams(), "", &SynapseGroupInternal::getWUDerivedParams, &PresynapticUpdateGroupMerged::isWUDerivedParamHeterogeneous); - synEnv.addEGPs(backend.getDeviceVarPrefix()); + synEnv.addExtraGlobalParams(wum->getExtraGlobalParams(), backend.getDeviceVarPrefix()); // Substitute in presynaptic neuron properties /*const unsigned int batchSize = modelMerged.getModel().getBatchSize(); From 93e49c2db4688c19489af9ab550b50178bdf50e9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 16 Jun 2023 15:35:30 +0100 Subject: [PATCH 230/725] fixed a couple of typos --- .../genn/genn/code_generator/backendBase.h | 112 +++++++++--------- .../backends/single_threaded_cpu/backend.cc | 2 +- .../code_generator/neuronUpdateGroupMerged.cc | 3 +- 3 files changed, 58 insertions(+), 59 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 8be1874822..c79c1e2dee 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -492,28 +492,28 @@ class GENN_EXPORT BackendBase template void genNeuronIndexCalculation(EnvironmentGroupMergedField &env, unsigned int batchSize) const { - env.add(Type::Uint32.addConst(), "num_neurons", - Type::Uint32, "numNeurons", - [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); - env.add(Type::Uint32.createPointer(), "_spk_cnt", "spkCnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getName(); }); - env.add(Type::Uint32.createPointer(), "_spk", "spk", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getName(); }); - env.add(Type::Uint32.createPointer(), "_spk_cnt_evnt", "spkCntEvnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getName(); }); - env.add(Type::Uint32.createPointer(), "_spk_evnt", "spkEvnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkEvnt" + g.getName(); }); - env.add(Type::Uint32.createPointer(), "_spk_que_ptr", "spkQuePtr", - [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getName(); }); - - env.add(env.getGroup().getTimeType().createPointer(), "_spk_time", "sT", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "sT" + g.getName(); }); - env.add(env.getGroup().getTimeType().createPointer(), "_spk_evnt_time", "seT", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "seT" + g.getName(); }); - env.add(env.getGroup().getTimeType().createPointer(), "_prev_spk_time", "prevST", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "prevST" + g.getName(); }); - env.add(env.getGroup().getTimeType().createPointer(), "_prev_spk_evnt_time", "prevSET", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "prevSET" + g.getName(); }); + env.addField(Type::Uint32.addConst(), "num_neurons", + Type::Uint32, "numNeurons", + [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); + env.addField(Type::Uint32.createPointer(), "_spk_cnt", "spkCnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getName(); }); + env.addField(Type::Uint32.createPointer(), "_spk", "spk", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getName(); }); + env.addField(Type::Uint32.createPointer(), "_spk_cnt_evnt", "spkCntEvnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getName(); }); + env.addField(Type::Uint32.createPointer(), "_spk_evnt", "spkEvnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkEvnt" + g.getName(); }); + env.addField(Type::Uint32.createPointer(), "_spk_que_ptr", "spkQuePtr", + [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getName(); }); + + env.addField(env.getGroup().getTimeType().createPointer(), "_spk_time", "sT", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "sT" + g.getName(); }); + env.addField(env.getGroup().getTimeType().createPointer(), "_spk_evnt_time", "seT", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "seT" + g.getName(); }); + env.addField(env.getGroup().getTimeType().createPointer(), "_prev_spk_time", "prevST", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "prevST" + g.getName(); }); + env.addField(env.getGroup().getTimeType().createPointer(), "_prev_spk_evnt_time", "prevSET", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "prevSET" + g.getName(); }); // If batching is enabled, calculate batch offset @@ -573,49 +573,49 @@ class GENN_EXPORT BackendBase void genSynapseIndexCalculation(EnvironmentGroupMergedField &env, unsigned int batchSize) const { // Synapse group fields - groupEnv.add(Type::Uint32.addConst(), "num_pre", - Type::Uint32, "numSrcNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); - groupEnv.add(Type::Uint32.addConst(), "num_post", - Type::Uint32, "numTrgNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); - groupEnv.add(Type::Uint32, "_row_stride", "rowStride", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(getSynapticMatrixRowStride(sg)); }); - groupEnv.add(Type::Uint32, "_col_stride", "colStride", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getMaxSourceConnections()); }); + groupEnv.addField(Type::Uint32.addConst(), "num_pre", + Type::Uint32, "numSrcNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + groupEnv.addField(Type::Uint32.addConst(), "num_post", + Type::Uint32, "numTrgNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + groupEnv.addField(Type::Uint32, "_row_stride", "rowStride", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(getSynapticMatrixRowStride(sg)); }); + groupEnv.addField(Type::Uint32, "_col_stride", "colStride", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getMaxSourceConnections()); }); // Postsynaptic model fields - groupEnv.add(modelMerged.getModel().getPrecision().createPointer(), "_out_post", "outPost", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); - groupEnv.add(modelMerged.getModel().getPrecision().createPointer(), "_den_delay", "denDelay", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); - groupEnv.add(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); + groupEnv.addField(modelMerged.getModel().getPrecision().createPointer(), "_out_post", "outPost", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); + groupEnv.addField(modelMerged.getModel().getPrecision().createPointer(), "_den_delay", "denDelay", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); + groupEnv.addField(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); // Presynaptic output fields - groupEnv.add(modelMerged.getModel().getPrecision().createPointer(), "_out_pre", "outPre", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); + groupEnv.addField(modelMerged.getModel().getPrecision().createPointer(), "_out_pre", "outPre", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); // Source neuron fields - groupEnv.add(Type::Uint32.createPointer(), "_src_spk_que_ptr", "srcSpkQuePtr", - [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getSrcNeuronGroup()->getName(); }); - groupEnv.add(Type::Uint32.createPointer(), "_src_spk_cnt", "srcSpkCnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getSrcNeuronGroup()->getName(); }); - groupEnv.add(Type::Uint32.createPointer(), "_src_spk", "srcSpk", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getSrcNeuronGroup()->getName(); }); - groupEnv.add(Type::Uint32.createPointer(), "_src_spk_evnt_cnt", "srcSpkCntEvnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getSrcNeuronGroup()->getName(); }); - groupEnv.add(Type::Uint32.createPointer(), "_src_spk_evnt", "srcSpkEvnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkEvnt" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.addField(Type::Uint32.createPointer(), "_src_spk_que_ptr", "srcSpkQuePtr", + [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.addField(Type::Uint32.createPointer(), "_src_spk_cnt", "srcSpkCnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.addField(Type::Uint32.createPointer(), "_src_spk", "srcSpk", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.addField(Type::Uint32.createPointer(), "_src_spk_evnt_cnt", "srcSpkCntEvnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getSrcNeuronGroup()->getName(); }); + groupEnv.addField(Type::Uint32.createPointer(), "_src_spk_evnt", "srcSpkEvnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkEvnt" + g.getSrcNeuronGroup()->getName(); }); // Target neuron fields - groupEnv.add(Type::Uint32.createPointer(), "_trg_spk_que_ptr", "trgSpkQuePtr", - [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getTrgNeuronGroup()->getName(); }); - groupEnv.add(Type::Uint32.createPointer(), "_trg_spk_cnt", "trgSpkCnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getTrgNeuronGroup()->getName(); }); - groupEnv.add(Type::Uint32.createPointer(), "_trg_spk", "trgSpk", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getTrgNeuronGroup()->getName(); }); + groupEnv.addField(Type::Uint32.createPointer(), "_trg_spk_que_ptr", "trgSpkQuePtr", + [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getTrgNeuronGroup()->getName(); }); + groupEnv.addField(Type::Uint32.createPointer(), "_trg_spk_cnt", "trgSpkCnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getTrgNeuronGroup()->getName(); }); + groupEnv.addField(Type::Uint32.createPointer(), "_trg_spk", "trgSpk", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getTrgNeuronGroup()->getName(); }); // If batching is enabled if(batchSize > 1) { diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index aedc56ae14..f2aeac37c3 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1445,7 +1445,7 @@ void Backend::genPopVariableInit(EnvironmentExternalBase &env, HandlerEnv handle void Backend::genVariableInit(EnvironmentExternalBase &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const { // **TODO** loops like this should be generated like CUDA threads - env.getStream() << "for (unsigned int i = 0; i < (" << count << "); i++)"; + env.getStream() << "for (unsigned int i = 0; i < (" << env[count] << "); i++)"; { CodeStream::Scope b(env.getStream()); diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 6939ed9719..1772196976 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -1,7 +1,6 @@ #include "code_generator/neuronUpdateGroupMerged.h" // GeNN code generator includes -#include "code_generator/standardLibrary.h" #include "code_generator/groupMergedTypeEnvironment.h" #include "code_generator/modelSpecMerged.h" @@ -709,7 +708,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E neuronVarEnv.getStream() << "// spike reset code" << std::endl; Transpiler::ErrorHandler errorHandler("Neuron reset code " + std::to_string(getIndex())); - prettyPrintExpression(nm->getResetCode(), getTypeContext(), neuronVarEnv, errorHandler); + prettyPrintStatements(nm->getResetCode(), getTypeContext(), neuronVarEnv, errorHandler); } } From 05bddaaa0f78dc9fca6e5e163f68d9ebe4181d23 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 16 Jun 2023 15:35:37 +0100 Subject: [PATCH 231/725] neuron initialisation updates --- .../genn/code_generator/initGroupMerged.h | 83 +--- .../genn/code_generator/initGroupMerged.cc | 441 +++++++----------- 2 files changed, 181 insertions(+), 343 deletions(-) diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index f89759c4a1..90506bf392 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -11,23 +11,18 @@ namespace GeNN::CodeGenerator class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase { public: - using VarInitAST = std::unordered_map>; - //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::CurrentSource //---------------------------------------------------------------------------- //! Child group merged for current sources attached to this neuron update group - class CurrentSource : public GroupMerged + class CurrentSource : public ChildGroupMerged { public: - CurrentSource(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups); - //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, EnvironmentExternal &env, - const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const; + void generate(const BackendBase &backend, EnvironmentExternalBase &env, + NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged); //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -44,12 +39,6 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- //! Is the var init parameter referenced? bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; - - //---------------------------------------------------------------------------- - // Members - //---------------------------------------------------------------------------- - //! Parsed statements and resolved types for initialising each variable - VarInitAST m_VarInitASTs; }; //---------------------------------------------------------------------------- @@ -59,14 +48,11 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase class InSynPSM : public GroupMerged { public: - InSynPSM(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups); - - //---------------------------------------------------------------------------- + //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, EnvironmentExternal &env, - const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const; + void generate(const BackendBase &backend, EnvironmentExternalBase &env, + NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged); //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -83,12 +69,6 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- //! Is the var init parameter referenced? bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; - - //---------------------------------------------------------------------------- - // Members - //---------------------------------------------------------------------------- - //! Parsed statements and resolved types for initialising each variable - VarInitAST m_VarInitASTs; }; //---------------------------------------------------------------------------- @@ -98,14 +78,11 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase class OutSynPreOutput : public GroupMerged { public: - OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups); - //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, EnvironmentExternal &env, - const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const; + void generate(const BackendBase &backend, EnvironmentExternalBase &env, + NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged); }; //---------------------------------------------------------------------------- @@ -115,14 +92,11 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase class InSynWUMPostVars : public GroupMerged { public: - InSynWUMPostVars(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups); - //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, EnvironmentExternal &env, - const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const; + void generate(const BackendBase &backend, EnvironmentExternalBase &env, + NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged); //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -139,12 +113,6 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- //! Is the var init parameter referenced? bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; - - //---------------------------------------------------------------------------- - // Members - //---------------------------------------------------------------------------- - //! Parsed statements and resolved types for initialising each variable - VarInitAST m_VarInitASTs; }; //---------------------------------------------------------------------------- @@ -154,14 +122,11 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase class OutSynWUMPreVars: public GroupMerged { public: - OutSynWUMPreVars(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups); - //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - void generate(const BackendBase &backend, EnvironmentExternal &env, - const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const; + void generate(const BackendBase &backend, EnvironmentExternalBase &env, + NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged); //! Update hash with child groups void updateHash(boost::uuids::detail::sha1 &hash) const; @@ -178,12 +143,6 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- //! Is the var init parameter referenced? bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; - - //---------------------------------------------------------------------------- - // Members - //---------------------------------------------------------------------------- - //! Parsed statements and resolved types for initialising each variable - VarInitAST m_VarInitASTs; }; NeuronInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, @@ -204,7 +163,7 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase runnerVarDecl, runnerMergedStructAlloc, name); } - void generateInit(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged) const; + void generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); const std::vector &getMergedCurrentSourceGroups() const { return m_MergedCurrentSourceGroups; } const std::vector &getMergedInSynPSMGroups() const { return m_MergedInSynPSMGroups; } @@ -221,11 +180,11 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //------------------------------------------------------------------------ // Private methods //------------------------------------------------------------------------ - void genInitSpikeCount(const BackendBase &backend, EnvironmentExternal &env, bool spikeEvent, unsigned int batchSize) const; + void genInitSpikeCount(const BackendBase &backend, EnvironmentExternalBase &env, bool spikeEvent, unsigned int batchSize); - void genInitSpikes(const BackendBase &backend, EnvironmentExternal &env, bool spikeEvent, unsigned int batchSize) const; + void genInitSpikes(const BackendBase &backend, EnvironmentExternalBase &env, bool spikeEvent, unsigned int batchSize); - void genInitSpikeTime(const BackendBase &backend, EnvironmentExternal &env, const std::string &varName, unsigned int batchSize) const; + void genInitSpikeTime(const BackendBase &backend, EnvironmentExternalBase &env, const std::string &varName, unsigned int batchSize); //------------------------------------------------------------------------ // Members @@ -235,9 +194,6 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase std::vector m_MergedOutSynPreOutputGroups; std::vector m_MergedInSynWUMPostVarGroups; std::vector m_MergedOutSynWUMPreVarGroups; - - //! Parsed statements and resolved types for initialising each variable - VarInitAST m_VarInitASTs; }; @@ -514,13 +470,6 @@ class GENN_EXPORT CustomUpdateInitGroupMerged : public CustomUpdateInitGroupMerg // Static constants //---------------------------------------------------------------------------- static const std::string name; - -private: - //---------------------------------------------------------------------------- - // Members - //---------------------------------------------------------------------------- - //! Parsed statements and resolved types for initialising each variable - NeuronInitGroupMerged::VarInitAST m_VarInitASTs; }; diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 9240257650..183dc21ac0 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -9,7 +9,6 @@ #include "transpiler/parser.h" #include "transpiler/prettyPrinter.h" #include "transpiler/scanner.h" -#include "transpiler/standardLibrary.h" #include "transpiler/typeChecker.h" using namespace GeNN; @@ -21,7 +20,7 @@ using namespace GeNN::Transpiler; //-------------------------------------------------------------------------- namespace { -void genVariableFill(CodeStream &os, const std::string &fieldName, const std::string &value, const std::string &idx, const std::string &stride, +void genVariableFill(EnvironmentExternalBase &env, const std::string &target, const std::string &value, const std::string &idx, const std::string &stride, VarAccessDuplication varDuplication, unsigned int batchSize, bool delay = false, unsigned int numDelaySlots = 1) { // Determine number of values to fill in each thread @@ -29,19 +28,19 @@ void genVariableFill(CodeStream &os, const std::string &fieldName, const std::st // If there's only one, don't generate a loop if(numValues == 1) { - os << "group->" << fieldName << "[" << idx << "] = " << value << ";" << std::endl; + env.getStream() << env[target] << "[" << env[idx] << "] = " << value << ";" << std::endl; } // Otherwise else { - os << "for(unsigned int d = 0; d < " << numValues << "; d++)"; + env.getStream() << "for(unsigned int d = 0; d < " << numValues << "; d++)"; { - CodeStream::Scope b(os); - os << "group->" << fieldName << "[(d * " << stride << ") + " << idx << "] = " << value << ";" << std::endl; + CodeStream::Scope b(env.getStream()); + env.getStream() << env[target] << "[(d * " << stride << ") + " << env[idx] << "] = " << value << ";" << std::endl; } } } //-------------------------------------------------------------------------- -void genScalarFill(CodeStream &os, const std::string &fieldName, const std::string &value, +void genScalarFill(EnvironmentExternalBase &env, const std::string &target, const std::string &value, VarAccessDuplication varDuplication, unsigned int batchSize, bool delay = false, unsigned int numDelaySlots = 1) { // Determine number of values to fill in each thread @@ -49,115 +48,63 @@ void genScalarFill(CodeStream &os, const std::string &fieldName, const std::stri // If there's only one, don't generate a loop if(numValues == 1) { - os << "group->" << fieldName << "[0] = " << value << ";" << std::endl; + env.getStream() << env[target] << "[0] = " << value << ";" << std::endl; } // Otherwise else { - os << "for(unsigned int d = 0; d < " << numValues << "; d++)"; + env.getStream() << "for(unsigned int d = 0; d < " << numValues << "; d++)"; { - CodeStream::Scope b(os); - os << "group->" << fieldName << "[d] = " << value << ";" << std::endl; - } - } -} - -template -NeuronInitGroupMerged::VarInitAST addInitNeuronVarFields(const BackendBase &backend, TypeChecker::EnvironmentBase &enclosingEnv, - const G &groupMerged, const std::string &fieldSuffix) -{ - // Loop through variables - NeuronInitGroupMerged::VarInitAST varInitAST; - A archetypeAdaptor(groupMerged.getArchetype()); - for (const auto &var : archetypeAdaptor.getDefs()) { - // If there is any initialisation code - const auto *snippet = archetypeAdaptor.getInitialisers().at(var.name).getSnippet(); - if (!snippet->getCode().empty()) { - // Create type environment for this variable's initialisation - GroupMergedTypeEnvironment typeEnvironment(*this, &enclosingEnv); - - // Add pointers to state variable itself - //typeEnvironment.definePointerField(var.type, var.name, backend.getDeviceVarPrefix(), - // getVarAccessMode(var.access), suffix, &SynapseGroupInternal::getFusedPSVarSuffix);*/ - const auto varResolvedType = var.type.resolve(groupMerged.getTypeContext()); - const auto varQualifiedType = (var.access & VarAccessModeAttribute::READ_ONLY) ? varResolvedType.addQualifier(Type::Qualifier::CONSTANT) : varResolvedType; - defineField(varQualifiedType, var.name, - varResolvedType.createPointer(), var.name + fieldSuffix, - [&backend, var](const auto &g, size_t) - { - return backend.getDeviceVarPrefix() + var.name + A(g).getNameSuffix(); - }); - - - // Add heterogeneous var init parameters - typeEnvironment.defineHeterogeneousVarInitParams(&G::isVarInitParamHeterogeneous, suffix); - typeEnvironment.defineHeterogeneousVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, suffix); - - // Add EGPs - typeEnvironment.defineEGPs(snippet->getExtraGlobalParams(), backend.getDeviceVarPrefix(), - var.name, suffix); - - // Scan, parse and type-check update code - ErrorHandler errorHandler; - const std::string code = upgradeCodeString(snippet->getCode()); - const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); - - auto initStatements = Parser::parseBlockItemList(tokens, typeContext, errorHandler); - auto initTypes = TypeChecker::typeCheck(initStatements, typeEnvironment, errorHandler); - - // Add to map of per-variable initialistion AST - varInitAST.emplace(var.name, std::make_tuple(std::move(initStatements), std::move(initTypes))); + CodeStream::Scope b(env.getStream()); + env.getStream() << env[target] << "[d] = " << value << ";" << std::endl; } } - - return varInitAST; } //------------------------------------------------------------------------ -template -void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternal &env, - const G &groupMerged, const NeuronInitGroupMerged::VarInitAST &varInitAST, const std::string &fieldSuffix, - const std::string &countMember, size_t numDelaySlots, unsigned int batchSizet) +template +void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &env, + G &group, F &fieldGroup, const std::string &fieldSuffix, + const std::string &count, size_t numDelaySlots, unsigned int batchSize) { A adaptor(groupMerged.getArchetype()); - const std::string count = "group->" + countMember; for (const auto &var : adaptor.getDefs()) { + // If there is any initialisation code + const auto resolvedType = var.type.resolve(group.getTypeContext()); const auto &varInit = adaptor.getInitialisers().at(var.name); - const auto &varAST = varInitAST.at(var.name); - - // If there are any initialisation statements for this variable - if (!std::get<0>(varAST).empty()) { + const auto *snippet = adaptor.getInitialisers().at(var.name).getSnippet(); + if (!snippet->getCode().empty()) { CodeStream::Scope b(env.getStream()); - EnvironmentSubstitute varEnv(&env); + EnvironmentGroupMergedField varEnv(env, group, fieldGroup); // Substitute in parameters and derived parameters for initialising variables - varEnv.addParamValueSubstitution(varInit.getSnippet()->getParamNames(), varInit.getParams(), var.name + fieldSuffix, - [&groupMerged, &var, isParamHeterogeneousFn](const std::string &p) - { - return groupMerged.isParamHeterogeneous(groupMerged, var.name, p); - }); - varEnv.addVarValueSubstitution(varInit.getSnippet()->getDerivedParams(), varInit.getDerivedParams(), var.name + fieldSuffix, - [&groupMerged, &var, isDerivedParamHeterogeneousFn](const std::string &p) - { - return groupMerged.isDerivedParamHeterogeneous(groupMerged, var.name, p); - }); - varEnv.addVarNameSubstitution(varInit.getSnippet()->getExtraGlobalParams(), var.name + fieldSuffix); + varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous, fieldSuffix); + varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, fieldSuffix); + varEnv.addExtraGlobalParams(snippet->getExtraGlobalParameters(), backend.getDeviceVarPrefix(), var.name, fieldSuffix); + + // Add field for variable itself + varEnv.addField(resolvedType.createPointer(), "_value", var.name + fieldSuffix, + [&backend, var](const auto &g, size_t) + { + return backend.getDeviceVarPrefix() + var.name + A(g).getNameSuffix(); + }); // If variable is shared between neurons if (getVarAccessDuplication(var.access) == VarAccessDuplication::SHARED_NEURON) { backend.genPopVariableInit( - varEnvs, - [&adaptor, &groupMerged, &var, &varAST, &fieldSuffix, batchSize, numDelaySlots] - (EnvironmentExternal &varInitEnv) + varEnv, + [&adaptor, &fieldSuffix, &group, &resolvedType, &var, batchSize, numDelaySlots, snippet] + (EnvironmentExternalBase &varInitEnv) { // Generate initial value into temporary variable - varInitEnv.getStream() << var.type.resolve(groupMerged.getTypeContext()).getName() << " initVal;" << std::endl; - varInitEnv.addVarSubstitution("value", "initVal"); + varInitEnv.getStream() << resolvedType.getName() << " initVal;" << std::endl; + varInitEnv.add(resolvedType, "value", "initVal"); // Pretty print variable initialisation code - PrettyPrinter::print(std::get<0>(varAST), varInitEnv, groupMerged.getTypeContext(), std::get<1>(varAST)); + Transpiler::ErrorHandler errorHandler("Variable '" + var.name + "' init code" + std::to_string(group.getIndex())); + prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all delay slots and batches - genScalarFill(varInitEnv.getStream(), var.name + fieldSuffix, "initVal", getVarAccessDuplication(var.access), + genScalarFill(varInitEnv, "_value", "initVal", getVarAccessDuplication(var.access), batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } @@ -165,18 +112,19 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternal &env, else { backend.genVariableInit( varEnvs, count, "id", - [&adaptor, &groupMerged, &var, &varAST, &fieldSuffix, batchSize, count, numDelaySlots] + [&adaptor, &fieldSuffix, &group, &var, &resolvedType, batchSize, count, numDelaySlots] (EnvironmentExternal &varInitEnv) { // Generate initial value into temporary variable - varInitEnv.getStream() << var.type.resolve(groupMerged.getTypeContext()).getName() << " initVal;" << std::endl; - varInitEnv.addVarSubstitution("value", "initVal"); + varInitEnv.getStream() << resolvedType.getName() << " initVal;" << std::endl; + varInitEnv.add(resolvedType, "value", "initVal"); // Pretty print variable initialisation code - PrettyPrinter::print(std::get<0>(varAST), varInitEnv, groupMerged.getTypeContext(), std::get<1>(varAST)); + Transpiler::ErrorHandler errorHandler("Variable '" + var.name + "' init code" + std::to_string(group.getIndex())); + prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all delay slots and batches - genVariableFill(varInitEnv.getStream(), var.name + fieldSuffix, "initVal", varInitSubs["id"], count, + genVariableFill(varInitEnv(), "_value", "initVal", "id", count, getVarAccessDuplication(var.access), batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } @@ -185,6 +133,14 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternal &env, } } //------------------------------------------------------------------------ +template +void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &env, + G &group, const std::string &fieldSuffix, + const std::string &count, size_t numDelaySlots, unsigned int batchSize) +{ + genInitNeuronVarCode(backend, env, group, group, fieldSuffix, count, numDelaySlots, batchSize); +} +//------------------------------------------------------------------------ // Initialise one row of weight update model variables template void genInitWUVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const Substitutions &popSubs, @@ -222,7 +178,7 @@ void genInitWUVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const os << code << std::endl; // Fill value across all batches - genVariableFill(os, var.name, "initVal", varSubs["id_syn"], stride, + genVariableFill(os, var.name, "initVal", "id_syn", stride, getVarAccessDuplication(var.access), batchSize); }); } @@ -233,45 +189,12 @@ void genInitWUVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::CurrentSource //---------------------------------------------------------------------------- -NeuronInitGroupMerged::CurrentSource::CurrentSource(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, typeContext, groups) +void NeuronInitGroupMerged::CurrentSource::generate(const BackendBase &backend, EnvironmentExternalBase &env, + NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) { - const std::string suffix = "CS" + std::to_string(getIndex()); - - // Loop through variables - // **TODO** adaptor - const auto &varInit = getArchetype().getVarInitialisers(); - for(const auto &var : getArchetype().getCurrentSourceModel()->getVars()) { - // Add pointers to state variable - if(!varInit.at(var.name).getSnippet()->getCode().empty()) { - addPointerField(var.type, var.name + suffix, - backend.getDeviceVarPrefix() + var.name); - } - - // Add heterogeneous var init parameters - addHeterogeneousVarInitParams( - &CurrentSource::isVarInitParamHeterogeneous, suffix); - addHeterogeneousVarInitDerivedParams( - &CurrentSource::isVarInitDerivedParamHeterogeneous, suffix); - - // Add extra global parameters - for(const auto &e : varInit.at(var.name).getSnippet()->getExtraGlobalParams()) { - addField(e.type.resolve(getTypeContext()).createPointer(), e.name + var.name + suffix, - [&backend, e, suffix, var](const auto &g, size_t) - { - return backend.getDeviceVarPrefix() + e.name + var.name + g.getName(); - }, - GroupMergedFieldType::DYNAMIC); - } - } -} -//---------------------------------------------------------------------------- -void NeuronInitGroupMerged::CurrentSource::generate(const BackendBase &backend, EnvironmentExternal &env, - const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const -{ - genInitNeuronVarCode(backend, env, *this, m_VarInitASTs, "CS" + std::to_string(getIndex()), - "numNeurons", 0, modelMerged.getModel().getBatchSize()); + genInitNeuronVarCode( + backend, env, *this, ng, "CS" + std::to_string(getIndex()), + "num_neurons", 0, modelMerged.getModel().getBatchSize()); } //---------------------------------------------------------------------------- void NeuronInitGroupMerged::CurrentSource::updateHash(boost::uuids::detail::sha1 &hash) const @@ -301,66 +224,52 @@ bool NeuronInitGroupMerged::CurrentSource::isVarInitParamReferenced(const std::s //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::InSynPSM //---------------------------------------------------------------------------- -NeuronInitGroupMerged::InSynPSM::InSynPSM(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, typeContext, groups) +void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, EnvironmentExternalBase &env, + NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) { - const std::string suffix = "InSyn" + std::to_string(getIndex()); - - // Add pointer to insyn - addField(getScalarType().createPointer(), "inSyn" + suffix, - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "inSyn" + g.getFusedPSVarSuffix(); }); - - // Add pointer to dendritic delay buffer if required - if(getArchetype().isDendriticDelayRequired()) { - addField(getScalarType().createPointer(), "denDelay" + suffix, - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); + const std::string fieldSuffix = "InSyn" + std::to_string(getIndex()); - addField(Type::Uint32.createPointer(), "denDelayPtr" + suffix, - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); - } - - // Add fields required to initialise PSM variables and get AST - m_VarInitASTs = addInitNeuronVarFields(backend, enclosingEnv, *this, suffix); -} -//---------------------------------------------------------------------------- -void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, EnvironmentExternal &env, - const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const -{ - const std::string suffix = "InSyn" + std::to_string(getIndex()); + // Create environment for group + EnvironmentGroupMergedField groupEnv(env, *this, ng); - // Zero InSyn - backend.genVariableInit(env, "group->numNeurons", "id", - [&modelMerged, &suffix] (EnvironmentExternal &varEnv) + // Add field for InSyn and zero + groupEnv.addField(getScalarType().createPointer(), "_out_post", "outPost", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); + backend.genVariableInit(env, "num_neurons", "id", + [&modelMerged] (EnvironmentExternalBase &varEnv) { - genVariableFill(varEnv.getStream(), "inSyn" + suffix, modelMerged.scalarExpr(0.0), - varEnv["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, + genVariableFill(varEnv, "_out_post", modelMerged.scalarExpr(0.0), + "id", "num_neurons", VarAccessDuplication::DUPLICATE, modelMerged.getModel().getBatchSize()); }); // If dendritic delays are required if(getArchetype().isDendriticDelayRequired()) { - // Zero dendritic delay buffer - backend.genVariableInit(env, "group->numNeurons", "id", - [&modelMerged, &suffix, this](EnvironmentExternal &varEnv) + // Add field for dendritic delay buffer and zero + groupEnv.addField(getScalarType().createPointer(), "_den_delay", "denDelay", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); + backend.genVariableInit(env, "num_neurons", "id", + [&modelMerged, this](EnvironmentExternalBase &varEnv) { - genVariableFill(varEnv.getStream(), "denDelay" + suffix, modelMerged.scalarExpr(0.0), - varEnv["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, + genVariableFill(varEnv, "_den_delay", modelMerged.scalarExpr(0.0), + "id", "num_neurons", VarAccessDuplication::DUPLICATE, modelMerged.getModel().getBatchSize(), true, getArchetype().getMaxDendriticDelayTimesteps()); }); - // Zero dendritic delay pointer + // Add field for dendritic delay pointer and zero + groupEnv.addField(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); backend.genPopVariableInit(env, - [&suffix](EnvironmentExternal &varEnv) + [](EnvironmentExternalBase &varEnv) { - varEnv.getStream() << "*group->denDelayPtr" << suffix << " = 0;" << std::endl; + varEnv.getStream() << "*" << varEnv["_den_delay_ptr"] << " = 0;" << std::endl; }); } - genInitNeuronVarCode(backend, env, *this, m_VarInitASTs, suffix, - "numNeurons", 1, modelMerged.getModel().getBatchSize()); + genInitNeuronVarCode( + backend, groupEnv, *this, ng, fieldSuffix, "num_neurons", 0, modelMerged.getModel().getBatchSize()); } //---------------------------------------------------------------------------- void NeuronInitGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &hash) const @@ -390,26 +299,22 @@ bool NeuronInitGroupMerged::InSynPSM::isVarInitParamReferenced(const std::string //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynPreOutput //---------------------------------------------------------------------------- -NeuronInitGroupMerged::OutSynPreOutput::OutSynPreOutput(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &, - const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, typeContext, groups) +void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, EnvironmentExternalBase &env, + NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) { const std::string suffix = "OutSyn" + std::to_string(getIndex()); - addField(getScalarType().createPointer(), "revInSyn" + suffix, - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "revInSyn" + g.getFusedPreOutputSuffix(); }); -} -//---------------------------------------------------------------------------- -void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, EnvironmentExternal &env, - const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const -{ - const std::string suffix = "OutSyn" + std::to_string(getIndex()); + // Create environment for group + EnvironmentGroupMergedField groupEnv(env, *this, ng); - backend.genVariableInit(env, "group->numNeurons", "id", - [&modelMerged, suffix] (EnvironmentExternal &varEnv) + // Add + groupEnv.addField(getScalarType().createPointer(), "_out_pre", "outPre", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); + backend.genVariableInit(env, "num_neurons", "id", + [&modelMerged] (EnvironmentExternalBase &varEnv) { - genVariableFill(varEnv.getStream(), "revInSyn" + suffix, modelMerged.scalarExpr(0.0), - varEnv["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, + genVariableFill(varEnv, "_out_pre", modelMerged.scalarExpr(0.0), + "id", "num_neurons", VarAccessDuplication::DUPLICATE, modelMerged.getModel().getBatchSize()); }); } @@ -417,20 +322,11 @@ void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::InSynWUMPostVars //---------------------------------------------------------------------------- -NeuronInitGroupMerged::InSynWUMPostVars::InSynWUMPostVars(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, typeContext, groups) +void NeuronInitGroupMerged::InSynWUMPostVars::generate(const BackendBase &backend, EnvironmentExternalBase &env, + NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) { - // Add fields required to initialise PSM variables and get AST - m_VarInitASTs = addInitNeuronVarFields(backend, enclosingEnv, *this, - "InSynWUMPost" + std::to_string(getIndex())); -} -//---------------------------------------------------------------------------- -void NeuronInitGroupMerged::InSynWUMPostVars::generate(const BackendBase &backend, EnvironmentExternal &env, - const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const -{ - genInitNeuronVarCode(backend, env, *this, m_VarInitASTs, "InSynWUMPost" + std::to_string(getIndex()), - "numNeurons", 1, modelMerged.getModel().getBatchSize()); + genInitNeuronVarCode( + backend, env, *this, ng, "InSynWUMPost" + std::to_string(getIndex()), "num_neurons", 0, modelMerged.getModel().getBatchSize()); } //---------------------------------------------------------------------------- void NeuronInitGroupMerged::InSynWUMPostVars::updateHash(boost::uuids::detail::sha1 &hash) const @@ -460,20 +356,11 @@ bool NeuronInitGroupMerged::InSynWUMPostVars::isVarInitParamReferenced(const std //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynWUMPreVars //---------------------------------------------------------------------------- -NeuronInitGroupMerged::OutSynWUMPreVars::OutSynWUMPreVars(size_t index, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &enclosingEnv, - const BackendBase &backend, const std::vector> &groups) -: GroupMerged(index, typeContext, groups) +void NeuronInitGroupMerged::OutSynWUMPreVars::generate(const BackendBase &backend, EnvironmentExternalBase &env, + NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) { - // Add fields required to initialise PSM variables and get AST - m_VarInitASTs = addInitNeuronVarFields(backend, enclosingEnv, *this, - "OutSynWUMPre" + std::to_string(getIndex())); -} -//---------------------------------------------------------------------------- -void NeuronInitGroupMerged::OutSynWUMPreVars::generate(const BackendBase &backend, EnvironmentExternal &env, - const NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) const -{ - genInitNeuronVarCode(backend, env, *this, m_VarInitASTs, "OutSynWUMPre" + std::to_string(getIndex()), - "numNeurons", 1, modelMerged.getModel().getBatchSize()); + genInitNeuronVarCode( + backend, env, *this, ng, "OutSynWUMPre" + std::to_string(getIndex()), "num_neurons", 0, modelMerged.getModel().getBatchSize()); } //---------------------------------------------------------------------------- void NeuronInitGroupMerged::OutSynWUMPreVars::updateHash(boost::uuids::detail::sha1 &hash) const @@ -509,44 +396,30 @@ NeuronInitGroupMerged::NeuronInitGroupMerged(size_t index, const Type::TypeConte const std::vector> &groups) : NeuronGroupMergedBase(index, typeContext, backend, groups) { - // Create type environment - StandardLibrary::FunctionTypes stdLibraryEnv; - GroupMergedTypeEnvironment typeEnvironment(*this, &stdLibraryEnv); - - if(backend.isPopulationRNGRequired() && getArchetype().isSimRNGRequired() - && backend.isPopulationRNGInitialisedOnDevice()) - { - // **TODO** inject RNG types into environment - addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rng"); - } - - // Add fields required to initialise PSM variables and get AST - m_VarInitASTs = addInitNeuronVarFields(backend, typeEnvironment, *this, ""); - // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, typeEnvironment, backend, + orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, &NeuronGroupInternal::getFusedPSMInSyn, &SynapseGroupInternal::getPSInitHashDigest ); // Build vector of vectors containing each child group's merged out syns with pre output, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedOutSynPreOutputGroups, typeContext, typeEnvironment, backend, + orderNeuronGroupChildren(m_MergedOutSynPreOutputGroups, typeContext, &NeuronGroupInternal::getFusedPreOutputOutSyn, &SynapseGroupInternal::getPreOutputInitHashDigest ); // Build vector of vectors containing each child group's current sources, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedCurrentSourceGroups, typeContext, typeEnvironment, backend, + orderNeuronGroupChildren(m_MergedCurrentSourceGroups, typeContext, &NeuronGroupInternal::getCurrentSources, &CurrentSourceInternal::getInitHashDigest ); // Build vector of vectors containing each child group's incoming synapse groups // with postsynaptic weight update model variable, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedInSynWUMPostVarGroups, typeContext, typeEnvironment, backend, + orderNeuronGroupChildren(m_MergedInSynWUMPostVarGroups, typeContext, &NeuronGroupInternal::getFusedInSynWithPostVars, &SynapseGroupInternal::getWUPostInitHashDigest); // Build vector of vectors containing each child group's outgoing synapse groups // with presynaptic weight update model variables, ordered to match those of the archetype group - orderNeuronGroupChildren(m_MergedOutSynWUMPreVarGroups, typeContext, typeEnvironment, backend, + orderNeuronGroupChildren(m_MergedOutSynWUMPreVarGroups, typeContext, &NeuronGroupInternal::getFusedOutSynWithPreVars, &SynapseGroupInternal::getWUPreInitHashDigest); } @@ -582,36 +455,39 @@ boost::uuids::detail::sha1::digest_type NeuronInitGroupMerged::getHashDigest() c return hash.get_digest(); } //---------------------------------------------------------------------------- -void NeuronInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged) const +void NeuronInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { const auto &model = modelMerged.getModel(); + // Create environment for group + EnvironmentGroupMergedField groupEnv(env, *this); + // Initialise spike counts - genInitSpikeCount(backend, env, false, model.getBatchSize()); - genInitSpikeCount(backend, env, true, model.getBatchSize()); + genInitSpikeCount(backend, groupEnv, false, model.getBatchSize()); + genInitSpikeCount(backend, groupEnv, true, model.getBatchSize()); // Initialise spikes - genInitSpikes(backend, env, false, model.getBatchSize()); - genInitSpikes(backend, env, true, model.getBatchSize()); + genInitSpikes(backend, groupEnv, false, model.getBatchSize()); + genInitSpikes(backend, groupEnv, true, model.getBatchSize()); // Initialize spike times if(getArchetype().isSpikeTimeRequired()) { - genInitSpikeTime(backend, env, "sT", model.getBatchSize()); + genInitSpikeTime(backend, groupEnv, "sT", model.getBatchSize()); } // Initialize previous spike times if(getArchetype().isPrevSpikeTimeRequired()) { - genInitSpikeTime( backend, env, "prevST", model.getBatchSize()); + genInitSpikeTime( backend, groupEnv, "prevST", model.getBatchSize()); } // Initialize spike-like-event times if(getArchetype().isSpikeEventTimeRequired()) { - genInitSpikeTime(backend, env, "seT", model.getBatchSize()); + genInitSpikeTime(backend, groupEnv, "seT", model.getBatchSize()); } // Initialize previous spike-like-event times if(getArchetype().isPrevSpikeEventTimeRequired()) { - genInitSpikeTime(backend, env, "prevSET", model.getBatchSize()); + genInitSpikeTime(backend, groupEnv, "prevSET", model.getBatchSize()); } // If neuron group requires delays, zero spike queue pointer @@ -624,87 +500,100 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, Environment } // Initialise neuron variables - genInitNeuronVarCode(backend, env, *this, m_VarInitASTs, "", - "numNeurons", 1, modelMerged.getModel().getBatchSize()); - + genInitNeuronVarCode( + backend, env, *this, "", "num_neurons", 0, modelMerged.getModel().getBatchSize()); + // Generate initialisation code for child groups - for (const auto &cs : getMergedCurrentSourceGroups()) { + for (auto &cs : m_MergedCurrentSourceGroups) { cs.generate(backend, env, *this, modelMerged); } - for(const auto &sg : getMergedInSynPSMGroups()) { + for(auto &sg : m_MergedInSynPSMGroups) { sg.generate(backend, env, *this, modelMerged); } - for (const auto &sg : getMergedOutSynPreOutputGroups()) { + for (auto &sg : m_MergedOutSynPreOutputGroups) { sg.generate(backend, env, *this, modelMerged); } - for (const auto &sg : getMergedOutSynWUMPreVarGroups()) { + for (auto &sg : m_MergedOutSynWUMPreVarGroups) { sg.generate(backend, env, *this, modelMerged); } - for (const auto &sg : getMergedInSynWUMPostVarGroups()) { + for (auto &sg : m_MergedInSynWUMPostVarGroups) { sg.generate(backend, env, *this, modelMerged); } } //-------------------------------------------------------------------------- -void NeuronInitGroupMerged::genInitSpikeCount(const BackendBase &backend, EnvironmentExternal &env, - bool spikeEvent, unsigned int batchSize) const +void NeuronInitGroupMerged::genInitSpikeCount(const BackendBase &backend, EnvironmentExternalBase &env, + bool spikeEvent, unsigned int batchSize) { // Is initialisation required at all - const bool initRequired = spikeEvent ? getArchetype().isSpikeEventRequired() : true; - if(initRequired) { + const bool required = spikeEvent ? getArchetype().isSpikeEventRequired() : true; + if(required) { + // Add spike count field + const std::string suffix = spikeEvent ? "Evnt" : ""; + EnvironmentGroupMergedField spikeCountEnv(env, *this); + spikeCountEnv.addField(Type::Uint32.createPointer(), "_spk_cnt", "spkCnt" + suffix, + [&backend, &suffix](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkCnt" + suffix + g.getName(); }); + // Generate variable initialisation code backend.genPopVariableInit(env, - [batchSize, spikeEvent, this] (EnvironmentExternal &spikeCountEnv) + [batchSize, spikeEvent, this] (EnvironmentExternalBase &spikeCountEnv) { - // Get variable name - const char *spikeCntName = spikeEvent ? "spkCntEvnt" : "spkCnt"; - // Is delay required const bool delayRequired = spikeEvent ? getArchetype().isDelayRequired() : (getArchetype().isTrueSpikeRequired() && getArchetype().isDelayRequired()); // Zero across all delay slots and batches - genScalarFill(spikeCountEnv.getStream(), spikeCntName, "0", VarAccessDuplication::DUPLICATE, batchSize, delayRequired, getArchetype().getNumDelaySlots()); + genScalarFill(spikeCountEnv, "_spk_cnt", "0", VarAccessDuplication::DUPLICATE, batchSize, delayRequired, getArchetype().getNumDelaySlots()); }); } } //-------------------------------------------------------------------------- -void NeuronInitGroupMerged::genInitSpikes(const BackendBase &backend, EnvironmentExternal &env, - bool spikeEvent, unsigned int batchSize) const +void NeuronInitGroupMerged::genInitSpikes(const BackendBase &backend, EnvironmentExternalBase &env, + bool spikeEvent, unsigned int batchSize) { // Is initialisation required at all - const bool initRequired = spikeEvent ? getArchetype().isSpikeEventRequired() : true; - if(initRequired) { + const bool required = spikeEvent ? getArchetype().isSpikeEventRequired() : true; + if(required) { + // Add spike count field + const std::string suffix = spikeEvent ? "Evnt" : ""; + EnvironmentGroupMergedField spikeEnv(env, *this); + spikeEnv.addField(Type::Uint32.createPointer(), "_spk", "spk" + suffix, + [&backend, &suffix](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpk" + suffix + g.getName(); }); + + // Generate variable initialisation code - backend.genVariableInit(env, "group->numNeurons", "id", - [batchSize, spikeEvent, this] (EnvironmentExternal &varEnv) + backend.genVariableInit(spikeEnv, "num_neurons", "id", + [batchSize, spikeEvent, this] (EnvironmentExternalBase &varEnv) { - // Get variable name - const char *spikeName = spikeEvent ? "spkEvnt" : "spk"; - + // Is delay required const bool delayRequired = spikeEvent ? getArchetype().isDelayRequired() : (getArchetype().isTrueSpikeRequired() && getArchetype().isDelayRequired()); // Zero across all delay slots and batches - genVariableFill(varEnv.getStream(), spikeName, "0", varEnv["id"], "group->numNeurons", + genVariableFill(varEnv, "_spk", "0", "id", "num_neurons", VarAccessDuplication::DUPLICATE, batchSize, delayRequired, getArchetype().getNumDelaySlots()); }); } } //------------------------------------------------------------------------ -void NeuronInitGroupMerged::genInitSpikeTime(const BackendBase &backend, EnvironmentExternal &env, - const std::string &varName, unsigned int batchSize) const +void NeuronInitGroupMerged::genInitSpikeTime(const BackendBase &backend, EnvironmentExternalBase &env, + const std::string &varName, unsigned int batchSize) { + // Add spike time field + EnvironmentGroupMergedField timeEnv(env, *this); + timeEnv.addField(getTimeType().createPointer(), "_time", varName, + [&backend, varName](const auto &g, size_t) { return backend.getDeviceVarPrefix() + varName + g.getName(); }); + + // Generate variable initialisation code - backend.genVariableInit(env, "group->numNeurons", "id", - [batchSize, varName, this] (EnvironmentExternal &varEnv) + backend.genVariableInit(env, "num_neurons", "id", + [batchSize, varName, this] (EnvironmentExternalBase &varEnv) { - genVariableFill(varEnv.getStream(), varName, "-TIME_MAX", varEnv["id"], "group->numNeurons", VarAccessDuplication::DUPLICATE, + genVariableFill(varEnv, varName, "-TIME_MAX", "id", "num_neurons", VarAccessDuplication::DUPLICATE, batchSize, getArchetype().isDelayRequired(), getArchetype().getNumDelaySlots()); - }); } From 2cad1871d490ee76d1d1d90846a1dba55b32bee0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 16 Jun 2023 16:14:17 +0100 Subject: [PATCH 232/725] deleting group merged crud --- .../genn/genn/code_generator/groupMerged.h | 159 +----------------- .../genn/code_generator/initGroupMerged.h | 6 + .../code_generator/neuronUpdateGroupMerged.h | 6 + src/genn/genn/code_generator/groupMerged.cc | 61 ------- .../genn/code_generator/initGroupMerged.cc | 89 ++++++---- .../code_generator/neuronUpdateGroupMerged.cc | 12 +- 6 files changed, 83 insertions(+), 250 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 957326449a..f11bcc2b90 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -354,149 +354,16 @@ class GroupMerged : public ChildGroupMerged } } -//protected: - //------------------------------------------------------------------------ - // Protected methods - //------------------------------------------------------------------------ void addField(const Type::ResolvedType &type, const std::string &name, GetFieldValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) { // Add field to data structure m_Fields.emplace_back(type, name, getFieldValue, fieldType); } - void addScalarField(const std::string &name, GetFieldDoubleValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) - { - addField(getScalarType(), name, - [getFieldValue, this](const G &g, size_t i) - { - return Utils::writePreciseString(getFieldValue(g, i), getScalarType().getNumeric().maxDigits10) + getScalarType().getNumeric().literalSuffix; - }, - fieldType); - } - - void addPointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix, - GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) - { - assert(type.isValue()); - addField(type.createPointer(), name, - [prefix](const G &g, size_t) { return prefix + g.getName(); }, - fieldType); - } - - void addPointerField(const Type::UnresolvedType &type, const std::string &name, const std::string &prefix, - GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) - { - addPointerField(type.resolve(getTypeContext()), name, prefix, fieldType); - } - - - void addVars(const Models::Base::VarVec &vars, const std::string &arrayPrefix) - { - // Loop through variables - for(const auto &v : vars) { - addPointerField(v.type, v.name, arrayPrefix + v.name); - } - } - - template - void addVarReferences(const Models::Base::VarRefVec &varReferences, const std::string &arrayPrefix, V getVarRefFn) - { - // Loop through variables - for(const auto &v : varReferences) { - addField(v.type.resolve(getTypeContext()).createPointer(), v.name, - [getVarRefFn, arrayPrefix, v](const G &g, size_t) - { - const auto varRef = getVarRefFn(g).at(v.name); - return arrayPrefix + varRef.getVar().name + varRef.getTargetName(); - }); - } - } - - void addEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "") - { - for(const auto &e : egps) { - addField(e.type.resolve(getTypeContext()).createPointer(), e.name + varName, - [e, arrayPrefix, varName](const G &g, size_t) { return arrayPrefix + e.name + varName + g.getName(); }, - GroupMergedFieldType::DYNAMIC); - } - } - - template - void addHeterogeneousParams(const Snippet::Base::StringVec ¶mNames, const std::string &suffix, - P getParamValues, H isHeterogeneous) - { - // Loop through params - for(const auto &p : paramNames) { - // If parameters is heterogeneous - // **TODO** std::invoke - if((static_cast(this)->*isHeterogeneous)(p)) { - // Add field - addScalarField(p + suffix, - [p, getParamValues](const G &g, size_t) - { - return getParamValues(g).at(p); - }); - } - } - } - - template - void addHeterogeneousDerivedParams(const Snippet::Base::DerivedParamVec &derivedParams, const std::string &suffix, - D getDerivedParamValues, H isHeterogeneous) - { - // Loop through derived params - for(const auto &d : derivedParams) { - // If parameters isn't homogeneous - // **TODO** std::invoke - if((static_cast(this)->*isHeterogeneous)(d.name)) { - // Add field - addScalarField(d.name + suffix, - [d, getDerivedParamValues](const G &g, size_t) - { - return getDerivedParamValues(g).at(d.name); - }); - } - } - } - - template - void addHeterogeneousVarInitParams(H isHeterogeneous, const std::string &suffix = "") - { - // Loop through weight update model variables - const A archetypeAdaptor(getArchetype()); - for(const auto &v : archetypeAdaptor.getDefs()) { - // Loop through parameters - for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getParams()) { - if((static_cast(this)->*isHeterogeneous)(v.name, p.first)) { - addScalarField(p.first + v.name + suffix, - [p, v](const G &g, size_t) - { - return A(g).getInitialisers().at(v.name).getParams().at(p.first); - }); - } - } - } - } - - template - void addHeterogeneousVarInitDerivedParams(H isHeterogeneous, const std::string &suffix = "") - { - // Loop through weight update model variables - const A archetypeAdaptor(getArchetype()); - for(const auto &v : archetypeAdaptor.getDefs()) { - // Loop through parameters - for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getDerivedParams()) { - if((static_cast(this)->*isHeterogeneous)(v.name, p.first)) { - addScalarField(p.first + v.name + suffix, - [p, v](const G &g, size_t) - { - return A(g).getInitialisers().at(v.name).getDerivedParams().at(p.first); - }); - } - } - } - } - +protected: + //------------------------------------------------------------------------ + // Protected methods + //------------------------------------------------------------------------ void generateRunnerBase(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, @@ -645,25 +512,9 @@ class GENN_EXPORT NeuronPrevSpikeTimeUpdateGroupMerged : public GroupMerged { public: - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - //! Should the parameter be implemented heterogeneously? - bool isParamHeterogeneous(const std::string ¶mName) const; - - //! Should the derived parameter be implemented heterogeneously? - bool isDerivedParamHeterogeneous(const std::string ¶mName) const; - - //! Should the var init parameter be implemented heterogeneously? - bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - //! Should the var init derived parameter be implemented heterogeneously? - bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + using GroupMerged::GroupMerged; protected: - NeuronGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); - //------------------------------------------------------------------------ // Protected API //------------------------------------------------------------------------ diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 90506bf392..e958e1015f 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -186,6 +186,12 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase void genInitSpikeTime(const BackendBase &backend, EnvironmentExternalBase &env, const std::string &varName, unsigned int batchSize); + //! Should the var init parameter be implemented heterogeneously? + bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + + //! Should the var init derived parameter be implemented heterogeneously? + bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index 2a8a2d792a..b8c176846b 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -201,6 +201,12 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase static const std::string name; private: + //! Should the parameter be implemented heterogeneously? + bool isParamHeterogeneous(const std::string ¶mName) const; + + //! Should the derived parameter be implemented heterogeneously? + bool isDerivedParamHeterogeneous(const std::string ¶mName) const; + //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 986306d3b4..e748ba2f8e 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -103,67 +103,6 @@ NeuronPrevSpikeTimeUpdateGroupMerged::NeuronPrevSpikeTimeUpdateGroupMerged(size_ //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronGroupMergedBase //---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isParamHeterogeneous(const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, [](const NeuronGroupInternal &ng) { return ng.getParams(); }); -} -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isDerivedParamHeterogeneous(const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, [](const NeuronGroupInternal &ng) { return ng.getDerivedParams(); }); -} -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return (isVarInitParamReferenced(varName, paramName) && - isParamValueHeterogeneous(paramName, - [varName](const NeuronGroupInternal &sg) { return sg.getVarInitialisers().at(varName).getParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return (isVarInitParamReferenced(varName, paramName) && - isParamValueHeterogeneous(paramName, - [varName](const NeuronGroupInternal &sg){ return sg.getVarInitialisers().at(varName).getDerivedParams(); })); -} -//---------------------------------------------------------------------------- -NeuronGroupMergedBase::NeuronGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: GroupMerged(index, typeContext, groups) -{ - using namespace Type; - - addField(Uint32, "numNeurons", - [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); - - addPointerField(Uint32, "spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); - addPointerField(Uint32, "spk", backend.getDeviceVarPrefix() + "glbSpk"); - - if(getArchetype().isSpikeEventRequired()) { - addPointerField(Uint32, "spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); - addPointerField(Uint32, "spkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); - } - - if(getArchetype().isDelayRequired()) { - addPointerField(Uint32, "spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); - } - - // **TODO** add to type environment for update - if(getArchetype().isSpikeTimeRequired()) { - addPointerField(getTimeType(), "sT", backend.getDeviceVarPrefix() + "sT"); - } - if(getArchetype().isSpikeEventTimeRequired()) { - addPointerField(getTimeType(), "seT", backend.getDeviceVarPrefix() + "seT"); - } - - if(getArchetype().isPrevSpikeTimeRequired()) { - addPointerField(getTimeType(), "prevST", backend.getDeviceVarPrefix() + "prevST"); - } - if(getArchetype().isPrevSpikeEventTimeRequired()) { - addPointerField(getTimeType(), "prevSET", backend.getDeviceVarPrefix() + "prevSET"); - } -} -//---------------------------------------------------------------------------- bool NeuronGroupMergedBase::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const { const auto *varInitSnippet = getArchetype().getVarInitialisers().at(varName).getSnippet(); diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 183dc21ac0..c5ecdab0c9 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -70,13 +70,12 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e // If there is any initialisation code const auto resolvedType = var.type.resolve(group.getTypeContext()); const auto &varInit = adaptor.getInitialisers().at(var.name); - const auto *snippet = adaptor.getInitialisers().at(var.name).getSnippet(); + const auto *snippet = varInit.getSnippet(); if (!snippet->getCode().empty()) { CodeStream::Scope b(env.getStream()); - EnvironmentGroupMergedField varEnv(env, group, fieldGroup); - // Substitute in parameters and derived parameters for initialising variables + EnvironmentGroupMergedField varEnv(env, group, fieldGroup); varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous, fieldSuffix); varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, fieldSuffix); varEnv.addExtraGlobalParams(snippet->getExtraGlobalParameters(), backend.getDeviceVarPrefix(), var.name, fieldSuffix); @@ -113,7 +112,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e backend.genVariableInit( varEnvs, count, "id", [&adaptor, &fieldSuffix, &group, &var, &resolvedType, batchSize, count, numDelaySlots] - (EnvironmentExternal &varInitEnv) + (EnvironmentExternalBase &varInitEnv) { // Generate initial value into temporary variable varInitEnv.getStream() << resolvedType.getName() << " initVal;" << std::endl; @@ -142,43 +141,50 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e } //------------------------------------------------------------------------ // Initialise one row of weight update model variables -template -void genInitWUVarCode(CodeStream &os, const ModelSpecMerged &modelMerged, const Substitutions &popSubs, +template +void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, G &group, const Models::Base::VarVec &vars, const std::unordered_map &varInitialisers, - const std::string &stride, const size_t groupIndex, unsigned int batchSize, - P isParamHeterogeneousFn, D isDerivedParamHeterogeneousFn, G genSynapseVariableRowInitFn) + const std::string &stride, unsigned int batchSize, + EnvironmentGroupMergedField::IsVarInitHeterogeneousFn isParamHeterogeneousFn, + EnvironmentGroupMergedField::IsVarInitHeterogeneousFn isDerivedParamHeterogeneousFn, + V genSynapseVariableRowInitFn) { for (const auto &var : vars) { + // If this variable has any initialisation code and doesn't require a kernel + const auto resolvedType = var.type.resolve(group.getTypeContext()); const auto &varInit = varInitialisers.at(var.name); + const auto *snippet = adaptor.varInit.getSnippet(); + if(!snippet->getCode().empty() && !varInit.getSnippet()->requiresKernel()) { + CodeStream::Scope b(env.getStream()); - // If this variable has any initialisation code and doesn't require a kernel - if(!varInit.getSnippet()->getCode().empty() && !varInit.getSnippet()->requiresKernel()) { - CodeStream::Scope b(os); + // Substitute in parameters and derived parameters for initialising variables + EnvironmentGroupMergedField varEnv(env, group); + varEnv.addVarInitParams(isParamHeterogeneousFn, fieldSuffix); + varEnv.addVarInitDerivedParams(isDerivedParamHeterogeneousFn, fieldSuffix); + varEnv.addExtraGlobalParams(snippet->getExtraGlobalParameters(), backend.getDeviceVarPrefix(), var.name, fieldSuffix); - // Generate target-specific code to initialise variable - genSynapseVariableRowInitFn(os, popSubs, - [&var, &varInit, &stride, &modelMerged, batchSize, groupIndex, isParamHeterogeneousFn, isDerivedParamHeterogeneousFn] - (CodeStream &os, Substitutions &varSubs) - { - varSubs.addParamValueSubstitution(varInit.getSnippet()->getParamNames(), varInit.getParams(), - [&var, isParamHeterogeneousFn](const std::string &p) { return isParamHeterogeneousFn(var.name, p); }, - "", "group->", var.name); - varSubs.addVarValueSubstitution(varInit.getSnippet()->getDerivedParams(), varInit.getDerivedParams(), - [&var, isDerivedParamHeterogeneousFn](const std::string &p) { return isDerivedParamHeterogeneousFn(var.name, p); }, - "", "group->", var.name); - varSubs.addVarNameSubstitution(varInit.getSnippet()->getExtraGlobalParams(), - "", "group->", var.name); + // Add field for variable itself + varEnv.addField(resolvedType.createPointer(), "_value", var.name + fieldSuffix, + [&backend, var](const auto &g, size_t) + { + return backend.getDeviceVarPrefix() + var.name + g.getName(); + }); + // Generate target-specific code to initialise variable + genSynapseVariableRowInitFn(varEnv, + [&group, &modelMerged, &resolvedType, &stride, &var, batchSize, snippet] + (EnvironmentExternalBase &varInitEnv) + { // Generate initial value into temporary variable - os << var.type.resolve(modelMerged.getTypeContext()).getName() << " initVal;" << std::endl; - varSubs.addVarSubstitution("value", "initVal"); - std::string code = varInit.getSnippet()->getCode(); - varSubs.applyCheckUnreplaced(code, "initVar : merged" + var.name + std::to_string(groupIndex)); - //code = ensureFtype(code, scalarType); - os << code << std::endl; + varInitEnv.getStream() << resolvedType.getName() << " initVal;" << std::endl; + varInitEnv.add(resolvedType, "value", "initVal"); + + // Pretty print variable initialisation code + Transpiler::ErrorHandler errorHandler("Variable '" + var.name + "' init code" + std::to_string(group.getIndex())); + prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all batches - genVariableFill(os, var.name, "initVal", "id_syn", stride, + genVariableFill(varInitEnv(), "_value", "initVal", "id_syn", stride, getVarAccessDuplication(var.access), batchSize); }); } @@ -490,12 +496,15 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, Environment genInitSpikeTime(backend, groupEnv, "prevSET", model.getBatchSize()); } - // If neuron group requires delays, zero spike queue pointer + // If neuron group requires delays if(getArchetype().isDelayRequired()) { + // Add spike queue pointer field and zero + groupEnv.addField(Type::Uint32.createPointer(), "_spk_que_ptr", "spkQuePtr", + [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + g.getName(); }); backend.genPopVariableInit(env, - [](CodeStream &os, Substitutions &) + [](EnvironmentExternalBase &varEnv) { - os << "*group->spkQuePtr = 0;" << std::endl; + varEnv.getStream() << "*" << varEnv["_spk_que_ptr"] << " = 0;" << std::endl; }); } @@ -596,6 +605,18 @@ void NeuronInitGroupMerged::genInitSpikeTime(const BackendBase &backend, Environ batchSize, getArchetype().isDelayRequired(), getArchetype().getNumDelaySlots()); }); } +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return (isParamValueHeterogeneous(paramName, + [varName](const NeuronGroupInternal &sg) { return sg.getVarInitialisers().at(varName).getParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return (isParamValueHeterogeneous(paramName, + [varName](const NeuronGroupInternal &sg){ return sg.getVarInitialisers().at(varName).getDerivedParams(); })); +} //---------------------------------------------------------------------------- // GeNN::CodeGenerator::SynapseInitGroupMerged diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 1772196976..38a511b445 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -382,7 +382,7 @@ const std::string NeuronUpdateGroupMerged::name = "NeuronUpdate"; //---------------------------------------------------------------------------- NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: NeuronGroupMergedBase(index, typeContext, backend, groups) +: NeuronGroupMergedBase(index, typeContext, groups) { // Loop through neuron groups /*std::vector> eventThresholdSGs; @@ -834,3 +834,13 @@ std::string NeuronUpdateGroupMerged::getWriteVarIndex(bool delay, unsigned int b return getVarIndex(batchSize, varDuplication, index); } } +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::isParamHeterogeneous(const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [](const NeuronGroupInternal &ng) { return ng.getParams(); }); +} +//---------------------------------------------------------------------------- +bool NeuronUpdateGroupMerged::isDerivedParamHeterogeneous(const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [](const NeuronGroupInternal &ng) { return ng.getDerivedParams(); }); +} From 082ffbcbbf944524cf31f53d8021a0790c100f8e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 16 Jun 2023 17:55:21 +0100 Subject: [PATCH 233/725] some gnarliness remains but local cached environment now re-extended to handle variable references --- .../genn/genn/code_generator/backendBase.h | 2 +- .../code_generator/customUpdateGroupMerged.h | 25 +- .../genn/genn/code_generator/environment.h | 153 +++++++--- .../genn/code_generator/modelSpecMerged.h | 54 ++-- .../backends/single_threaded_cpu/backend.cc | 280 +++++++++--------- src/genn/genn/code_generator/backendBase.cc | 42 ++- .../code_generator/customUpdateGroupMerged.cc | 146 +++------ .../code_generator/neuronUpdateGroupMerged.cc | 34 +-- 8 files changed, 401 insertions(+), 335 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index c79c1e2dee..b339769a5d 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -729,7 +729,7 @@ class GENN_EXPORT BackendBase } } } - void genCustomUpdateIndexCalculation(CodeStream &os, const CustomUpdateGroupMerged &cu) const; + void genCustomUpdateIndexCalculation(EnvironmentGroupMergedField &env) const; void genCustomConnectivityUpdateIndexCalculation(CodeStream &os, const CustomConnectivityUpdateGroupMerged &cu) const; diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index eca80747e2..bc841df5c6 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -13,15 +13,9 @@ namespace GeNN::CodeGenerator class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged { public: - CustomUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); - //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- - bool isParamHeterogeneous(const std::string ¶mName) const; - bool isDerivedParamHeterogeneous(const std::string ¶mName) const; - boost::uuids::detail::sha1::digest_type getHashDigest() const; void generateRunner(const BackendBase &backend, @@ -33,13 +27,11 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged //GroupMergedTypeEnvironment typeEnvironment(*this, Type::parseNumeric(precision)); // Loop through variables and add pointers if they are reduction targets - const CustomUpdateModels::Base *cm = this->getArchetype().getCustomUpdateModel(); + /*const CustomUpdateModels::Base *cm = this->getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { if(v.access & VarAccessModeAttribute::REDUCE) { this->addPointerField(v.type, v.name, backend.getDeviceVarPrefix() + v.name); @@ -205,7 +194,7 @@ class CustomUpdateHostReductionGroupMergedBase : public GroupMerged if(v.access & VarAccessModeAttribute::REDUCE) { this->addPointerField(v.type, v.name, backend.getDeviceVarPrefix() + v.name); } - } + }*/ } }; diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index c149b86bcb..ef6f710b43 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -565,22 +565,100 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase>> m_Environment; }; +template +class VarCachePolicy +{ +public: + using GroupInternal = typename G::GroupInternal; + using GetIndexFn = std::function; + + VarCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex) + : m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex) + {} + + VarCachePolicy(GetIndexFn getIndex) + : m_GetReadIndex(getIndex), m_GetWriteIndex(getIndex) + {} + + std::string getVarSuffix(const GroupInternal &g, const Models::Base::Var &var) + { + return A(g).getNameSuffix(); + } + + std::string getReadIndex(G &g, const Models::Base::Var &var) + { + return m_GetReadIndex(var.name, getVarAccessDuplication(var.access)); + } + + std::string getWriteIndex(G &g, const Models::Base::Var &var) + { + return m_GetWriteIndex(var.name, getVarAccessDuplication(var.access)); + } + +private: + GetIndexFn m_GetReadIndex; + GetIndexFn m_GetWriteIndex; +}; + +template +class VarRefCachePolicy +{ +protected: + using GroupInternal = typename G::GroupInternal; + using Initialiser = typename std::remove_reference_t>::mapped_type; + using GetIndexFn = std::function; + + + VarRefCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex) + : m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex) + {} + + VarRefCachePolicy(GetIndexFn getIndex) + : m_GetReadIndex(getIndex), m_GetWriteIndex(getIndex) + {} + + std::string getVarSuffix(const GroupInternal &g, const Models::Base::VarRef &var) + { + return A(g).getInitialisers().at(var.name).getTargetName(); + } + + std::string getReadIndex(G &g, const Models::Base::VarRef &var) + { + return m_GetReadIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); + } + + std::string getWriteIndex(G &g, const Models::Base::VarRef &var) + { + return m_GetWriteIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); + } + +private: + GetIndexFn m_GetReadIndex; + GetIndexFn m_GetWriteIndex; +}; + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentLocalVarCache //---------------------------------------------------------------------------- //! Pretty printing environment which caches used variables in local variables -template -class EnvironmentLocalVarCache : public EnvironmentExternalBase +template +class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P { - //! Function used to provide index strings based on var name and - using GetIndexFn = std::function; + //! Type of a single definition + using Def = typename std::invoke_result_t::value_type; public: - EnvironmentLocalVarCache(G &group, F &fieldGroup, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, - const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, - GetIndexFn getReadIndex, GetIndexFn getWriteIndex) - : EnvironmentExternalBase(enclosing), m_Group(group), m_FieldGroup(fieldGroup), m_Context(context), m_Contents(m_ContentsStream), - m_ArrayPrefix(arrayPrefix), m_FieldSuffix(fieldSuffix), m_LocalPrefix(localPrefix), m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex) + /*template + EnvironmentExternalDynamicBase(EnvironmentExternalBase &enclosing, PolicyArgs&&... policyArgs) + : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...) + {}*/ + + template + EnvironmentLocalCacheBase(G &group, F &fieldGroup, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, + const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, + PolicyArgs&&... policyArgs) + : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...), m_Group(group), m_FieldGroup(fieldGroup), + m_Context(context), m_Contents(m_ContentsStream), m_ArrayPrefix(arrayPrefix), m_FieldSuffix(fieldSuffix), m_LocalPrefix(localPrefix) { // Copy variables into variables referenced, alongside boolean const auto defs = A(m_Group.get().getArchetype()).getDefs(); @@ -588,45 +666,35 @@ class EnvironmentLocalVarCache : public EnvironmentExternalBase [](const auto &v){ return std::make_pair(v.name, std::make_pair(false, v)); }); } - EnvironmentLocalVarCache(G &group, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, - const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, - GetIndexFn getReadIndex, GetIndexFn getWriteIndex) - : EnvironmentLocalVarCache(group, group, context, enclosing, arrayPrefix, fieldSuffix, localPrefix, getReadIndex, getWriteIndex) - {} - - EnvironmentLocalVarCache(G &group, F &fieldGroup, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, - const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, GetIndexFn getIndex) - : EnvironmentLocalVarCache(group, fieldGroup, context, enclosing, arrayPrefix, fieldSuffix, localPrefix, getIndex, getIndex) - { - } - - EnvironmentLocalVarCache(G &group, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, - const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, GetIndexFn getIndex) - : EnvironmentLocalVarCache(group, group, context, enclosing, arrayPrefix, fieldSuffix, localPrefix, getIndex, getIndex) - { - } + /*template + EnvironmentLocalCacheBase(G &group, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, + const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, + PolicyArgs&&... policyArgs) + : EnvironmentLocalVarCache(group, group, context, enclosing, arrayPrefix, fieldSuffix, localPrefix, std::forward(policyArgs)...) + {}*/ + - EnvironmentLocalVarCache(const EnvironmentLocalVarCache&) = delete; + EnvironmentLocalCacheBase(const EnvironmentLocalCacheBase&) = delete; - ~EnvironmentLocalVarCache() + ~EnvironmentLocalCacheBase() { A archetypeAdapter(m_Group.get().getArchetype()); // Copy definitions of variables which have been referenced into new vector const auto varDefs = archetypeAdapter.getDefs(); - Models::Base::VarVec referencedVars; - std::copy_if(varDefs.cbegin(), varDefs.cend(), std::back_inserter(referencedVars), + std::vector referencedDefs; + std::copy_if(varDefs.cbegin(), varDefs.cend(), std::back_inserter(referencedDefs), [this](const auto &v){ return m_VariablesReferenced.at(v.name).first; }); - // Loop through referenced variables - for(const auto &v : referencedVars) { + // Loop through referenced definitions + for(const auto &v : referencedDefs) { const auto resolvedType = v.type.resolve(m_Context.get()); // Add field to underlying field group m_FieldGroup.get().addField(resolvedType.createPointer(), v.name + m_FieldSuffix, [this, v](const typename F::GroupInternal &, size_t i) { - return m_ArrayPrefix + v.name + A(m_Group.get().getGroups().at(i)).getNameSuffix(); + return m_ArrayPrefix + v.name + getVarSuffix(m_Group.get().getGroups().at(i), v); }); if(v.access & VarAccessMode::READ_ONLY) { @@ -638,7 +706,7 @@ class EnvironmentLocalVarCache : public EnvironmentExternalBase // **NOTE** by not initialising these variables for reductions, // compilers SHOULD emit a warning if user code doesn't set it to something if(!(v.access & VarAccessModeAttribute::REDUCE)) { - getContextStream() << " = group->" << v.name << m_FieldSuffix << "[" << m_GetReadIndex(v.name, v.access) << "]"; + getContextStream() << " = group->" << v.name << m_FieldSuffix << "[" << getReadIndex(m_Group.get(), v) << "]"; } getContextStream() << ";" << std::endl; } @@ -646,11 +714,11 @@ class EnvironmentLocalVarCache : public EnvironmentExternalBase // Write contents to context stream getContextStream() << m_ContentsStream.str(); - // Loop through referenced variables again - for(const auto &v : referencedVars) { + // Loop through referenced definitions again + for(const auto &v : referencedDefs) { // If variables are read-write if(v.access & VarAccessMode::READ_WRITE) { - getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << m_GetWriteIndex(v.name, v.access) << "]"; + getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << getWriteIndex(m_Group.get(), v) << "]"; getContextStream() << " = " << m_LocalPrefix << v.name << ";" << std::endl; } } @@ -715,8 +783,13 @@ class EnvironmentLocalVarCache : public EnvironmentExternalBase std::string m_ArrayPrefix; std::string m_FieldSuffix; std::string m_LocalPrefix; - GetIndexFn m_GetReadIndex; - GetIndexFn m_GetWriteIndex; - std::unordered_map> m_VariablesReferenced; + std::unordered_map> m_VariablesReferenced; }; + +template +using EnvironmentLocalVarCache = EnvironmentLocalCacheBase, A, G, F>; + +template +using EnvironmentLocalVarRefCache = EnvironmentLocalCacheBase, A, G, F>; + } // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 380dc61143..d636e71f35 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -197,58 +197,76 @@ class GENN_EXPORT ModelSpecMerged } template - void genMergedCustomUpdateGroups(const BackendBase &backend, G generateGroup) + void genMergedCustomUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, G generateGroup) { createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateGroups, - [](const CustomUpdateInternal &) { return true; }, + [&updateGroupName](const CustomUpdateInternal &cg) { return cg.getUpdateGroupName() == updateGroupName(); }, &CustomUpdateInternal::getHashDigest, generateGroup); } template - void genMergedCustomUpdateWUGroups(const BackendBase &backend, G generateGroup) + void genMergedCustomUpdateWUGroups(const BackendBase &backend, const std::string &updateGroupName, G generateGroup) { createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, - [](const CustomUpdateWUInternal &cg) { return !cg.isTransposeOperation(); }, + [&updateGroupName](const CustomUpdateWUInternal &cg) + { + return (!cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName()); + }, &CustomUpdateWUInternal::getHashDigest, generateGroup); } template - void genMergedCustomUpdateTransposeWUGroups(const BackendBase &backend, G generateGroup) + void genMergedCustomUpdateTransposeWUGroups(const BackendBase &backend, const std::string &updateGroupName, G generateGroup) { createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, - [](const CustomUpdateWUInternal &cg) { return cg.isTransposeOperation(); }, + [&updateGroupName](const CustomUpdateWUInternal &cg) + { + return (cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName()); + }, &CustomUpdateWUInternal::getHashDigest, generateGroup); } template - void genMergedCustomUpdateHostReductionGroups(const BackendBase &backend, G generateGroup) + void genMergedCustomUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, G generateGroup) { createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateHostReductionGroups, - [](const CustomUpdateInternal &cg) { return cg.isBatchReduction(); }, + [&updateGroupName](const CustomUpdateInternal &cg) + { + return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName()); + }, &CustomUpdateInternal::getHashDigest, generateGroup, true); } template - void genMergedCustomWUUpdateHostReductionGroups(const BackendBase &backend, G generateGroup) + void genMergedCustomWUUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, G generateGroup) { createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateHostReductionGroups, - [](const CustomUpdateWUInternal &cg) { return cg.isBatchReduction(); }, + [&updateGroupName](const CustomUpdateWUInternal &cg) + { + return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName()); + }, &CustomUpdateWUInternal::getHashDigest, generateGroup, true); } template - void genMergedCustomConnectivityUpdateGroups(const BackendBase &backend, G generateGroup) + void genMergedCustomConnectivityUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, G generateGroup) { createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateGroups, - [](const CustomConnectivityUpdateInternal &cg) { return !cg.getCustomConnectivityUpdateModel()->getRowUpdateCode().empty(); }, + [&updateGroupName](const CustomConnectivityUpdateInternal &cg) + { + return (!cg.getCustomConnectivityUpdateModel()->getRowUpdateCode().empty() && cg.getUpdateGroupName() == updateGroupName()); + }, &CustomConnectivityUpdateInternal::getHashDigest, genereateGroup); } template - void genMergedCustomConnectivityHostUpdateGroups(BackendBase &backend, G generateGroup) + void genMergedCustomConnectivityHostUpdateGroups(BackendBase &backend, const std::string &updateGroupName, G generateGroup) { createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityHostUpdateGroups, - [](const CustomConnectivityUpdateInternal &cg) { return !cg.getCustomConnectivityUpdateModel()->getHostUpdateCode().empty(); }, + [&updateGroupName](const CustomConnectivityUpdateInternal &cg) + { + return (!cg.getCustomConnectivityUpdateModel()->getHostUpdateCode().empty() && cg.getUpdateGroupName() == updateGroupName()); + }, &CustomConnectivityUpdateInternal::getHashDigest, generateGroup, true); } @@ -540,7 +558,7 @@ class GENN_EXPORT ModelSpecMerged } // Reserve final merged groups vector - mergedGroups.reserve(protoMergedGroups.size()); + mergedGroups.reserve(mergedGroups.size() + protoMergedGroups.size()); // Loop through resultant merged groups size_t i = 0; @@ -571,10 +589,10 @@ class GENN_EXPORT ModelSpecMerged } } - template + template void createMergedGroups(const BackendBase &backend, const std::map &groups, std::vector &mergedGroups, - F filter, U updateHash, G generateGroup, bool host = false) + F filter, D getHashDigest, G generateGroup, bool host = false) { // Build temporary vector of references to groups that pass filter std::vector> unmergedGroups; @@ -585,7 +603,7 @@ class GENN_EXPORT ModelSpecMerged } // Merge filtered vector - createMergedGroups(backend, unmergedGroups, mergedGroups, updateHash, generateGroup, host); + createMergedGroups(backend, unmergedGroups, mergedGroups, getHashDigest, generateGroup, host); } //-------------------------------------------------------------------------- diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index f2aeac37c3..3939cb0d61 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -543,26 +543,11 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos os << synapseUpdateStream.str(); } //-------------------------------------------------------------------------- -void Backend::genCustomUpdate(CodeStream &os_, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); - // Generate struct definitions - modelMerged.genMergedCustomUpdateStructs(os_, *this); - modelMerged.genMergedCustomUpdateWUStructs(os_, *this); - modelMerged.genMergedCustomUpdateTransposeWUStructs(os_, *this); - modelMerged.genMergedCustomConnectivityUpdateStructs(os_, *this); - - // Generate arrays of merged structs and functions to set them - genMergedStructArrayPush(os_, modelMerged.getMergedCustomUpdateGroups()); - genMergedStructArrayPush(os_, modelMerged.getMergedCustomUpdateWUGroups()); - genMergedStructArrayPush(os_, modelMerged.getMergedCustomUpdateTransposeWUGroups()); - genMergedStructArrayPush(os_, modelMerged.getMergedCustomConnectivityUpdateGroups()); - - // Generate preamble - preambleHandler(os_); - - // Build set containing union of all custom update groupsnames + // Build set containing names of all custom update groups std::set customUpdateGroups; std::transform(model.getCustomUpdates().cbegin(), model.getCustomUpdates().cend(), std::inserter(customUpdateGroups, customUpdateGroups.end()), @@ -574,16 +559,23 @@ void Backend::genCustomUpdate(CodeStream &os_, ModelSpecMerged &modelMerged, Hos std::inserter(customUpdateGroups, customUpdateGroups.end()), [](const ModelSpec::CustomConnectivityUpdateValueType &v) { return v.second.getUpdateGroupName(); }); + // Generate stream with custom update code + std::ostringstream customUpdateStream; + CodeStream customUpdate(customUpdateStream); + + // Begin environment with standard library + EnvironmentLibrary customUpdateEnv(customUpdate, StandardLibrary::getFunctions()); + // Loop through custom update groups for(const auto &g : customUpdateGroups) { - os_ << "void update" << g << "()"; + customUpdateEnv.getStream() << "void update" << g << "()"; { - CodeStream::Scope b(os_); + CodeStream::Scope b(customUpdateEnv.getStream()); + + EnvironmentExternal funcEnv(customUpdateEnv); + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); + funcEnv.add(Type::Uint32.addConst(), "batch", "0"); - StandardLibrary::Environment stdEnv(os_); - EnvironmentSubstitute funcEnv(stdEnv); - funcEnv.addSubstitution("t", "t"); - funcEnv.addSubstitution("batch", "0"); // Loop through host update groups and generate code for those in this custom update group for (const auto &cg : modelMerged.getMergedCustomConnectivityHostUpdateGroups()) { @@ -595,145 +587,149 @@ void Backend::genCustomUpdate(CodeStream &os_, ModelSpecMerged &modelMerged, Hos { Timer t(funcEnv.getStream(), "customUpdate" + g, model.isTimingEnabled()); - - // Loop through merged custom update groups - for(const auto &c : modelMerged.getMergedCustomUpdateGroups()) { - // If this update group isn't for current group, skip - if(c.getArchetype().getUpdateGroupName() != g) { - continue; - } - - CodeStream::Scope b(funcEnv.getStream()); - funcEnv.getStream() << "// merged custom update group " << c.getIndex() << std::endl; - funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + modelMerged.genMergedCustomUpdateGroups( + *this, g, + [this, &funcEnv](auto &c) { CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged custom update group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(funcEnv.getStream()); - // Get reference to group - funcEnv.getStream() << "const auto *group = &mergedCustomUpdateGroup" << c.getIndex() << "[g]; " << std::endl; - - genCustomUpdateIndexCalculation(funcEnv.getStream(), c); + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedCustomUpdateGroup" << c.getIndex() << "[g]; " << std::endl; + + // Create matching environment + EnvironmentGroupMergedField groupEnv(funcEnv, c); + + genCustomUpdateIndexCalculation(groupEnv, c); - if (c.getArchetype().isNeuronReduction()) { - // Initialise reduction targets - const auto reductionTargets = genInitReductionTargets(funcEnv.getStream(), c); + if (c.getArchetype().isNeuronReduction()) { + // Initialise reduction targets + // **TODO** these should be provided with some sort of caching mechanism + const auto reductionTargets = genInitReductionTargets(groupEnv.getStream(), c); - // Loop through group members - funcEnv.getStream() << "for(unsigned int i = 0; i < group->size; i++)"; - { - CodeStream::Scope b(funcEnv.getStream()); + // Loop through group members + groupEnv.getStream() << "for(unsigned int i = 0; i < group->size; i++)"; + { + CodeStream::Scope b(groupEnv.getStream()); - // Generate custom update - EnvironmentSubstitute env(funcEnv); - env.addSubstitution("id", "i"); - c.generateCustomUpdate(*this, env); + // Generate custom update + EnvironmentGroupMergedField memberEnv(groupEnv, c); + memberEnv.addSubstitution("id", "i"); + c.generateCustomUpdate(*this, memberEnv); + + // Loop through reduction targets and generate reduction + // **TODO** reduction should be automatically implemented by transpiler + for (const auto &r : reductionTargets) { + memberEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; + } + } - // Loop through reduction targets and generate reduction + // Write back reductions for (const auto &r : reductionTargets) { - env.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; + groupEnv.getStream() << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; } } + else { + // Loop through group members + groupEnv.getStream() << "for(unsigned int i = 0; i < group->size; i++)"; + { + CodeStream::Scope b(groupEnv.getStream()); - // Write back reductions - for (const auto &r : reductionTargets) { - funcEnv.getStream() << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; - } - } - else { - // Loop through group members - funcEnv.getStream() << "for(unsigned int i = 0; i < group->size; i++)"; - { - CodeStream::Scope b(funcEnv.getStream()); - - // Generate custom update - EnvironmentSubstitute env(funcEnv); - env.addSubstitution("id", "i"); - c.generateCustomUpdate(*this, env); + // Generate custom update + EnvironmentGroupMergedField memberEnv(groupEnv, c); + memberEnv.add(Type::Uint32.addConst(), "id", "i"); + c.generateCustomUpdate(*this, memberEnv); - // Write back reductions - genWriteBackReductions(env, c, "id"); + // Write back reductions + genWriteBackReductions(memberEnv, c, "id"); + } } } - } - } + }); // Loop through merged custom WU update groups - for(const auto &c : modelMerged.getMergedCustomUpdateWUGroups()) { - // If this update group isn't for current group, skip - if(c.getArchetype().getUpdateGroupName() != g) { - continue; - } - - CodeStream::Scope b(funcEnv.getStream()); - funcEnv.getStream() << "// merged custom WU update group " << c.getIndex() << std::endl; - funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + modelMerged.genMergedCustomUpdateWUGroups( + *this, g, + [this, &funcEnv](auto &c) { CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged custom WU update group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(funcEnv.getStream()); - // Get reference to group - funcEnv.getStream() << "const auto *group = &mergedCustomUpdateWUGroup" << c.getIndex() << "[g]; " << std::endl; + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedCustomUpdateWUGroup" << c.getIndex() << "[g]; " << std::endl; - const SynapseGroupInternal *sg = c.getArchetype().getSynapseGroup(); - EnvironmentSubstitute synEnv(funcEnv); - if (sg->getMatrixType() & SynapseMatrixWeight::KERNEL) { - genKernelIteration(synEnv, c, c.getArchetype().getSynapseGroup()->getKernelSize().size(), - [&c, this](EnvironmentExternal &env) - { - // Call custom update handler - c.generateCustomUpdate(*this, env); + // Create matching environment + EnvironmentGroupMergedField groupEnv(funcEnv, c); - // Write back reductions - genWriteBackReductions(env, c, "id_syn"); - }); - } - else { - // Loop through presynaptic neurons - synEnv.getStream() << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; - { - // If this synapse group has sparse connectivity, loop through length of this row - CodeStream::Scope b(synEnv.getStream()); - if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - synEnv.getStream() << "for(unsigned int s = 0; s < group->rowLength[i]; s++)"; - } - // Otherwise, if it's dense, loop through each postsynaptic neuron - else if (sg->getMatrixType() & SynapseMatrixConnectivity::DENSE) { - synEnv.getStream() << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; - } - else { - throw std::runtime_error("Only DENSE and SPARSE format connectivity can be used for custom updates"); - } + // **TODO** add fields + const SynapseGroupInternal *sg = c.getArchetype().getSynapseGroup(); + if (sg->getMatrixType() & SynapseMatrixWeight::KERNEL) { + genKernelIteration(groupEnv, c, c.getArchetype().getSynapseGroup()->getKernelSize().size(), + [&c, this](EnvironmentExternalBase &env) + { + // Call custom update handler + c.generateCustomUpdate(*this, env); + + // Write back reductions + genWriteBackReductions(env, c, "id_syn"); + }); + } + else { + // Loop through presynaptic neurons + groupEnv.getStream() << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; { + // If this synapse group has sparse connectivity, loop through length of this row CodeStream::Scope b(synEnv.getStream()); - - // Add presynaptic index to substitutions - synEnv.addSubstitution("id_pre", "i"); - - // If connectivity is sparse if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - // Add initialisers to calculate synaptic index and thus lookup postsynaptic index - const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * group->rowStride) + s;"); - const size_t jInit = synEnv.addInitialiser("const unsigned int j = group->ind[idSyn];"); - - // Add substitutions - synEnv.addSubstitution("id_syn", "idSyn", {idSynInit}); - synEnv.addSubstitution("id_post", "j", {jInit, idSynInit}); + groupEnv.getStream() << "for(unsigned int s = 0; s < group->rowLength[i]; s++)"; + } + // Otherwise, if it's dense, loop through each postsynaptic neuron + else if (sg->getMatrixType() & SynapseMatrixConnectivity::DENSE) { + groupEnv.getStream() << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; } else { - synEnv.addSubstitution("id_post", "j"); - - const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * group->numTrgNeurons) + j;"); - synEnv.addSubstitution("id_syn", "idSyn", {idSynInit}); + throw std::runtime_error("Only DENSE and SPARSE format connectivity can be used for custom updates"); + } + { + CodeStream::Scope b(groupEnv.getStream()); + + // Add presynaptic index to substitutions + EnvironmentGroupMergedField synEnv(groupEnv, c); + synEnv.add(Type::Uint32.addConst(), "id_pre", "i"); + + // If connectivity is sparse + if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { + // Add initialisers to calculate synaptic index and thus lookup postsynaptic index + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * group->rowStride) + s;"); + const size_t jInit = synEnv.addInitialiser("const unsigned int j = group->ind[idSyn];"); + + // Add substitutions + synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}, {"_row_stride"}); + synEnv.add(Type::Uint32.addConst(), "id_post", "j", {jInit, idSynInit}, {"_ind", "_row_stride"}); + } + else { + synEnv.add(Type::Uint32.addConst(), "id_post", "j"); + + const size_t idSynInit = ; + synEnv.addSubstitution("id_syn", "idSyn", + {synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["num_post"] + ") + j;")}, + } + + // Generate custom update + c.generateCustomUpdate(*this, synEnv); + + // Write back reductions + genWriteBackReductions(synEnv, c, "id_syn"); } - - // Generate custom update - c.generateCustomUpdate(*this, synEnv); - - // Write back reductions - genWriteBackReductions(synEnv, c, "id_syn"); } } - } - } + }); } // Loop through merged custom connectivity update groups @@ -827,6 +823,24 @@ void Backend::genCustomUpdate(CodeStream &os_, ModelSpecMerged &modelMerged, Hos } } } + + // Generate struct definitions + modelMerged.genMergedCustomUpdateStructs(os, *this); + modelMerged.genMergedCustomUpdateWUStructs(os, *this); + modelMerged.genMergedCustomUpdateTransposeWUStructs(os, *this); + modelMerged.genMergedCustomConnectivityUpdateStructs(os, *this); + + // Generate arrays of merged structs and functions to set them + genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateWUGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateTransposeWUGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdateGroups()); + + // Generate preamble + preambleHandler(os); + + os << customUpdateStream.str(); + } //-------------------------------------------------------------------------- void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 57c257e270..746f2d0ecd 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -33,25 +33,47 @@ bool BackendBase::areSixtyFourBitSynapseIndicesRequired(const SynapseGroupMerged return ((maxSynapses & 0xFFFFFFFF00000000ULL) != 0); } //----------------------------------------------------------------------- -void BackendBase::genCustomUpdateIndexCalculation(CodeStream &os, const CustomUpdateGroupMerged &cu) const +void BackendBase::genCustomUpdateIndexCalculation(EnvironmentGroupMergedField &env) const { + // Add size field + env.addField(Type::Uint32, "size", "size", + [](const auto &c, size_t) { return std::to_string(c.getSize()); }); + // If batching is enabled, calculate batch offset - if(cu.getArchetype().isBatched()) { - os << "const unsigned int batchOffset = group->size * batch;" << std::endl; + if(env.getGroup().getArchetype().isBatched()) { + env.add(Type::Uint32.addConst(), "_batch_offset", "batchOffset", + {env.addInitialiser("const unsigned int batchOffset = " + env["size"] + " * batch;")}, + {"size"}); } // If axonal delays are required - if(cu.getArchetype().getDelayNeuronGroup() != nullptr) { - // We should read from delay slot pointed to be spkQuePtr - os << "const unsigned int delaySlot = *group->spkQuePtr;" << std::endl; - os << "const unsigned int delayOffset = (delaySlot * group->size);" << std::endl; + if(env.getGroup().getArchetype().getDelayNeuronGroup() != nullptr) { + // Add spike queue pointer field + env.addField(Type::Uint32.createPointer(), "_spk_que_ptr", "spkQuePtr", + [this](const auto &cg, size_t) + { + return getScalarAddressPrefix() + "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); + }); + + // We should read from delay slot pointed to be spkQuePtr + env.add(Type::Uint32.addConst(), "_delay_slot", "delaySlot", + {env.addInitialiser("const unsigned int delaySlot = *" + env["_spk_que_ptr"] + ";")}, + {"_spk_que_ptr"}); + env.add(Type::Uint32.addConst(), "_delay_offset", "delayOffset", + {env.addInitialiser("const unsigned int delayOffset = delaySlot * " + env["size"] + ";")}, + {"size", "_delay_slot"}); // If batching is also enabled, calculate offset including delay and batch - if(cu.getArchetype().isBatched()) { - os << "const unsigned int batchDelaySlot = (batch * " << cu.getArchetype().getDelayNeuronGroup()->getNumDelaySlots() << ") + delaySlot;" << std::endl; + if(env.getGroup().getArchetype().isBatched()) { + const std::string numDelaysSlotStr = std::to_string(env.getGroup().getArchetype().getDelayNeuronGroup()->getNumDelaySlots()); + env.add(Type::Uint32.addConst(), "_batch_delay_slot", "batchDelaySlot", + {env.addInitialiser("const unsigned int batchDelaySlot = (batch * " + numDelaySlotsStr + ") + delaySlot;")}, + {"_delay_slot"}); // Calculate current batch offset - os << "const unsigned int batchDelayOffset = delayOffset + (batchOffset * " << cu.getArchetype().getDelayNeuronGroup()->getNumDelaySlots() << ");" << std::endl; + env.add(Type::Uint32.addConst(), "_batch_delay_offset", "batchDelayOffset", + {env.addInitialiser("const unsigned int batchDelayOffset = batchOffset * " + numDelaySlotsStr + ";")}, + {"_batch_offset"}); } } } diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 122619cb3d..0f6076b6e5 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -13,7 +13,6 @@ #include "transpiler/parser.h" #include "transpiler/prettyPrinter.h" #include "transpiler/scanner.h" -#include "transpiler/standardLibrary.h" #include "transpiler/typeChecker.h" @@ -27,65 +26,6 @@ using namespace GeNN::Transpiler; //---------------------------------------------------------------------------- const std::string CustomUpdateGroupMerged::name = "CustomUpdate"; //---------------------------------------------------------------------------- -CustomUpdateGroupMerged::CustomUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: GroupMerged(index, typeContext, groups) -{ - using namespace Type; - - // Create type environment - StandardLibrary::FunctionTypes stdLibraryEnv; - GroupMergedTypeEnvironment typeEnvironment(*this, &stdLibraryEnv); - - addField(Uint32, "size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); - - // If some variables are delayed, add delay pointer - if(getArchetype().getDelayNeuronGroup() != nullptr) { - addField(Uint32.createPointer(), "spkQuePtr", - [&backend](const auto &cg, size_t) - { - return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); - }); - } - - // Add heterogeneous custom update model parameters - const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); - typeEnvironment.defineHeterogeneousParams( - cm->getParamNames(), "", - &CustomUpdateInternal::getParams, - &CustomUpdateGroupMerged::isParamHeterogeneous); - - // Add heterogeneous custom update model derived parameters - typeEnvironment.defineHeterogeneousDerivedParams( - cm->getDerivedParams(), "", - &CustomUpdateInternal::getDerivedParams, - &CustomUpdateGroupMerged::isDerivedParamHeterogeneous); - - // Add variables to struct - typeEnvironment.defineVars(cm->getVars(), backend.getDeviceVarPrefix()); - - // Add variable references to struct - typeEnvironment.defineVarReferences(cm->getVarRefs(), backend.getDeviceVarPrefix()); - - // Add EGPs to struct - typeEnvironment.defineEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); - - // Scan, parse and type-check update code - ErrorHandler errorHandler; - std::tie(m_UpdateStatements, m_ResolvedTypes) = scanParseAndTypeCheckStatements(cm->getUpdateCode(), typeContext, - typeEnvironment, errorHandler); -} -//---------------------------------------------------------------------------- -bool CustomUpdateGroupMerged::isParamHeterogeneous(const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, [](const auto &cg) { return cg.getParams(); }); -} -//---------------------------------------------------------------------------- -bool CustomUpdateGroupMerged::isDerivedParamHeterogeneous(const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, [](const auto &cg) { return cg.getDerivedParams(); }); -} -//---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() const { boost::uuids::detail::sha1 hash; @@ -104,37 +44,37 @@ boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, EnvironmentExternal &env) const +void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env) { // Add parameters, derived parameters and EGPs to environment - EnvironmentSubstitute envSubs(env); + EnvironmentGroupMergedField cuEnv(env, *this); + + // Substitute parameter and derived parameter names const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); - envSubs.addParamValueSubstitution(cm->getParamNames(), getArchetype().getParams(), - [this](const std::string &p) { return isParamHeterogeneous(p); }); - envSubs.addVarValueSubstitution(cm->getDerivedParams(), getArchetype().getDerivedParams(), - [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); - envSubs.addVarNameSubstitution(cm->getExtraGlobalParams()); + cuEnv.addParams(cm->getParamNames(), "", &CustomUpdateInternal::getParams, &CustomUpdateGroupMerged::isParamHeterogeneous); + cuEnv.addDerivedParams(cm->getDerivedParams(), "", &CustomUpdateInternal::getDerivedParams, &CustomUpdateGroupMerged::isDerivedParamHeterogeneous); + cuEnv.addExtraGlobalParams(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); // Create an environment which caches variables in local variables if they are accessed - EnvironmentLocalVarCache varSubs( - getArchetype(), getTypeContext(), envSubs, - [this, &envSubs](const std::string&, const Models::VarInit&, VarAccess a) + EnvironmentLocalVarCache varEnv( + *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", + [this, &cuEnv](const std::string&, VarAccess a) { - return getVarIndex(getVarAccessDuplication(a), envSubs["id"]); + return getVarIndex(getVarAccessDuplication(a), cuEnv["id"]); }); // Create an environment which caches variable references in local variables if they are accessed - EnvironmentLocalVarCache varRefSubs( - getArchetype(), getTypeContext(), varSubs, - [this, &envSubs](const std::string&, const Models::VarReference &v, VarAccessMode) + EnvironmentLocalVarRefCache varRefEnv( + *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", + [this, &varEnv](const std::string&, const Models::VarReference &v) { return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, getVarAccessDuplication(v.getVar().access), - envSubs["id"]); + varEnv["id"]); }); - // Pretty print previously parsed update statements - PrettyPrinter::print(getUpdateStatements(), varRefSubs, getTypeContext(), m_ResolvedTypes); + Transpiler::ErrorHandler errorHandler("Custom update code " + std::to_string(getIndex())); + prettyPrintExpression(cm->getUpdateCode(), getTypeContext(), varRefEnv, errorHandler); } //---------------------------------------------------------------------------- std::string CustomUpdateGroupMerged::getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const @@ -173,6 +113,16 @@ std::string CustomUpdateGroupMerged::getVarRefIndex(bool delay, VarAccessDuplica return getVarIndex(varDuplication, index); } } +//---------------------------------------------------------------------------- +bool CustomUpdateGroupMerged::isParamHeterogeneous(const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [](const auto &cg) { return cg.getParams(); }); +} +//---------------------------------------------------------------------------- +bool CustomUpdateGroupMerged::isDerivedParamHeterogeneous(const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [](const auto &cg) { return cg.getDerivedParams(); }); +} // ---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomUpdateWUGroupMergedBase @@ -213,36 +163,36 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWUGroupMergedBase::getHashDi return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &backend, EnvironmentExternal &env) const +void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env) { - // Add parameters, derived parameters and EGPs to environment - EnvironmentSubstitute envSubs(env); + // Add parameters, derived parameters and EGPs to environment + EnvironmentGroupMergedField cuEnv(env, *this); + + // Substitute parameter and derived parameter names const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); - envSubs.addParamValueSubstitution(cm->getParamNames(), getArchetype().getParams(), - [this](const std::string &p) { return isParamHeterogeneous(p); }); - envSubs.addVarValueSubstitution(cm->getDerivedParams(), getArchetype().getDerivedParams(), - [this](const std::string &p) { return isDerivedParamHeterogeneous(p); }); - envSubs.addVarNameSubstitution(cm->getExtraGlobalParams()); + cuEnv.addParams(cm->getParamNames(), "", &CustomUpdateInternal::getParams, &CustomUpdateWUGroupMergedBase::isParamHeterogeneous); + cuEnv.addDerivedParams(cm->getDerivedParams(), "", &CustomUpdateInternal::getDerivedParams, &CustomUpdateWUGroupMergedBase::isDerivedParamHeterogeneous); + cuEnv.addExtraGlobalParams(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); // Create an environment which caches variables in local variables if they are accessed - EnvironmentLocalVarCache varSubs( - getArchetype(), getTypeContext(), envSubs, - [&envSubs, this](const std::string&, const Models::VarInit&, VarAccess a) + EnvironmentLocalVarCache varEnv( + *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", + [this, &cuEnv](const std::string&, VarAccess a) { - return getVarIndex(getVarAccessDuplication(a), envSubs["id_syn"]); + return getVarIndex(getVarAccessDuplication(a), cuEnv["id_syn"]); }); // Create an environment which caches variable references in local variables if they are accessed - EnvironmentLocalVarCache varRefSubs( - getArchetype(), getTypeContext(), varSubs, - [&envSubs, this](const std::string&, const Models::WUVarReference &v, VarAccessMode) + EnvironmentLocalVarRefCache varRefEnv( + *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", + [this, &varEnv](const std::string&, const Models::WUVarReference &v) { - return getVarRefIndex(getVarAccessDuplication(v.getVar().access), - envSubs["id_syn"]); + return getVarRefIndex(getVarAccessDuplication(v.getVar().access), + varEnv["id_syn"]); }); - // Pretty print previously parsed update statements - PrettyPrinter::print(getUpdateStatements(), varRefSubs, getTypeContext(), getResolvedTypes()); + Transpiler::ErrorHandler errorHandler("Custom WU update code " + std::to_string(getIndex())); + prettyPrintExpression(cm->getUpdateCode(), getTypeContext(), varRefEnv, errorHandler); } //---------------------------------------------------------------------------- std::string CustomUpdateWUGroupMergedBase::getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const @@ -261,7 +211,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const const std::vector> &groups) : GroupMerged(index, typeContext, groups) { - using namespace Type; + /*using namespace Type; // Create type environment StandardLibrary::FunctionTypes stdLibraryEnv; @@ -359,7 +309,7 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const // Scan, parse and type-check update code ErrorHandler errorHandler; std::tie(m_UpdateStatements, m_ResolvedTypes) = scanParseAndTypeCheckStatements(cm->getUpdateCode(), typeContext, - typeEnvironment, errorHandler); + typeEnvironment, errorHandler);*/ } // ---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 38a511b445..46558b40df 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -41,9 +41,9 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, - [&csEnv, &modelMerged, &ng](const std::string&, VarAccess a) + [&csEnv, &modelMerged, &ng](const std::string&, VarAccessDuplication d) { - return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), csEnv["id"]); + return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, csEnv["id"]); }); // Pretty print code back to environment @@ -127,9 +127,9 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, - [&psmEnv, &modelMerged, &ng](const std::string&, VarAccess a) + [&psmEnv, &modelMerged, &ng](const std::string&, VarAccessDuplication d) { - return ng.getVarIndex(modelMerged.getModel().getBatchSize(), getVarAccessDuplication(a), psmEnv["id"]); + return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, psmEnv["id"]); }); // Pretty print code back to environment @@ -217,13 +217,13 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, - [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess a) + [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { - return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), synEnv["id"]); + return ng.getReadVarIndex(delayed, batchSize, d, synEnv["id"]); }, - [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess a) + [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { - return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), synEnv["id"]); + return ng.getWriteVarIndex(delayed, batchSize, d, synEnv["id"]); }); /*neuronSubstitutionsInSynapticCode(varEnv, &ng.getArchetype(), "", "_post", "", "", "", dynamicsNotSpike, @@ -309,13 +309,13 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, - [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess a) + [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { - return ng.getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), synEnv["id"]); + return ng.getReadVarIndex(delayed, batchSize, d, synEnv["id"]); }, - [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccess a) + [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { - return ng.getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), synEnv["id"]); + return ng.getWriteVarIndex(delayed, batchSize, d, synEnv["id"]); }); /*neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_pre", "", "", "", dynamicsNotSpike, @@ -541,16 +541,16 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Create an environment which caches variables in local variables if they are accessed // **NOTE** we do this right at the top so that local copies can be used by child groups EnvironmentLocalVarCache neuronVarEnv( - *this, getTypeContext(), neuronEnv, backend.getDeviceVarPrefix(), "l", "", - [batchSize, &neuronEnv, this](const std::string &varName, VarAccess a) + *this, *this, getTypeContext(), neuronEnv, backend.getDeviceVarPrefix(), "l", "", + [batchSize, &neuronEnv, this](const std::string &varName, VarAccessDuplication d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getReadVarIndex(delayed, batchSize, getVarAccessDuplication(a), neuronEnv["id"]) ; + return getReadVarIndex(delayed, batchSize, d, neuronEnv["id"]) ; }, - [batchSize, &neuronEnv, this](const std::string &varName, VarAccess a) + [batchSize, &neuronEnv, this](const std::string &varName, VarAccessDuplication d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getWriteVarIndex(delayed, batchSize, getVarAccessDuplication(a), neuronEnv["id"]) ; + return getWriteVarIndex(delayed, batchSize, d, neuronEnv["id"]) ; }); From 776bbcfb4569506912aa88f9d1bd2812ba6b7aba Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 19 Jun 2023 09:31:23 +0100 Subject: [PATCH 234/725] removed commented out code --- include/genn/genn/code_generator/environment.h | 5 ----- 1 file changed, 5 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index ef6f710b43..77ac5ce136 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -648,11 +648,6 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P using Def = typename std::invoke_result_t::value_type; public: - /*template - EnvironmentExternalDynamicBase(EnvironmentExternalBase &enclosing, PolicyArgs&&... policyArgs) - : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...) - {}*/ - template EnvironmentLocalCacheBase(G &group, F &fieldGroup, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, From 522bf66e5ef12347d96611508c91b890c5008cb5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 19 Jun 2023 09:31:40 +0100 Subject: [PATCH 235/725] fixed typo --- src/genn/genn/code_generator/customUpdateGroupMerged.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 0f6076b6e5..3d3e17042f 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -58,9 +58,9 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", - [this, &cuEnv](const std::string&, VarAccess a) + [this, &cuEnv](const std::string&, VarAccessDuplication d) { - return getVarIndex(getVarAccessDuplication(a), cuEnv["id"]); + return getVarIndex(d, cuEnv["id"]); }); // Create an environment which caches variable references in local variables if they are accessed @@ -177,9 +177,9 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", - [this, &cuEnv](const std::string&, VarAccess a) + [this, &cuEnv](const std::string&, VarAccessDuplication d) { - return getVarIndex(getVarAccessDuplication(a), cuEnv["id_syn"]); + return getVarIndex(d, cuEnv["id_syn"]); }); // Create an environment which caches variable references in local variables if they are accessed From a4f063197fecbcaaf64184b56a96a48f93b334bf Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 19 Jun 2023 12:30:52 +0100 Subject: [PATCH 236/725] fixed up custom update group merged --- .../code_generator/customUpdateGroupMerged.h | 80 +++--- .../code_generator/customUpdateGroupMerged.cc | 245 ++++++++---------- 2 files changed, 148 insertions(+), 177 deletions(-) diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index bc841df5c6..a6c2c3f66e 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -13,6 +13,8 @@ namespace GeNN::CodeGenerator class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged { public: + using GroupMerged::GroupMerged; + //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- @@ -51,6 +53,8 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged { public: + using GroupMerged::GroupMerged; + //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- @@ -59,8 +63,6 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged> &groups); + void generateCustomUpdateBase(const BackendBase &backend, EnvironmentExternalBase &env); private: static const std::vector& getGroupKernelSize(const CustomUpdateWUInternal& g) { return g.getSynapseGroup()->getKernelSize(); } - - //---------------------------------------------------------------------------- - // Members - //---------------------------------------------------------------------------- - //! List of statements parsed and type-checked in constructor; and used to generate code - Transpiler::Statement::StatementList m_UpdateStatements; - - //! Resolved types used to generate code - Transpiler::TypeChecker::ResolvedTypeMap m_ResolvedTypes; }; // ---------------------------------------------------------------------------- @@ -111,11 +100,7 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged> &groups) - : CustomUpdateWUGroupMergedBase(index, typeContext, backend, groups) - { - } + using CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase; //---------------------------------------------------------------------------- // Public API @@ -129,6 +114,11 @@ class GENN_EXPORT CustomUpdateWUGroupMerged : public CustomUpdateWUGroupMergedBa runnerVarDecl, runnerMergedStructAlloc, name); } + void generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env) + { + generateCustomUpdateBase(backend, env); + } + //---------------------------------------------------------------------------- // Static constants //---------------------------------------------------------------------------- @@ -141,11 +131,7 @@ class GENN_EXPORT CustomUpdateWUGroupMerged : public CustomUpdateWUGroupMergedBa class GENN_EXPORT CustomUpdateTransposeWUGroupMerged : public CustomUpdateWUGroupMergedBase { public: - CustomUpdateTransposeWUGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) - : CustomUpdateWUGroupMergedBase(index, typeContext, backend, groups) - { - } + using CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase; //---------------------------------------------------------------------------- // Public API @@ -159,6 +145,8 @@ class GENN_EXPORT CustomUpdateTransposeWUGroupMerged : public CustomUpdateWUGrou runnerVarDecl, runnerMergedStructAlloc, name); } + void generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env); + //---------------------------------------------------------------------------- // Static constants //---------------------------------------------------------------------------- @@ -173,28 +161,36 @@ template class CustomUpdateHostReductionGroupMergedBase : public GroupMerged { protected: - CustomUpdateHostReductionGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) - : GroupMerged(index, typeContext, groups) + using GroupMerged::GroupMerged; + + template + void generateCustomUpdateBase(const BackendBase &backend, EnvironmentGroupMergedField &env) { - // Create type environment - // **TEMP** parse precision to get scalar type - //GroupMergedTypeEnvironment typeEnvironment(*this, Type::parseNumeric(precision)); - // Loop through variables and add pointers if they are reduction targets - /*const CustomUpdateModels::Base *cm = this->getArchetype().getCustomUpdateModel(); + const auto *cm = this->getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { if(v.access & VarAccessModeAttribute::REDUCE) { - this->addPointerField(v.type, v.name, backend.getDeviceVarPrefix() + v.name); + const auto fieldType = v.type.resolve(getTypeContext()).createPointer(); + env.addField(fieldType, v.name, v.name, + [&backend, v](const auto &g, size_t) + { + return backend.getDeviceVarPrefix() + v.name + g.getName(); + }); } } // Loop through variable references and add pointers if they are reduction targets for(const auto &v : cm->getVarRefs()) { if(v.access & VarAccessModeAttribute::REDUCE) { - this->addPointerField(v.type, v.name, backend.getDeviceVarPrefix() + v.name); + const auto fieldType = v.type.resolve(getTypeContext()).createPointer(); + env.addField(fieldType, v.name, v.name, + [&backend, v](const auto &g, size_t) + { + const auto varRef = g.getVarReferences().at(v.name); + return backend.getDeviceVarPrefix() + v.name + varRef.getTargetName(); + }); } - }*/ + } } }; @@ -204,8 +200,7 @@ class CustomUpdateHostReductionGroupMergedBase : public GroupMerged class GENN_EXPORT CustomUpdateHostReductionGroupMerged : public CustomUpdateHostReductionGroupMergedBase { public: - CustomUpdateHostReductionGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + using CustomUpdateHostReductionGroupMergedBase::CustomUpdateHostReductionGroupMergedBase; //------------------------------------------------------------------------ // Public API @@ -219,6 +214,8 @@ class GENN_EXPORT CustomUpdateHostReductionGroupMerged : public CustomUpdateHost runnerVarDecl, runnerMergedStructAlloc, name, true); } + void generateCustomUpdate(const BackendBase &backend, EnvironmentGroupMergedField &env); + //---------------------------------------------------------------------------- // Static constants //---------------------------------------------------------------------------- @@ -231,8 +228,7 @@ class GENN_EXPORT CustomUpdateHostReductionGroupMerged : public CustomUpdateHost class GENN_EXPORT CustomWUUpdateHostReductionGroupMerged : public CustomUpdateHostReductionGroupMergedBase { public: - CustomWUUpdateHostReductionGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + using CustomUpdateHostReductionGroupMergedBase::CustomUpdateHostReductionGroupMergedBase; //------------------------------------------------------------------------ // Public API @@ -246,6 +242,8 @@ class GENN_EXPORT CustomWUUpdateHostReductionGroupMerged : public CustomUpdateHo runnerVarDecl, runnerMergedStructAlloc, name, true); } + void generateCustomUpdate(const BackendBase &backend, EnvironmentGroupMergedField &env); + //---------------------------------------------------------------------------- // Static constants //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 3d3e17042f..47d30477eb 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -49,6 +49,16 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E // Add parameters, derived parameters and EGPs to environment EnvironmentGroupMergedField cuEnv(env, *this); + cuEnv.addField(Type::Uint32.addConst(), "size", + Type::Uint32, "size", + [](const CustomUpdateInternal &c, size_t) { return std::to_string(c.getSize()); }); + + // If some variables are delayed, add delay pointer + if(getArchetype().getDelayNeuronGroup() != nullptr) { + cuEnv.addField(Type::Uint32.createPointer(), "_spk_que_ptr", "spkQuePtr", + [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + g.getDelayNeuronGroup()->getName(); }); + } + // Substitute parameter and derived parameter names const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); cuEnv.addParams(cm->getParamNames(), "", &CustomUpdateInternal::getParams, &CustomUpdateGroupMerged::isParamHeterogeneous); @@ -163,38 +173,6 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWUGroupMergedBase::getHashDi return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env) -{ - // Add parameters, derived parameters and EGPs to environment - EnvironmentGroupMergedField cuEnv(env, *this); - - // Substitute parameter and derived parameter names - const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); - cuEnv.addParams(cm->getParamNames(), "", &CustomUpdateInternal::getParams, &CustomUpdateWUGroupMergedBase::isParamHeterogeneous); - cuEnv.addDerivedParams(cm->getDerivedParams(), "", &CustomUpdateInternal::getDerivedParams, &CustomUpdateWUGroupMergedBase::isDerivedParamHeterogeneous); - cuEnv.addExtraGlobalParams(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); - - // Create an environment which caches variables in local variables if they are accessed - EnvironmentLocalVarCache varEnv( - *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", - [this, &cuEnv](const std::string&, VarAccessDuplication d) - { - return getVarIndex(d, cuEnv["id_syn"]); - }); - - // Create an environment which caches variable references in local variables if they are accessed - EnvironmentLocalVarRefCache varRefEnv( - *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", - [this, &varEnv](const std::string&, const Models::WUVarReference &v) - { - return getVarRefIndex(getVarAccessDuplication(v.getVar().access), - varEnv["id_syn"]); - }); - - Transpiler::ErrorHandler errorHandler("Custom WU update code " + std::to_string(getIndex())); - prettyPrintExpression(cm->getUpdateCode(), getTypeContext(), varRefEnv, errorHandler); -} -//---------------------------------------------------------------------------- std::string CustomUpdateWUGroupMergedBase::getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? @@ -207,15 +185,10 @@ std::string CustomUpdateWUGroupMergedBase::getVarRefIndex(VarAccessDuplication v return ((varDuplication == VarAccessDuplication::SHARED || !getArchetype().isBatched()) ? "" : "batchOffset + ") + index; } //---------------------------------------------------------------------------- -CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: GroupMerged(index, typeContext, groups) +void CustomUpdateWUGroupMergedBase::generateCustomUpdateBase(const BackendBase &backend, EnvironmentExternalBase &env) { - /*using namespace Type; - - // Create type environment - StandardLibrary::FunctionTypes stdLibraryEnv; - GroupMergedTypeEnvironment typeEnvironment(*this, &stdLibraryEnv); + // Add parameters, derived parameters and EGPs to environment + EnvironmentGroupMergedField cuEnv(env, *this); // If underlying synapse group has kernel weights if (getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL) { @@ -223,95 +196,79 @@ CustomUpdateWUGroupMergedBase::CustomUpdateWUGroupMergedBase(size_t index, const for (size_t d = 0; d < getArchetype().getSynapseGroup()->getKernelSize().size(); d++) { // If this dimension has a heterogeneous size, add it to struct if (isKernelSizeHeterogeneous(d)) { - addField(Uint32, "kernelSize" + std::to_string(d), - [d](const auto &cu, size_t) - { - return std::to_string(cu.getSynapseGroup()->getKernelSize().at(d)); - }); + cuEnv.addField(Type::Uint32, "_kernel_size_" + std::to_string(d), "kernelSize" + std::to_string(d), + [d](const auto &cu, size_t) + { + return std::to_string(cu.getSynapseGroup()->getKernelSize().at(d)); + }); } } } // Otherwise else { - addField(Uint32, "rowStride", - [&backend](const auto &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(backend.getSynapticMatrixRowStride(*sgInternal)); - }); + cuEnv.addField(Type::Uint32, "_row_stride", "rowStride", + [&backend](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(backend.getSynapticMatrixRowStride(*sgInternal)); + }); - addField(Uint32, "numSrcNeurons", - [](const auto &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); - }); - - addField(Uint32, "numTrgNeurons", - [](const auto &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(sgInternal->getTrgNeuronGroup()->getNumNeurons()); - }); + cuEnv.addField(Type::Uint32, "num_pre", "numSrcNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); + }); + + cuEnv.addField(Type::Uint32, "num_post", "numTrgNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getTrgNeuronGroup()->getNumNeurons()); + }); // If synapse group has sparse connectivity if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addField(getArchetype().getSynapseGroup()->getSparseIndType().createPointer(), "ind", - [&backend](const auto &cg, size_t) - { - return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); - }); - - addField(Uint32.createPointer(), "rowLength", - [&backend](const auto &cg, size_t) - { - return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); - }); + cuEnv.addField(getArchetype().getSynapseGroup()->getSparseIndType().createPointer(), "_ind", "ind", + [&backend](const auto &cg, size_t) + { + return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); + }); + + cuEnv.addField(Type::Uint32.createPointer(), "_row_length", "rowLength", + [&backend](const auto &cg, size_t) + { + return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); + }); } } - // Add heterogeneous custom update model parameters + // Substitute parameter and derived parameter names const CustomUpdateModels::Base *cm = getArchetype().getCustomUpdateModel(); - typeEnvironment.defineHeterogeneousParams( - cm->getParamNames(), "", - &CustomUpdateWUInternal::getParams, - &CustomUpdateWUGroupMergedBase::isParamHeterogeneous); - - // Add heterogeneous weight update model derived parameters - typeEnvironment.defineHeterogeneousDerivedParams( - cm->getDerivedParams(), "", - &CustomUpdateWUInternal::getDerivedParams, - &CustomUpdateWUGroupMergedBase::isDerivedParamHeterogeneous); - - // Add variables to struct - typeEnvironment.defineVars(cm->getVars(), backend.getDeviceVarPrefix()); - - // Add variable references to struct - const auto varRefs = cm->getVarRefs(); - typeEnvironment.defineVarReferences(varRefs, backend.getDeviceVarPrefix()); + cuEnv.addParams(cm->getParamNames(), "", &CustomUpdateInternal::getParams, &CustomUpdateWUGroupMergedBase::isParamHeterogeneous); + cuEnv.addDerivedParams(cm->getDerivedParams(), "", &CustomUpdateInternal::getDerivedParams, &CustomUpdateWUGroupMergedBase::isDerivedParamHeterogeneous); + cuEnv.addExtraGlobalParams(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); - // Loop through variables - for(const auto &v : varRefs) { - // If variable has a transpose - if(getArchetype().getVarReferences().at(v.name).getTransposeSynapseGroup() != nullptr) { - // Add field with transpose suffix, pointing to transpose var - addField(v.type.resolve(getTypeContext()).createPointer(), v.name + "Transpose", - [&backend, v](const auto &g, size_t) - { - const auto varRef = g.getVarReferences().at(v.name); - return backend.getDeviceVarPrefix() + varRef.getTransposeVar().name + varRef.getTransposeTargetName(); - }); - } - } - // Add EGPs to struct - typeEnvironment.defineEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); + // Create an environment which caches variables in local variables if they are accessed + EnvironmentLocalVarCache varEnv( + *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", + [this, &cuEnv](const std::string&, VarAccessDuplication d) + { + return getVarIndex(d, cuEnv["id_syn"]); + }); + + // Create an environment which caches variable references in local variables if they are accessed + EnvironmentLocalVarRefCache varRefEnv( + *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", + [this, &varEnv](const std::string&, const Models::WUVarReference &v) + { + return getVarRefIndex(getVarAccessDuplication(v.getVar().access), + varEnv["id_syn"]); + }); - // Scan, parse and type-check update code - ErrorHandler errorHandler; - std::tie(m_UpdateStatements, m_ResolvedTypes) = scanParseAndTypeCheckStatements(cm->getUpdateCode(), typeContext, - typeEnvironment, errorHandler);*/ + Transpiler::ErrorHandler errorHandler("Custom WU update code " + std::to_string(getIndex())); + prettyPrintExpression(cm->getUpdateCode(), getTypeContext(), varRefEnv, errorHandler); } - // ---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomUpdateWUGroupMerged //---------------------------------------------------------------------------- @@ -321,45 +278,61 @@ const std::string CustomUpdateWUGroupMerged::name = "CustomUpdateWU"; // CustomUpdateTransposeWUGroupMerged //---------------------------------------------------------------------------- const std::string CustomUpdateTransposeWUGroupMerged::name = "CustomUpdateTransposeWU"; +// ---------------------------------------------------------------------------- +void CustomUpdateTransposeWUGroupMerged::generateCustomUpdate(const BackendBase &backend, EnvironmentExternalBase &env) +{ + // Add parameters, derived parameters and EGPs to environment + EnvironmentGroupMergedField cuEnv(env, *this); + + // Loop through variable references + const auto varRefs = getArchetype().getCustomUpdateModel()->getVarRefs(); + for(const auto &v : varRefs) { + const auto fieldType = v.type.resolve(getTypeContext()).createPointer(); + + // If variable has a transpose, add field with transpose suffix, pointing to transpose var + if(getArchetype().getVarReferences().at(v.name).getTransposeSynapseGroup() != nullptr) { + cuEnv.addField(fieldType, v.name + "_transpose", v.name + "Transpose", + [&backend, v](const auto &g, size_t) + { + const auto varRef = g.getVarReferences().at(v.name); + return backend.getDeviceVarPrefix() + varRef.getTransposeVar().name + varRef.getTransposeTargetName(); + }); + } + } + + generateCustomUpdateBase(backend, cuEnv); +} // ---------------------------------------------------------------------------- // CustomUpdateHostReductionGroupMerged //---------------------------------------------------------------------------- const std::string CustomUpdateHostReductionGroupMerged::name = "CustomUpdateHostReduction"; //---------------------------------------------------------------------------- -CustomUpdateHostReductionGroupMerged::CustomUpdateHostReductionGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: CustomUpdateHostReductionGroupMergedBase(index, typeContext, backend, groups) +void CustomUpdateHostReductionGroupMerged::generateCustomUpdate(const BackendBase &backend, EnvironmentGroupMergedField &env) { - using namespace Type; - - addField(Uint32, "size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); - + env.addField(Type::Uint32, "_size", "size", + [](const auto &c, size_t) { return std::to_string(c.getSize()); }); + // If some variables are delayed, add delay pointer - // **NOTE** this is HOST delay pointer if(getArchetype().getDelayNeuronGroup() != nullptr) { - addField(Uint32.createPointer(), "spkQuePtr", - [](const auto &cg, size_t) - { - return "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); - }); + env.addField(Type::Uint32.createPointer(), "_spk_que_ptr", "spkQuePtr", + [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + g.getDelayNeuronGroup()->getName(); }); } -} + generateCustomUpdateBase(backend, env); +} // ---------------------------------------------------------------------------- // CustomWUUpdateHostReductionGroupMerged //---------------------------------------------------------------------------- const std::string CustomWUUpdateHostReductionGroupMerged::name = "CustomWUUpdateHostReduction"; //---------------------------------------------------------------------------- -CustomWUUpdateHostReductionGroupMerged::CustomWUUpdateHostReductionGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: CustomUpdateHostReductionGroupMergedBase(index, typeContext, backend, groups) +void CustomWUUpdateHostReductionGroupMerged::generateCustomUpdate(const BackendBase &backend, EnvironmentGroupMergedField &env) { - using namespace Type; + env.addField(Type::Uint32, "_size", "size", + [](const auto &c, size_t) + { + return std::to_string(c.getSynapseGroup()->getMaxConnections() * (size_t)c.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); + }); - addField(Uint32, "size", - [&backend](const auto &cg, size_t) - { - return std::to_string(cg.getSynapseGroup()->getMaxConnections() * (size_t)cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); - }); + generateCustomUpdateBase(backend, env); } From 6074b3c3c0d18f449c6820609df7bd041b236132 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 20 Jun 2023 10:10:11 +0100 Subject: [PATCH 237/725] work on merged initialisation groups --- .../genn/genn/code_generator/backendBase.h | 6 +- .../genn/genn/code_generator/environment.h | 58 ++- .../genn/genn/code_generator/groupMerged.h | 2 +- .../genn/code_generator/initGroupMerged.h | 60 +-- .../code_generator/synapseUpdateGroupMerged.h | 14 +- .../backends/single_threaded_cpu/backend.cc | 130 ++--- .../genn/code_generator/initGroupMerged.cc | 464 ++++++++---------- 7 files changed, 361 insertions(+), 373 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index b339769a5d..526fb2f2c9 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -318,8 +318,8 @@ class GENN_EXPORT BackendBase virtual void genPopVariableInit(EnvironmentExternalBase &env, HandlerEnv handler) const = 0; virtual void genVariableInit(EnvironmentExternalBase &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const = 0; - virtual void genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, Handler handler) const = 0; - virtual void genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, Handler handler) const = 0; + virtual void genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const = 0; + virtual void genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const = 0; virtual void genKernelSynapseVariableInit(EnvironmentExternalBase &env, const SynapseInitGroupMerged &sg, HandlerEnv handler) const = 0; virtual void genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, const CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const = 0; @@ -578,7 +578,7 @@ class GENN_EXPORT BackendBase [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); groupEnv.addField(Type::Uint32.addConst(), "num_post", Type::Uint32, "numTrgNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); groupEnv.addField(Type::Uint32, "_row_stride", "rowStride", [](const SynapseGroupInternal &sg, size_t) { return std::to_string(getSynapticMatrixRowStride(sg)); }); groupEnv.addField(Type::Uint32, "_col_stride", "colStride", diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 77ac5ce136..e5d98f6801 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -341,6 +341,10 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &(GroupInternal::*)(void) const; + template + using GetConnectivityFn = const Snippet::Init &(GroupInternal::*)(void) const; + + template using GetVarReferencesFn = const std::unordered_map &(GroupInternal::*)(void) const; @@ -360,9 +364,9 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &initialisers = {}, const std::vector &dependents = {}) + const GeNN::Type::ResolvedType &fieldType, const std::string &fieldName, typename G::GetFieldValueFunc getFieldValue, + const std::string &indexSuffix = "", GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, + const std::vector &initialisers = {}, const std::vector &dependents = {}) { addInternal(type, name, std::make_tuple(false, indexSuffix, std::make_optional(std::make_tuple(fieldType, fieldName, getFieldValue, mergedFieldType))), initialisers, dependents); @@ -444,6 +448,54 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase + void addConnectInitParams(const std::string &fieldSuffix, GetConnectivityFn getConnectivity, + IsHeterogeneousFn isHeterogeneous) + { + // Loop through params + const auto &connectInit = std::invoke(getConnectivity, getGroup().getArchetype()); + const auto *snippet = connectInit.getSnippet(); + for(const auto &p : snippet->getParamNames()) { + // If parameter is heterogeneous, add scalar field + if (std::invoke(isHeterogeneous, getGroup(), p)) { + addScalar(p, fieldSuffix, + [p, getConnectivity](const auto &g, size_t) + { + return std::invoke(getConnectivity, g).at(p).getParams(); + }); + } + // Otherwise, just add a const-qualified scalar to the type environment + else { + add(getGroup().getScalarType().addConst(), p, + getScalarString(connectInit.getParams().at(p))); + } + } + } + + template + void addConnectInitDerivedParams(const std::string &fieldSuffix, GetConnectivityFn getConnectivity, + IsHeterogeneousFn isHeterogeneous) + { + // Loop through params + const auto &connectInit = std::invoke(getConnectivity, getGroup().getArchetype()); + const auto *snippet = connectInit.getSnippet(); + for(const auto &d : snippet->getDerivedParams()) { + // If parameter is heterogeneous, add scalar field + if (std::invoke(isHeterogeneous, getGroup(), d.name)) { + addScalar(d, fieldSuffix, + [d, getConnectivity](const auto &g, size_t) + { + return std::invoke(getConnectivity, g).at(d.name).getDerivedParams()(); + }); + } + // Otherwise, just add a const-qualified scalar to the type environment + else { + add(getGroup().getScalarType().addConst(), d.name, + getScalarString(connectInit.getDerivedParams().at(d.name))); + } + } + } + template void addVarInitParams(IsVarInitHeterogeneousFn isHeterogeneous, const std::string &fieldSuffix = "") { diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index f11bcc2b90..231af503ab 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -692,7 +692,7 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged> &groups) : GroupMerged(index, typeContext, groups), m_ArchetypeCode(archetypeCode) {} diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index e958e1015f..a6d8fbc9f1 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -209,9 +209,9 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase class GENN_EXPORT SynapseInitGroupMerged : public SynapseGroupMergedBase { public: - SynapseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + SynapseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups) - : SynapseGroupMergedBase(index, typeContext, backend, SynapseGroupMergedBase::Role::Init, "", groups) + : SynapseGroupMergedBase(index, typeContext, SynapseGroupMergedBase::Role::Init, "", groups) {} boost::uuids::detail::sha1::digest_type getHashDigest() const @@ -228,7 +228,7 @@ class GENN_EXPORT SynapseInitGroupMerged : public SynapseGroupMergedBase runnerVarDecl, runnerMergedStructAlloc, name); } - void generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); //---------------------------------------------------------------------------- // Static constants @@ -261,7 +261,7 @@ class GENN_EXPORT SynapseSparseInitGroupMerged : public SynapseGroupMergedBase runnerVarDecl, runnerMergedStructAlloc, name); } - void generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); //---------------------------------------------------------------------------- // Static constants @@ -294,9 +294,9 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMerged runnerVarDecl, runnerMergedStructAlloc, name); } - void generateSparseRowInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; - void generateSparseColumnInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; - void generateKernelInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateSparseRowInit(const BackendBase &backend, EnvironmentExternalBase &env); + void generateSparseColumnInit(const BackendBase &backend, EnvironmentExternalBase &env); + void generateKernelInit(const BackendBase &backend, EnvironmentExternalBase &env); //---------------------------------------------------------------------------- // Static constants @@ -308,7 +308,7 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMerged // Private methods //---------------------------------------------------------------------------- //! Generate either row or column connectivity init code - void genInitConnectivity(CodeStream &os, Substitutions &popSubs, bool rowNotColumns) const; + void genInitConnectivity(const BackendBase &backend, EnvironmentExternalBase &env, bool rowNotColumns); }; @@ -333,7 +333,7 @@ class GENN_EXPORT SynapseConnectivityHostInitGroupMerged : public GroupMerged class CustomUpdateInitGroupMergedBase : public GroupMerged { protected: - CustomUpdateInitGroupMergedBase(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) - : GroupMerged(index, typeContext, groups) - { - // Loop through variables - A archetypeAdaptor(this->getArchetype()); - for (const auto &var : archetypeAdaptor.getDefs()) { - // If we're not initialising or if there is initialization code for this variable - const auto &varInit = archetypeAdaptor.getInitialisers().at(var.name); - if (!varInit.getSnippet()->getCode().empty()) { - this->addPointerField(var.type, var.name, backend.getDeviceVarPrefix() + var.name); - } - - // Add any var init EGPs to structure - this->addEGPs(varInit.getSnippet()->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name); - } - - this->template addHeterogeneousVarInitParams, A>( - &CustomUpdateInitGroupMergedBase::isVarInitParamHeterogeneous); - - this->template addHeterogeneousVarInitDerivedParams, A>( - &CustomUpdateInitGroupMergedBase::isVarInitDerivedParamHeterogeneous); - } - //---------------------------------------------------------------------------- // Protected methods //---------------------------------------------------------------------------- @@ -453,8 +429,7 @@ class CustomUpdateInitGroupMergedBase : public GroupMerged class GENN_EXPORT CustomUpdateInitGroupMerged : public CustomUpdateInitGroupMergedBase { public: - CustomUpdateInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + using CustomUpdateInitGroupMergedBase::CustomUpdateInitGroupMergedBase; //---------------------------------------------------------------------------- // Public API @@ -470,7 +445,7 @@ class GENN_EXPORT CustomUpdateInitGroupMerged : public CustomUpdateInitGroupMerg runnerVarDecl, runnerMergedStructAlloc, name); } - void generateInit(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged) const; + void generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); //---------------------------------------------------------------------------- // Static constants @@ -503,7 +478,7 @@ class GENN_EXPORT CustomWUUpdateInitGroupMerged : public CustomUpdateInitGroupMe runnerVarDecl, runnerMergedStructAlloc, name); } - void generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); //! Is kernel size heterogeneous in this dimension? bool isKernelSizeHeterogeneous(size_t dimensionIndex) const @@ -562,7 +537,7 @@ class GENN_EXPORT CustomWUUpdateSparseInitGroupMerged : public CustomUpdateInitG runnerVarDecl, runnerMergedStructAlloc, name); } - void generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); //---------------------------------------------------------------------------- // Static constants @@ -594,7 +569,7 @@ class GENN_EXPORT CustomConnectivityUpdatePreInitGroupMerged : public CustomUpda runnerVarDecl, runnerMergedStructAlloc, name); } - void generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); //---------------------------------------------------------------------------- // Static constants @@ -626,7 +601,7 @@ class GENN_EXPORT CustomConnectivityUpdatePostInitGroupMerged : public CustomUpd runnerVarDecl, runnerMergedStructAlloc, name); } - void generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); //---------------------------------------------------------------------------- // Static constants @@ -641,8 +616,7 @@ class GENN_EXPORT CustomConnectivityUpdateSparseInitGroupMerged : public CustomU CustomConnectivityUpdateVarAdapter> { public: - CustomConnectivityUpdateSparseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + using CustomUpdateInitGroupMergedBase::CustomUpdateInitGroupMergedBase; //---------------------------------------------------------------------------- // Public API @@ -658,7 +632,7 @@ class GENN_EXPORT CustomConnectivityUpdateSparseInitGroupMerged : public CustomU runnerVarDecl, runnerMergedStructAlloc, name); } - void generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const; + void generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); //---------------------------------------------------------------------------- // Static constants diff --git a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h index 7d9be43909..b57943ae0e 100644 --- a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h @@ -11,9 +11,9 @@ namespace GeNN::CodeGenerator class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase { public: - PresynapticUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + PresynapticUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups) - : SynapseGroupMergedBase(index, typeContext, backend, SynapseGroupMergedBase::Role::PresynapticUpdate, + : SynapseGroupMergedBase(index, typeContext, SynapseGroupMergedBase::Role::PresynapticUpdate, groups.front().get().getWUModel()->getSimCode() + groups.front().get().getWUModel()->getEventCode() + groups.front().get().getWUModel()->getEventThresholdConditionCode(), groups) {} @@ -49,9 +49,9 @@ class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase { public: - PostsynapticUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + PostsynapticUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups) - : SynapseGroupMergedBase(index, typeContext, backend, SynapseGroupMergedBase::Role::PostsynapticUpdate, + : SynapseGroupMergedBase(index, typeContext, SynapseGroupMergedBase::Role::PostsynapticUpdate, groups.front().get().getWUModel()->getLearnPostCode(), groups) {} @@ -83,9 +83,9 @@ class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase { public: - SynapseDynamicsGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + SynapseDynamicsGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups) - : SynapseGroupMergedBase(index, typeContext, backend, SynapseGroupMergedBase::Role::SynapseDynamics, + : SynapseGroupMergedBase(index, typeContext, SynapseGroupMergedBase::Role::SynapseDynamics, groups.front().get().getWUModel()->getSynapseDynamicsCode(), groups) {} @@ -117,7 +117,7 @@ class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseDendriticDelayUpdateGroupMerged : public GroupMerged { public: - SynapseDendriticDelayUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + SynapseDendriticDelayUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &group); //------------------------------------------------------------------------ diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 3939cb0d61..90cd075981 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -729,97 +729,99 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host } } } - }); - } + } + }); // Loop through merged custom connectivity update groups - for(const auto &c : modelMerged.getMergedCustomConnectivityUpdateGroups()) { - // If this update group isn't for current group, skip - if(c.getArchetype().getUpdateGroupName() != g) { - continue; - } - - CodeStream::Scope b(funcEnv.getStream()); - funcEnv.getStream() << "// merged custom connectivity update group " << c.getIndex() << std::endl; - funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + modelMerged.genMergedCustomConnectivityUpdateGroups( + *this, g, + [this, &funcEnv](auto &c) { CodeStream::Scope b(funcEnv.getStream()); - - // Get reference to group - funcEnv.getStream() << "const auto *group = &mergedCustomConnectivityUpdateGroup" << c.getIndex() << "[g]; " << std::endl; - - genCustomConnectivityUpdateIndexCalculation(funcEnv.getStream(), c); - - // Loop through presynaptic neurons - funcEnv.getStream() << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; + funcEnv.getStream() << "// merged custom connectivity update group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; { CodeStream::Scope b(funcEnv.getStream()); + + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedCustomConnectivityUpdateGroup" << c.getIndex() << "[g]; " << std::endl; - // Configure substitutions - EnvironmentSubstitute cuEnv(funcEnv); - cuEnv.addSubstitution("id_pre", "i"); - cuEnv.addSubstitution("rng", "hostRNG"); + // Create matching environment + EnvironmentGroupMergedField groupEnv(funcEnv, c); + + genCustomConnectivityUpdateIndexCalculation(funcEnv.getStream(), c); + + // Loop through presynaptic neurons + funcEnv.getStream() << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; + { + CodeStream::Scope b(funcEnv.getStream()); + + // Configure substitutions + groupEnv.add(Type::Uint32, "id_pre", "i"); - assert(false); - //c.generateUpdate(*this, cuEnv, model.getBatchSize()); + assert(false); + //c.generateUpdate(*this, cuEnv, model.getBatchSize()); + } } - } - } + }); } // Loop through merged custom WU transpose update groups { Timer t(funcEnv.getStream(), "customUpdate" + g + "Transpose", model.isTimingEnabled()); - for(const auto &c : modelMerged.getMergedCustomUpdateTransposeWUGroups()) { - // If this update group isn't for current group, skip - if(c.getArchetype().getUpdateGroupName() != g) { - continue; - } - - CodeStream::Scope b(funcEnv.getStream()); - funcEnv.getStream() << "// merged custom WU transpose update group " << c.getIndex() << std::endl; - funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + // Loop through merged custom connectivity update groups + modelMerged.genMergedCustomUpdateTransposeWUGroups( + *this, g, + [this, &funcEnv](auto &c) { CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged custom WU transpose update group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(funcEnv.getStream()); - // Get reference to group - funcEnv.getStream() << "const auto *group = &mergedCustomUpdateTransposeWUGroup" << c.getIndex() << "[g]; " << std::endl; + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedCustomUpdateTransposeWUGroup" << c.getIndex() << "[g]; " << std::endl; - // Get index of variable being transposed - const size_t transposeVarIdx = std::distance(c.getArchetype().getVarReferences().cbegin(), - std::find_if(c.getArchetype().getVarReferences().cbegin(), c.getArchetype().getVarReferences().cend(), - [](const auto &v) { return v.second.getTransposeSynapseGroup() != nullptr; })); - const std::string transposeVarName = c.getArchetype().getCustomUpdateModel()->getVarRefs().at(transposeVarIdx).name; + // Create matching environment + EnvironmentGroupMergedField groupEnv(funcEnv, c); - // Loop through presynaptic neurons - funcEnv.getStream() << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; - { - CodeStream::Scope b(funcEnv.getStream()); + // Get index of variable being transposed + const size_t transposeVarIdx = std::distance(c.getArchetype().getVarReferences().cbegin(), + std::find_if(c.getArchetype().getVarReferences().cbegin(), c.getArchetype().getVarReferences().cend(), + [](const auto &v) { return v.second.getTransposeSynapseGroup() != nullptr; })); + const std::string transposeVarName = c.getArchetype().getCustomUpdateModel()->getVarRefs().at(transposeVarIdx).name; - // Loop through each postsynaptic neuron - funcEnv.getStream() << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; + // Loop through presynaptic neurons + groupEnv.getStream() << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; { - CodeStream::Scope b(funcEnv.getStream()); + CodeStream::Scope b(groupEnv.getStream()); - // Add pre and postsynaptic indices to environment - EnvironmentSubstitute synEnv(funcEnv); - synEnv.addSubstitution("id_pre", "i"); - synEnv.addSubstitution("id_post", "j"); + // Loop through each postsynaptic neuron + groupEnv.getStream() << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; + { + CodeStream::Scope b(groupEnv.getStream()); + + // Add pre and postsynaptic indices to environment + groupEnv.add(Type::Uint32, "id_pre", "i"); + groupEnv.add(Type::Uint32, "id_post", "j"); - // Add conditional initialisation code to calculate synapse index - const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * group->numTrgNeurons) + j;"); - synEnv.addSubstitution("id_syn", "idSyn", {idSynInit}); + // Add conditional initialisation code to calculate synapse index + groupEnv.addSubstitution(Type::Uint32, "id_syn", "idSyn", + {groupEnv.addInitialiser("const unsigned int idSyn = (i * " + groupEnv["num_post"] + ") + j;")}, + {"num_post"}); - // Generate custom update - c.generateCustomUpdate(*this, synEnv); + // Generate custom update + c.generateCustomUpdate(*this, synEnv); - // Update transpose variable - synEnv.getStream() << "group->" << transposeVarName << "Transpose[(j * group->numSrcNeurons) + i] = l" << transposeVarName << ";" << std::endl; + // Update transpose variable + // **YUCK** this is sorta outside scope + synEnv.getStream() << groupEnv[transposeVarName + "_transpose"] << "[(j * " << groupEnv["num_pre"] << ") + i] = l" << transposeVarName << ";" << std::endl; + } } - } - } - } + } + }); } } } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index c5ecdab0c9..b8f1cc6361 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -141,26 +141,24 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e } //------------------------------------------------------------------------ // Initialise one row of weight update model variables -template -void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, G &group, - const Models::Base::VarVec &vars, const std::unordered_map &varInitialisers, +template +void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, G &groupMerged, const std::string &stride, unsigned int batchSize, - EnvironmentGroupMergedField::IsVarInitHeterogeneousFn isParamHeterogeneousFn, - EnvironmentGroupMergedField::IsVarInitHeterogeneousFn isDerivedParamHeterogeneousFn, V genSynapseVariableRowInitFn) { + A adaptor(groupMerged.getArchetype()); for (const auto &var : vars) { // If this variable has any initialisation code and doesn't require a kernel - const auto resolvedType = var.type.resolve(group.getTypeContext()); - const auto &varInit = varInitialisers.at(var.name); - const auto *snippet = adaptor.varInit.getSnippet(); - if(!snippet->getCode().empty() && !varInit.getSnippet()->requiresKernel()) { + const auto resolvedType = var.type.resolve(groupMerged.getTypeContext()); + const auto &varInit = adaptor.getInitialisers().at(var.name); + const auto *snippet = varInit.getSnippet(); + if(!snippet->getCode().empty() && !snippet->requiresKernel()) { CodeStream::Scope b(env.getStream()); // Substitute in parameters and derived parameters for initialising variables - EnvironmentGroupMergedField varEnv(env, group); - varEnv.addVarInitParams(isParamHeterogeneousFn, fieldSuffix); - varEnv.addVarInitDerivedParams(isDerivedParamHeterogeneousFn, fieldSuffix); + EnvironmentGroupMergedField varEnv(env, groupMerged); + varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous, fieldSuffix); + varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, fieldSuffix); varEnv.addExtraGlobalParams(snippet->getExtraGlobalParameters(), backend.getDeviceVarPrefix(), var.name, fieldSuffix); // Add field for variable itself @@ -180,8 +178,8 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, varInitEnv.add(resolvedType, "value", "initVal"); // Pretty print variable initialisation code - Transpiler::ErrorHandler errorHandler("Variable '" + var.name + "' init code" + std::to_string(group.getIndex())); - prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Variable '" + var.name + "' init code" + std::to_string(groupMerged.getIndex())); + prettyPrintStatements(snippet->getCode(), groupMerged.getTypeContext(), varInitEnv, errorHandler); // Fill value across all batches genVariableFill(varInitEnv(), "_value", "initVal", "id_syn", stride, @@ -400,7 +398,7 @@ const std::string NeuronInitGroupMerged::name = "NeuronInit"; //---------------------------------------------------------------------------- NeuronInitGroupMerged::NeuronInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, const std::vector> &groups) -: NeuronGroupMergedBase(index, typeContext, backend, groups) +: NeuronGroupMergedBase(index, typeContext, groups) { // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, @@ -623,51 +621,55 @@ bool NeuronInitGroupMerged::isVarInitDerivedParamHeterogeneous(const std::string //---------------------------------------------------------------------------- const std::string SynapseInitGroupMerged::name = "SynapseInit"; //---------------------------------------------------------------------------- -void SynapseInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void SynapseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { + // Create environment for group + EnvironmentGroupMergedField groupEnv(env, *this); + // If model is batched and has kernel weights const bool kernel = (getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL); if (kernel && modelMerged.getModel().getBatchSize() > 1) { // Loop through kernel dimensions and multiply together to calculate batch stride - os << "const unsigned int batchStride = "; + // **TODO** dependency for add + std::ostringstream batchStrideInit; + batchStrideInit << "const unsigned int batchStride = "; const auto &kernelSize = getArchetype().getKernelSize(); for (size_t i = 0; i < kernelSize.size(); i++) { - os << getKernelSize(i); + batchStrideInit << getKernelSize(i); if (i != (kernelSize.size() - 1)) { - os << " * "; + batchStrideInit << " * "; } } - os << ";" << std::endl;; + batchStrideInit << ";" << std::endl; + groupEnv.add(Type::Uint32.addConst(), "_batch_stride", "batchStride", + {groupEnv.addInitialiser(batchStrideInit.str())}); } // If we're using non-kernel weights, generate loop over source neurons if (!kernel) { - os << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; - os << CodeStream::OB(1); - popSubs.addVarSubstitution("id_pre", "i"); + groupEnv.getStream() << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; + groupEnv.getStream() << CodeStream::OB(1); + groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); } // Generate initialisation code - const std::string stride = kernel ? "batchStride" : "group->numSrcNeurons * group->rowStride"; - genInitWUVarCode(os, modelMerged, popSubs, getArchetype().getWUModel()->getVars(), - getArchetype().getWUVarInitialisers(), stride, getIndex(), modelMerged.getModel().getBatchSize(), - [this](const std::string &v, const std::string &p) { return isWUVarInitParamHeterogeneous(v, p); }, - [this](const std::string &v, const std::string &p) { return isWUVarInitDerivedParamHeterogeneous(v, p); }, - [&backend, kernel, this](CodeStream &os, const Substitutions &kernelSubs, BackendBase::Handler handler) - { - if (kernel) { - backend.genKernelSynapseVariableInit(os, *this, kernelSubs, handler); - } - else { - backend.genDenseSynapseVariableRowInit(os, kernelSubs, handler); - } - }); + const std::string stride = kernel ? groupEnv["_batch_stride"] : groupEnv["num_pre"] + " * " + groupEnv["_row_stride"]; + genInitWUVarCode(backend, groupEnv, *this, stride, modelMerged.getModel().getBatchSize(), + [&backend, kernel, this](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) + { + if (kernel) { + backend.genKernelSynapseVariableInit(varInitEnv, *this, handler); + } + else { + backend.genDenseSynapseVariableRowInit(varInitEnv, handler); + } + }); // If we're using non-kernel weights, close loop if (!kernel) { - os << CodeStream::CB(1); + groupEnv.getStream() << CodeStream::CB(1); } } @@ -676,15 +678,12 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, CodeStream //---------------------------------------------------------------------------- const std::string SynapseSparseInitGroupMerged::name = "SynapseSparseInit"; //---------------------------------------------------------------------------- -void SynapseSparseInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void SynapseSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { - genInitWUVarCode(os, modelMerged, popSubs, getArchetype().getWUModel()->getVars(), - getArchetype().getWUVarInitialisers(), "group->numSrcNeurons * group->rowStride", getIndex(), modelMerged.getModel().getBatchSize(), - [this](const std::string &v, const std::string &p) { return isWUVarInitParamHeterogeneous(v, p); }, - [this](const std::string &v, const std::string &p) { return isWUVarInitDerivedParamHeterogeneous(v, p); }, - [&backend](CodeStream &os, const Substitutions &kernelSubs, BackendBase::Handler handler) + genInitWUVarCode(backend, env, *this, env["num_pre"] + " * " + env["_row_stride"], modelMerged.getModel().getBatchSize(), + [&backend](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { - backend.genSparseSynapseVariableRowInit(os, kernelSubs, handler); + backend.genSparseSynapseVariableRowInit(varInitEnv, handler); }); } @@ -693,23 +692,25 @@ void SynapseSparseInitGroupMerged::generateInit(const BackendBase &backend, Code //---------------------------------------------------------------------------- const std::string SynapseConnectivityInitGroupMerged::name = "SynapseConnectivityInit"; //---------------------------------------------------------------------------- -void SynapseConnectivityInitGroupMerged::generateSparseRowInit(const BackendBase&, CodeStream &os, const ModelSpecMerged &, Substitutions &popSubs) const +void SynapseConnectivityInitGroupMerged::generateSparseRowInit(const BackendBase &backend, EnvironmentExternalBase &env) { - genInitConnectivity(os, popSubs, true); + genInitConnectivity(backend, env, true); } //---------------------------------------------------------------------------- -void SynapseConnectivityInitGroupMerged::generateSparseColumnInit(const BackendBase&, CodeStream &os, const ModelSpecMerged &, Substitutions &popSubs) const +void SynapseConnectivityInitGroupMerged::generateSparseColumnInit(const BackendBase &backend, EnvironmentExternalBase &env) { - genInitConnectivity(os, popSubs, false); + genInitConnectivity(backend, env, false); } //---------------------------------------------------------------------------- -void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase&, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase &backend, EnvironmentExternalBase &env) { - // Generate kernel index and add to substitutions - os << "const unsigned int kernelInd = "; - genKernelIndex(os, popSubs); - os << ";" << std::endl; - popSubs.addVarSubstitution("id_kernel", "kernelInd"); + // Create environment for group + EnvironmentGroupMergedField groupEnv(env, *this); + + // Add substitution + // **TODO** dependencies on kernel fields + groupEnv.add(Type::Uint32, "id_kernel", "kernelInd", + {groupEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(groupEnv) + ";")}); for(const auto &var : getArchetype().getWUModel()->getVars()) { const auto &varInit = getArchetype().getWUVarInitialisers().at(var.name); @@ -743,44 +744,45 @@ void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase&, } } //---------------------------------------------------------------------------- -void SynapseConnectivityInitGroupMerged::genInitConnectivity(CodeStream &os, Substitutions &popSubs, bool rowNotColumns) const +void SynapseConnectivityInitGroupMerged::genInitConnectivity(const BackendBase &backend, EnvironmentExternalBase &env, bool rowNotColumns) { const auto &connectInit = getArchetype().getConnectivityInitialiser(); const auto *snippet = connectInit.getSnippet(); - // Add substitutions - popSubs.addFuncSubstitution(rowNotColumns ? "endRow" : "endCol", 0, "break"); - popSubs.addParamValueSubstitution(snippet->getParamNames(), connectInit.getParams(), - [this](const std::string &p) { return isSparseConnectivityInitParamHeterogeneous(p); }, - "", "group->"); - popSubs.addVarValueSubstitution(snippet->getDerivedParams(), connectInit.getDerivedParams(), - [this](const std::string &p) { return isSparseConnectivityInitDerivedParamHeterogeneous(p); }, - "", "group->"); - popSubs.addVarNameSubstitution(snippet->getExtraGlobalParams(), "", "group->"); + // Create environment for group + EnvironmentGroupMergedField groupEnv(env, *this); + + // Add substitution for end function + // **TODO** remove + groupEnv.add(Type::ResolvedType::createFunction(Type::Void, {}), rowNotColumns ? "endRow" : "endCol", "break;"); + + // Substitute in parameters and derived parameters for initialising variables + groupEnv.addConnectInitParams("", &SynapseGroupInternal::getConnectivityInitialiser, + &SynapseConnectivityInitGroupMerged::isSparseConnectivityInitParamHeterogeneous); + groupEnv.addConnectInitDerivedParams("", &SynapseGroupInternal::getConnectivityInitialiser, + &SynapseConnectivityInitGroupMerged::isSparseConnectivityInitDerivedParamHeterogeneous); + groupEnv.addExtraGlobalParams(snippet->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", ""); + // Initialise state variables and loop on generated code to initialise sparse connectivity - os << "// Build sparse connectivity" << std::endl; + groupEnv.getStream() << "// Build sparse connectivity" << std::endl; const auto stateVars = rowNotColumns ? snippet->getRowBuildStateVars() : snippet->getColBuildStateVars(); + const std::string context = rowNotColumns ? "row" : "column"; for(const auto &a : stateVars) { - // Apply substitutions to value - std::string value = a.value; - popSubs.applyCheckUnreplaced(value, "initSparseConnectivity state var : merged" + std::to_string(getIndex())); - //value = ensureFtype(value, ftype); + + groupEnv.getStream() << a.type.resolve(getTypeContext()).getName() << " " << a.name << " = "; - os << a.type.resolve(getTypeContext()).getName() << " " << a.name << " = " << value << ";" << std::endl; + Transpiler::ErrorHandler errorHandler("Connectivity init " + context + " build state var" + std::to_string(getIndex())); + prettyPrintExpression(a.value, getTypeContext(), groupEnv, errorHandler); + + groupEnv.getStream() << ";" << std::endl; } - os << "while(true)"; + groupEnv.getStream() << "while(true)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(groupEnv.getStream()); - // Apply substitutions to row build code - std::string code = rowNotColumns ? snippet->getRowBuildCode() : snippet->getColBuildCode(); - popSubs.addVarNameSubstitution(stateVars); - popSubs.applyCheckUnreplaced(code, "initSparseConnectivity : merged" + std::to_string(getIndex())); - //code = ensureFtype(code, ftype); - - // Write out code - os << code << std::endl; + Transpiler::ErrorHandler errorHandler("Connectivity init " + context + " build" + std::to_string(getIndex())); + prettyPrintStatements(rowNotColumns ? snippet->getRowBuildCode() : snippet->getColBuildCode(), getTypeContext(), groupEnv, errorHandler); } } @@ -794,8 +796,7 @@ SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(s const std::vector> &groups) : GroupMerged(index, typeContext, groups) { - using namespace Type; - + // **TODO** these could be generic addField(Uint32, "numSrcNeurons", [](const auto &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); @@ -843,17 +844,20 @@ SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(s } } //------------------------------------------------------------------------- -void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged) const +void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { - CodeStream::Scope b(os); - os << "// merged synapse connectivity host init group " << getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << getGroups().size() << "; g++)"; + + CodeStream::Scope b(env.getStream()); + env.getStream() << "// merged synapse connectivity host init group " << getIndex() << std::endl; + env.getStream() << "for(unsigned int g = 0; g < " << getGroups().size() << "; g++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); // Get reference to group - os << "const auto *group = &mergedSynapseConnectivityHostInitGroup" << getIndex() << "[g]; " << std::endl; - + env.getStream() << "const auto *group = &mergedSynapseConnectivityHostInitGroup" << getIndex() << "[g]; " << std::endl; + + // Create environment for group + EnvironmentGroupMergedField groupEnv(env, *this); const auto &connectInit = getArchetype().getConnectivityInitialiser(); // If matrix type is procedural then initialized connectivity init snippet will potentially be used with multiple threads per spike. @@ -861,19 +865,21 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac const size_t numThreads = (getArchetype().getMatrixType() & SynapseMatrixConnectivity::PROCEDURAL) ? getArchetype().getNumThreadsPerSpike() : 1; // Create substitutions - Substitutions subs; - subs.addVarSubstitution("rng", "hostRNG"); - subs.addVarSubstitution("num_pre", "group->numSrcNeurons"); - subs.addVarSubstitution("num_post", "group->numTrgNeurons"); - subs.addVarSubstitution("num_threads", std::to_string(numThreads)); - subs.addVarNameSubstitution(connectInit.getSnippet()->getExtraGlobalParams(), "", "*group->"); - subs.addParamValueSubstitution(connectInit.getSnippet()->getParamNames(), connectInit.getParams(), - [this](const std::string &p) { return isConnectivityInitParamHeterogeneous(p); }, - "", "group->"); - subs.addVarValueSubstitution(connectInit.getSnippet()->getDerivedParams(), connectInit.getDerivedParams(), - [this](const std::string &p) { return isConnectivityInitDerivedParamHeterogeneous(p); }, - "", "group->"); - + groupEnv.addField(Type::Uint32.addConst(), "num_pre", + Type::Uint32, "numSrcNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + groupEnv.addField(Type::Uint32.addConst(), "num_post", + Type::Uint32, "numTrgNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); + groupEnv.add(Type::Uint32.addConst(), "num_threads", std::to_string(numThreads)); + + groupEnv.addConnectInitParams("", &SynapseGroupInternal::getConnectivityInitialiser, + &SynapseConnectivityHostInitGroupMerged::isConnectivityInitParamHeterogeneous); + groupEnv.addConnectInitDerivedParams("", &SynapseGroupInternal::getConnectivityInitialiser, + &SynapseConnectivityHostInitGroupMerged::isSparseConnectivityInitDerivedParamHeterogeneous); + + //subs.addVarNameSubstitution(connectInit.getSnippet()->getExtraGlobalParams(), "", "*group->"); + // Loop through EGPs for(const auto &egp : connectInit.getSnippet()->getExtraGlobalParams()) { // If EGP is located on the host @@ -889,7 +895,7 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac loc, "$(0)", "group->"); // Add substitution - subs.addFuncSubstitution("allocate" + egp.name, 1, allocStream.str()); + groupEnv.add(Type::ResolvedType::createFunction(Type::Void, {Type::Uint32}), "allocate" + egp.name, allocStream.str()); // Generate code to push this EGP with count specified by $(0) std::stringstream pushStream; @@ -900,15 +906,11 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac // Add substitution - subs.addFuncSubstitution("push" + egp.name, 1, pushStream.str()); + groupEnv.add(Type::ResolvedType::createFunction(Type::Void, {Type::Uint32}), "push" + egp.name, pushStream.str()); } } - std::string code = connectInit.getSnippet()->getHostInitCode(); - subs.applyCheckUnreplaced(code, "hostInitSparseConnectivity : merged" + std::to_string(getIndex())); - //code = ensureFtype(code, modelMerged.getModel().getPrecision()); - - // Write out code - os << code << std::endl; + Transpiler::ErrorHandler errorHandler("Connectivity host init" + std::to_string(getIndex())); + prettyPrintStatements(connectInit.getSnippet()->getHostInitCode(), getTypeContext(), groupEnv, errorHandler); } } //---------------------------------------------------------------------------- @@ -936,14 +938,6 @@ bool SynapseConnectivityHostInitGroupMerged::isSparseConnectivityInitParamRefere //---------------------------------------------------------------------------- const std::string CustomUpdateInitGroupMerged::name = "CustomUpdateInit"; //---------------------------------------------------------------------------- -CustomUpdateInitGroupMerged::CustomUpdateInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) -{ - addField(Type::Uint32, "size", - [](const auto &c, size_t) { return std::to_string(c.getSize()); }); -} -//---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomUpdateInitGroupMerged::getHashDigest() const { boost::uuids::detail::sha1 hash; @@ -957,11 +951,11 @@ boost::uuids::detail::sha1::digest_type CustomUpdateInitGroupMerged::getHashDige return hash.get_digest(); } // ---------------------------------------------------------------------------- -void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternal &env, const ModelSpecMerged &modelMerged) const +void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { // Initialise custom update variables - genInitNeuronVarCode(backend, env, *this, m_VarInitASTs, ""), - "size", 1, getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1); + genInitNeuronVarCode(backend, env, *this, "", "size", 1, + getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1); } // ---------------------------------------------------------------------------- @@ -969,32 +963,6 @@ void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, Envir //---------------------------------------------------------------------------- const std::string CustomWUUpdateInitGroupMerged::name = "CustomWUUpdateInit"; //---------------------------------------------------------------------------- -CustomWUUpdateInitGroupMerged::CustomWUUpdateInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) -{ - using namespace Type; - - if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL) { - // Loop through kernel size dimensions - for (size_t d = 0; d < getArchetype().getSynapseGroup()->getKernelSize().size(); d++) { - // If this dimension has a heterogeneous size, add it to struct - if (isKernelSizeHeterogeneous(d)) { - addField(Uint32, "kernelSize" + std::to_string(d), - [d](const auto &g, size_t) { return std::to_string(g.getSynapseGroup()->getKernelSize().at(d)); }); - } - } - } - else { - addField(Uint32, "rowStride", - [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); - addField(Uint32, "numSrcNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); - addField(Uint32, "numTrgNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); - } -} -//---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomWUUpdateInitGroupMerged::getHashDigest() const { boost::uuids::detail::sha1 hash; @@ -1030,6 +998,25 @@ boost::uuids::detail::sha1::digest_type CustomWUUpdateInitGroupMerged::getHashDi // ---------------------------------------------------------------------------- void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { + /*if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL) { + // Loop through kernel size dimensions + for (size_t d = 0; d < getArchetype().getSynapseGroup()->getKernelSize().size(); d++) { + // If this dimension has a heterogeneous size, add it to struct + if (isKernelSizeHeterogeneous(d)) { + addField(Uint32, "kernelSize" + std::to_string(d), + [d](const auto &g, size_t) { return std::to_string(g.getSynapseGroup()->getKernelSize().at(d)); }); + } + } + } + else { + addField(Uint32, "rowStride", + [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); + addField(Uint32, "numSrcNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); + addField(Uint32, "numTrgNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); + }*/ + const bool kernel = (getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL); if(kernel && modelMerged.getModel().getBatchSize() > 1) { // Loop through kernel dimensions and multiply together to calculate batch stride @@ -1079,34 +1066,6 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Cod //---------------------------------------------------------------------------- const std::string CustomWUUpdateSparseInitGroupMerged::name = "CustomWUUpdateSparseInit"; //---------------------------------------------------------------------------- -CustomWUUpdateSparseInitGroupMerged::CustomWUUpdateSparseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) -{ - using namespace Type; - - addField(Uint32, "rowStride", - [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); - - addField(Uint32, "numSrcNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); - addField(Uint32, "numTrgNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); - - addField(Uint32.createPointer(), "rowLength", - [&backend](const auto &cg, size_t) - { - const SynapseGroupInternal *sg = cg.getSynapseGroup(); - return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); - }); - addField(getArchetype().getSynapseGroup()->getSparseIndType().createPointer(), "ind", - [&backend](const auto &cg, size_t) - { - const SynapseGroupInternal *sg = cg.getSynapseGroup(); - return backend.getDeviceVarPrefix() + "ind" + sg->getName(); - }); -} -//---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomWUUpdateSparseInitGroupMerged::getHashDigest() const { boost::uuids::detail::sha1 hash; @@ -1135,6 +1094,27 @@ boost::uuids::detail::sha1::digest_type CustomWUUpdateSparseInitGroupMerged::get // ---------------------------------------------------------------------------- void CustomWUUpdateSparseInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const { + /* addField(Uint32, "rowStride", + [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); + + addField(Uint32, "numSrcNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); + addField(Uint32, "numTrgNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); + + addField(Uint32.createPointer(), "rowLength", + [&backend](const auto &cg, size_t) + { + const SynapseGroupInternal *sg = cg.getSynapseGroup(); + return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); + }); + addField(getArchetype().getSynapseGroup()->getSparseIndType().createPointer(), "ind", + [&backend](const auto &cg, size_t) + { + const SynapseGroupInternal *sg = cg.getSynapseGroup(); + return backend.getDeviceVarPrefix() + "ind" + sg->getName(); + });*/ + genInitWUVarCode(os, modelMerged, popSubs, getArchetype().getCustomUpdateModel()->getVars(), getArchetype().getVarInitialisers(), "group->numSrcNeurons * group->rowStride", getIndex(), getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, @@ -1151,22 +1131,6 @@ void CustomWUUpdateSparseInitGroupMerged::generateInit(const BackendBase &backen //---------------------------------------------------------------------------- const std::string CustomConnectivityUpdatePreInitGroupMerged::name = "CustomConnectivityUpdatePreInit"; //---------------------------------------------------------------------------- -CustomConnectivityUpdatePreInitGroupMerged::CustomConnectivityUpdatePreInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) -{ - addField(Type::Uint32, "size", - [](const auto &c, size_t) - { - return std::to_string(c.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); - }); - - // If this backend initialises population RNGs on device and this group requires one for simulation - if(backend.isPopulationRNGRequired() && getArchetype().isRowSimRNGRequired() && backend.isPopulationRNGInitialisedOnDevice()) { - addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); - } -} -//---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePreInitGroupMerged::getHashDigest() const { boost::uuids::detail::sha1 hash; @@ -1183,14 +1147,21 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePreInitGroupMerg return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomConnectivityUpdatePreInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void CustomConnectivityUpdatePreInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { + // Create environment for group + EnvironmentGroupMergedField groupEnv(env, *this); + + groupEnv.addField(Type::Uint32.addConst(), "size", + Type::Uint32, "size", + [](const auto &c, size_t) + { + return std::to_string(c.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); + }); + // Initialise presynaptic custom connectivity update variables - // **TODO** adaptor - genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getCustomConnectivityUpdateModel()->getPreVars(), getArchetype().getPreVarInitialisers(), - "", "size", getIndex(), 1, - [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, - [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); + genInitNeuronVarCode( + backend, groupEnv, *this, "", "size", 0, modelMerged.getModel().getBatchSize()); } // ---------------------------------------------------------------------------- @@ -1198,17 +1169,6 @@ void CustomConnectivityUpdatePreInitGroupMerged::generateInit(const BackendBase //---------------------------------------------------------------------------- const std::string CustomConnectivityUpdatePostInitGroupMerged::name = "CustomConnectivityUpdatePostInit"; //---------------------------------------------------------------------------- -CustomConnectivityUpdatePostInitGroupMerged::CustomConnectivityUpdatePostInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) -{ - addField(Type::Uint32, "size", - [](const auto &c, size_t) - { - return std::to_string(c.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); - }); -} -//---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePostInitGroupMerged::getHashDigest() const { boost::uuids::detail::sha1 hash; @@ -1225,14 +1185,21 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePostInitGroupMer return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { + // Create environment for group + EnvironmentGroupMergedField groupEnv(env, *this); + + groupEnv.addField(Type::Uint32.addConst(), "size", + Type::Uint32, "size", + [](const auto &c, size_t) + { + return std::to_string(c.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); + }); + // Initialise presynaptic custom connectivity update variables - // **TODO** adapter - genInitNeuronVarCode(os, modelMerged, backend, popSubs, getArchetype().getCustomConnectivityUpdateModel()->getPostVars(), getArchetype().getPostVarInitialisers(), - "", "size", getIndex(), 1, - [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, - [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }); + genInitNeuronVarCode( + backend, groupEnv, *this, "", "size", 0, modelMerged.getModel().getBatchSize()); } // ---------------------------------------------------------------------------- @@ -1240,34 +1207,6 @@ void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase //---------------------------------------------------------------------------- const std::string CustomConnectivityUpdateSparseInitGroupMerged::name = "CustomConnectivityUpdateSparseInit"; //---------------------------------------------------------------------------- -CustomConnectivityUpdateSparseInitGroupMerged::CustomConnectivityUpdateSparseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: CustomUpdateInitGroupMergedBase(index, typeContext, backend, groups) -{ - using namespace Type; - - addField(Uint32, "rowStride", - [&backend](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); - - addField(Uint32, "numSrcNeurons", - [](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); - addField(Uint32, "numTrgNeurons", - [](const CustomConnectivityUpdateInternal &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); - - addField(Uint32.createPointer(), "rowLength", - [&backend](const CustomConnectivityUpdateInternal &cg, size_t) - { - const SynapseGroupInternal *sg = cg.getSynapseGroup(); - return backend.getDeviceVarPrefix() + "rowLength" + sg->getName(); - }); - addField(getArchetype().getSynapseGroup()->getSparseIndType().createPointer(), "ind", - [&backend](const auto &cg, size_t) - { - const SynapseGroupInternal *sg = cg.getSynapseGroup(); - return backend.getDeviceVarPrefix() + "ind" + sg->getName(); - }); -} -//---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateSparseInitGroupMerged::getHashDigest() const { boost::uuids::detail::sha1 hash; @@ -1294,15 +1233,36 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateSparseInitGroupM return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomConnectivityUpdateSparseInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void CustomConnectivityUpdateSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { + // Create environment for group + EnvironmentGroupMergedField groupEnv(env, *this); + + groupEnv.addField(Type::Uint32.addConst(), "num_pre", + Type::Uint32, "numSrcNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup().getSrcNeuronGroup()->getNumNeurons()); }); + groupEnv.addField(Type::Uint32.addConst(), "num_post", + Type::Uint32, "numTrgNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup().getTrgNeuronGroup()->getNumNeurons()); }); + groupEnv.addField(Type::Uint32, "_row_stride", "rowStride", + [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); + + groupEnv.addField(Type::Uint32.createPointer(), "_row_length", "rowLength", + [&backend](const CustomConnectivityUpdateInternal &cg, size_t) + { + return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); + }); + groupEnv.addField(getArchetype().getSynapseGroup()->getSparseIndType().createPointer(), "_ind", "ind", + [&backend](const auto &cg, size_t) + { + return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); + }); + // Initialise custom connectivity update variables - genInitWUVarCode(os, modelMerged, popSubs, getArchetype().getCustomConnectivityUpdateModel()->getVars(), - getArchetype().getVarInitialisers(), "group->numSrcNeurons * group->rowStride", getIndex(), 1, - [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, - [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }, - [&backend](CodeStream &os, const Substitutions &kernelSubs, BackendBase::Handler handler) - { - return backend.genSparseSynapseVariableRowInit(os, kernelSubs, handler); - }); + genInitWUVarCode( + backend, groupEnv, *this, "group->numSrcNeurons * group->rowStride", 1, + [&backend](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) + { + return backend.genSparseSynapseVariableRowInit(varInitEnv, handler); + }); } From b9682c59fe679eecaef1ae9be79c39366a8bc8ec Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 20 Jun 2023 13:47:18 +0100 Subject: [PATCH 238/725] init group merged now compiles --- .../genn/genn/code_generator/environment.h | 10 +- .../genn/genn/code_generator/groupMerged.h | 4 +- .../genn/code_generator/initGroupMerged.h | 57 +-- include/genn/genn/currentSourceInternal.h | 2 + .../genn/customConnectivityUpdateInternal.h | 8 + include/genn/genn/customUpdate.h | 2 + include/genn/genn/type.h | 1 + .../backends/single_threaded_cpu/backend.cc | 38 +- src/genn/genn/code_generator/groupMerged.cc | 4 +- .../genn/code_generator/initGroupMerged.cc | 371 ++++++++---------- 10 files changed, 250 insertions(+), 247 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index e5d98f6801..811017506a 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -437,7 +437,9 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetDerivedParams()) { // If parameter is heterogeneous, add scalar field if (std::invoke(isHeterogeneous, getGroup(), d.name)) { - addScalar(d, fieldSuffix, + addScalar(d.name, fieldSuffix, [d, getConnectivity](const auto &g, size_t) { - return std::invoke(getConnectivity, g).at(d.name).getDerivedParams()(); + return std::invoke(getConnectivity, g).getDerivedParams().at(d.name); }); } // Otherwise, just add a const-qualified scalar to the type environment diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 231af503ab..31635dd955 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -587,10 +587,10 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged { public: + using ChildGroupMerged::ChildGroupMerged; + //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- @@ -45,9 +47,11 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase // GeNN::CodeGenerator::NeuronInitGroupMerged::InSynPSM //---------------------------------------------------------------------------- //! Child group merged for incoming synapse groups - class InSynPSM : public GroupMerged + class InSynPSM : public ChildGroupMerged { public: + using ChildGroupMerged::ChildGroupMerged; + //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- @@ -75,9 +79,11 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase // GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynPreOutput //---------------------------------------------------------------------------- //! Child group merged for outgoing synapse groups with $(addToPre) logic - class OutSynPreOutput : public GroupMerged + class OutSynPreOutput : public ChildGroupMerged { public: + using ChildGroupMerged::ChildGroupMerged; + //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- @@ -89,9 +95,11 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase // GeNN::CodeGenerator::NeuronInitGroupMerged::InSynWUMPostCode //---------------------------------------------------------------------------- //! Child group merged for incoming synapse groups with postsynaptic variables - class InSynWUMPostVars : public GroupMerged + class InSynWUMPostVars : public ChildGroupMerged { public: + using ChildGroupMerged::ChildGroupMerged; + //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- @@ -119,9 +127,11 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase // GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynWUMPreVars //---------------------------------------------------------------------------- //! Child group merged for outgoing synapse groups with presynaptic variables - class OutSynWUMPreVars: public GroupMerged + class OutSynWUMPreVars: public ChildGroupMerged { public: + using ChildGroupMerged::ChildGroupMerged; + //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- @@ -171,6 +181,12 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase const std::vector &getMergedInSynWUMPostVarGroups() const { return m_MergedInSynWUMPostVarGroups; } const std::vector &getMergedOutSynWUMPreVarGroups() const { return m_MergedOutSynWUMPreVarGroups; } + //! Should the var init parameter be implemented heterogeneously? + bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + + //! Should the var init derived parameter be implemented heterogeneously? + bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + //---------------------------------------------------------------------------- // Static constants //---------------------------------------------------------------------------- @@ -185,13 +201,7 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase void genInitSpikes(const BackendBase &backend, EnvironmentExternalBase &env, bool spikeEvent, unsigned int batchSize); void genInitSpikeTime(const BackendBase &backend, EnvironmentExternalBase &env, const std::string &varName, unsigned int batchSize); - - //! Should the var init parameter be implemented heterogeneously? - bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - //! Should the var init derived parameter be implemented heterogeneously? - bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - + //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ @@ -242,9 +252,9 @@ class GENN_EXPORT SynapseInitGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseSparseInitGroupMerged : public SynapseGroupMergedBase { public: - SynapseSparseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + SynapseSparseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups) - : SynapseGroupMergedBase(index, typeContext, backend, SynapseGroupMergedBase::Role::SparseInit, "", groups) + : SynapseGroupMergedBase(index, typeContext, SynapseGroupMergedBase::Role::SparseInit, "", groups) {} boost::uuids::detail::sha1::digest_type getHashDigest() const @@ -275,9 +285,9 @@ class GENN_EXPORT SynapseSparseInitGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMergedBase { public: - SynapseConnectivityInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + SynapseConnectivityInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups) - : SynapseGroupMergedBase(index, typeContext, backend, SynapseGroupMergedBase::Role::ConnectivityInit, "", groups) + : SynapseGroupMergedBase(index, typeContext, SynapseGroupMergedBase::Role::ConnectivityInit, "", groups) {} boost::uuids::detail::sha1::digest_type getHashDigest() const @@ -296,7 +306,7 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMerged void generateSparseRowInit(const BackendBase &backend, EnvironmentExternalBase &env); void generateSparseColumnInit(const BackendBase &backend, EnvironmentExternalBase &env); - void generateKernelInit(const BackendBase &backend, EnvironmentExternalBase &env); + void generateKernelInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); //---------------------------------------------------------------------------- // Static constants @@ -318,8 +328,7 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMerged class GENN_EXPORT SynapseConnectivityHostInitGroupMerged : public GroupMerged { public: - SynapseConnectivityHostInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + using GroupMerged::GroupMerged; //------------------------------------------------------------------------ // Public API @@ -361,9 +370,9 @@ class GENN_EXPORT SynapseConnectivityHostInitGroupMerged : public GroupMerged class CustomUpdateInitGroupMergedBase : public GroupMerged { -protected: +public: //---------------------------------------------------------------------------- - // Protected methods + // Public API //---------------------------------------------------------------------------- //! Should the var init parameter be implemented heterogeneously? bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const @@ -388,7 +397,10 @@ class CustomUpdateInitGroupMergedBase : public GroupMerged return archetypeAdaptor.getInitialisers().at(varName).getDerivedParams(); })); } - +protected: + //---------------------------------------------------------------------------- + // Protected methods + //---------------------------------------------------------------------------- void updateBaseHash(boost::uuids::detail::sha1 &hash) const { // Update hash with archetype's hash digest @@ -520,8 +532,7 @@ class GENN_EXPORT CustomWUUpdateSparseInitGroupMerged : public CustomUpdateInitG CustomUpdateVarAdapter> { public: - CustomWUUpdateSparseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + using CustomUpdateInitGroupMergedBase::CustomUpdateInitGroupMergedBase; //---------------------------------------------------------------------------- // Public API diff --git a/include/genn/genn/currentSourceInternal.h b/include/genn/genn/currentSourceInternal.h index ba6a7628a2..263d7c4548 100644 --- a/include/genn/genn/currentSourceInternal.h +++ b/include/genn/genn/currentSourceInternal.h @@ -49,6 +49,8 @@ class CurrentSourceVarAdapter const std::unordered_map &getInitialisers() const{ return m_CS.getVarInitialisers(); } + bool isVarDelayed(const std::string&) const{ return false; } + const std::string &getNameSuffix() const{ return m_CS.getName(); } private: diff --git a/include/genn/genn/customConnectivityUpdateInternal.h b/include/genn/genn/customConnectivityUpdateInternal.h index 04a218cc9e..15febd6e23 100644 --- a/include/genn/genn/customConnectivityUpdateInternal.h +++ b/include/genn/genn/customConnectivityUpdateInternal.h @@ -84,6 +84,10 @@ class CustomConnectivityUpdatePreVarAdapter const std::unordered_map &getInitialisers() const{ return m_CU.getPreVarInitialisers(); } + bool isVarDelayed(const std::string &) const { return false; } + + const std::string &getNameSuffix() const{ return m_CU.getName(); } + private: //---------------------------------------------------------------------------- // Members @@ -109,6 +113,10 @@ class CustomConnectivityUpdatePostVarAdapter const std::unordered_map &getInitialisers() const{ return m_CU.getPostVarInitialisers(); } + bool isVarDelayed(const std::string &) const { return false; } + + const std::string &getNameSuffix() const{ return m_CU.getName(); } + private: //---------------------------------------------------------------------------- // Members diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 37d5364bb1..84c1526729 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -185,6 +185,8 @@ class CustomUpdateVarAdapter const std::unordered_map &getInitialisers() const{ return m_CU.getVarInitialisers(); } + bool isVarDelayed(const std::string &) const { return false; } + const std::string &getNameSuffix() const{ return m_CU.getName(); } private: diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index cb14ea5250..868fd341a3 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -220,6 +220,7 @@ struct ResolvedType //------------------------------------------------------------------------ bool isValue() const{ return std::holds_alternative(detail); } bool isPointer() const{ return std::holds_alternative(detail); } + bool isPointerToPointer() const{ return isPointer() && getPointer().valueType->isPointer(); } bool isFunction() const{ return std::holds_alternative(detail); } bool isNumeric() const{ return isValue() && getValue().numeric; } bool isVoid() const{ return std::holds_alternative(detail); } diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 90cd075981..93408f32e4 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -97,8 +97,8 @@ void genKernelIteration(EnvironmentExternal &env, const G &g, size_t numKernelDi // Generate kernel index and use as "synapse" index // **TODO** rename assert(false); - //const size_t kernelInit = loopEnv.addInitialiser("const unsigned int kernelInd = " + g.genKernelIndex(loopEnv) + ";"); - //loopEnv.addVarSubstitution("id_syn", "kernelInd", kernelInit); + //const size_t addSynapse = loopEnv.addInitialiser("const unsigned int kernelInd = " + g.genKernelIndex(loopEnv) + ";"); + //loopEnv.addVarSubstitution("id_syn", "kernelInd", addSynapse); // Call handler handler(loopEnv); @@ -1046,21 +1046,21 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler CodeStream::Scope b(os); // Create new stream to generate addSynapse function which initializes all kernel variables - std::ostringstream kernelInitStream; - CodeStream kernelInit(kernelInitStream); + std::ostringstream addSynapseStream; + CodeStream addSynapse(addSynapseStream); // Use classic macro trick to turn block of initialization code into statement and 'eat' semicolon - kernelInit << "do"; + addSynapse << "do"; { - CodeStream::Scope b(kernelInit); + CodeStream::Scope b(addSynapse); // Calculate index in data structure of this synapse if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { if(!snippet->getRowBuildCode().empty()) { - kernelInit << "const unsigned int idx = " << "(" + popSubs["id_pre"] + " * group->rowStride) + group->rowLength[i];" << std::endl; + addSynapse << "const unsigned int idx = " << "(" + popSubs["id_pre"] + " * group->rowStride) + group->rowLength[i];" << std::endl; } else { - kernelInit << "const unsigned int idx = " << "(($(0)) * group->rowStride) + group->rowLength[$(0)];" << std::endl; + addSynapse << "const unsigned int idx = " << "(($(0)) * group->rowStride) + group->rowLength[$(0)];" << std::endl; } } @@ -1088,39 +1088,39 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler } // Call handler to initialize variables - s.generateKernelInit(*this, kernelInit, modelMerged, kernelInitSubs); + s.generateKernelInit(*this, addSynapse, modelMerged, kernelInitSubs); } // If there is row-building code in this snippet if(!snippet->getRowBuildCode().empty()) { // If matrix is sparse, add function to increment row length and insert synapse into ind array if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - kernelInit << "group->ind[idx] = $(0);" << std::endl; - kernelInit << "group->rowLength[i]++;" << std::endl; + addSynapse << "group->ind[idx] = $(0);" << std::endl; + addSynapse << "group->rowLength[i]++;" << std::endl; } // Otherwise, add function to set correct bit in bitmask else { - kernelInit << "const int64_t rowStartGID = i * group->rowStride;" << std::endl; - kernelInit << "setB(group->gp[(rowStartGID + ($(0))) / 32], (rowStartGID + $(0)) & 31);" << std::endl; + addSynapse << "const int64_t rowStartGID = i * group->rowStride;" << std::endl; + addSynapse << "setB(group->gp[(rowStartGID + ($(0))) / 32], (rowStartGID + $(0)) & 31);" << std::endl; } } // Otherwise else { // If matrix is sparse, add function to increment row length and insert synapse into ind array if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - kernelInit << "group->ind[idx] = " << popSubs["id_post"] << ";" << std::endl; - kernelInit << "group->rowLength[$(0)]++;" << std::endl; + addSynapse << "group->ind[idx] = " << popSubs["id_post"] << ";" << std::endl; + addSynapse << "group->rowLength[$(0)]++;" << std::endl; } else { - kernelInit << "const int64_t colStartGID = j;" << std::endl; - kernelInit << "setB(group->gp[(colStartGID + (($(0)) * group->rowStride)) / 32], ((colStartGID + (($(0)) * group->rowStride)) & 31));" << std::endl; + addSynapse << "const int64_t colStartGID = j;" << std::endl; + addSynapse << "setB(group->gp[(colStartGID + (($(0)) * group->rowStride)) / 32], ((colStartGID + (($(0)) * group->rowStride)) & 31));" << std::endl; } } } - kernelInit << "while(false)"; + addSynapse << "while(false)"; popSubs.addFuncSubstitution("addSynapse", 1 + (unsigned int)s.getArchetype().getKernelSize().size(), - kernelInitStream.str()); + addSynapseStream.str()); // Call appropriate connectivity handler if(!snippet->getRowBuildCode().empty()) { diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index e748ba2f8e..0c42754cde 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -130,13 +130,13 @@ bool SynapseGroupMergedBase::isWUGlobalVarHeterogeneous(const std::string &varNa isParamValueHeterogeneous(varName, [](const SynapseGroupInternal &sg) { return sg.getWUConstInitVals(); })); } //---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isWUVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +bool SynapseGroupMergedBase::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { return (isWUVarInitParamReferenced(varName, paramName) && isParamValueHeterogeneous(paramName, [varName](const SynapseGroupInternal &sg){ return sg.getWUVarInitialisers().at(varName).getParams(); })); } //---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isWUVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +bool SynapseGroupMergedBase::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { return (isWUVarInitParamReferenced(varName, paramName) && isParamValueHeterogeneous(paramName, [varName](const SynapseGroupInternal &sg) { return sg.getWUVarInitialisers().at(varName).getDerivedParams(); })); diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index b8f1cc6361..2004864a76 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -60,12 +60,12 @@ void genScalarFill(EnvironmentExternalBase &env, const std::string &target, cons } } //------------------------------------------------------------------------ -template +template void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &env, G &group, F &fieldGroup, const std::string &fieldSuffix, const std::string &count, size_t numDelaySlots, unsigned int batchSize) { - A adaptor(groupMerged.getArchetype()); + A adaptor(group.getArchetype()); for (const auto &var : adaptor.getDefs()) { // If there is any initialisation code const auto resolvedType = var.type.resolve(group.getTypeContext()); @@ -78,7 +78,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e EnvironmentGroupMergedField varEnv(env, group, fieldGroup); varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous, fieldSuffix); varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, fieldSuffix); - varEnv.addExtraGlobalParams(snippet->getExtraGlobalParameters(), backend.getDeviceVarPrefix(), var.name, fieldSuffix); + varEnv.addExtraGlobalParams(snippet->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name, fieldSuffix); // Add field for variable itself varEnv.addField(resolvedType.createPointer(), "_value", var.name + fieldSuffix, @@ -91,10 +91,11 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e if (getVarAccessDuplication(var.access) == VarAccessDuplication::SHARED_NEURON) { backend.genPopVariableInit( varEnv, - [&adaptor, &fieldSuffix, &group, &resolvedType, &var, batchSize, numDelaySlots, snippet] - (EnvironmentExternalBase &varInitEnv) + [&adaptor, &fieldGroup, &fieldSuffix, &group, &resolvedType, &var, batchSize, numDelaySlots, snippet] + (EnvironmentExternalBase &env) { // Generate initial value into temporary variable + EnvironmentGroupMergedField varInitEnv(env, group, fieldGroup); varInitEnv.getStream() << resolvedType.getName() << " initVal;" << std::endl; varInitEnv.add(resolvedType, "value", "initVal"); @@ -103,18 +104,19 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all delay slots and batches - genScalarFill(varInitEnv, "_value", "initVal", getVarAccessDuplication(var.access), + genScalarFill(varInitEnv, "value", "initVal", getVarAccessDuplication(var.access), batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } // Otherwise else { backend.genVariableInit( - varEnvs, count, "id", - [&adaptor, &fieldSuffix, &group, &var, &resolvedType, batchSize, count, numDelaySlots] - (EnvironmentExternalBase &varInitEnv) + varEnv, count, "id", + [&adaptor, &fieldGroup, &fieldSuffix, &group, &var, &resolvedType, batchSize, count, numDelaySlots, snippet] + (EnvironmentExternalBase &env) { // Generate initial value into temporary variable + EnvironmentGroupMergedField varInitEnv(env, group, fieldGroup); varInitEnv.getStream() << resolvedType.getName() << " initVal;" << std::endl; varInitEnv.add(resolvedType, "value", "initVal"); @@ -123,7 +125,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all delay slots and batches - genVariableFill(varInitEnv(), "_value", "initVal", "id", count, + genVariableFill(varInitEnv, "value", "initVal", "id", count, getVarAccessDuplication(var.access), batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } @@ -137,32 +139,32 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e G &group, const std::string &fieldSuffix, const std::string &count, size_t numDelaySlots, unsigned int batchSize) { - genInitNeuronVarCode(backend, env, group, group, fieldSuffix, count, numDelaySlots, batchSize); + genInitNeuronVarCode(backend, env, group, group, fieldSuffix, count, numDelaySlots, batchSize); } //------------------------------------------------------------------------ // Initialise one row of weight update model variables template -void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, G &groupMerged, +void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, G &group, const std::string &stride, unsigned int batchSize, V genSynapseVariableRowInitFn) { - A adaptor(groupMerged.getArchetype()); - for (const auto &var : vars) { + A adaptor(group.getArchetype()); + for (const auto &var : adaptor.getDefs()) { // If this variable has any initialisation code and doesn't require a kernel - const auto resolvedType = var.type.resolve(groupMerged.getTypeContext()); + const auto resolvedType = var.type.resolve(group.getTypeContext()); const auto &varInit = adaptor.getInitialisers().at(var.name); const auto *snippet = varInit.getSnippet(); if(!snippet->getCode().empty() && !snippet->requiresKernel()) { CodeStream::Scope b(env.getStream()); // Substitute in parameters and derived parameters for initialising variables - EnvironmentGroupMergedField varEnv(env, groupMerged); - varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous, fieldSuffix); - varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, fieldSuffix); - varEnv.addExtraGlobalParams(snippet->getExtraGlobalParameters(), backend.getDeviceVarPrefix(), var.name, fieldSuffix); + EnvironmentGroupMergedField varEnv(env, group); + varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous); + varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous); + varEnv.addExtraGlobalParams(snippet->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name); // Add field for variable itself - varEnv.addField(resolvedType.createPointer(), "_value", var.name + fieldSuffix, + varEnv.addField(resolvedType.createPointer(), "_value", var.name, [&backend, var](const auto &g, size_t) { return backend.getDeviceVarPrefix() + var.name + g.getName(); @@ -170,19 +172,20 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, // Generate target-specific code to initialise variable genSynapseVariableRowInitFn(varEnv, - [&group, &modelMerged, &resolvedType, &stride, &var, batchSize, snippet] - (EnvironmentExternalBase &varInitEnv) + [&group, &resolvedType, &stride, &var, batchSize, snippet] + (EnvironmentExternalBase &env) { // Generate initial value into temporary variable + EnvironmentGroupMergedField varInitEnv(env, group); varInitEnv.getStream() << resolvedType.getName() << " initVal;" << std::endl; varInitEnv.add(resolvedType, "value", "initVal"); // Pretty print variable initialisation code - Transpiler::ErrorHandler errorHandler("Variable '" + var.name + "' init code" + std::to_string(groupMerged.getIndex())); - prettyPrintStatements(snippet->getCode(), groupMerged.getTypeContext(), varInitEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Variable '" + var.name + "' init code" + std::to_string(group.getIndex())); + prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all batches - genVariableFill(varInitEnv(), "_value", "initVal", "id_syn", stride, + genVariableFill(varInitEnv, "value", "initVal", "id_syn", stride, getVarAccessDuplication(var.access), batchSize); }); } @@ -196,7 +199,7 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, void NeuronInitGroupMerged::CurrentSource::generate(const BackendBase &backend, EnvironmentExternalBase &env, NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) { - genInitNeuronVarCode( + genInitNeuronVarCode( backend, env, *this, ng, "CS" + std::to_string(getIndex()), "num_neurons", 0, modelMerged.getModel().getBatchSize()); } @@ -272,7 +275,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir }); } - genInitNeuronVarCode( + genInitNeuronVarCode( backend, groupEnv, *this, ng, fieldSuffix, "num_neurons", 0, modelMerged.getModel().getBatchSize()); } //---------------------------------------------------------------------------- @@ -329,7 +332,7 @@ void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend void NeuronInitGroupMerged::InSynWUMPostVars::generate(const BackendBase &backend, EnvironmentExternalBase &env, NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) { - genInitNeuronVarCode( + genInitNeuronVarCode( backend, env, *this, ng, "InSynWUMPost" + std::to_string(getIndex()), "num_neurons", 0, modelMerged.getModel().getBatchSize()); } //---------------------------------------------------------------------------- @@ -363,7 +366,7 @@ bool NeuronInitGroupMerged::InSynWUMPostVars::isVarInitParamReferenced(const std void NeuronInitGroupMerged::OutSynWUMPreVars::generate(const BackendBase &backend, EnvironmentExternalBase &env, NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) { - genInitNeuronVarCode( + genInitNeuronVarCode( backend, env, *this, ng, "OutSynWUMPre" + std::to_string(getIndex()), "num_neurons", 0, modelMerged.getModel().getBatchSize()); } //---------------------------------------------------------------------------- @@ -507,8 +510,7 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, Environment } // Initialise neuron variables - genInitNeuronVarCode( - backend, env, *this, "", "num_neurons", 0, modelMerged.getModel().getBatchSize()); + genInitNeuronVarCode(backend, env, *this, "", "num_neurons", 0, modelMerged.getModel().getBatchSize()); // Generate initialisation code for child groups for (auto &cs : m_MergedCurrentSourceGroups) { @@ -527,6 +529,18 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, Environment sg.generate(backend, env, *this, modelMerged); } } +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return (isParamValueHeterogeneous(paramName, + [varName](const NeuronGroupInternal &sg) { return sg.getVarInitialisers().at(varName).getParams(); })); +} +//---------------------------------------------------------------------------- +bool NeuronInitGroupMerged::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return (isParamValueHeterogeneous(paramName, + [varName](const NeuronGroupInternal &sg){ return sg.getVarInitialisers().at(varName).getDerivedParams(); })); +} //-------------------------------------------------------------------------- void NeuronInitGroupMerged::genInitSpikeCount(const BackendBase &backend, EnvironmentExternalBase &env, bool spikeEvent, unsigned int batchSize) @@ -603,18 +617,6 @@ void NeuronInitGroupMerged::genInitSpikeTime(const BackendBase &backend, Environ batchSize, getArchetype().isDelayRequired(), getArchetype().getNumDelaySlots()); }); } -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return (isParamValueHeterogeneous(paramName, - [varName](const NeuronGroupInternal &sg) { return sg.getVarInitialisers().at(varName).getParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return (isParamValueHeterogeneous(paramName, - [varName](const NeuronGroupInternal &sg){ return sg.getVarInitialisers().at(varName).getDerivedParams(); })); -} //---------------------------------------------------------------------------- // GeNN::CodeGenerator::SynapseInitGroupMerged @@ -702,7 +704,7 @@ void SynapseConnectivityInitGroupMerged::generateSparseColumnInit(const BackendB genInitConnectivity(backend, env, false); } //---------------------------------------------------------------------------- -void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase &backend, EnvironmentExternalBase &env) +void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); @@ -712,36 +714,12 @@ void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase &b groupEnv.add(Type::Uint32, "id_kernel", "kernelInd", {groupEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(groupEnv) + ";")}); - for(const auto &var : getArchetype().getWUModel()->getVars()) { - const auto &varInit = getArchetype().getWUVarInitialisers().at(var.name); - - // If this variable require a kernel - if(varInit.getSnippet()->requiresKernel()) { - CodeStream::Scope b(os); - - popSubs.addParamValueSubstitution(varInit.getSnippet()->getParamNames(), varInit.getParams(), - [&var, this](const std::string &p) { return isWUVarInitParamHeterogeneous(var.name, p); }, - "", "group->", var.name); - popSubs.addVarValueSubstitution(varInit.getSnippet()->getDerivedParams(), varInit.getDerivedParams(), - [&var, this](const std::string &p) { return isWUVarInitDerivedParamHeterogeneous(var.name, p); }, - "", "group->", var.name); - popSubs.addVarNameSubstitution(varInit.getSnippet()->getExtraGlobalParams(), - "", "group->", var.name); - - // Generate initial value into temporary variable - os << var.type.resolve(getTypeContext()).getName() << " initVal;" << std::endl; - popSubs.addVarSubstitution("value", "initVal"); - std::string code = varInit.getSnippet()->getCode(); - //popSubs.applyCheckUnreplaced(code, "initVar : merged" + vars[k].name + std::to_string(sg.getIndex())); - popSubs.apply(code); - //code = ensureFtype(code, modelMerged.getModel().getPrecision()); - os << code << std::endl; - - // Fill value across all batches - genVariableFill(os, var.name, "initVal", popSubs["id_syn"], "group->numSrcNeurons * group->rowStride", - getVarAccessDuplication(var.access), modelMerged.getModel().getBatchSize()); - } - } + // Initialise single (hence empty lambda function) synapse variable + genInitWUVarCode(backend, groupEnv, *this, groupEnv["num_pre"] + " * " + groupEnv["_row_stride"], modelMerged.getModel().getBatchSize(), + [](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) + { + handler(varInitEnv); + }); } //---------------------------------------------------------------------------- void SynapseConnectivityInitGroupMerged::genInitConnectivity(const BackendBase &backend, EnvironmentExternalBase &env, bool rowNotColumns) @@ -791,58 +769,6 @@ void SynapseConnectivityInitGroupMerged::genInitConnectivity(const BackendBase & // CodeGenerator::SynapseConnectivityHostInitGroupMerged //---------------------------------------------------------------------------- const std::string SynapseConnectivityHostInitGroupMerged::name = "SynapseConnectivityHostInit"; -//------------------------------------------------------------------------ -SynapseConnectivityHostInitGroupMerged::SynapseConnectivityHostInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: GroupMerged(index, typeContext, groups) -{ - - // **TODO** these could be generic - addField(Uint32, "numSrcNeurons", - [](const auto &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); - addField(Uint32, "numTrgNeurons", - [](const auto &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); - addField(Uint32, "rowStride", - [&backend](const auto &sg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(sg)); }); - - // Add heterogeneous connectivity initialiser model parameters - addHeterogeneousParams( - getArchetype().getConnectivityInitialiser().getSnippet()->getParamNames(), "", - [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getParams(); }, - &SynapseConnectivityHostInitGroupMerged::isConnectivityInitParamHeterogeneous); - - // Add heterogeneous connectivity initialiser derived parameters - addHeterogeneousDerivedParams( - getArchetype().getConnectivityInitialiser().getSnippet()->getDerivedParams(), "", - [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); }, - &SynapseConnectivityHostInitGroupMerged::isConnectivityInitDerivedParamHeterogeneous); - - // Add EGP pointers to struct for both host and device EGPs if they are seperate - const auto egps = getArchetype().getConnectivityInitialiser().getSnippet()->getExtraGlobalParams(); - for(const auto &e : egps) { - const auto &pointerToPointerToEGP = e.type.resolve(getTypeContext()).createPointer().createPointer(); - addField(pointerToPointerToEGP, e.name, - [e](const SynapseGroupInternal &g, size_t) { return "&" + e.name + g.getName(); }, - GroupMergedFieldType::HOST_DYNAMIC); - - if(!backend.getDeviceVarPrefix().empty()) { - addField(pointerToPointerToEGP, backend.getDeviceVarPrefix() + e.name, - [e, &backend](const SynapseGroupInternal &g, size_t) - { - return "&" + backend.getDeviceVarPrefix() + e.name + g.getName(); - }, - GroupMergedFieldType::DYNAMIC); - } - if(!backend.getHostVarPrefix().empty()) { - addField(pointerToPointerToEGP, backend.getHostVarPrefix() + e.name, - [e, &backend](const SynapseGroupInternal &g, size_t) - { - return "&" + backend.getHostVarPrefix() + e.name + g.getName(); - }, - GroupMergedFieldType::DYNAMIC); - } - } -} //------------------------------------------------------------------------- void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { @@ -876,19 +802,69 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac groupEnv.addConnectInitParams("", &SynapseGroupInternal::getConnectivityInitialiser, &SynapseConnectivityHostInitGroupMerged::isConnectivityInitParamHeterogeneous); groupEnv.addConnectInitDerivedParams("", &SynapseGroupInternal::getConnectivityInitialiser, - &SynapseConnectivityHostInitGroupMerged::isSparseConnectivityInitDerivedParamHeterogeneous); + &SynapseConnectivityHostInitGroupMerged::isConnectivityInitDerivedParamHeterogeneous); - //subs.addVarNameSubstitution(connectInit.getSnippet()->getExtraGlobalParams(), "", "*group->"); - + /*const auto &pointerToPointerToEGP = e.type.resolve(getTypeContext()).createPointer().createPointer(); + addField(pointerToPointerToEGP, e.name, + [e](const SynapseGroupInternal &g, size_t) { return "&" + e.name + g.getName(); }, + GroupMergedFieldType::HOST_DYNAMIC); + + if(!backend.getDeviceVarPrefix().empty()) { + addField(pointerToPointerToEGP, backend.getDeviceVarPrefix() + e.name, + [e, &backend](const SynapseGroupInternal &g, size_t) + { + return "&" + backend.getDeviceVarPrefix() + e.name + g.getName(); + }, + GroupMergedFieldType::DYNAMIC); + } + if(!backend.getHostVarPrefix().empty()) { + addField(pointerToPointerToEGP, backend.getHostVarPrefix() + e.name, + [e, &backend](const SynapseGroupInternal &g, size_t) + { + return "&" + backend.getHostVarPrefix() + e.name + g.getName(); + }, + GroupMergedFieldType::DYNAMIC); + }*/ // Loop through EGPs for(const auto &egp : connectInit.getSnippet()->getExtraGlobalParams()) { // If EGP is located on the host const auto loc = getArchetype().getSparseConnectivityExtraGlobalParamLocation(egp.name); if(loc & VarLocation::HOST) { + const auto resolvedType = egp.type.resolve(getTypeContext()); + assert(!resolvedType.isPointer()); + const auto pointerType = resolvedType.createPointer(); + const auto pointerToPointerType = pointerType.createPointer(); + + // Add field for host pointer + // **NOTE** use [0] to dereference on access to obta + groupEnv.addField(pointerType, egp.name, + pointerToPointerType, egp.name, + [egp](const auto &g, size_t) { return "&" + egp.name + g.getName(); }, + "[0]", GroupMergedFieldType::HOST_DYNAMIC); + + // If backend requires seperate device variables, add additional (private) field) + if(!backend.getDeviceVarPrefix().empty()) { + groupEnv.addField(pointerToPointerType, "_" + backend.getDeviceVarPrefix() + egp.name, + backend.getDeviceVarPrefix() + egp.name, + [egp](const auto &g, size_t) { return "&" + egp.name + g.getName(); }, + "", GroupMergedFieldType::DYNAMIC); + } + + // If backend requires seperate host variables, add additional (private) field) + if(!backend.getHostVarPrefix().empty()) { + groupEnv.addField(pointerToPointerType, "_" + backend.getHostVarPrefix() + egp.name, + backend.getHostVarPrefix() + egp.name, + [egp, &backend](const SynapseGroupInternal &g, size_t) + { + return "&" + backend.getHostVarPrefix() + egp.name + g.getName(); + }, + "", GroupMergedFieldType::DYNAMIC); + } + // Generate code to allocate this EGP with count specified by $(0) // **NOTE** we generate these with a pointer type as the fields are pointer to pointer std::stringstream allocStream; - const auto &pointerToEGP = egp.type.resolve(getTypeContext()).createPointer(); + const auto &pointerToEGP = resolvedType.createPointer(); CodeGenerator::CodeStream alloc(allocStream); backend.genVariableDynamicAllocation(alloc, pointerToEGP, egp.name, @@ -954,8 +930,8 @@ boost::uuids::detail::sha1::digest_type CustomUpdateInitGroupMerged::getHashDige void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { // Initialise custom update variables - genInitNeuronVarCode(backend, env, *this, "", "size", 1, - getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1); + genInitNeuronVarCode(backend, env, *this, "", "size", 1, + getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1); } // ---------------------------------------------------------------------------- @@ -996,68 +972,71 @@ boost::uuids::detail::sha1::digest_type CustomWUUpdateInitGroupMerged::getHashDi return hash.get_digest(); } // ---------------------------------------------------------------------------- -void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { - /*if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL) { + EnvironmentGroupMergedField groupEnv(env, *this); + + const bool kernel = (getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL); + if(kernel) { // Loop through kernel size dimensions for (size_t d = 0; d < getArchetype().getSynapseGroup()->getKernelSize().size(); d++) { // If this dimension has a heterogeneous size, add it to struct if (isKernelSizeHeterogeneous(d)) { - addField(Uint32, "kernelSize" + std::to_string(d), - [d](const auto &g, size_t) { return std::to_string(g.getSynapseGroup()->getKernelSize().at(d)); }); + groupEnv.addField(Type::Uint32, "_kernel_size_" + std::to_string(d), "kernelSize" + std::to_string(d), + [d](const auto &g, size_t) { return std::to_string(g.getSynapseGroup()->getKernelSize().at(d)); }); } } - } - else { - addField(Uint32, "rowStride", - [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); - addField(Uint32, "numSrcNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); - addField(Uint32, "numTrgNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); - }*/ - const bool kernel = (getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL); - if(kernel && modelMerged.getModel().getBatchSize() > 1) { - // Loop through kernel dimensions and multiply together to calculate batch stride - os << "const unsigned int batchStride = "; - const auto &kernelSize = getArchetype().getSynapseGroup()->getKernelSize(); - for (size_t i = 0; i < kernelSize.size(); i++) { - os << getKernelSize(i); - - if (i != (kernelSize.size() - 1)) { - os << " * "; + if(modelMerged.getModel().getBatchSize() > 1) { + // Loop through kernel dimensions and multiply together to calculate batch stride + std::ostringstream batchStrideInit; + batchStrideInit << "const unsigned int batchStride = "; + const auto &kernelSize = getArchetype().getSynapseGroup()->getKernelSize(); + for (size_t i = 0; i < kernelSize.size(); i++) { + batchStrideInit << getKernelSize(i); + + if (i != (kernelSize.size() - 1)) { + batchStrideInit << " * "; + } } + batchStrideInit << ";" << std::endl; + groupEnv.add(Type::Uint32.addConst(), "_batch_stride", "batchStride", + {groupEnv.addInitialiser(batchStrideInit.str())}); } - os << ";" << std::endl; } - - if(!kernel) { - os << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; - os << CodeStream::OB(3); - popSubs.addVarSubstitution("id_pre", "i"); + else { + groupEnv.addField(Type::Uint32.addConst(), "num_pre", + Type::Uint32, "numSrcNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); + groupEnv.addField(Type::Uint32.addConst(), "num_post", + Type::Uint32, "numTrgNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); + groupEnv.addField(Type::Uint32, "_row_stride", "rowStride", + [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); + + + groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)"; + groupEnv.getStream() << CodeStream::OB(3); + groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); } // Loop through rows - const std::string stride = kernel ? "batchStride" : "group->numSrcNeurons * group->rowStride"; - genInitWUVarCode(os, modelMerged, popSubs, getArchetype().getCustomUpdateModel()->getVars(), - getArchetype().getVarInitialisers(), stride, getIndex(), - getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, - [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, - [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }, - [&backend, kernel, this](CodeStream &os, const Substitutions &kernelSubs, BackendBase::Handler handler) - { - if (kernel) { - backend.genKernelCustomUpdateVariableInit(os, *this, kernelSubs, handler); - } - else { - backend.genDenseSynapseVariableRowInit(os, kernelSubs, handler); - } + const std::string stride = kernel ? groupEnv["_batch_stride"] : groupEnv["num_pre"] + " * " + groupEnv["_row_stride"]; + genInitWUVarCode( + backend, groupEnv, *this, stride, getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, + [&backend, kernel, this](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) + { + if (kernel) { + backend.genKernelCustomUpdateVariableInit(varInitEnv, *this, handler); + } + else { + backend.genDenseSynapseVariableRowInit(varInitEnv, handler); + } - }); - + }); + if(!kernel) { - os << CodeStream::CB(3); + groupEnv.getStream() << CodeStream::CB(3); } } @@ -1092,8 +1071,11 @@ boost::uuids::detail::sha1::digest_type CustomWUUpdateSparseInitGroupMerged::get return hash.get_digest(); } // ---------------------------------------------------------------------------- -void CustomWUUpdateSparseInitGroupMerged::generateInit(const BackendBase &backend, CodeStream &os, const ModelSpecMerged &modelMerged, Substitutions &popSubs) const +void CustomWUUpdateSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { + // Create environment for group + EnvironmentGroupMergedField groupEnv(env, *this); + /* addField(Uint32, "rowStride", [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); @@ -1115,15 +1097,12 @@ void CustomWUUpdateSparseInitGroupMerged::generateInit(const BackendBase &backen return backend.getDeviceVarPrefix() + "ind" + sg->getName(); });*/ - genInitWUVarCode(os, modelMerged, popSubs, getArchetype().getCustomUpdateModel()->getVars(), - getArchetype().getVarInitialisers(), "group->numSrcNeurons * group->rowStride", getIndex(), - getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, - [this](const std::string &v, const std::string &p) { return isVarInitParamHeterogeneous(v, p); }, - [this](const std::string &v, const std::string &p) { return isVarInitDerivedParamHeterogeneous(v, p); }, - [&backend](CodeStream &os, const Substitutions &kernelSubs, BackendBase::Handler handler) - { - return backend.genSparseSynapseVariableRowInit(os, kernelSubs, handler); - }); + genInitWUVarCode(backend, groupEnv, *this, groupEnv["num_pre"] + " * " + groupEnv["_row_stride"], + getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, + [&backend](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) + { + return backend.genSparseSynapseVariableRowInit(varInitEnv, handler); + }); } // ---------------------------------------------------------------------------- @@ -1160,8 +1139,7 @@ void CustomConnectivityUpdatePreInitGroupMerged::generateInit(const BackendBase }); // Initialise presynaptic custom connectivity update variables - genInitNeuronVarCode( - backend, groupEnv, *this, "", "size", 0, modelMerged.getModel().getBatchSize()); + genInitNeuronVarCode(backend, groupEnv, *this, "", "size", 0, modelMerged.getModel().getBatchSize()); } // ---------------------------------------------------------------------------- @@ -1188,7 +1166,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePostInitGroupMer void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { // Create environment for group - EnvironmentGroupMergedField groupEnv(env, *this); + EnvironmentGroupMergedField groupEnv(env, *this); groupEnv.addField(Type::Uint32.addConst(), "size", Type::Uint32, "size", @@ -1198,8 +1176,7 @@ void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase }); // Initialise presynaptic custom connectivity update variables - genInitNeuronVarCode( - backend, groupEnv, *this, "", "size", 0, modelMerged.getModel().getBatchSize()); + genInitNeuronVarCode(backend, groupEnv, *this, "", "size", 0, modelMerged.getModel().getBatchSize()); } // ---------------------------------------------------------------------------- @@ -1240,10 +1217,10 @@ void CustomConnectivityUpdateSparseInitGroupMerged::generateInit(const BackendBa groupEnv.addField(Type::Uint32.addConst(), "num_pre", Type::Uint32, "numSrcNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup().getSrcNeuronGroup()->getNumNeurons()); }); + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); groupEnv.addField(Type::Uint32.addConst(), "num_post", Type::Uint32, "numTrgNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup().getTrgNeuronGroup()->getNumNeurons()); }); + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); groupEnv.addField(Type::Uint32, "_row_stride", "rowStride", [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); @@ -1259,8 +1236,8 @@ void CustomConnectivityUpdateSparseInitGroupMerged::generateInit(const BackendBa }); // Initialise custom connectivity update variables - genInitWUVarCode( - backend, groupEnv, *this, "group->numSrcNeurons * group->rowStride", 1, + genInitWUVarCode( + backend, groupEnv, *this, groupEnv["num_pre"] + " * " + groupEnv["_row_stride"], 1, [&backend](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { return backend.genSparseSynapseVariableRowInit(varInitEnv, handler); From c9880fece5eae3dd764322b52d95d21db46afb0f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 20 Jun 2023 15:34:35 +0100 Subject: [PATCH 239/725] boilerplate purge --- .../genn/genn/code_generator/codeGenUtils.h | 29 +- .../code_generator/customUpdateGroupMerged.h | 24 -- .../genn/genn/code_generator/groupMerged.h | 27 +- .../genn/code_generator/initGroupMerged.h | 323 ++++++------------ include/genn/genn/customUpdate.h | 2 + include/genn/genn/customUpdateInternal.h | 1 + .../genn/code_generator/initGroupMerged.cc | 169 ++------- 7 files changed, 161 insertions(+), 414 deletions(-) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 2dba711f78..357d8e92fc 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -182,46 +182,45 @@ void neuronSubstitutionsInSynapticCode(CodeGenerator::Substitutions &substitutio substitutions.addVarNameSubstitution(nm->getExtraGlobalParams(), sourceSuffix, "group->", destSuffix); } -template -bool isKernelSizeHeterogeneous(const G *group, size_t dimensionIndex, K getKernelSizeFn) +template +bool isKernelSizeHeterogeneous(const G &group, size_t dimensionIndex) { // Get size of this kernel dimension for archetype - const unsigned archetypeValue = getKernelSizeFn(group->getArchetype()).at(dimensionIndex); + const unsigned archetypeValue = group.getArchetype().getKernelSize().at(dimensionIndex); // Return true if any of the other groups have a different value - return std::any_of(group->getGroups().cbegin(), group->getGroups().cend(), - [archetypeValue, dimensionIndex, getKernelSizeFn] + return std::any_of(group.getGroups().cbegin(), group.getGroups().cend(), + [archetypeValue, dimensionIndex] (const typename G::GroupInternal& g) { - return (getKernelSizeFn(g).at(dimensionIndex) != archetypeValue); + return (g.getKernelSize().at(dimensionIndex) != archetypeValue); }); } -template -std::string getKernelSize(const G *group, size_t dimensionIndex, K getKernelSizeFn) +template +std::string getKernelSize(const G &group, size_t dimensionIndex) { // If kernel size if heterogeneous in this dimension, return group structure entry - if (isKernelSizeHeterogeneous(group, dimensionIndex, getKernelSizeFn)) { + if (isKernelSizeHeterogeneous(group, dimensionIndex)) { return "group->kernelSize" + std::to_string(dimensionIndex); } // Otherwise, return literal else { - return std::to_string(getKernelSizeFn(group->getArchetype()).at(dimensionIndex)); + return std::to_string(group.getArchetype().getKernelSize().at(dimensionIndex)); } } -template -std::string getKernelIndex(const G *group, EnvironmentExternalBase &env, - K getKernelSizeFn) +template +std::string getKernelIndex(const G &group, EnvironmentExternalBase &env) { // Loop through kernel dimensions to calculate array index - const auto &kernelSize = getKernelSizeFn(group->getArchetype()); + const auto &kernelSize = group.getArchetype().getKernelSize(); std::ostringstream kernelIndex; for (size_t i = 0; i < kernelSize.size(); i++) { kernelIndex << "(" << env["id_kernel_" + std::to_string(i)]; // Loop through remainining dimensions of kernel and multiply for (size_t j = i + 1; j < kernelSize.size(); j++) { - kernelIndex << " * " << getKernelSize(group, j, getKernelSizeFn); + kernelIndex << " * " << getKernelSize(group, j); } kernelIndex << ")"; diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index a6c2c3f66e..1c5c111079 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -66,32 +66,8 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMergedkernelSizeXXX) - std::string getKernelSize(size_t dimensionIndex) const - { - return CodeGenerator::getKernelSize(this, dimensionIndex, getGroupKernelSize); - } - - //! Generate an index into a kernel based on the id_kernel_XXX variables in subs - std::string getKernelIndex(EnvironmentExternalBase &env) const - { - return CodeGenerator::getKernelIndex(this, env, getGroupKernelSize); - } - protected: void generateCustomUpdateBase(const BackendBase &backend, EnvironmentExternalBase &env); - -private: - static const std::vector& getGroupKernelSize(const CustomUpdateWUInternal& g) - { - return g.getSynapseGroup()->getKernelSize(); - } }; // ---------------------------------------------------------------------------- diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 31635dd955..5b133b9182 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -128,7 +128,7 @@ class ChildGroupMerged } //! Helper to test whether parameter values are heterogeneous within merged group - template + /*template bool isParamValueHeterogeneous(size_t index, P getParamValuesFn) const { // Get value of parameter in archetype group @@ -140,7 +140,7 @@ class ChildGroupMerged { return (getParamValuesFn(g).at(index) != archetypeValue); }); - } + }*/ //! Helper to update hash with the hash of calling getHashableFn on each group template @@ -616,24 +616,6 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMergedkernelSizeXXX) - std::string getKernelSize(size_t dimensionIndex) const - { - return CodeGenerator::getKernelSize(this, dimensionIndex, getGroupKernelSize); - } - - //! Generate an index into a kernel based on the id_kernel_XXX variables in subs - std::string getKernelIndex(EnvironmentExternalBase &env) const - { - return CodeGenerator::getKernelIndex(this, env, getGroupKernelSize); - } - std::string getPreSlot(unsigned int batchSize) const; std::string getPostSlot(unsigned int batchSize) const; @@ -731,11 +713,6 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged& getGroupKernelSize(const SynapseGroupInternal& g) - { - return g.getKernelSize(); - } - //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index f08c15b820..9d6bf52db4 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -4,21 +4,80 @@ #include "code_generator/groupMerged.h" //---------------------------------------------------------------------------- -// GeNN::CodeGenerator::NeuronInitGroupMerged +// GeNN::CodeGenerator::InitGroupMergedBase //---------------------------------------------------------------------------- namespace GeNN::CodeGenerator { -class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase +template +class GENN_EXPORT InitGroupMergedBase : public B +{ +public: + using B::B; + + //---------------------------------------------------------------------------- + // Public API + //---------------------------------------------------------------------------- + //! Should the var init parameter be implemented heterogeneously? + bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const + { + return (isVarInitParamReferenced(varName, paramName) && + this->isParamValueHeterogeneous(paramName, + [&varName](const auto &g) + { + return A(g).getInitialisers().at(varName).getParams(); + })); + } + + //! Should the var init derived parameter be implemented heterogeneously? + bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const + { + return (isVarInitParamReferenced(varName, paramName) && + this->isParamValueHeterogeneous(paramName, + [&varName](const auto &g) + { + return A(g).getInitialisers().at(varName).getDerivedParams(); + })); + } +protected: + //---------------------------------------------------------------------------- + // Protected methods + //---------------------------------------------------------------------------- + void updateBaseHash(boost::uuids::detail::sha1 &hash) const + { + // Update hash with each group's variable initialisation parameters and derived parameters + this->template updateVarInitParamHash, A>( + &InitGroupMergedBase::isVarInitParamHeterogeneous, hash); + + this->template updateVarInitDerivedParamHash, A>( + &InitGroupMergedBase::isVarInitDerivedParamHeterogeneous, hash); + } + +private: + //---------------------------------------------------------------------------- + // Private methods + //---------------------------------------------------------------------------- + //! Is the var init parameter referenced? + bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const + { + const auto *varInitSnippet = A(this->getArchetype()).getInitialisers().at(varName).getSnippet(); + return this->isParamReferenced({varInitSnippet->getCode()}, paramName); + } +}; + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronInitGroupMerged +//---------------------------------------------------------------------------- +class GENN_EXPORT NeuronInitGroupMerged : public InitGroupMergedBase { public: //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::CurrentSource //---------------------------------------------------------------------------- //! Child group merged for current sources attached to this neuron update group - class CurrentSource : public ChildGroupMerged + class CurrentSource : public InitGroupMergedBase, CurrentSourceVarAdapter> { public: - using ChildGroupMerged::ChildGroupMerged; + using InitGroupMergedBase::InitGroupMergedBase; //---------------------------------------------------------------------------- // Public API @@ -27,30 +86,22 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged); //! Update hash with child groups - void updateHash(boost::uuids::detail::sha1 &hash) const; - - //! Should the var init parameter be implemented heterogeneously? - bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - //! Should the var init derived parameter be implemented heterogeneously? - bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + void updateHash(boost::uuids::detail::sha1 &hash) const + { + updateBaseHash(hash); + Utils::updateHash(getArchetype().getInitHashDigest(), hash); - private: - //---------------------------------------------------------------------------- - // Private methods - //---------------------------------------------------------------------------- - //! Is the var init parameter referenced? - bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; + } }; //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::InSynPSM //---------------------------------------------------------------------------- //! Child group merged for incoming synapse groups - class InSynPSM : public ChildGroupMerged + class InSynPSM : public InitGroupMergedBase, SynapsePSMVarAdapter> { public: - using ChildGroupMerged::ChildGroupMerged; + using InitGroupMergedBase::InitGroupMergedBase; //---------------------------------------------------------------------------- // Public API @@ -59,20 +110,11 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged); //! Update hash with child groups - void updateHash(boost::uuids::detail::sha1 &hash) const; - - //! Should the var init parameter be implemented heterogeneously? - bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - //! Should the var init derived parameter be implemented heterogeneously? - bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - private: - //---------------------------------------------------------------------------- - // Private methods - //---------------------------------------------------------------------------- - //! Is the var init parameter referenced? - bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; + void updateHash(boost::uuids::detail::sha1 &hash) const + { + updateBaseHash(hash); + Utils::updateHash(getArchetype().getPSInitHashDigest(), hash); + } }; //---------------------------------------------------------------------------- @@ -95,10 +137,10 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase // GeNN::CodeGenerator::NeuronInitGroupMerged::InSynWUMPostCode //---------------------------------------------------------------------------- //! Child group merged for incoming synapse groups with postsynaptic variables - class InSynWUMPostVars : public ChildGroupMerged + class InSynWUMPostVars : public InitGroupMergedBase, SynapseWUPostVarAdapter> { public: - using ChildGroupMerged::ChildGroupMerged; + using InitGroupMergedBase::InitGroupMergedBase; //---------------------------------------------------------------------------- // Public API @@ -107,30 +149,21 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged); //! Update hash with child groups - void updateHash(boost::uuids::detail::sha1 &hash) const; - - //! Should the var init parameter be implemented heterogeneously? - bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - //! Should the var init derived parameter be implemented heterogeneously? - bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - private: - //---------------------------------------------------------------------------- - // Private methods - //---------------------------------------------------------------------------- - //! Is the var init parameter referenced? - bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; + void updateHash(boost::uuids::detail::sha1 &hash) const + { + updateBaseHash(hash); + Utils::updateHash(getArchetype().getWUPostInitHashDigest(), hash); + } }; //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynWUMPreVars //---------------------------------------------------------------------------- //! Child group merged for outgoing synapse groups with presynaptic variables - class OutSynWUMPreVars: public ChildGroupMerged + class OutSynWUMPreVars: public InitGroupMergedBase, SynapseWUPreVarAdapter> { public: - using ChildGroupMerged::ChildGroupMerged; + using InitGroupMergedBase::InitGroupMergedBase; //---------------------------------------------------------------------------- // Public API @@ -139,23 +172,14 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged); //! Update hash with child groups - void updateHash(boost::uuids::detail::sha1 &hash) const; - - //! Should the var init parameter be implemented heterogeneously? - bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - //! Should the var init derived parameter be implemented heterogeneously? - bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - private: - //---------------------------------------------------------------------------- - // Private methods - //---------------------------------------------------------------------------- - //! Is the var init parameter referenced? - bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const; + void updateHash(boost::uuids::detail::sha1 &hash) const + { + updateBaseHash(hash); + Utils::updateHash(getArchetype().getWUPreInitHashDigest(), hash); + } }; - NeuronInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + NeuronInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -181,12 +205,6 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase const std::vector &getMergedInSynWUMPostVarGroups() const { return m_MergedInSynWUMPostVarGroups; } const std::vector &getMergedOutSynWUMPreVarGroups() const { return m_MergedOutSynWUMPreVarGroups; } - //! Should the var init parameter be implemented heterogeneously? - bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - //! Should the var init derived parameter be implemented heterogeneously? - bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - //---------------------------------------------------------------------------- // Static constants //---------------------------------------------------------------------------- @@ -216,18 +234,12 @@ class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase //---------------------------------------------------------------------------- // GeNN::CodeGenerator::SynapseInitGroupMerged //---------------------------------------------------------------------------- -class GENN_EXPORT SynapseInitGroupMerged : public SynapseGroupMergedBase +class GENN_EXPORT SynapseInitGroupMerged : public InitGroupMergedBase, SynapseWUVarAdapter> { public: - SynapseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, - const std::vector> &groups) - : SynapseGroupMergedBase(index, typeContext, SynapseGroupMergedBase::Role::Init, "", groups) - {} + using InitGroupMergedBase::InitGroupMergedBase; - boost::uuids::detail::sha1::digest_type getHashDigest() const - { - return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::Init); - } + boost::uuids::detail::sha1::digest_type getHashDigest() const; void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, @@ -249,18 +261,10 @@ class GENN_EXPORT SynapseInitGroupMerged : public SynapseGroupMergedBase //---------------------------------------------------------------------------- // GeNN::CodeGenerator::SynapseSparseInitGroupMerged //---------------------------------------------------------------------------- -class GENN_EXPORT SynapseSparseInitGroupMerged : public SynapseGroupMergedBase +class GENN_EXPORT SynapseSparseInitGroupMerged : public InitGroupMergedBase, SynapseWUVarAdapter> { public: - SynapseSparseInitGroupMerged(size_t index, const Type::TypeContext &typeContext, - const std::vector> &groups) - : SynapseGroupMergedBase(index, typeContext, SynapseGroupMergedBase::Role::SparseInit, "", groups) - {} - - boost::uuids::detail::sha1::digest_type getHashDigest() const - { - return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::SparseInit); - } + boost::uuids::detail::sha1::digest_type getHashDigest() const; void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, @@ -363,85 +367,14 @@ class GENN_EXPORT SynapseConnectivityHostInitGroupMerged : public GroupMerged -class CustomUpdateInitGroupMergedBase : public GroupMerged -{ -public: - //---------------------------------------------------------------------------- - // Public API - //---------------------------------------------------------------------------- - //! Should the var init parameter be implemented heterogeneously? - bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const - { - return (isVarInitParamReferenced(varName, paramName) && - this->isParamValueHeterogeneous(paramName, - [&varName](const G &cg) - { - A archetypeAdaptor(cg); - return archetypeAdaptor.getInitialisers().at(varName).getParams(); - })); - } - - //! Should the var init derived parameter be implemented heterogeneously? - bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const - { - return (isVarInitParamReferenced(varName, paramName) && - this->isParamValueHeterogeneous(paramName, - [&varName](const G &cg) - { - A archetypeAdaptor(cg); - return archetypeAdaptor.getInitialisers().at(varName).getDerivedParams(); - })); - } -protected: - //---------------------------------------------------------------------------- - // Protected methods - //---------------------------------------------------------------------------- - void updateBaseHash(boost::uuids::detail::sha1 &hash) const - { - // Update hash with archetype's hash digest - Utils::updateHash(this->getArchetype().getInitHashDigest(), hash); - - // Update hash with each group's variable initialisation parameters and derived parameters - this->template updateVarInitParamHash, A>( - &CustomUpdateInitGroupMergedBase::isVarInitParamHeterogeneous, hash); - - this->template updateVarInitDerivedParamHash, A>( - &CustomUpdateInitGroupMergedBase::isVarInitDerivedParamHeterogeneous, hash); - } - -private: - //---------------------------------------------------------------------------- - // Private methods - //---------------------------------------------------------------------------- - //! Is the var init parameter referenced? - bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const - { - A archetypeAdaptor(this->getArchetype()); - const auto *varInitSnippet = archetypeAdaptor.getInitialisers().at(varName).getSnippet(); - return this->isParamReferenced({varInitSnippet->getCode()}, paramName); - } - - //! Is the var init derived parameter referenced? - bool isVarInitDerivedParamReferenced(const std::string &varName, const std::string ¶mName) const - { - A archetypeAdaptor(this->getArchetype()); - const auto *varInitSnippet = archetypeAdaptor.getInitialisers().at(varName).getSnippet(); - return this->isParamReferenced({varInitSnippet->getCode()}, paramName); - } -}; - // ---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomUpdateInitGroupMerged //---------------------------------------------------------------------------- -class GENN_EXPORT CustomUpdateInitGroupMerged : public CustomUpdateInitGroupMergedBase +class GENN_EXPORT CustomUpdateInitGroupMerged : public InitGroupMergedBase, + CustomUpdateVarAdapter> { public: - using CustomUpdateInitGroupMergedBase::CustomUpdateInitGroupMergedBase; + using InitGroupMergedBase::InitGroupMergedBase; //---------------------------------------------------------------------------- // Public API @@ -469,12 +402,11 @@ class GENN_EXPORT CustomUpdateInitGroupMerged : public CustomUpdateInitGroupMerg // ---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomWUUpdateInitGroupMerged //---------------------------------------------------------------------------- -class GENN_EXPORT CustomWUUpdateInitGroupMerged : public CustomUpdateInitGroupMergedBase +class GENN_EXPORT CustomWUUpdateInitGroupMerged : public InitGroupMergedBase, + CustomUpdateVarAdapter> { public: - CustomWUUpdateInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + using InitGroupMergedBase::InitGroupMergedBase; //---------------------------------------------------------------------------- // Public API @@ -492,47 +424,20 @@ class GENN_EXPORT CustomWUUpdateInitGroupMerged : public CustomUpdateInitGroupMe void generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); - //! Is kernel size heterogeneous in this dimension? - bool isKernelSizeHeterogeneous(size_t dimensionIndex) const - { - return CodeGenerator::isKernelSizeHeterogeneous(this, dimensionIndex, getGroupKernelSize); - } - - //! Get expression for kernel size in dimension (may be literal or group->kernelSizeXXX) - std::string getKernelSize(size_t dimensionIndex) const - { - return CodeGenerator::getKernelSize(this, dimensionIndex, getGroupKernelSize); - } - - //! Generate an index into a kernel based on the id_kernel_XXX variables in subs - std::string getKernelIndex(EnvironmentExternalBase &env) const - { - return CodeGenerator::getKernelIndex(this, env, getGroupKernelSize); - } - //---------------------------------------------------------------------------- // Static constants //---------------------------------------------------------------------------- static const std::string name; - -private: - //---------------------------------------------------------------------------- - // Private static methods - //---------------------------------------------------------------------------- - static const std::vector &getGroupKernelSize(const CustomUpdateWUInternal &g) - { - return g.getSynapseGroup()->getKernelSize(); - } }; // ---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomWUUpdateSparseInitGroupMerged //---------------------------------------------------------------------------- -class GENN_EXPORT CustomWUUpdateSparseInitGroupMerged : public CustomUpdateInitGroupMergedBase +class GENN_EXPORT CustomWUUpdateSparseInitGroupMerged : public InitGroupMergedBase, + CustomUpdateVarAdapter> { public: - using CustomUpdateInitGroupMergedBase::CustomUpdateInitGroupMergedBase; + using InitGroupMergedBase::InitGroupMergedBase; //---------------------------------------------------------------------------- // Public API @@ -559,12 +464,11 @@ class GENN_EXPORT CustomWUUpdateSparseInitGroupMerged : public CustomUpdateInitG //---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomConnectivityUpdatePreInitGroupMerged //---------------------------------------------------------------------------- -class GENN_EXPORT CustomConnectivityUpdatePreInitGroupMerged : public CustomUpdateInitGroupMergedBase +class GENN_EXPORT CustomConnectivityUpdatePreInitGroupMerged : public InitGroupMergedBase, + CustomConnectivityUpdatePreVarAdapter> { public: - CustomConnectivityUpdatePreInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + InitGroupMergedBase::InitGroupMergedBase; //---------------------------------------------------------------------------- // Public API @@ -591,12 +495,11 @@ class GENN_EXPORT CustomConnectivityUpdatePreInitGroupMerged : public CustomUpda //---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomConnectivityUpdatePostInitGroupMerged //---------------------------------------------------------------------------- -class GENN_EXPORT CustomConnectivityUpdatePostInitGroupMerged : public CustomUpdateInitGroupMergedBase +class GENN_EXPORT CustomConnectivityUpdatePostInitGroupMerged : public InitGroupMergedBase, + CustomConnectivityUpdatePostVarAdapter> { public: - CustomConnectivityUpdatePostInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + using InitGroupMergedBase::InitGroupMergedBase; //---------------------------------------------------------------------------- // Public API @@ -623,11 +526,11 @@ class GENN_EXPORT CustomConnectivityUpdatePostInitGroupMerged : public CustomUpd //---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomConnectivityUpdateSparseInitGroupMerged //---------------------------------------------------------------------------- -class GENN_EXPORT CustomConnectivityUpdateSparseInitGroupMerged : public CustomUpdateInitGroupMergedBase +class GENN_EXPORT CustomConnectivityUpdateSparseInitGroupMerged : public InitGroupMergedBase, + CustomConnectivityUpdateVarAdapter> { public: - using CustomUpdateInitGroupMergedBase::CustomUpdateInitGroupMergedBase; + using InitGroupMergedBase::InitGroupMergedBase; //---------------------------------------------------------------------------- // Public API diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 84c1526729..c88971c31e 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -297,6 +297,8 @@ class GENN_EXPORT CustomUpdateWU : public CustomUpdateBase SynapseGroupInternal *getSynapseGroup() const { return m_SynapseGroup; } + const std::vector &getKernelSize() const { return getSynapseGroup()->getKernelSize(); } + //! Updates hash with custom update /*! NOTE: this can only be called after model is finalized */ boost::uuids::detail::sha1::digest_type getHashDigest() const; diff --git a/include/genn/genn/customUpdateInternal.h b/include/genn/genn/customUpdateInternal.h index 983c6e245a..81aa4921df 100644 --- a/include/genn/genn/customUpdateInternal.h +++ b/include/genn/genn/customUpdateInternal.h @@ -87,6 +87,7 @@ class CustomUpdateWUInternal : public CustomUpdateWU using CustomUpdateWU::getHashDigest; using CustomUpdateWU::getInitHashDigest; using CustomUpdateWU::getSynapseGroup; + using CustomUpdateWU::getKernelSize; using CustomUpdateWU::isBatchReduction; using CustomUpdateWU::isTransposeOperation; }; diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 2004864a76..c20e79d875 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -203,30 +203,7 @@ void NeuronInitGroupMerged::CurrentSource::generate(const BackendBase &backend, backend, env, *this, ng, "CS" + std::to_string(getIndex()), "num_neurons", 0, modelMerged.getModel().getBatchSize()); } -//---------------------------------------------------------------------------- -void NeuronInitGroupMerged::CurrentSource::updateHash(boost::uuids::detail::sha1 &hash) const -{ - updateVarInitParamHash(&CurrentSource::isVarInitParamReferenced, hash); - updateVarInitDerivedParamHash(&CurrentSource::isVarInitParamReferenced, hash); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::CurrentSource::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, - [varName](const auto &cs){ return cs.getVarInitialisers().at(varName).getParams(); }); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::CurrentSource::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, - [varName](const auto &cs){ return cs.getVarInitialisers().at(varName).getDerivedParams(); }); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::CurrentSource::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const -{ - const auto *varInitSnippet = getArchetype().getVarInitialisers().at(varName).getSnippet(); - return isParamReferenced({varInitSnippet->getCode()}, paramName); -} + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::InSynPSM @@ -278,30 +255,6 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir genInitNeuronVarCode( backend, groupEnv, *this, ng, fieldSuffix, "num_neurons", 0, modelMerged.getModel().getBatchSize()); } -//---------------------------------------------------------------------------- -void NeuronInitGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &hash) const -{ - updateVarInitParamHash(&InSynPSM::isVarInitParamReferenced, hash); - updateVarInitDerivedParamHash(&InSynPSM::isVarInitParamReferenced, hash); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::InSynPSM::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, - [varName](const auto &sg){ return sg.getPSVarInitialisers().at(varName).getParams(); }); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::InSynPSM::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, - [varName](const auto &sg){ return sg.getPSVarInitialisers().at(varName).getDerivedParams(); }); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::InSynPSM::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const -{ - const auto *varInitSnippet = getArchetype().getPSVarInitialisers().at(varName).getSnippet(); - return isParamReferenced({varInitSnippet->getCode()}, paramName); -} //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynPreOutput @@ -335,30 +288,6 @@ void NeuronInitGroupMerged::InSynWUMPostVars::generate(const BackendBase &backen genInitNeuronVarCode( backend, env, *this, ng, "InSynWUMPost" + std::to_string(getIndex()), "num_neurons", 0, modelMerged.getModel().getBatchSize()); } -//---------------------------------------------------------------------------- -void NeuronInitGroupMerged::InSynWUMPostVars::updateHash(boost::uuids::detail::sha1 &hash) const -{ - updateVarInitParamHash(&InSynWUMPostVars::isVarInitParamReferenced, hash); - updateVarInitDerivedParamHash(&InSynWUMPostVars::isVarInitParamReferenced, hash); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::InSynWUMPostVars::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, - [varName](const auto &sg){ return sg.getWUPostVarInitialisers().at(varName).getParams(); }); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::InSynWUMPostVars::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, - [varName](const auto &sg){ return sg.getWUPostVarInitialisers().at(varName).getDerivedParams(); }); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::InSynWUMPostVars::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const -{ - const auto *varInitSnippet = getArchetype().getWUPostVarInitialisers().at(varName).getSnippet(); - return isParamReferenced({varInitSnippet->getCode()}, paramName); -} //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynWUMPreVars @@ -369,39 +298,15 @@ void NeuronInitGroupMerged::OutSynWUMPreVars::generate(const BackendBase &backen genInitNeuronVarCode( backend, env, *this, ng, "OutSynWUMPre" + std::to_string(getIndex()), "num_neurons", 0, modelMerged.getModel().getBatchSize()); } -//---------------------------------------------------------------------------- -void NeuronInitGroupMerged::OutSynWUMPreVars::updateHash(boost::uuids::detail::sha1 &hash) const -{ - updateVarInitParamHash(&OutSynWUMPreVars::isVarInitParamReferenced, hash); - updateVarInitDerivedParamHash(&OutSynWUMPreVars::isVarInitParamReferenced, hash); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::OutSynWUMPreVars::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, - [varName](const auto &sg){ return sg.getWUPreVarInitialisers().at(varName).getParams(); }); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::OutSynWUMPreVars::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, - [varName](const auto &sg){ return sg.getWUPreVarInitialisers().at(varName).getDerivedParams(); }); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::OutSynWUMPreVars::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const -{ - const auto *varInitSnippet = getArchetype().getWUPreVarInitialisers().at(varName).getSnippet(); - return isParamReferenced({varInitSnippet->getCode()}, paramName); -} //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged //---------------------------------------------------------------------------- const std::string NeuronInitGroupMerged::name = "NeuronInit"; //---------------------------------------------------------------------------- -NeuronInitGroupMerged::NeuronInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, +NeuronInitGroupMerged::NeuronInitGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups) -: NeuronGroupMergedBase(index, typeContext, groups) +: InitGroupMergedBase(index, typeContext, groups) { // Build vector of vectors containing each child group's merged in syns, ordered to match those of the archetype group orderNeuronGroupChildren(m_MergedInSynPSMGroups, typeContext, @@ -435,16 +340,15 @@ boost::uuids::detail::sha1::digest_type NeuronInitGroupMerged::getHashDigest() c { boost::uuids::detail::sha1 hash; - /// Update hash with each group's neuron count + // Update hash with standard archetype hash and var init parameters and derived parameters + updateBaseHash(hash); + + // Update hash with each group's neuron count updateHash([](const NeuronGroupInternal &g) { return g.getNumNeurons(); }, hash); // Update hash with archetype's hash digest Utils::updateHash(getArchetype().getInitHashDigest(), hash); - // Update hash with each group's variable initialisation parameters and derived parameters - updateVarInitParamHash(&NeuronInitGroupMerged::isVarInitParamReferenced, hash); - updateVarInitDerivedParamHash(&NeuronInitGroupMerged::isVarInitParamReferenced, hash); - // Update hash with child groups for (const auto &cs : getMergedCurrentSourceGroups()) { cs.updateHash(hash); @@ -529,18 +433,6 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, Environment sg.generate(backend, env, *this, modelMerged); } } -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return (isParamValueHeterogeneous(paramName, - [varName](const NeuronGroupInternal &sg) { return sg.getVarInitialisers().at(varName).getParams(); })); -} -//---------------------------------------------------------------------------- -bool NeuronInitGroupMerged::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return (isParamValueHeterogeneous(paramName, - [varName](const NeuronGroupInternal &sg){ return sg.getVarInitialisers().at(varName).getDerivedParams(); })); -} //-------------------------------------------------------------------------- void NeuronInitGroupMerged::genInitSpikeCount(const BackendBase &backend, EnvironmentExternalBase &env, bool spikeEvent, unsigned int batchSize) @@ -637,7 +529,7 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, Environmen batchStrideInit << "const unsigned int batchStride = "; const auto &kernelSize = getArchetype().getKernelSize(); for (size_t i = 0; i < kernelSize.size(); i++) { - batchStrideInit << getKernelSize(i); + batchStrideInit << getKernelSize(*this, i); if (i != (kernelSize.size() - 1)) { batchStrideInit << " * "; @@ -712,7 +604,7 @@ void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase &b // Add substitution // **TODO** dependencies on kernel fields groupEnv.add(Type::Uint32, "id_kernel", "kernelInd", - {groupEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(groupEnv) + ";")}); + {groupEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(*this, groupEnv) + ";")}); // Initialise single (hence empty lambda function) synapse variable genInitWUVarCode(backend, groupEnv, *this, groupEnv["num_pre"] + " * " + groupEnv["_row_stride"], modelMerged.getModel().getBatchSize(), @@ -804,27 +696,6 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac groupEnv.addConnectInitDerivedParams("", &SynapseGroupInternal::getConnectivityInitialiser, &SynapseConnectivityHostInitGroupMerged::isConnectivityInitDerivedParamHeterogeneous); - /*const auto &pointerToPointerToEGP = e.type.resolve(getTypeContext()).createPointer().createPointer(); - addField(pointerToPointerToEGP, e.name, - [e](const SynapseGroupInternal &g, size_t) { return "&" + e.name + g.getName(); }, - GroupMergedFieldType::HOST_DYNAMIC); - - if(!backend.getDeviceVarPrefix().empty()) { - addField(pointerToPointerToEGP, backend.getDeviceVarPrefix() + e.name, - [e, &backend](const SynapseGroupInternal &g, size_t) - { - return "&" + backend.getDeviceVarPrefix() + e.name + g.getName(); - }, - GroupMergedFieldType::DYNAMIC); - } - if(!backend.getHostVarPrefix().empty()) { - addField(pointerToPointerToEGP, backend.getHostVarPrefix() + e.name, - [e, &backend](const SynapseGroupInternal &g, size_t) - { - return "&" + backend.getHostVarPrefix() + e.name + g.getName(); - }, - GroupMergedFieldType::DYNAMIC); - }*/ // Loop through EGPs for(const auto &egp : connectInit.getSnippet()->getExtraGlobalParams()) { // If EGP is located on the host @@ -921,6 +792,9 @@ boost::uuids::detail::sha1::digest_type CustomUpdateInitGroupMerged::getHashDige // Update hash with generic custom update init data updateBaseHash(hash); + // Update hash with archetype's hash digest + Utils::updateHash(getArchetype().getInitHashDigest(), hash); + // Update hash with size of custom update updateHash([](const CustomUpdateInternal &cg) { return cg.getSize(); }, hash); @@ -946,6 +820,9 @@ boost::uuids::detail::sha1::digest_type CustomWUUpdateInitGroupMerged::getHashDi // Update hash with generic custom update init data updateBaseHash(hash); + // Update hash with archetype's hash digest + Utils::updateHash(getArchetype().getInitHashDigest(), hash); + // If underlying synapse group has kernel weights, update hash with kernel size if(getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL) { updateHash([](const auto &g) { return g.getSynapseGroup()->getKernelSize(); }, hash); @@ -981,7 +858,7 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Env // Loop through kernel size dimensions for (size_t d = 0; d < getArchetype().getSynapseGroup()->getKernelSize().size(); d++) { // If this dimension has a heterogeneous size, add it to struct - if (isKernelSizeHeterogeneous(d)) { + if (isKernelSizeHeterogeneous(*this, d)) { groupEnv.addField(Type::Uint32, "_kernel_size_" + std::to_string(d), "kernelSize" + std::to_string(d), [d](const auto &g, size_t) { return std::to_string(g.getSynapseGroup()->getKernelSize().at(d)); }); } @@ -993,7 +870,7 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Env batchStrideInit << "const unsigned int batchStride = "; const auto &kernelSize = getArchetype().getSynapseGroup()->getKernelSize(); for (size_t i = 0; i < kernelSize.size(); i++) { - batchStrideInit << getKernelSize(i); + batchStrideInit << getKernelSize(*this, i); if (i != (kernelSize.size() - 1)) { batchStrideInit << " * "; @@ -1052,6 +929,9 @@ boost::uuids::detail::sha1::digest_type CustomWUUpdateSparseInitGroupMerged::get // Update hash with generic custom update init data updateBaseHash(hash); + // Update hash with archetype's hash digest + Utils::updateHash(getArchetype().getInitHashDigest(), hash); + // Update hash with sizes of pre and postsynaptic neuron groups; and max row length updateHash([](const auto &cg) { @@ -1117,6 +997,9 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePreInitGroupMerg // Update hash with generic custom update init data updateBaseHash(hash); + // Update hash with archetype's hash digest + Utils::updateHash(getArchetype().getInitHashDigest(), hash); + // Update hash with size of custom update updateHash([](const auto &cg) { @@ -1154,6 +1037,9 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePostInitGroupMer // Update hash with generic custom update init data updateBaseHash(hash); + // Update hash with archetype's hash digest + Utils::updateHash(getArchetype().getInitHashDigest(), hash); + // Update hash with size of custom update updateHash([](const auto &cg) { @@ -1191,6 +1077,9 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateSparseInitGroupM // Update hash with generic custom update init data updateBaseHash(hash); + // Update hash with archetype's hash digest + Utils::updateHash(getArchetype().getInitHashDigest(), hash); + // Update hash with sizes of pre and postsynaptic neuron groups; and max row length updateHash([](const CustomConnectivityUpdateInternal &cg) { From 26ddd2581b91f7fd3e4d0eb8e6d60c2ecfb99ecf Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 20 Jun 2023 18:33:30 +0100 Subject: [PATCH 240/725] slowly approaching a CPU backend which compiles --- .../backends/single_threaded_cpu/backend.cc | 688 +++++++++--------- .../code_generator/customUpdateGroupMerged.cc | 2 +- 2 files changed, 358 insertions(+), 332 deletions(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 93408f32e4..f9b8b099ee 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -845,416 +845,442 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host } //-------------------------------------------------------------------------- -void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genInit(CodeStream &_os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); if(model.getBatchSize() != 1) { throw std::runtime_error("The single-threaded CPU backend only supports simulations with a batch size of 1"); } - // Generate struct definitions - modelMerged.genMergedNeuronInitGroupStructs(os, *this); - modelMerged.genMergedSynapseInitGroupStructs(os, *this); - modelMerged.genMergedCustomUpdateInitGroupStructs(os, *this); - modelMerged.genMergedCustomWUUpdateInitGroupStructs(os, *this); - modelMerged.genMergedSynapseConnectivityInitGroupStructs(os, *this); - modelMerged.genMergedSynapseSparseInitGroupStructs(os, *this); - modelMerged.genMergedCustomWUUpdateSparseInitGroupStructs(os, *this); - modelMerged.genMergedCustomConnectivityUpdatePreInitStructs(os, *this); - modelMerged.genMergedCustomConnectivityUpdatePostInitStructs(os, *this); - modelMerged.genMergedCustomConnectivityUpdateSparseInitStructs(os, *this); + // Generate stream with neuron update code + std::ostringstream initStream; + CodeStream init(initStream); - // Generate arrays of merged structs and functions to set them - genMergedStructArrayPush(os, modelMerged.getMergedNeuronInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedSynapseInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomWUUpdateInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedSynapseConnectivityInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedSynapseSparseInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomWUUpdateSparseInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdatePreInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdatePostInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups()); - - // Generate preamble - preambleHandler(os); + // Begin environment with standard library + EnvironmentLibrary initEnv(init, StandardLibrary::getFunctions()); - os << "void initialize()"; + initEnv.getStream() << "void initialize()"; { - CodeStream::Scope b(os); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision().getName())); + CodeStream::Scope b(initEnv.getStream()); + EnvironmentExternal funcEnv(initEnv); - Timer t(os, "init", model.isTimingEnabled()); - - // If model requires a host RNG, add RNG to substitutions - if(isGlobalHostRNGRequired(modelMerged)) { - funcSubs.addVarSubstitution("rng", "hostRNG"); - } + Timer t(funcEnv.getStream(), "init", model.isTimingEnabled()); - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Neuron groups" << std::endl; - for(const auto &n : modelMerged.getMergedNeuronInitGroups()) { - CodeStream::Scope b(os); - os << "// merged neuron init group " << n.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Neuron groups" << std::endl; + modelMerged.genMergedNeuronInitGroups( + *this, + [this, &funcEnv, &modelMerged](auto &n) { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged neuron init group " << n.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << n.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(funcEnv.getStream()); - // Get reference to group - os << "const auto *group = &mergedNeuronInitGroup" << n.getIndex() << "[g]; " << std::endl; - Substitutions popSubs(&funcSubs); - n.generateInit(*this, os, modelMerged, popSubs); - } - } + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedNeuronInitGroup" << n.getIndex() << "[g]; " << std::endl; + n.generateInit(*this, funcEnv, modelMerged); + } + }); - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Synapse groups" << std::endl; - for(const auto &s : modelMerged.getMergedSynapseInitGroups()) { - CodeStream::Scope b(os); - os << "// merged synapse init group " << s.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Synapse groups" << std::endl; + modelMerged.genMergedSynapseInitGroups( + *this, + [this, &funcEnv, &modelMerged](auto &s) { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged synapse init group " << s.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(osfuncEnv.getStream(); - // Get reference to group - os << "const auto *group = &mergedSynapseInitGroup" << s.getIndex() << "[g]; " << std::endl; - Substitutions popSubs(&funcSubs); - s.generateInit(*this, os, modelMerged, popSubs); - } - } + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedSynapseInitGroup" << s.getIndex() << "[g]; " << std::endl; + s.generateInit(*this, funcEnv, modelMerged); + } + }); - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom update groups" << std::endl; - for(const auto &c : modelMerged.getMergedCustomUpdateInitGroups()) { - CodeStream::Scope b(os); - os << "// merged custom init group " << c.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Custom update groups" << std::endl; + modelMerged.genMergedCustomUpdateInitGroups( + *this, + [this, &funcEnv, &modelMerged](auto &c) { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged custom init group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(funcEnv.getStream()); - // Get reference to group - os << "const auto *group = &mergedCustomUpdateInitGroup" << c.getIndex() << "[g]; " << std::endl; - Substitutions popSubs(&funcSubs); - c.generateInit(*this, os, modelMerged, popSubs); - } - } + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedCustomUpdateInitGroup" << c.getIndex() << "[g]; " << std::endl; + c.generateInit(*this, funcEnv, modelMerged); + } + }); - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom connectivity presynaptic update groups" << std::endl; - for(const auto &c : modelMerged.getMergedCustomConnectivityUpdatePreInitGroups()) { - CodeStream::Scope b(os); - os << "// merged custom connectivity presynaptic init group " << c.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Custom connectivity presynaptic update groups" << std::endl; + modelMerged.genMergedCustomConnectivityUpdatePreInitGroups( + *this, + [this, &funcEnv, &modelMerged](auto &c) { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged custom connectivity presynaptic init group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(funcEnv.getStream()); - // Get reference to group - os << "const auto *group = &mergedCustomConnectivityUpdatePreInitGroup" << c.getIndex() << "[g]; " << std::endl; - Substitutions popSubs(&funcSubs); - c.generateInit(*this, os, modelMerged, popSubs); - } - } + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedCustomConnectivityUpdatePreInitGroup" << c.getIndex() << "[g]; " << std::endl; + c.generateInit(*this, funcEnv, modelMerged); + } + }); - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom connectivity postsynaptic update groups" << std::endl; - for(const auto &c : modelMerged.getMergedCustomConnectivityUpdatePostInitGroups()) { - CodeStream::Scope b(os); - os << "// merged custom connectivity postsynaptic init group " << c.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; - { - CodeStream::Scope b(os); - - // Get reference to group - os << "const auto *group = &mergedCustomConnectivityUpdatePostInitGroup" << c.getIndex() << "[g]; " << std::endl; - Substitutions popSubs(&funcSubs); - c.generateInit(*this, os, modelMerged, popSubs); - } - } - - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom WU update groups" << std::endl; - for(const auto &c : modelMerged.getMergedCustomWUUpdateInitGroups()) { - CodeStream::Scope b(os); - os << "// merged custom WU update group " << c.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Custom connectivity postsynaptic update groups" << std::endl; + modelMerged.genMergedCustomConnectivityUpdatePostInitGroups( + *this, + [this, &funcEnv, &modelMerged](auto &c) { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged custom connectivity postsynaptic init group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(funcEnv.getStream()); - // Get reference to group - os << "const auto *group = &mergedCustomWUUpdateInitGroup" << c.getIndex() << "[g]; " << std::endl; - Substitutions popSubs(&funcSubs); - c.generateInit(*this, os, modelMerged, popSubs); - } - } + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedCustomConnectivityUpdatePostInitGroup" << c.getIndex() << "[g]; " << std::endl; + c.generateInit(*this, os, modelMerged, popSubs); + } + }); - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Synapse sparse connectivity" << std::endl; - for(const auto &s : modelMerged.getMergedSynapseConnectivityInitGroups()) { - CodeStream::Scope b(os); - os << "// merged synapse connectivity init group " << s.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Custom WU update groups" << std::endl; + modelMerged.genMergedCustomWUUpdateInitGroups( + *this, + [this, &funcEnv, &modelMerged](auto &c) { - CodeStream::Scope b(os); - - // Get reference to group - os << "const auto *group = &mergedSynapseConnectivityInitGroup" << s.getIndex() << "[g]; " << std::endl; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged custom WU update group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(os); - // If matrix connectivity is ragged - if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - // Zero row lengths - os << "memset(group->rowLength, 0, group->numSrcNeurons * sizeof(unsigned int));" << std::endl; - } - else if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { - os << "const size_t gpSize = ((((size_t)group->numSrcNeurons * (size_t)group->rowStride) + 32 - 1) / 32);" << std::endl; - os << "memset(group->gp, 0, gpSize * sizeof(uint32_t));" << std::endl; - } - else { - throw std::runtime_error("Only BITMASK and SPARSE format connectivity can be generated using a connectivity initialiser"); + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedCustomWUUpdateInitGroup" << c.getIndex() << "[g]; " << std::endl; + c.generateInit(*this, os, modelMerged, popSubs); } + }); - // If there is row-building code in this snippet - Substitutions popSubs(&funcSubs); - const auto *snippet = s.getArchetype().getConnectivityInitialiser().getSnippet(); - if(!snippet->getRowBuildCode().empty()) { - // Generate loop through source neurons - os << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)"; - - // Configure substitutions - popSubs.addVarSubstitution("id_pre", "i"); - popSubs.addVarSubstitution("id_post_begin", "0"); - popSubs.addVarSubstitution("id_thread", "0"); - popSubs.addVarSubstitution("num_threads", "1"); - popSubs.addVarSubstitution("num_pre", "group->numSrcNeurons"); - popSubs.addVarSubstitution("num_post", "group->numTrgNeurons"); - } - // Otherwise - else { - assert(!snippet->getColBuildCode().empty()); - - // Loop through target neurons - os << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; - - // Configure substitutions - popSubs.addVarSubstitution("id_post", "j"); - popSubs.addVarSubstitution("id_pre_begin", "0"); - popSubs.addVarSubstitution("id_thread", "0"); - popSubs.addVarSubstitution("num_threads", "1"); - popSubs.addVarSubstitution("num_pre", "group->numSrcNeurons"); - popSubs.addVarSubstitution("num_post", "group->numTrgNeurons"); - } + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Synapse sparse connectivity" << std::endl; + modelMerged.genMergedSynapseConnectivityInitGroups( + *this, + [this, &funcEnv, &modelMerged](auto &c) + { + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged synapse connectivity init group " << s.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); - // Create new stream to generate addSynapse function which initializes all kernel variables - std::ostringstream addSynapseStream; - CodeStream addSynapse(addSynapseStream); + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedSynapseConnectivityInitGroup" << s.getIndex() << "[g]; " << std::endl; - // Use classic macro trick to turn block of initialization code into statement and 'eat' semicolon - addSynapse << "do"; - { - CodeStream::Scope b(addSynapse); + // If matrix connectivity is ragged + if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { + // Zero row lengths + funcEnv.getStream() << "std::fill_n(group->rowLength, group->numSrcNeurons, 0);" << std::endl; + } + else if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { + funcEnv.getStream() << "const size_t gpSize = ((((size_t)group->numSrcNeurons * (size_t)group->rowStride) + 32 - 1) / 32);" << std::endl; + funcEnv.getStream() << "std::fill(group->gp, gpSize, 0);" << std::endl; + } + else { + throw std::runtime_error("Only BITMASK and SPARSE format connectivity can be generated using a connectivity initialiser"); + } - // Calculate index in data structure of this synapse - if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - if(!snippet->getRowBuildCode().empty()) { - addSynapse << "const unsigned int idx = " << "(" + popSubs["id_pre"] + " * group->rowStride) + group->rowLength[i];" << std::endl; - } - else { - addSynapse << "const unsigned int idx = " << "(($(0)) * group->rowStride) + group->rowLength[$(0)];" << std::endl; - } - } + // If there is row-building code in this snippet + EnvironmentGroupMergedField groupEnv(funcEnv, c); + const auto *snippet = s.getArchetype().getConnectivityInitialiser().getSnippet(); + if(!snippet->getRowBuildCode().empty()) { + // Generate loop through source neurons + groupEnv.getStream() << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)"; + + // Configure substitutions + groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); + groupEnv.add(Type::Uint32.addConst(), "id_post_begin", "0"); + groupEnv.add(Type::Uint32.addConst(), "id_thread", "0"); + groupEnv.add(Type::Uint32.addConst(), "num_threads", "1"); + //groupEnv.add("num_pre", "group->numSrcNeurons"); + //groupEnv.add("num_post", "group->numTrgNeurons"); + } + // Otherwise + else { + assert(!snippet->getColBuildCode().empty()); + + // Loop through target neurons + groupEnv.getStream() << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; + + // Configure substitutions + groupEnv.add(Type::Uint32.addConst(), "id_post", "j"); + groupEnv.add(Type::Uint32.addConst(), "id_pre_begin", "0"); + groupEnv.add(Type::Uint32.addConst(), "id_thread", "0"); + groupEnv.add(Type::Uint32.addConst(), "num_threads", "1"); + //popSubs.addVarSubstitution("num_pre", "group->numSrcNeurons"); + //popSubs.addVarSubstitution("num_post", "group->numTrgNeurons"); + } + { + CodeStream::Scope b(os); - // If there is a kernel - if(!s.getArchetype().getKernelSize().empty()) { - Substitutions kernelInitSubs(&popSubs); + // Create new stream to generate addSynapse function which initializes all kernel variables + std::ostringstream addSynapseStream; + CodeStream addSynapse(addSynapseStream); - // Replace $(id_post) with first 'function' parameter as simulation code is - // going to be, in turn, substituted into procedural connectivity generation code - if(!snippet->getRowBuildCode().empty()) { - kernelInitSubs.addVarSubstitution("id_post", "$(0)"); - } - else { - kernelInitSubs.addVarSubstitution("id_pre", "$(0)"); - } + // Use classic macro trick to turn block of initialization code into statement and 'eat' semicolon + addSynapse << "do"; + { + CodeStream::Scope b(addSynapse); - // Add index of synapse + // Calculate index in data structure of this synapse if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - kernelInitSubs.addVarSubstitution("id_syn", "idx"); + if(!snippet->getRowBuildCode().empty()) { + addSynapse << "const unsigned int idx = " << "(" + groupEnv["id_pre"] + " * " << groupEnv["_row_stride"] << ") + " << groupEnv["_row_length"] << "[i];" << std::endl; + } + else { + addSynapse << "const unsigned int idx = " << "(($(0)) * " << groupEnv["_row_stride"] << ") + groupEnv["_row_length"][$(0)];" << std::endl; + } } - // Replace kernel indices with the subsequent 'function' parameters - for(size_t i = 0; i < s.getArchetype().getKernelSize().size(); i++) { - kernelInitSubs.addVarSubstitution("id_kernel_" + std::to_string(i), "$(" + std::to_string(i + 1) + ")"); - } + // If there is a kernel + if(!s.getArchetype().getKernelSize().empty()) { + EnvironmentGroupMergedField kernelInitEnv(groupEnv, c); - // Call handler to initialize variables - s.generateKernelInit(*this, addSynapse, modelMerged, kernelInitSubs); - } + // Replace $(id_post) with first 'function' parameter as simulation code is + // going to be, in turn, substituted into procedural connectivity generation code + assert(false); + if(!snippet->getRowBuildCode().empty()) { + kernelInitEnv.add(Type::Uint32.addConst(), "id_post", "$(0)"); + } + else { + kernelInitEnv.add(Type::Uint32.addConst(), "id_pre", "$(0)"); + } - // If there is row-building code in this snippet - if(!snippet->getRowBuildCode().empty()) { - // If matrix is sparse, add function to increment row length and insert synapse into ind array - if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addSynapse << "group->ind[idx] = $(0);" << std::endl; - addSynapse << "group->rowLength[i]++;" << std::endl; - } - // Otherwise, add function to set correct bit in bitmask - else { - addSynapse << "const int64_t rowStartGID = i * group->rowStride;" << std::endl; - addSynapse << "setB(group->gp[(rowStartGID + ($(0))) / 32], (rowStartGID + $(0)) & 31);" << std::endl; + // Add index of synapse + if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { + kernelInitEnv.add(Type::Uint32.addConst(), "id_syn", "idx"); + } + + // Replace kernel indices with the subsequent 'function' parameters + for(size_t i = 0; i < s.getArchetype().getKernelSize().size(); i++) { + kernelInitEnv.add(Type::Uint32.addConst(), "id_kernel_" + std::to_string(i), "$(" + std::to_string(i + 1) + ")"); + } + + // Call handler to initialize variables + s.generateKernelInit(*this, addSynapse, modelMerged, kernelInitSubs); } - } - // Otherwise - else { - // If matrix is sparse, add function to increment row length and insert synapse into ind array - if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addSynapse << "group->ind[idx] = " << popSubs["id_post"] << ";" << std::endl; - addSynapse << "group->rowLength[$(0)]++;" << std::endl; + + // If there is row-building code in this snippet + if(!snippet->getRowBuildCode().empty()) { + // If matrix is sparse, add function to increment row length and insert synapse into ind array + if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { + addSynapse << groupEnv["_ind"] << "[idx] = $(0);" << std::endl; + addSynapse << groupEnv["_row_length"] << "[i]++;" << std::endl; + } + // Otherwise, add function to set correct bit in bitmask + else { + addSynapse << "const int64_t rowStartGID = i * " << groupEnv["_row_stride"] << ";" << std::endl; + addSynapse << "setB(group->gp[(rowStartGID + ($(0))) / 32], (rowStartGID + $(0)) & 31);" << std::endl; + } } + // Otherwise else { - addSynapse << "const int64_t colStartGID = j;" << std::endl; - addSynapse << "setB(group->gp[(colStartGID + (($(0)) * group->rowStride)) / 32], ((colStartGID + (($(0)) * group->rowStride)) & 31));" << std::endl; + // If matrix is sparse, add function to increment row length and insert synapse into ind array + if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { + addSynapse << groupEnv["_ind"] << "[idx] = " << groupEnv["id_post"] << ";" << std::endl; + addSynapse << groupEnv["_row_length"] << "[$(0)]++;" << std::endl; + } + else { + addSynapse << "const int64_t colStartGID = j;" << std::endl; + addSynapse << "setB(" << groupEnv["_gp"] << "[(colStartGID + (($(0)) * " << groupEnv["_row_stride"] << ")) / 32], ((colStartGID + (($(0)) * " << groupEnv["_row_stride"] << ")) & 31));" << std::endl; + } } } - } - addSynapse << "while(false)"; + addSynapse << "while(false)"; - popSubs.addFuncSubstitution("addSynapse", 1 + (unsigned int)s.getArchetype().getKernelSize().size(), - addSynapseStream.str()); + const auto addSynapseType = Type::ResolvedType::createFunction(Type::Void, std::vector{1ull + s.getArchetype().getKernelSize().size(), Type::Uint32}); + groupEnv.add(addSynapseType, "addSynapse", addSynapseStream.str()); - // Call appropriate connectivity handler - if(!snippet->getRowBuildCode().empty()) { - s.generateSparseRowInit(*this, os, modelMerged, popSubs); - } - else { - s.generateSparseColumnInit(*this, os, modelMerged, popSubs); + // Call appropriate connectivity handler + if(!snippet->getRowBuildCode().empty()) { + s.generateSparseRowInit(*this, groupEnv, modelMerged); + } + else { + s.generateSparseColumnInit(*this, groupEnv, modelMerged); + } } } - } - } + }); } - os << std::endl; - os << "void initializeSparse()"; + initEnv.getStream() << std::endl; + initEnv.getStream() << "void initializeSparse()"; { - CodeStream::Scope b(os); - Substitutions funcSubs(getFunctionTemplates(model.getPrecision().getName())); + CodeStream::Scope b(initEnv.getStream()); + EnvironmentExternal funcEnv(initEnv); - Timer t(os, "initSparse", model.isTimingEnabled()); - - // If model requires RNG, add it to substitutions - if(isGlobalHostRNGRequired(modelMerged)) { - funcSubs.addVarSubstitution("rng", "hostRNG"); - } + Timer t(funcEnv.getStream(), "initSparse", model.isTimingEnabled()); - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Synapse groups with sparse connectivity" << std::endl; - for(const auto &s : modelMerged.getMergedSynapseSparseInitGroups()) { - CodeStream::Scope b(os); - os << "// merged sparse synapse init group " << s.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Synapse groups with sparse connectivity" << std::endl; + modelMerged.genMergedSynapseSparseInitGroups( + *this, + [this, &funcEnv, &modelMerged](auto &s) { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged sparse synapse init group " << s.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(funcEnv.getStream()); - // Get reference to group - os << "const auto *group = &mergedSynapseSparseInitGroup" << s.getIndex() << "[g]; " << std::endl; + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedSynapseSparseInitGroup" << s.getIndex() << "[g]; " << std::endl; - // If postsynaptic learning is required, initially zero column lengths - if (!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { - os << "// Zero column lengths" << std::endl; - os << "std::fill_n(group->colLength, group->numTrgNeurons, 0);" << std::endl; - } + // If postsynaptic learning is required, initially zero column lengths + if (!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { + funcEnv.getStream() << "// Zero column lengths" << std::endl; + funcEnv.getStream() << "std::fill_n(group->colLength, group->numTrgNeurons, 0);" << std::endl; + } - os << "// Loop through presynaptic neurons" << std::endl; - os << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)" << std::endl; - { - CodeStream::Scope b(os); + funcEnv.getStream() << "// Loop through presynaptic neurons" << std::endl; + funcEnv.getStream() << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)" << std::endl; + { + CodeStream::Scope b(funcEnv.getStream()); - // Generate sparse initialisation code - if(s.getArchetype().isWUVarInitRequired()) { - Substitutions popSubs(&funcSubs); - popSubs.addVarSubstitution("id_pre", "i"); - popSubs.addVarSubstitution("row_len", "group->rowLength[i]"); - s.generateInit(*this, os, modelMerged, popSubs); - } + // Generate sparse initialisation code + if(s.getArchetype().isWUVarInitRequired()) { + Substitutions popSubs(&funcSubs); + popSubs.addVarSubstitution("id_pre", "i"); + popSubs.addVarSubstitution("row_len", "group->rowLength[i]"); + s.generateInit(*this, os, modelMerged, popSubs); + } - // If postsynaptic learning is required - if(!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { - os << "// Loop through synapses in corresponding matrix row" << std::endl; - os << "for(unsigned int j = 0; j < group->rowLength[i]; j++)" << std::endl; - { - CodeStream::Scope b(os); - - // If postsynaptic learning is required, calculate column length and remapping - if(!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { - os << "// Calculate index of this synapse in the row-major matrix" << std::endl; - os << "const unsigned int rowMajorIndex = (i * group->rowStride) + j;" << std::endl; - os << "// Using this, lookup postsynaptic target" << std::endl; - os << "const unsigned int postIndex = group->ind[rowMajorIndex];" << std::endl; - os << "// From this calculate index of this synapse in the column-major matrix" << std::endl; - os << "const unsigned int colMajorIndex = (postIndex * group->colStride) + group->colLength[postIndex];" << std::endl; - os << "// Increment column length corresponding to this postsynaptic neuron" << std::endl; - os << "group->colLength[postIndex]++;" << std::endl; - os << "// Add remapping entry" << std::endl; - os << "group->remap[colMajorIndex] = rowMajorIndex;" << std::endl; + // If postsynaptic learning is required + if(!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { + os << "// Loop through synapses in corresponding matrix row" << std::endl; + os << "for(unsigned int j = 0; j < group->rowLength[i]; j++)" << std::endl; + { + CodeStream::Scope b(os); + + // If postsynaptic learning is required, calculate column length and remapping + if(!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { + os << "// Calculate index of this synapse in the row-major matrix" << std::endl; + os << "const unsigned int rowMajorIndex = (i * group->rowStride) + j;" << std::endl; + os << "// Using this, lookup postsynaptic target" << std::endl; + os << "const unsigned int postIndex = group->ind[rowMajorIndex];" << std::endl; + os << "// From this calculate index of this synapse in the column-major matrix" << std::endl; + os << "const unsigned int colMajorIndex = (postIndex * group->colStride) + group->colLength[postIndex];" << std::endl; + os << "// Increment column length corresponding to this postsynaptic neuron" << std::endl; + os << "group->colLength[postIndex]++;" << std::endl; + os << "// Add remapping entry" << std::endl; + os << "group->remap[colMajorIndex] = rowMajorIndex;" << std::endl; + } } } } } - } - } + }); - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom sparse WU update groups" << std::endl; - for(const auto &c : modelMerged.getMergedCustomWUUpdateSparseInitGroups()) { - CodeStream::Scope b(os); - os << "// merged custom sparse WU update group " << c.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Custom sparse WU update groups" << std::endl; + modelMerged.genMergedCustomWUUpdateSparseInitGroups( + *this, + [this, &funcEnv, &modelMerged](auto &c) { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged custom sparse WU update group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(funcEnv.getStream()); - // Get reference to group - os << "const auto *group = &mergedCustomWUUpdateSparseInitGroup" << c.getIndex() << "[g]; " << std::endl; + // Get reference to group + funcEnv.getStream() << "const auto *group = &mergedCustomWUUpdateSparseInitGroup" << c.getIndex() << "[g]; " << std::endl; - os << "// Loop through presynaptic neurons" << std::endl; - os << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)" << std::endl; - { - CodeStream::Scope b(os); + os << "// Loop through presynaptic neurons" << std::endl; + os << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)" << std::endl; + { + CodeStream::Scope b(os); - // Generate initialisation code - Substitutions popSubs(&funcSubs); - popSubs.addVarSubstitution("id_pre", "i"); - popSubs.addVarSubstitution("row_len", "group->rowLength[i]"); - c.generateInit(*this, os, modelMerged, popSubs); + // Generate initialisation code + Substitutions popSubs(&funcSubs); + popSubs.addVarSubstitution("id_pre", "i"); + popSubs.addVarSubstitution("row_len", "group->rowLength[i]"); + c.generateInit(*this, os, modelMerged, popSubs); + } } - } - } + }); - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom connectivity update sparse init groups" << std::endl; - for(const auto &c : modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups()) { - CodeStream::Scope b(os); - os << "// merged custom connectivity update sparse init group " << c.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Custom connectivity update sparse init groups" << std::endl; + modelMerged.genMergedCustomConnectivityUpdateSparseInitGroups( + *this, + [this, &funcEnv, &modelMerged](auto &c) { CodeStream::Scope b(os); - - // Get reference to group - os << "const auto *group = &mergedCustomConnectivityUpdateSparseInitGroup" << c.getIndex() << "[g]; " << std::endl; - - os << "// Loop through presynaptic neurons" << std::endl; - os << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)" << std::endl; + os << "// merged custom connectivity update sparse init group " << c.getIndex() << std::endl; + os << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; { CodeStream::Scope b(os); - // Generate initialisation code - Substitutions popSubs(&funcSubs); - popSubs.addVarSubstitution("id_pre", "i"); - popSubs.addVarSubstitution("row_len", "group->rowLength[i]"); - c.generateInit(*this, os, modelMerged, popSubs); + // Get reference to group + os << "const auto *group = &mergedCustomConnectivityUpdateSparseInitGroup" << c.getIndex() << "[g]; " << std::endl; + + os << "// Loop through presynaptic neurons" << std::endl; + os << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)" << std::endl; + { + CodeStream::Scope b(os); + + // Generate initialisation code + Substitutions popSubs(&funcSubs); + popSubs.addVarSubstitution("id_pre", "i"); + popSubs.addVarSubstitution("row_len", "group->rowLength[i]"); + c.generateInit(*this, os, modelMerged, popSubs); + } } - } - } + }); } + + + // Generate struct definitions + modelMerged.genMergedNeuronInitGroupStructs(os, *this); + modelMerged.genMergedSynapseInitGroupStructs(os, *this); + modelMerged.genMergedCustomUpdateInitGroupStructs(os, *this); + modelMerged.genMergedCustomWUUpdateInitGroupStructs(os, *this); + modelMerged.genMergedSynapseConnectivityInitGroupStructs(os, *this); + modelMerged.genMergedSynapseSparseInitGroupStructs(os, *this); + modelMerged.genMergedCustomWUUpdateSparseInitGroupStructs(os, *this); + modelMerged.genMergedCustomConnectivityUpdatePreInitStructs(os, *this); + modelMerged.genMergedCustomConnectivityUpdatePostInitStructs(os, *this); + modelMerged.genMergedCustomConnectivityUpdateSparseInitStructs(os, *this); + + // Generate arrays of merged structs and functions to set them + genMergedStructArrayPush(os, modelMerged.getMergedNeuronInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedSynapseInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomWUUpdateInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedSynapseConnectivityInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedSynapseSparseInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomWUUpdateSparseInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdatePreInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdatePostInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups()); + + // Generate preamble + preambleHandler(os); + + os << initStream.str(); + } //-------------------------------------------------------------------------- size_t Backend::getSynapticMatrixRowStride(const SynapseGroupInternal &sg) const diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 47d30477eb..2084d69ded 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -195,7 +195,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdateBase(const BackendBase & // Loop through kernel size dimensions for (size_t d = 0; d < getArchetype().getSynapseGroup()->getKernelSize().size(); d++) { // If this dimension has a heterogeneous size, add it to struct - if (isKernelSizeHeterogeneous(d)) { + if (isKernelSizeHeterogeneous(*this, d)) { cuEnv.addField(Type::Uint32, "_kernel_size_" + std::to_string(d), "kernelSize" + std::to_string(d), [d](const auto &cu, size_t) { From 1dd323174116cf31476fb424ceea94e0e3211a44 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 21 Jun 2023 09:58:45 +0100 Subject: [PATCH 241/725] closer to a CPU backend which compiles --- .../backends/single_threaded_cpu/backend.h | 4 +- .../genn/code_generator/modelSpecMerged.h | 16 +- .../backends/single_threaded_cpu/backend.cc | 138 +++++++++--------- 3 files changed, 80 insertions(+), 78 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 7caf718a06..cb8eec02fe 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -186,9 +186,9 @@ class BACKEND_EXPORT Backend : public BackendBase //-------------------------------------------------------------------------- // Private methods //-------------------------------------------------------------------------- - void genPresynapticUpdate(EnvironmentExternalBase &env, const PresynapticUpdateGroupMerged &sg, const ModelSpecMerged &modelMerged, bool trueSpike) const; + void genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, const ModelSpecMerged &modelMerged, bool trueSpike) const; - void genEmitSpike(EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, bool trueSpike, bool recordingEnabled) const; + void genEmitSpike(EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng, bool trueSpike, bool recordingEnabled) const; template void genMergedStructArrayPush(CodeStream &os, const std::vector &groups) const diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index d636e71f35..9b31cff712 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -200,7 +200,7 @@ class GENN_EXPORT ModelSpecMerged void genMergedCustomUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, G generateGroup) { createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateGroups, - [&updateGroupName](const CustomUpdateInternal &cg) { return cg.getUpdateGroupName() == updateGroupName(); }, + [&updateGroupName](const CustomUpdateInternal &cg) { return cg.getUpdateGroupName() == updateGroupName; }, &CustomUpdateInternal::getHashDigest, generateGroup); } @@ -210,7 +210,7 @@ class GENN_EXPORT ModelSpecMerged createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, [&updateGroupName](const CustomUpdateWUInternal &cg) { - return (!cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName()); + return (!cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName); }, &CustomUpdateWUInternal::getHashDigest, generateGroup); } @@ -221,7 +221,7 @@ class GENN_EXPORT ModelSpecMerged createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, [&updateGroupName](const CustomUpdateWUInternal &cg) { - return (cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName()); + return (cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName); }, &CustomUpdateWUInternal::getHashDigest, generateGroup); } @@ -232,7 +232,7 @@ class GENN_EXPORT ModelSpecMerged createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateHostReductionGroups, [&updateGroupName](const CustomUpdateInternal &cg) { - return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName()); + return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName); }, &CustomUpdateInternal::getHashDigest, generateGroup, true); } @@ -243,7 +243,7 @@ class GENN_EXPORT ModelSpecMerged createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateHostReductionGroups, [&updateGroupName](const CustomUpdateWUInternal &cg) { - return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName()); + return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName); }, &CustomUpdateWUInternal::getHashDigest, generateGroup, true); } @@ -254,9 +254,9 @@ class GENN_EXPORT ModelSpecMerged createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateGroups, [&updateGroupName](const CustomConnectivityUpdateInternal &cg) { - return (!cg.getCustomConnectivityUpdateModel()->getRowUpdateCode().empty() && cg.getUpdateGroupName() == updateGroupName()); + return (!cg.getCustomConnectivityUpdateModel()->getRowUpdateCode().empty() && cg.getUpdateGroupName() == updateGroupName); }, - &CustomConnectivityUpdateInternal::getHashDigest, genereateGroup); + &CustomConnectivityUpdateInternal::getHashDigest, generateGroup); } template @@ -265,7 +265,7 @@ class GENN_EXPORT ModelSpecMerged createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityHostUpdateGroups, [&updateGroupName](const CustomConnectivityUpdateInternal &cg) { - return (!cg.getCustomConnectivityUpdateModel()->getHostUpdateCode().empty() && cg.getUpdateGroupName() == updateGroupName()); + return (!cg.getCustomConnectivityUpdateModel()->getHostUpdateCode().empty() && cg.getUpdateGroupName() == updateGroupName); }, &CustomConnectivityUpdateInternal::getHashDigest, generateGroup, true); } diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index f9b8b099ee..2c7c0c7b1c 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -264,12 +264,12 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host groupEnv.add(Type::Uint32, "id", "i"); // Add RNG libray - EnvironmentLibrary rngEnv(groupEnv, (modelMerged.getModel().getPrecision() == Type::Float) ? cpuSinglePrecisionFunctions : cpuDoublePrecisionFunctions; + EnvironmentLibrary rngEnv(groupEnv, (modelMerged.getModel().getPrecision() == Type::Float) ? cpuSinglePrecisionFunctions : cpuDoublePrecisionFunctions); // Generate neuron update n.generateNeuronUpdate(*this, rngEnv, modelMerged, // Emit true spikes - [&modelMerged, this](EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng) + [&modelMerged, this](EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng) { // Insert code to update WU vars ng.generateWUVarUpdate(*this, env, modelMerged); @@ -278,7 +278,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host genEmitSpike(env, ng, true, ng.getArchetype().isSpikeRecordingEnabled()); }, // Emit spike-like events - [this](EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng) + [this](EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng) { // Insert code to emit spike-like events genEmitSpike(env, ng, false, ng.getArchetype().isSpikeEventRecordingEnabled()); @@ -302,7 +302,6 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host preambleHandler(os); os << neuronUpdateStream.str(); - } //-------------------------------------------------------------------------- void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const @@ -396,7 +395,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos synEnv.add(Type::AddToPost, "addToPost", synEnv["_out_post"] + "[" + s.getPostISynIndex(1, "j") + "] += $(0)", {}, {"_out_post"}); synEnv.add(Type::AddToPre, "addToPre", synEnv["_out_pre"] + "[" + s.getPreISynIndex(1, synEnv["id_pre"]) + "] += $(0)", - {}, {"id_pre"})); + {}, {"id_pre"}); // Call synapse dynamics handler s.generateSynapseUpdate(*this, synEnv, modelMerged); @@ -411,7 +410,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos Timer t(funcEnv.getStream(), "presynapticUpdate", modelMerged.getModel().isTimingEnabled()); modelMerged.genMergedPresynapticUpdateGroups( *this, - [this, &funcEnv](auto &s) + [this, &funcEnv, &modelMerged](auto &s) { CodeStream::Scope b(funcEnv.getStream()); funcEnv.getStream() << "// merged presynaptic update group " << s.getIndex() << std::endl; @@ -446,7 +445,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos Timer t(funcEnv.getStream(), "postsynapticUpdate", modelMerged.getModel().isTimingEnabled()); modelMerged.genMergedPostsynapticUpdateGroups( *this, - [this, &funcEnv](auto &s) + [this, &funcEnv, &modelMerged](auto &s) { CodeStream::Scope b(funcEnv.getStream()); funcEnv.getStream() << "// merged postsynaptic update group " << s.getIndex() << std::endl; @@ -611,7 +610,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host const auto reductionTargets = genInitReductionTargets(groupEnv.getStream(), c); // Loop through group members - groupEnv.getStream() << "for(unsigned int i = 0; i < group->size; i++)"; + groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["size"] << "; i++)"; { CodeStream::Scope b(groupEnv.getStream()); @@ -634,7 +633,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host } else { // Loop through group members - groupEnv.getStream() << "for(unsigned int i = 0; i < group->size; i++)"; + groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["size"] << "; i++)"; { CodeStream::Scope b(groupEnv.getStream()); @@ -682,16 +681,16 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host } else { // Loop through presynaptic neurons - groupEnv.getStream() << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; + groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)"; { // If this synapse group has sparse connectivity, loop through length of this row CodeStream::Scope b(synEnv.getStream()); if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - groupEnv.getStream() << "for(unsigned int s = 0; s < group->rowLength[i]; s++)"; + groupEnv.getStream() << "for(unsigned int s = 0; s < " << groupEnv["_row_length"] << "[i]; s++)"; } // Otherwise, if it's dense, loop through each postsynaptic neuron else if (sg->getMatrixType() & SynapseMatrixConnectivity::DENSE) { - groupEnv.getStream() << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; + groupEnv.getStream() << "for (unsigned int j = 0; j < " << groupEnv["size"] << "; j++)"; } else { throw std::runtime_error("Only DENSE and SPARSE format connectivity can be used for custom updates"); @@ -718,7 +717,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host const size_t idSynInit = ; synEnv.addSubstitution("id_syn", "idSyn", - {synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["num_post"] + ") + j;")}, + {synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["num_post"] + ") + j;")}), } // Generate custom update @@ -845,7 +844,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host } //-------------------------------------------------------------------------- -void Backend::genInit(CodeStream &_os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); if(model.getBatchSize() != 1) { @@ -894,7 +893,7 @@ void Backend::genInit(CodeStream &_os, ModelSpecMerged &modelMerged, HostHandler funcEnv.getStream() << "// merged synapse init group " << s.getIndex() << std::endl; funcEnv.getStream() << "for(unsigned int g = 0; g < " << s.getGroups().size() << "; g++)"; { - CodeStream::Scope b(osfuncEnv.getStream(); + CodeStream::Scope b(funcEnv.getStream()); // Get reference to group funcEnv.getStream() << "const auto *group = &mergedSynapseInitGroup" << s.getIndex() << "[g]; " << std::endl; @@ -952,7 +951,7 @@ void Backend::genInit(CodeStream &_os, ModelSpecMerged &modelMerged, HostHandler // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomConnectivityUpdatePostInitGroup" << c.getIndex() << "[g]; " << std::endl; - c.generateInit(*this, os, modelMerged, popSubs); + c.generateInit(*this, funcEnv, modelMerged); } }); @@ -966,11 +965,11 @@ void Backend::genInit(CodeStream &_os, ModelSpecMerged &modelMerged, HostHandler funcEnv.getStream() << "// merged custom WU update group " << c.getIndex() << std::endl; funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomWUUpdateInitGroup" << c.getIndex() << "[g]; " << std::endl; - c.generateInit(*this, os, modelMerged, popSubs); + c.generateInit(*this, funcEnv, modelMerged); } }); @@ -988,26 +987,26 @@ void Backend::genInit(CodeStream &_os, ModelSpecMerged &modelMerged, HostHandler // Get reference to group funcEnv.getStream() << "const auto *group = &mergedSynapseConnectivityInitGroup" << s.getIndex() << "[g]; " << std::endl; + EnvironmentGroupMergedField groupEnv(funcEnv, c); // If matrix connectivity is ragged if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Zero row lengths - funcEnv.getStream() << "std::fill_n(group->rowLength, group->numSrcNeurons, 0);" << std::endl; + funcEnv.getStream() << "std::fill_n(" << groupEnv["_row_length"] << ", " << groupEnv["num_pre"] << ", 0);" << std::endl; } else if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { - funcEnv.getStream() << "const size_t gpSize = ((((size_t)group->numSrcNeurons * (size_t)group->rowStride) + 32 - 1) / 32);" << std::endl; - funcEnv.getStream() << "std::fill(group->gp, gpSize, 0);" << std::endl; + funcEnv.getStream() << "const size_t gpSize = ((((size_t)" << groupEnv["num_pre"] << " * (size_t)" << groupEnv["_row_stride"] << ") + 32 - 1) / 32);" << std::endl; + funcEnv.getStream() << "std::fill(" << groupEnv["_num_gp"] << ", gpSize, 0);" << std::endl; } else { throw std::runtime_error("Only BITMASK and SPARSE format connectivity can be generated using a connectivity initialiser"); } // If there is row-building code in this snippet - EnvironmentGroupMergedField groupEnv(funcEnv, c); const auto *snippet = s.getArchetype().getConnectivityInitialiser().getSnippet(); if(!snippet->getRowBuildCode().empty()) { // Generate loop through source neurons - groupEnv.getStream() << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)"; + groupEnv.getStream() << "for (unsigned int i = 0; i <" << groupEnv["num_pre"] << "; i++)"; // Configure substitutions groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); @@ -1033,7 +1032,7 @@ void Backend::genInit(CodeStream &_os, ModelSpecMerged &modelMerged, HostHandler //popSubs.addVarSubstitution("num_post", "group->numTrgNeurons"); } { - CodeStream::Scope b(os); + CodeStream::Scope b(groupEnv.getStream()); // Create new stream to generate addSynapse function which initializes all kernel variables std::ostringstream addSynapseStream; @@ -1146,45 +1145,46 @@ void Backend::genInit(CodeStream &_os, ModelSpecMerged &modelMerged, HostHandler // Get reference to group funcEnv.getStream() << "const auto *group = &mergedSynapseSparseInitGroup" << s.getIndex() << "[g]; " << std::endl; + EnvironmentGroupMergedField groupEnv(funcEnv, c); // If postsynaptic learning is required, initially zero column lengths if (!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { - funcEnv.getStream() << "// Zero column lengths" << std::endl; - funcEnv.getStream() << "std::fill_n(group->colLength, group->numTrgNeurons, 0);" << std::endl; + groupEnv.getStream() << "// Zero column lengths" << std::endl; + groupEnv.getStream() << "std::fill_n(" << groupEnv["_col_length"] << ", " << groupEnv["num_post"] << ", 0);" << std::endl; } - funcEnv.getStream() << "// Loop through presynaptic neurons" << std::endl; - funcEnv.getStream() << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)" << std::endl; + groupEnv.getStream() << "// Loop through presynaptic neurons" << std::endl; + groupEnv.getStream() << "for (unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)" << std::endl; { - CodeStream::Scope b(funcEnv.getStream()); + CodeStream::Scope b(groupEnv.getStream()); // Generate sparse initialisation code + groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); if(s.getArchetype().isWUVarInitRequired()) { - Substitutions popSubs(&funcSubs); - popSubs.addVarSubstitution("id_pre", "i"); - popSubs.addVarSubstitution("row_len", "group->rowLength[i]"); - s.generateInit(*this, os, modelMerged, popSubs); + groupEnv.add(Type::Uint32.addConst(), "row_len", groupEnv["_row_length"] + "[i]", + {"_row_length"}); + s.generateInit(*this, groupEnv, modelMerged); } // If postsynaptic learning is required if(!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { - os << "// Loop through synapses in corresponding matrix row" << std::endl; - os << "for(unsigned int j = 0; j < group->rowLength[i]; j++)" << std::endl; + groupEnv.getStream() << "// Loop through synapses in corresponding matrix row" << std::endl; + groupEnv.getStream() << "for(unsigned int j = 0; j < " << groupEnv["_row_length"] << "[i]; j++)" << std::endl; { - CodeStream::Scope b(os); + CodeStream::Scope b(groupEnv.getStream()); // If postsynaptic learning is required, calculate column length and remapping if(!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { - os << "// Calculate index of this synapse in the row-major matrix" << std::endl; - os << "const unsigned int rowMajorIndex = (i * group->rowStride) + j;" << std::endl; - os << "// Using this, lookup postsynaptic target" << std::endl; - os << "const unsigned int postIndex = group->ind[rowMajorIndex];" << std::endl; - os << "// From this calculate index of this synapse in the column-major matrix" << std::endl; - os << "const unsigned int colMajorIndex = (postIndex * group->colStride) + group->colLength[postIndex];" << std::endl; - os << "// Increment column length corresponding to this postsynaptic neuron" << std::endl; - os << "group->colLength[postIndex]++;" << std::endl; - os << "// Add remapping entry" << std::endl; - os << "group->remap[colMajorIndex] = rowMajorIndex;" << std::endl; + groupEnv.getStream() << "// Calculate index of this synapse in the row-major matrix" << std::endl; + groupEnv.getStream() << "const unsigned int rowMajorIndex = (i * " << groupEnv["_row_stride"] << ") + j;" << std::endl; + groupEnv.getStream() << "// Using this, lookup postsynaptic target" << std::endl; + groupEnv.getStream() << "const unsigned int postIndex = " << groupEnv["_ind"] << "[rowMajorIndex];" << std::endl; + groupEnv.getStream() << "// From this calculate index of this synapse in the column-major matrix" << std::endl; + groupEnv.getStream() << "const unsigned int colMajorIndex = (postIndex * " << groupEnv["_col_stride"] << ") + " << groupEnv["_col_length"] << "[postIndex];" << std::endl; + groupEnv.getStream() << "// Increment column length corresponding to this postsynaptic neuron" << std::endl; + groupEnv.getStream() << groupEnv["_col_length"] << "[postIndex]++;" << std::endl; + groupEnv.getStream() << "// Add remapping entry" << std::endl; + groupEnv.getStream() << groupEnv["_remap"] << "p[colMajorIndex] = rowMajorIndex;" << std::endl; } } } @@ -1206,17 +1206,18 @@ void Backend::genInit(CodeStream &_os, ModelSpecMerged &modelMerged, HostHandler // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomWUUpdateSparseInitGroup" << c.getIndex() << "[g]; " << std::endl; + EnvironmentGroupMergedField groupEnv(funcEnv, c); - os << "// Loop through presynaptic neurons" << std::endl; - os << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)" << std::endl; + groupEnv.getStream() << "// Loop through presynaptic neurons" << std::endl; + groupEnv.getStream() << "for (unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)" << std::endl; { - CodeStream::Scope b(os); + CodeStream::Scope b(groupEnv.getStream()); // Generate initialisation code - Substitutions popSubs(&funcSubs); - popSubs.addVarSubstitution("id_pre", "i"); - popSubs.addVarSubstitution("row_len", "group->rowLength[i]"); - c.generateInit(*this, os, modelMerged, popSubs); + groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); + groupEnv.add(Type::Uint32.addConst(), "row_len", groupEnv["_row_length"] + "[i]", + {"_row_length"}); + c.generateInit(*this, groupEnv, modelMerged); } } }); @@ -1227,25 +1228,26 @@ void Backend::genInit(CodeStream &_os, ModelSpecMerged &modelMerged, HostHandler *this, [this, &funcEnv, &modelMerged](auto &c) { - CodeStream::Scope b(os); - os << "// merged custom connectivity update sparse init group " << c.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; + CodeStream::Scope b(funcEnv.getStream()); + funcEnv.getStream() << "// merged custom connectivity update sparse init group " << c.getIndex() << std::endl; + funcEnv.getStream() << "for(unsigned int g = 0; g < " << c.getGroups().size() << "; g++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(funcEnv.getStream()); // Get reference to group - os << "const auto *group = &mergedCustomConnectivityUpdateSparseInitGroup" << c.getIndex() << "[g]; " << std::endl; + funcEnv.getStream() << "const auto *group = &mergedCustomConnectivityUpdateSparseInitGroup" << c.getIndex() << "[g]; " << std::endl; + EnvironmentGroupMergedField groupEnv(funcEnv, c); - os << "// Loop through presynaptic neurons" << std::endl; - os << "for (unsigned int i = 0; i < group->numSrcNeurons; i++)" << std::endl; + groupEnv.getStream() << "// Loop through presynaptic neurons" << std::endl; + groupEnv.getStream() << "for (unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)" << std::endl; { - CodeStream::Scope b(os); + CodeStream::Scope b(groupEnv.getStream()); // Generate initialisation code - Substitutions popSubs(&funcSubs); - popSubs.addVarSubstitution("id_pre", "i"); - popSubs.addVarSubstitution("row_len", "group->rowLength[i]"); - c.generateInit(*this, os, modelMerged, popSubs); + groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); + groupEnv.add(Type::Uint32.addConst(), "row_len", groupEnv["_row_length"] + "[i]", + {"_row_length"}); + c.generateInit(*this, groupEnv, modelMerged); } } }); @@ -1509,7 +1511,7 @@ void Backend::genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, Hand {varEnv.addInitialiser("const unsigned int idSyn = (" + varEnv["id_pre"] + " * " + varEnv["_row_stride"] + ") + j;")}, {"id_pre", "_rowStride"}); varEnv.add(Type::Uint32, "id_post", "idPost", - {varEnv.addInitialiser("const unsigned int idPost = (" + varEnv["_ind"] + "[(" + varEnv["id_pre"] + " * " + varEnv["_row_stride"] + ") + j]"); + {varEnv.addInitialiser("const unsigned int idPost = (" + varEnv["_ind"] + "[(" + varEnv["id_pre"] + " * " + varEnv["_row_stride"] + ") + j]")}); handler(varEnv); } } @@ -1727,7 +1729,7 @@ boost::uuids::detail::sha1::digest_type Backend::getHashDigest() const return hash.get_digest(); } //-------------------------------------------------------------------------- -void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, const PresynapticUpdateGroupMerged &sg, const ModelSpecMerged &modelMerged, bool trueSpike) const +void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, const ModelSpecMerged &modelMerged, bool trueSpike) const { // Get suffix based on type of events const std::string eventSuffix = trueSpike ? "" : "Evnt"; @@ -1997,7 +1999,7 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, const Presynapt } } //-------------------------------------------------------------------------- -void Backend::genEmitSpike(EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, bool trueSpike, bool recordingEnabled) const +void Backend::genEmitSpike(EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng, bool trueSpike, bool recordingEnabled) const { // Determine if delay is required and thus, at what offset we should write into the spike queue const bool spikeDelayRequired = trueSpike ? (ng.getArchetype().isDelayRequired() && ng.getArchetype().isTrueSpikeRequired()) : ng.getArchetype().isDelayRequired(); From 0759eb1af824aa254dc3187e022f84dbd1044bfd Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 21 Jun 2023 16:34:45 +0100 Subject: [PATCH 242/725] fixed some typos --- include/genn/genn/code_generator/backendBase.h | 2 +- include/genn/genn/code_generator/modelSpecMerged.h | 2 +- src/genn/backends/single_threaded_cpu/backend.cc | 14 +++++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 526fb2f2c9..26df023082 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -195,7 +195,7 @@ class GENN_EXPORT BackendBase using GroupHandler = std::function ; template - using GroupHandlerEnv = std::function ; + using GroupHandlerEnv = std::function ; //! Vector of prefixes required to allocate in memory space and size of memory space typedef std::vector> MemorySpaces; diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 9b31cff712..a2b53d0efd 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -554,7 +554,7 @@ class GENN_EXPORT ModelSpecMerged // Add unmerged groups to correct vector for(const auto &g : unmergedGroups) { - protoMergedGroups[std::invoke(g.get(), getHashDigest)].push_back(g); + protoMergedGroups[std::invoke(getHashDigest, g.get())].push_back(g); } // Reserve final merged groups vector diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 2c7c0c7b1c..d202375516 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -428,12 +428,12 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // generate the code for processing spike-like events if (s.getArchetype().isSpikeEventRequired()) { - genPresynapticUpdate(groupEnv, modelMerged, s, false); + genPresynapticUpdate(groupEnv, s, modelMerged, false); } // generate the code for processing true spike events if (s.getArchetype().isTrueSpikeRequired()) { - genPresynapticUpdate(groupEnv, modelMerged, s, true); + genPresynapticUpdate(groupEnv, s, modelMerged, true); } funcEnv.getStream() << std::endl; } @@ -498,19 +498,19 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos const size_t idPreInit = synEnv.addInitialiser("const unsigned int idPre = rowMajorIndex / " + synEnv["_row_stride"] + ";"); // Add presynaptic and synapse index to environment - synEnv.add("id_pre", "idPre", {colMajorIdxInit, rowMajorIdxInit, idPreInit}, {"_col_stride", "_row_stride", "_remap"}); - synEnv.add("id_syn", "rowMajorIndex", {colMajorIdxInit, rowMajorIdxInit}, {"_col_stride", "_remap"}); + synEnv.add(Type::Uint32.addConst(), "id_pre", "idPre", {colMajorIdxInit, rowMajorIdxInit, idPreInit}, {"_col_stride", "_row_stride", "_remap"}); + synEnv.add(Type::Uint32.addConst(), "id_syn", "rowMajorIndex", {colMajorIdxInit, rowMajorIdxInit}, {"_col_stride", "_remap"}); } else { // Add initialiser to calculate synaptic index const size_t idSynInit = groupEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["num_post"] + ") + spike;"); // Add presynaptic and synapse index to environment - synEnv.add(Type::Uint32, "id_pre", "i"); - synEnv.add(Type::Uint32, "id_syn", "idSyn", {idSynInit}, {"num_post"}); + synEnv.add(Type::Uint32.addConst(), "id_pre", "i"); + synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}, {"num_post"}); } - synEnv.add(Type::Uint32, "id_post", "spike"); + synEnv.add(Type::Uint32.addConst(), "id_post", "spike"); synEnv.add(Type::AddToPre, "addToPre", synEnv["_out_pre"] + "[" + s.getPreISynIndex(1, synEnv["id_pre"]) + "] += $(0)"); s.generateSynapseUpdate(*this, synEnv, modelMerged); From b8aff4a0e53ee017d1d5ebe9620632af87c7a4f3 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 22 Jun 2023 12:30:46 +0100 Subject: [PATCH 243/725] more fixing --- .../backends/single_threaded_cpu/backend.h | 10 +- .../genn/genn/code_generator/backendBase.h | 115 ++++++++++-------- .../genn/genn/code_generator/environment.h | 12 +- .../genn/genn/code_generator/groupMerged.h | 6 +- .../genn/code_generator/initGroupMerged.h | 2 + .../genn/code_generator/modelSpecMerged.h | 2 +- .../code_generator/neuronUpdateGroupMerged.h | 2 +- .../backends/single_threaded_cpu/backend.cc | 67 +++++----- src/genn/genn/code_generator/groupMerged.cc | 48 -------- .../genn/code_generator/initGroupMerged.cc | 38 ++++++ .../code_generator/neuronUpdateGroupMerged.cc | 2 +- 11 files changed, 156 insertions(+), 148 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index cb8eec02fe..aa1478e2ef 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -219,21 +219,21 @@ class BACKEND_EXPORT Backend : public BackendBase //! Helper to generate code to copy reduced custom update group variables back to memory /*! Because reduction operations are unnecessary in unbatched single-threaded CPU models so there's no need to actually reduce */ - void genWriteBackReductions(EnvironmentExternal &env, const CustomUpdateGroupMerged &cg, const std::string &idxName) const; + void genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateGroupMerged &cg, const std::string &idxName) const; //! Helper to generate code to copy reduced custom weight update group variables back to memory /*! Because reduction operations are unnecessary in unbatched single-threaded CPU models so there's no need to actually reduce */ - void genWriteBackReductions(EnvironmentExternal &env, const CustomUpdateWUGroupMerged &cg, const std::string &idxName) const; + void genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateWUGroupMerged &cg, const std::string &idxName) const; template - void genWriteBackReductions(EnvironmentExternal &env, const G &cg, const std::string &idxName, R getVarRefIndexFn) const + void genWriteBackReductions(EnvironmentExternalBase &env, G &cg, const std::string &idxName, R getVarRefIndexFn) const { const auto *cm = cg.getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { // If variable is a reduction target, copy value from register straight back into global memory if(v.access & VarAccessModeAttribute::REDUCE) { const std::string idx = env.getName(idxName); - env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(getVarAccessDuplication(v.access), idx) << "] = l" << v.name << ";" << std::endl; + env.getStream() << "group->" << v.name << "[" << cg.getVarIndex(getVarAccessDuplication(v.access), idx) << "] = " << env[v.name] << ";" << std::endl; } } @@ -244,7 +244,7 @@ class BACKEND_EXPORT Backend : public BackendBase // If variable reference is a reduction target, copy value from register straight back into global memory if(modelVarRef.access & VarAccessModeAttribute::REDUCE) { const std::string idx = env.getName(idxName); - env.getStream() << "group->" << modelVarRef.name << "[" << getVarRefIndexFn(varRef, idx) << "] = l" << modelVarRef.name << ";" << std::endl; + env.getStream() << "group->" << modelVarRef.name << "[" << getVarRefIndexFn(varRef, idx) << "] = " << env[modelVarRef.name] << ";" << std::endl; } } } diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 26df023082..61b59b9fa1 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -573,49 +573,49 @@ class GENN_EXPORT BackendBase void genSynapseIndexCalculation(EnvironmentGroupMergedField &env, unsigned int batchSize) const { // Synapse group fields - groupEnv.addField(Type::Uint32.addConst(), "num_pre", - Type::Uint32, "numSrcNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); - groupEnv.addField(Type::Uint32.addConst(), "num_post", - Type::Uint32, "numTrgNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); - groupEnv.addField(Type::Uint32, "_row_stride", "rowStride", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(getSynapticMatrixRowStride(sg)); }); - groupEnv.addField(Type::Uint32, "_col_stride", "colStride", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getMaxSourceConnections()); }); + env.addField(Type::Uint32.addConst(), "num_pre", + Type::Uint32, "numSrcNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + env.addField(Type::Uint32.addConst(), "num_post", + Type::Uint32, "numTrgNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); + env.addField(Type::Uint32, "_row_stride", "rowStride", + [this](const SynapseGroupInternal &sg, size_t) { return std::to_string(getSynapticMatrixRowStride(sg)); }); + env.addField(Type::Uint32, "_col_stride", "colStride", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getMaxSourceConnections()); }); // Postsynaptic model fields - groupEnv.addField(modelMerged.getModel().getPrecision().createPointer(), "_out_post", "outPost", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); - groupEnv.addField(modelMerged.getModel().getPrecision().createPointer(), "_den_delay", "denDelay", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); - groupEnv.addField(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); + env.addField(env.getGroup().getScalarType().createPointer(), "_out_post", "outPost", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); + env.addField(env.getGroup().getScalarType().createPointer(), "_den_delay", "denDelay", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); + env.addField(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); // Presynaptic output fields - groupEnv.addField(modelMerged.getModel().getPrecision().createPointer(), "_out_pre", "outPre", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); + env.addField(env.getGroup().getScalarType().createPointer(), "_out_pre", "outPre", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); // Source neuron fields - groupEnv.addField(Type::Uint32.createPointer(), "_src_spk_que_ptr", "srcSpkQuePtr", - [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getSrcNeuronGroup()->getName(); }); - groupEnv.addField(Type::Uint32.createPointer(), "_src_spk_cnt", "srcSpkCnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getSrcNeuronGroup()->getName(); }); - groupEnv.addField(Type::Uint32.createPointer(), "_src_spk", "srcSpk", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getSrcNeuronGroup()->getName(); }); - groupEnv.addField(Type::Uint32.createPointer(), "_src_spk_evnt_cnt", "srcSpkCntEvnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getSrcNeuronGroup()->getName(); }); - groupEnv.addField(Type::Uint32.createPointer(), "_src_spk_evnt", "srcSpkEvnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkEvnt" + g.getSrcNeuronGroup()->getName(); }); + env.addField(Type::Uint32.createPointer(), "_src_spk_que_ptr", "srcSpkQuePtr", + [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getSrcNeuronGroup()->getName(); }); + env.addField(Type::Uint32.createPointer(), "_src_spk_cnt", "srcSpkCnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getSrcNeuronGroup()->getName(); }); + env.addField(Type::Uint32.createPointer(), "_src_spk", "srcSpk", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getSrcNeuronGroup()->getName(); }); + env.addField(Type::Uint32.createPointer(), "_src_spk_evnt_cnt", "srcSpkCntEvnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getSrcNeuronGroup()->getName(); }); + env.addField(Type::Uint32.createPointer(), "_src_spk_evnt", "srcSpkEvnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkEvnt" + g.getSrcNeuronGroup()->getName(); }); // Target neuron fields - groupEnv.addField(Type::Uint32.createPointer(), "_trg_spk_que_ptr", "trgSpkQuePtr", - [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getTrgNeuronGroup()->getName(); }); - groupEnv.addField(Type::Uint32.createPointer(), "_trg_spk_cnt", "trgSpkCnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getTrgNeuronGroup()->getName(); }); - groupEnv.addField(Type::Uint32.createPointer(), "_trg_spk", "trgSpk", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getTrgNeuronGroup()->getName(); }); + env.addField(Type::Uint32.createPointer(), "_trg_spk_que_ptr", "trgSpkQuePtr", + [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getTrgNeuronGroup()->getName(); }); + env.addField(Type::Uint32.createPointer(), "_trg_spk_cnt", "trgSpkCnt", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getTrgNeuronGroup()->getName(); }); + env.addField(Type::Uint32.createPointer(), "_trg_spk", "trgSpk", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getTrgNeuronGroup()->getName(); }); // If batching is enabled if(batchSize > 1) { @@ -646,7 +646,7 @@ class GENN_EXPORT BackendBase std::ostringstream kernBatchOffsetInit; kernBatchOffsetInit << "const unsigned int kernBatchOffset = "; for(size_t i = 0; i < kernelSize.size(); i++) { - kernBatchOffsetInit << env.getGroup().getKernelSize(i) << " * "; + kernBatchOffsetInit << getKernelSize(env.getGroup(), i) << " * "; } // And finally by batch @@ -681,9 +681,6 @@ class GENN_EXPORT BackendBase env.add(Type::Uint32, "_pre_batch_delay_slot", "preBatchDelaySlot", {env.addInitialiser("const unsigned int preBatchDelaySlot = preDelaySlot + (batch * " + std::to_string(numSrcDelaySlots) + ");")}, {"_pre_delay_slot"}); - - os << << std::endl; - env.add(Type::Uint32, "_pre_batch_delay_offset", "preBatchDelayOffset", {env.addInitialiser("const unsigned int preBatchDelayOffset = preDelayOffset + (preBatchOffset * " + std::to_string(numSrcDelaySlots) + ");")}, {"_pre_delay_offset", "_pre_batch_offset"}); @@ -692,10 +689,15 @@ class GENN_EXPORT BackendBase if(env.getGroup().getArchetype().getWUModel()->isPrevPreSpikeTimeRequired() || env.getGroup().getArchetype().getWUModel()->isPrevPreSpikeEventTimeRequired()) { - os << "const unsigned int prePrevSpikeTimeDelayOffset = " << "((*group->srcSpkQuePtr + " << (numSrcDelaySlots - numDelaySteps - 1) << ") % " << numSrcDelaySlots << ")" << " * group->numSrcNeurons;" << std::endl; + env.add(Type::Uint32, "_pre_prev_spike_time_delay_offset", "prePrevSpikeTimeDelayOffset", + {env.addInitialiser("const unsigned int prePrevSpikeTimeDelayOffset = ((*" + env["_src_spk_que_ptr"] + " + " + + std::to_string(numSrcDelaySlots - numDelaySteps - 1) + ") % " + std::to_string(numSrcDelaySlots) + ") * " + env["num_pre"] + ";")}, + {"_src_spk_que_ptr", "num_pre"}); if(batchSize > 1) { - os << "const unsigned int prePrevSpikeTimeBatchDelayOffset = prePrevSpikeTimeDelayOffset + (preBatchOffset * " << numSrcDelaySlots << ");" << std::endl; + env.add(Type::Uint32, "_pre_prev_spike_time_batch_delay_offset", "prePrevSpikeTimeBatchDelayOffset", + {env.addInitialiser("const unsigned int prePrevSpikeTimeBatchDelayOffset =prePrevSpikeTimeDelayOffset + (" + env["_pre_batch_offset"] + " * " + std::to_string(numSrcDelaySlots) + ");")}, + {"_pre_prev_spike_time_delay_offset", "_pre_batch_offset"}); } } } @@ -705,25 +707,40 @@ class GENN_EXPORT BackendBase const unsigned int numBackPropDelaySteps = env.getGroup().getArchetype().getBackPropDelaySteps(); const unsigned int numTrgDelaySlots = env.getGroup().getArchetype().getTrgNeuronGroup()->getNumDelaySlots(); - os << "const unsigned int postDelaySlot = "; + std::ostringstream postDelaySlotInit; + postDelaySlotInit << "const unsigned int postDelaySlot = "; if(numBackPropDelaySteps == 0) { - os << "*group->trgSpkQuePtr;" << std::endl; + postDelaySlotInit << "*" << env["_trg_spk_que_ptr"] << ";" << std::endl; } else { - os << "(*group->trgSpkQuePtr + " << (numTrgDelaySlots - numBackPropDelaySteps) << ") % " << numTrgDelaySlots << ";" << std::endl; + postDelaySlotInit << "(*" << env["_trg_spk_que_ptr"] << " + " << (numTrgDelaySlots - numBackPropDelaySteps) << ") % " << numTrgDelaySlots << ";" << std::endl; } - os << "const unsigned int postDelayOffset = postDelaySlot * group->numTrgNeurons;" << std::endl; + env.add(Type::Uint32, "_post_delay_slot", "postDelaySlot", + {env.addInitialiser(postDelaySlotInit.str())}, {"_trg_spk_que_ptr"}); + + env.add(Type::Uint32, "_post_delay_offset", "postDelayOffset", + {env.addInitialiser("const unsigned int postDelayOffset = postDelaySlot * " + env["num_post"] + ";")}, + {"num_post", "_post_delay_slot"}); if(batchSize > 1) { - os << "const unsigned int postBatchDelaySlot = postDelaySlot + (batch * " << numTrgDelaySlots << ");" << std::endl; - os << "const unsigned int postBatchDelayOffset = postDelayOffset + (postBatchOffset * " << numTrgDelaySlots << ");" << std::endl; + env.add(Type::Uint32, "_post_batch_delay_slot", "postBatchDelaySlot", + {env.addInitialiser("const unsigned int postBatchDelaySlot = postDelaySlot + (batch * " + std::to_string(numTrgDelaySlots) + ");")}, + {"_post_delay_slot"}); + env.add(Type::Uint32, "_post_batch_delay_offset", "postBatchDelayOffset", + {env.addInitialiser("const unsigned int postBatchDelayOffset = postDelayOffset + (postBatchOffset * " + std::to_string(numTrgDelaySlots) + ");")}, + {"_post_delay_offset", "_post_batch_offset"}); } if(env.getGroup().getArchetype().getWUModel()->isPrevPostSpikeTimeRequired()) { - os << "const unsigned int postPrevSpikeTimeDelayOffset = " << "((*group->trgSpkQuePtr + " << (numTrgDelaySlots - numBackPropDelaySteps - 1) << ") % " << numTrgDelaySlots << ")" << " * group->numTrgNeurons;" << std::endl; - + env.add(Type::Uint32, "_post_prev_spike_time_delay_offset", "postPrevSpikeTimeDelayOffset", + {env.addInitialiser("const unsigned int postPrevSpikeTimeDelayOffset = ((*" + env["_trg_spk_que_ptr"] + " + " + + std::to_string(numTrgDelaySlots - numBackPropDelaySteps - 1) + ") % " + std::to_string(numTrgDelaySlots) + ") * " + env["num_post"] + ";")}, + {"_trg_spk_que_ptr", "num_post"}); + if(batchSize > 1) { - os << "const unsigned int postPrevSpikeTimeBatchDelayOffset = postPrevSpikeTimeDelayOffset + (postBatchOffset * " << numTrgDelaySlots << ");" << std::endl; + env.add(Type::Uint32, "_post_prev_spike_time_batch_delay_offset", "postPrevSpikeTimeBatchDelayOffset", + {env.addInitialiser("const unsigned int postPrevSpikeTimeBatchDelayOffset = postPrevSpikeTimeDelayOffset + (" + env["_post_batch_offset"] + " * " + std::to_string(numTrgDelaySlots) + ");")}, + {"_post_prev_spike_time_delay_offset", "_post_batch_offset"}); } } diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 811017506a..2bd78cc992 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -141,9 +141,15 @@ class EnvironmentSubstitutionPolicy template class EnvironmentFieldPolicy { +public: + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + const G &getGroup() const{ return m_Group; } + protected: using Payload = std::tuple>; - + EnvironmentFieldPolicy(G &group, F &fieldGroup) : m_Group(group), m_FieldGroup(fieldGroup) { @@ -187,8 +193,6 @@ class EnvironmentFieldPolicy } } - const G &getGroup() const{ return m_Group; } - private: std::reference_wrapper m_FieldGroup; std::reference_wrapper m_Group; @@ -198,7 +202,7 @@ class EnvironmentFieldPolicy // GeNN::CodeGenerator::EnvironmentExternalDynamicBase //---------------------------------------------------------------------------- template -class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, protected P +class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P { public: template diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 5b133b9182..a198425667 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -456,8 +456,7 @@ class GroupMerged : public ChildGroupMerged class GENN_EXPORT NeuronSpikeQueueUpdateGroupMerged : public GroupMerged { public: - NeuronSpikeQueueUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + using GroupMerged::GroupMerged; //------------------------------------------------------------------------ // Public API @@ -485,8 +484,7 @@ class GENN_EXPORT NeuronSpikeQueueUpdateGroupMerged : public GroupMerged { public: - NeuronPrevSpikeTimeUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + using GroupMerged::GroupMerged; //------------------------------------------------------------------------ // Public API diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 9d6bf52db4..efe63f8931 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -264,6 +264,8 @@ class GENN_EXPORT SynapseInitGroupMerged : public InitGroupMergedBase, SynapseWUVarAdapter> { public: + using InitGroupMergedBase::InitGroupMergedBase; + boost::uuids::detail::sha1::digest_type getHashDigest() const; void generateRunner(const BackendBase &backend, diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index a2b53d0efd..3259caa554 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -564,7 +564,7 @@ class GENN_EXPORT ModelSpecMerged size_t i = 0; for(const auto &p : protoMergedGroups) { // Add group to vector - mergedGroups.emplace_back(i, m_TypeContext, backend, p.second); + mergedGroups.emplace_back(i, m_TypeContext, p.second); generateGroup(mergedGroups.back()); // Loop through fields diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index b8c176846b..7a4e5dd410 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -161,7 +161,7 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase bool isParamReferenced(const std::string ¶mName) const; }; - NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups); //------------------------------------------------------------------------ diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index d202375516..1b969756aa 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -72,33 +72,30 @@ class Timer //----------------------------------------------------------------------- template -void genKernelIteration(EnvironmentExternal &env, const G &g, size_t numKernelDims, std::function/*BackendBase::Handler*/ handler) +void genKernelIteration(EnvironmentExternalBase &env, const G &g, size_t numKernelDims, BackendBase::HandlerEnv handler) { - EnvironmentSubstitute varEnv(env); - // Define recursive function to generate nested kernel initialisation loops // **NOTE** this is a std::function as type of auto lambda couldn't be determined inside for recursive call std::function generateRecursive = - [&handler, &varEnv, &g, &generateRecursive, numKernelDims] + [&handler, &env, &g, &generateRecursive, numKernelDims] (size_t depth) { // Loop through this kernel dimensions const std::string idxVar = "k" + std::to_string(depth); - varEnv.getStream() << "for(unsigned int " << idxVar << " = 0; " << idxVar << " < " << g.getKernelSize(depth) << "; " << idxVar << "++)"; + env.getStream() << "for(unsigned int " << idxVar << " = 0; " << idxVar << " < " << getKernelSize(g, depth) << "; " << idxVar << "++)"; { - CodeStream::Scope b(varEnv.getStream()); - EnvironmentSubstitute loopEnv(varEnv); + CodeStream::Scope b(env.getStream()); + EnvironmentGroupMergedField loopEnv(env, g); // Add substitution for this kernel index - loopEnv.addSubstitution("id_kernel_" + std::to_string(depth), idxVar); + loopEnv.add(Type::Uint32.addConst(), "id_kernel_" + std::to_string(depth), idxVar); // If we've recursed through all dimensions if (depth == (numKernelDims - 1)) { // Generate kernel index and use as "synapse" index // **TODO** rename - assert(false); - //const size_t addSynapse = loopEnv.addInitialiser("const unsigned int kernelInd = " + g.genKernelIndex(loopEnv) + ";"); - //loopEnv.addVarSubstitution("id_syn", "kernelInd", addSynapse); + loopEnv.add(Type::Uint32.addConst(), "id_syn", "kernelInd", + {loopEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(g, loopEnv) + ";")}); // Call handler handler(loopEnv); @@ -602,7 +599,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, c); - genCustomUpdateIndexCalculation(groupEnv, c); + genCustomUpdateIndexCalculation(groupEnv); if (c.getArchetype().isNeuronReduction()) { // Initialise reduction targets @@ -616,7 +613,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Generate custom update EnvironmentGroupMergedField memberEnv(groupEnv, c); - memberEnv.addSubstitution("id", "i"); + memberEnv.add(Type::Uint32.addConst(), "id", "i"); c.generateCustomUpdate(*this, memberEnv); // Loop through reduction targets and generate reduction @@ -684,7 +681,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)"; { // If this synapse group has sparse connectivity, loop through length of this row - CodeStream::Scope b(synEnv.getStream()); + CodeStream::Scope b(groupEnv.getStream()); if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { groupEnv.getStream() << "for(unsigned int s = 0; s < " << groupEnv["_row_length"] << "[i]; s++)"; } @@ -715,9 +712,9 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host else { synEnv.add(Type::Uint32.addConst(), "id_post", "j"); - const size_t idSynInit = ; - synEnv.addSubstitution("id_syn", "idSyn", - {synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["num_post"] + ") + j;")}), + synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", + {synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["num_post"] + ") + j;")}, + {"num_post"}); } // Generate custom update @@ -806,16 +803,16 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host groupEnv.add(Type::Uint32, "id_post", "j"); // Add conditional initialisation code to calculate synapse index - groupEnv.addSubstitution(Type::Uint32, "id_syn", "idSyn", - {groupEnv.addInitialiser("const unsigned int idSyn = (i * " + groupEnv["num_post"] + ") + j;")}, - {"num_post"}); + groupEnv.add(Type::Uint32, "id_syn", "idSyn", + {groupEnv.addInitialiser("const unsigned int idSyn = (i * " + groupEnv["num_post"] + ") + j;")}, + {"num_post"}); // Generate custom update - c.generateCustomUpdate(*this, synEnv); + c.generateCustomUpdate(*this, groupEnv); // Update transpose variable // **YUCK** this is sorta outside scope - synEnv.getStream() << groupEnv[transposeVarName + "_transpose"] << "[(j * " << groupEnv["num_pre"] << ") + i] = l" << transposeVarName << ";" << std::endl; + groupEnv.getStream() << groupEnv[transposeVarName + "_transpose"] << "[(j * " << groupEnv["num_pre"] << ") + i] = l" << transposeVarName << ";" << std::endl; } } @@ -977,7 +974,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler funcEnv.getStream() << "// Synapse sparse connectivity" << std::endl; modelMerged.genMergedSynapseConnectivityInitGroups( *this, - [this, &funcEnv, &modelMerged](auto &c) + [this, &funcEnv, &modelMerged](auto &s) { CodeStream::Scope b(funcEnv.getStream()); funcEnv.getStream() << "// merged synapse connectivity init group " << s.getIndex() << std::endl; @@ -987,7 +984,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Get reference to group funcEnv.getStream() << "const auto *group = &mergedSynapseConnectivityInitGroup" << s.getIndex() << "[g]; " << std::endl; - EnvironmentGroupMergedField groupEnv(funcEnv, c); + EnvironmentGroupMergedField groupEnv(funcEnv, s); // If matrix connectivity is ragged if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { @@ -1049,13 +1046,13 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler addSynapse << "const unsigned int idx = " << "(" + groupEnv["id_pre"] + " * " << groupEnv["_row_stride"] << ") + " << groupEnv["_row_length"] << "[i];" << std::endl; } else { - addSynapse << "const unsigned int idx = " << "(($(0)) * " << groupEnv["_row_stride"] << ") + groupEnv["_row_length"][$(0)];" << std::endl; + addSynapse << "const unsigned int idx = " << "(($(0)) * " << groupEnv["_row_stride"] << ") + " << groupEnv["_row_length"] << "[$(0)];" << std::endl; } } // If there is a kernel if(!s.getArchetype().getKernelSize().empty()) { - EnvironmentGroupMergedField kernelInitEnv(groupEnv, c); + EnvironmentGroupMergedField kernelInitEnv(groupEnv, s); // Replace $(id_post) with first 'function' parameter as simulation code is // going to be, in turn, substituted into procedural connectivity generation code @@ -1078,7 +1075,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler } // Call handler to initialize variables - s.generateKernelInit(*this, addSynapse, modelMerged, kernelInitSubs); + s.generateKernelInit(*this, kernelInitEnv, modelMerged); } // If there is row-building code in this snippet @@ -1114,10 +1111,10 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Call appropriate connectivity handler if(!snippet->getRowBuildCode().empty()) { - s.generateSparseRowInit(*this, groupEnv, modelMerged); + s.generateSparseRowInit(*this, groupEnv); } else { - s.generateSparseColumnInit(*this, groupEnv, modelMerged); + s.generateSparseColumnInit(*this, groupEnv); } } } @@ -1145,7 +1142,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Get reference to group funcEnv.getStream() << "const auto *group = &mergedSynapseSparseInitGroup" << s.getIndex() << "[g]; " << std::endl; - EnvironmentGroupMergedField groupEnv(funcEnv, c); + EnvironmentGroupMergedField groupEnv(funcEnv, s); // If postsynaptic learning is required, initially zero column lengths if (!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { @@ -1162,7 +1159,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); if(s.getArchetype().isWUVarInitRequired()) { groupEnv.add(Type::Uint32.addConst(), "row_len", groupEnv["_row_length"] + "[i]", - {"_row_length"}); + {}, {"_row_length"}); s.generateInit(*this, groupEnv, modelMerged); } @@ -1216,7 +1213,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Generate initialisation code groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); groupEnv.add(Type::Uint32.addConst(), "row_len", groupEnv["_row_length"] + "[i]", - {"_row_length"}); + {}, {"_row_length"}); c.generateInit(*this, groupEnv, modelMerged); } } @@ -1246,7 +1243,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Generate initialisation code groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); groupEnv.add(Type::Uint32.addConst(), "row_len", groupEnv["_row_length"] + "[i]", - {"_row_length"}); + {}, {"_row_length"}); c.generateInit(*this, groupEnv, modelMerged); } } @@ -2032,7 +2029,7 @@ void Backend::genEmitSpike(EnvironmentExternalBase &env, NeuronUpdateGroupMerged } } //-------------------------------------------------------------------------- -void Backend::genWriteBackReductions(EnvironmentExternal &env, const CustomUpdateGroupMerged &cg, const std::string &idxName) const +void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateGroupMerged &cg, const std::string &idxName) const { genWriteBackReductions(env, cg, idxName, [&cg](const Models::VarReference &varRef, const std::string &index) @@ -2043,7 +2040,7 @@ void Backend::genWriteBackReductions(EnvironmentExternal &env, const CustomUpdat }); } //-------------------------------------------------------------------------- -void Backend::genWriteBackReductions(EnvironmentExternal &env, const CustomUpdateWUGroupMerged &cg, const std::string &idxName) const +void Backend::genWriteBackReductions(EnvironmentExternalBase &env, CustomUpdateWUGroupMerged &cg, const std::string &idxName) const { genWriteBackReductions(env, cg, idxName, [&cg](const Models::WUVarReference &varRef, const std::string &index) diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 0c42754cde..250b4c1b42 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -20,23 +20,6 @@ using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- const std::string NeuronSpikeQueueUpdateGroupMerged::name = "NeuronSpikeQueueUpdate"; //---------------------------------------------------------------------------- -NeuronSpikeQueueUpdateGroupMerged::NeuronSpikeQueueUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: GroupMerged(index, typeContext, groups) -{ - using namespace Type; - - if(getArchetype().isDelayRequired()) { - addPointerField(Uint32, "spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); - } - - addPointerField(Uint32, "spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); - - if(getArchetype().isSpikeEventRequired()) { - addPointerField(Uint32, "spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); - } -} -//---------------------------------------------------------------------------- void NeuronSpikeQueueUpdateGroupMerged::genMergedGroupSpikeCountReset(EnvironmentExternalBase &env, unsigned int batchSize) const { if(getArchetype().isSpikeEventRequired()) { @@ -68,37 +51,6 @@ void NeuronSpikeQueueUpdateGroupMerged::genMergedGroupSpikeCountReset(Environmen // GeNN::CodeGenerator::NeuronPrevSpikeTimeUpdateGroupMerged //---------------------------------------------------------------------------- const std::string NeuronPrevSpikeTimeUpdateGroupMerged::name = "NeuronPrevSpikeTimeUpdate"; -//---------------------------------------------------------------------------- -NeuronPrevSpikeTimeUpdateGroupMerged::NeuronPrevSpikeTimeUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: GroupMerged(index, typeContext, groups) -{ - using namespace Type; - - if(getArchetype().isDelayRequired()) { - addPointerField(Uint32, "spkQuePtr", backend.getScalarAddressPrefix() + "spkQuePtr"); - } - - addPointerField(Uint32, "spkCnt", backend.getDeviceVarPrefix() + "glbSpkCnt"); - - if(getArchetype().isSpikeEventRequired()) { - addPointerField(Uint32, "spkCntEvnt", backend.getDeviceVarPrefix() + "glbSpkCntEvnt"); - } - - if(getArchetype().isPrevSpikeTimeRequired()) { - addPointerField(Uint32, "spk", backend.getDeviceVarPrefix() + "glbSpk"); - addPointerField(getTimeType(), "prevST", backend.getDeviceVarPrefix() + "prevST"); - } - if(getArchetype().isPrevSpikeEventTimeRequired()) { - addPointerField(Uint32, "spkEvnt", backend.getDeviceVarPrefix() + "glbSpkEvnt"); - addPointerField(getTimeType(), "prevSET", backend.getDeviceVarPrefix() + "prevSET"); - } - - if(getArchetype().isDelayRequired()) { - addField(Uint32, "numNeurons", - [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); - } -} //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronGroupMergedBase diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index c20e79d875..fac6caf81b 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -515,6 +515,25 @@ void NeuronInitGroupMerged::genInitSpikeTime(const BackendBase &backend, Environ //---------------------------------------------------------------------------- const std::string SynapseInitGroupMerged::name = "SynapseInit"; //---------------------------------------------------------------------------- +boost::uuids::detail::sha1::digest_type SynapseInitGroupMerged::getHashDigest() const +{ + boost::uuids::detail::sha1 hash; + + // Update hash with standard archetype hash and var init parameters and derived parameters + updateBaseHash(hash); + + // Update hash with archetype's hash digest + Utils::updateHash(getArchetype().getWUInitHashDigest(), hash); + + // Update hash with number of neurons in pre and postsynaptic population + updateHash([](const SynapseGroupInternal &g) { return g.getSrcNeuronGroup()->getNumNeurons(); }, hash); + updateHash([](const SynapseGroupInternal &g) { return g.getTrgNeuronGroup()->getNumNeurons(); }, hash); + updateHash([](const SynapseGroupInternal &g) { return g.getMaxConnections(); }, hash); + updateHash([](const SynapseGroupInternal &g) { return g.getMaxSourceConnections(); }, hash); + + return hash.get_digest(); +} +//---------------------------------------------------------------------------- void SynapseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { // Create environment for group @@ -572,6 +591,25 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, Environmen //---------------------------------------------------------------------------- const std::string SynapseSparseInitGroupMerged::name = "SynapseSparseInit"; //---------------------------------------------------------------------------- +boost::uuids::detail::sha1::digest_type SynapseSparseInitGroupMerged::getHashDigest() const +{ + boost::uuids::detail::sha1 hash; + + // Update hash with standard archetype hash and var init parameters and derived parameters + updateBaseHash(hash); + + // Update hash with archetype's hash digest + Utils::updateHash(getArchetype().getWUInitHashDigest(), hash); + + // Update hash with number of neurons in pre and postsynaptic population + updateHash([](const SynapseGroupInternal &g) { return g.getSrcNeuronGroup()->getNumNeurons(); }, hash); + updateHash([](const SynapseGroupInternal &g) { return g.getTrgNeuronGroup()->getNumNeurons(); }, hash); + updateHash([](const SynapseGroupInternal &g) { return g.getMaxConnections(); }, hash); + updateHash([](const SynapseGroupInternal &g) { return g.getMaxSourceConnections(); }, hash); + + return hash.get_digest(); +} +//---------------------------------------------------------------------------- void SynapseSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { genInitWUVarCode(backend, env, *this, env["num_pre"] + " * " + env["_row_stride"], modelMerged.getModel().getBatchSize(), diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 46558b40df..6dd61d467d 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -380,7 +380,7 @@ bool NeuronUpdateGroupMerged::OutSynWUMPreCode::isDerivedParamHeterogeneous( con //---------------------------------------------------------------------------- const std::string NeuronUpdateGroupMerged::name = "NeuronUpdate"; //---------------------------------------------------------------------------- -NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, +NeuronUpdateGroupMerged::NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups) : NeuronGroupMergedBase(index, typeContext, groups) { From e348ef4d7719292b35ac21939ee0a899bfb77948 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 22 Jun 2023 14:29:15 +0100 Subject: [PATCH 244/725] started updating custom connectivity update group merged --- .../genn/code_generator/customConnectivityUpdateGroupMerged.h | 4 ++-- .../code_generator/customConnectivityUpdateGroupMerged.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index 6ed6483cac..dcbe480269 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -49,7 +49,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public CustomConnectivit runnerVarDecl, runnerMergedStructAlloc, name); } - void generateUpdate(const BackendBase &backend, CodeStream &os, unsigned int batchSize, Substitutions &popSubs) const; + void generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const; //! Get sorted vector of variable names, types and duplication modes which //! need updating when synapses are added and removed, belonging to archetype group @@ -89,7 +89,7 @@ class GENN_EXPORT CustomConnectivityHostUpdateGroupMerged : public CustomConnect runnerVarDecl, runnerMergedStructAlloc, name, true); } - void generateUpdate(const BackendBase &backend, CodeStream &os) const; + void generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const; //---------------------------------------------------------------------------- // Static constants diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 8df17d8b7c..8362ca4a66 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -204,7 +204,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateGroupMerged::get return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &backend, CodeStream &os, unsigned int batchSize, Substitutions &popSubs) const +void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const { Substitutions updateSubs(&popSubs); From 2e893d56d63fcf4bb5a226a158197a71e1888626 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 22 Jun 2023 18:19:09 +0100 Subject: [PATCH 245/725] first go at adding some custom syntax for_each_synapse{} --- .../genn/genn/code_generator/codeGenUtils.h | 7 +- include/genn/genn/transpiler/prettyPrinter.h | 6 +- include/genn/genn/transpiler/statement.h | 20 +++ include/genn/genn/transpiler/token.h | 2 +- include/genn/genn/transpiler/typeChecker.h | 13 +- src/genn/genn/code_generator/codeGenUtils.cc | 13 +- .../customConnectivityUpdateGroupMerged.cc | 120 +++++++++++------- src/genn/genn/transpiler/parser.cc | 9 +- src/genn/genn/transpiler/prettyPrinter.cc | 32 ++++- src/genn/genn/transpiler/scanner.cc | 1 + src/genn/genn/transpiler/typeChecker.cc | 43 +++++-- 11 files changed, 195 insertions(+), 71 deletions(-) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 357d8e92fc..252cfbc466 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -108,7 +108,8 @@ GENN_EXPORT std::string upgradeCodeString(const std::string &codeString); */ //-------------------------------------------------------------------------- GENN_EXPORT std::tuple scanParseAndTypeCheckStatements( - const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, Transpiler::ErrorHandlerBase &errorHandler); + const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, + Transpiler::ErrorHandlerBase &errorHandler, Transpiler::TypeChecker::StatementHandler forEachSynapseHandler = nullptr); //-------------------------------------------------------------------------- /*! \brief This function uses the transpiler to scan, parse and type check expression contained in a code string @@ -128,7 +129,9 @@ GENN_EXPORT void prettyPrintExpression(const std::string &code, const Type::Type /*! \brief This function uses the transpiler to scan, parse, type check and pretty print statametns contained in a code string */ //-------------------------------------------------------------------------- -GENN_EXPORT void prettyPrintStatements(const std::string &code, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, Transpiler::ErrorHandlerBase &errorHandler); +GENN_EXPORT void prettyPrintStatements(const std::string &code, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, + Transpiler::ErrorHandlerBase &errorHandler, Transpiler::TypeChecker::StatementHandler forEachSynapseTypeCheckHandler = nullptr, + Transpiler::PrettyPrinter::StatementHandler forEachSynapsePrettyPrintHandler = nullptr); //------------------------------------------------------------------------- /*! diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h index 456868471b..2d02332316 100644 --- a/include/genn/genn/transpiler/prettyPrinter.h +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -1,6 +1,7 @@ #pragma once // Standard C++ includes +#include #include // GeNN includes @@ -37,11 +38,14 @@ class EnvironmentBase virtual CodeGenerator::CodeStream &getStream() = 0; }; +typedef std::function)> StatementHandler; + //--------------------------------------------------------------------------- // Free functions //--------------------------------------------------------------------------- void print(const Statement::StatementList &statements, EnvironmentBase &environment, - const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes); + const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes, + StatementHandler forEachSynapseHandler = nullptr); void print(const Expression::ExpressionPtr &expression, EnvironmentBase &environment, const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes); } diff --git a/include/genn/genn/transpiler/statement.h b/include/genn/genn/transpiler/statement.h index 97a171b813..eb8ae5c412 100644 --- a/include/genn/genn/transpiler/statement.h +++ b/include/genn/genn/transpiler/statement.h @@ -19,6 +19,7 @@ class Continue; class Do; class Expression; class For; +class ForEachSynapse; class If; class Labelled; class Switch; @@ -40,6 +41,7 @@ class Visitor virtual void visit(const Do &doStatement) = 0; virtual void visit(const Expression &expression) = 0; virtual void visit(const For &forStatement) = 0; + virtual void visit(const ForEachSynapse &forEachSynapseStatement) = 0; virtual void visit(const If &ifStatement) = 0; virtual void visit(const Labelled &labelled) = 0; virtual void visit(const Switch &switchStatement) = 0; @@ -180,6 +182,24 @@ class For : public Acceptable StatementPtr m_Body; }; +//--------------------------------------------------------------------------- +// GeNN::Transpiler::Statement::ForEachSynapse +//--------------------------------------------------------------------------- +class ForEachSynapse : public Acceptable +{ +public: + ForEachSynapse(Token forEachSynapse, StatementPtr body) + : m_ForEachSynapse(forEachSynapse), m_Body(std::move(body)) + {} + + const Token &getForEachSynapse() const { return m_ForEachSynapse; } + const Base *getBody() const { return m_Body.get(); } + +private: + Token m_ForEachSynapse; + StatementPtr m_Body; +}; + //--------------------------------------------------------------------------- // GeNN::Transpiler::Statement::If //--------------------------------------------------------------------------- diff --git a/include/genn/genn/transpiler/token.h b/include/genn/genn/transpiler/token.h index 66b3cbcf13..e7d66cbf5f 100644 --- a/include/genn/genn/transpiler/token.h +++ b/include/genn/genn/transpiler/token.h @@ -44,7 +44,7 @@ struct Token TYPE_QUALIFIER, // Keywords - DO, ELSE, FALSE, FOR, IF, TRUE, WHILE, SWITCH, CONTINUE, BREAK, CASE, DEFAULT, + DO, ELSE, FALSE, FOR, FOR_EACH_SYNAPSE, IF, TRUE, WHILE, SWITCH, CONTINUE, BREAK, CASE, DEFAULT, END_OF_FILE, }; diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index c073641d35..3e7800b741 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -1,6 +1,7 @@ #pragma once // Standard C++ includes +#include #include #include #include @@ -32,8 +33,6 @@ class TypeCheckError : public std::runtime_error } }; -typedef std::unordered_map ResolvedTypeMap; - //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- @@ -52,12 +51,18 @@ class EnvironmentBase Type::ResolvedType getType(const Token &name, ErrorHandlerBase &errorHandler); }; +//--------------------------------------------------------------------------- +// Typedefines +//--------------------------------------------------------------------------- +typedef std::unordered_map ResolvedTypeMap; +typedef std::function StatementHandler; + //--------------------------------------------------------------------------- // Free functions //--------------------------------------------------------------------------- ResolvedTypeMap typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, - ErrorHandlerBase &errorHandler); + ErrorHandlerBase &errorHandler, StatementHandler forEachSynapseHandler = nullptr); ResolvedTypeMap typeCheck(const Expression::Base *expression, EnvironmentBase &environment, ErrorHandlerBase &errorHandler); -} // namespace MiniParse::GeNN::Transpiler +} // namespace GeNN::Transpiler::TypeChecker diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index 3d0b7ea25a..ec4c310d43 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -487,7 +487,8 @@ std::string upgradeCodeString(const std::string &codeString) } //---------------------------------------------------------------------------- std::tuple scanParseAndTypeCheckStatements( - const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, Transpiler::ErrorHandlerBase &errorHandler) + const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, + Transpiler::ErrorHandlerBase &errorHandler, Transpiler::TypeChecker::StatementHandler forEachSynapseHandler) { using namespace Transpiler; @@ -501,7 +502,7 @@ std::tuple(expressionTypes), env, typeContext, std::get<1>(expressionTypes)); } //-------------------------------------------------------------------------- -void prettyPrintStatements(const std::string &code, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, Transpiler::ErrorHandlerBase &errorHandler) +void prettyPrintStatements(const std::string &code, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, + Transpiler::ErrorHandlerBase &errorHandler, Transpiler::TypeChecker::StatementHandler forEachSynapseTypeCheckHandler, + Transpiler::PrettyPrinter::StatementHandler forEachSynapsePrettyPrintHandler) { // Scan, parse and type check statements - auto statementTypes = scanParseAndTypeCheckStatements(code, typeContext, env, errorHandler); + auto statementTypes = scanParseAndTypeCheckStatements(code, typeContext, env, errorHandler, forEachSynapseTypeCheckHandler); // Pretty print - Transpiler::PrettyPrinter::print(std::get<0>(statementTypes), env, typeContext, std::get<1>(statementTypes)); + Transpiler::PrettyPrinter::print(std::get<0>(statementTypes), env, typeContext, std::get<1>(statementTypes), forEachSynapsePrettyPrintHandler); } } // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 8362ca4a66..05a48304cc 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -6,6 +6,9 @@ // GeNN code generator includes #include "code_generator/modelSpecMerged.h" +// GeNN transpiler includes +#include "transpiler/errorHandler.h" + using namespace GeNN; using namespace GeNN::CodeGenerator; @@ -102,7 +105,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t })); - addField(Uint32, "rowStride", + /*addField(Uint32, "rowStride", [&backend](const auto &cg, size_t) { const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); @@ -173,7 +176,7 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t const auto &varRef = m_SortedDependentVars[g][i]; return backend.getDeviceVarPrefix() + varRef.getVar().name + varRef.getTargetName(); }); - } + }*/ } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateGroupMerged::getHashDigest() const @@ -206,40 +209,49 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateGroupMerged::get //---------------------------------------------------------------------------- void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const { - Substitutions updateSubs(&popSubs); - - // Add substitutions for number of pre and postsynaptic neurons - updateSubs.addVarSubstitution("num_pre", "group->numSrcNeurons"); - updateSubs.addVarSubstitution("num_post", "group->numTrgNeurons"); - - // Define synapse loop function - // **NOTE** index is signed integer so generated code can safely use j-- to process same synapse again - // **YUCK** ideally id_post, id_syn, remove_synapse and all synaptic and postsynaptic variable substitutions would only be allowable within this scope but that's not currently possible - updateSubs.addFuncSubstitution("for_each_synapse", 1, "for(int j = 0; j < group->rowLength[" + updateSubs["id_pre"] + "]; j++){ const unsigned int idx = rowStartIdx + j; $(0) }"); - - updateSubs.addVarSubstitution("id_post", "group->ind[rowStartIdx + j]"); - updateSubs.addVarSubstitution("id_syn", "idx"); - - // Get variables which will need to be manipulated when adding and removing synapses + // Create new environment to add current source fields to neuron update group + EnvironmentGroupMergedField updateEnv(env, *this); + + // Add fields for number of pre and postsynaptic neurons + updateEnv.addField(Type::Uint32.addConst(), "num_pre", + Type::Uint32, "numSrcNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); + }); + updateEnv.addField(Type::Uint32.addConst(), "num_post", + Type::Uint32, "numTrgNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); + }); + + // Substitute parameter and derived parameter names const auto *cm = getArchetype().getCustomConnectivityUpdateModel(); + updateEnv.addParams(cm->getParamNames(), "", &CustomConnectivityUpdateInternal::getParams, + &CustomConnectivityUpdateGroupMerged::isParamHeterogeneous); + updateEnv.addDerivedParams(cm->getDerivedParams(), "", &CustomConnectivityUpdateInternal::getDerivedParams, + &CustomConnectivityUpdateGroupMerged::isDerivedParamHeterogeneous); + updateEnv.addExtraGlobalParams(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); + + // Get variables which will need to be manipulated when adding and removing synapses const auto &ccuVars = cm->getVars(); const auto &ccuVarRefs = cm->getVarRefs(); const auto &dependentVars = getSortedArchetypeDependentVars(); - // Determine if any - const bool modelBatched = (batchSize > 1); - const bool anyBatched = (modelBatched && (std::any_of(getArchetype().getVarReferences().cbegin(), getArchetype().getVarReferences().cend(), - [](const auto &v){ return v.second.isDuplicated(); }) - || std::any_of(dependentVars.cbegin(), dependentVars.cend(), - [](const auto &v){ return v.isDuplicated(); }))); - // Calculate index of start of row - os << "const unsigned int rowStartIdx = " << updateSubs["id_pre"] << " * group->rowStride;" << std::endl; + updateEnv.add(Type::Uint32.addConst(), "_row_start_idx", "rowStartIdx", + {updateEnv.addInitialiser("const unsigned int rowStartIdx = " + updateEnv["id_pre"] + " * " + updateEnv["_row_stride"] + ";")}, + {"id_pre", "_row_stride"}); - // If any variables are batched - if (anyBatched) { - os << "const unsigned int synStride = group->numSrcNeurons * group->rowStride;" << std::endl; - } + updateEnv.add(Type::Uint32.addConst(), "_syn_stride", "synStride", + {updateEnv.addInitialiser("const unsigned int synStride = " + updateEnv["num_pre"] + " * " + updateEnv["_row_stride"] + ";")}, + {"num_pre", "_row_stride"}); + + std::vector addSynapseTypes{Type::Uint32}; + addSynapseTypes.reserve(1 + ccuVars.size() + ccuVarRefs.size() + dependentVars.size()); // Generate code to add a synapse to this row std::stringstream addSynapseStream; @@ -248,27 +260,28 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back CodeStream::Scope b(addSynapse); // Assert that there is space to add synapse - backend.genAssert(addSynapse, "group->rowLength[" + updateSubs["id_pre"] + "] < group->rowStride"); + backend.genAssert(addSynapse, updateEnv["_row_length"] + "[" + updateEnv["id_pre"] + "] < " + updateEnv["_row_stride"]); // Calculate index to insert synapse - addSynapse << "const unsigned newIdx = rowStartIdx + group->rowLength[" << updateSubs["id_pre"] << "];" << std::endl; + addSynapse << "const unsigned newIdx = rowStartIdx + " + updateEnv["_row_length"] + "[" << updateEnv["id_pre"] << "];" << std::endl; // Set postsynaptic target to parameter 0 - addSynapse << "group->ind[newIdx] = $(0);" << std::endl; + addSynapse << updateEnv["_ind"] + "[newIdx] = $(0);" << std::endl; // Use subsequent parameters to initialise new synapse's custom connectivity update model variables for (size_t i = 0; i < ccuVars.size(); i++) { addSynapse << "group->" << ccuVars[i].name << "[newIdx] = $(" << (1 + i) << ");" << std::endl; + addSynapseTypes.push_back(ccuVars[i].type.resolve(getTypeContext())); } // Use subsequent parameters to initialise new synapse's variables referenced via the custom connectivity update for (size_t i = 0; i < ccuVarRefs.size(); i++) { // If model is batched and this variable is duplicated - if ((batchSize > 1) && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) + if ((modelMerged.getModel().getBatchSize() > 1) && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) { // Copy parameter into a register (just incase it's e.g. a RNG call) and copy into all batches addSynapse << "const " << ccuVarRefs[i].type.resolve(getTypeContext()).getName() << " _" << ccuVarRefs[i].name << "Val = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; - addSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; + addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); addSynapse << "group->" << ccuVarRefs[i].name << "[(b * synStride) + newIdx] = _" << ccuVarRefs[i].name << "Val;" << std::endl; @@ -278,15 +291,17 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back else { addSynapse << "group->" << ccuVarRefs[i].name << "[newIdx] = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; } + + addSynapseTypes.push_back(ccuVarRefs[i].type.resolve(getTypeContext())); } // Loop through any other dependent variables for (size_t i = 0; i < dependentVars.size(); i++) { // If model is batched and this dependent variable is duplicated - if ((batchSize > 1) && dependentVars.at(i).isDuplicated()) + if ((modelMerged.getModel().getBatchSize() > 1) && dependentVars.at(i).isDuplicated()) { // Loop through all batches and zero - addSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; + addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); addSynapse << "group->_dependentVar" << i << "[(b * synStride) + newIdx] = 0;" << std::endl; @@ -296,15 +311,17 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back else { addSynapse << "group->_dependentVar" << i << "[newIdx] = 0;" << std::endl; } + + addSynapseTypes.push_back(dependentVars.at(i).getVar().type.resolve(getTypeContext())); } // Increment row length // **NOTE** this will also effect any forEachSynapse loop currently in operation - addSynapse << "group->rowLength[" << updateSubs["id_pre"] << "]++;" << std::endl; + addSynapse << updateEnv["_row_length"] + "[" << updateEnv["id_pre"] << "]++;" << std::endl; } // Add function substitution with parameters to initialise custom connectivity update variables and variable references - updateSubs.addFuncSubstitution("add_synapse", 1 + ccuVars.size() + ccuVarRefs.size(), addSynapseStream.str()); + updateEnv.add(Type::ResolvedType::createFunction(Type::Void, addSynapseTypes), "add_synapse", addSynapseStream.str()); // Generate code to remove a synapse from this row std::stringstream removeSynapseStream; @@ -313,7 +330,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back CodeStream::Scope b(removeSynapse); // Calculate index we want to copy synapse from - removeSynapse << "const unsigned lastIdx = rowStartIdx + group->rowLength[" << updateSubs["id_pre"] << "] - 1;" << std::endl; + removeSynapse << "const unsigned lastIdx = rowStartIdx + " + updateEnv["_row_length"] + "[" << updateEnv["id_pre"] << "] - 1;" << std::endl; // Copy postsynaptic target from end of row over synapse to be deleted removeSynapse << "group->ind[idx] = group->ind[lastIdx];" << std::endl; @@ -416,11 +433,26 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back "", "group->"); updateSubs.addVarNameSubstitution(cm->getExtraGlobalParams(), "", "group->"); - // Apply substitutons to row update code and write out - std::string code = cm->getRowUpdateCode(); - updateSubs.applyCheckUnreplaced(code, "custom connectivity update : merged" + std::to_string(getIndex())); - //code = ensureFtype(code, Type::modelMerged.getModel().getPrecision()); - os << code; + // Pretty print code back to environment + Transpiler::ErrorHandler errorHandler("Current source injection" + std::to_string(getIndex())); + prettyPrintStatements(cm->getRowUpdateCode(), getTypeContext(), updateEnv, errorHandler, + // Within for_each_synapse loops, define the following types + [](auto &env, auto &errorHandler) + { + env.define(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, "id_post", 0}, Type::Uint32.addConst(), errorHandler); + env.define(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, "id_syn", 0}, Type::Uint32.addConst(), errorHandler); + }, + [](auto &env, auto generateBody) + { + env.getStream() << "for(int j = 0; j < " << env["_row_length"] + "[" + env["id_pre"] + "]; j++)"; + { + CodeStream::Scope b(env.getStream()); + env.getStream() << "const unsigned int idx = rowStartIdx + j;" << std::endl; + //pdateSubs.addVarSubstitution("id_post", "group->ind[rowStartIdx + j]"); + //updateSubs.addVarSubstitution("id_syn", "idx"); + generateBody(); + } + }); } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 36fb3f7dd3..c43c60f694 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -631,6 +631,7 @@ Statement::StatementPtr parseIterationStatement(ParserState &parserState) // iteration-statement ::= // "while" "(" expression ")" statement // "do" statement "while" "(" expression ")" ";" + // "for" statement // "for" "(" expression? ";" expression? ";" expression? ")" statement // "for" "(" declaration expression? ";" expression? ")" statement @@ -655,6 +656,12 @@ Statement::StatementPtr parseIterationStatement(ParserState &parserState) return std::make_unique(std::move(condition), std::move(body)); } + // Otherwise, if this is a for_each_synapse statement + else if(parserState.previous().type == Token::Type::FOR_EACH_SYNAPSE) { + auto body = parseStatement(parserState); + return std::make_unique(parserState.previous(), + std::move(body)); + } // Otherwise, it's a for statement else { parserState.consume(Token::Type::LEFT_PAREN, "Expect '(' after 'for'"); @@ -687,7 +694,7 @@ Statement::StatementPtr parseIterationStatement(ParserState &parserState) } parserState.consume(Token::Type::RIGHT_PAREN, "Expect ')' after for clauses"); - Statement::StatementPtr body = parseStatement(parserState); + auto body = parseStatement(parserState); // Return for statement // **NOTE** we could "de-sugar" into a while statement but this makes pretty-printing easier diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 2f0a4fa597..8897cf8114 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -123,8 +123,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { public: Visitor(const Statement::StatementList &statements, EnvironmentInternal &environment, - const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes) - : m_Environment(environment), m_Context(context), m_ResolvedTypes(resolvedTypes) + const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes, + StatementHandler forEachSynapseHandler) + : m_Environment(environment), m_Context(context), m_ResolvedTypes(resolvedTypes), m_ForEachSynapseHandler(forEachSynapseHandler) { for(auto &s : statements) { s.get()->accept(*this); @@ -134,7 +135,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor Visitor(const Expression::ExpressionPtr &expression, EnvironmentInternal &environment, const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes) - : m_Environment(environment), m_Context(context), m_ResolvedTypes(resolvedTypes) + : m_Environment(environment), m_Context(context), m_ResolvedTypes(resolvedTypes) , m_ForEachSynapseHandler(nullptr) { expression.get()->accept(*this); } @@ -402,6 +403,25 @@ class Visitor : public Expression::Visitor, public Statement::Visitor m_Environment = oldEnvironment; } + virtual void visit(const Statement::ForEachSynapse &forEachSynapseStatement) final + { + // Cache reference to current reference + std::reference_wrapper oldEnvironment = m_Environment; + + // Create new environment and set to current + EnvironmentInternal environment(m_Environment); + m_Environment = environment; + + m_ForEachSynapseHandler(m_Environment, + [this, &forEachSynapseStatement]() + { + forEachSynapseStatement.getBody()->accept(*this); + }); + // Restore old environment + m_Environment = oldEnvironment; + } + + virtual void visit(const Statement::If &ifStatement) final { m_Environment.get().getStream() << "if("; @@ -462,6 +482,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor std::reference_wrapper m_Environment; const Type::TypeContext &m_Context; const TypeChecker::ResolvedTypeMap &m_ResolvedTypes; + StatementHandler m_ForEachSynapseHandler; std::stack> m_CallArguments; }; } // Anonymous namespace @@ -470,10 +491,11 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // GeNN::Transpiler::PrettyPrinter //--------------------------------------------------------------------------- void GeNN::Transpiler::PrettyPrinter::print(const Statement::StatementList &statements, EnvironmentBase &environment, - const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes) + const Type::TypeContext &context, const TypeChecker::ResolvedTypeMap &resolvedTypes, + StatementHandler forEachSynapseHandler) { EnvironmentInternal internalEnvironment(environment); - Visitor visitor(statements, internalEnvironment, context, resolvedTypes); + Visitor visitor(statements, internalEnvironment, context, resolvedTypes, forEachSynapseHandler); } //--------------------------------------------------------------------------- void GeNN::Transpiler::PrettyPrinter::print(const Expression::ExpressionPtr &expression, EnvironmentBase &environment, diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index cc36196fa2..7cfa95480e 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -28,6 +28,7 @@ const std::unordered_map keywords{ {"else", Token::Type::ELSE}, {"false", Token::Type::FALSE}, {"for", Token::Type::FOR}, + {"for_each_synapse", Token::Type::FOR_EACH_SYNAPSE}, {"if", Token::Type::IF}, {"true", Token::Type::TRUE}, {"while", Token::Type::WHILE}, diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index e83ca1f5e8..75a54f88bb 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -176,8 +176,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { public: Visitor(const Statement::StatementList &statements, EnvironmentInternal &environment, - ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler) - : Visitor(environment, resolvedTypes, errorHandler) + ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler, StatementHandler forEachSynapseHandler) + : Visitor(environment, resolvedTypes, errorHandler, forEachSynapseHandler) { for (auto &s : statements) { s.get()->accept(*this); @@ -186,15 +186,15 @@ class Visitor : public Expression::Visitor, public Statement::Visitor Visitor(const Expression::Base *expression, EnvironmentInternal &environment, ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler) - : Visitor(environment, resolvedTypes, errorHandler) + : Visitor(environment, resolvedTypes, errorHandler, nullptr) { expression->accept(*this); } private: - Visitor(EnvironmentInternal &environment, - ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler) - : m_Environment(environment), m_ErrorHandler(errorHandler), + Visitor(EnvironmentInternal &environment, ResolvedTypeMap &resolvedTypes, + ErrorHandlerBase &errorHandler, StatementHandler forEachSynapseHandler) + : m_Environment(environment), m_ErrorHandler(errorHandler), m_ForEachSynapseHandler(forEachSynapseHandler), m_ResolvedTypes(resolvedTypes), m_InLoop(false), m_InSwitch(false) { } @@ -740,6 +740,31 @@ class Visitor : public Expression::Visitor, public Statement::Visitor m_Environment = oldEnvironment; } + virtual void visit(const Statement::ForEachSynapse &forEachSynapseStatement) final + { + if(!m_ForEachSynapseHandler) { + m_ErrorHandler.error(forEachSynapseStatement.getForEachSynapse(), + "Not supported in this context"); + throw TypeCheckError(); + } + // Cache reference to current reference + std::reference_wrapper oldEnvironment = m_Environment; + + // Create new environment and set to current + EnvironmentInternal environment(m_Environment); + m_Environment = environment; + + // Call handler to define anything required in environment + m_ForEachSynapseHandler(m_Environment, m_ErrorHandler); + + m_InLoop = true; + forEachSynapseStatement.getBody()->accept(*this); + m_InLoop = false; + + // Restore old environment + m_Environment = oldEnvironment; + } + virtual void visit(const Statement::If &ifStatement) final { ifStatement.getCondition()->accept(*this); @@ -830,6 +855,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- std::reference_wrapper m_Environment; ErrorHandlerBase &m_ErrorHandler; + StatementHandler m_ForEachSynapseHandler; ResolvedTypeMap &m_ResolvedTypes; std::stack> m_CallArguments; bool m_InLoop; @@ -856,11 +882,12 @@ Type::ResolvedType EnvironmentBase::getType(const Token &name, ErrorHandlerBase // GeNN::Transpiler::TypeChecker //--------------------------------------------------------------------------- ResolvedTypeMap GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, - ErrorHandlerBase &errorHandler) + ErrorHandlerBase &errorHandler, StatementHandler forEachSynapseHandler) { ResolvedTypeMap expressionTypes; EnvironmentInternal internalEnvironment(environment); - Visitor visitor(statements, internalEnvironment, expressionTypes, errorHandler); + Visitor visitor(statements, internalEnvironment, expressionTypes, errorHandler, + forEachSynapseHandler); return expressionTypes; } //--------------------------------------------------------------------------- From c682f9e991ead89f264aaf738d773807428b4cc1 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 23 Jun 2023 15:23:52 +0100 Subject: [PATCH 246/725] missing include --- include/genn/genn/code_generator/codeGenUtils.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 252cfbc466..8b8b2fec99 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -22,6 +22,7 @@ #include "teeStream.h" // GeNN transpiler includes +#include "transpiler/prettyPrinter.h" #include "transpiler/statement.h" #include "transpiler/typeChecker.h" From a790c559328575bb520007df50143d859ceb4c7b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 23 Jun 2023 15:24:50 +0100 Subject: [PATCH 247/725] fixed typo --- src/genn/genn/code_generator/backendBase.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 746f2d0ecd..093a8ba22b 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -65,7 +65,7 @@ void BackendBase::genCustomUpdateIndexCalculation(EnvironmentGroupMergedFieldgetNumDelaySlots()); + const std::string numDelaySlotsStr = std::to_string(env.getGroup().getArchetype().getDelayNeuronGroup()->getNumDelaySlots()); env.add(Type::Uint32.addConst(), "_batch_delay_slot", "batchDelaySlot", {env.addInitialiser("const unsigned int batchDelaySlot = (batch * " + numDelaySlotsStr + ") + delaySlot;")}, {"_delay_slot"}); From 673410effa775cc2ce97a80d6b4e1e1c6348efbf Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 23 Jun 2023 15:25:06 +0100 Subject: [PATCH 248/725] simple lazy string class for initialisers --- include/genn/genn/code_generator/lazyString.h | 73 +++++++++++++++++++ src/genn/genn/code_generator/lazyString.cc | 37 ++++++++++ src/genn/genn/genn.vcxproj | 2 + 3 files changed, 112 insertions(+) create mode 100644 include/genn/genn/code_generator/lazyString.h create mode 100644 src/genn/genn/code_generator/lazyString.cc diff --git a/include/genn/genn/code_generator/lazyString.h b/include/genn/genn/code_generator/lazyString.h new file mode 100644 index 0000000000..1e8561bf98 --- /dev/null +++ b/include/genn/genn/code_generator/lazyString.h @@ -0,0 +1,73 @@ +#pragma once + +// Standard C++ includes +#include +#include +#include + +// Forward declarations +namespace GeNN::CodeGenerator +{ +class EnvironmentExternalBase; +} + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::LazyString +//---------------------------------------------------------------------------- +//! Base class for external environments i.e. those defines OUTSIDE of transpiled code by code generator +namespace GeNN::CodeGenerator +{ +class LazyString +{ +public: + typedef std::variant, std::string>> Element; + typedef std::vector Payload; + + LazyString(const std::string &str) : m_Payload{str} + {} + LazyString(const char *chr) : m_Payload{chr} + {} + LazyString(EnvironmentExternalBase &env, const std::string &name) : m_Payload{std::make_pair(std::ref(env), name)} + {} + + //---------------------------------------------------------------------------- + // Public API + //---------------------------------------------------------------------------- + //! Evaluate lazy string + std::string str() const; + + +private: + LazyString(const Payload &payload) : m_Payload(payload){} + + //---------------------------------------------------------------------------- + // Friends + //---------------------------------------------------------------------------- + friend LazyString operator + (const LazyString& lhs, const LazyString &rhs); + + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + Payload m_Payload; +}; + +//---------------------------------------------------------------------------- +// Operators +//---------------------------------------------------------------------------- +inline LazyString operator + (const LazyString& lhs, const LazyString &rhs) +{ + std::vector payload(lhs.m_Payload); + payload.insert(payload.end(), rhs.m_Payload.cbegin(), rhs.m_Payload.cend()); + return LazyString(payload); +} + +inline LazyString operator + (const char *lhs, const LazyString &rhs) +{ + return LazyString(lhs) + rhs; +} + +inline LazyString operator + (const LazyString &lhs, const char *rhs) +{ + return lhs + LazyString(rhs); +} +} // namespace GeNN::CodeGenerator \ No newline at end of file diff --git a/src/genn/genn/code_generator/lazyString.cc b/src/genn/genn/code_generator/lazyString.cc new file mode 100644 index 0000000000..7b8d8ce563 --- /dev/null +++ b/src/genn/genn/code_generator/lazyString.cc @@ -0,0 +1,37 @@ +#include "code_generator/lazyString.h" + +// Standard C++ includes +#include + +// GeNN includes +#include "gennUtils.h" + +// GeNN code generator includes +#include "code_generator/environment.h" + +using namespace GeNN; +using namespace GeNN::CodeGenerator; + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::LazyString +//---------------------------------------------------------------------------- +std::string LazyString::str() const +{ + std::ostringstream stream; + for(const auto &e : m_Payload) + { + std::visit( + Utils::Overload{ + [&stream](const std::string &str) + { + stream << str; + }, + [&stream](const std::pair, std::string> &env) + { + stream << env.first.get()[env.second]; + }}, + e); + } + return stream.str(); +} + \ No newline at end of file diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index 6db5c9083f..f5f836a3ec 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -33,6 +33,7 @@ + @@ -87,6 +88,7 @@ + From b321da4595f34a9465a9e3d2eee9eb9b99089ff8 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 23 Jun 2023 15:25:47 +0100 Subject: [PATCH 249/725] var reference adaptors for custom connectivity update --- .../genn/customConnectivityUpdateInternal.h | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/include/genn/genn/customConnectivityUpdateInternal.h b/include/genn/genn/customConnectivityUpdateInternal.h index 15febd6e23..3c7aacaea1 100644 --- a/include/genn/genn/customConnectivityUpdateInternal.h +++ b/include/genn/genn/customConnectivityUpdateInternal.h @@ -141,6 +141,75 @@ class CustomConnectivityUpdateEGPAdapter Snippet::Base::EGPVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getExtraGlobalParams(); } +private: + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + const CustomConnectivityUpdateInternal &m_CU; +}; + +//---------------------------------------------------------------------------- +// CustomConnectivityUpdateVarRefAdapter +//---------------------------------------------------------------------------- +class CustomConnectivityUpdateVarRefAdapter +{ +public: + CustomConnectivityUpdateVarRefAdapter(const CustomConnectivityUpdateInternal &cu) : m_CU(cu) + {} + + //---------------------------------------------------------------------------- + // Public methods + //---------------------------------------------------------------------------- + Models::Base::VarRefVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getVarRefs(); } + + const std::unordered_map &getInitialisers() const{ return m_CU.getVarReferences(); } + +private: + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + const CustomConnectivityUpdateInternal &m_CU; +}; + +//---------------------------------------------------------------------------- +// CustomConnectivityUpdatePreVarRefAdapter +//---------------------------------------------------------------------------- +class CustomConnectivityUpdatePreVarRefAdapter +{ +public: + CustomConnectivityUpdatePreVarRefAdapter(const CustomConnectivityUpdateInternal &cu) : m_CU(cu) + {} + + //---------------------------------------------------------------------------- + // Public methods + //---------------------------------------------------------------------------- + Models::Base::VarRefVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPreVarRefs(); } + + const std::unordered_map &getInitialisers() const{ return m_CU.getPreVarReferences(); } + +private: + //---------------------------------------------------------------------------- + // Members + //---------------------------------------------------------------------------- + const CustomConnectivityUpdateInternal &m_CU; +}; + +//---------------------------------------------------------------------------- +// CustomConnectivityUpdatePostVarRefAdapter +//---------------------------------------------------------------------------- +class CustomConnectivityUpdatePostVarRefAdapter +{ +public: + CustomConnectivityUpdatePostVarRefAdapter(const CustomConnectivityUpdateInternal &cu) : m_CU(cu) + {} + + //---------------------------------------------------------------------------- + // Public methods + //---------------------------------------------------------------------------- + Models::Base::VarRefVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPostVarRefs(); } + + const std::unordered_map &getInitialisers() const{ return m_CU.getPostVarReferences(); } + private: //---------------------------------------------------------------------------- // Members From 869d1577b7dcfbad7499059d33e90eda3f60377a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 23 Jun 2023 15:26:31 +0100 Subject: [PATCH 250/725] * add slightly hacky additional constructor to EnvironmentExternalBase so it can be created with only a pretty printing environment above it * also switched initialisers to use LazyString --- .../genn/genn/code_generator/environment.h | 49 +++++++++-- src/genn/genn/code_generator/environment.cc | 81 +++++-------------- 2 files changed, 63 insertions(+), 67 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 2bd78cc992..c856d29cc0 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -13,6 +13,7 @@ // GeNN code generator includes #include "code_generator/codeStream.h" #include "code_generator/groupMerged.h" +#include "code_generator/lazyString.h" // GeNN transpiler includes #include "transpiler/prettyPrinter.h" @@ -36,7 +37,17 @@ class EnvironmentExternalBase : public Transpiler::PrettyPrinter::EnvironmentBas { public: explicit EnvironmentExternalBase(EnvironmentExternalBase &enclosing) - : m_Context(enclosing) + : m_Context(std::make_pair(&enclosing, &enclosing)) + { + } + + explicit EnvironmentExternalBase(Transpiler::PrettyPrinter::EnvironmentBase &enclosing) + : m_Context(std::make_pair(nullptr, &enclosing)) + { + } + + explicit EnvironmentExternalBase(Transpiler::TypeChecker::EnvironmentBase &enclosing) + : m_Context(std::make_pair(&enclosing, nullptr)) { } @@ -83,7 +94,8 @@ class EnvironmentExternalBase : public Transpiler::PrettyPrinter::EnvironmentBas //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ - std::variant, std::reference_wrapper> m_Context; + std::variant, + std::reference_wrapper> m_Context; }; //---------------------------------------------------------------------------- @@ -94,11 +106,21 @@ class EnvironmentLibrary : public EnvironmentExternalBase public: using Library = std::unordered_multimap>; - EnvironmentLibrary(EnvironmentExternalBase &enclosing, const Library &library) + explicit EnvironmentLibrary(EnvironmentExternalBase &enclosing, const Library &library) : EnvironmentExternalBase(enclosing), m_Library(library) {} - EnvironmentLibrary(CodeStream &os, const Library &library) + explicit EnvironmentLibrary(Transpiler::PrettyPrinter::EnvironmentBase &enclosing, const Library &library) + : EnvironmentExternalBase(enclosing), m_Library(library) + { + } + + explicit EnvironmentLibrary(Transpiler::TypeChecker::EnvironmentBase &enclosing, const Library &library) + : EnvironmentExternalBase(enclosing), m_Library(library) + { + } + + explicit EnvironmentLibrary(CodeStream &os, const Library &library) : EnvironmentExternalBase(os), m_Library(library) {} @@ -210,6 +232,16 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...) {} + template + EnvironmentExternalDynamicBase(Transpiler::PrettyPrinter::EnvironmentBase &enclosing, PolicyArgs&&... policyArgs) + : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...) + {} + + template + EnvironmentExternalDynamicBase(Transpiler::TypeChecker::EnvironmentBase &enclosing, PolicyArgs&&... policyArgs) + : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...) + {} + template EnvironmentExternalDynamicBase(CodeStream &os, PolicyArgs&&... policyArgs) : EnvironmentExternalBase(os), P(std::forward(policyArgs)...) @@ -221,7 +253,7 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P for(const auto &i : m_Initialisers) { // If variable requiring initialiser has been referenced, write out initialiser if (i.first) { - getContextStream() << i.second << std::endl; + getContextStream() << i.second.str() << std::endl; } } @@ -280,7 +312,7 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P } - size_t addInitialiser(const std::string &initialiser) + size_t addInitialiser(const LazyString &initialiser) { m_Initialisers.emplace_back(false, initialiser); return (m_Initialisers.size() - 1); @@ -307,7 +339,7 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P CodeStream m_Contents; std::unordered_map, std::vector, typename P::Payload>> m_Environment; - std::vector> m_Initialisers; + std::vector> m_Initialisers; }; //---------------------------------------------------------------------------- @@ -594,7 +626,8 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase enclosing)->CodeStream& { return enclosing.get().getStream(); }, - [](std::reference_wrapper os)->CodeStream& { return os.get(); }}, + [](std::pair enclosing)->CodeStream& + { + assert(enclosing.second != nullptr); + return enclosing.second->getStream(); + }, + [](std::reference_wrapper os)->CodeStream& + { + return os.get(); + }}, m_Context); } //---------------------------------------------------------------------------- @@ -43,8 +50,15 @@ std::string EnvironmentExternalBase::getContextName(const std::string &name, std { return std::visit( Utils::Overload{ - [&name, type](std::reference_wrapper enclosing)->std::string { return enclosing.get().getName(name, type); }, - [&name](std::reference_wrapper)->std::string { throw std::runtime_error("Identifier '" + name + "' undefined"); }}, + [&name, &type](std::pair enclosing)->std::string + { + assert(enclosing.second != nullptr); + return enclosing.second->getName(name, type); + }, + [&name](std::reference_wrapper)->std::string + { + throw std::runtime_error("Identifier '" + name + "' undefined"); + }}, m_Context); } //---------------------------------------------------------------------------- @@ -52,9 +66,10 @@ std::vector EnvironmentExternalBase::getContextTypes(const T { return std::visit( Utils::Overload{ - [&errorHandler, &name](std::reference_wrapper enclosing)->std::vector - { - return enclosing.get().getTypes(name, errorHandler); + [&errorHandler, &name](std::pair enclosing)->std::vector + { + assert(enclosing.first != nullptr); + return enclosing.first->getTypes(name, errorHandler); }, [&errorHandler, &name](std::reference_wrapper)->std::vector { @@ -103,56 +118,4 @@ std::string EnvironmentLibrary::getName(const std::string &name, std::optional type) -{ - // If there isn't a substitution for this name, try and get name from context - auto var = m_VarSubstitutions.find(name); - if(var == m_VarSubstitutions.end()) { - return getContextName(name, type); - } - // Otherwise, return substitution - else { - // If this variable relies on any initialiser statements, mark these initialisers as required - for(const auto i : var->second.second) { - m_Initialisers.at(i).first = true; - } - - return var->second.first; - } -} -//------------------------------------------------------------------------ -void EnvironmentSubstitute::addSubstitution(const std::string &source, const std::string &destination, - std::vector initialisers) -{ - assert(std::all_of(initialisers.cbegin(), initialisers.cend(), - [this](size_t i) { return i < m_Initialisers.size(); })); - - if(!m_VarSubstitutions.try_emplace(source, destination, initialisers).second) { - throw std::runtime_error("Redeclaration of substitution '" + source + "'"); - } -} -//------------------------------------------------------------------------ -size_t EnvironmentSubstitute::addInitialiser(const std::string &initialiser) -{ - m_Initialisers.emplace_back(false, initialiser); - return (m_Initialisers.size() - 1); } \ No newline at end of file From db97467ea82cc65ac8a3ac53962d8698916bfeaa Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 23 Jun 2023 15:26:49 +0100 Subject: [PATCH 251/725] more custom connectivity update hacking --- .../customConnectivityUpdateGroupMerged.cc | 163 ++++++++++-------- 1 file changed, 91 insertions(+), 72 deletions(-) diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 05a48304cc..31963b66fa 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -236,10 +236,22 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back &CustomConnectivityUpdateGroupMerged::isDerivedParamHeterogeneous); updateEnv.addExtraGlobalParams(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); - // Get variables which will need to be manipulated when adding and removing synapses - const auto &ccuVars = cm->getVars(); - const auto &ccuVarRefs = cm->getVarRefs(); - const auto &dependentVars = getSortedArchetypeDependentVars(); + // Add presynaptic variables and variable references + // **TODO** var references to batched variables should be private + // **THINK** what about batched pre var references? + updateEnv.addVars(backend.getDeviceVarPrefix(), updateEnv["id_pre"], "", + {"id_pre"}); + updateEnv.addVarRefs(backend.getDeviceVarPrefix(), + [&updateEnv](VarAccessMode, const Models::VarReference &v) + { + if(v.getDelayNeuronGroup() != nullptr) { + return "[" + updateSubs["_pre_delay_offset"] + " + " + updateSubs["id_pre"] + "]"; + } + else { + return "[" + updateSubs["id_pre"] + "]"; + } + }, "", + {"id_pre"}); // Calculate index of start of row updateEnv.add(Type::Uint32.addConst(), "_row_start_idx", "rowStartIdx", @@ -250,6 +262,10 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back {updateEnv.addInitialiser("const unsigned int synStride = " + updateEnv["num_pre"] + " * " + updateEnv["_row_stride"] + ";")}, {"num_pre", "_row_stride"}); + // Get variables which will need to be manipulated when adding and removing synapses + const auto ccuVars = cm->getVars(); + const auto ccuVarRefs = cm->getVarRefs(); + const auto &dependentVars = getSortedArchetypeDependentVars(); std::vector addSynapseTypes{Type::Uint32}; addSynapseTypes.reserve(1 + ccuVars.size() + ccuVarRefs.size() + dependentVars.size()); @@ -263,7 +279,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back backend.genAssert(addSynapse, updateEnv["_row_length"] + "[" + updateEnv["id_pre"] + "] < " + updateEnv["_row_stride"]); // Calculate index to insert synapse - addSynapse << "const unsigned newIdx = rowStartIdx + " + updateEnv["_row_length"] + "[" << updateEnv["id_pre"] << "];" << std::endl; + addSynapse << "const unsigned newIdx = " + updateEnv["_row_start_idx"] + " + " + updateEnv["_row_length"] + "[" << updateEnv["id_pre"] << "];" << std::endl; // Set postsynaptic target to parameter 0 addSynapse << updateEnv["_ind"] + "[newIdx] = $(0);" << std::endl; @@ -284,7 +300,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); - addSynapse << "group->" << ccuVarRefs[i].name << "[(b * synStride) + newIdx] = _" << ccuVarRefs[i].name << "Val;" << std::endl; + addSynapse << "group->" << ccuVarRefs[i].name << "[(b * " << updateEnv["_syn_stride"] << ") + newIdx] = _" << ccuVarRefs[i].name << "Val;" << std::endl; } } // Otherwise, write parameter straight into var reference @@ -304,7 +320,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); - addSynapse << "group->_dependentVar" << i << "[(b * synStride) + newIdx] = 0;" << std::endl; + addSynapse << "group->_dependentVar" << i << "[(b * " << updateEnv["_syn_stride"] << ") + newIdx] = 0;" << std::endl; } } // Otherwise, zero var reference @@ -330,10 +346,10 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back CodeStream::Scope b(removeSynapse); // Calculate index we want to copy synapse from - removeSynapse << "const unsigned lastIdx = rowStartIdx + " + updateEnv["_row_length"] + "[" << updateEnv["id_pre"] << "] - 1;" << std::endl; + removeSynapse << "const unsigned lastIdx = " + updateEnv["_row_start_idx"] + " + " + updateEnv["_row_length"] + "[" << updateEnv["id_pre"] << "] - 1;" << std::endl; // Copy postsynaptic target from end of row over synapse to be deleted - removeSynapse << "group->ind[idx] = group->ind[lastIdx];" << std::endl; + removeSynapse << updateEnv["_ind"] << "[idx] = " << updateEnv["_ind"] << "[lastIdx];" << std::endl; // Copy custom connectivity update variables from end of row over synapse to be deleted for (size_t i = 0; i < ccuVars.size(); i++) { @@ -343,14 +359,15 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Loop through variable references for (size_t i = 0; i < ccuVarRefs.size(); i++) { // If model is batched and this variable is duplicated - if ((batchSize > 1) + if ((modelMerged.getModel().getBatchSize() > 1) && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) { // Loop through all batches and copy custom connectivity update variable references from end of row over synapse to be deleted - removeSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; + removeSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); - removeSynapse << "group->" << ccuVarRefs[i].name << "[(b * synStride) + idx] = group->" << ccuVarRefs[i].name << "[(b * synStride) + lastIdx];" << std::endl; + removeSynapse << "group->" << ccuVarRefs[i].name << "[(b * " << updateEnv["_syn_stride"] << ") + idx] = "; + removeSynapse << "group->" << ccuVarRefs[i].name << "[(b * " << updateEnv["_syn_stride"] << ") + lastIdx];" << std::endl; } } // Otherwise, copy custom connectivity update variable references from end of row over synapse to be deleted @@ -362,12 +379,13 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Loop through any other dependent variables for (size_t i = 0; i < dependentVars.size(); i++) { // If model is batched and this dependent variable is duplicated - if ((batchSize > 1) && dependentVars.at(i).isDuplicated()) { + if ((modelMerged.getModel().getBatchSize() > 1) && dependentVars.at(i).isDuplicated()) { // Loop through all batches and copy dependent variable from end of row over synapse to be deleted - removeSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; + removeSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(removeSynapse); - removeSynapse << "group->_dependentVar" << i << "[(b * synStride) + idx] = group->_dependentVar" << i << "[(b * synStride) + lastIdx];" << std::endl; + removeSynapse << "group->_dependentVar" << i << "[(b * " << updateEnv["_syn_stride"] << ") + idx] = "; + removeSynapse << "group->_dependentVar" << i << "[(b * " << updateEnv["_syn_stride"] << ") + lastIdx];" << std::endl; } } // Otherwise, copy dependent variable from end of row over synapse to be deleted @@ -378,60 +396,14 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Decrement row length // **NOTE** this will also effect any forEachSynapse loop currently in operation - removeSynapse << "group->rowLength[" << updateSubs["id_pre"] << "]--;" << std::endl; + removeSynapse << updateSubs["_row_length"] << "[" << updateSubs["id_pre"] << "]--;" << std::endl; // Decrement loop counter so synapse j will get processed removeSynapse << "j--;" << std::endl; } - updateSubs.addFuncSubstitution("remove_synapse", 0, removeSynapseStream.str()); - - // **TODO** presynaptic variables and variable references could be read into registers at start of row - updateSubs.addVarNameSubstitution(cm->getVars(), "", "group->", "[" + updateSubs["id_syn"] + "]"); - updateSubs.addVarNameSubstitution(cm->getPreVars(), "", "group->", "[" + updateSubs["id_pre"] + "]"); - updateSubs.addVarNameSubstitution(cm->getPostVars(), "", "group->", "[" + updateSubs["id_post"] + "]"); - - // Substitute in variable references, filtering out those which are duplicated - const auto &variableRefs = getArchetype().getVarReferences(); - updateSubs.addVarNameSubstitution(cm->getVarRefs(), "", "group->", - [&updateSubs](VarAccessMode, const std::string&) { return "[" + updateSubs["id_syn"] + "]"; }, - [modelBatched, &variableRefs](VarAccessMode, const std::string &name) - { - return !modelBatched || !variableRefs.at(name).isDuplicated(); - }); - - // Substitute in (potentially delayed) presynaptic variable references - const auto &preVariableRefs = getArchetype().getPreVarReferences(); - updateSubs.addVarNameSubstitution(cm->getPreVarRefs(), "", "group->", - [&preVariableRefs, &updateSubs](VarAccessMode, const std::string &name) - { - if(preVariableRefs.at(name).getDelayNeuronGroup() != nullptr) { - return "[preDelayOffset + " + updateSubs["id_pre"] + "]"; - } - else { - return "[" + updateSubs["id_pre"] + "]"; - } - }); - - // Substitute in (potentially delayed) postsynaptic variable references - const auto &postVariableRefs = getArchetype().getPostVarReferences(); - updateSubs.addVarNameSubstitution(cm->getPostVarRefs(), "", "group->", - [&postVariableRefs, &updateSubs](VarAccessMode, const std::string &name) - { - if(postVariableRefs.at(name).getDelayNeuronGroup() != nullptr) { - return "[postDelayOffset + " + updateSubs["id_post"] + "]"; - } - else { - return "[" + updateSubs["id_post"] + "]"; - } - }); - - updateSubs.addParamValueSubstitution(cm->getParamNames(), getArchetype().getParams(), - [this](const std::string &name) { return isParamHeterogeneous(name); }, - "", "group->"); - updateSubs.addVarValueSubstitution(cm->getDerivedParams(), getArchetype().getDerivedParams(), - [this](const std::string &name) { return isDerivedParamHeterogeneous(name); }, - "", "group->"); - updateSubs.addVarNameSubstitution(cm->getExtraGlobalParams(), "", "group->"); + + // Add function substitution with parameters to initialise custom connectivity update variables and variable references + updateEnv.add(Type::ResolvedType::createFunction(Type::Void, {}), "remove_synapse", removeSynapseStream.str()); // Pretty print code back to environment Transpiler::ErrorHandler errorHandler("Current source injection" + std::to_string(getIndex())); @@ -441,16 +413,63 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back { env.define(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, "id_post", 0}, Type::Uint32.addConst(), errorHandler); env.define(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, "id_syn", 0}, Type::Uint32.addConst(), errorHandler); + + // **TODO** variable types }, - [](auto &env, auto generateBody) + [&backend, &modelMerged, this](auto &env, auto generateBody) { - env.getStream() << "for(int j = 0; j < " << env["_row_length"] + "[" + env["id_pre"] + "]; j++)"; + EnvironmentGroupMergedField bodyEnv(env, *this); + bodyEnv.getStream() << "for(int j = 0; j < " << bodyEnv["_row_length"] + "[" + bodyEnv["id_pre"] + "]; j++)"; { - CodeStream::Scope b(env.getStream()); - env.getStream() << "const unsigned int idx = rowStartIdx + j;" << std::endl; - //pdateSubs.addVarSubstitution("id_post", "group->ind[rowStartIdx + j]"); - //updateSubs.addVarSubstitution("id_syn", "idx"); - generateBody(); + CodeStream::Scope b(bodyEnv.getStream()); + + // Add postsynaptic and synaptic indices + bodyEnv.add(Type::Uint32.addConst(), "id_post", bodyEnv["_ind"] + "[" + bodyEnv["_row_start_idx"] + " + j]", + {}, {"_ind", "_row_start_idx"}); + bodyEnv.add(Type::Uint32.addConst(), "id_syn", "idx", + {bodyEnv.addInitialiser("const unsigned int idx = " + bodyEnv["_row_start_idx"] + " + j;")}, + {"_row_start_idx"}); + + // Add postsynaptic and synaptic variables + bodyEnv.addVars(backend.getDeviceVarPrefix(), bodyEnv["id_syn"], "", + {"id_syn"}); + bodyEnv.addVars(backend.getDeviceVarPrefix(), bodyEnv["id_post"], "", + {"id_post"}); + + // Add postsynaptic and synaptic var references + // **TODO** + bodyEnv.addVarRefs(backend.getDeviceVarPrefix(), + [modelMerged, this](const std::string &ma, const Models::VarReference &v) + { + return (modelMerged.getModel().getBatchSize() == 1) || !v.isDuplicated(); + }); + bodyEnv.addVarRefs(backend.getDeviceVarPrefix(), + [modelMerged, this](const std::string &ma, const Models::WUVarReference &v) + { + return (modelMerged.getModel().getBatchSize() == 1) || !v.isDuplicated(); + }); + // Substitute in variable references, filtering out those which are duplicated + const auto &variableRefs = getArchetype().getVarReferences(); + updateSubs.addVarNameSubstitution(cm->getVarRefs(), "", "group->", + [&updateSubs](VarAccessMode, const std::string&) { return "[" + updateSubs["id_syn"] + "]"; }, + [modelBatched, &variableRefs](VarAccessMode, const std::string &name) + { + return !modelBatched || !variableRefs.at(name).isDuplicated(); + }); + + // Substitute in (potentially delayed) postsynaptic variable references + const auto &postVariableRefs = getArchetype().getPostVarReferences(); + updateSubs.addVarNameSubstitution(cm->getPostVarRefs(), "", "group->", + [&postVariableRefs, &updateSubs](VarAccessMode, const std::string &name) + { + if(postVariableRefs.at(name).getDelayNeuronGroup() != nullptr) { + return "[postDelayOffset + " + updateSubs["id_post"] + "]"; + } + else { + return "[" + updateSubs["id_post"] + "]"; + } + }); + generateBody(bodyEnv); } }); } From a2b4215e075671d8026208c380bd3e3811c0f9ef Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 23 Jun 2023 17:01:18 +0100 Subject: [PATCH 252/725] started using LazyString in addInitialiser calls --- .../genn/genn/code_generator/backendBase.h | 69 +++++++------------ .../code_generator/synapseUpdateGroupMerged.h | 3 +- .../backends/single_threaded_cpu/backend.cc | 53 +++++++------- src/genn/genn/code_generator/backendBase.cc | 17 ++--- .../code_generator/neuronUpdateGroupMerged.cc | 10 +-- .../synapseUpdateGroupMerged.cc | 15 +--- 6 files changed, 69 insertions(+), 98 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 61b59b9fa1..f2080fbbb8 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -492,6 +492,7 @@ class GENN_EXPORT BackendBase template void genNeuronIndexCalculation(EnvironmentGroupMergedField &env, unsigned int batchSize) const { + using LS = LazyString; env.addField(Type::Uint32.addConst(), "num_neurons", Type::Uint32, "numNeurons", [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); @@ -519,8 +520,7 @@ class GENN_EXPORT BackendBase // If batching is enabled, calculate batch offset if(batchSize > 1) { env.add(Type::Uint32.addConst(), "_batchOffset", "batchOffset", - {env.addInitialiser("const unsigned int batchOffset = " + env["num_neurons"] + " * batch;")}, - {"num_neurons"}); + {env.addInitialiser("const unsigned int batchOffset = " + LS(env, "num_neurons") + " * batch;")}); } // If axonal delays are required @@ -529,42 +529,33 @@ class GENN_EXPORT BackendBase const unsigned int numDelaySlots = env.getGroup().getArchetype().getNumDelaySlots(); const std::string numDelaySlotsStr = std::to_string(numDelaySlots); env.add(Type::Uint32.addConst(), "_read_delay_slot", "readDelaySlot", - {env.addInitialiser("const unsigned int readDelaySlot = (*" + env["_spk_que_ptr"] + " + " + std::to_string(numDelaySlots - 1) + ") % " + numDelaySlotsStr+ ";")}, - {"_spk_que_ptr"}); + {env.addInitialiser("const unsigned int readDelaySlot = (*" + LS(env, "_spk_que_ptr") + " + " + std::to_string(numDelaySlots - 1) + ") % " + numDelaySlotsStr+ ";")}); env.add(Type::Uint32.addConst(), "_read_delay_offset", "readDelayOffset", - {env.addInitialiser("const unsigned int readDelayOffset = readDelaySlot * " + env["num_neurons"] + ";")}, - {"num_neurons", "_read_delay_slot"}); + {env.addInitialiser("const unsigned int readDelayOffset = " + LS(env, "_read_delay_slot") + " * " + LS(env, "num_neurons") + ";")}); // And we should WRITE to delay slot pointed to be spkQuePtr env.add(Type::Uint32.addConst(), "_write_delay_slot", "writeDelaySlot", - {env.addInitialiser("const unsigned int writeDelaySlot = *" + env["_spk_que_ptr"] + ";")}, - {"_spk_que_ptr"}); + {env.addInitialiser("const unsigned int writeDelaySlot = *" + LS(env, "_spk_que_ptr") + ";")}); env.add(Type::Uint32.addConst(), "_write_delay_offset", "writeDelayOffset", - {env.addInitialiser("const unsigned int writeDelayOffset = writeDelaySlot * " + env["num_neurons"] + ";")}, - {"num_neurons", "_write_delay_slot"}); + {env.addInitialiser("const unsigned int writeDelayOffset = " + LS(env, "_write_delay_slot") + " * " + LS(env, "num_neurons") + ";")}); // If batching is also enabled if(batchSize > 1) { // Calculate batched delay slots env.add(Type::Uint32.addConst(), "_read_batch_delay_slot", "readBatchDelaySlot", - {env.addInitialiser("const unsigned int readBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + readDelaySlot;")}, - {"_read_delay_slot"}); + {env.addInitialiser("const unsigned int readBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + " + LS(env, "_read_delay_slot") + ";")}); env.add(Type::Uint32.addConst(), "_write_batch_delay_slot", "writeBatchDelaySlot", - {env.addInitialiser("const unsigned int writeBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + writeDelaySlot;")}, - {"_write_delay_slot"}); + {env.addInitialiser("const unsigned int writeBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + " + LS(env, "_write_delay_slot") + ";")}); // Calculate current batch offset env.add(Type::Uint32.addConst(), "_batch_delay_offset", "batchDelayOffset", - {env.addInitialiser("const unsigned int batchDelayOffset = batchOffset * " + numDelaySlotsStr + ";")}, - {"_batch_offset"}); + {env.addInitialiser("const unsigned int batchDelayOffset = " + LS(env, "_batch_offset") + " * " + numDelaySlotsStr + ";")}); // Calculate further offsets to include delay and batch env.add(Type::Uint32.addConst(), "_read_batch_delay_offset", "readBatchDelayOffset", - {env.addInitialiser("const unsigned int readBatchDelayOffset = readDelayOffset + batchDelayOffset;")}, - {"_read_delay_offset", "_batchDelayOffset"}); + {env.addInitialiser("const unsigned int readBatchDelayOffset = " + LS(env, "_read_delay_offset") + " + " + LS(env, "_batch_delay_offset") + ";")}); env.add(Type::Uint32.addConst(), "_write_batch_delay_offset", "writeBatchDelayOffset", - {env.addInitialiser("const unsigned int writeBatchDelayOffset = writeDelayOffset + batchDelayOffset;")}, - {"_write_delay_offset", "_batchDelayOffset"}); + {env.addInitialiser("const unsigned int writeBatchDelayOffset = " + LS(env, "_write_delay_offset") + " + " + LS(env, "_batch_delay_offset") + ";")}); } } } @@ -572,6 +563,7 @@ class GENN_EXPORT BackendBase template void genSynapseIndexCalculation(EnvironmentGroupMergedField &env, unsigned int batchSize) const { + using LS = LazyString; // Synapse group fields env.addField(Type::Uint32.addConst(), "num_pre", Type::Uint32, "numSrcNeurons", @@ -621,11 +613,9 @@ class GENN_EXPORT BackendBase if(batchSize > 1) { // Calculate batch offsets into pre and postsynaptic populations env.add(Type::Uint32.addConst(), "_pre_batch_offset", "preBatchOffset", - {env.addInitialiser("const unsigned int preBatchOffset = " + env["num_pre"] + " * batch;")}, - {"num_pre"}); + {env.addInitialiser("const unsigned int preBatchOffset = " + LS(env, "num_pre") + " * batch;")}); env.add(Type::Uint32.addConst(), "_post_batch_offset", "postBatchOffset", - {env.addInitialiser("const unsigned int preBatchOffset = " + env["num_post"] + " * batch;")}, - {"num_post"}); + {env.addInitialiser("const unsigned int preBatchOffset = " + LS(env, "num_post") + " * batch;")}); // Calculate batch offsets into synapse arrays, using 64-bit arithmetic if necessary if(areSixtyFourBitSynapseIndicesRequired(env.getGroup())) { @@ -634,8 +624,7 @@ class GENN_EXPORT BackendBase } else { env.add(Type::Uint32.addConst(), "_syn_batch_offset", "synBatchOffset", - {env.addInitialiser("const unsigned int synBatchOffset = " + env["_pre_batch_offset"] + " * " + env["_row_stride"] + ";")}, - {"_pre_batch_offset", "_row_stride"}); + {env.addInitialiser("const unsigned int synBatchOffset = " + LS(env, "_pre_batch_offset") + " * " + LS(env, "_row_stride") + ";")}); } // If synapse group has kernel @@ -674,8 +663,7 @@ class GENN_EXPORT BackendBase {env.addInitialiser(preDelaySlotInit.str())}, {"_src_spk_que_ptr"}); env.add(Type::Uint32, "_pre_delay_offset", "preDelayOffset", - {env.addInitialiser("const unsigned int preDelayOffset = preDelaySlot * " + env["num_pre"] + ";")}, - {"num_pre", "_pre_delay_slot"}); + {env.addInitialiser("const unsigned int preDelayOffset = " + LS(env, "_pre_delay_slot") + " * " + LS(env, "num_pre") + ";")}); if(batchSize > 1) { env.add(Type::Uint32, "_pre_batch_delay_slot", "preBatchDelaySlot", @@ -690,14 +678,12 @@ class GENN_EXPORT BackendBase || env.getGroup().getArchetype().getWUModel()->isPrevPreSpikeEventTimeRequired()) { env.add(Type::Uint32, "_pre_prev_spike_time_delay_offset", "prePrevSpikeTimeDelayOffset", - {env.addInitialiser("const unsigned int prePrevSpikeTimeDelayOffset = ((*" + env["_src_spk_que_ptr"] + " + " - + std::to_string(numSrcDelaySlots - numDelaySteps - 1) + ") % " + std::to_string(numSrcDelaySlots) + ") * " + env["num_pre"] + ";")}, - {"_src_spk_que_ptr", "num_pre"}); + {env.addInitialiser("const unsigned int prePrevSpikeTimeDelayOffset = ((*" + LS(env, "_src_spk_que_ptr") + " + " + + std::to_string(numSrcDelaySlots - numDelaySteps - 1) + ") % " + std::to_string(numSrcDelaySlots) + ") * " + LS(env, "num_pre") + ";")}); if(batchSize > 1) { env.add(Type::Uint32, "_pre_prev_spike_time_batch_delay_offset", "prePrevSpikeTimeBatchDelayOffset", - {env.addInitialiser("const unsigned int prePrevSpikeTimeBatchDelayOffset =prePrevSpikeTimeDelayOffset + (" + env["_pre_batch_offset"] + " * " + std::to_string(numSrcDelaySlots) + ");")}, - {"_pre_prev_spike_time_delay_offset", "_pre_batch_offset"}); + {env.addInitialiser("const unsigned int prePrevSpikeTimeBatchDelayOffset = " + LS(env, "_pre_prev_spike_time_delay_offset") + " + (" + LS(env, "_pre_batch_offset") + " * " + std::to_string(numSrcDelaySlots) + ");")}); } } } @@ -719,28 +705,23 @@ class GENN_EXPORT BackendBase {env.addInitialiser(postDelaySlotInit.str())}, {"_trg_spk_que_ptr"}); env.add(Type::Uint32, "_post_delay_offset", "postDelayOffset", - {env.addInitialiser("const unsigned int postDelayOffset = postDelaySlot * " + env["num_post"] + ";")}, - {"num_post", "_post_delay_slot"}); + {env.addInitialiser("const unsigned int postDelayOffset = " + LS(env, "_post_delay_slot") + " * " + LS(env, "num_post") + ";")}); if(batchSize > 1) { env.add(Type::Uint32, "_post_batch_delay_slot", "postBatchDelaySlot", - {env.addInitialiser("const unsigned int postBatchDelaySlot = postDelaySlot + (batch * " + std::to_string(numTrgDelaySlots) + ");")}, - {"_post_delay_slot"}); + {env.addInitialiser("const unsigned int postBatchDelaySlot = " + LS(env, "_post_delay_slot") + " + (batch * " + std::to_string(numTrgDelaySlots) + ");")}); env.add(Type::Uint32, "_post_batch_delay_offset", "postBatchDelayOffset", - {env.addInitialiser("const unsigned int postBatchDelayOffset = postDelayOffset + (postBatchOffset * " + std::to_string(numTrgDelaySlots) + ");")}, - {"_post_delay_offset", "_post_batch_offset"}); + {env.addInitialiser("const unsigned int postBatchDelayOffset = " + LS(env, "_post_delay_offset") + " + (" + LS(env, "_post_batch_offset") + " * " + std::to_string(numTrgDelaySlots) + ");")}); } if(env.getGroup().getArchetype().getWUModel()->isPrevPostSpikeTimeRequired()) { env.add(Type::Uint32, "_post_prev_spike_time_delay_offset", "postPrevSpikeTimeDelayOffset", - {env.addInitialiser("const unsigned int postPrevSpikeTimeDelayOffset = ((*" + env["_trg_spk_que_ptr"] + " + " - + std::to_string(numTrgDelaySlots - numBackPropDelaySteps - 1) + ") % " + std::to_string(numTrgDelaySlots) + ") * " + env["num_post"] + ";")}, - {"_trg_spk_que_ptr", "num_post"}); + {env.addInitialiser("const unsigned int postPrevSpikeTimeDelayOffset = ((*" + LS(env, "_trg_spk_que_ptr") + " + " + + std::to_string(numTrgDelaySlots - numBackPropDelaySteps - 1) + ") % " + std::to_string(numTrgDelaySlots) + ") * " + LS(env, "num_post") + ";")}); if(batchSize > 1) { env.add(Type::Uint32, "_post_prev_spike_time_batch_delay_offset", "postPrevSpikeTimeBatchDelayOffset", - {env.addInitialiser("const unsigned int postPrevSpikeTimeBatchDelayOffset = postPrevSpikeTimeDelayOffset + (" + env["_post_batch_offset"] + " * " + std::to_string(numTrgDelaySlots) + ");")}, - {"_post_prev_spike_time_delay_offset", "_post_batch_offset"}); + {env.addInitialiser("const unsigned int postPrevSpikeTimeBatchDelayOffset = " + LS(env, "_post_prev_spike_time_delay_offset") + " + (" + LS(env, "_post_batch_offset") + " * " + std::to_string(numTrgDelaySlots) + ");")}); } } diff --git a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h index b57943ae0e..13c5aa8935 100644 --- a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h @@ -117,8 +117,7 @@ class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseDendriticDelayUpdateGroupMerged : public GroupMerged { public: - SynapseDendriticDelayUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, - const std::vector> &group); + using GroupMerged::GroupMerged; //------------------------------------------------------------------------ // Public API diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 1b969756aa..d77702210b 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -303,6 +303,8 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host //-------------------------------------------------------------------------- void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { + using LS = LazyString; + if (modelMerged.getModel().getBatchSize() != 1) { throw std::runtime_error("The single-threaded CPU backend only supports simulations with a batch size of 1"); } @@ -368,8 +370,8 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos if (s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Add initialiser strings to calculate synaptic and presynaptic index - const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["_row_stride"] + ") + s;"); - const size_t idPostInit = synEnv.addInitialiser("const unsigned int idPost = " + synEnv["_ind"] + "[idSyn];"); + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * " + LS(synEnv, "_row_stride") + ") + s;"); + const size_t idPostInit = synEnv.addInitialiser("const unsigned int idPost = " + LS(synEnv, "_ind") + "[idSyn];"); // **TODO** id_syn can be 64-bit synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}, {"_row_stride"}); @@ -380,7 +382,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos synEnv.add(Type::Uint32.addConst(), "id_post", "j"); // Add initialiser to calculate synaptic index - const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["num_post"] + ") + j;"); + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * " + LS(synEnv, "num_post") + ") + j;"); // **TODO** id_syn can be 64-bit synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}); @@ -490,9 +492,9 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Add initialisers to calculate column and row-major indices // **TODO** fast divide optimisations - const size_t colMajorIdxInit = synEnv.addInitialiser("const unsigned int colMajorIndex = (spike * " + synEnv["_col_stride"] + ") + i;"); - const size_t rowMajorIdxInit = synEnv.addInitialiser("const unsigned int rowMajorIndex = " + synEnv["_remap"] + "[colMajorIndex];"); - const size_t idPreInit = synEnv.addInitialiser("const unsigned int idPre = rowMajorIndex / " + synEnv["_row_stride"] + ";"); + const size_t colMajorIdxInit = synEnv.addInitialiser("const unsigned int colMajorIndex = (spike * " + LS(synEnv, "_col_stride") + ") + i;"); + const size_t rowMajorIdxInit = synEnv.addInitialiser("const unsigned int rowMajorIndex = " + LS(synEnv. "_remap") + "[colMajorIndex];"); + const size_t idPreInit = synEnv.addInitialiser("const unsigned int idPre = rowMajorIndex / " + LS(synEnv, "_row_stride") + ";"); // Add presynaptic and synapse index to environment synEnv.add(Type::Uint32.addConst(), "id_pre", "idPre", {colMajorIdxInit, rowMajorIdxInit, idPreInit}, {"_col_stride", "_row_stride", "_remap"}); @@ -500,7 +502,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos } else { // Add initialiser to calculate synaptic index - const size_t idSynInit = groupEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["num_post"] + ") + spike;"); + const size_t idSynInit = groupEnv.addInitialiser("const unsigned int idSyn = (i * " + LS(synEnv, "num_post") + ") + spike;"); // Add presynaptic and synapse index to environment synEnv.add(Type::Uint32.addConst(), "id_pre", "i"); @@ -541,6 +543,8 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos //-------------------------------------------------------------------------- void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { + using LS = LazyString; + const ModelSpecInternal &model = modelMerged.getModel(); // Build set containing names of all custom update groups @@ -713,8 +717,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host synEnv.add(Type::Uint32.addConst(), "id_post", "j"); synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", - {synEnv.addInitialiser("const unsigned int idSyn = (i * " + synEnv["num_post"] + ") + j;")}, - {"num_post"}); + {synEnv.addInitialiser("const unsigned int idSyn = (i * " + LS(synEnv, "num_post") + ") + j;")}); } // Generate custom update @@ -804,15 +807,14 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Add conditional initialisation code to calculate synapse index groupEnv.add(Type::Uint32, "id_syn", "idSyn", - {groupEnv.addInitialiser("const unsigned int idSyn = (i * " + groupEnv["num_post"] + ") + j;")}, - {"num_post"}); + {groupEnv.addInitialiser("const unsigned int idSyn = (i * " + LS(groupEnv, "num_post") + ") + j;")}); // Generate custom update c.generateCustomUpdate(*this, groupEnv); // Update transpose variable // **YUCK** this is sorta outside scope - groupEnv.getStream() << groupEnv[transposeVarName + "_transpose"] << "[(j * " << groupEnv["num_pre"] << ") + i] = l" << transposeVarName << ";" << std::endl; + groupEnv.getStream() << groupEnv[transposeVarName + "_transpose"] << "[(j * " << LS(groupEnv, "num_pre") << ") + i] = l" << transposeVarName << ";" << std::endl; } } @@ -1498,6 +1500,8 @@ void Backend::genVariableInit(EnvironmentExternalBase &env, const std::string &c //-------------------------------------------------------------------------- void Backend::genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const { + using LS = LazyString; + env.getStream() << "for (unsigned int j = 0; j < group->rowLength[" << env["id_pre"] << "]; j++)"; { CodeStream::Scope b(env.getStream()); @@ -1505,16 +1509,17 @@ void Backend::genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, Hand EnvironmentExternal varEnv(env); // **TODO** 64-bit varEnv.add(Type::Uint32, "id_syn", "idSyn", - {varEnv.addInitialiser("const unsigned int idSyn = (" + varEnv["id_pre"] + " * " + varEnv["_row_stride"] + ") + j;")}, - {"id_pre", "_rowStride"}); + {varEnv.addInitialiser("const unsigned int idSyn = (" + LS(varEnv, "id_pre") + " * " + LS(varEnv, "_row_stride") + ") + j;")}); varEnv.add(Type::Uint32, "id_post", "idPost", - {varEnv.addInitialiser("const unsigned int idPost = (" + varEnv["_ind"] + "[(" + varEnv["id_pre"] + " * " + varEnv["_row_stride"] + ") + j]")}); + {varEnv.addInitialiser("const unsigned int idPost = (" + LS(varEnv, "_ind") + "[(" + LS(varEnv, "id_pre") + " * " + LS(varEnv, "_row_stride") + ") + j]")}); handler(varEnv); } } //-------------------------------------------------------------------------- void Backend::genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const { + using LS = LazyString; + env.getStream() << "for (unsigned int j = 0; j < " << env["num_post"] << "; j++)"; { CodeStream::Scope b(env.getStream()); @@ -1522,8 +1527,7 @@ void Backend::genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, Handl EnvironmentExternal varEnv(env); // **TODO** 64-bit varEnv.add(Type::Uint32, "id_syn", "idSyn", - {varEnv.addInitialiser("const unsigned int idSyn = (" + varEnv["id_pre"] + " * " + varEnv["_row_stride"] + ") + j;")}, - {"id_pre", "_rowStride"}); + {varEnv.addInitialiser("const unsigned int idSyn = (" + LS(varEnv, "id_pre") + " * " + LS(varEnv, "_row_stride") + ") + j;")}); varEnv.add(Type::Uint32, "id_post", "j"); handler(varEnv); } @@ -1728,6 +1732,8 @@ boost::uuids::detail::sha1::digest_type Backend::getHashDigest() const //-------------------------------------------------------------------------- void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, const ModelSpecMerged &modelMerged, bool trueSpike) const { + using LS = LazyString; + // Get suffix based on type of events const std::string eventSuffix = trueSpike ? "" : "Evnt"; const auto *wu = sg.getArchetype().getWUModel(); @@ -1860,7 +1866,7 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda const std::string queueOffset = sg.getArchetype().getSrcNeuronGroup()->isDelayRequired() ? "preDelayOffset + " : ""; groupEnv.add(Type::Uint32, "id_pre", "idPre", - {groupEnv.addInitialiser("const unsigned int ipre = group->srcSpk" + eventSuffix + "[" + queueOffset + "i];")}); + {groupEnv.addInitialiser("const unsigned int idPre = group->srcSpk" + eventSuffix + "[" + queueOffset + "i];")}); // If this is a spike-like event, insert threshold check for this presynaptic neuron if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { @@ -1891,11 +1897,9 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda // **TODO** 64-bit id_syn synEnv.add(Type::Uint32, "id_syn", "idSyn", - {synEnv.addInitialiser("const unsigned int idSyn = (ipre * " + env["_row_stride"] + ") + j;")}, - {"_row_stride"}); + {synEnv.addInitialiser("const unsigned int idSyn = (" + LS(env, "id_pre") + " * " + LS(env, "_row_stride") + ") + j;")}); synEnv.add(Type::Uint32, "id_post", "idPost", - {synEnv.addInitialiser("const unsigned int idPost = " + env["_ind"] + "[idSyn];")}, - {"_ind", "id_syn"}); + {synEnv.addInitialiser("const unsigned int idPost = " + LS(env, "_ind") + "[" + LS(env, "id_syn") + "];")}); if(trueSpike) { sg.generateSpikeUpdate(*this, synEnv, modelMerged); @@ -1965,14 +1969,13 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { // **TODO** 64-bit index - synEnv.getStream() << "const uint64_t gid = (ipre * group->numTrgNeurons + ipost);" << std::endl; + synEnv.getStream() << "const uint64_t gid = (" << synEnv["id_pre"] << " * " << synEnv["num_post"] << ") + " << synEnv["id_post"] + ";" << std::endl; synEnv.getStream() << "if (B(group->gp[gid / 32], gid & 31))" << CodeStream::OB(20); } else { synEnv.add(Type::Uint32, "id_syn", "idSyn", - {synEnv.addInitialiser("const unsigned int idSyn = (ipre * " + synEnv["num_post"] + ") + ipost;")}, - {"num_post"}); + {synEnv.addInitialiser("const unsigned int idSyn = (" + LS(synEnv, "id_pre") + " * " + LS(synEnv, "num_post") + ") + " + LS(synEnv, "id_post") + ";")}); } diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 093a8ba22b..28ba6ef6a3 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -35,6 +35,8 @@ bool BackendBase::areSixtyFourBitSynapseIndicesRequired(const SynapseGroupMerged //----------------------------------------------------------------------- void BackendBase::genCustomUpdateIndexCalculation(EnvironmentGroupMergedField &env) const { + using LS = LazyString; + // Add size field env.addField(Type::Uint32, "size", "size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); @@ -42,8 +44,7 @@ void BackendBase::genCustomUpdateIndexCalculation(EnvironmentGroupMergedFieldgetNumDelaySlots()); env.add(Type::Uint32.addConst(), "_batch_delay_slot", "batchDelaySlot", - {env.addInitialiser("const unsigned int batchDelaySlot = (batch * " + numDelaySlotsStr + ") + delaySlot;")}, - {"_delay_slot"}); + {env.addInitialiser("const unsigned int batchDelaySlot = (batch * " + numDelaySlotsStr + ") + " + LS(env, "_delay_slot") + ";")}); // Calculate current batch offset env.add(Type::Uint32.addConst(), "_batch_delay_offset", "batchDelayOffset", - {env.addInitialiser("const unsigned int batchDelayOffset = batchOffset * " + numDelaySlotsStr + ";")}, - {"_batch_offset"}); + {env.addInitialiser("const unsigned int batchDelayOffset = " + LS(env, "_batch_offset") + " * " + numDelaySlotsStr + ";")}); } } } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 6dd61d467d..4a001b323b 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -489,6 +489,8 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E BackendBase::GroupHandlerEnv genEmitTrueSpike, BackendBase::GroupHandlerEnv genEmitSpikeLikeEvent) { + using LS = LazyString; + const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); const NeuronModels::Base *nm = getArchetype().getNeuronModel(); @@ -530,13 +532,13 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Substitute spike times const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, neuronEnv["id"]); neuronEnv.add(getTimeType().addConst(), "sT", "lsT", - {neuronEnv.addInitialiser("const timepoint lsT = " + neuronEnv["_spk_time"] + "[" + spikeTimeReadIndex + "];")}); + {neuronEnv.addInitialiser("const timepoint lsT = " + LS(neuronEnv, "_spk_time") + "[" + spikeTimeReadIndex + "];")}); neuronEnv.add(getTimeType().addConst(), "prev_sT", "lprevST", - {neuronEnv.addInitialiser("const timepoint lprevST = " + neuronEnv["_prev_spk_time"] + "[" + spikeTimeReadIndex + "];")}); + {neuronEnv.addInitialiser("const timepoint lprevST = " + LS(neuronEnv, "_prev_spk_time") + "[" + spikeTimeReadIndex + "];")}); neuronEnv.add(getTimeType().addConst(), "seT", "lseT", - {neuronEnv.addInitialiser("const timepoint lseT = " + neuronEnv["_spk_evnt_time"] + "[" + spikeTimeReadIndex+ "];")}); + {neuronEnv.addInitialiser("const timepoint lseT = " + LS(neuronEnv, "_spk_evnt_time") + "[" + spikeTimeReadIndex+ "];")}); neuronEnv.add(getTimeType().addConst(), "prev_seT", "lprevSET", - {neuronEnv.addInitialiser("const timepoint lprevSET = " + neuronEnv["_prev_spk_evnt_time"] + "[" + spikeTimeReadIndex + "];")}); + {neuronEnv.addInitialiser("const timepoint lprevSET = " + LS(neuronEnv, "_prev_spk_evnt_time") + "[" + spikeTimeReadIndex + "];")}); // Create an environment which caches variables in local variables if they are accessed // **NOTE** we do this right at the top so that local copies can be used by child groups diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 2011814fc7..efed795f20 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -49,7 +49,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa // Add substitution // **TODO** dependencies on kernel fields synEnv.add(Type::Uint32, "id_kernel", "kernelInd", - {synEnv.addInitialiser("const unsigned int kernelInd = " + sg.getKernelIndex(synEnv) + ";")}); + {synEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(sg, synEnv) + ";")}); } // If weights are individual, substitute variables for values stored in global memory @@ -290,15 +290,4 @@ void SynapseDynamicsGroupMerged::generateSynapseUpdate(const BackendBase &backen //---------------------------------------------------------------------------- // CodeGenerator::SynapseDendriticDelayUpdateGroupMerged //---------------------------------------------------------------------------- -const std::string SynapseDendriticDelayUpdateGroupMerged::name = "SynapseDendriticDelayUpdate"; -//---------------------------------------------------------------------------- -SynapseDendriticDelayUpdateGroupMerged::SynapseDendriticDelayUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups) -: GroupMerged(index, typeContext, groups) -{ - addField(Type::Uint32.createPointer(), "denDelayPtr", - [&backend](const SynapseGroupInternal &sg, size_t) - { - return backend.getScalarAddressPrefix() + "denDelayPtr" + sg.getFusedPSVarSuffix(); - }); -} +const std::string SynapseDendriticDelayUpdateGroupMerged::name = "SynapseDendriticDelayUpdate"; \ No newline at end of file From dbc4b6a8687d6c28ccd67c66218fe4598132abd1 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sun, 25 Jun 2023 13:48:19 +0100 Subject: [PATCH 253/725] lazy string print functionality --- include/genn/genn/code_generator/lazyString.h | 2 + src/genn/genn/code_generator/lazyString.cc | 37 ++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/include/genn/genn/code_generator/lazyString.h b/include/genn/genn/code_generator/lazyString.h index 1e8561bf98..bc3ec874de 100644 --- a/include/genn/genn/code_generator/lazyString.h +++ b/include/genn/genn/code_generator/lazyString.h @@ -36,6 +36,8 @@ class LazyString //! Evaluate lazy string std::string str() const; + // Static API + static LazyString print(const std::string &format, EnvironmentExternalBase &env); private: LazyString(const Payload &payload) : m_Payload(payload){} diff --git a/src/genn/genn/code_generator/lazyString.cc b/src/genn/genn/code_generator/lazyString.cc index 7b8d8ce563..01e379974e 100644 --- a/src/genn/genn/code_generator/lazyString.cc +++ b/src/genn/genn/code_generator/lazyString.cc @@ -34,4 +34,39 @@ std::string LazyString::str() const } return stream.str(); } - \ No newline at end of file +//---------------------------------------------------------------------------- +LazyString LazyString::print(const std::string &format, EnvironmentExternalBase &env) +{ + // Create regex iterator to iterate over $(XXX) style varibles in format string + std::regex regex("\\$\\(([\\w]+)\\)"); + std::sregex_iterator matchesBegin(format.cbegin(), format.cend(), regex); + std::sregex_iterator matchesEnd; + + // If there are no matches, leave format unmodified and return + if(matchesBegin == matchesEnd) { + return LazyString(format); + } + // Otherwise + else { + // Loop through matches to build lazy string payload + Payload payload; + for(std::sregex_iterator m = matchesBegin;;) { + // Copy the non-matched subsequence (m->prefix()) onto payload + payload.push_back(std::string{m->prefix().first, m->prefix().second}); + + // Add lazy environment reference for $(XXX) variable to payload + payload.push_back(std::make_pair(std::ref(env), (*m)[1])); + + // If there are no subsequent matches, add the remaining non-matched + // characters onto payload, construct lazy string and return + if(std::next(m) == matchesEnd) { + payload.push_back(std::string{m->suffix().first, m->suffix().second}); + return LazyString(payload); + } + // Otherwise go onto next match + else { + m++; + } + } + } +} \ No newline at end of file From 6be0942684edb033a76fc8cd75648d39c6bb2b13 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sun, 25 Jun 2023 13:52:10 +0100 Subject: [PATCH 254/725] started updating - way nicer --- include/genn/genn/code_generator/backendBase.h | 2 +- src/genn/genn/code_generator/backendBase.cc | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index f2080fbbb8..b849a64871 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -520,7 +520,7 @@ class GENN_EXPORT BackendBase // If batching is enabled, calculate batch offset if(batchSize > 1) { env.add(Type::Uint32.addConst(), "_batchOffset", "batchOffset", - {env.addInitialiser("const unsigned int batchOffset = " + LS(env, "num_neurons") + " * batch;")}); + {env.addInitialiser(LS::print("const unsigned int batchOffset = $(num_neurons) * batch;", env))}); } // If axonal delays are required diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 28ba6ef6a3..1ef8ecc2bd 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -44,7 +44,7 @@ void BackendBase::genCustomUpdateIndexCalculation(EnvironmentGroupMergedFieldgetNumDelaySlots()); env.add(Type::Uint32.addConst(), "_batch_delay_slot", "batchDelaySlot", - {env.addInitialiser("const unsigned int batchDelaySlot = (batch * " + numDelaySlotsStr + ") + " + LS(env, "_delay_slot") + ";")}); + {env.addInitialiser(LS::print("const unsigned int batchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_delay_slot);", env))}); // Calculate current batch offset env.add(Type::Uint32.addConst(), "_batch_delay_offset", "batchDelayOffset", - {env.addInitialiser("const unsigned int batchDelayOffset = " + LS(env, "_batch_offset") + " * " + numDelaySlotsStr + ";")}); + {env.addInitialiser(LS::print("const unsigned int batchDelayOffset = $(_batch_offset) * " + numDelaySlotsStr + ";", env))}); } } } From 4343d13482b7c4947018f805aeda608d2d62bdb4 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 27 Jun 2023 09:12:02 +0100 Subject: [PATCH 255/725] removed manual dependency marking everywhere --- .../genn/genn/code_generator/backendBase.h | 70 ++++---- .../genn/genn/code_generator/codeGenUtils.h | 8 +- .../genn/genn/code_generator/environment.h | 128 ++++++-------- include/genn/genn/code_generator/lazyString.h | 2 + .../backends/single_threaded_cpu/backend.cc | 166 ++++++++---------- src/genn/genn/code_generator/backendBase.cc | 12 +- src/genn/genn/code_generator/codeGenUtils.cc | 36 ++++ .../customConnectivityUpdateGroupMerged.cc | 60 +++---- src/genn/genn/code_generator/groupMerged.cc | 26 +-- .../genn/code_generator/initGroupMerged.cc | 39 ++-- src/genn/genn/code_generator/lazyString.cc | 19 +- .../code_generator/neuronUpdateGroupMerged.cc | 66 ++++--- .../synapseUpdateGroupMerged.cc | 23 +-- 13 files changed, 325 insertions(+), 330 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index b849a64871..23dfd6665c 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -520,7 +520,7 @@ class GENN_EXPORT BackendBase // If batching is enabled, calculate batch offset if(batchSize > 1) { env.add(Type::Uint32.addConst(), "_batchOffset", "batchOffset", - {env.addInitialiser(LS::print("const unsigned int batchOffset = $(num_neurons) * batch;", env))}); + {env.addInitialiser("const unsigned int batchOffset = $(num_neurons) * batch;", env)}); } // If axonal delays are required @@ -529,33 +529,33 @@ class GENN_EXPORT BackendBase const unsigned int numDelaySlots = env.getGroup().getArchetype().getNumDelaySlots(); const std::string numDelaySlotsStr = std::to_string(numDelaySlots); env.add(Type::Uint32.addConst(), "_read_delay_slot", "readDelaySlot", - {env.addInitialiser("const unsigned int readDelaySlot = (*" + LS(env, "_spk_que_ptr") + " + " + std::to_string(numDelaySlots - 1) + ") % " + numDelaySlotsStr+ ";")}); + {env.addInitialiser("const unsigned int readDelaySlot = (*$(_spk_que_ptr) + " + std::to_string(numDelaySlots - 1) + ") % " + numDelaySlotsStr+ ";", env)}); env.add(Type::Uint32.addConst(), "_read_delay_offset", "readDelayOffset", - {env.addInitialiser("const unsigned int readDelayOffset = " + LS(env, "_read_delay_slot") + " * " + LS(env, "num_neurons") + ";")}); + {env.addInitialiser("const unsigned int readDelayOffset = $(_read_delay_slot) * $(num_neurons);", env)}); // And we should WRITE to delay slot pointed to be spkQuePtr env.add(Type::Uint32.addConst(), "_write_delay_slot", "writeDelaySlot", - {env.addInitialiser("const unsigned int writeDelaySlot = *" + LS(env, "_spk_que_ptr") + ";")}); + {env.addInitialiser("const unsigned int writeDelaySlot = * $(_spk_que_ptr);", env)}); env.add(Type::Uint32.addConst(), "_write_delay_offset", "writeDelayOffset", - {env.addInitialiser("const unsigned int writeDelayOffset = " + LS(env, "_write_delay_slot") + " * " + LS(env, "num_neurons") + ";")}); + {env.addInitialiser("const unsigned int writeDelayOffset = $(_write_delay_slot) * $(num_neurons);", env)}); // If batching is also enabled if(batchSize > 1) { // Calculate batched delay slots env.add(Type::Uint32.addConst(), "_read_batch_delay_slot", "readBatchDelaySlot", - {env.addInitialiser("const unsigned int readBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + " + LS(env, "_read_delay_slot") + ";")}); + {env.addInitialiser("const unsigned int readBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_read_delay_slot);", env)}); env.add(Type::Uint32.addConst(), "_write_batch_delay_slot", "writeBatchDelaySlot", - {env.addInitialiser("const unsigned int writeBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + " + LS(env, "_write_delay_slot") + ";")}); + {env.addInitialiser("const unsigned int writeBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_write_delay_slot);", env)}); // Calculate current batch offset env.add(Type::Uint32.addConst(), "_batch_delay_offset", "batchDelayOffset", - {env.addInitialiser("const unsigned int batchDelayOffset = " + LS(env, "_batch_offset") + " * " + numDelaySlotsStr + ";")}); + {env.addInitialiser("const unsigned int batchDelayOffset = $(_batch_offset) * " + numDelaySlotsStr + ";", env)}); // Calculate further offsets to include delay and batch env.add(Type::Uint32.addConst(), "_read_batch_delay_offset", "readBatchDelayOffset", - {env.addInitialiser("const unsigned int readBatchDelayOffset = " + LS(env, "_read_delay_offset") + " + " + LS(env, "_batch_delay_offset") + ";")}); + {env.addInitialiser("const unsigned int readBatchDelayOffset = $(_read_delay_offset) + $(_batch_delay_offset);", env)}); env.add(Type::Uint32.addConst(), "_write_batch_delay_offset", "writeBatchDelayOffset", - {env.addInitialiser("const unsigned int writeBatchDelayOffset = " + LS(env, "_write_delay_offset") + " + " + LS(env, "_batch_delay_offset") + ";")}); + {env.addInitialiser("const unsigned int writeBatchDelayOffset = $(_write_delay_offset)+ $(_batch_delay_offset);", env)}); } } } @@ -563,7 +563,6 @@ class GENN_EXPORT BackendBase template void genSynapseIndexCalculation(EnvironmentGroupMergedField &env, unsigned int batchSize) const { - using LS = LazyString; // Synapse group fields env.addField(Type::Uint32.addConst(), "num_pre", Type::Uint32, "numSrcNeurons", @@ -588,7 +587,6 @@ class GENN_EXPORT BackendBase env.addField(env.getGroup().getScalarType().createPointer(), "_out_pre", "outPre", [this](const auto &g, size_t) { return getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); - // Source neuron fields env.addField(Type::Uint32.createPointer(), "_src_spk_que_ptr", "srcSpkQuePtr", [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getSrcNeuronGroup()->getName(); }); @@ -613,9 +611,9 @@ class GENN_EXPORT BackendBase if(batchSize > 1) { // Calculate batch offsets into pre and postsynaptic populations env.add(Type::Uint32.addConst(), "_pre_batch_offset", "preBatchOffset", - {env.addInitialiser("const unsigned int preBatchOffset = " + LS(env, "num_pre") + " * batch;")}); + {env.addInitialiser("const unsigned int preBatchOffset = $(num_pre) * $(batch);", env)}); env.add(Type::Uint32.addConst(), "_post_batch_offset", "postBatchOffset", - {env.addInitialiser("const unsigned int preBatchOffset = " + LS(env, "num_post") + " * batch;")}); + {env.addInitialiser("const unsigned int preBatchOffset = $(num_post) * $(batch);", env)}); // Calculate batch offsets into synapse arrays, using 64-bit arithmetic if necessary if(areSixtyFourBitSynapseIndicesRequired(env.getGroup())) { @@ -624,7 +622,7 @@ class GENN_EXPORT BackendBase } else { env.add(Type::Uint32.addConst(), "_syn_batch_offset", "synBatchOffset", - {env.addInitialiser("const unsigned int synBatchOffset = " + LS(env, "_pre_batch_offset") + " * " + LS(env, "_row_stride") + ";")}); + {env.addInitialiser("const unsigned int synBatchOffset = $(_pre_batch_offset) * $(_row_stride);", env)}); } // If synapse group has kernel @@ -639,10 +637,10 @@ class GENN_EXPORT BackendBase } // And finally by batch - kernBatchOffsetInit << "batch;" << std::endl; + kernBatchOffsetInit << "$(batch);" << std::endl; env.add(Type::Uint32.addConst(), "_kern_batch_offset", "kernBatchOffset", - {env.addInitialiser(kernBatchOffsetInit.str())}); + {env.addInitialiser(kernBatchOffsetInit.str(), env)}); } } @@ -654,36 +652,34 @@ class GENN_EXPORT BackendBase std::ostringstream preDelaySlotInit; preDelaySlotInit << "const unsigned int preDelaySlot = "; if(numDelaySteps == 0) { - preDelaySlotInit << "*" << env["_src_spk_que_ptr"] << ";" << std::endl; + preDelaySlotInit << "*$(_src_spk_que_ptr);" << std::endl; } else { - preDelaySlotInit << "(*" << env["_src_spk_que_ptr"] << " + " << (numSrcDelaySlots - numDelaySteps) << ") % " << numSrcDelaySlots << ";" << std::endl; + preDelaySlotInit << "(*$(_src_spk_que_ptr) + " << (numSrcDelaySlots - numDelaySteps) << ") % " << numSrcDelaySlots << ";" << std::endl; } env.add(Type::Uint32, "_pre_delay_slot", "preDelaySlot", - {env.addInitialiser(preDelaySlotInit.str())}, {"_src_spk_que_ptr"}); + {env.addInitialiser(preDelaySlotInit.str(), env)}); env.add(Type::Uint32, "_pre_delay_offset", "preDelayOffset", - {env.addInitialiser("const unsigned int preDelayOffset = " + LS(env, "_pre_delay_slot") + " * " + LS(env, "num_pre") + ";")}); + {env.addInitialiser("const unsigned int preDelayOffset = $(_pre_delay_slot) * $(num_pre);", env)}); if(batchSize > 1) { env.add(Type::Uint32, "_pre_batch_delay_slot", "preBatchDelaySlot", - {env.addInitialiser("const unsigned int preBatchDelaySlot = preDelaySlot + (batch * " + std::to_string(numSrcDelaySlots) + ");")}, - {"_pre_delay_slot"}); + {env.addInitialiser("const unsigned int preBatchDelaySlot = $(_pre_delay_slot) + ($(batch) * " + std::to_string(numSrcDelaySlots) + ");", env)}); env.add(Type::Uint32, "_pre_batch_delay_offset", "preBatchDelayOffset", - {env.addInitialiser("const unsigned int preBatchDelayOffset = preDelayOffset + (preBatchOffset * " + std::to_string(numSrcDelaySlots) + ");")}, - {"_pre_delay_offset", "_pre_batch_offset"}); + {env.addInitialiser("const unsigned int preBatchDelayOffset = $(_pre_delay_offset) + ($(_pre_batch_offset) * " + std::to_string(numSrcDelaySlots) + ");", env)}); } if(env.getGroup().getArchetype().getWUModel()->isPrevPreSpikeTimeRequired() || env.getGroup().getArchetype().getWUModel()->isPrevPreSpikeEventTimeRequired()) { env.add(Type::Uint32, "_pre_prev_spike_time_delay_offset", "prePrevSpikeTimeDelayOffset", - {env.addInitialiser("const unsigned int prePrevSpikeTimeDelayOffset = ((*" + LS(env, "_src_spk_que_ptr") + " + " - + std::to_string(numSrcDelaySlots - numDelaySteps - 1) + ") % " + std::to_string(numSrcDelaySlots) + ") * " + LS(env, "num_pre") + ";")}); + {env.addInitialiser("const unsigned int prePrevSpikeTimeDelayOffset = ((*$(_src_spk_que_ptr) + " + + std::to_string(numSrcDelaySlots - numDelaySteps - 1) + ") % " + std::to_string(numSrcDelaySlots) + ") * $(num_pre);", env)}); if(batchSize > 1) { env.add(Type::Uint32, "_pre_prev_spike_time_batch_delay_offset", "prePrevSpikeTimeBatchDelayOffset", - {env.addInitialiser("const unsigned int prePrevSpikeTimeBatchDelayOffset = " + LS(env, "_pre_prev_spike_time_delay_offset") + " + (" + LS(env, "_pre_batch_offset") + " * " + std::to_string(numSrcDelaySlots) + ");")}); + {env.addInitialiser("const unsigned int prePrevSpikeTimeBatchDelayOffset = $(_pre_prev_spike_time_delay_offset) + ($(_pre_batch_offset) * " + std::to_string(numSrcDelaySlots) + ");", env)}); } } } @@ -696,32 +692,32 @@ class GENN_EXPORT BackendBase std::ostringstream postDelaySlotInit; postDelaySlotInit << "const unsigned int postDelaySlot = "; if(numBackPropDelaySteps == 0) { - postDelaySlotInit << "*" << env["_trg_spk_que_ptr"] << ";" << std::endl; + postDelaySlotInit << "*$(_trg_spk_que_ptr);" << std::endl; } else { - postDelaySlotInit << "(*" << env["_trg_spk_que_ptr"] << " + " << (numTrgDelaySlots - numBackPropDelaySteps) << ") % " << numTrgDelaySlots << ";" << std::endl; + postDelaySlotInit << "(*$(_trg_spk_que_ptr) + " << (numTrgDelaySlots - numBackPropDelaySteps) << ") % " << numTrgDelaySlots << ";" << std::endl; } env.add(Type::Uint32, "_post_delay_slot", "postDelaySlot", - {env.addInitialiser(postDelaySlotInit.str())}, {"_trg_spk_que_ptr"}); + {env.addInitialiser(postDelaySlotInit.str(), env)}); env.add(Type::Uint32, "_post_delay_offset", "postDelayOffset", - {env.addInitialiser("const unsigned int postDelayOffset = " + LS(env, "_post_delay_slot") + " * " + LS(env, "num_post") + ";")}); + {env.addInitialiser("const unsigned int postDelayOffset = $(_post_delay_slot) * $(num_post);", env)}); if(batchSize > 1) { env.add(Type::Uint32, "_post_batch_delay_slot", "postBatchDelaySlot", - {env.addInitialiser("const unsigned int postBatchDelaySlot = " + LS(env, "_post_delay_slot") + " + (batch * " + std::to_string(numTrgDelaySlots) + ");")}); + {env.addInitialiser("const unsigned int postBatchDelaySlot =$(_post_delay_slot) + (batch * " + std::to_string(numTrgDelaySlots) + ");", env)}); env.add(Type::Uint32, "_post_batch_delay_offset", "postBatchDelayOffset", - {env.addInitialiser("const unsigned int postBatchDelayOffset = " + LS(env, "_post_delay_offset") + " + (" + LS(env, "_post_batch_offset") + " * " + std::to_string(numTrgDelaySlots) + ");")}); + {env.addInitialiser("const unsigned int postBatchDelayOffset = $(_post_delay_offset) + ($(_post_batch_offset) * " + std::to_string(numTrgDelaySlots) + ");", env)}); } if(env.getGroup().getArchetype().getWUModel()->isPrevPostSpikeTimeRequired()) { env.add(Type::Uint32, "_post_prev_spike_time_delay_offset", "postPrevSpikeTimeDelayOffset", - {env.addInitialiser("const unsigned int postPrevSpikeTimeDelayOffset = ((*" + LS(env, "_trg_spk_que_ptr") + " + " - + std::to_string(numTrgDelaySlots - numBackPropDelaySteps - 1) + ") % " + std::to_string(numTrgDelaySlots) + ") * " + LS(env, "num_post") + ";")}); + {env.addInitialiser("const unsigned int postPrevSpikeTimeDelayOffset = ((*$(_trg_spk_que_ptr) + " + + std::to_string(numTrgDelaySlots - numBackPropDelaySteps - 1) + ") % " + std::to_string(numTrgDelaySlots) + ") * $(num_post);", env)}); if(batchSize > 1) { env.add(Type::Uint32, "_post_prev_spike_time_batch_delay_offset", "postPrevSpikeTimeBatchDelayOffset", - {env.addInitialiser("const unsigned int postPrevSpikeTimeBatchDelayOffset = " + LS(env, "_post_prev_spike_time_delay_offset") + " + (" + LS(env, "_post_batch_offset") + " * " + std::to_string(numTrgDelaySlots) + ");")}); + {env.addInitialiser("const unsigned int postPrevSpikeTimeBatchDelayOffset = $(_post_prev_spike_time_delay_offset) + ($(_post_batch_offset) * " + std::to_string(numTrgDelaySlots) + ");", env)}); } } diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 8b8b2fec99..976d9e3b22 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -18,6 +18,7 @@ // GeNN code generator includes #include "backendBase.h" #include "codeStream.h" +#include "lazyString.h" #include "substitutions.h" #include "teeStream.h" @@ -134,6 +135,7 @@ GENN_EXPORT void prettyPrintStatements(const std::string &code, const Type::Type Transpiler::ErrorHandlerBase &errorHandler, Transpiler::TypeChecker::StatementHandler forEachSynapseTypeCheckHandler = nullptr, Transpiler::PrettyPrinter::StatementHandler forEachSynapsePrettyPrintHandler = nullptr); +GENN_EXPORT std::string printSubs(const std::string &format, EnvironmentExternalBase &env); //------------------------------------------------------------------------- /*! \brief Function for performing the code and value substitutions necessary to insert neuron related variables, parameters, and extraGlobal parameters into synaptic code. @@ -206,7 +208,7 @@ std::string getKernelSize(const G &group, size_t dimensionIndex) { // If kernel size if heterogeneous in this dimension, return group structure entry if (isKernelSizeHeterogeneous(group, dimensionIndex)) { - return "group->kernelSize" + std::to_string(dimensionIndex); + return "$(_kernel_size_" + std::to_string(dimensionIndex) + ")"; } // Otherwise, return literal else { @@ -215,13 +217,13 @@ std::string getKernelSize(const G &group, size_t dimensionIndex) } template -std::string getKernelIndex(const G &group, EnvironmentExternalBase &env) +std::string getKernelIndex(const G &group) { // Loop through kernel dimensions to calculate array index const auto &kernelSize = group.getArchetype().getKernelSize(); std::ostringstream kernelIndex; for (size_t i = 0; i < kernelSize.size(); i++) { - kernelIndex << "(" << env["id_kernel_" + std::to_string(i)]; + kernelIndex << "($(id_kernel_" << i << ")"; // Loop through remainining dimensions of kernel and multiply for (size_t j = i + 1; j < kernelSize.size(); j++) { kernelIndex << " * " << getKernelSize(group, j); diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index c856d29cc0..e4f88b3556 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -46,11 +46,6 @@ class EnvironmentExternalBase : public Transpiler::PrettyPrinter::EnvironmentBas { } - explicit EnvironmentExternalBase(Transpiler::TypeChecker::EnvironmentBase &enclosing) - : m_Context(std::make_pair(&enclosing, nullptr)) - { - } - explicit EnvironmentExternalBase(CodeStream &os) : m_Context(os) { @@ -115,11 +110,6 @@ class EnvironmentLibrary : public EnvironmentExternalBase { } - explicit EnvironmentLibrary(Transpiler::TypeChecker::EnvironmentBase &enclosing, const Library &library) - : EnvironmentExternalBase(enclosing), m_Library(library) - { - } - explicit EnvironmentLibrary(CodeStream &os, const Library &library) : EnvironmentExternalBase(os), m_Library(library) {} @@ -145,14 +135,14 @@ class EnvironmentLibrary : public EnvironmentExternalBase class EnvironmentSubstitutionPolicy { protected: - using Payload = std::string; + using Payload = LazyString; - std::string getNameInternal(const std::string &payload) + std::string getNameInternal(const LazyString &payload) { - return payload; + return payload.str(); } - void setRequired(std::string&) + void setRequired(LazyString&) { } }; @@ -170,7 +160,7 @@ class EnvironmentFieldPolicy const G &getGroup() const{ return m_Group; } protected: - using Payload = std::tuple>; + using Payload = std::tuple>; EnvironmentFieldPolicy(G &group, F &fieldGroup) : m_Group(group), m_FieldGroup(fieldGroup) @@ -185,13 +175,21 @@ class EnvironmentFieldPolicy std::string getNameInternal(const Payload &payload) { // If a field is specified + const auto str = std::get<1>(payload).str(); if(std::get<2>(payload)) { - return "group->" + std::get<1>(std::get<2>(payload).value()) + std::get<1>(payload); + // If there is no value specified, access field directly + if(str.empty()) { + return "group->" + std::get<1>(std::get<2>(payload).value()); + } + // Otherwise, treat value as index + else { + return "group->" + std::get<1>(std::get<2>(payload).value()) + "[" + str + "]"; + } } // Otherwise, use value directly else { - assert(!std::get<1>(payload).empty()); - return std::get<1>(payload); + assert(!str.empty()); + return str; } } @@ -237,11 +235,6 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...) {} - template - EnvironmentExternalDynamicBase(Transpiler::TypeChecker::EnvironmentBase &enclosing, PolicyArgs&&... policyArgs) - : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...) - {} - template EnvironmentExternalDynamicBase(CodeStream &os, PolicyArgs&&... policyArgs) : EnvironmentExternalBase(os), P(std::forward(policyArgs)...) @@ -273,7 +266,7 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P } // Otherwise, get name from payload else { - return getNameInternal(std::get<3>(env->second)); + return getNameInternal(std::get<2>(env->second)); } } @@ -296,15 +289,8 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P m_Initialisers.at(i).first = true; } - // If this identifier relies on any others, get their types - // **YUCK** - for(const std::string &id : std::get<2>(env->second)) { - getTypes(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, id, 0}, - errorHandler); - } - // Perform any type-specific logic to mark this identifier as required - setRequired(std::get<3>(env->second)); + setRequired(std::get<2>(env->second)); // Return type of variables return {std::get<0>(env->second)}; @@ -318,15 +304,20 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P return (m_Initialisers.size() - 1); } + size_t addInitialiser(const std::string &format, EnvironmentExternalBase &env) + { + return addInitialiser(LazyString::print(format, env)); + } + protected: //------------------------------------------------------------------------ // Protected API //------------------------------------------------------------------------ - //! Map an identifier to a type (for type-checking), lists of initialisers and dependencies and a payload + //! Map an identifier to a type (for type-checking), lists of initialisers and a payload void addInternal(const GeNN::Type::ResolvedType &type, const std::string &name, const typename P::Payload &payload, - const std::vector &initialisers = {}, const std::vector &dependents = {}) + const std::vector &initialisers = {}) { - if(!m_Environment.try_emplace(name, type, initialisers, dependents, payload).second) { + if(!m_Environment.try_emplace(name, type, initialisers, payload).second) { throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); } } @@ -338,7 +329,7 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P std::ostringstream m_ContentsStream; CodeStream m_Contents; - std::unordered_map, std::vector, typename P::Payload>> m_Environment; + std::unordered_map, typename P::Payload>> m_Environment; std::vector> m_Initialisers; }; @@ -356,10 +347,10 @@ class EnvironmentExternal : public EnvironmentExternalDynamicBase &initialisers = {}, const std::vector &dependents = {}) + void add(const GeNN::Type::ResolvedType &type, const std::string &name, const LazyString &value, + const std::vector &initialisers = {}) { - addInternal(type, name, value, initialisers, dependents); + addInternal(type, name, value, initialisers); } }; @@ -391,29 +382,28 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &initialisers = {}, const std::vector &dependents = {}) + void add(const GeNN::Type::ResolvedType &type, const std::string &name, const LazyString &value, + const std::vector &initialisers = {}) { - addInternal(type, name, std::make_tuple(false, value, std::nullopt), - initialisers, dependents); + addInternal(type, name, std::make_tuple(false, value, std::nullopt), initialisers); } //! Map a type (for type-checking) and a group merged field to back it to an identifier void addField(const GeNN::Type::ResolvedType &type, const std::string &name, const GeNN::Type::ResolvedType &fieldType, const std::string &fieldName, typename G::GetFieldValueFunc getFieldValue, - const std::string &indexSuffix = "", GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, - const std::vector &initialisers = {}, const std::vector &dependents = {}) + const LazyString &indexSuffix = "", GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, + const std::vector &initialisers = {}) { addInternal(type, name, std::make_tuple(false, indexSuffix, std::make_optional(std::make_tuple(fieldType, fieldName, getFieldValue, mergedFieldType))), - initialisers, dependents); + initialisers); } //! Map a type (for type-checking) and a group merged field to back it to an identifier void addField(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &fieldName, - typename G::GetFieldValueFunc getFieldValue, const std::string &indexSuffix = "", GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, - const std::vector &initialisers = {}, const std::vector &dependents = {}) + typename G::GetFieldValueFunc getFieldValue, const LazyString &indexSuffix = "", GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, + const std::vector &initialisers = {}) { - addField(type, name, type, fieldName, getFieldValue, indexSuffix, mergedFieldType, initialisers, dependents); + addField(type, name, type, fieldName, getFieldValue, indexSuffix, mergedFieldType, initialisers); } void addScalar(const std::string &name, const std::string &fieldSuffix, typename G::GetFieldDoubleValueFunc getFieldValue) @@ -583,8 +573,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase - void addVars(const std::string &arrayPrefix, I getIndexFn, const std::string &fieldSuffix = "", - const std::vector &dependents = {}) + void addVars(const std::string &arrayPrefix, I getIndexFn, const std::string &fieldSuffix = "") { // Loop through variables const A archetypeAdaptor(getGroup().getArchetype()); @@ -597,21 +586,19 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase - void addVars(const std::string &arrayPrefix, const std::string &index, const std::string &fieldSuffix = "", - const std::vector &dependents = {}) + void addVars(const std::string &arrayPrefix, const LazyString &indexSuffix, const std::string &fieldSuffix = "") { - addVars(arrayPrefix, [&index](VarAccess a, const std::string &) { return index; }, - fieldSuffix, dependents); + addVars(arrayPrefix, [&indexSuffix](VarAccess a, const std::string &) { return indexSuffix; }, + fieldSuffix); } template - void addVarRefs(const std::string &arrayPrefix, I getIndexFn, const std::string &fieldSuffix = "", - const std::vector &dependents = {}) + void addVarRefs(const std::string &arrayPrefix, I getIndexFn, const std::string &fieldSuffix = "") { // Loop through variable references const A archetypeAdaptor(getGroup().getArchetype()); @@ -627,16 +614,15 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase - void addVarRefs(const std::string &arrayPrefix, const std::string &index, const std::string &fieldSuffix = "", - const std::vector &dependents = {}) + void addVarRefs(const std::string &arrayPrefix, const LazyString &indexSuffix, const std::string &fieldSuffix = "") { - addVarRefs(arrayPrefix, [&index](VarAccess a, const std::string &) { return index; }, - fieldSuffix, dependents); + addVarRefs(arrayPrefix, [&indexSuffix](VarAccess a, const std::string &) { return indexSuffix; }, + fieldSuffix); } private: @@ -661,7 +647,7 @@ class VarCachePolicy { public: using GroupInternal = typename G::GroupInternal; - using GetIndexFn = std::function; + using GetIndexFn = std::function; VarCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex) : m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex) @@ -676,12 +662,12 @@ class VarCachePolicy return A(g).getNameSuffix(); } - std::string getReadIndex(G &g, const Models::Base::Var &var) + LazyString getReadIndex(G &g, const Models::Base::Var &var) { return m_GetReadIndex(var.name, getVarAccessDuplication(var.access)); } - std::string getWriteIndex(G &g, const Models::Base::Var &var) + LazyString getWriteIndex(G &g, const Models::Base::Var &var) { return m_GetWriteIndex(var.name, getVarAccessDuplication(var.access)); } @@ -697,7 +683,7 @@ class VarRefCachePolicy protected: using GroupInternal = typename G::GroupInternal; using Initialiser = typename std::remove_reference_t>::mapped_type; - using GetIndexFn = std::function; + using GetIndexFn = std::function; VarRefCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex) @@ -713,12 +699,12 @@ class VarRefCachePolicy return A(g).getInitialisers().at(var.name).getTargetName(); } - std::string getReadIndex(G &g, const Models::Base::VarRef &var) + LazyString getReadIndex(G &g, const Models::Base::VarRef &var) { return m_GetReadIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); } - std::string getWriteIndex(G &g, const Models::Base::VarRef &var) + LazyString getWriteIndex(G &g, const Models::Base::VarRef &var) { return m_GetWriteIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); } @@ -792,7 +778,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // **NOTE** by not initialising these variables for reductions, // compilers SHOULD emit a warning if user code doesn't set it to something if(!(v.access & VarAccessModeAttribute::REDUCE)) { - getContextStream() << " = group->" << v.name << m_FieldSuffix << "[" << getReadIndex(m_Group.get(), v) << "]"; + getContextStream() << " = group->" << v.name << m_FieldSuffix << "[" << getReadIndex(m_Group.get(), v).str() << "]"; } getContextStream() << ";" << std::endl; } @@ -804,7 +790,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P for(const auto &v : referencedDefs) { // If variables are read-write if(v.access & VarAccessMode::READ_WRITE) { - getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << getWriteIndex(m_Group.get(), v) << "]"; + getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << getWriteIndex(m_Group.get(), v).str() << "]"; getContextStream() << " = " << m_LocalPrefix << v.name << ";" << std::endl; } } diff --git a/include/genn/genn/code_generator/lazyString.h b/include/genn/genn/code_generator/lazyString.h index bc3ec874de..5f86ce4a53 100644 --- a/include/genn/genn/code_generator/lazyString.h +++ b/include/genn/genn/code_generator/lazyString.h @@ -36,7 +36,9 @@ class LazyString //! Evaluate lazy string std::string str() const; + //---------------------------------------------------------------------------- // Static API + //---------------------------------------------------------------------------- static LazyString print(const std::string &format, EnvironmentExternalBase &env); private: diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index d77702210b..9f18ac31a4 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -82,7 +82,7 @@ void genKernelIteration(EnvironmentExternalBase &env, const G &g, size_t numKern { // Loop through this kernel dimensions const std::string idxVar = "k" + std::to_string(depth); - env.getStream() << "for(unsigned int " << idxVar << " = 0; " << idxVar << " < " << getKernelSize(g, depth) << "; " << idxVar << "++)"; + env.getStream() << "for(unsigned int " << idxVar << " = 0; " << idxVar << " < " << printSubs(getKernelSize(g, depth), env) << "; " << idxVar << "++)"; { CodeStream::Scope b(env.getStream()); EnvironmentGroupMergedField loopEnv(env, g); @@ -95,7 +95,7 @@ void genKernelIteration(EnvironmentExternalBase &env, const G &g, size_t numKern // Generate kernel index and use as "synapse" index // **TODO** rename loopEnv.add(Type::Uint32.addConst(), "id_syn", "kernelInd", - {loopEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(g, loopEnv) + ";")}); + {loopEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(g) + ";", loopEnv)}); // Call handler handler(loopEnv); @@ -163,28 +163,28 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host if(n.getArchetype().isDelayRequired()) { if(n.getArchetype().isPrevSpikeTimeRequired()) { // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["_spk_cnt"] << "[" << groupEnv["_read_delay_slot"] << "]; i++)"; + groupEnv.getStream() << printSubs("for(unsigned int i = 0; i < $(_spk_cnt)[$(_read_delay_slot)]; i++)", groupEnv); { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.getStream() << groupEnv["_prev_spk_time"] << "[" << groupEnv["_read_delay_offset"] << " + " << groupEnv["_spk"] << "[" << groupEnv["_read_delay_offset"] << " + i]] = t - DT;" << std::endl; + groupEnv.getStream() << printSubs("$(_prev_spk_time)[$(_read_delay_offset) + $(_spk)[$(_read_delay_offset) + i]] = t - DT;", groupEnv) << std::endl; } } if(n.getArchetype().isPrevSpikeEventTimeRequired()) { // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["_spk_cnt_envt"] << "[" << groupEnv["_read_delay_slot"] << "]; i++)"; + groupEnv.getStream() << printSubs("for(unsigned int i = 0; i < $(_spk_cnt_envt)[$(_read_delay_slot)]; i++)", groupEnv); { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.getStream() << groupEnv["_prev_spk_evnt_time"] << "[" << groupEnv["_read_delay_offset"] << " + " << groupEnv["_spk_evnt"] << "[" << groupEnv["_read_delay_offset"] << " + i]] = t - DT;" << std::endl; + groupEnv.getStream() << printSubs("$(_prev_spk_evnt_time)[$(_read_delay_offset) + $(_spk_evnt)[$(_read_delay_offset) + i]] = t - DT;", groupEnv) << std::endl; } } } else { if(n.getArchetype().isPrevSpikeTimeRequired()) { // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["_spk_cnt"] << "[0]; i++)"; + groupEnv.getStream() << printSubs("for(unsigned int i = 0; i < $(_spk_cnt)[0]; i++)", groupEnv); { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.getStream() << groupEnv["_prev_spk_time"] << "[" << groupEnv["_spk"] << "[i]] = t - DT;" << std::endl; + groupEnv.getStream() << printSubs("$(_prev_spk_time)[$(_spk)[i]] = t - DT;", groupEnv) << std::endl; } } if(n.getArchetype().isPrevSpikeEventTimeRequired()) { @@ -303,8 +303,6 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host //-------------------------------------------------------------------------- void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { - using LS = LazyString; - if (modelMerged.getModel().getBatchSize() != 1) { throw std::runtime_error("The single-threaded CPU backend only supports simulations with a batch size of 1"); } @@ -370,31 +368,27 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos if (s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Add initialiser strings to calculate synaptic and presynaptic index - const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * " + LS(synEnv, "_row_stride") + ") + s;"); - const size_t idPostInit = synEnv.addInitialiser("const unsigned int idPost = " + LS(synEnv, "_ind") + "[idSyn];"); + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * $(_row_stride)) + s;", synEnv); + const size_t idPostInit = synEnv.addInitialiser("const unsigned int idPost = $(_ind)[$(id_syn)];", synEnv); // **TODO** id_syn can be 64-bit - synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}, {"_row_stride"}); - synEnv.add(Type::Uint32.addConst(), "id_post", "idPost", {idPostInit, idSynInit}, {"_ind"}); + synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}); + synEnv.add(Type::Uint32.addConst(), "id_post", "idPost", {idPostInit, idSynInit}); } else { // Add postsynaptic index to substitutions synEnv.add(Type::Uint32.addConst(), "id_post", "j"); // Add initialiser to calculate synaptic index - const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * " + LS(synEnv, "num_post") + ") + j;"); - // **TODO** id_syn can be 64-bit - synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}); + synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", + {synEnv.addInitialiser("const unsigned int idSyn = (i * $(num_post)) + j;", synEnv)}); } // Add correct functions for apply synaptic input - synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", synEnv["_den_delay"] + "[" + s.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)", - {}, {"_den_delay"}); - synEnv.add(Type::AddToPost, "addToPost", synEnv["_out_post"] + "[" + s.getPostISynIndex(1, "j") + "] += $(0)", - {}, {"_out_post"}); - synEnv.add(Type::AddToPre, "addToPre", synEnv["_out_pre"] + "[" + s.getPreISynIndex(1, synEnv["id_pre"]) + "] += $(0)", - {}, {"id_pre"}); + synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", LazyString::print("$(_den_delay)[" + s.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)", synEnv)); + synEnv.add(Type::AddToPost, "addToPost", LazyString::print("$(_out_post)[" + s.getPostISynIndex(1, "j") + "] += $(0)", synEnv)); + synEnv.add(Type::AddToPre, "addToPre", LazyString::print("$(_out_pre)[" + s.getPreISynIndex(1, "id_pre") + "] += $(0)", synEnv)); // Call synapse dynamics handler s.generateSynapseUpdate(*this, synEnv, modelMerged); @@ -462,10 +456,10 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Get number of postsynaptic spikes if (s.getArchetype().getTrgNeuronGroup()->isDelayRequired() && s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired()) { - groupEnv.getStream() << "const unsigned int numSpikes = group->trgSpkCnt[postDelaySlot];" << std::endl; + groupEnv.getStream() << printSubs("const unsigned int numSpikes = $(_trg_spk_cnt)[$(_post_delay_slot)];", groupEnv) << std::endl; } else { - groupEnv.getStream() << "const unsigned int numSpikes = group->trgSpkCnt[0];" << std::endl; + groupEnv.getStream() << printSubs("const unsigned int numSpikes = $(_trg_spk_cnt)[0];", groupEnv) << std::endl; } // Loop through postsynaptic spikes @@ -474,12 +468,12 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos CodeStream::Scope b(groupEnv.getStream()); // **TODO** prod types - const std::string offsetTrueSpkPost = (s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired() && s.getArchetype().getTrgNeuronGroup()->isDelayRequired()) ? "postDelayOffset + " : ""; - groupEnv.getStream() << "const unsigned int spike = group->trgSpk[" << offsetTrueSpkPost << "j];" << std::endl; + const std::string offsetTrueSpkPost = (s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired() && s.getArchetype().getTrgNeuronGroup()->isDelayRequired()) ? "$(_post_delay_offset) + " : ""; + groupEnv.getStream() << printSubs("const unsigned int spike = $(_trg_spk)[" + offsetTrueSpkPost + "j];", groupEnv) << std::endl; // Loop through column of presynaptic neurons if (s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - groupEnv.getStream() << "const unsigned int npre = group->colLength[spike];" << std::endl; + groupEnv.getStream() << printSubs("const unsigned int npre = $(_col_length)[spike];", groupEnv) << std::endl; groupEnv.getStream() << "for (unsigned int i = 0; i < npre; i++)"; } else { @@ -492,25 +486,23 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Add initialisers to calculate column and row-major indices // **TODO** fast divide optimisations - const size_t colMajorIdxInit = synEnv.addInitialiser("const unsigned int colMajorIndex = (spike * " + LS(synEnv, "_col_stride") + ") + i;"); - const size_t rowMajorIdxInit = synEnv.addInitialiser("const unsigned int rowMajorIndex = " + LS(synEnv. "_remap") + "[colMajorIndex];"); - const size_t idPreInit = synEnv.addInitialiser("const unsigned int idPre = rowMajorIndex / " + LS(synEnv, "_row_stride") + ";"); + const size_t colMajorIdxInit = synEnv.addInitialiser("const unsigned int colMajorIndex = (spike * $(_col_stride)) + i;", synEnv); + const size_t rowMajorIdxInit = synEnv.addInitialiser("const unsigned int rowMajorIndex = $(_remap)[colMajorIndex];", synEnv); + const size_t idPreInit = synEnv.addInitialiser("const unsigned int idPre = rowMajorIndex / $(_row_stride);", synEnv); // Add presynaptic and synapse index to environment - synEnv.add(Type::Uint32.addConst(), "id_pre", "idPre", {colMajorIdxInit, rowMajorIdxInit, idPreInit}, {"_col_stride", "_row_stride", "_remap"}); - synEnv.add(Type::Uint32.addConst(), "id_syn", "rowMajorIndex", {colMajorIdxInit, rowMajorIdxInit}, {"_col_stride", "_remap"}); + synEnv.add(Type::Uint32.addConst(), "id_pre", "idPre", {colMajorIdxInit, rowMajorIdxInit, idPreInit}); + synEnv.add(Type::Uint32.addConst(), "id_syn", "rowMajorIndex", {colMajorIdxInit, rowMajorIdxInit}); } else { - // Add initialiser to calculate synaptic index - const size_t idSynInit = groupEnv.addInitialiser("const unsigned int idSyn = (i * " + LS(synEnv, "num_post") + ") + spike;"); - // Add presynaptic and synapse index to environment synEnv.add(Type::Uint32.addConst(), "id_pre", "i"); - synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}, {"num_post"}); + synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", + {synEnv.addInitialiser("const unsigned int idSyn = (i * $(num_post)) + spike;", synEnv)}); } synEnv.add(Type::Uint32.addConst(), "id_post", "spike"); - synEnv.add(Type::AddToPre, "addToPre", synEnv["_out_pre"] + "[" + s.getPreISynIndex(1, synEnv["id_pre"]) + "] += $(0)"); + synEnv.add(Type::AddToPre, "addToPre", LazyString::print("$(_out_pre)[" + s.getPreISynIndex(1, "id_pre") + "] += $(0)", synEnv)); s.generateSynapseUpdate(*this, synEnv, modelMerged); } @@ -543,8 +535,6 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos //-------------------------------------------------------------------------- void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { - using LS = LazyString; - const ModelSpecInternal &model = modelMerged.getModel(); // Build set containing names of all custom update groups @@ -562,7 +552,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Generate stream with custom update code std::ostringstream customUpdateStream; CodeStream customUpdate(customUpdateStream); - + // Begin environment with standard library EnvironmentLibrary customUpdateEnv(customUpdate, StandardLibrary::getFunctions()); @@ -706,18 +696,18 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // If connectivity is sparse if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Add initialisers to calculate synaptic index and thus lookup postsynaptic index - const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * group->rowStride) + s;"); - const size_t jInit = synEnv.addInitialiser("const unsigned int j = group->ind[idSyn];"); + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * $(_row_stride)) + s;", synEnv); + const size_t jInit = synEnv.addInitialiser("const unsigned int j = $(_ind)[idSyn];", synEnv); // Add substitutions - synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}, {"_row_stride"}); - synEnv.add(Type::Uint32.addConst(), "id_post", "j", {jInit, idSynInit}, {"_ind", "_row_stride"}); + synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}); + synEnv.add(Type::Uint32.addConst(), "id_post", "j", {jInit, idSynInit}); } else { synEnv.add(Type::Uint32.addConst(), "id_post", "j"); synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", - {synEnv.addInitialiser("const unsigned int idSyn = (i * " + LS(synEnv, "num_post") + ") + j;")}); + {synEnv.addInitialiser("const unsigned int idSyn = (i * $(num_post)) + j;", synEnv)}); } // Generate custom update @@ -751,12 +741,12 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host genCustomConnectivityUpdateIndexCalculation(funcEnv.getStream(), c); // Loop through presynaptic neurons - funcEnv.getStream() << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; + funcEnv.getStream() << "for(unsigned int i = 0; i < " << funcEnv["num_pre"] << "; i++)"; { CodeStream::Scope b(funcEnv.getStream()); // Configure substitutions - groupEnv.add(Type::Uint32, "id_pre", "i"); + groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); assert(false); //c.generateUpdate(*this, cuEnv, model.getBatchSize()); @@ -797,24 +787,24 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host CodeStream::Scope b(groupEnv.getStream()); // Loop through each postsynaptic neuron - groupEnv.getStream() << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; + groupEnv.getStream() << "for (unsigned int j = 0; j < " << funcEnv["groupEnv"] << "; j++)"; { CodeStream::Scope b(groupEnv.getStream()); // Add pre and postsynaptic indices to environment - groupEnv.add(Type::Uint32, "id_pre", "i"); - groupEnv.add(Type::Uint32, "id_post", "j"); + groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); + groupEnv.add(Type::Uint32.addConst(), "id_post", "j"); // Add conditional initialisation code to calculate synapse index - groupEnv.add(Type::Uint32, "id_syn", "idSyn", - {groupEnv.addInitialiser("const unsigned int idSyn = (i * " + LS(groupEnv, "num_post") + ") + j;")}); + groupEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", + {groupEnv.addInitialiser("const unsigned int idSyn = (i * $(num_post)) + j;", groupEnv)}); // Generate custom update c.generateCustomUpdate(*this, groupEnv); // Update transpose variable // **YUCK** this is sorta outside scope - groupEnv.getStream() << groupEnv[transposeVarName + "_transpose"] << "[(j * " << LS(groupEnv, "num_pre") << ") + i] = l" << transposeVarName << ";" << std::endl; + groupEnv.getStream() << printSubs("$(" + transposeVarName + "_transpose)[(j * $(num_pre)) + i] = l" + transposeVarName + ";", groupEnv) << std::endl; } } @@ -991,10 +981,10 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // If matrix connectivity is ragged if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Zero row lengths - funcEnv.getStream() << "std::fill_n(" << groupEnv["_row_length"] << ", " << groupEnv["num_pre"] << ", 0);" << std::endl; + funcEnv.getStream() << printSubs("std::fill_n($(_row_length), $(num_pre), 0);", funcEnv) << std::endl; } else if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { - funcEnv.getStream() << "const size_t gpSize = ((((size_t)" << groupEnv["num_pre"] << " * (size_t)" << groupEnv["_row_stride"] << ") + 32 - 1) / 32);" << std::endl; + funcEnv.getStream() << printSubs("const size_t gpSize = ((((size_t)$(num_pre) * (size_t)$(_row_stride)) + 32 - 1) / 32);", funcEnv) << std::endl; funcEnv.getStream() << "std::fill(" << groupEnv["_num_gp"] << ", gpSize, 0);" << std::endl; } else { @@ -1020,7 +1010,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler assert(!snippet->getColBuildCode().empty()); // Loop through target neurons - groupEnv.getStream() << "for (unsigned int j = 0; j < group->numTrgNeurons; j++)"; + groupEnv.getStream() << "for (unsigned int j = 0; j < " << groupEnv["num_post"] << "; j++)"; // Configure substitutions groupEnv.add(Type::Uint32.addConst(), "id_post", "j"); @@ -1045,10 +1035,10 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Calculate index in data structure of this synapse if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { if(!snippet->getRowBuildCode().empty()) { - addSynapse << "const unsigned int idx = " << "(" + groupEnv["id_pre"] + " * " << groupEnv["_row_stride"] << ") + " << groupEnv["_row_length"] << "[i];" << std::endl; + addSynapse << "const unsigned int idx = " << "($(id_pre) * $(_row_stride)) + $(_row_length)[i];" << std::endl; } else { - addSynapse << "const unsigned int idx = " << "(($(0)) * " << groupEnv["_row_stride"] << ") + " << groupEnv["_row_length"] << "[$(0)];" << std::endl; + addSynapse << "const unsigned int idx = " << "(($(0)) * $(_row_stride)) + $(_row_length)[$(0)];" << std::endl; } } @@ -1084,12 +1074,12 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler if(!snippet->getRowBuildCode().empty()) { // If matrix is sparse, add function to increment row length and insert synapse into ind array if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addSynapse << groupEnv["_ind"] << "[idx] = $(0);" << std::endl; - addSynapse << groupEnv["_row_length"] << "[i]++;" << std::endl; + addSynapse << "$(_ind)[idx] = $(0);" << std::endl; + addSynapse << "$(_row_length)[i]++;" << std::endl; } // Otherwise, add function to set correct bit in bitmask else { - addSynapse << "const int64_t rowStartGID = i * " << groupEnv["_row_stride"] << ";" << std::endl; + addSynapse << "const int64_t rowStartGID = i * $(_row_stride);" << std::endl; addSynapse << "setB(group->gp[(rowStartGID + ($(0))) / 32], (rowStartGID + $(0)) & 31);" << std::endl; } } @@ -1097,19 +1087,19 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler else { // If matrix is sparse, add function to increment row length and insert synapse into ind array if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - addSynapse << groupEnv["_ind"] << "[idx] = " << groupEnv["id_post"] << ";" << std::endl; - addSynapse << groupEnv["_row_length"] << "[$(0)]++;" << std::endl; + addSynapse << "$(_ind)[idx] = $(id_post);" << std::endl; + addSynapse << "$(_row_length)[$(0)]++;" << std::endl; } else { addSynapse << "const int64_t colStartGID = j;" << std::endl; - addSynapse << "setB(" << groupEnv["_gp"] << "[(colStartGID + (($(0)) * " << groupEnv["_row_stride"] << ")) / 32], ((colStartGID + (($(0)) * " << groupEnv["_row_stride"] << ")) & 31));" << std::endl; + addSynapse << "setB($(_gp)[(colStartGID + (($(0)) * $(_row_stride))) / 32], ((colStartGID + (($(0)) * $(_row_stride))) & 31));" << std::endl; } } } addSynapse << "while(false)"; const auto addSynapseType = Type::ResolvedType::createFunction(Type::Void, std::vector{1ull + s.getArchetype().getKernelSize().size(), Type::Uint32}); - groupEnv.add(addSynapseType, "addSynapse", addSynapseStream.str()); + groupEnv.add(addSynapseType, "addSynapse", LazyString::print(addSynapseStream.str(), groupEnv)); // Call appropriate connectivity handler if(!snippet->getRowBuildCode().empty()) { @@ -1149,7 +1139,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // If postsynaptic learning is required, initially zero column lengths if (!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { groupEnv.getStream() << "// Zero column lengths" << std::endl; - groupEnv.getStream() << "std::fill_n(" << groupEnv["_col_length"] << ", " << groupEnv["num_post"] << ", 0);" << std::endl; + groupEnv.getStream() << printSubs("std::fill_n($(_col_length), $(num_post), 0);", groupEnv) << std::endl; } groupEnv.getStream() << "// Loop through presynaptic neurons" << std::endl; @@ -1160,8 +1150,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Generate sparse initialisation code groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); if(s.getArchetype().isWUVarInitRequired()) { - groupEnv.add(Type::Uint32.addConst(), "row_len", groupEnv["_row_length"] + "[i]", - {}, {"_row_length"}); + groupEnv.add(Type::Uint32.addConst(), "row_len", LazyString::print("$(_row_length)[i]", groupEnv)); s.generateInit(*this, groupEnv, modelMerged); } @@ -1214,8 +1203,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Generate initialisation code groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); - groupEnv.add(Type::Uint32.addConst(), "row_len", groupEnv["_row_length"] + "[i]", - {}, {"_row_length"}); + groupEnv.add(Type::Uint32.addConst(), "row_len", LazyString::print("$(_row_length)[i]", groupEnv)); c.generateInit(*this, groupEnv, modelMerged); } } @@ -1244,8 +1232,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Generate initialisation code groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); - groupEnv.add(Type::Uint32.addConst(), "row_len", groupEnv["_row_length"] + "[i]", - {}, {"_row_length"}); + groupEnv.add(Type::Uint32.addConst(), "row_len", LazyString::print("$(_row_length)[i]", groupEnv)); c.generateInit(*this, groupEnv, modelMerged); } } @@ -1500,26 +1487,22 @@ void Backend::genVariableInit(EnvironmentExternalBase &env, const std::string &c //-------------------------------------------------------------------------- void Backend::genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const { - using LS = LazyString; - - env.getStream() << "for (unsigned int j = 0; j < group->rowLength[" << env["id_pre"] << "]; j++)"; + env.getStream() << printSubs("for (unsigned int j = 0; j < $(_row_length)[$(id_pre)]; j++)", env); { CodeStream::Scope b(env.getStream()); EnvironmentExternal varEnv(env); // **TODO** 64-bit varEnv.add(Type::Uint32, "id_syn", "idSyn", - {varEnv.addInitialiser("const unsigned int idSyn = (" + LS(varEnv, "id_pre") + " * " + LS(varEnv, "_row_stride") + ") + j;")}); + {varEnv.addInitialiser("const unsigned int idSyn = ($(id_pre) * $(_row_stride)) + j;", varEnv)}); varEnv.add(Type::Uint32, "id_post", "idPost", - {varEnv.addInitialiser("const unsigned int idPost = (" + LS(varEnv, "_ind") + "[(" + LS(varEnv, "id_pre") + " * " + LS(varEnv, "_row_stride") + ") + j]")}); + {varEnv.addInitialiser("const unsigned int idPost = $(_ind)[($(id_pre) * $(_row_stride)) + j]", varEnv)}); handler(varEnv); } } //-------------------------------------------------------------------------- void Backend::genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const { - using LS = LazyString; - env.getStream() << "for (unsigned int j = 0; j < " << env["num_post"] << "; j++)"; { CodeStream::Scope b(env.getStream()); @@ -1527,7 +1510,7 @@ void Backend::genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, Handl EnvironmentExternal varEnv(env); // **TODO** 64-bit varEnv.add(Type::Uint32, "id_syn", "idSyn", - {varEnv.addInitialiser("const unsigned int idSyn = (" + LS(varEnv, "id_pre") + " * " + LS(varEnv, "_row_stride") + ") + j;")}); + {varEnv.addInitialiser("const unsigned int idSyn = ($(id_pre) * $(_row_stride)) + j;", varEnv)}); varEnv.add(Type::Uint32, "id_post", "j"); handler(varEnv); } @@ -1732,8 +1715,6 @@ boost::uuids::detail::sha1::digest_type Backend::getHashDigest() const //-------------------------------------------------------------------------- void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, const ModelSpecMerged &modelMerged, bool trueSpike) const { - using LS = LazyString; - // Get suffix based on type of events const std::string eventSuffix = trueSpike ? "" : "Evnt"; const auto *wu = sg.getArchetype().getWUModel(); @@ -1880,16 +1861,13 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda } // Add correct functions for apply synaptic input - groupEnv.add(Type::AddToPostDenDelay, "addToPostDelay", env["_den_delay"] + "[" + sg.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)", - {}, {"_den_delay"}); - groupEnv.add(Type::AddToPost, "addToPost", env["_out_post"] + "[" + sg.getPostISynIndex(1, "j") + "] += $(0)", - {}, {"_out_post"}); - groupEnv.add(Type::AddToPre, "addToPre", env["_out_pre"] + "[" + sg.getPreISynIndex(1, env["id_pre"]) + "] += $(0)", - {}, {"id_pre"}); + groupEnv.add(Type::AddToPostDenDelay, "addToPostDelay", LazyString::print("$(_den_delay)[" + sg.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)", groupEnv)); + groupEnv.add(Type::AddToPost, "addToPost", LazyString::print("$(_out_post)[" + sg.getPostISynIndex(1, "j") + "] += $(0)", groupEnv)); + groupEnv.add(Type::AddToPre, "addToPre", LazyString::print("$(_out_pre)[" + sg.getPreISynIndex(1, env["id_pre"]) + "] += $(0)", groupEnv)); // If connectivity is sparse if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - groupEnv.getStream() << "const unsigned int npost = group->rowLength[ipre];" << std::endl; + groupEnv.getStream() << "const unsigned int npost = " << groupEnv["_row_length"] << "[ipre];" << std::endl; groupEnv.getStream() << "for (unsigned int j = 0; j < npost; j++)"; { CodeStream::Scope b(groupEnv.getStream()); @@ -1897,9 +1875,9 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda // **TODO** 64-bit id_syn synEnv.add(Type::Uint32, "id_syn", "idSyn", - {synEnv.addInitialiser("const unsigned int idSyn = (" + LS(env, "id_pre") + " * " + LS(env, "_row_stride") + ") + j;")}); + {synEnv.addInitialiser("const unsigned int idSyn = ($(id_pre) * $(_row_stride)) + j;", synEnv)}); synEnv.add(Type::Uint32, "id_post", "idPost", - {synEnv.addInitialiser("const unsigned int idPost = " + LS(env, "_ind") + "[" + LS(env, "id_syn") + "];")}); + {synEnv.addInitialiser("const unsigned int idPost = $(_ind)[$(id_syn)];", synEnv)}); if(trueSpike) { sg.generateSpikeUpdate(*this, synEnv, modelMerged); @@ -1969,13 +1947,13 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { // **TODO** 64-bit index - synEnv.getStream() << "const uint64_t gid = (" << synEnv["id_pre"] << " * " << synEnv["num_post"] << ") + " << synEnv["id_post"] + ";" << std::endl; + synEnv.getStream() << printSubs("const uint64_t gid = ($(id_pre) * $(num_post)) + $(id_post);", synEnv) << std::endl; synEnv.getStream() << "if (B(group->gp[gid / 32], gid & 31))" << CodeStream::OB(20); } else { synEnv.add(Type::Uint32, "id_syn", "idSyn", - {synEnv.addInitialiser("const unsigned int idSyn = (" + LS(synEnv, "id_pre") + " * " + LS(synEnv, "num_post") + ") + " + LS(synEnv, "id_post") + ";")}); + {synEnv.addInitialiser("const unsigned int idSyn = ($(id_pre) * $(num_post)) + $(id_post);", synEnv)}); } diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 1ef8ecc2bd..5c43c861fb 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -35,8 +35,6 @@ bool BackendBase::areSixtyFourBitSynapseIndicesRequired(const SynapseGroupMerged //----------------------------------------------------------------------- void BackendBase::genCustomUpdateIndexCalculation(EnvironmentGroupMergedField &env) const { - using LS = LazyString; - // Add size field env.addField(Type::Uint32, "size", "size", [](const auto &c, size_t) { return std::to_string(c.getSize()); }); @@ -44,7 +42,7 @@ void BackendBase::genCustomUpdateIndexCalculation(EnvironmentGroupMergedFieldgetNumDelaySlots()); env.add(Type::Uint32.addConst(), "_batch_delay_slot", "batchDelaySlot", - {env.addInitialiser(LS::print("const unsigned int batchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_delay_slot);", env))}); + {env.addInitialiser("const unsigned int batchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_delay_slot);", env)}); // Calculate current batch offset env.add(Type::Uint32.addConst(), "_batch_delay_offset", "batchDelayOffset", - {env.addInitialiser(LS::print("const unsigned int batchDelayOffset = $(_batch_offset) * " + numDelaySlotsStr + ";", env))}); + {env.addInitialiser("const unsigned int batchDelayOffset = $(_batch_offset) * " + numDelaySlotsStr + ";", env)}); } } } diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index ec4c310d43..f9aad04da2 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -548,4 +548,40 @@ void prettyPrintStatements(const std::string &code, const Type::TypeContext &typ // Pretty print Transpiler::PrettyPrinter::print(std::get<0>(statementTypes), env, typeContext, std::get<1>(statementTypes), forEachSynapsePrettyPrintHandler); } +std::string printSubs(const std::string &format, EnvironmentExternalBase &env) +{ + // Create regex iterator to iterate over $(XXX) style varibles in format string + std::regex regex("\\$\\(([\\w]+)\\)"); + std::sregex_iterator matchesBegin(format.cbegin(), format.cend(), regex); + std::sregex_iterator matchesEnd; + + // If there are no matches, leave format unmodified and return + if(matchesBegin == matchesEnd) { + return format; + } + // Otherwise + else { + // Loop through matches to build lazy string payload + std::string output; + for(std::sregex_iterator m = matchesBegin;;) { + // Copy the non-matched subsequence (m->prefix()) onto output + std::copy(m->prefix().first, m->prefix().second, std::back_inserter(output)); + + // Add environment value of $(XXX) to output + output += env[(*m)[1]]; + + // If there are no subsequent matches, add the remaining + // non-matched characters onto output and return + if(std::next(m) == matchesEnd) { + // Copy the non-matched subsequence (m->prefix()) onto output + std::copy(m->suffix().first, m->suffix().second, std::back_inserter(output)); + return output; + } + // Otherwise go onto next match + else { + m++; + } + } + } +} } // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 31963b66fa..c088a14e1a 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -239,28 +239,24 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Add presynaptic variables and variable references // **TODO** var references to batched variables should be private // **THINK** what about batched pre var references? - updateEnv.addVars(backend.getDeviceVarPrefix(), updateEnv["id_pre"], "", - {"id_pre"}); + updateEnv.addVars(backend.getDeviceVarPrefix(), LazyString{updateEnv, "id_pre"}, ""); updateEnv.addVarRefs(backend.getDeviceVarPrefix(), [&updateEnv](VarAccessMode, const Models::VarReference &v) { if(v.getDelayNeuronGroup() != nullptr) { - return "[" + updateSubs["_pre_delay_offset"] + " + " + updateSubs["id_pre"] + "]"; + return LazyString::print("$(_pre_delay_offset) + $(id_pre)", updateEnv); } else { - return "[" + updateSubs["id_pre"] + "]"; + return LazyString{updateEnv, "id_pre"}; } - }, "", - {"id_pre"}); + }, ""); // Calculate index of start of row updateEnv.add(Type::Uint32.addConst(), "_row_start_idx", "rowStartIdx", - {updateEnv.addInitialiser("const unsigned int rowStartIdx = " + updateEnv["id_pre"] + " * " + updateEnv["_row_stride"] + ";")}, - {"id_pre", "_row_stride"}); + {updateEnv.addInitialiser("const unsigned int rowStartIdx = $(id_pre) * $(_row_stride);", updateEnv)}); updateEnv.add(Type::Uint32.addConst(), "_syn_stride", "synStride", - {updateEnv.addInitialiser("const unsigned int synStride = " + updateEnv["num_pre"] + " * " + updateEnv["_row_stride"] + ";")}, - {"num_pre", "_row_stride"}); + {updateEnv.addInitialiser("const unsigned int synStride = $(num_pre) * $(_row_stride);", updateEnv)}); // Get variables which will need to be manipulated when adding and removing synapses const auto ccuVars = cm->getVars(); @@ -276,13 +272,13 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back CodeStream::Scope b(addSynapse); // Assert that there is space to add synapse - backend.genAssert(addSynapse, updateEnv["_row_length"] + "[" + updateEnv["id_pre"] + "] < " + updateEnv["_row_stride"]); + backend.genAssert(addSynapse, "$(_row_length)[$(id_pre)] < $(_row_stride)"); // Calculate index to insert synapse - addSynapse << "const unsigned newIdx = " + updateEnv["_row_start_idx"] + " + " + updateEnv["_row_length"] + "[" << updateEnv["id_pre"] << "];" << std::endl; + addSynapse << "const unsigned newIdx = $(_row_start_idx) + $(_row_length)[$(id_pre)];" << std::endl; // Set postsynaptic target to parameter 0 - addSynapse << updateEnv["_ind"] + "[newIdx] = $(0);" << std::endl; + addSynapse << "$(_ind)[newIdx] = $(0);" << std::endl; // Use subsequent parameters to initialise new synapse's custom connectivity update model variables for (size_t i = 0; i < ccuVars.size(); i++) { @@ -300,7 +296,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); - addSynapse << "group->" << ccuVarRefs[i].name << "[(b * " << updateEnv["_syn_stride"] << ") + newIdx] = _" << ccuVarRefs[i].name << "Val;" << std::endl; + addSynapse << "group->" << ccuVarRefs[i].name << "[(b * $(_syn_stride)) + newIdx] = _" << ccuVarRefs[i].name << "Val;" << std::endl; } } // Otherwise, write parameter straight into var reference @@ -320,7 +316,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); - addSynapse << "group->_dependentVar" << i << "[(b * " << updateEnv["_syn_stride"] << ") + newIdx] = 0;" << std::endl; + addSynapse << "group->_dependentVar" << i << "[(b * $(_syn_stride)) + newIdx] = 0;" << std::endl; } } // Otherwise, zero var reference @@ -333,11 +329,11 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Increment row length // **NOTE** this will also effect any forEachSynapse loop currently in operation - addSynapse << updateEnv["_row_length"] + "[" << updateEnv["id_pre"] << "]++;" << std::endl; + addSynapse << "$(_row_length)[$(id_pre)]++;" << std::endl; } // Add function substitution with parameters to initialise custom connectivity update variables and variable references - updateEnv.add(Type::ResolvedType::createFunction(Type::Void, addSynapseTypes), "add_synapse", addSynapseStream.str()); + updateEnv.add(Type::ResolvedType::createFunction(Type::Void, addSynapseTypes), "add_synapse", LazyString(addSynapseStream.str(), updateEnv)); // Generate code to remove a synapse from this row std::stringstream removeSynapseStream; @@ -346,10 +342,10 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back CodeStream::Scope b(removeSynapse); // Calculate index we want to copy synapse from - removeSynapse << "const unsigned lastIdx = " + updateEnv["_row_start_idx"] + " + " + updateEnv["_row_length"] + "[" << updateEnv["id_pre"] << "] - 1;" << std::endl; + removeSynapse << "const unsigned lastIdx = $(_row_start_idx) + $(_row_length)[$(id_pre)] - 1;" << std::endl; // Copy postsynaptic target from end of row over synapse to be deleted - removeSynapse << updateEnv["_ind"] << "[idx] = " << updateEnv["_ind"] << "[lastIdx];" << std::endl; + removeSynapse << "$(_ind)[idx] = $(_ind)[lastIdx];" << std::endl; // Copy custom connectivity update variables from end of row over synapse to be deleted for (size_t i = 0; i < ccuVars.size(); i++) { @@ -366,8 +362,8 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back removeSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); - removeSynapse << "group->" << ccuVarRefs[i].name << "[(b * " << updateEnv["_syn_stride"] << ") + idx] = "; - removeSynapse << "group->" << ccuVarRefs[i].name << "[(b * " << updateEnv["_syn_stride"] << ") + lastIdx];" << std::endl; + removeSynapse << "group->" << ccuVarRefs[i].name << "[(b * $(_syn_stride)) + idx] = "; + removeSynapse << "group->" << ccuVarRefs[i].name << "[(b * $(_syn_stride)) + lastIdx];" << std::endl; } } // Otherwise, copy custom connectivity update variable references from end of row over synapse to be deleted @@ -384,8 +380,8 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back removeSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(removeSynapse); - removeSynapse << "group->_dependentVar" << i << "[(b * " << updateEnv["_syn_stride"] << ") + idx] = "; - removeSynapse << "group->_dependentVar" << i << "[(b * " << updateEnv["_syn_stride"] << ") + lastIdx];" << std::endl; + removeSynapse << "group->_dependentVar" << i << "[(b * $(_syn_stride)) + idx] = "; + removeSynapse << "group->_dependentVar" << i << "[(b * $(_syn_stride)) + lastIdx];" << std::endl; } } // Otherwise, copy dependent variable from end of row over synapse to be deleted @@ -396,14 +392,14 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Decrement row length // **NOTE** this will also effect any forEachSynapse loop currently in operation - removeSynapse << updateSubs["_row_length"] << "[" << updateSubs["id_pre"] << "]--;" << std::endl; + removeSynapse << "$(_row_length)[$(id_pre)]--;" << std::endl; // Decrement loop counter so synapse j will get processed removeSynapse << "j--;" << std::endl; } // Add function substitution with parameters to initialise custom connectivity update variables and variable references - updateEnv.add(Type::ResolvedType::createFunction(Type::Void, {}), "remove_synapse", removeSynapseStream.str()); + updateEnv.add(Type::ResolvedType::createFunction(Type::Void, {}), "remove_synapse", LazyString{updateEnv, removeSynapseStream.str()}); // Pretty print code back to environment Transpiler::ErrorHandler errorHandler("Current source injection" + std::to_string(getIndex())); @@ -419,22 +415,18 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back [&backend, &modelMerged, this](auto &env, auto generateBody) { EnvironmentGroupMergedField bodyEnv(env, *this); - bodyEnv.getStream() << "for(int j = 0; j < " << bodyEnv["_row_length"] + "[" + bodyEnv["id_pre"] + "]; j++)"; + bodyEnv.getStream() << printSubs("for(int j = 0; j < $(_row_length)[$(id_pre)]; j++)", bodyEnv); { CodeStream::Scope b(bodyEnv.getStream()); // Add postsynaptic and synaptic indices - bodyEnv.add(Type::Uint32.addConst(), "id_post", bodyEnv["_ind"] + "[" + bodyEnv["_row_start_idx"] + " + j]", - {}, {"_ind", "_row_start_idx"}); + bodyEnv.add(Type::Uint32.addConst(), "id_post", LazyString::print("$(_ind)[$(_row_start_idx) + j]", bodyEnv); bodyEnv.add(Type::Uint32.addConst(), "id_syn", "idx", - {bodyEnv.addInitialiser("const unsigned int idx = " + bodyEnv["_row_start_idx"] + " + j;")}, - {"_row_start_idx"}); + {bodyEnv.addInitialiser("const unsigned int idx = $(_row_start_idx) + j;", bodyEnv)}); // Add postsynaptic and synaptic variables - bodyEnv.addVars(backend.getDeviceVarPrefix(), bodyEnv["id_syn"], "", - {"id_syn"}); - bodyEnv.addVars(backend.getDeviceVarPrefix(), bodyEnv["id_post"], "", - {"id_post"}); + bodyEnv.addVars(backend.getDeviceVarPrefix(), LazyString{bodyEnv, "id_syn"}, ""); + bodyEnv.addVars(backend.getDeviceVarPrefix(), LazyString{bodyEnv, "id_post"}, ""); // Add postsynaptic and synaptic var references // **TODO** diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 250b4c1b42..28a2495f9f 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -145,20 +145,20 @@ bool SynapseGroupMergedBase::isTrgNeuronDerivedParamHeterogeneous(const std::str std::string SynapseGroupMergedBase::getPreSlot(unsigned int batchSize) const { if(getArchetype().getSrcNeuronGroup()->isDelayRequired()) { - return (batchSize == 1) ? "preDelaySlot" : "preBatchDelaySlot"; + return (batchSize == 1) ? "$(_pre_delay_slot)" : "$(_pre_batch_delay_slot)"; } else { - return (batchSize == 1) ? "0" : "batch"; + return (batchSize == 1) ? "0" : "$(batch)"; } } //---------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getPostSlot(unsigned int batchSize) const { if(getArchetype().getTrgNeuronGroup()->isDelayRequired()) { - return (batchSize == 1) ? "postDelaySlot" : "postBatchDelaySlot"; + return (batchSize == 1) ? "$(_post_delay_slot)" : "$(_post_batch_delay_slot)"; } else { - return (batchSize == 1) ? "0" : "batch"; + return (batchSize == 1) ? "0" : "$(batch)"; } } //---------------------------------------------------------------------------- @@ -166,13 +166,13 @@ std::string SynapseGroupMergedBase::getPostDenDelayIndex(unsigned int batchSize, { assert(getArchetype().isDendriticDelayRequired()); - const std::string batchID = ((batchSize == 1) ? "" : "postBatchOffset + ") + index; + const std::string batchID = ((batchSize == 1) ? "" : "$(_post_batch_offset) + ") + "$(" + index + ")"; if(offset.empty()) { - return "(*group->denDelayPtr * group->numTrgNeurons) + " + batchID; + return "(*$(_den_delay_ptr) * $(num_post) + " + batchID; } else { - return "(((*group->denDelayPtr + " + offset + ") % " + std::to_string(getArchetype().getMaxDendriticDelayTimesteps()) + ") * group->numTrgNeurons) + " + batchID; + return "(((*(_den_delay_ptr) + " + offset + ") % " + std::to_string(getArchetype().getMaxDendriticDelayTimesteps()) + ") * $(num_post)) + " + batchID; } } //---------------------------------------------------------------------------- @@ -191,10 +191,10 @@ std::string SynapseGroupMergedBase::getPrePrevSpikeTimeIndex(bool delay, unsigne const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); if(delay) { - return (singleBatch ? "prePrevSpikeTimeDelayOffset + " : "prePrevSpikeTimeBatchDelayOffset + ") + index; + return (singleBatch ? "$(_pre_prev_spike_time_delay_offset) + " : "$(_pre_prev_spike_time_batch_delay_offset) + ") + index; } else { - return (singleBatch ? "" : "preBatchOffset + ") + index; + return (singleBatch ? "" : "$(_pre_batch_offset) + ") + "$(" + index + ")"; } } //-------------------------------------------------------------------------- @@ -203,23 +203,23 @@ std::string SynapseGroupMergedBase::getPostPrevSpikeTimeIndex(bool delay, unsign const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); if(delay) { - return (singleBatch ? "postPrevSpikeTimeDelayOffset + " : "postPrevSpikeTimeBatchDelayOffset + ") + index; + return (singleBatch ? "$(_post_prev_spike_time_delay_offset) + " : "$(_post_prev_spike_time_batch_delay_offset) + ") + "$(" + index + ")";; } else { - return (singleBatch ? "" : "postBatchOffset + ") + index; + return (singleBatch ? "" : "$(_post_batch_offset) + ") + "$(" + index + ")";; } } //-------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getSynVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const { const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); - return (singleBatch ? "" : "synBatchOffset + ") + index; + return (singleBatch ? "" : "$(_syn_batch_offset)") + "$(" + index + ")"; } //-------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getKernelVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const { const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); - return (singleBatch ? "" : "kernBatchOffset + ") + index; + return (singleBatch ? "" : "$(_kern_batch_offset)") + "$(" + index + ")"; } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type SynapseGroupMergedBase::getHashDigest(Role role) const diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index fac6caf81b..b957e9dcf8 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -28,14 +28,14 @@ void genVariableFill(EnvironmentExternalBase &env, const std::string &target, co // If there's only one, don't generate a loop if(numValues == 1) { - env.getStream() << env[target] << "[" << env[idx] << "] = " << value << ";" << std::endl; + env.getStream() << env[target] << "[" << env[idx] << "] = " << value << ";" << std::endl; } // Otherwise else { env.getStream() << "for(unsigned int d = 0; d < " << numValues << "; d++)"; { CodeStream::Scope b(env.getStream()); - env.getStream() << env[target] << "[(d * " << stride << ") + " << env[idx] << "] = " << value << ";" << std::endl; + env.getStream() << env[target] << "[(d * " << printSubs(stride, env) << ") + " << env[idx] << "] = " << value << ";" << std::endl; } } } @@ -125,7 +125,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all delay slots and batches - genVariableFill(varInitEnv, "value", "initVal", "id", count, + genVariableFill(varInitEnv, "value", "initVal", "id", "$(" + count + ")", getVarAccessDuplication(var.access), batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } @@ -201,7 +201,7 @@ void NeuronInitGroupMerged::CurrentSource::generate(const BackendBase &backend, { genInitNeuronVarCode( backend, env, *this, ng, "CS" + std::to_string(getIndex()), - "num_neurons", 0, modelMerged.getModel().getBatchSize()); + "$(num_neurons)", 0, modelMerged.getModel().getBatchSize()); } @@ -223,7 +223,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir [&modelMerged] (EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_out_post", modelMerged.scalarExpr(0.0), - "id", "num_neurons", VarAccessDuplication::DUPLICATE, + "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, modelMerged.getModel().getBatchSize()); }); @@ -237,7 +237,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir [&modelMerged, this](EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_den_delay", modelMerged.scalarExpr(0.0), - "id", "num_neurons", VarAccessDuplication::DUPLICATE, + "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, modelMerged.getModel().getBatchSize(), true, getArchetype().getMaxDendriticDelayTimesteps()); }); @@ -253,7 +253,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir } genInitNeuronVarCode( - backend, groupEnv, *this, ng, fieldSuffix, "num_neurons", 0, modelMerged.getModel().getBatchSize()); + backend, groupEnv, *this, ng, fieldSuffix, "$(num_neurons)", 0, modelMerged.getModel().getBatchSize()); } //---------------------------------------------------------------------------- @@ -274,7 +274,7 @@ void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend [&modelMerged] (EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_out_pre", modelMerged.scalarExpr(0.0), - "id", "num_neurons", VarAccessDuplication::DUPLICATE, + "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, modelMerged.getModel().getBatchSize()); }); } @@ -486,7 +486,7 @@ void NeuronInitGroupMerged::genInitSpikes(const BackendBase &backend, Environmen (getArchetype().isTrueSpikeRequired() && getArchetype().isDelayRequired()); // Zero across all delay slots and batches - genVariableFill(varEnv, "_spk", "0", "id", "num_neurons", + genVariableFill(varEnv, "_spk", "0", "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, batchSize, delayRequired, getArchetype().getNumDelaySlots()); }); } @@ -505,7 +505,7 @@ void NeuronInitGroupMerged::genInitSpikeTime(const BackendBase &backend, Environ backend.genVariableInit(env, "num_neurons", "id", [batchSize, varName, this] (EnvironmentExternalBase &varEnv) { - genVariableFill(varEnv, varName, "-TIME_MAX", "id", "num_neurons", VarAccessDuplication::DUPLICATE, + genVariableFill(varEnv, varName, "-TIME_MAX", "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, batchSize, getArchetype().isDelayRequired(), getArchetype().getNumDelaySlots()); }); } @@ -543,7 +543,6 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, Environmen const bool kernel = (getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL); if (kernel && modelMerged.getModel().getBatchSize() > 1) { // Loop through kernel dimensions and multiply together to calculate batch stride - // **TODO** dependency for add std::ostringstream batchStrideInit; batchStrideInit << "const unsigned int batchStride = "; const auto &kernelSize = getArchetype().getKernelSize(); @@ -556,7 +555,7 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, Environmen } batchStrideInit << ";" << std::endl; groupEnv.add(Type::Uint32.addConst(), "_batch_stride", "batchStride", - {groupEnv.addInitialiser(batchStrideInit.str())}); + {groupEnv.addInitialiser(batchStrideInit.str(), groupEnv)}); } @@ -568,7 +567,7 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, Environmen } // Generate initialisation code - const std::string stride = kernel ? groupEnv["_batch_stride"] : groupEnv["num_pre"] + " * " + groupEnv["_row_stride"]; + const std::string stride = kernel ? "$(_batch_stride)" : "$(num_pre) * $(_row_stride)"; genInitWUVarCode(backend, groupEnv, *this, stride, modelMerged.getModel().getBatchSize(), [&backend, kernel, this](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { @@ -612,7 +611,7 @@ boost::uuids::detail::sha1::digest_type SynapseSparseInitGroupMerged::getHashDig //---------------------------------------------------------------------------- void SynapseSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { - genInitWUVarCode(backend, env, *this, env["num_pre"] + " * " + env["_row_stride"], modelMerged.getModel().getBatchSize(), + genInitWUVarCode(backend, env, *this, "$(num_pre) * $(_row_stride)", modelMerged.getModel().getBatchSize(), [&backend](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { backend.genSparseSynapseVariableRowInit(varInitEnv, handler); @@ -642,10 +641,10 @@ void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase &b // Add substitution // **TODO** dependencies on kernel fields groupEnv.add(Type::Uint32, "id_kernel", "kernelInd", - {groupEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(*this, groupEnv) + ";")}); + {groupEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(*this) + ";", groupEnv)}); // Initialise single (hence empty lambda function) synapse variable - genInitWUVarCode(backend, groupEnv, *this, groupEnv["num_pre"] + " * " + groupEnv["_row_stride"], modelMerged.getModel().getBatchSize(), + genInitWUVarCode(backend, groupEnv, *this, "$(num_pre) * $(_row_stride)", modelMerged.getModel().getBatchSize(), [](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { handler(varInitEnv); @@ -916,7 +915,7 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Env } batchStrideInit << ";" << std::endl; groupEnv.add(Type::Uint32.addConst(), "_batch_stride", "batchStride", - {groupEnv.addInitialiser(batchStrideInit.str())}); + {groupEnv.addInitialiser(batchStrideInit.str(), groupEnv)}); } } else { @@ -936,7 +935,7 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Env } // Loop through rows - const std::string stride = kernel ? groupEnv["_batch_stride"] : groupEnv["num_pre"] + " * " + groupEnv["_row_stride"]; + const std::string stride = kernel ? "$(_batch_stride)" : "$(num_pre) * $(_row_stride)"; genInitWUVarCode( backend, groupEnv, *this, stride, getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, [&backend, kernel, this](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) @@ -1015,7 +1014,7 @@ void CustomWUUpdateSparseInitGroupMerged::generateInit(const BackendBase &backen return backend.getDeviceVarPrefix() + "ind" + sg->getName(); });*/ - genInitWUVarCode(backend, groupEnv, *this, groupEnv["num_pre"] + " * " + groupEnv["_row_stride"], + genInitWUVarCode(backend, groupEnv, *this, "$(num_pre) * $(_row_stride)", getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, [&backend](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { @@ -1164,7 +1163,7 @@ void CustomConnectivityUpdateSparseInitGroupMerged::generateInit(const BackendBa // Initialise custom connectivity update variables genInitWUVarCode( - backend, groupEnv, *this, groupEnv["num_pre"] + " * " + groupEnv["_row_stride"], 1, + backend, groupEnv, *this, "$(num_pre) * $(_row_stride", 1, [&backend](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { return backend.genSparseSynapseVariableRowInit(varInitEnv, handler); diff --git a/src/genn/genn/code_generator/lazyString.cc b/src/genn/genn/code_generator/lazyString.cc index 01e379974e..b98849ec0b 100644 --- a/src/genn/genn/code_generator/lazyString.cc +++ b/src/genn/genn/code_generator/lazyString.cc @@ -26,15 +26,28 @@ std::string LazyString::str() const { stream << str; }, - [&stream](const std::pair, std::string> &env) + [&stream](const EnvRef &envRef) { - stream << env.first.get()[env.second]; + stream << envRef.env.get()[envRef.name]; }}, e); } return stream.str(); } //---------------------------------------------------------------------------- +LazyString& LazyString::operator += (const LazyString &rhs) +{ + // Add RHS's payload to ours + m_Payload.insert(m_Payload.end(), rhs.m_Payload.cbegin(), rhs.m_Payload.cend()); + + return *this; +} +//---------------------------------------------------------------------------- +LazyString& LazyString::operator += (const std::string &rhs) +{ + return operator += (LazyString{rhs}); +} +//---------------------------------------------------------------------------- LazyString LazyString::print(const std::string &format, EnvironmentExternalBase &env) { // Create regex iterator to iterate over $(XXX) style varibles in format string @@ -55,7 +68,7 @@ LazyString LazyString::print(const std::string &format, EnvironmentExternalBase payload.push_back(std::string{m->prefix().first, m->prefix().second}); // Add lazy environment reference for $(XXX) variable to payload - payload.push_back(std::make_pair(std::ref(env), (*m)[1])); + payload.push_back(EnvRef{std::ref(env), (*m)[1]}); // If there are no subsequent matches, add the remaining non-matched // characters onto payload, construct lazy string and return diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 4a001b323b..6ffabe3733 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -35,15 +35,15 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend csEnv.addExtraGlobalParams(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", fieldSuffix); // Define inject current function - csEnv.add(Type::ResolvedType::createFunction(Type::Void, {modelMerged.getModel().getPrecision()}), "injectCurrent", csEnv["Isyn"] + " += $(0)", - {}, {"Isyn"}); + csEnv.add(Type::ResolvedType::createFunction(Type::Void, {modelMerged.getModel().getPrecision()}), "injectCurrent", + LazyString::print("$(Isyn) += $(0)", csEnv)); // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, [&csEnv, &modelMerged, &ng](const std::string&, VarAccessDuplication d) { - return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, csEnv["id"]); + return LazyString::print(ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "id"), csEnv); }); // Pretty print code back to environment @@ -89,7 +89,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Read into local variable psmEnv.getStream() << "// postsynaptic model " << getIndex() << std::endl; psmEnv.getStream() << "scalar linSyn = " << psmEnv["_out_post"] << "["; - psmEnv.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, psmEnv["id"]); + psmEnv.getStream() << printSubs(ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"), psmEnv); psmEnv.getStream() << "];" << std::endl; // If dendritic delay is required @@ -103,7 +103,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Get reference to dendritic delay buffer input for this timestep psmEnv.getStream() << backend.getPointerPrefix() << "scalar *denDelayFront = "; psmEnv.getStream() << "&" << psmEnv["_den_delay"] << "[(*" << psmEnv["_den_delay_ptr"] << " * " << psmEnv["num_neurons"] << ") + "; - psmEnv.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, psmEnv["id"]); + psmEnv.getStream() << printSubs(ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"), psmEnv); psmEnv.getStream() << "];" << std::endl; // Add delayed input from buffer into inSyn @@ -129,7 +129,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, [&psmEnv, &modelMerged, &ng](const std::string&, VarAccessDuplication d) { - return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, psmEnv["id"]); + return LazyString::print(ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "id"), psmEnv); }); // Pretty print code back to environment @@ -141,7 +141,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Write back linSyn varEnv.getStream() << psmEnv["_out_post"] << "["; - varEnv.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, psmEnv["id"]); + varEnv.getStream() << printSubs(ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"), psmEnv); varEnv.getStream() << "] = linSyn;" << std::endl; } //---------------------------------------------------------------------------- @@ -180,12 +180,12 @@ void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backe // Add reverse insyn variable to outSynEnv.getStream() << getArchetype().getPreTargetVar() << " += "; outSynEnv.getStream() << outSynEnv["_out_pre"] << "["; - outSynEnv.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, env["id"]); + outSynEnv.getStream() << printSubs(ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"), env); outSynEnv.getStream() << "];" << std::endl; // Zero it again outSynEnv.getStream() << outSynEnv["_out_pre"] << "["; - outSynEnv.getStream() << ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, env["id"]); + outSynEnv.getStream() << printSubs(ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"), env); outSynEnv.getStream() << "] = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; } @@ -219,18 +219,18 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { - return ng.getReadVarIndex(delayed, batchSize, d, synEnv["id"]); + return LazyString::print(ng.getReadVarIndex(delayed, batchSize, d, "id"), synEnv); }, [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { - return ng.getWriteVarIndex(delayed, batchSize, d, synEnv["id"]); + return LazyString::print(ng.getWriteVarIndex(delayed, batchSize, d, "id"), synEnv); }); /*neuronSubstitutionsInSynapticCode(varEnv, &ng.getArchetype(), "", "_post", "", "", "", dynamicsNotSpike, [&ng](const std::string &p) { return ng.isParamHeterogeneous(p); }, [&ng](const std::string &p) { return ng.isDerivedParamHeterogeneous(p); }, [&subs, &ng, batchSize](bool delay, VarAccessDuplication varDuplication) - { + s { return ng.getReadVarIndex(delay, batchSize, varDuplication, subs["id"]); }, [&subs, &ng, batchSize](bool delay, VarAccessDuplication varDuplication) @@ -253,11 +253,11 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentEx for(const auto &v : getArchetype().getWUModel()->getPostVars()) { if(v.access & VarAccessMode::READ_WRITE) { env.getStream() << env[v.name] << "["; - env.getStream() << ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), env["id"]); + env.getStream() << printSubs(ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), "id"), env); env.getStream() << "] = "; env.getStream() << env[v.name] << "["; - env.getStream() << ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), env["id"]); + env.getStream() << printSubs(ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), "id"), env); env.getStream() << "];" << std::endl; } } @@ -311,11 +311,11 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { - return ng.getReadVarIndex(delayed, batchSize, d, synEnv["id"]); + return LazyString::print(ng.getReadVarIndex(delayed, batchSize, d, "id"), synEnv); }, [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { - return ng.getWriteVarIndex(delayed, batchSize, d, synEnv["id"]); + return LazyString::print(ng.getWriteVarIndex(delayed, batchSize, d, "id"), synEnv); }); /*neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_pre", "", "", "", dynamicsNotSpike, @@ -345,11 +345,11 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentEx for(const auto &v : getArchetype().getWUModel()->getPreVars()) { if(v.access & VarAccessMode::READ_WRITE) { env.getStream() << env[v.name] << "["; - env.getStream() << ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), env["id"]); + env.getStream() << printSubs(ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), "id"), env); env.getStream() << "] = "; env.getStream() << env[v.name] << "["; - env.getStream() << ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), env["id"]); + env.getStream() << printSubs(ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), "id"), env); env.getStream() << "];" << std::endl; } } @@ -489,8 +489,6 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E BackendBase::GroupHandlerEnv genEmitTrueSpike, BackendBase::GroupHandlerEnv genEmitSpikeLikeEvent) { - using LS = LazyString; - const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); const NeuronModels::Base *nm = getArchetype().getNeuronModel(); @@ -530,15 +528,15 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E neuronEnv.addExtraGlobalParams(nm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); // Substitute spike times - const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, neuronEnv["id"]); + const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "id"); neuronEnv.add(getTimeType().addConst(), "sT", "lsT", - {neuronEnv.addInitialiser("const timepoint lsT = " + LS(neuronEnv, "_spk_time") + "[" + spikeTimeReadIndex + "];")}); + {neuronEnv.addInitialiser("const timepoint lsT = $(_spk_time)[" + spikeTimeReadIndex + "];", neuronEnv)}); neuronEnv.add(getTimeType().addConst(), "prev_sT", "lprevST", - {neuronEnv.addInitialiser("const timepoint lprevST = " + LS(neuronEnv, "_prev_spk_time") + "[" + spikeTimeReadIndex + "];")}); + {neuronEnv.addInitialiser("const timepoint lprevST = $(_prev_spk_time)[" + spikeTimeReadIndex + "];", neuronEnv)}); neuronEnv.add(getTimeType().addConst(), "seT", "lseT", - {neuronEnv.addInitialiser("const timepoint lseT = " + LS(neuronEnv, "_spk_evnt_time") + "[" + spikeTimeReadIndex+ "];")}); + {neuronEnv.addInitialiser("const timepoint lseT = $(_spk_evnt_time)[" + spikeTimeReadIndex+ "];", neuronEnv)}); neuronEnv.add(getTimeType().addConst(), "prev_seT", "lprevSET", - {neuronEnv.addInitialiser("const timepoint lprevSET = " + LS(neuronEnv, "_prev_spk_evnt_time") + "[" + spikeTimeReadIndex + "];")}); + {neuronEnv.addInitialiser("const timepoint lprevSET = $(_prev_spk_evnt_time)[" + spikeTimeReadIndex + "];", neuronEnv)}); // Create an environment which caches variables in local variables if they are accessed // **NOTE** we do this right at the top so that local copies can be used by child groups @@ -791,13 +789,13 @@ std::string NeuronUpdateGroupMerged::getVarIndex(unsigned int batchSize, VarAcce { // **YUCK** there's a lot of duplication in these methods - do they belong elsewhere? if (varDuplication == VarAccessDuplication::SHARED_NEURON) { - return (batchSize == 1) ? "0" : "batch"; + return (batchSize == 1) ? "0" : "$(batch)"; } else if(varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return index; + return "$(" + index + ")"; } else { - return "batchOffset + " + index; + return "$(_batch_offset) + $(" + index + ")"; } } //-------------------------------------------------------------------------- @@ -805,13 +803,13 @@ std::string NeuronUpdateGroupMerged::getReadVarIndex(bool delay, unsigned int ba { if(delay) { if (varDuplication == VarAccessDuplication::SHARED_NEURON) { - return (batchSize == 1) ? "readDelaySlot" : "readBatchDelaySlot"; + return (batchSize == 1) ? "$(_read_delay_slot)" : "$(_read_batch_delay_slot)"; } else if (varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return "readDelayOffset + " + index; + return "$(_read_delay_offset) + $(" + index + ")"; } else { - return "readBatchDelayOffset + " + index; + return "$(_read_batch_delay_offset) + $(" + index + ")"; } } else { @@ -823,13 +821,13 @@ std::string NeuronUpdateGroupMerged::getWriteVarIndex(bool delay, unsigned int b { if(delay) { if (varDuplication == VarAccessDuplication::SHARED_NEURON) { - return (batchSize == 1) ? "writeDelaySlot" : "writeBatchDelaySlot"; + return (batchSize == 1) ? "$(_write_delay_slot)" : "$(_write_batch_delay_slot)"; } else if (varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return "writeDelayOffset + " + index; + return "$(_write_delay_offset) + $(" + index + ")"; } else { - return "writeBatchDelayOffset + " + index; + return "$(_write_batch_delay_offset) + $(" + index + ")"; } } else { diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index efed795f20..14098cabf1 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -33,23 +33,20 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa synEnv.addVars(backend.getDeviceVarPrefix(), [&sg, &synEnv, batchSize](VarAccess a, const std::string&) { - return "[" + sg.getPreWUVarIndex(batchSize, getVarAccessDuplication(a), synEnv["id_pre"]) + "]"; - }, - {"id_pre"}); + return LazyString::print(sg.getPreWUVarIndex(batchSize, getVarAccessDuplication(a), "id_pre"), synEnv); + }); synEnv.addVars(backend.getDeviceVarPrefix(), [&sg, &synEnv, batchSize](VarAccess a, const std::string&) { - return "[" + sg.getPostWUVarIndex(batchSize, getVarAccessDuplication(a), synEnv["id_post"]) + "]"; - }, - {"id_post"}); + return LazyString::print(sg.getPostWUVarIndex(batchSize, getVarAccessDuplication(a), "id_post"), synEnv); + }); // If this synapse group has a kernel if (!sg.getArchetype().getKernelSize().empty()) { // Add substitution - // **TODO** dependencies on kernel fields synEnv.add(Type::Uint32, "id_kernel", "kernelInd", - {synEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(sg, synEnv) + ";")}); + {synEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(sg) + ";", synEnv)}); } // If weights are individual, substitute variables for values stored in global memory @@ -57,9 +54,8 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa synEnv.addVars(backend.getDeviceVarPrefix(), [&sg, &synEnv, batchSize](VarAccess a, const std::string&) { - return "[" + sg.getSynVarIndex(batchSize, getVarAccessDuplication(a), synEnv["id_syn"]) + "]"; - }, - {"id_syn"}); + return LazyString::print(sg.getSynVarIndex(batchSize, getVarAccessDuplication(a), "id_syn"), synEnv); + }); } // Otherwise, if weights are procedual else if (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::PROCEDURAL) { @@ -108,9 +104,8 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa synEnv.addVars(backend.getDeviceVarPrefix(), [&sg, &synEnv, batchSize](VarAccess a, const std::string&) { - return "[" + sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), synEnv["id_kernel"]) + "]"; - }, - {"id_kernel"}); + return LazyString::print(sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), "id_kernel"), synEnv); + }); } // Otherwise, substitute variables for constant values else { From 6216b393ea10cab1b309dbec1a02bf8b4356d7ef Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 27 Jun 2023 09:47:12 +0100 Subject: [PATCH 256/725] removed unnecessary environment parameter from addInitialiser - now all initialisers are scanned for $(XX) to lazyily evaluate --- .../genn/genn/code_generator/backendBase.h | 53 +++++++------- .../genn/genn/code_generator/environment.h | 10 +-- .../backends/single_threaded_cpu/backend.cc | 72 +++++++++---------- src/genn/genn/code_generator/backendBase.cc | 10 +-- .../genn/code_generator/initGroupMerged.cc | 6 +- .../code_generator/neuronUpdateGroupMerged.cc | 8 +-- .../synapseUpdateGroupMerged.cc | 2 +- 7 files changed, 76 insertions(+), 85 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 23dfd6665c..7a93a12612 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -492,7 +492,6 @@ class GENN_EXPORT BackendBase template void genNeuronIndexCalculation(EnvironmentGroupMergedField &env, unsigned int batchSize) const { - using LS = LazyString; env.addField(Type::Uint32.addConst(), "num_neurons", Type::Uint32, "numNeurons", [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); @@ -520,7 +519,7 @@ class GENN_EXPORT BackendBase // If batching is enabled, calculate batch offset if(batchSize > 1) { env.add(Type::Uint32.addConst(), "_batchOffset", "batchOffset", - {env.addInitialiser("const unsigned int batchOffset = $(num_neurons) * batch;", env)}); + {env.addInitialiser("const unsigned int batchOffset = $(num_neurons) * $(batch);")}); } // If axonal delays are required @@ -529,33 +528,33 @@ class GENN_EXPORT BackendBase const unsigned int numDelaySlots = env.getGroup().getArchetype().getNumDelaySlots(); const std::string numDelaySlotsStr = std::to_string(numDelaySlots); env.add(Type::Uint32.addConst(), "_read_delay_slot", "readDelaySlot", - {env.addInitialiser("const unsigned int readDelaySlot = (*$(_spk_que_ptr) + " + std::to_string(numDelaySlots - 1) + ") % " + numDelaySlotsStr+ ";", env)}); + {env.addInitialiser("const unsigned int readDelaySlot = (*$(_spk_que_ptr) + " + std::to_string(numDelaySlots - 1) + ") % " + numDelaySlotsStr+ ";")}); env.add(Type::Uint32.addConst(), "_read_delay_offset", "readDelayOffset", - {env.addInitialiser("const unsigned int readDelayOffset = $(_read_delay_slot) * $(num_neurons);", env)}); + {env.addInitialiser("const unsigned int readDelayOffset = $(_read_delay_slot) * $(num_neurons);")}); // And we should WRITE to delay slot pointed to be spkQuePtr env.add(Type::Uint32.addConst(), "_write_delay_slot", "writeDelaySlot", - {env.addInitialiser("const unsigned int writeDelaySlot = * $(_spk_que_ptr);", env)}); + {env.addInitialiser("const unsigned int writeDelaySlot = * $(_spk_que_ptr);")}); env.add(Type::Uint32.addConst(), "_write_delay_offset", "writeDelayOffset", - {env.addInitialiser("const unsigned int writeDelayOffset = $(_write_delay_slot) * $(num_neurons);", env)}); + {env.addInitialiser("const unsigned int writeDelayOffset = $(_write_delay_slot) * $(num_neurons);")}); // If batching is also enabled if(batchSize > 1) { // Calculate batched delay slots env.add(Type::Uint32.addConst(), "_read_batch_delay_slot", "readBatchDelaySlot", - {env.addInitialiser("const unsigned int readBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_read_delay_slot);", env)}); + {env.addInitialiser("const unsigned int readBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_read_delay_slot);")}); env.add(Type::Uint32.addConst(), "_write_batch_delay_slot", "writeBatchDelaySlot", - {env.addInitialiser("const unsigned int writeBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_write_delay_slot);", env)}); + {env.addInitialiser("const unsigned int writeBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_write_delay_slot);")}); // Calculate current batch offset env.add(Type::Uint32.addConst(), "_batch_delay_offset", "batchDelayOffset", - {env.addInitialiser("const unsigned int batchDelayOffset = $(_batch_offset) * " + numDelaySlotsStr + ";", env)}); + {env.addInitialiser("const unsigned int batchDelayOffset = $(_batch_offset) * " + numDelaySlotsStr + ";")}); // Calculate further offsets to include delay and batch env.add(Type::Uint32.addConst(), "_read_batch_delay_offset", "readBatchDelayOffset", - {env.addInitialiser("const unsigned int readBatchDelayOffset = $(_read_delay_offset) + $(_batch_delay_offset);", env)}); + {env.addInitialiser("const unsigned int readBatchDelayOffset = $(_read_delay_offset) + $(_batch_delay_offset);")}); env.add(Type::Uint32.addConst(), "_write_batch_delay_offset", "writeBatchDelayOffset", - {env.addInitialiser("const unsigned int writeBatchDelayOffset = $(_write_delay_offset)+ $(_batch_delay_offset);", env)}); + {env.addInitialiser("const unsigned int writeBatchDelayOffset = $(_write_delay_offset)+ $(_batch_delay_offset);")}); } } } @@ -611,9 +610,9 @@ class GENN_EXPORT BackendBase if(batchSize > 1) { // Calculate batch offsets into pre and postsynaptic populations env.add(Type::Uint32.addConst(), "_pre_batch_offset", "preBatchOffset", - {env.addInitialiser("const unsigned int preBatchOffset = $(num_pre) * $(batch);", env)}); + {env.addInitialiser("const unsigned int preBatchOffset = $(num_pre) * $(batch);")}); env.add(Type::Uint32.addConst(), "_post_batch_offset", "postBatchOffset", - {env.addInitialiser("const unsigned int preBatchOffset = $(num_post) * $(batch);", env)}); + {env.addInitialiser("const unsigned int preBatchOffset = $(num_post) * $(batch);")}); // Calculate batch offsets into synapse arrays, using 64-bit arithmetic if necessary if(areSixtyFourBitSynapseIndicesRequired(env.getGroup())) { @@ -622,7 +621,7 @@ class GENN_EXPORT BackendBase } else { env.add(Type::Uint32.addConst(), "_syn_batch_offset", "synBatchOffset", - {env.addInitialiser("const unsigned int synBatchOffset = $(_pre_batch_offset) * $(_row_stride);", env)}); + {env.addInitialiser("const unsigned int synBatchOffset = $(_pre_batch_offset) * $(_row_stride);")}); } // If synapse group has kernel @@ -640,7 +639,7 @@ class GENN_EXPORT BackendBase kernBatchOffsetInit << "$(batch);" << std::endl; env.add(Type::Uint32.addConst(), "_kern_batch_offset", "kernBatchOffset", - {env.addInitialiser(kernBatchOffsetInit.str(), env)}); + {env.addInitialiser(kernBatchOffsetInit.str())}); } } @@ -658,16 +657,16 @@ class GENN_EXPORT BackendBase preDelaySlotInit << "(*$(_src_spk_que_ptr) + " << (numSrcDelaySlots - numDelaySteps) << ") % " << numSrcDelaySlots << ";" << std::endl; } env.add(Type::Uint32, "_pre_delay_slot", "preDelaySlot", - {env.addInitialiser(preDelaySlotInit.str(), env)}); + {env.addInitialiser(preDelaySlotInit.str())}); env.add(Type::Uint32, "_pre_delay_offset", "preDelayOffset", - {env.addInitialiser("const unsigned int preDelayOffset = $(_pre_delay_slot) * $(num_pre);", env)}); + {env.addInitialiser("const unsigned int preDelayOffset = $(_pre_delay_slot) * $(num_pre);")}); if(batchSize > 1) { env.add(Type::Uint32, "_pre_batch_delay_slot", "preBatchDelaySlot", - {env.addInitialiser("const unsigned int preBatchDelaySlot = $(_pre_delay_slot) + ($(batch) * " + std::to_string(numSrcDelaySlots) + ");", env)}); + {env.addInitialiser("const unsigned int preBatchDelaySlot = $(_pre_delay_slot) + ($(batch) * " + std::to_string(numSrcDelaySlots) + ");")}); env.add(Type::Uint32, "_pre_batch_delay_offset", "preBatchDelayOffset", - {env.addInitialiser("const unsigned int preBatchDelayOffset = $(_pre_delay_offset) + ($(_pre_batch_offset) * " + std::to_string(numSrcDelaySlots) + ");", env)}); + {env.addInitialiser("const unsigned int preBatchDelayOffset = $(_pre_delay_offset) + ($(_pre_batch_offset) * " + std::to_string(numSrcDelaySlots) + ");")}); } if(env.getGroup().getArchetype().getWUModel()->isPrevPreSpikeTimeRequired() @@ -675,11 +674,11 @@ class GENN_EXPORT BackendBase { env.add(Type::Uint32, "_pre_prev_spike_time_delay_offset", "prePrevSpikeTimeDelayOffset", {env.addInitialiser("const unsigned int prePrevSpikeTimeDelayOffset = ((*$(_src_spk_que_ptr) + " - + std::to_string(numSrcDelaySlots - numDelaySteps - 1) + ") % " + std::to_string(numSrcDelaySlots) + ") * $(num_pre);", env)}); + + std::to_string(numSrcDelaySlots - numDelaySteps - 1) + ") % " + std::to_string(numSrcDelaySlots) + ") * $(num_pre);")}); if(batchSize > 1) { env.add(Type::Uint32, "_pre_prev_spike_time_batch_delay_offset", "prePrevSpikeTimeBatchDelayOffset", - {env.addInitialiser("const unsigned int prePrevSpikeTimeBatchDelayOffset = $(_pre_prev_spike_time_delay_offset) + ($(_pre_batch_offset) * " + std::to_string(numSrcDelaySlots) + ");", env)}); + {env.addInitialiser("const unsigned int prePrevSpikeTimeBatchDelayOffset = $(_pre_prev_spike_time_delay_offset) + ($(_pre_batch_offset) * " + std::to_string(numSrcDelaySlots) + ");")}); } } } @@ -698,26 +697,26 @@ class GENN_EXPORT BackendBase postDelaySlotInit << "(*$(_trg_spk_que_ptr) + " << (numTrgDelaySlots - numBackPropDelaySteps) << ") % " << numTrgDelaySlots << ";" << std::endl; } env.add(Type::Uint32, "_post_delay_slot", "postDelaySlot", - {env.addInitialiser(postDelaySlotInit.str(), env)}); + {env.addInitialiser(postDelaySlotInit.str())}); env.add(Type::Uint32, "_post_delay_offset", "postDelayOffset", - {env.addInitialiser("const unsigned int postDelayOffset = $(_post_delay_slot) * $(num_post);", env)}); + {env.addInitialiser("const unsigned int postDelayOffset = $(_post_delay_slot) * $(num_post);")}); if(batchSize > 1) { env.add(Type::Uint32, "_post_batch_delay_slot", "postBatchDelaySlot", - {env.addInitialiser("const unsigned int postBatchDelaySlot =$(_post_delay_slot) + (batch * " + std::to_string(numTrgDelaySlots) + ");", env)}); + {env.addInitialiser("const unsigned int postBatchDelaySlot =$(_post_delay_slot) + (batch * " + std::to_string(numTrgDelaySlots) + ");")}); env.add(Type::Uint32, "_post_batch_delay_offset", "postBatchDelayOffset", - {env.addInitialiser("const unsigned int postBatchDelayOffset = $(_post_delay_offset) + ($(_post_batch_offset) * " + std::to_string(numTrgDelaySlots) + ");", env)}); + {env.addInitialiser("const unsigned int postBatchDelayOffset = $(_post_delay_offset) + ($(_post_batch_offset) * " + std::to_string(numTrgDelaySlots) + ");")}); } if(env.getGroup().getArchetype().getWUModel()->isPrevPostSpikeTimeRequired()) { env.add(Type::Uint32, "_post_prev_spike_time_delay_offset", "postPrevSpikeTimeDelayOffset", {env.addInitialiser("const unsigned int postPrevSpikeTimeDelayOffset = ((*$(_trg_spk_que_ptr) + " - + std::to_string(numTrgDelaySlots - numBackPropDelaySteps - 1) + ") % " + std::to_string(numTrgDelaySlots) + ") * $(num_post);", env)}); + + std::to_string(numTrgDelaySlots - numBackPropDelaySteps - 1) + ") % " + std::to_string(numTrgDelaySlots) + ") * $(num_post);")}); if(batchSize > 1) { env.add(Type::Uint32, "_post_prev_spike_time_batch_delay_offset", "postPrevSpikeTimeBatchDelayOffset", - {env.addInitialiser("const unsigned int postPrevSpikeTimeBatchDelayOffset = $(_post_prev_spike_time_delay_offset) + ($(_post_batch_offset) * " + std::to_string(numTrgDelaySlots) + ");", env)}); + {env.addInitialiser("const unsigned int postPrevSpikeTimeBatchDelayOffset = $(_post_prev_spike_time_delay_offset) + ($(_post_batch_offset) * " + std::to_string(numTrgDelaySlots) + ");")}); } } diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index e4f88b3556..056594c3eb 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -297,18 +297,12 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P } } - - size_t addInitialiser(const LazyString &initialiser) + size_t addInitialiser(const std::string &format) { - m_Initialisers.emplace_back(false, initialiser); + m_Initialisers.emplace_back(false, LazyString::print(format, *this)); return (m_Initialisers.size() - 1); } - size_t addInitialiser(const std::string &format, EnvironmentExternalBase &env) - { - return addInitialiser(LazyString::print(format, env)); - } - protected: //------------------------------------------------------------------------ // Protected API diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 9f18ac31a4..02a4835898 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -95,7 +95,7 @@ void genKernelIteration(EnvironmentExternalBase &env, const G &g, size_t numKern // Generate kernel index and use as "synapse" index // **TODO** rename loopEnv.add(Type::Uint32.addConst(), "id_syn", "kernelInd", - {loopEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(g) + ";", loopEnv)}); + {loopEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(g) + ";")}); // Call handler handler(loopEnv); @@ -368,8 +368,8 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos if (s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Add initialiser strings to calculate synaptic and presynaptic index - const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * $(_row_stride)) + s;", synEnv); - const size_t idPostInit = synEnv.addInitialiser("const unsigned int idPost = $(_ind)[$(id_syn)];", synEnv); + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * $(_row_stride)) + s;"); + const size_t idPostInit = synEnv.addInitialiser("const unsigned int idPost = $(_ind)[$(id_syn)];"); // **TODO** id_syn can be 64-bit synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}); @@ -382,7 +382,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Add initialiser to calculate synaptic index // **TODO** id_syn can be 64-bit synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", - {synEnv.addInitialiser("const unsigned int idSyn = (i * $(num_post)) + j;", synEnv)}); + {synEnv.addInitialiser("const unsigned int idSyn = (i * $(num_post)) + j;")}); } // Add correct functions for apply synaptic input @@ -486,9 +486,9 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Add initialisers to calculate column and row-major indices // **TODO** fast divide optimisations - const size_t colMajorIdxInit = synEnv.addInitialiser("const unsigned int colMajorIndex = (spike * $(_col_stride)) + i;", synEnv); - const size_t rowMajorIdxInit = synEnv.addInitialiser("const unsigned int rowMajorIndex = $(_remap)[colMajorIndex];", synEnv); - const size_t idPreInit = synEnv.addInitialiser("const unsigned int idPre = rowMajorIndex / $(_row_stride);", synEnv); + const size_t colMajorIdxInit = synEnv.addInitialiser("const unsigned int colMajorIndex = (spike * $(_col_stride)) + i;"); + const size_t rowMajorIdxInit = synEnv.addInitialiser("const unsigned int rowMajorIndex = $(_remap)[colMajorIndex];"); + const size_t idPreInit = synEnv.addInitialiser("const unsigned int idPre = rowMajorIndex / $(_row_stride);"); // Add presynaptic and synapse index to environment synEnv.add(Type::Uint32.addConst(), "id_pre", "idPre", {colMajorIdxInit, rowMajorIdxInit, idPreInit}); @@ -498,7 +498,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Add presynaptic and synapse index to environment synEnv.add(Type::Uint32.addConst(), "id_pre", "i"); synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", - {synEnv.addInitialiser("const unsigned int idSyn = (i * $(num_post)) + spike;", synEnv)}); + {synEnv.addInitialiser("const unsigned int idSyn = (i * $(num_post)) + spike;")}); } synEnv.add(Type::Uint32.addConst(), "id_post", "spike"); @@ -696,8 +696,8 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // If connectivity is sparse if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Add initialisers to calculate synaptic index and thus lookup postsynaptic index - const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * $(_row_stride)) + s;", synEnv); - const size_t jInit = synEnv.addInitialiser("const unsigned int j = $(_ind)[idSyn];", synEnv); + const size_t idSynInit = synEnv.addInitialiser("const unsigned int idSyn = (i * $(_row_stride)) + s;"); + const size_t jInit = synEnv.addInitialiser("const unsigned int j = $(_ind)[idSyn];"); // Add substitutions synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", {idSynInit}); @@ -707,7 +707,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host synEnv.add(Type::Uint32.addConst(), "id_post", "j"); synEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", - {synEnv.addInitialiser("const unsigned int idSyn = (i * $(num_post)) + j;", synEnv)}); + {synEnv.addInitialiser("const unsigned int idSyn = (i * $(num_post)) + j;")}); } // Generate custom update @@ -797,7 +797,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Add conditional initialisation code to calculate synapse index groupEnv.add(Type::Uint32.addConst(), "id_syn", "idSyn", - {groupEnv.addInitialiser("const unsigned int idSyn = (i * $(num_post)) + j;", groupEnv)}); + {groupEnv.addInitialiser("const unsigned int idSyn = (i * $(num_post)) + j;")}); // Generate custom update c.generateCustomUpdate(*this, groupEnv); @@ -1494,9 +1494,9 @@ void Backend::genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, Hand EnvironmentExternal varEnv(env); // **TODO** 64-bit varEnv.add(Type::Uint32, "id_syn", "idSyn", - {varEnv.addInitialiser("const unsigned int idSyn = ($(id_pre) * $(_row_stride)) + j;", varEnv)}); + {varEnv.addInitialiser("const unsigned int idSyn = ($(id_pre) * $(_row_stride)) + j;")}); varEnv.add(Type::Uint32, "id_post", "idPost", - {varEnv.addInitialiser("const unsigned int idPost = $(_ind)[($(id_pre) * $(_row_stride)) + j]", varEnv)}); + {varEnv.addInitialiser("const unsigned int idPost = $(_ind)[($(id_pre) * $(_row_stride)) + j]")}); handler(varEnv); } } @@ -1510,7 +1510,7 @@ void Backend::genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, Handl EnvironmentExternal varEnv(env); // **TODO** 64-bit varEnv.add(Type::Uint32, "id_syn", "idSyn", - {varEnv.addInitialiser("const unsigned int idSyn = ($(id_pre) * $(_row_stride)) + j;", varEnv)}); + {varEnv.addInitialiser("const unsigned int idSyn = ($(id_pre) * $(_row_stride)) + j;")}); varEnv.add(Type::Uint32, "id_post", "j"); handler(varEnv); } @@ -1716,14 +1716,14 @@ boost::uuids::detail::sha1::digest_type Backend::getHashDigest() const void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, const ModelSpecMerged &modelMerged, bool trueSpike) const { // Get suffix based on type of events - const std::string eventSuffix = trueSpike ? "" : "Evnt"; + const std::string eventSuffix = trueSpike ? "" : "_evnt"; const auto *wu = sg.getArchetype().getWUModel(); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::TOEPLITZ) { const auto &connectInit = sg.getArchetype().getToeplitzConnectivityInitialiser(); // Loop through Toeplitz matrix diagonals - env.getStream() << "for(unsigned int j = 0; j < group->rowStride; j++)"; + env.getStream() << "for(unsigned int j = 0; j < " << env["_row_stride"] << "; j++)"; { /*CodeStream::Scope b(env.getStream()); @@ -1832,10 +1832,10 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda // Detect spike events or spikes and do the update env.getStream() << "// process presynaptic events: " << (trueSpike ? "True Spikes" : "Spike type events") << std::endl; if(sg.getArchetype().getSrcNeuronGroup()->isDelayRequired()) { - env.getStream() << "for (unsigned int i = 0; i < group->srcSpkCnt" << eventSuffix << "[preDelaySlot]; i++)"; + env.getStream() << printSubs("for (unsigned int i = 0; i < $(_src_spk_cnt" + eventSuffix + ")[$(pre_delay_slot)]; i++)", env); } else { - env.getStream() << "for (unsigned int i = 0; i < group->srcSpkCnt" << eventSuffix << "[0]; i++)"; + env.getStream() << printSubs("for (unsigned int i = 0; i < $(_src_spk_cnt" + eventSuffix + ")[0]; i++)", env); } { CodeStream::Scope b(env.getStream()); @@ -1845,9 +1845,9 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda EnvironmentGroupMergedField groupEnv(env, sg); - const std::string queueOffset = sg.getArchetype().getSrcNeuronGroup()->isDelayRequired() ? "preDelayOffset + " : ""; + const std::string queueOffset = sg.getArchetype().getSrcNeuronGroup()->isDelayRequired() ? "$(pre_delay_offset) + " : ""; groupEnv.add(Type::Uint32, "id_pre", "idPre", - {groupEnv.addInitialiser("const unsigned int idPre = group->srcSpk" + eventSuffix + "[" + queueOffset + "i];")}); + {groupEnv.addInitialiser("const unsigned int idPre = $(_src_spk" + eventSuffix + ")[" + queueOffset + "i];")}); // If this is a spike-like event, insert threshold check for this presynaptic neuron if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { @@ -1875,9 +1875,9 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda // **TODO** 64-bit id_syn synEnv.add(Type::Uint32, "id_syn", "idSyn", - {synEnv.addInitialiser("const unsigned int idSyn = ($(id_pre) * $(_row_stride)) + j;", synEnv)}); + {synEnv.addInitialiser("const unsigned int idSyn = ($(id_pre) * $(_row_stride)) + j;")}); synEnv.add(Type::Uint32, "id_post", "idPost", - {synEnv.addInitialiser("const unsigned int idPost = $(_ind)[$(id_syn)];", synEnv)}); + {synEnv.addInitialiser("const unsigned int idPost = $(_ind)[$(id_syn)];")}); if(trueSpike) { sg.generateSpikeUpdate(*this, synEnv, modelMerged); @@ -1898,7 +1898,7 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda CodeStream::Scope b(groupEnv.getStream()); // Read row word - groupEnv.getStream() << "uint32_t connectivityWord = group->gp[(ipre * rowWords) + w];" << std::endl; + groupEnv.getStream() << "uint32_t connectivityWord = " << groupEnv["_gp"] << "[(ipre * rowWords) + w];" << std::endl; // Set ipost to first synapse in connectivity word groupEnv.getStream() << "unsigned int ipost = w * 32;" << std::endl; @@ -1921,7 +1921,7 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda // If we aren't in padding region // **TODO** don't bother checking if there is no padding - groupEnv.getStream() << "if(ipost < group->numTrgNeurons)"; + groupEnv.getStream() << "if(ipost < " << groupEnv["num_post"] << ")"; { CodeStream::Scope b(env.getStream()); if(trueSpike) { @@ -1949,11 +1949,11 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda // **TODO** 64-bit index synEnv.getStream() << printSubs("const uint64_t gid = ($(id_pre) * $(num_post)) + $(id_post);", synEnv) << std::endl; - synEnv.getStream() << "if (B(group->gp[gid / 32], gid & 31))" << CodeStream::OB(20); + synEnv.getStream() << "if (B(" << synEnv["_gp"] << "[gid / 32], gid & 31))" << CodeStream::OB(20); } else { synEnv.add(Type::Uint32, "id_syn", "idSyn", - {synEnv.addInitialiser("const unsigned int idSyn = ($(id_pre) * $(num_post)) + $(id_post);", synEnv)}); + {synEnv.addInitialiser("const unsigned int idSyn = ($(id_pre) * $(num_post)) + $(id_post);")}); } @@ -1981,12 +1981,12 @@ void Backend::genEmitSpike(EnvironmentExternalBase &env, NeuronUpdateGroupMerged { // Determine if delay is required and thus, at what offset we should write into the spike queue const bool spikeDelayRequired = trueSpike ? (ng.getArchetype().isDelayRequired() && ng.getArchetype().isTrueSpikeRequired()) : ng.getArchetype().isDelayRequired(); - const std::string spikeQueueOffset = spikeDelayRequired ? "writeDelayOffset + " : ""; + const std::string spikeQueueOffset = spikeDelayRequired ? "$(_write_delay_offset) + " : ""; - const std::string suffix = trueSpike ? "" : "Evnt"; - env.getStream() << "group->spk" << suffix << "[" << spikeQueueOffset << "group->spkCnt" << suffix; + const std::string suffix = trueSpike ? "" : "_evnt"; + env.getStream() << printSubs("$(_spk" + suffix + ")[" + spikeQueueOffset + "$(_spk_cnt" + suffix + ")", env); if(spikeDelayRequired) { // WITH DELAY - env.getStream() << "[*group->spkQuePtr]++]"; + env.getStream() << "[*" << env["_spk_que_ptr"] << "]++]"; } else { // NO DELAY env.getStream() << "[0]++]"; @@ -1994,19 +1994,17 @@ void Backend::genEmitSpike(EnvironmentExternalBase &env, NeuronUpdateGroupMerged env.getStream() << " = " << env["id"] << ";" << std::endl; // Reset spike and spike-like-event times - const std::string queueOffset = ng.getArchetype().isDelayRequired() ? "writeDelayOffset + " : ""; + const std::string queueOffset = ng.getArchetype().isDelayRequired() ? "$(_write_delay_offset) + " : ""; if(trueSpike && ng.getArchetype().isSpikeTimeRequired()) { - env.getStream() << "group->sT[" << queueOffset << env["id"] << "] = " << env["t"] << ";" << std::endl; + env.getStream() << printSubs("$(_spk_time)[" + queueOffset + "$(id)] = $(t);", env) << std::endl; } else if(!trueSpike && ng.getArchetype().isSpikeEventTimeRequired()) { - env.getStream() << "group->seT[" << queueOffset << env["id"] << "] = " << env["t"] << ";" << std::endl; + env.getStream() << printSubs("$(_spk_evnt_time)[" + queueOffset + "$(id)] = $(t);", env) << std::endl; } // If recording is enabled if(recordingEnabled) { - const std::string recordSuffix = trueSpike ? "" : "Event"; - env.getStream() << "group->recordSpk" << recordSuffix << "[(recordingTimestep * numRecordingWords) + (" << env["id"] << " / 32)]"; - env.getStream() << " |= (1 << (" << env["id"] << " % 32));" << std::endl; + env.getStream() << printSubs("$(_record_spk" + suffix + ")[(recordingTimestep * numRecordingWords) + ($(id) / 32)] |= (1 << ($(id) % 32));", env) << std::endl; } } //-------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 5c43c861fb..2749435d28 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -42,7 +42,7 @@ void BackendBase::genCustomUpdateIndexCalculation(EnvironmentGroupMergedFieldgetNumDelaySlots()); env.add(Type::Uint32.addConst(), "_batch_delay_slot", "batchDelaySlot", - {env.addInitialiser("const unsigned int batchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_delay_slot);", env)}); + {env.addInitialiser("const unsigned int batchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_delay_slot);")}); // Calculate current batch offset env.add(Type::Uint32.addConst(), "_batch_delay_offset", "batchDelayOffset", - {env.addInitialiser("const unsigned int batchDelayOffset = $(_batch_offset) * " + numDelaySlotsStr + ";", env)}); + {env.addInitialiser("const unsigned int batchDelayOffset = $(_batch_offset) * " + numDelaySlotsStr + ";")}); } } } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index b957e9dcf8..95551bda2c 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -555,7 +555,7 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, Environmen } batchStrideInit << ";" << std::endl; groupEnv.add(Type::Uint32.addConst(), "_batch_stride", "batchStride", - {groupEnv.addInitialiser(batchStrideInit.str(), groupEnv)}); + {groupEnv.addInitialiser(batchStrideInit.str())}); } @@ -641,7 +641,7 @@ void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase &b // Add substitution // **TODO** dependencies on kernel fields groupEnv.add(Type::Uint32, "id_kernel", "kernelInd", - {groupEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(*this) + ";", groupEnv)}); + {groupEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(*this) + ";")}); // Initialise single (hence empty lambda function) synapse variable genInitWUVarCode(backend, groupEnv, *this, "$(num_pre) * $(_row_stride)", modelMerged.getModel().getBatchSize(), @@ -915,7 +915,7 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Env } batchStrideInit << ";" << std::endl; groupEnv.add(Type::Uint32.addConst(), "_batch_stride", "batchStride", - {groupEnv.addInitialiser(batchStrideInit.str(), groupEnv)}); + {groupEnv.addInitialiser(batchStrideInit.str())}); } } else { diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 6ffabe3733..2324d569da 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -530,13 +530,13 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Substitute spike times const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "id"); neuronEnv.add(getTimeType().addConst(), "sT", "lsT", - {neuronEnv.addInitialiser("const timepoint lsT = $(_spk_time)[" + spikeTimeReadIndex + "];", neuronEnv)}); + {neuronEnv.addInitialiser("const timepoint lsT = $(_spk_time)[" + spikeTimeReadIndex + "];")}); neuronEnv.add(getTimeType().addConst(), "prev_sT", "lprevST", - {neuronEnv.addInitialiser("const timepoint lprevST = $(_prev_spk_time)[" + spikeTimeReadIndex + "];", neuronEnv)}); + {neuronEnv.addInitialiser("const timepoint lprevST = $(_prev_spk_time)[" + spikeTimeReadIndex + "];")}); neuronEnv.add(getTimeType().addConst(), "seT", "lseT", - {neuronEnv.addInitialiser("const timepoint lseT = $(_spk_evnt_time)[" + spikeTimeReadIndex+ "];", neuronEnv)}); + {neuronEnv.addInitialiser("const timepoint lseT = $(_spk_evnt_time)[" + spikeTimeReadIndex+ "];")}); neuronEnv.add(getTimeType().addConst(), "prev_seT", "lprevSET", - {neuronEnv.addInitialiser("const timepoint lprevSET = $(_prev_spk_evnt_time)[" + spikeTimeReadIndex + "];", neuronEnv)}); + {neuronEnv.addInitialiser("const timepoint lprevSET = $(_prev_spk_evnt_time)[" + spikeTimeReadIndex + "];")}); // Create an environment which caches variables in local variables if they are accessed // **NOTE** we do this right at the top so that local copies can be used by child groups diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 14098cabf1..886880cd70 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -46,7 +46,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa if (!sg.getArchetype().getKernelSize().empty()) { // Add substitution synEnv.add(Type::Uint32, "id_kernel", "kernelInd", - {synEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(sg) + ";", synEnv)}); + {synEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(sg) + ";")}); } // If weights are individual, substitute variables for values stored in global memory From af69c9b83a02a9b50efb8f7c32ffd36e2e35f2ab Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 27 Jun 2023 10:35:00 +0100 Subject: [PATCH 257/725] Further simplification of ``LazyString`` usage * No longer exposed in Environment API - constructed everywhere from format strings * No need for all the operators etc - LazyString::print has become constructor and all operators removed --- .../genn/genn/code_generator/environment.h | 39 ++++++----- include/genn/genn/code_generator/lazyString.h | 44 ++---------- .../backends/single_threaded_cpu/backend.cc | 22 +++--- src/genn/genn/code_generator/lazyString.cc | 70 ++++++++----------- .../code_generator/neuronUpdateGroupMerged.cc | 24 +++---- .../synapseUpdateGroupMerged.cc | 16 ++--- 6 files changed, 84 insertions(+), 131 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 056594c3eb..b3cfe681b6 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -299,7 +299,7 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P size_t addInitialiser(const std::string &format) { - m_Initialisers.emplace_back(false, LazyString::print(format, *this)); + m_Initialisers.emplace_back(false, LazyString{format, *this}); return (m_Initialisers.size() - 1); } @@ -341,10 +341,10 @@ class EnvironmentExternal : public EnvironmentExternalDynamicBase &initialisers = {}) { - addInternal(type, name, value, initialisers); + addInternal(type, name, LazyString{value, *this}, initialisers); } }; @@ -376,25 +376,26 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &initialisers = {}) { - addInternal(type, name, std::make_tuple(false, value, std::nullopt), initialisers); + addInternal(type, name, std::make_tuple(false, LazyString{value, *this}, std::nullopt), initialisers); } //! Map a type (for type-checking) and a group merged field to back it to an identifier void addField(const GeNN::Type::ResolvedType &type, const std::string &name, const GeNN::Type::ResolvedType &fieldType, const std::string &fieldName, typename G::GetFieldValueFunc getFieldValue, - const LazyString &indexSuffix = "", GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, + const std::string &indexSuffix = "", GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, const std::vector &initialisers = {}) { - addInternal(type, name, std::make_tuple(false, indexSuffix, std::make_optional(std::make_tuple(fieldType, fieldName, getFieldValue, mergedFieldType))), + addInternal(type, name, std::make_tuple(false, LazyString{indexSuffix, *this}, + std::make_optional(std::make_tuple(fieldType, fieldName, getFieldValue, mergedFieldType))), initialisers); } //! Map a type (for type-checking) and a group merged field to back it to an identifier void addField(const GeNN::Type::ResolvedType &type, const std::string &name, const std::string &fieldName, - typename G::GetFieldValueFunc getFieldValue, const LazyString &indexSuffix = "", GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, + typename G::GetFieldValueFunc getFieldValue, const std::string &indexSuffix = "", GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD, const std::vector &initialisers = {}) { addField(type, name, type, fieldName, getFieldValue, indexSuffix, mergedFieldType, initialisers); @@ -585,7 +586,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase - void addVars(const std::string &arrayPrefix, const LazyString &indexSuffix, const std::string &fieldSuffix = "") + void addVars(const std::string &arrayPrefix, const std::string &indexSuffix, const std::string &fieldSuffix = "") { addVars(arrayPrefix, [&indexSuffix](VarAccess a, const std::string &) { return indexSuffix; }, fieldSuffix); @@ -613,9 +614,9 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase - void addVarRefs(const std::string &arrayPrefix, const LazyString &indexSuffix, const std::string &fieldSuffix = "") + void addVarRefs(const std::string &arrayPrefix, const std::string &indexSuffix, const std::string &fieldSuffix = "") { - addVarRefs(arrayPrefix, [&indexSuffix](VarAccess a, const std::string &) { return indexSuffix; }, + addVarRefs(arrayPrefix, [&indexSuffix](VarAccess a, auto &) { return indexSuffix; }, fieldSuffix); } @@ -641,7 +642,7 @@ class VarCachePolicy { public: using GroupInternal = typename G::GroupInternal; - using GetIndexFn = std::function; + using GetIndexFn = std::function; VarCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex) : m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex) @@ -656,12 +657,12 @@ class VarCachePolicy return A(g).getNameSuffix(); } - LazyString getReadIndex(G &g, const Models::Base::Var &var) + std::string getReadIndex(G &g, const Models::Base::Var &var) { return m_GetReadIndex(var.name, getVarAccessDuplication(var.access)); } - LazyString getWriteIndex(G &g, const Models::Base::Var &var) + std::string getWriteIndex(G &g, const Models::Base::Var &var) { return m_GetWriteIndex(var.name, getVarAccessDuplication(var.access)); } @@ -677,7 +678,7 @@ class VarRefCachePolicy protected: using GroupInternal = typename G::GroupInternal; using Initialiser = typename std::remove_reference_t>::mapped_type; - using GetIndexFn = std::function; + using GetIndexFn = std::function; VarRefCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex) @@ -693,12 +694,12 @@ class VarRefCachePolicy return A(g).getInitialisers().at(var.name).getTargetName(); } - LazyString getReadIndex(G &g, const Models::Base::VarRef &var) + std::string getReadIndex(G &g, const Models::Base::VarRef &var) { return m_GetReadIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); } - LazyString getWriteIndex(G &g, const Models::Base::VarRef &var) + std::string getWriteIndex(G &g, const Models::Base::VarRef &var) { return m_GetWriteIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); } @@ -772,7 +773,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // **NOTE** by not initialising these variables for reductions, // compilers SHOULD emit a warning if user code doesn't set it to something if(!(v.access & VarAccessModeAttribute::REDUCE)) { - getContextStream() << " = group->" << v.name << m_FieldSuffix << "[" << getReadIndex(m_Group.get(), v).str() << "]"; + getContextStream() << " = group->" << v.name << m_FieldSuffix << "[" << printSubs(getReadIndex(m_Group.get(), v), *this) << "]"; } getContextStream() << ";" << std::endl; } @@ -784,7 +785,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P for(const auto &v : referencedDefs) { // If variables are read-write if(v.access & VarAccessMode::READ_WRITE) { - getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << getWriteIndex(m_Group.get(), v).str() << "]"; + getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << printSubs(getWriteIndex(m_Group.get(), v), *this) << "]"; getContextStream() << " = " << m_LocalPrefix << v.name << ";" << std::endl; } } diff --git a/include/genn/genn/code_generator/lazyString.h b/include/genn/genn/code_generator/lazyString.h index 5f86ce4a53..f826491f23 100644 --- a/include/genn/genn/code_generator/lazyString.h +++ b/include/genn/genn/code_generator/lazyString.h @@ -14,21 +14,13 @@ class EnvironmentExternalBase; //---------------------------------------------------------------------------- // GeNN::CodeGenerator::LazyString //---------------------------------------------------------------------------- -//! Base class for external environments i.e. those defines OUTSIDE of transpiled code by code generator +//! Lazily-evaluated string class - constructed from a format string containing $(XX) references to variables in environment namespace GeNN::CodeGenerator { class LazyString { public: - typedef std::variant, std::string>> Element; - typedef std::vector Payload; - - LazyString(const std::string &str) : m_Payload{str} - {} - LazyString(const char *chr) : m_Payload{chr} - {} - LazyString(EnvironmentExternalBase &env, const std::string &name) : m_Payload{std::make_pair(std::ref(env), name)} - {} + LazyString(const std::string &format, EnvironmentExternalBase &env); //---------------------------------------------------------------------------- // Public API @@ -36,42 +28,16 @@ class LazyString //! Evaluate lazy string std::string str() const; - //---------------------------------------------------------------------------- - // Static API - //---------------------------------------------------------------------------- - static LazyString print(const std::string &format, EnvironmentExternalBase &env); - private: - LazyString(const Payload &payload) : m_Payload(payload){} - //---------------------------------------------------------------------------- - // Friends + // Typedefines //---------------------------------------------------------------------------- - friend LazyString operator + (const LazyString& lhs, const LazyString &rhs); + typedef std::variant, std::string>> Element; + typedef std::vector Payload; //---------------------------------------------------------------------------- // Members //---------------------------------------------------------------------------- Payload m_Payload; }; - -//---------------------------------------------------------------------------- -// Operators -//---------------------------------------------------------------------------- -inline LazyString operator + (const LazyString& lhs, const LazyString &rhs) -{ - std::vector payload(lhs.m_Payload); - payload.insert(payload.end(), rhs.m_Payload.cbegin(), rhs.m_Payload.cend()); - return LazyString(payload); -} - -inline LazyString operator + (const char *lhs, const LazyString &rhs) -{ - return LazyString(lhs) + rhs; -} - -inline LazyString operator + (const LazyString &lhs, const char *rhs) -{ - return lhs + LazyString(rhs); -} } // namespace GeNN::CodeGenerator \ No newline at end of file diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 02a4835898..ca41197023 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -386,9 +386,9 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos } // Add correct functions for apply synaptic input - synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", LazyString::print("$(_den_delay)[" + s.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)", synEnv)); - synEnv.add(Type::AddToPost, "addToPost", LazyString::print("$(_out_post)[" + s.getPostISynIndex(1, "j") + "] += $(0)", synEnv)); - synEnv.add(Type::AddToPre, "addToPre", LazyString::print("$(_out_pre)[" + s.getPreISynIndex(1, "id_pre") + "] += $(0)", synEnv)); + synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", "$(_den_delay)[" + s.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)"); + synEnv.add(Type::AddToPost, "addToPost", "$(_out_post)[" + s.getPostISynIndex(1, "j") + "] += $(0)"); + synEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + s.getPreISynIndex(1, "id_pre") + "] += $(0)"); // Call synapse dynamics handler s.generateSynapseUpdate(*this, synEnv, modelMerged); @@ -502,7 +502,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos } synEnv.add(Type::Uint32.addConst(), "id_post", "spike"); - synEnv.add(Type::AddToPre, "addToPre", LazyString::print("$(_out_pre)[" + s.getPreISynIndex(1, "id_pre") + "] += $(0)", synEnv)); + synEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + s.getPreISynIndex(1, "id_pre") + "] += $(0)"); s.generateSynapseUpdate(*this, synEnv, modelMerged); } @@ -1099,7 +1099,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler addSynapse << "while(false)"; const auto addSynapseType = Type::ResolvedType::createFunction(Type::Void, std::vector{1ull + s.getArchetype().getKernelSize().size(), Type::Uint32}); - groupEnv.add(addSynapseType, "addSynapse", LazyString::print(addSynapseStream.str(), groupEnv)); + groupEnv.add(addSynapseType, "addSynapse", addSynapseStream.str()); // Call appropriate connectivity handler if(!snippet->getRowBuildCode().empty()) { @@ -1150,7 +1150,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Generate sparse initialisation code groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); if(s.getArchetype().isWUVarInitRequired()) { - groupEnv.add(Type::Uint32.addConst(), "row_len", LazyString::print("$(_row_length)[i]", groupEnv)); + groupEnv.add(Type::Uint32.addConst(), "row_len", "$(_row_length)[i]"); s.generateInit(*this, groupEnv, modelMerged); } @@ -1203,7 +1203,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Generate initialisation code groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); - groupEnv.add(Type::Uint32.addConst(), "row_len", LazyString::print("$(_row_length)[i]", groupEnv)); + groupEnv.add(Type::Uint32.addConst(), "row_len", "$(_row_length)[i]"); c.generateInit(*this, groupEnv, modelMerged); } } @@ -1232,7 +1232,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Generate initialisation code groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); - groupEnv.add(Type::Uint32.addConst(), "row_len", LazyString::print("$(_row_length)[i]", groupEnv)); + groupEnv.add(Type::Uint32.addConst(), "row_len", "$(_row_length)[i]"); c.generateInit(*this, groupEnv, modelMerged); } } @@ -1861,9 +1861,9 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda } // Add correct functions for apply synaptic input - groupEnv.add(Type::AddToPostDenDelay, "addToPostDelay", LazyString::print("$(_den_delay)[" + sg.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)", groupEnv)); - groupEnv.add(Type::AddToPost, "addToPost", LazyString::print("$(_out_post)[" + sg.getPostISynIndex(1, "j") + "] += $(0)", groupEnv)); - groupEnv.add(Type::AddToPre, "addToPre", LazyString::print("$(_out_pre)[" + sg.getPreISynIndex(1, env["id_pre"]) + "] += $(0)", groupEnv)); + groupEnv.add(Type::AddToPostDenDelay, "addToPostDelay", "$(_den_delay)[" + sg.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)"); + groupEnv.add(Type::AddToPost, "addToPost", "$(_out_post)[" + sg.getPostISynIndex(1, "j") + "] += $(0)"); + groupEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + sg.getPreISynIndex(1, env["id_pre"]) + "] += $(0)"); // If connectivity is sparse if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { diff --git a/src/genn/genn/code_generator/lazyString.cc b/src/genn/genn/code_generator/lazyString.cc index b98849ec0b..1ed631aad5 100644 --- a/src/genn/genn/code_generator/lazyString.cc +++ b/src/genn/genn/code_generator/lazyString.cc @@ -15,66 +15,32 @@ using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- // GeNN::CodeGenerator::LazyString //---------------------------------------------------------------------------- -std::string LazyString::str() const -{ - std::ostringstream stream; - for(const auto &e : m_Payload) - { - std::visit( - Utils::Overload{ - [&stream](const std::string &str) - { - stream << str; - }, - [&stream](const EnvRef &envRef) - { - stream << envRef.env.get()[envRef.name]; - }}, - e); - } - return stream.str(); -} -//---------------------------------------------------------------------------- -LazyString& LazyString::operator += (const LazyString &rhs) -{ - // Add RHS's payload to ours - m_Payload.insert(m_Payload.end(), rhs.m_Payload.cbegin(), rhs.m_Payload.cend()); - - return *this; -} -//---------------------------------------------------------------------------- -LazyString& LazyString::operator += (const std::string &rhs) -{ - return operator += (LazyString{rhs}); -} -//---------------------------------------------------------------------------- -LazyString LazyString::print(const std::string &format, EnvironmentExternalBase &env) +LazyString::LazyString(const std::string &format, EnvironmentExternalBase &env) { // Create regex iterator to iterate over $(XXX) style varibles in format string std::regex regex("\\$\\(([\\w]+)\\)"); std::sregex_iterator matchesBegin(format.cbegin(), format.cend(), regex); std::sregex_iterator matchesEnd; - // If there are no matches, leave format unmodified and return + // If there are no matches, use format directly as payload if(matchesBegin == matchesEnd) { - return LazyString(format); + m_Payload.push_back(format); } // Otherwise else { // Loop through matches to build lazy string payload - Payload payload; for(std::sregex_iterator m = matchesBegin;;) { // Copy the non-matched subsequence (m->prefix()) onto payload - payload.push_back(std::string{m->prefix().first, m->prefix().second}); + m_Payload.push_back(std::string{m->prefix().first, m->prefix().second}); // Add lazy environment reference for $(XXX) variable to payload - payload.push_back(EnvRef{std::ref(env), (*m)[1]}); + m_Payload.push_back(std::make_pair(std::ref(env), (*m)[1])); // If there are no subsequent matches, add the remaining non-matched - // characters onto payload, construct lazy string and return + // characters onto payload, construct lazy string and stop if(std::next(m) == matchesEnd) { - payload.push_back(std::string{m->suffix().first, m->suffix().second}); - return LazyString(payload); + m_Payload.push_back(std::string{m->suffix().first, m->suffix().second}); + break; } // Otherwise go onto next match else { @@ -82,4 +48,24 @@ LazyString LazyString::print(const std::string &format, EnvironmentExternalBase } } } +} +//---------------------------------------------------------------------------- +std::string LazyString::str() const +{ + std::ostringstream stream; + for(const auto &e : m_Payload) + { + std::visit( + Utils::Overload{ + [&stream](const std::string &str) + { + stream << str; + }, + [&stream](const std::pair, std::string> &envRef) + { + stream << envRef.first.get()[envRef.second]; + }}, + e); + } + return stream.str(); } \ No newline at end of file diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 2324d569da..47b6112946 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -35,15 +35,15 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend csEnv.addExtraGlobalParams(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", fieldSuffix); // Define inject current function - csEnv.add(Type::ResolvedType::createFunction(Type::Void, {modelMerged.getModel().getPrecision()}), "injectCurrent", - LazyString::print("$(Isyn) += $(0)", csEnv)); + csEnv.add(Type::ResolvedType::createFunction(Type::Void, {modelMerged.getModel().getPrecision()}), + "injectCurrent", "$(Isyn) += $(0)"); // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, - [&csEnv, &modelMerged, &ng](const std::string&, VarAccessDuplication d) + [&modelMerged, &ng](const std::string&, VarAccessDuplication d) { - return LazyString::print(ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "id"), csEnv); + return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "id"); }); // Pretty print code back to environment @@ -127,9 +127,9 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, - [&psmEnv, &modelMerged, &ng](const std::string&, VarAccessDuplication d) + [&modelMerged, &ng](const std::string&, VarAccessDuplication d) { - return LazyString::print(ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "id"), psmEnv); + return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "id"); }); // Pretty print code back to environment @@ -219,11 +219,11 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { - return LazyString::print(ng.getReadVarIndex(delayed, batchSize, d, "id"), synEnv); + return ng.getReadVarIndex(delayed, batchSize, d, "id"); }, [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { - return LazyString::print(ng.getWriteVarIndex(delayed, batchSize, d, "id"), synEnv); + return ng.getWriteVarIndex(delayed, batchSize, d, "id"); }); /*neuronSubstitutionsInSynapticCode(varEnv, &ng.getArchetype(), "", "_post", "", "", "", dynamicsNotSpike, @@ -309,13 +309,13 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, - [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) + [batchSize, delayed, &ng](const std::string&, VarAccessDuplication d) { - return LazyString::print(ng.getReadVarIndex(delayed, batchSize, d, "id"), synEnv); + return ng.getReadVarIndex(delayed, batchSize, d, "id"); }, - [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) + [batchSize, delayed, &ng](const std::string&, VarAccessDuplication d) { - return LazyString::print(ng.getWriteVarIndex(delayed, batchSize, d, "id"), synEnv); + return ng.getWriteVarIndex(delayed, batchSize, d, "id"); }); /*neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_pre", "", "", "", dynamicsNotSpike, diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 886880cd70..2174078de6 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -31,14 +31,14 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa // Substitute names of pre and postsynaptic weight update variable synEnv.addVars(backend.getDeviceVarPrefix(), - [&sg, &synEnv, batchSize](VarAccess a, const std::string&) + [&sg, batchSize](VarAccess a, const std::string&) { - return LazyString::print(sg.getPreWUVarIndex(batchSize, getVarAccessDuplication(a), "id_pre"), synEnv); + return sg.getPreWUVarIndex(batchSize, getVarAccessDuplication(a), "id_pre"); }); synEnv.addVars(backend.getDeviceVarPrefix(), - [&sg, &synEnv, batchSize](VarAccess a, const std::string&) + [&sg, batchSize](VarAccess a, const std::string&) { - return LazyString::print(sg.getPostWUVarIndex(batchSize, getVarAccessDuplication(a), "id_post"), synEnv); + return sg.getPostWUVarIndex(batchSize, getVarAccessDuplication(a), "id_post"); }); @@ -52,9 +52,9 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa // If weights are individual, substitute variables for values stored in global memory if (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::INDIVIDUAL) { synEnv.addVars(backend.getDeviceVarPrefix(), - [&sg, &synEnv, batchSize](VarAccess a, const std::string&) + [&sg, batchSize](VarAccess a, const std::string&) { - return LazyString::print(sg.getSynVarIndex(batchSize, getVarAccessDuplication(a), "id_syn"), synEnv); + return sg.getSynVarIndex(batchSize, getVarAccessDuplication(a), "id_syn"); }); } // Otherwise, if weights are procedual @@ -102,9 +102,9 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa assert(!sg.getArchetype().getKernelSize().empty()); synEnv.addVars(backend.getDeviceVarPrefix(), - [&sg, &synEnv, batchSize](VarAccess a, const std::string&) + [&sg, batchSize](VarAccess a, const std::string&) { - return LazyString::print(sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), "id_kernel"), synEnv); + return sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), "id_kernel"); }); } // Otherwise, substitute variables for constant values From 019f37347c78cf8321f0e08d2456ec31f0bab4af Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 27 Jun 2023 10:37:33 +0100 Subject: [PATCH 258/725] fixed some backend typos --- src/genn/backends/single_threaded_cpu/backend.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index ca41197023..8e165a4384 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -72,7 +72,7 @@ class Timer //----------------------------------------------------------------------- template -void genKernelIteration(EnvironmentExternalBase &env, const G &g, size_t numKernelDims, BackendBase::HandlerEnv handler) +void genKernelIteration(EnvironmentExternalBase &env, G &g, size_t numKernelDims, BackendBase::HandlerEnv handler) { // Define recursive function to generate nested kernel initialisation loops // **NOTE** this is a std::function as type of auto lambda couldn't be determined inside for recursive call @@ -690,7 +690,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host CodeStream::Scope b(groupEnv.getStream()); // Add presynaptic index to substitutions - EnvironmentGroupMergedField synEnv(groupEnv, c); + EnvironmentGroupMergedField synEnv(groupEnv, c); synEnv.add(Type::Uint32.addConst(), "id_pre", "i"); // If connectivity is sparse From 8720260964366fed9debd7bce108c1e5fec1cf24 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 27 Jun 2023 12:32:47 +0100 Subject: [PATCH 259/725] cust6om connectivity work --- .../customConnectivityUpdateGroupMerged.h | 38 +- include/genn/genn/transpiler/prettyPrinter.h | 2 +- .../customConnectivityUpdateGroupMerged.cc | 324 ++++++++++-------- src/genn/genn/transpiler/prettyPrinter.cc | 3 +- 4 files changed, 208 insertions(+), 159 deletions(-) diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index dcbe480269..a3e8c86910 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -8,31 +8,14 @@ #include "code_generator/groupMerged.h" //---------------------------------------------------------------------------- -// GeNN::CodeGenerator::CustomConnectivityUpdateGroupMergedBase +// GeNN::CodeGenerator::CustomConnectivityUpdateGroupMerged //---------------------------------------------------------------------------- namespace GeNN::CodeGenerator { -class GENN_EXPORT CustomConnectivityUpdateGroupMergedBase : public GroupMerged +class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public GroupMerged { public: - CustomConnectivityUpdateGroupMergedBase(size_t index, const Type::TypeContext &typeContext, - const std::vector> &groups); - -protected: - //---------------------------------------------------------------------------- - // Protected methods - //---------------------------------------------------------------------------- - bool isParamHeterogeneous(const std::string &name) const; - bool isDerivedParamHeterogeneous(const std::string &name) const; -}; - -//---------------------------------------------------------------------------- -// GeNN::CodeGenerator::CustomConnectivityUpdateGroupMerged -//---------------------------------------------------------------------------- -class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public CustomConnectivityUpdateGroupMergedBase -{ -public: - CustomConnectivityUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, + CustomConnectivityUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups); //---------------------------------------------------------------------------- @@ -61,6 +44,12 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public CustomConnectivit static const std::string name; private: + //---------------------------------------------------------------------------- + // Private methods + //---------------------------------------------------------------------------- + bool isParamHeterogeneous(const std::string &name) const; + bool isDerivedParamHeterogeneous(const std::string &name) const; + //---------------------------------------------------------------------------- // Members //---------------------------------------------------------------------------- @@ -72,11 +61,11 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public CustomConnectivit //---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomConnectivityHostUpdateGroupMerged //---------------------------------------------------------------------------- -class GENN_EXPORT CustomConnectivityHostUpdateGroupMerged : public CustomConnectivityUpdateGroupMergedBase +class GENN_EXPORT CustomConnectivityHostUpdateGroupMerged : public GroupMerged { public: - CustomConnectivityHostUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, - const std::vector> &groups); + using GroupMerged::GroupMerged; + //---------------------------------------------------------------------------- // Public API //---------------------------------------------------------------------------- @@ -100,6 +89,9 @@ class GENN_EXPORT CustomConnectivityHostUpdateGroupMerged : public CustomConnect //---------------------------------------------------------------------------- // Private methods //---------------------------------------------------------------------------- + bool isParamHeterogeneous(const std::string &name) const; + bool isDerivedParamHeterogeneous(const std::string &name) const; + void addVarPushPullFuncSubs(const BackendBase &backend, Substitutions &subs, const Models::Base::VarVec &vars, const std::string &count, VarLocation(CustomConnectivityUpdateInternal:: *getVarLocationFn)(const std::string&) const) const; diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h index 2d02332316..14cd30034c 100644 --- a/include/genn/genn/transpiler/prettyPrinter.h +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -38,7 +38,7 @@ class EnvironmentBase virtual CodeGenerator::CodeStream &getStream() = 0; }; -typedef std::function)> StatementHandler; +typedef std::function)> StatementHandler; //--------------------------------------------------------------------------- // Free functions diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index c088a14e1a..1a76f9ebf4 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -13,62 +13,98 @@ using namespace GeNN; using namespace GeNN::CodeGenerator; //---------------------------------------------------------------------------- -// CodeGenerator::CustomConnectivityUpdateGroupMergedBase +// Anonymous namespace //---------------------------------------------------------------------------- -CustomConnectivityUpdateGroupMergedBase::CustomConnectivityUpdateGroupMergedBase(size_t index, const Type::TypeContext &typeContext, - const std::vector> &groups) -: GroupMerged(index, typeContext, groups) +namespace { - using namespace Type; - - addField(Uint32, "numSrcNeurons", - [](const auto &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); - }); - - addField(Uint32, "numTrgNeurons", - [](const auto &cg, size_t) - { - const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); - return std::to_string(sgInternal->getTrgNeuronGroup()->getNumNeurons()); - }); - - // Add heterogeneous custom update model parameters - addHeterogeneousParams( - getArchetype().getCustomConnectivityUpdateModel()->getParamNames(), "", - [](const auto &cg) { return cg.getParams(); }, - &CustomConnectivityUpdateGroupMergedBase::isParamHeterogeneous); - - // Add heterogeneous weight update model CustomConnectivityUpdateGroupMerged parameters - addHeterogeneousDerivedParams( - getArchetype().getCustomConnectivityUpdateModel()->getDerivedParams(), "", - [](const auto &cg) { return cg.getDerivedParams(); }, - &CustomConnectivityUpdateGroupMergedBase::isDerivedParamHeterogeneous); +template +void addPrivateVarPointerFields(EnvironmentGroupMergedField &env, const std::string &arrayPrefix, const G &group) +{ + // Loop through variable references and add private pointer field + const A archetypeAdaptor(group.getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + const auto resolvedType = v.type.resolve(getGroup().getTypeContext()); + env.addField(resolvedType.createPointer(), "_" + v.name, v.name + fieldSuffix, + [arrayPrefix, v](const auto &g, size_t) + { + return arrayPrefix + v.name + A(g).getNameSuffix(); + }); + } } //---------------------------------------------------------------------------- -bool CustomConnectivityUpdateGroupMergedBase::isParamHeterogeneous(const std::string &name) const +template +void addPrivateVarRefPointerFields(EnvironmentGroupMergedField &env, const std::string &arrayPrefix, const G &group) { - return isParamValueHeterogeneous(name, [](const CustomConnectivityUpdateInternal &cg) { return cg.getParams(); }); + // Loop through variable references and add private pointer field + const A archetypeAdaptor(group.getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + const auto resolvedType = v.type.resolve(getGroup().getTypeContext()); + env.addField(resolvedType.createPointer(), "_" + v.name, v.name + fieldSuffix, + [arrayPrefix, v](const auto &g, size_t) + { + const auto varRef = A(g).getInitialisers().at(v.name); + return arrayPrefix + varRef.getVar().name + varRef.getTargetName(); + }); + } } //---------------------------------------------------------------------------- -bool CustomConnectivityUpdateGroupMergedBase::isDerivedParamHeterogeneous(const std::string &name) const +template +void addPrivateVarRefAccess(EnvironmentGroupMergedField &env, const G &group, unsigned int batchSize, I getIndexFn) { - return isParamValueHeterogeneous(name, [](const CustomConnectivityUpdateInternal &cg) { return cg.getDerivedParams(); }); + // Loop through variable references + const A archetypeAdaptor(group.getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + // If model isn't batched or variable isn't duplicated + const auto &varRef = archetypeAdaptor.getInitialisers().at(v.name); + if(batchSize == 1 || !varRef.isDuplicated()) { + // Add field with qualified type which indexes private pointer field + const auto resolvedType = v.type.resolve(getGroup().getTypeContext()); + const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; + env.add(qualifiedType, v.name, "$(_" + v.name + ")[" + getIndexFn(v.access, varRef)); + } + } +} +//---------------------------------------------------------------------------- +template +void addPrivateVarRefAccess(EnvironmentGroupMergedField &env, const G &group, unsigned int batchSize, const std::string &indexSuffix) +{ + addPrivateVarRefPointerFields(env, group, batchSize, [&indexSuffix](){ return indexSuffix; }); } +//---------------------------------------------------------------------------- +template +void addPrivateVarAccess(EnvironmentGroupMergedField &env, const G &group, unsigned int batchSize, const std::string &indexSuffix) +{ + // Loop through variable references + const A archetypeAdaptor(group.getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + // Add field with qualified type which indexes private pointer field + const auto resolvedType = v.type.resolve(getGroup().getTypeContext()); + const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; + env.add(qualifiedType, v.name, "$(_" + v.name + ")[" + getIndexFn(v.access, varRef)); + } +} +//---------------------------------------------------------------------------- +template +void addTypes(Transpiler::TypeChecker::EnvironmentBase &env, const std::vector &vars, + const Type::TypeContext &typeContext, Transpiler::ErrorHandler::ErrorHandlerBase &errorHandle) +{ + for(const auto &v : vars) { + const auto resolvedType = v.type.resolve(typeContext); + const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; + env.define(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, v.name, 0}, qualifiedType, errorHandler); + } +} +} // Anonymous namespace //---------------------------------------------------------------------------- // CodeGenerator::CustomConnectivityUpdateGroupMerged //---------------------------------------------------------------------------- const std::string CustomConnectivityUpdateGroupMerged::name = "CustomConnectivityUpdate"; //---------------------------------------------------------------------------- -CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, +CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups) -: CustomConnectivityUpdateGroupMergedBase(index, typeContext, groups) +: GroupMerged(index, typeContext, groups) { - using namespace Type; - // Reserve vector of vectors to hold variables to update for all custom connectivity update groups, in archetype order m_SortedDependentVars.reserve(getGroups().size()); @@ -144,27 +180,8 @@ CustomConnectivityUpdateGroupMerged::CustomConnectivityUpdateGroupMerged(size_t }); } - // If this backend requires per-population RNGs and this group requires one - if(backend.isPopulationRNGRequired() && getArchetype().isRowSimRNGRequired()){ - addPointerField(*backend.getMergedGroupSimRNGType(), "rng", backend.getDeviceVarPrefix() + "rowRNG"); - } // Add variables to struct - const auto *cm = getArchetype().getCustomConnectivityUpdateModel(); - addVars(cm->getVars(), backend.getDeviceVarPrefix()); - addVars(cm->getPreVars(), backend.getDeviceVarPrefix()); - addVars(cm->getPostVars(), backend.getDeviceVarPrefix()); - - // Add variable references to struct - addVarReferences(cm->getVarRefs(), backend.getDeviceVarPrefix(), - [](const auto &cg) { return cg.getVarReferences(); }); - addVarReferences(cm->getPreVarRefs(), backend.getDeviceVarPrefix(), - [](const auto &cg) { return cg.getPreVarReferences(); }); - addVarReferences(cm->getPostVarRefs(), backend.getDeviceVarPrefix(), - [](const auto &cg) { return cg.getPostVarReferences(); }); - - // Add EGPs to struct - this->addEGPs(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); // Loop through sorted dependent variables @@ -228,6 +245,13 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); }); + // Calculate index of start of row + updateEnv.add(Type::Uint32.addConst(), "_row_start_idx", "rowStartIdx", + {updateEnv.addInitialiser("const unsigned int rowStartIdx = $(id_pre) * $(_row_stride);")}); + + updateEnv.add(Type::Uint32.addConst(), "_syn_stride", "synStride", + {updateEnv.addInitialiser("const unsigned int synStride = $(num_pre) * $(_row_stride);")}); + // Substitute parameter and derived parameter names const auto *cm = getArchetype().getCustomConnectivityUpdateModel(); updateEnv.addParams(cm->getParamNames(), "", &CustomConnectivityUpdateInternal::getParams, @@ -236,28 +260,52 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back &CustomConnectivityUpdateGroupMerged::isDerivedParamHeterogeneous); updateEnv.addExtraGlobalParams(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); - // Add presynaptic variables and variable references - // **TODO** var references to batched variables should be private - // **THINK** what about batched pre var references? - updateEnv.addVars(backend.getDeviceVarPrefix(), LazyString{updateEnv, "id_pre"}, ""); - updateEnv.addVarRefs(backend.getDeviceVarPrefix(), - [&updateEnv](VarAccessMode, const Models::VarReference &v) - { - if(v.getDelayNeuronGroup() != nullptr) { - return LazyString::print("$(_pre_delay_offset) + $(id_pre)", updateEnv); - } - else { - return LazyString{updateEnv, "id_pre"}; - } - }, ""); + // Add presynaptic variables + updateEnv.addVars(backend.getDeviceVarPrefix(), "$(id_pre)"); + + // Loop through presynaptic variable references + for(const auto &v : getArchetype().getCustomConnectivityUpdateModel()->getPreVarRefs()) { + // If model isn't batched or variable isn't duplicated + const auto &varRef = getArchetype().getPreVarReferences().at(v.name); + if(modelMerged.getModel().getBatchSize() == 1 || !varRef.isDuplicated()) { + // Determine index + const std::string index = (varRef.getDelayNeuronGroup() != nullptr) ? "$(_pre_delay_offset) + $(id_pre)" : "$(id_pre)"; + + // If variable access is read-only, qualify type with const + const auto resolvedType = v.type.resolve(getTypeContext()); + const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; + + // Add field + updateEnv.addField(qualifiedType, v.name, + resolvedType.createPointer(), v.name, + [&backend, v](const auto &g, size_t) + { + const auto varRef = g.getPreVarReferences().at(v.name); + return backend.getDeviceVarPrefix() + varRef.getVar().name + varRef.getTargetName(); + }, + index); + } + } - // Calculate index of start of row - updateEnv.add(Type::Uint32.addConst(), "_row_start_idx", "rowStartIdx", - {updateEnv.addInitialiser("const unsigned int rowStartIdx = $(id_pre) * $(_row_stride);", updateEnv)}); + // Add fields and private $(_XXX) substitutions for postsyanptic and synaptic variables and variables references as, + // while these can only be accessed by user code inside loop, they can be used directly by add/remove synapse functions + addPrivateVarPointerFields(updateEnv, backend.getDeviceVarPrefix(), *this); + addPrivateVarPointerFields(updateEnv, backend.getDeviceVarPrefix(), *this); + addPrivateVarRefPointerFields(updateEnv, backend.getDeviceVarPrefix(), *this); + addPrivateVarRefPointerFields(updateEnv, backend.getDeviceVarPrefix(), *this); - updateEnv.add(Type::Uint32.addConst(), "_syn_stride", "synStride", - {updateEnv.addInitialiser("const unsigned int synStride = $(num_pre) * $(_row_stride);", updateEnv)}); + // Add private fields for dependent variables + for(size_t i = 0; i < getSortedArchetypeDependentVars().size(); i++) { + auto resolvedType = getSortedArchetypeDependentVars().at(i).getVar().type.resolve(getTypeContext()); + updateEnv.addField(resolvedType.createPointer(), "_dependent_var_" + std::to_string(i), "dependentVar" + std::to_string(i), + [i, &backend, this](const auto&, size_t g) + { + const auto &varRef = m_SortedDependentVars[g][i]; + return backend.getDeviceVarPrefix() + varRef.getVar().name + varRef.getTargetName(); + }); + } + // Get variables which will need to be manipulated when adding and removing synapses const auto ccuVars = cm->getVars(); const auto ccuVarRefs = cm->getVarRefs(); @@ -282,7 +330,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Use subsequent parameters to initialise new synapse's custom connectivity update model variables for (size_t i = 0; i < ccuVars.size(); i++) { - addSynapse << "group->" << ccuVars[i].name << "[newIdx] = $(" << (1 + i) << ");" << std::endl; + addSynapse << "$(_" << ccuVars[i].name << ")[newIdx] = $(" << (1 + i) << ");" << std::endl; addSynapseTypes.push_back(ccuVars[i].type.resolve(getTypeContext())); } @@ -296,12 +344,12 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); - addSynapse << "group->" << ccuVarRefs[i].name << "[(b * $(_syn_stride)) + newIdx] = _" << ccuVarRefs[i].name << "Val;" << std::endl; + addSynapse << "$(_" << ccuVarRefs[i].name << ")[(b * $(_syn_stride)) + newIdx] = _" << ccuVarRefs[i].name << "Val;" << std::endl; } } // Otherwise, write parameter straight into var reference else { - addSynapse << "group->" << ccuVarRefs[i].name << "[newIdx] = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; + addSynapse << "$(_" << ccuVarRefs[i].name << ")[newIdx] = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; } addSynapseTypes.push_back(ccuVarRefs[i].type.resolve(getTypeContext())); @@ -316,12 +364,12 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); - addSynapse << "group->_dependentVar" << i << "[(b * $(_syn_stride)) + newIdx] = 0;" << std::endl; + addSynapse << "$(_dependent_var_" << i << ")[(b * $(_syn_stride)) + newIdx] = 0;" << std::endl; } } // Otherwise, zero var reference else { - addSynapse << "group->_dependentVar" << i << "[newIdx] = 0;" << std::endl; + addSynapse << "$(_dependent_var_" << i << ")[newIdx] = 0;" << std::endl; } addSynapseTypes.push_back(dependentVars.at(i).getVar().type.resolve(getTypeContext())); @@ -333,7 +381,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back } // Add function substitution with parameters to initialise custom connectivity update variables and variable references - updateEnv.add(Type::ResolvedType::createFunction(Type::Void, addSynapseTypes), "add_synapse", LazyString(addSynapseStream.str(), updateEnv)); + updateEnv.add(Type::ResolvedType::createFunction(Type::Void, addSynapseTypes), "add_synapse", addSynapseStream.str()); // Generate code to remove a synapse from this row std::stringstream removeSynapseStream; @@ -349,7 +397,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Copy custom connectivity update variables from end of row over synapse to be deleted for (size_t i = 0; i < ccuVars.size(); i++) { - removeSynapse << "group->" << ccuVars[i].name << "[idx] = group->" << ccuVars[i].name << "[lastIdx];" << std::endl; + removeSynapse << "$(_" << ccuVars[i].name << ")[idx] = $(_" << ccuVars[i].name << ")[lastIdx];" << std::endl; } // Loop through variable references @@ -362,13 +410,13 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back removeSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); - removeSynapse << "group->" << ccuVarRefs[i].name << "[(b * $(_syn_stride)) + idx] = "; - removeSynapse << "group->" << ccuVarRefs[i].name << "[(b * $(_syn_stride)) + lastIdx];" << std::endl; + removeSynapse << "$(_" << ccuVarRefs[i].name << ")[(b * $(_syn_stride)) + idx] = "; + removeSynapse << "$(_" << ccuVarRefs[i].name << ")[(b * $(_syn_stride)) + lastIdx];" << std::endl; } } // Otherwise, copy custom connectivity update variable references from end of row over synapse to be deleted else { - removeSynapse << "group->" << ccuVarRefs[i].name << "[idx] = group->" << ccuVarRefs[i].name << "[lastIdx];" << std::endl; + removeSynapse << "$(_" << ccuVarRefs[i].name << ")[idx] = $(_" << ccuVarRefs[i].name << ")[lastIdx];" << std::endl; } } @@ -380,13 +428,13 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back removeSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(removeSynapse); - removeSynapse << "group->_dependentVar" << i << "[(b * $(_syn_stride)) + idx] = "; - removeSynapse << "group->_dependentVar" << i << "[(b * $(_syn_stride)) + lastIdx];" << std::endl; + removeSynapse << "$(_dependent_var_" << i << ")[(b * $(_syn_stride)) + idx] = "; + removeSynapse << "$(_dependent_var_" << i << ")[(b * $(_syn_stride)) + lastIdx];" << std::endl; } } // Otherwise, copy dependent variable from end of row over synapse to be deleted else { - removeSynapse << "group->_dependentVar" << i << "[idx] = group->_dependentVar" << i << "[lastIdx];" << std::endl; + removeSynapse << "$(_dependent_var_" << i << ")[idx] = $(_dependent_var_" << i << ")[lastIdx];" << std::endl; } } @@ -399,18 +447,24 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back } // Add function substitution with parameters to initialise custom connectivity update variables and variable references - updateEnv.add(Type::ResolvedType::createFunction(Type::Void, {}), "remove_synapse", LazyString{updateEnv, removeSynapseStream.str()}); + updateEnv.add(Type::ResolvedType::createFunction(Type::Void, {}), "remove_synapse", removeSynapseStream.str()); // Pretty print code back to environment Transpiler::ErrorHandler errorHandler("Current source injection" + std::to_string(getIndex())); prettyPrintStatements(cm->getRowUpdateCode(), getTypeContext(), updateEnv, errorHandler, // Within for_each_synapse loops, define the following types - [](auto &env, auto &errorHandler) + [this](auto &env, auto &errorHandler) { + // Add typed indices env.define(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, "id_post", 0}, Type::Uint32.addConst(), errorHandler); env.define(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, "id_syn", 0}, Type::Uint32.addConst(), errorHandler); - // **TODO** variable types + // Add types for variables and variable references accessible within loop + // **TODO** filter + addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getVars(), getTypeContext(), errorHandler); + addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getPostVars(), getTypeContext(), errorHandler); + addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getVarRefs(), getTypeContext(), errorHandler); + addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getPostVarRefs(), getTypeContext(), errorHandler); }, [&backend, &modelMerged, this](auto &env, auto generateBody) { @@ -420,58 +474,50 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back CodeStream::Scope b(bodyEnv.getStream()); // Add postsynaptic and synaptic indices - bodyEnv.add(Type::Uint32.addConst(), "id_post", LazyString::print("$(_ind)[$(_row_start_idx) + j]", bodyEnv); + bodyEnv.add(Type::Uint32.addConst(), "id_post", "$(_ind)[$(_row_start_idx) + j]"); bodyEnv.add(Type::Uint32.addConst(), "id_syn", "idx", - {bodyEnv.addInitialiser("const unsigned int idx = $(_row_start_idx) + j;", bodyEnv)}); + {bodyEnv.addInitialiser("const unsigned int idx = $(_row_start_idx) + j;")}); // Add postsynaptic and synaptic variables - bodyEnv.addVars(backend.getDeviceVarPrefix(), LazyString{bodyEnv, "id_syn"}, ""); - bodyEnv.addVars(backend.getDeviceVarPrefix(), LazyString{bodyEnv, "id_post"}, ""); - - // Add postsynaptic and synaptic var references - // **TODO** - bodyEnv.addVarRefs(backend.getDeviceVarPrefix(), - [modelMerged, this](const std::string &ma, const Models::VarReference &v) - { - return (modelMerged.getModel().getBatchSize() == 1) || !v.isDuplicated(); - }); - bodyEnv.addVarRefs(backend.getDeviceVarPrefix(), - [modelMerged, this](const std::string &ma, const Models::WUVarReference &v) - { - return (modelMerged.getModel().getBatchSize() == 1) || !v.isDuplicated(); - }); - // Substitute in variable references, filtering out those which are duplicated - const auto &variableRefs = getArchetype().getVarReferences(); - updateSubs.addVarNameSubstitution(cm->getVarRefs(), "", "group->", - [&updateSubs](VarAccessMode, const std::string&) { return "[" + updateSubs["id_syn"] + "]"; }, - [modelBatched, &variableRefs](VarAccessMode, const std::string &name) - { - return !modelBatched || !variableRefs.at(name).isDuplicated(); - }); + bodyEnv.addVars(backend.getDeviceVarPrefix(), "$(id_syn)"); + bodyEnv.addVars(backend.getDeviceVarPrefix(), "$(id_post)"); + + // Add postsynaptic and synapse variable references, only exposing those that aren't batched + addPrivateVarRefAccess(bodyEnv, *this, modelMerged.getModel().getBatchSize(), "$(id_syn)"); + addPrivateVarRefAccess( + bodyEnv, *this, modelMerged.getModel().getBatchSize(), + [](VarAccessMode a, const Models::VarReference &varRef) + { + if(varRef.getDelayNeuronGroup() != nullptr) { + return "$(_post_delay_offset) + $(id_post)"; + } + else { + return "$(id_post)"; + } + }); - // Substitute in (potentially delayed) postsynaptic variable references - const auto &postVariableRefs = getArchetype().getPostVarReferences(); - updateSubs.addVarNameSubstitution(cm->getPostVarRefs(), "", "group->", - [&postVariableRefs, &updateSubs](VarAccessMode, const std::string &name) - { - if(postVariableRefs.at(name).getDelayNeuronGroup() != nullptr) { - return "[postDelayOffset + " + updateSubs["id_post"] + "]"; - } - else { - return "[" + updateSubs["id_post"] + "]"; - } - }); - generateBody(bodyEnv); + // Generate body of for_each_synapse loop within this new environment + generateBody(bodyEnv); } }); } +//---------------------------------------------------------------------------- +bool CustomConnectivityUpdateGroupMerged::isParamHeterogeneous(const std::string &name) const +{ + return isParamValueHeterogeneous(name, [](const CustomConnectivityUpdateInternal &cg) { return cg.getParams(); }); +} +//---------------------------------------------------------------------------- +bool CustomConnectivityUpdateGroupMerged::isDerivedParamHeterogeneous(const std::string &name) const +{ + return isParamValueHeterogeneous(name, [](const CustomConnectivityUpdateInternal &cg) { return cg.getDerivedParams(); }); +} //---------------------------------------------------------------------------- // CodeGenerator::CustomConnectivityHostUpdateGroupMerged //---------------------------------------------------------------------------- const std::string CustomConnectivityHostUpdateGroupMerged::name = "CustomConnectivityHostUpdate"; //---------------------------------------------------------------------------- -CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const BackendBase &backend, +/*CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> &groups) : CustomConnectivityUpdateGroupMergedBase(index, typeContext, groups) { @@ -499,9 +545,9 @@ CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged } } -} +}*/ //---------------------------------------------------------------------------- -void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase &backend, CodeStream &os) const +void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const { CodeStream::Scope b(os); os << "// merged custom connectivity host update group " << getIndex() << std::endl; @@ -564,6 +610,16 @@ void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase & } } //---------------------------------------------------------------------------- +bool CustomConnectivityHostUpdateGroupMerged::isParamHeterogeneous(const std::string &name) const +{ + return isParamValueHeterogeneous(name, [](const CustomConnectivityUpdateInternal &cg) { return cg.getParams(); }); +} +//---------------------------------------------------------------------------- +bool CustomConnectivityHostUpdateGroupMerged::isDerivedParamHeterogeneous(const std::string &name) const +{ + return isParamValueHeterogeneous(name, [](const CustomConnectivityUpdateInternal &cg) { return cg.getDerivedParams(); }); +} +//---------------------------------------------------------------------------- void CustomConnectivityHostUpdateGroupMerged::addVarPushPullFuncSubs(const BackendBase &backend, Substitutions &subs, const Models::Base::VarVec &vars, const std::string &count, VarLocation(CustomConnectivityUpdateInternal:: *getVarLocationFn)(const std::string&) const) const diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 8897cf8114..a9fe80ae7d 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -413,8 +413,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor m_Environment = environment; m_ForEachSynapseHandler(m_Environment, - [this, &forEachSynapseStatement]() + [this, &forEachSynapseStatement](EnvironmentBase &env) { + m_Environment = env; forEachSynapseStatement.getBody()->accept(*this); }); // Restore old environment From f8e2896c26df6eb73b55186b4fc366d51b7d726b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 27 Jun 2023 17:27:27 +0100 Subject: [PATCH 260/725] fixed issue with string index overrides in environment and got custom connectivity update compiling --- .../customConnectivityUpdateGroupMerged.h | 139 +++++++- .../genn/genn/code_generator/environment.h | 24 +- .../genn/customConnectivityUpdateInternal.h | 8 + include/genn/genn/customUpdateInternal.h | 4 + .../customConnectivityUpdateGroupMerged.cc | 312 +++++------------- 5 files changed, 239 insertions(+), 248 deletions(-) diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index a3e8c86910..035c794e4d 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -32,7 +32,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public GroupMerged + void addPrivateVarPointerFields(EnvironmentGroupMergedField &env, const std::string &arrayPrefix) + { + // Loop through variables and add private pointer field + const A archetypeAdaptor(getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + const auto resolvedType = v.type.resolve(getTypeContext()); + env.addField(resolvedType.createPointer(), "_" + v.name, v.name, + [arrayPrefix, v](const auto &g, size_t) + { + return arrayPrefix + v.name + A(g).getNameSuffix(); + }); + } + } + + template + void addPrivateVarRefPointerFields(EnvironmentGroupMergedField &env, const std::string &arrayPrefix) + { + // Loop through variable references and add private pointer field + const A archetypeAdaptor(getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + const auto resolvedType = v.type.resolve(getTypeContext()); + env.addField(resolvedType.createPointer(), "_" + v.name, v.name, + [arrayPrefix, v](const auto &g, size_t) + { + const auto varRef = A(g).getInitialisers().at(v.name); + return arrayPrefix + varRef.getVar().name + varRef.getTargetName(); + }); + } + } + + template + void addPrivateVarRefAccess(EnvironmentGroupMergedField &env, unsigned int batchSize, + std::function getIndexFn) + { + // Loop through variable references + const A archetypeAdaptor(getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + // If model isn't batched or variable isn't duplicated + const auto &varRef = archetypeAdaptor.getInitialisers().at(v.name); + if(batchSize == 1 || !varRef.isDuplicated()) { + // Add field with qualified type which indexes private pointer field + const auto resolvedType = v.type.resolve(getTypeContext()); + const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; + env.add(qualifiedType, v.name, "$(_" + v.name + ")[" + getIndexFn(v.access, varRef) + "]"); + } + } + } + + template + void addPrivateVarRefAccess(EnvironmentGroupMergedField &env, unsigned int batchSize, const std::string &indexSuffix) + { + addPrivateVarRefAccess(env, batchSize, [&indexSuffix](VarAccessMode, const typename A::RefType&){ return indexSuffix; }); + } + + template + void addPrivateVarAccess(EnvironmentGroupMergedField &env, unsigned int batchSize, const std::string &indexSuffix) + { + // Loop through variable references + const A archetypeAdaptor(getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + // Add field with qualified type which indexes private pointer field + const auto resolvedType = v.type.resolve(getTypeContext()); + const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; + env.add(qualifiedType, v.name, "$(_" + v.name + ")[" + indexSuffix + "]"); + } + } + + template + void addTypes(GeNN::Transpiler::TypeChecker::EnvironmentBase &env, const std::vector &vars, + GeNN::Transpiler::ErrorHandlerBase &errorHandler) + { + // Loop through variables + for(const auto &v : vars) { + const auto resolvedType = v.type.resolve(getTypeContext()); + const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; + env.define(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, v.name, 0}, qualifiedType, errorHandler); + } + } + //---------------------------------------------------------------------------- // Members //---------------------------------------------------------------------------- @@ -78,7 +158,7 @@ class GENN_EXPORT CustomConnectivityHostUpdateGroupMerged : public GroupMerged + void addVars(EnvironmentGroupMergedField &env, const std::string &count, const BackendBase &backend) + { + // Loop through variables and add private pointer field + const A archetypeAdaptor(getArchetype()); + for(const auto &v : archetypeAdaptor.getDefs()) { + // If var is located on the host + const auto loc = archetypeAdaptor.getLoc(v.name); + if (loc & VarLocation::HOST) { + // Add pointer field to allow user code to access + const auto resolvedType = v.type.resolve(getTypeContext()); + env.addField(resolvedType.createPointer(), v.name, v.name, + [v](const auto &g, size_t) + { + return v.name + g.getName(); + }, + "", GroupMergedFieldType::HOST); + + // If backend has device variables, also add hidden pointer field with device pointer + if(!backend.getDeviceVarPrefix().empty()) { + env.addField(resolvedType.createPointer(), "_" + backend.getDeviceVarPrefix() + v.name, backend.getDeviceVarPrefix() + v.name, + [v, &backend](const auto &g, size_t) + { + return backend.getDeviceVarPrefix() + v.name + g.getName(); + }); + } + + // Generate code to push this variable + std::stringstream pushStream; + CodeStream push(pushStream); + backend.genVariableDynamicPush(push, resolvedType, v.name, + loc, count, "group->"); + + // Add substitution + env.add(Type::ResolvedType::createFunction(Type::Void, {}), + "push" + v.name + "ToDevice", pushStream.str()); + + // Generate code to pull this variable + std::stringstream pullStream; + CodeStream pull(pullStream); + backend.genVariableDynamicPull(pull, resolvedType, v.name, + loc, count, "group->"); + + // Add substitution + env.add(Type::ResolvedType::createFunction(Type::Void, {}), + "pull" + v.name + "FromDevice", pullStream.str()); + } + } + } }; } // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index b3cfe681b6..95e83f642c 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -358,14 +358,16 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &(GroupInternal::*)(void) const; + using GetVarIndexFn = std::function; + + template + using GetVarRefIndexFn = std::function; template using GetConnectivityFn = const Snippet::Init &(GroupInternal::*)(void) const; - template using GetVarReferencesFn = const std::unordered_map &(GroupInternal::*)(void) const; @@ -567,8 +569,8 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase - void addVars(const std::string &arrayPrefix, I getIndexFn, const std::string &fieldSuffix = "") + template + void addVars(const std::string &arrayPrefix, GetVarIndexFn getIndexFn, const std::string &fieldSuffix = "") { // Loop through variables const A archetypeAdaptor(getGroup().getArchetype()); @@ -577,11 +579,11 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase - void addVarRefs(const std::string &arrayPrefix, I getIndexFn, const std::string &fieldSuffix = "") + template + void addVarRefs(const std::string &arrayPrefix, GetVarRefIndexFn getIndexFn, const std::string &fieldSuffix = "") { // Loop through variable references const A archetypeAdaptor(getGroup().getArchetype()); @@ -608,8 +610,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase>::mapped_type; - using GetIndexFn = std::function; + using GetIndexFn = std::function; VarRefCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex) diff --git a/include/genn/genn/customConnectivityUpdateInternal.h b/include/genn/genn/customConnectivityUpdateInternal.h index 3c7aacaea1..4c6a7f6ba5 100644 --- a/include/genn/genn/customConnectivityUpdateInternal.h +++ b/include/genn/genn/customConnectivityUpdateInternal.h @@ -59,6 +59,8 @@ class CustomConnectivityUpdateVarAdapter const std::unordered_map &getInitialisers() const{ return m_CU.getVarInitialisers(); } + const std::string &getNameSuffix() const{ return m_CU.getName(); } + private: //---------------------------------------------------------------------------- // Members @@ -157,6 +159,8 @@ class CustomConnectivityUpdateVarRefAdapter CustomConnectivityUpdateVarRefAdapter(const CustomConnectivityUpdateInternal &cu) : m_CU(cu) {} + using RefType = Models::WUVarReference; + //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- @@ -180,6 +184,8 @@ class CustomConnectivityUpdatePreVarRefAdapter CustomConnectivityUpdatePreVarRefAdapter(const CustomConnectivityUpdateInternal &cu) : m_CU(cu) {} + using RefType = Models::VarReference; + //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- @@ -203,6 +209,8 @@ class CustomConnectivityUpdatePostVarRefAdapter CustomConnectivityUpdatePostVarRefAdapter(const CustomConnectivityUpdateInternal &cu) : m_CU(cu) {} + using RefType = Models::VarReference; + //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- diff --git a/include/genn/genn/customUpdateInternal.h b/include/genn/genn/customUpdateInternal.h index 81aa4921df..af3a001867 100644 --- a/include/genn/genn/customUpdateInternal.h +++ b/include/genn/genn/customUpdateInternal.h @@ -45,6 +45,8 @@ class CustomUpdateVarRefAdapter CustomUpdateVarRefAdapter(const CustomUpdateInternal &cu) : m_CU(cu) {} + using RefType = Models::VarReference; + //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- @@ -101,6 +103,8 @@ class CustomUpdateWUVarRefAdapter CustomUpdateWUVarRefAdapter(const CustomUpdateWUInternal &cu) : m_CU(cu) {} + using RefType = Models::WUVarReference; + //---------------------------------------------------------------------------- // Public methods //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 1a76f9ebf4..6a9fb77582 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -12,90 +12,6 @@ using namespace GeNN; using namespace GeNN::CodeGenerator; -//---------------------------------------------------------------------------- -// Anonymous namespace -//---------------------------------------------------------------------------- -namespace -{ -template -void addPrivateVarPointerFields(EnvironmentGroupMergedField &env, const std::string &arrayPrefix, const G &group) -{ - // Loop through variable references and add private pointer field - const A archetypeAdaptor(group.getArchetype()); - for(const auto &v : archetypeAdaptor.getDefs()) { - const auto resolvedType = v.type.resolve(getGroup().getTypeContext()); - env.addField(resolvedType.createPointer(), "_" + v.name, v.name + fieldSuffix, - [arrayPrefix, v](const auto &g, size_t) - { - return arrayPrefix + v.name + A(g).getNameSuffix(); - }); - } -} -//---------------------------------------------------------------------------- -template -void addPrivateVarRefPointerFields(EnvironmentGroupMergedField &env, const std::string &arrayPrefix, const G &group) -{ - // Loop through variable references and add private pointer field - const A archetypeAdaptor(group.getArchetype()); - for(const auto &v : archetypeAdaptor.getDefs()) { - const auto resolvedType = v.type.resolve(getGroup().getTypeContext()); - env.addField(resolvedType.createPointer(), "_" + v.name, v.name + fieldSuffix, - [arrayPrefix, v](const auto &g, size_t) - { - const auto varRef = A(g).getInitialisers().at(v.name); - return arrayPrefix + varRef.getVar().name + varRef.getTargetName(); - }); - } -} -//---------------------------------------------------------------------------- -template -void addPrivateVarRefAccess(EnvironmentGroupMergedField &env, const G &group, unsigned int batchSize, I getIndexFn) -{ - // Loop through variable references - const A archetypeAdaptor(group.getArchetype()); - for(const auto &v : archetypeAdaptor.getDefs()) { - // If model isn't batched or variable isn't duplicated - const auto &varRef = archetypeAdaptor.getInitialisers().at(v.name); - if(batchSize == 1 || !varRef.isDuplicated()) { - // Add field with qualified type which indexes private pointer field - const auto resolvedType = v.type.resolve(getGroup().getTypeContext()); - const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; - env.add(qualifiedType, v.name, "$(_" + v.name + ")[" + getIndexFn(v.access, varRef)); - } - } -} -//---------------------------------------------------------------------------- -template -void addPrivateVarRefAccess(EnvironmentGroupMergedField &env, const G &group, unsigned int batchSize, const std::string &indexSuffix) -{ - addPrivateVarRefPointerFields(env, group, batchSize, [&indexSuffix](){ return indexSuffix; }); -} -//---------------------------------------------------------------------------- -template -void addPrivateVarAccess(EnvironmentGroupMergedField &env, const G &group, unsigned int batchSize, const std::string &indexSuffix) -{ - // Loop through variable references - const A archetypeAdaptor(group.getArchetype()); - for(const auto &v : archetypeAdaptor.getDefs()) { - // Add field with qualified type which indexes private pointer field - const auto resolvedType = v.type.resolve(getGroup().getTypeContext()); - const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; - env.add(qualifiedType, v.name, "$(_" + v.name + ")[" + getIndexFn(v.access, varRef)); - } -} -//---------------------------------------------------------------------------- -template -void addTypes(Transpiler::TypeChecker::EnvironmentBase &env, const std::vector &vars, - const Type::TypeContext &typeContext, Transpiler::ErrorHandler::ErrorHandlerBase &errorHandle) -{ - for(const auto &v : vars) { - const auto resolvedType = v.type.resolve(typeContext); - const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; - env.define(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, v.name, 0}, qualifiedType, errorHandler); - } -} -} // Anonymous namespace - //---------------------------------------------------------------------------- // CodeGenerator::CustomConnectivityUpdateGroupMerged //---------------------------------------------------------------------------- @@ -224,7 +140,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateGroupMerged::get return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const +void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { // Create new environment to add current source fields to neuron update group EnvironmentGroupMergedField updateEnv(env, *this); @@ -289,10 +205,10 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Add fields and private $(_XXX) substitutions for postsyanptic and synaptic variables and variables references as, // while these can only be accessed by user code inside loop, they can be used directly by add/remove synapse functions - addPrivateVarPointerFields(updateEnv, backend.getDeviceVarPrefix(), *this); - addPrivateVarPointerFields(updateEnv, backend.getDeviceVarPrefix(), *this); - addPrivateVarRefPointerFields(updateEnv, backend.getDeviceVarPrefix(), *this); - addPrivateVarRefPointerFields(updateEnv, backend.getDeviceVarPrefix(), *this); + addPrivateVarPointerFields(updateEnv, backend.getDeviceVarPrefix()); + addPrivateVarPointerFields(updateEnv, backend.getDeviceVarPrefix()); + addPrivateVarRefPointerFields(updateEnv, backend.getDeviceVarPrefix()); + addPrivateVarRefPointerFields(updateEnv, backend.getDeviceVarPrefix()); // Add private fields for dependent variables for(size_t i = 0; i < getSortedArchetypeDependentVars().size(); i++) { @@ -393,11 +309,11 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back removeSynapse << "const unsigned lastIdx = $(_row_start_idx) + $(_row_length)[$(id_pre)] - 1;" << std::endl; // Copy postsynaptic target from end of row over synapse to be deleted - removeSynapse << "$(_ind)[idx] = $(_ind)[lastIdx];" << std::endl; + removeSynapse << "$(_ind)[$(id_syn)] = $(_ind)[lastIdx];" << std::endl; // Copy custom connectivity update variables from end of row over synapse to be deleted for (size_t i = 0; i < ccuVars.size(); i++) { - removeSynapse << "$(_" << ccuVars[i].name << ")[idx] = $(_" << ccuVars[i].name << ")[lastIdx];" << std::endl; + removeSynapse << "$(_" << ccuVars[i].name << ")[$(id_syn)] = $(_" << ccuVars[i].name << ")[lastIdx];" << std::endl; } // Loop through variable references @@ -410,13 +326,13 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back removeSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(addSynapse); - removeSynapse << "$(_" << ccuVarRefs[i].name << ")[(b * $(_syn_stride)) + idx] = "; - removeSynapse << "$(_" << ccuVarRefs[i].name << ")[(b * $(_syn_stride)) + lastIdx];" << std::endl; + removeSynapse << "$(_" << ccuVarRefs[i].name << ")[(b * $(_syn_stride)) + $(id_syn)] = "; + removeSynapse << "$(_" << ccuVarRefs[i].name << ")[(b * $(_syn_stride)) + $(id_syn)];" << std::endl; } } // Otherwise, copy custom connectivity update variable references from end of row over synapse to be deleted else { - removeSynapse << "$(_" << ccuVarRefs[i].name << ")[idx] = $(_" << ccuVarRefs[i].name << ")[lastIdx];" << std::endl; + removeSynapse << "$(_" << ccuVarRefs[i].name << ")[$(id_syn)] = $(_" << ccuVarRefs[i].name << ")[lastIdx];" << std::endl; } } @@ -428,13 +344,13 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back removeSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { CodeStream::Scope b(removeSynapse); - removeSynapse << "$(_dependent_var_" << i << ")[(b * $(_syn_stride)) + idx] = "; + removeSynapse << "$(_dependent_var_" << i << ")[(b * $(_syn_stride)) + $(id_syn)] = "; removeSynapse << "$(_dependent_var_" << i << ")[(b * $(_syn_stride)) + lastIdx];" << std::endl; } } // Otherwise, copy dependent variable from end of row over synapse to be deleted else { - removeSynapse << "$(_dependent_var_" << i << ")[idx] = $(_dependent_var_" << i << ")[lastIdx];" << std::endl; + removeSynapse << "$(_dependent_var_" << i << ")[$(id_syn)] = $(_dependent_var_" << i << ")[lastIdx];" << std::endl; } } @@ -446,27 +362,27 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back removeSynapse << "j--;" << std::endl; } - // Add function substitution with parameters to initialise custom connectivity update variables and variable references - updateEnv.add(Type::ResolvedType::createFunction(Type::Void, {}), "remove_synapse", removeSynapseStream.str()); - // Pretty print code back to environment - Transpiler::ErrorHandler errorHandler("Current source injection" + std::to_string(getIndex())); + Transpiler::ErrorHandler errorHandler("Custom connectivity update" + std::to_string(getIndex())); prettyPrintStatements(cm->getRowUpdateCode(), getTypeContext(), updateEnv, errorHandler, // Within for_each_synapse loops, define the following types [this](auto &env, auto &errorHandler) { + // Add type of remove synapse function + env.define(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, "remove_synapse", 0}, Type::ResolvedType::createFunction(Type::Void, {}), errorHandler); + // Add typed indices env.define(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, "id_post", 0}, Type::Uint32.addConst(), errorHandler); env.define(Transpiler::Token{Transpiler::Token::Type::IDENTIFIER, "id_syn", 0}, Type::Uint32.addConst(), errorHandler); // Add types for variables and variable references accessible within loop // **TODO** filter - addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getVars(), getTypeContext(), errorHandler); - addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getPostVars(), getTypeContext(), errorHandler); - addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getVarRefs(), getTypeContext(), errorHandler); - addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getPostVarRefs(), getTypeContext(), errorHandler); + addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getVars(), errorHandler); + addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getPostVars(), errorHandler); + addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getVarRefs(), errorHandler); + addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getPostVarRefs(), errorHandler); }, - [&backend, &modelMerged, this](auto &env, auto generateBody) + [&backend, &modelMerged, &removeSynapseStream, this](auto &env, auto generateBody) { EnvironmentGroupMergedField bodyEnv(env, *this); bodyEnv.getStream() << printSubs("for(int j = 0; j < $(_row_length)[$(id_pre)]; j++)", bodyEnv); @@ -483,9 +399,9 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back bodyEnv.addVars(backend.getDeviceVarPrefix(), "$(id_post)"); // Add postsynaptic and synapse variable references, only exposing those that aren't batched - addPrivateVarRefAccess(bodyEnv, *this, modelMerged.getModel().getBatchSize(), "$(id_syn)"); + addPrivateVarRefAccess(bodyEnv, modelMerged.getModel().getBatchSize(), "$(id_syn)"); addPrivateVarRefAccess( - bodyEnv, *this, modelMerged.getModel().getBatchSize(), + bodyEnv, modelMerged.getModel().getBatchSize(), [](VarAccessMode a, const Models::VarReference &varRef) { if(varRef.getDelayNeuronGroup() != nullptr) { @@ -494,7 +410,10 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back else { return "$(id_post)"; } - }); + }); + + // Add function substitution with parameters to initialise custom connectivity update variables and variable references + bodyEnv.add(Type::ResolvedType::createFunction(Type::Void, {}), "remove_synapse", removeSynapseStream.str()); // Generate body of for_each_synapse loop within this new environment generateBody(bodyEnv); @@ -517,66 +436,62 @@ bool CustomConnectivityUpdateGroupMerged::isDerivedParamHeterogeneous(const std: //---------------------------------------------------------------------------- const std::string CustomConnectivityHostUpdateGroupMerged::name = "CustomConnectivityHostUpdate"; //---------------------------------------------------------------------------- -/*CustomConnectivityHostUpdateGroupMerged::CustomConnectivityHostUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, - const std::vector> &groups) -: CustomConnectivityUpdateGroupMergedBase(index, typeContext, groups) +void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { - using namespace Type; + CodeStream::Scope b(env.getStream()); - // Add pre and postsynaptic variables - const auto *cm = getArchetype().getCustomConnectivityUpdateModel(); - addVars(backend, cm->getPreVars(), &CustomConnectivityUpdateInternal::getPreVarLocation); - addVars(backend, cm->getPostVars(), &CustomConnectivityUpdateInternal::getPostVarLocation); - - // Add host extra global parameters - for(const auto &e : cm->getExtraGlobalParams()) { - const auto resolvedType = e.type.resolve(getTypeContext()); - addField(resolvedType.createPointer(), e.name, - [e](const auto &g, size_t) { return e.name + g.getName(); }, - GroupMergedFieldType::HOST_DYNAMIC); - - if(!backend.getDeviceVarPrefix().empty()) { - addField(resolvedType.createPointer(), backend.getDeviceVarPrefix() + e.name, - [e, &backend](const auto &g, size_t) - { - return backend.getDeviceVarPrefix() + e.name + g.getName(); - }, - GroupMergedFieldType::DYNAMIC); - } - } - -}*/ -//---------------------------------------------------------------------------- -void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) const -{ - CodeStream::Scope b(os); - os << "// merged custom connectivity host update group " << getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << getGroups().size() << "; g++)"; + env.getStream() << "// merged custom connectivity host update group " << getIndex() << std::endl; + env.getStream() << "for(unsigned int g = 0; g < " << getGroups().size() << "; g++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); // Get reference to group - os << "const auto *group = &mergedCustomConnectivityHostUpdateGroup" << getIndex() << "[g]; " << std::endl; + env.getStream() << "const auto *group = &mergedCustomConnectivityHostUpdateGroup" << getIndex() << "[g]; " << std::endl; + + // Create matching environment + EnvironmentGroupMergedField groupEnv(env, *this); + + // Add fields for number of pre and postsynaptic neurons + groupEnv.addField(Type::Uint32.addConst(), "num_pre", + Type::Uint32, "numSrcNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); + }); + groupEnv.addField(Type::Uint32.addConst(), "num_post", + Type::Uint32, "numTrgNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); + }); - // Create substitutions + + // Substitute parameter and derived parameter names const auto *cm = getArchetype().getCustomConnectivityUpdateModel(); - Substitutions subs; - subs.addVarSubstitution("rng", "hostRNG"); - subs.addVarSubstitution("num_pre", "group->numSrcNeurons"); - subs.addVarSubstitution("num_post", "group->numTrgNeurons"); - subs.addVarNameSubstitution(cm->getExtraGlobalParams(), "", "group->"); - subs.addVarNameSubstitution(cm->getPreVars(), "", "group->"); - subs.addVarNameSubstitution(cm->getPostVars(), "", "group->"); - subs.addParamValueSubstitution(cm->getParamNames(), getArchetype().getParams(), - [this](const std::string &p) { return isParamHeterogeneous(p); }, - "", "group->"); - subs.addVarValueSubstitution(cm->getDerivedParams(), getArchetype().getDerivedParams(), - [this](const std::string & p) { return isDerivedParamHeterogeneous(p); }, - "", "group->"); + groupEnv.addParams(cm->getParamNames(), "", &CustomConnectivityUpdateInternal::getParams, + &CustomConnectivityHostUpdateGroupMerged::isParamHeterogeneous); + groupEnv.addDerivedParams(cm->getDerivedParams(), "", &CustomConnectivityUpdateInternal::getDerivedParams, + &CustomConnectivityHostUpdateGroupMerged::isDerivedParamHeterogeneous); // Loop through EGPs for(const auto &egp : cm->getExtraGlobalParams()) { + // Add pointer field to allow user code to access const auto resolvedType = egp.type.resolve(getTypeContext()); + groupEnv.addField(resolvedType.createPointer(), egp.name, egp.name, + [egp](const auto &g, size_t) { return egp.name + g.getName(); }, + "", GroupMergedFieldType::HOST_DYNAMIC); + + // If backend has device variables, also add hidden pointer field with device pointer + if(!backend.getDeviceVarPrefix().empty()) { + groupEnv.addField(resolvedType.createPointer(), "_" + backend.getDeviceVarPrefix() + egp.name, backend.getDeviceVarPrefix() + egp.name, + [egp, &backend](const auto &g, size_t) + { + return backend.getDeviceVarPrefix() + egp.name + g.getName(); + }, + "", GroupMergedFieldType::DYNAMIC); + } // Generate code to push this EGP with count specified by $(0) std::stringstream pushStream; @@ -585,7 +500,8 @@ void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase & VarLocation::HOST_DEVICE, "$(0)", "group->"); // Add substitution - subs.addFuncSubstitution("push" + egp.name + "ToDevice", 1, pushStream.str()); + groupEnv.add(Type::ResolvedType::createFunction(Type::Void, {Type::Uint32}), + "push" + egp.name + "ToDevice", pushStream.str()); // Generate code to pull this EGP with count specified by $(0) std::stringstream pullStream; @@ -594,19 +510,18 @@ void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase & VarLocation::HOST_DEVICE, "$(0)", "group->"); // Add substitution - subs.addFuncSubstitution("pull" + egp.name + "FromDevice", 1, pullStream.str()); + groupEnv.add(Type::ResolvedType::createFunction(Type::Void, {Type::Uint32}), + "pull" + egp.name + "FromDevice", pullStream.str()); } - addVarPushPullFuncSubs(backend, subs, cm->getPreVars(), "group->numSrcNeurons", - &CustomConnectivityUpdateInternal::getPreVarLocation); - addVarPushPullFuncSubs(backend, subs, cm->getPostVars(), "group->numTrgNeurons", - &CustomConnectivityUpdateInternal::getPostVarLocation); + // Add pre and postsynaptic variables along with push and pull functions + // **TODO** why not pre and post var-references + addVars(groupEnv, "$(num_pre)", backend); + addVars(groupEnv, "$(num_post)", backend); - // Apply substitutons to row update code and write out - std::string code = cm->getHostUpdateCode(); - subs.applyCheckUnreplaced(code, "custom connectivity host update : merged" + std::to_string(getIndex())); - //code = ensureFtype(code, modelMerged.getModel().getPrecision()); - os << code; + // Pretty print code back to environment + Transpiler::ErrorHandler errorHandler("Custom connectivity host update" + std::to_string(getIndex())); + prettyPrintStatements(cm->getHostUpdateCode(), getTypeContext(), groupEnv, errorHandler); } } //---------------------------------------------------------------------------- @@ -618,63 +533,4 @@ bool CustomConnectivityHostUpdateGroupMerged::isParamHeterogeneous(const std::st bool CustomConnectivityHostUpdateGroupMerged::isDerivedParamHeterogeneous(const std::string &name) const { return isParamValueHeterogeneous(name, [](const CustomConnectivityUpdateInternal &cg) { return cg.getDerivedParams(); }); -} -//---------------------------------------------------------------------------- -void CustomConnectivityHostUpdateGroupMerged::addVarPushPullFuncSubs(const BackendBase &backend, Substitutions &subs, - const Models::Base::VarVec &vars, const std::string &count, - VarLocation(CustomConnectivityUpdateInternal:: *getVarLocationFn)(const std::string&) const) const -{ - // Loop through variables - for(const auto &v : vars) { - const auto resolvedType = v.type.resolve(getTypeContext()); - - // If var is located on the host - const auto loc = std::invoke(getVarLocationFn, getArchetype(), v.name); - if (loc & VarLocation::HOST) { - // Generate code to push this variable - std::stringstream pushStream; - CodeStream push(pushStream); - backend.genVariableDynamicPush(push, resolvedType, v.name, - loc, count, "group->"); - - // Add substitution - subs.addFuncSubstitution("push" + v.name + "ToDevice", 0, pushStream.str()); - - // Generate code to pull this variable - // **YUCK** these EGP functions should probably just be called dynamic or something - std::stringstream pullStream; - CodeStream pull(pullStream); - backend.genVariableDynamicPull(pull, resolvedType, v.name, - loc, count, "group->"); - - // Add substitution - subs.addFuncSubstitution("pull" + v.name + "FromDevice", 0, pullStream.str()); - } - } -} -//---------------------------------------------------------------------------- -void CustomConnectivityHostUpdateGroupMerged::addVars(const BackendBase &backend, const Models::Base::VarVec &vars, - VarLocation(CustomConnectivityUpdateInternal:: *getVarLocationFn)(const std::string&) const) -{ - using namespace Type; - - // Loop through variables - for(const auto &v : vars) { - // If var is located on the host - const auto resolvedType = v.type.resolve(getTypeContext()); - if (std::invoke(getVarLocationFn, getArchetype(), v.name) & VarLocation::HOST) { - addField(resolvedType.createPointer(), v.name, - [v](const auto &g, size_t) { return v.name + g.getName(); }, - GroupMergedFieldType::HOST); - - if(!backend.getDeviceVarPrefix().empty()) { - // **TODO** I think could use addPointerField - addField(resolvedType.createPointer(), backend.getDeviceVarPrefix() + v.name, - [v, &backend](const auto &g, size_t) - { - return backend.getDeviceVarPrefix() + v.name + g.getName(); - }); - } - } - } -} +} \ No newline at end of file From bb4c9343a04e06a3aecbd361167108fea5804f98 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 27 Jun 2023 18:59:34 +0100 Subject: [PATCH 261/725] handle void type --- src/genn/genn/type.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index a59a854871..4e8265ee84 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -80,6 +80,10 @@ std::string ResolvedType::getName() const description += (a.getName() + ","); } return description + ")"; + }, + [&qualifier](std::monostate) + { + return qualifier + "void"; }}, detail); } @@ -99,6 +103,10 @@ size_t ResolvedType::getSize(size_t pointerBytes) const [](const Type::ResolvedType::Function&)->size_t { throw std::runtime_error("Function types do not have size"); + }, + [](std::monostate)->size_t + { + throw std::runtime_error("Void type does not have size"); }}, detail); } From 1bb13a3c3af0e0f888fd11c3a877921dcde8aa88 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 27 Jun 2023 18:59:50 +0100 Subject: [PATCH 262/725] fix concatenation typos --- src/genn/genn/code_generator/groupMerged.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 28a2495f9f..9f9f921599 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -166,7 +166,7 @@ std::string SynapseGroupMergedBase::getPostDenDelayIndex(unsigned int batchSize, { assert(getArchetype().isDendriticDelayRequired()); - const std::string batchID = ((batchSize == 1) ? "" : "$(_post_batch_offset) + ") + "$(" + index + ")"; + const std::string batchID = ((batchSize == 1) ? "" : "$(_post_batch_offset) + ") + std::string{"$(" + index + ")"}; if(offset.empty()) { return "(*$(_den_delay_ptr) * $(num_post) + " + batchID; @@ -194,7 +194,7 @@ std::string SynapseGroupMergedBase::getPrePrevSpikeTimeIndex(bool delay, unsigne return (singleBatch ? "$(_pre_prev_spike_time_delay_offset) + " : "$(_pre_prev_spike_time_batch_delay_offset) + ") + index; } else { - return (singleBatch ? "" : "$(_pre_batch_offset) + ") + "$(" + index + ")"; + return (singleBatch ? "" : "$(_pre_batch_offset) + ") + std::string{"$(" + index + ")"}; } } //-------------------------------------------------------------------------- @@ -206,20 +206,20 @@ std::string SynapseGroupMergedBase::getPostPrevSpikeTimeIndex(bool delay, unsign return (singleBatch ? "$(_post_prev_spike_time_delay_offset) + " : "$(_post_prev_spike_time_batch_delay_offset) + ") + "$(" + index + ")";; } else { - return (singleBatch ? "" : "$(_post_batch_offset) + ") + "$(" + index + ")";; + return (singleBatch ? "" : "$(_post_batch_offset) + ") + std::string{"$(" + index + ")"}; } } //-------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getSynVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const { const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); - return (singleBatch ? "" : "$(_syn_batch_offset)") + "$(" + index + ")"; + return (singleBatch ? "" : "$(_syn_batch_offset)") + std::string{"$(" + index + ")"}; } //-------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getKernelVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const { const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); - return (singleBatch ? "" : "$(_kern_batch_offset)") + "$(" + index + ")"; + return (singleBatch ? "" : "$(_kern_batch_offset)") + std::string{"$(" + index + ")"}; } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type SynapseGroupMergedBase::getHashDigest(Role role) const From 139a9e1a9a759dccb04013a5ec2307c79eab528a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 27 Jun 2023 19:00:06 +0100 Subject: [PATCH 263/725] fully hack out hashing --- src/genn/genn/code_generator/generateModules.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/code_generator/generateModules.cc b/src/genn/genn/code_generator/generateModules.cc index 88d0bb370c..155ee53872 100644 --- a/src/genn/genn/code_generator/generateModules.cc +++ b/src/genn/genn/code_generator/generateModules.cc @@ -103,6 +103,11 @@ std::pair, MemAlloc> generateAll(const ModelSpecInterna // Create merged model ModelSpecMerged modelMerged(model, backend); + // **TODO** because merged group fields are populated in the same pass + // as code is generated, we will need to ALWAYS generate code but only + // write it if hashes match - this will be less gross once we're doing + // runtime compilation + // If force rebuild flag is set or model should be rebuilt //const auto hashDigest = modelMerged.getHashDigest(backend); MemAlloc mem = MemAlloc::zero(); @@ -129,7 +134,7 @@ std::pair, MemAlloc> generateAll(const ModelSpecInterna } // Open file - std::ofstream os((outputPath / "model.sha").str()); + /*std::ofstream os((outputPath / "model.sha").str()); // Write digest as hex with each word seperated by a space os << std::hex; @@ -140,7 +145,7 @@ std::pair, MemAlloc> generateAll(const ModelSpecInterna // Write model memory usage estimates so it can be reloaded if code doesn't need re-generating os << std::dec; - os << mem << std::endl; + os << mem << std::endl;*/ } // Show memory usage From 998f1a2c4ccbaa4d9c311f54415b29ab08beb9b5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 27 Jun 2023 19:00:34 +0100 Subject: [PATCH 264/725] de-inlined method --- include/genn/genn/customUpdate.h | 2 +- src/genn/genn/customUpdate.cc | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index c88971c31e..69a7b50fe6 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -297,7 +297,7 @@ class GENN_EXPORT CustomUpdateWU : public CustomUpdateBase SynapseGroupInternal *getSynapseGroup() const { return m_SynapseGroup; } - const std::vector &getKernelSize() const { return getSynapseGroup()->getKernelSize(); } + const std::vector &getKernelSize() const; //! Updates hash with custom update /*! NOTE: this can only be called after model is finalized */ diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 497fe2e60d..677dee5dbc 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -260,6 +260,11 @@ bool CustomUpdateWU::isTransposeOperation() const [](const auto &v) { return (v.second.getTransposeSynapseGroup() != nullptr); }); } //---------------------------------------------------------------------------- +const std::vector &CustomUpdateWU::getKernelSize() const +{ + return getSynapseGroup()->getKernelSize(); +} +//---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type CustomUpdateWU::getHashDigest() const { // Superclass From 48c541fb67d2de4f74a4b08ad9d88f89e1ff0f38 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 27 Jun 2023 19:00:43 +0100 Subject: [PATCH 265/725] started fixing up backend SIMT --- .../genn/genn/code_generator/backendBase.h | 5 - .../genn/genn/code_generator/backendSIMT.h | 197 +++++------------- src/genn/genn/code_generator/backendSIMT.cc | 12 +- 3 files changed, 60 insertions(+), 154 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 7a93a12612..c44433ff1a 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -187,13 +187,8 @@ class GENN_EXPORT BackendBase //-------------------------------------------------------------------------- typedef std::function HostHandler; - typedef std::function Handler; - typedef std::function HandlerEnv; - template - using GroupHandler = std::function ; - template using GroupHandlerEnv = std::function ; diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index b12d52bf65..5151e45bcc 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -122,20 +122,20 @@ class GENN_EXPORT BackendSIMT : public BackendBase //! This function returns the device prefix so it can be used in otherwise platform-independent code. virtual std::string getDeviceVarPrefix() const final { return getPreferences().automaticCopy ? "" : "d_"; } - virtual void genPopVariableInit(EnvironmentExternal &env, HandlerEnv handler) const final; - virtual void genVariableInit(EnvironmentExternal &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const final; - virtual void genSparseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const final + virtual void genPopVariableInit(EnvironmentExternalBase &env, HandlerEnv handler) const final; + virtual void genVariableInit(EnvironmentExternalBase &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const final; + virtual void genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const final { - genSynapseVariableRowInit(os, kernelSubs, handler); + genSynapseVariableRowInit(env, handler); } - virtual void genDenseSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const final + virtual void genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const final { - genSynapseVariableRowInit(os, kernelSubs, handler); + genSynapseVariableRowInit(env, handler); } - virtual void genKernelSynapseVariableInit(CodeStream &os, const SynapseInitGroupMerged &sg, const Substitutions &kernelSubs, Handler handler) const final; - virtual void genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged &cu, const Substitutions &kernelSubs, Handler handler) const final; + virtual void genKernelSynapseVariableInit(EnvironmentExternalBase &env, const SynapseInitGroupMerged &sg, HandlerEnv handler) const final; + virtual void genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, const CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const final; //! Should 'scalar' variables be implemented on device or can host variables be used directly? virtual bool isDeviceScalarRequired() const final { return true; } @@ -226,93 +226,6 @@ class GENN_EXPORT BackendSIMT : public BackendBase //-------------------------------------------------------------------------- // Private methods //-------------------------------------------------------------------------- - template - void genParallelGroup(CodeStream &os, const Substitutions &kernelSubs, const std::vector &groups, size_t &idStart, - S getPaddedSizeFunc, F filter, GroupHandler handler) const - { - // Loop through groups - for(const auto &gMerge : groups) { - if(filter(gMerge)) { - // Sum padded sizes of each group within merged group - const size_t paddedSize = std::accumulate( - gMerge.getGroups().cbegin(), gMerge.getGroups().cend(), size_t{0}, - [getPaddedSizeFunc](size_t acc, std::reference_wrapper g) - { - return (acc + getPaddedSizeFunc(g.get())); - }); - - os << "// merged" << gMerge.getIndex() << std::endl; - - // If this is the first group - if(idStart == 0) { - os << "if(id < " << paddedSize << ")"; - } - else { - os << "if(id >= " << idStart << " && id < " << idStart + paddedSize << ")"; - } - { - CodeStream::Scope b(os); - Substitutions popSubs(&kernelSubs); - - if(gMerge.getGroups().size() == 1) { - os << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; - os << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[0]; " << std::endl; - os << "const unsigned int lid = id - " << idStart << ";" << std::endl; - - // Use the starting thread ID of the whole merged group as group_start_id - popSubs.addVarSubstitution("group_start_id", std::to_string(idStart)); - } - else { - // Perform bisect operation to get index of merged struct - os << "unsigned int lo = 0;" << std::endl; - os << "unsigned int hi = " << gMerge.getGroups().size() << ";" << std::endl; - os << "while(lo < hi)" << std::endl; - { - CodeStream::Scope b(os); - os << "const unsigned int mid = (lo + hi) / 2;" << std::endl; - - os << "if(id < d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[mid])"; - { - CodeStream::Scope b(os); - os << "hi = mid;" << std::endl; - } - os << "else"; - { - CodeStream::Scope b(os); - os << "lo = mid + 1;" << std::endl; - } - } - - // Use this to get reference to merged group structure - os << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; - os << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[lo - 1]; " << std::endl; - - // Get group start thread ID and use as group_start_id - os << "const unsigned int groupStartID = d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[lo - 1];" << std::endl; - popSubs.addVarSubstitution("group_start_id", "groupStartID"); - - // Use this to calculate local id within group - os << "const unsigned int lid = id - groupStartID;" << std::endl; - } - popSubs.addVarSubstitution("id", "lid"); - - handler(os, gMerge, popSubs); - - idStart += paddedSize; - } - } - } - } - - - template - void genParallelGroup(CodeStream &os, const Substitutions &kernelSubs, const std::vector &groups, size_t &idStart, - S getPaddedSizeFunc, GroupHandler handler) const - { - genParallelGroup(os, kernelSubs, groups, idStart, getPaddedSizeFunc, - [](const T &) { return true; }, handler); - } - template void genParallelGroup(EnvironmentExternal &env, const std::vector &groups, size_t &idStart, S getPaddedSizeFunc, F filter, GroupHandlerEnv handler) const @@ -339,51 +252,51 @@ class GENN_EXPORT BackendSIMT : public BackendBase } { CodeStream::Scope b(env.getStream()); - EnvironmentSubstitute popEnv(env); if(gMerge.getGroups().size() == 1) { - popEnv.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; - popEnv.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[0]; " << std::endl; - popEnv.getStream() << "const unsigned int lid = id - " << idStart << ";" << std::endl; + env.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; + env.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[0]; " << std::endl; + env.getStream() << "const unsigned int lid = id - " << idStart << ";" << std::endl; // Use the starting thread ID of the whole merged group as group_start_id - popEnv.addSubstitution("group_start_id", std::to_string(idStart)); + env.add(Type::Uint32.addConst(), "group_start_id", std::to_string(idStart)); } else { // Perform bisect operation to get index of merged struct - popEnv.getStream() << "unsigned int lo = 0;" << std::endl; - popEnv.getStream() << "unsigned int hi = " << gMerge.getGroups().size() << ";" << std::endl; - popEnv.getStream() << "while(lo < hi)" << std::endl; + env.getStream() << "unsigned int lo = 0;" << std::endl; + env.getStream() << "unsigned int hi = " << gMerge.getGroups().size() << ";" << std::endl; + env.getStream() << "while(lo < hi)" << std::endl; { - CodeStream::Scope b(popEnv.getStream()); - popEnv.getStream() << "const unsigned int mid = (lo + hi) / 2;" << std::endl; + CodeStream::Scope b(env.getStream()); + env.getStream() << "const unsigned int mid = (lo + hi) / 2;" << std::endl; - popEnv.getStream() << "if(id < d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[mid])"; + env.getStream() << "if(id < d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[mid])"; { - CodeStream::Scope b(popEnv.getStream()); - popEnv.getStream() << "hi = mid;" << std::endl; + CodeStream::Scope b(env.getStream()); + env.getStream() << "hi = mid;" << std::endl; } - popEnv.getStream() << "else"; + env.getStream() << "else"; { - CodeStream::Scope b(popEnv.getStream()); - popEnv.getStream() << "lo = mid + 1;" << std::endl; + CodeStream::Scope b(env.getStream()); + env.getStream() << "lo = mid + 1;" << std::endl; } } // Use this to get reference to merged group structure - popEnv.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; - popEnv.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[lo - 1]; " << std::endl; + env.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; + env.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[lo - 1]; " << std::endl; + // Get group start thread ID and use as group_start_id - popEnv.getStream() << "const unsigned int groupStartID = d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[lo - 1];" << std::endl; - popEnv.addSubstitution("group_start_id", "groupStartID"); + env.getStream() << "const unsigned int groupStartID = d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[lo - 1];" << std::endl; + env.add(Type::Uint32.addConst(), "group_start_id", "groupStartID"); // Use this to calculate local id within group - popEnv.getStream() << "const unsigned int lid = id - groupStartID;" << std::endl; + env.getStream() << "const unsigned int lid = id - groupStartID;" << std::endl; } - popEnv.addSubstitution("id", "lid"); + env.add(Type::Uint32.addConst(), "id", "lid"); - handler(popEnv, gMerge); + handler(env, gMerge); idStart += paddedSize; } @@ -478,57 +391,57 @@ class GENN_EXPORT BackendSIMT : public BackendBase // Helper function to generate kernel code to initialise variables associated with synapse group or custom WU update with sparse connectivity template - void genSparseSynapseVarInit(CodeStream &os, const ModelSpecMerged &modelMerged, const G &g, Substitutions &popSubs, - bool varInitRequired, GroupHandler handler) const + void genSparseSynapseVarInit(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, const G &g, + bool varInitRequired, GroupHandlerEnv handler) const { // Calculate how many blocks rows need to be processed in (in order to store row lengths in shared memory) const size_t blockSize = getKernelBlockSize(KernelInitializeSparse); - os << "const unsigned int numBlocks = (group->numSrcNeurons + " << blockSize << " - 1) / " << blockSize << ";" << std::endl; + env.getStream() << "const unsigned int numBlocks = (" << env["num_pre"] << " + " << blockSize << " - 1) / " << blockSize << ";" << std::endl; - os << "unsigned int idx = " << popSubs["id"] << ";" << std::endl; + env.getStream() << "unsigned int idx = " << env["id"] << ";" << std::endl; // Loop through blocks - os << "for(unsigned int r = 0; r < numBlocks; r++)"; + env.getStream() << "for(unsigned int r = 0; r < numBlocks; r++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); // Calculate number of rows to process in this block - os << "const unsigned numRowsInBlock = (r == (numBlocks - 1))"; - os << " ? ((group->numSrcNeurons - 1) % " << blockSize << ") + 1"; - os << " : " << blockSize << ";" << std::endl; + env.getStream() << "const unsigned numRowsInBlock = (r == (numBlocks - 1))"; + env.getStream() << " ? ((" << env["num_pre"] << " - 1) % " << blockSize << ") + 1"; + env.getStream() << " : " << blockSize << ";" << std::endl; // Use threads to copy block of sparse structure into shared memory - genSharedMemBarrier(os); - os << "if (" << getThreadID() << " < numRowsInBlock)"; + genSharedMemBarrier(env.getStream()); + env.getStream() << "if (" << getThreadID() << " < numRowsInBlock)"; { - CodeStream::Scope b(os); - os << "shRowLength[" << getThreadID() << "] = group->rowLength[(r * " << blockSize << ") + " << getThreadID() << "];" << std::endl; + CodeStream::Scope b(env.getStream()); + env.getStream() << "shRowLength[" << getThreadID() << "] = " << env["_row_length"] << "[(r * " << blockSize << ") + " << getThreadID() << "];" << std::endl; } - genSharedMemBarrier(os); + genSharedMemBarrier(env.getStream()); // Loop through rows - os << "for(unsigned int i = 0; i < numRowsInBlock; i++)"; + env.getStream() << "for(unsigned int i = 0; i < numRowsInBlock; i++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); // If there is a synapse for this thread to initialise - os << "if(" << popSubs["id"] << " < shRowLength[i])"; + env.getStream() << "if(" << env["id"] << " < shRowLength[i])"; { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); // Generate initialisation code if(varInitRequired) { - popSubs.addVarSubstitution("id_pre", "((r * " + std::to_string(blockSize) + ") + i)"); - popSubs.addVarSubstitution("id_post", "group->ind[idx]"); - g.generateInit(*this, os, modelMerged, popSubs); + env.add(Type::Uint32.addConst(), "id_pre", "((r * " + std::to_string(blockSize) + ") + i)"); + env.add(Type::Uint32.addConst(), "id_post", "$(_ind)[idx]"); + g.generateInit(*this, env, modelMerged); } // Call handler - handler(os, g, popSubs); + handler(env, g); } // If matrix is ragged, advance index to next row by adding stride - os << "idx += group->rowStride;" << std::endl; + env.getStream() << "idx += " << env["_row_stride"] << ";" << std::endl; } } } @@ -537,7 +450,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase void genRecordingSharedMemInit(CodeStream &os, const std::string &suffix) const; - void genSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const; + void genSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const; // Get appropriate presynaptic update strategy to use for this synapse group const PresynapticUpdateStrategySIMT::Base *getPresynapticUpdateStrategy(const SynapseGroupInternal &sg) const diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index b3028171a3..1a22b05889 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -1745,15 +1745,13 @@ void BackendSIMT::genRecordingSharedMemInit(CodeStream &os, const std::string &s } } //-------------------------------------------------------------------------- -void BackendSIMT::genSynapseVariableRowInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const +void BackendSIMT::genSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const { - // Pre and postsynaptic ID should already be provided via parallelism - assert(kernelSubs.hasVarSubstitution("id_pre")); - assert(kernelSubs.hasVarSubstitution("id_post")); + EnvironmentExternal varEnv(env); - Substitutions varSubs(&kernelSubs); - varSubs.addVarSubstitution("id_syn", "(" + kernelSubs["id_pre"] + " * group->rowStride) + " + kernelSubs["id"]); - handler(os, varSubs); + // **TODO** 64-bit id_syn + varEnv.add(Type::Uint32.addConst(), "id_syn", "($(id_pre) * $(_row_stride)) + $(id)"); + handler(varEnv); } //-------------------------------------------------------------------------- const PresynapticUpdateStrategySIMT::Base *BackendSIMT::getPresynapticUpdateStrategy(const SynapseGroupInternal &sg, From 0e233c4f301c7a40bb09a243d3316ac69df21b84 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 28 Jun 2023 17:00:37 +0100 Subject: [PATCH 266/725] * New print and printLine methods in EnvironmentExternalBase to save typing * Implemented most of SIMT presynaptic update strategies and backend --- .../genn/genn/code_generator/backendBase.h | 2 +- .../genn/genn/code_generator/backendSIMT.h | 73 +- .../genn/genn/code_generator/environment.h | 6 + .../genn/genn/code_generator/groupMerged.h | 4 +- .../presynapticUpdateStrategySIMT.h | 112 +-- .../backends/single_threaded_cpu/backend.cc | 74 +- src/genn/genn/code_generator/backendBase.cc | 12 +- src/genn/genn/code_generator/backendSIMT.cc | 929 +++++++++--------- src/genn/genn/code_generator/codeGenUtils.cc | 1 + .../customConnectivityUpdateGroupMerged.cc | 2 +- src/genn/genn/code_generator/environment.cc | 10 + .../genn/code_generator/initGroupMerged.cc | 4 +- .../code_generator/neuronUpdateGroupMerged.cc | 44 +- .../presynapticUpdateStrategySIMT.cc | 492 +++++----- 14 files changed, 871 insertions(+), 894 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index c44433ff1a..dc2a2ea440 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -719,7 +719,7 @@ class GENN_EXPORT BackendBase } void genCustomUpdateIndexCalculation(EnvironmentGroupMergedField &env) const; - void genCustomConnectivityUpdateIndexCalculation(CodeStream &os, const CustomConnectivityUpdateGroupMerged &cu) const; + void genCustomConnectivityUpdateIndexCalculation(EnvironmentGroupMergedField &env) const; //! Get the initial value to start reduction operations from std::string getReductionInitialValue(VarAccessMode access, const Type::ResolvedType &type) const; diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 5151e45bcc..5cc9d4609d 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -183,15 +183,15 @@ class GENN_EXPORT BackendSIMT : public BackendBase //------------------------------------------------------------------------ // Protected API //------------------------------------------------------------------------ - void genNeuronPrevSpikeTimeUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const; - void genNeuronSpikeQueueUpdateKernel(CodeStream &os, const ModelSpecMerged &modelMerged, size_t &idStart) const; + void genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; + void genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; - void genNeuronUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const; + void genNeuronUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; - void genSynapseDendriticDelayUpdateKernel(CodeStream &os, const ModelSpecMerged &modelMerged, size_t &idStart) const; - void genPresynapticUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const; - void genPostsynapticUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const; - void genSynapseDynamicsKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const; + void genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; + void genPresynapticUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; + void genPostsynapticUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; + void genSynapseDynamicsKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; void genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const; @@ -202,12 +202,12 @@ class GENN_EXPORT BackendSIMT : public BackendBase void genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const; - void genCustomConnectivityUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, + void genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const; - void genInitializeKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const; + void genInitializeKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; - void genInitializeSparseKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, + void genInitializeSparseKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t numInitializeThreads, size_t &idStart) const; //! Helper wrapper around padSize to pad size to a kernel size @@ -227,7 +227,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase // Private methods //-------------------------------------------------------------------------- template - void genParallelGroup(EnvironmentExternal &env, const std::vector &groups, size_t &idStart, + void genParallelGroup(EnvironmentExternalBase &env, const std::vector &groups, size_t &idStart, S getPaddedSizeFunc, F filter, GroupHandlerEnv handler) const { // Loop through groups @@ -289,7 +289,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase // Get group start thread ID and use as group_start_id env.getStream() << "const unsigned int groupStartID = d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[lo - 1];" << std::endl; - env.add(Type::Uint32.addConst(), "group_start_id", "groupStartID"); + env.add(Type::Uint32.addConst(), "_group_start_id", "groupStartID"); // Use this to calculate local id within group env.getStream() << "const unsigned int lid = id - groupStartID;" << std::endl; @@ -306,7 +306,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase template - void genParallelGroup(EnvironmentExternal &env, const std::vector &groups, size_t &idStart, + void genParallelGroup(EnvironmentExternalBase &env, const std::vector &groups, size_t &idStart, S getPaddedSizeFunc, GroupHandlerEnv handler) const { genParallelGroup(env, groups, idStart, getPaddedSizeFunc, @@ -315,31 +315,32 @@ class GENN_EXPORT BackendSIMT : public BackendBase // Helper function to generate kernel code to initialise variables associated with synapse group or custom WU update with dense/kernel connectivity template - void genSynapseVarInit(CodeStream &os, const ModelSpecMerged &modelMerged, const G &g, Substitutions &popSubs, + void genSynapseVarInit(EnvironmentGroupMergedField &env, const ModelSpecMerged &modelMerged, bool initRNGRequired, bool kernel, size_t kernelDimensions) const { - os << "if(" << popSubs["id"] << " < "; + env.getStream() << "if(" << env["id"] << " < "; // If synapse group has kernel weights, check ID against product of kernel dimensions if (kernel) { // Loop through kernel dimensions and multiply together - os << "("; + env.getStream() << "("; for (size_t i = 0; i < kernelDimensions; i++) { - os << g.getKernelSize(i); + env.print(getKernelSize(g, i)); if (i != (kernelDimensions - 1)) { - os << " * "; + env.getStream() << " * "; } } - os << ")"; + env.getStream() << ")"; } // Otherwise, against number of postsynaptic neurons else { - os << "group->numTrgNeurons"; + env.getStream() << env["num_post"]; } - os << ")"; + env.getStream() << ")"; { - CodeStream::Scope b(os); - + CodeStream::Scope b(env.getStream()); + EnvironmentGroupMergedField initEnv(env, env.getGroup()); + // If an RNG is required for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id @@ -351,41 +352,43 @@ class GENN_EXPORT BackendSIMT : public BackendBase if (kernel) { // Loop through kernel dimensions to generate seperate indices for (size_t i = 0; i < kernelDimensions; i++) { - os << "const unsigned int kernelID" << i << " = (" << popSubs["id"]; + std::ostringstream kernelIDInit; + kernelIDInit << "const unsigned int kernelID" << i << " = ($(id)"; // If this isn't the last dimension if (i < (kernelDimensions - 1)) { // Loop backwards through other kernel and generate code to divide by product of subsequent dimensions - os << " / ("; + kernelIDInit << " / ("; for (size_t j = (kernelDimensions - 1); j > i; j--) { - os << g.getKernelSize(j); + kernelIDInit << getKernelSize(env.getGroup(), j); if (j != (i + 1)) { - os << " * "; + kernelIDInit << " * "; } } - os << ")"; + kernelIDInit << ")"; } - os << ")"; + kernelIDInit << ")"; // If this isn't the first dimension, take modulus of kernel size if (i > 0) { - os << " % " << g.getKernelSize(i); + kernelIDInit << " % " << getKernelSize(env.getGroup(), i); } - os << ";" << std::endl; + kernelIDInit << ";" << std::endl; // Add substitution - popSubs.addVarSubstitution("id_kernel_" + std::to_string(i), "kernelID" + std::to_string(i)); + initEnv.add(Type::Uint32.addConst(), "id_kernel_" + std::to_string(i), "kernelID" + std::to_string(i), + {initEnv.addInitialiser(kernelIDInit.str())}); } } // Otherwise, just substitute postsynaptic index else { - popSubs.addVarSubstitution("id_post", popSubs["id"]); + initEnv.add(Type::Uint32.addConst(), "id_post", "$(id)"); } // Generate init code - g.generateInit(*this, os, modelMerged, popSubs); + g.generateInit(*this, initEnv, modelMerged); } } @@ -446,7 +449,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase } } - void genEmitSpike(const ModelSpecMerged &modelMerged, CodeStream &os, const Substitutions &subs, const std::string &suffix, bool recordingEnabled) const; + void genEmitSpike(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, const std::string &suffix, bool recordingEnabled) const; void genRecordingSharedMemInit(CodeStream &os, const std::string &suffix) const; diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 95e83f642c..dc554b308d 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -64,6 +64,12 @@ class EnvironmentExternalBase : public Transpiler::PrettyPrinter::EnvironmentBas virtual void define(const Transpiler::Token &name, const GeNN::Type::ResolvedType &type, Transpiler::ErrorHandlerBase &errorHandler) override; + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + void print(const std::string &format); + void printLine(const std::string &format); + //------------------------------------------------------------------------ // Operators //------------------------------------------------------------------------ diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index a198425667..471b32b384 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -647,12 +647,12 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMergedisDelayRequired() && s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired()) { - groupEnv.getStream() << printSubs("const unsigned int numSpikes = $(_trg_spk_cnt)[$(_post_delay_slot)];", groupEnv) << std::endl; + groupEnv.printLine("const unsigned int numSpikes = $(_trg_spk_cnt)[$(_post_delay_slot)];"); } else { - groupEnv.getStream() << printSubs("const unsigned int numSpikes = $(_trg_spk_cnt)[0];", groupEnv) << std::endl; + groupEnv.printLine("const unsigned int numSpikes = $(_trg_spk_cnt)[0];"); } // Loop through postsynaptic spikes @@ -469,15 +469,15 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // **TODO** prod types const std::string offsetTrueSpkPost = (s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired() && s.getArchetype().getTrgNeuronGroup()->isDelayRequired()) ? "$(_post_delay_offset) + " : ""; - groupEnv.getStream() << printSubs("const unsigned int spike = $(_trg_spk)[" + offsetTrueSpkPost + "j];", groupEnv) << std::endl; + groupEnv.printLine("const unsigned int spike = $(_trg_spk)[" + offsetTrueSpkPost + "j];", groupEnv); // Loop through column of presynaptic neurons if (s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - groupEnv.getStream() << printSubs("const unsigned int npre = $(_col_length)[spike];", groupEnv) << std::endl; + groupEnv.printLine("const unsigned int npre = $(_col_length)[spike];"); groupEnv.getStream() << "for (unsigned int i = 0; i < npre; i++)"; } else { - groupEnv.getStream() << "for (unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)"; + groupEnv.print("for (unsigned int i = 0; i < $(num_pre); i++)"); } { CodeStream::Scope b(groupEnv.getStream()); @@ -804,7 +804,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Update transpose variable // **YUCK** this is sorta outside scope - groupEnv.getStream() << printSubs("$(" + transposeVarName + "_transpose)[(j * $(num_pre)) + i] = l" + transposeVarName + ";", groupEnv) << std::endl; + groupEnv.printLine("$(" + transposeVarName + "_transpose)[(j * $(num_pre)) + i] = l" + transposeVarName + ";"); } } @@ -981,11 +981,11 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // If matrix connectivity is ragged if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Zero row lengths - funcEnv.getStream() << printSubs("std::fill_n($(_row_length), $(num_pre), 0);", funcEnv) << std::endl; + funcEnv.printLine("std::fill_n($(_row_length), $(num_pre), 0);"); } else if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { - funcEnv.getStream() << printSubs("const size_t gpSize = ((((size_t)$(num_pre) * (size_t)$(_row_stride)) + 32 - 1) / 32);", funcEnv) << std::endl; - funcEnv.getStream() << "std::fill(" << groupEnv["_num_gp"] << ", gpSize, 0);" << std::endl; + funcEnv.printLine("const size_t gpSize = ((((size_t)$(num_pre) * (size_t)$(_row_stride)) + 32 - 1) / 32);"); + funcEnv.printLine("std::fill($(_gp), gpSize, 0);"); } else { throw std::runtime_error("Only BITMASK and SPARSE format connectivity can be generated using a connectivity initialiser"); @@ -995,7 +995,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler const auto *snippet = s.getArchetype().getConnectivityInitialiser().getSnippet(); if(!snippet->getRowBuildCode().empty()) { // Generate loop through source neurons - groupEnv.getStream() << "for (unsigned int i = 0; i <" << groupEnv["num_pre"] << "; i++)"; + groupEnv.print("for (unsigned int i = 0; i < $(num_pre); i++)"); // Configure substitutions groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); @@ -1010,7 +1010,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler assert(!snippet->getColBuildCode().empty()); // Loop through target neurons - groupEnv.getStream() << "for (unsigned int j = 0; j < " << groupEnv["num_post"] << "; j++)"; + groupEnv.print("for (unsigned int j = 0; j < $(num_post); j++)"); // Configure substitutions groupEnv.add(Type::Uint32.addConst(), "id_post", "j"); @@ -1139,7 +1139,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // If postsynaptic learning is required, initially zero column lengths if (!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { groupEnv.getStream() << "// Zero column lengths" << std::endl; - groupEnv.getStream() << printSubs("std::fill_n($(_col_length), $(num_post), 0);", groupEnv) << std::endl; + groupEnv.printLine("std::fill_n($(_col_length), $(num_post), 0);"); } groupEnv.getStream() << "// Loop through presynaptic neurons" << std::endl; @@ -1487,7 +1487,7 @@ void Backend::genVariableInit(EnvironmentExternalBase &env, const std::string &c //-------------------------------------------------------------------------- void Backend::genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const { - env.getStream() << printSubs("for (unsigned int j = 0; j < $(_row_length)[$(id_pre)]; j++)", env); + env.print("for (unsigned int j = 0; j < $(_row_length)[$(id_pre)]; j++)"); { CodeStream::Scope b(env.getStream()); @@ -1518,14 +1518,12 @@ void Backend::genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, Handl //-------------------------------------------------------------------------- void Backend::genKernelSynapseVariableInit(EnvironmentExternalBase &env, const SynapseInitGroupMerged &sg, HandlerEnv handler) const { - assert(false); - //genKernelIteration(os, sg, sg.getArchetype().getKernelSize().size(), kernelSubs, handler); + genKernelIteration(env, sg, sg.getArchetype().getKernelSize().size(), handler); } //-------------------------------------------------------------------------- void Backend::genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, const CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const { - assert(false); - //genKernelIteration(os, cu, cu.getArchetype().getSynapseGroup()->getKernelSize().size(), kernelSubs, handler); + genKernelIteration(env, cu, cu.getArchetype().getSynapseGroup()->getKernelSize().size(), handler); } //-------------------------------------------------------------------------- void Backend::genGlobalDeviceRNG(CodeStream&, CodeStream&, CodeStream&, CodeStream&, CodeStream&, MemAlloc&) const @@ -1832,10 +1830,10 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda // Detect spike events or spikes and do the update env.getStream() << "// process presynaptic events: " << (trueSpike ? "True Spikes" : "Spike type events") << std::endl; if(sg.getArchetype().getSrcNeuronGroup()->isDelayRequired()) { - env.getStream() << printSubs("for (unsigned int i = 0; i < $(_src_spk_cnt" + eventSuffix + ")[$(pre_delay_slot)]; i++)", env); + env.print("for (unsigned int i = 0; i < $(_src_spk_cnt" + eventSuffix + ")[$(pre_delay_slot)]; i++)"); } else { - env.getStream() << printSubs("for (unsigned int i = 0; i < $(_src_spk_cnt" + eventSuffix + ")[0]; i++)", env); + env.print("for (unsigned int i = 0; i < $(_src_spk_cnt" + eventSuffix + ")[0]; i++)"); } { CodeStream::Scope b(env.getStream()); @@ -1867,7 +1865,7 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda // If connectivity is sparse if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - groupEnv.getStream() << "const unsigned int npost = " << groupEnv["_row_length"] << "[ipre];" << std::endl; + groupEnv.printLine("const unsigned int npost = $(_row_length)[ipre];"); groupEnv.getStream() << "for (unsigned int j = 0; j < npost; j++)"; { CodeStream::Scope b(groupEnv.getStream()); @@ -1892,13 +1890,13 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda } else if(getPreferences().enableBitmaskOptimisations && (sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK)) { // Determine the number of words in each row - groupEnv.getStream() << "const unsigned int rowWords = ((" << env["_num_post"] << " + 32 - 1) / 32);" << std::endl; + groupEnv.printLine("const unsigned int rowWords = (($(num_post) + 32 - 1) / 32);"); groupEnv.getStream() << "for(unsigned int w = 0; w < rowWords; w++)"; { CodeStream::Scope b(groupEnv.getStream()); // Read row word - groupEnv.getStream() << "uint32_t connectivityWord = " << groupEnv["_gp"] << "[(ipre * rowWords) + w];" << std::endl; + groupEnv.printLine("uint32_t connectivityWord = $(_gp)[(ipre * rowWords) + w];"); // Set ipost to first synapse in connectivity word groupEnv.getStream() << "unsigned int ipost = w * 32;" << std::endl; @@ -1921,7 +1919,7 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda // If we aren't in padding region // **TODO** don't bother checking if there is no padding - groupEnv.getStream() << "if(ipost < " << groupEnv["num_post"] << ")"; + groupEnv.print("if(ipost < $(num_post))"); { CodeStream::Scope b(env.getStream()); if(trueSpike) { @@ -1947,7 +1945,7 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { // **TODO** 64-bit index - synEnv.getStream() << printSubs("const uint64_t gid = ($(id_pre) * $(num_post)) + $(id_post);", synEnv) << std::endl; + synEnv.printLine("const uint64_t gid = ($(id_pre) * $(num_post)) + $(id_post);"); synEnv.getStream() << "if (B(" << synEnv["_gp"] << "[gid / 32], gid & 31))" << CodeStream::OB(20); } @@ -1984,27 +1982,27 @@ void Backend::genEmitSpike(EnvironmentExternalBase &env, NeuronUpdateGroupMerged const std::string spikeQueueOffset = spikeDelayRequired ? "$(_write_delay_offset) + " : ""; const std::string suffix = trueSpike ? "" : "_evnt"; - env.getStream() << printSubs("$(_spk" + suffix + ")[" + spikeQueueOffset + "$(_spk_cnt" + suffix + ")", env); + env.print("$(_spk" + suffix + ")[" + spikeQueueOffset + "$(_spk_cnt" + suffix + ")"); if(spikeDelayRequired) { // WITH DELAY - env.getStream() << "[*" << env["_spk_que_ptr"] << "]++]"; + env.print("[*$(_spk_que_ptr)]++]"); } else { // NO DELAY env.getStream() << "[0]++]"; } - env.getStream() << " = " << env["id"] << ";" << std::endl; + env.printLine(" = $(id);"); // Reset spike and spike-like-event times const std::string queueOffset = ng.getArchetype().isDelayRequired() ? "$(_write_delay_offset) + " : ""; if(trueSpike && ng.getArchetype().isSpikeTimeRequired()) { - env.getStream() << printSubs("$(_spk_time)[" + queueOffset + "$(id)] = $(t);", env) << std::endl; + env.printLine("$(_spk_time)[" + queueOffset + "$(id)] = $(t);"); } else if(!trueSpike && ng.getArchetype().isSpikeEventTimeRequired()) { - env.getStream() << printSubs("$(_spk_evnt_time)[" + queueOffset + "$(id)] = $(t);", env) << std::endl; + env.printLine("$(_spk_evnt_time)[" + queueOffset + "$(id)] = $(t);"); } // If recording is enabled if(recordingEnabled) { - env.getStream() << printSubs("$(_record_spk" + suffix + ")[(recordingTimestep * numRecordingWords) + ($(id) / 32)] |= (1 << ($(id) % 32));", env) << std::endl; + env.printLine("$(_record_spk" + suffix + ")[(recordingTimestep * numRecordingWords) + ($(id) / 32)] |= (1 << ($(id) % 32));"); } } //-------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 2749435d28..3fb2fcabc3 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -73,16 +73,18 @@ void BackendBase::genCustomUpdateIndexCalculation(EnvironmentGroupMergedField &env) const { // If there are delays on presynaptic variable references - if(cu.getArchetype().getPreDelayNeuronGroup() != nullptr) { - os << "const unsigned int preDelayOffset = (*group->preSpkQuePtr * group->numSrcNeurons);" << std::endl; + if(env.getGroup().getArchetype().getPreDelayNeuronGroup() != nullptr) { + env.add(Type::Uint32.addConst(), "_pre_delay_offset", "preDelayOffset", + {env.addInitialiser("const unsigned int preDelayOffset = (*$(_pre_spk_que_ptr) * $(num_pre));")}); } // If there are delays on postsynaptic variable references - if(cu.getArchetype().getPostDelayNeuronGroup() != nullptr) { - os << "const unsigned int postDelayOffset = (*group->postSpkQuePtr * group->numTrgNeurons);" << std::endl; + if(env.getGroup().getArchetype().getPostDelayNeuronGroup() != nullptr) { + env.add(Type::Uint32.addConst(), "_post_delay_offset", "postDelayOffset", + {env.addInitialiser("const unsigned int postDelayOffset = (*$(_post_spk_que_ptr) * $(num_post));")}); } } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 1a22b05889..f07e60c25f 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -58,48 +58,44 @@ size_t BackendSIMT::getSynapticMatrixRowStride(const SynapseGroupInternal &sg) c return getPresynapticUpdateStrategy(sg)->getSynapticMatrixRowStride(sg); } //-------------------------------------------------------------------------- -void BackendSIMT::genPopVariableInit(CodeStream &os, const Substitutions &kernelSubs, Handler handler) const +void BackendSIMT::genPopVariableInit(EnvironmentExternalBase &env, HandlerEnv handler) const { - Substitutions varSubs(&kernelSubs); - // If this is first thread in group - os << "if(" << varSubs["id"] << " == 0)"; + env.getStream() << "if(" << env["id"] << " == 0)"; { - CodeStream::Scope b(os); - handler(os, varSubs); + CodeStream::Scope b(env.getStream()); + handler(env); } } //-------------------------------------------------------------------------- -void BackendSIMT::genVariableInit(CodeStream &os, const std::string &, const std::string &countVarName, - const Substitutions &kernelSubs, Handler handler) const +void BackendSIMT::genVariableInit(EnvironmentExternalBase &env, const std::string&, const std::string&, HandlerEnv handler) const { // Variable should already be provided via parallelism - assert(kernelSubs.hasVarSubstitution(countVarName)); + //assert(kernelSubs.hasVarSubstitution(countVarName)); - Substitutions varSubs(&kernelSubs); - handler(os, varSubs); + handler(env); } //-------------------------------------------------------------------------- -void BackendSIMT::genKernelSynapseVariableInit(CodeStream &os, const SynapseInitGroupMerged &, const Substitutions &kernelSubs, Handler handler) const +void BackendSIMT::genKernelSynapseVariableInit(EnvironmentExternalBase &env, const SynapseInitGroupMerged &sg, HandlerEnv handler) const { // Variable should already be provided via parallelism - assert(kernelSubs.hasVarSubstitution("id")); + //assert(kernelSubs.hasVarSubstitution("id")); - Substitutions varSubs(&kernelSubs); - varSubs.addVarSubstitution("id_syn", varSubs["id"]); + EnvironmentExternal varEnv(env); + varEnv.add(Type::Uint32.addConst(), "id_syn", "$(id)"); - handler(os, varSubs); + handler(varEnv); } //-------------------------------------------------------------------------- -void BackendSIMT::genKernelCustomUpdateVariableInit(CodeStream &os, const CustomWUUpdateInitGroupMerged&, const Substitutions &kernelSubs, Handler handler) const +void BackendSIMT::genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, const CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const { // Variable should already be provided via parallelism - assert(kernelSubs.hasVarSubstitution("id")); + //assert(kernelSubs.hasVarSubstitution("id")); - Substitutions varSubs(&kernelSubs); - varSubs.addVarSubstitution("id_syn", varSubs["id"]); + EnvironmentExternal varEnv(env); + varEnv.add(Type::Uint32.addConst(), "id_syn", "$(id)"); - handler(os, varSubs); + handler(varEnv); } //-------------------------------------------------------------------------- bool BackendSIMT::isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) const @@ -313,84 +309,84 @@ void BackendSIMT::addPresynapticUpdateStrategy(PresynapticUpdateStrategySIMT::Ba s_PresynapticUpdateStrategies.push_back(strategy); } //-------------------------------------------------------------------------- -void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); // Parallelise over neuron groups idStart = 0; genParallelGroup( - os, kernelSubs, modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups(), idStart, + env, modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups(), idStart, [this](const NeuronGroupInternal &ng) { return padKernelSize(ng.getNumNeurons(), KernelNeuronUpdate); }, - [batchSize,this](CodeStream &os, const NeuronPrevSpikeTimeUpdateGroupMerged &ng, Substitutions &popSubs) + [batchSize, this](EnvironmentExternalBase &popEnv, NeuronPrevSpikeTimeUpdateGroupMerged &ng) { - CodeStream::Scope b(os); + CodeStream::Scope b(popEnv.getStream()); // If neuron group requires delays if(ng.getArchetype().isDelayRequired()) { if(batchSize == 1) { - os << "const unsigned int lastTimestepDelaySlot = *group->spkQuePtr;" << std::endl; + popEnv.printLine("const unsigned int lastTimestepDelaySlot = *$(_spk_que_ptr);"); } else { - os << "const unsigned int lastTimestepDelaySlot = *group->spkQuePtr + (batch * " << ng.getArchetype().getNumDelaySlots() << ");" << std::endl; + popEnv.printLine("const unsigned int lastTimestepDelaySlot = *$(_spk_que_ptr) + (batch * " + std::to_string(ng.getArchetype().getNumDelaySlots()) + ");"); } - os << "const unsigned int lastTimestepDelayOffset = lastTimestepDelaySlot * group->numNeurons;" << std::endl; + popEnv.printLine("const unsigned int lastTimestepDelayOffset = lastTimestepDelaySlot * $(num_neurons);"); if(ng.getArchetype().isPrevSpikeTimeRequired()) { // If there is a spike for this thread, set previous spike time to time of last timestep // **NOTE** spkQuePtr is updated below so this already points to last timestep - os << "if(" << popSubs["id"] << " < group->spkCnt[lastTimestepDelaySlot])"; + popEnv.print("if($(id) < $(_spk_cnt)[lastTimestepDelaySlot])"); { - CodeStream::Scope b(os); - os << "group->prevST[lastTimestepDelayOffset + group->spk[lastTimestepDelayOffset + " << popSubs["id"] << "]] = " << popSubs["t"] << " - DT;" << std::endl; + CodeStream::Scope b(popEnv.getStream()); + popEnv.printLine("$(_prev_spk_time)[lastTimestepDelayOffset + $(_spk)[lastTimestepDelayOffset + $(id)]] = $(t) - DT;") } } if(ng.getArchetype().isPrevSpikeEventTimeRequired()) { // If there is a spike-like-event for this thread, set previous spike-like-event time to time of last timestep // **NOTE** spkQuePtr is updated below so this already points to last timestep - os << "if(" << popSubs["id"] << " < group->spkCntEvnt[lastTimestepDelaySlot])"; + popEnv.print("if($(id) < $(_spk_cnt_envt)[lastTimestepDelaySlot])"); { - CodeStream::Scope b(os); - os << "group->prevSET[lastTimestepDelayOffset + group->spkEvnt[lastTimestepDelayOffset + " << popSubs["id"] << "]] = " << popSubs["t"] << " - DT;" << std::endl; + CodeStream::Scope b(popEnv.getStream()); + popEnv.printLine("$(_prev_spk_evnt_time)[lastTimestepDelayOffset + $(_spk_evnt)[lastTimestepDelayOffset + $(id)]] = $(t) - DT;"); } } } - // Otherwise + // Otherwises else { if(batchSize > 1) { - os << "const unsigned int batchOffset = group->numNeurons * batch;" << std::endl; + popEnv.printLine("const unsigned int batchOffset = $(num_neurons) * batch;"); } if(ng.getArchetype().isPrevSpikeTimeRequired()) { // If there is a spike for this thread, set previous spike time to time of last timestep - os << "if(" << popSubs["id"] << " < group->spkCnt[" << ((batchSize == 1) ? "0" : "batch") << "])"; + popEnv.print("if($(id) < $(_spk_cnt)[" + std::string{(batchSize == 1) ? "0" : "batch"} + "])"); { - CodeStream::Scope b(os); - os << "group->prevST[group->spk["; + CodeStream::Scope b(popEnv.getStream()); + popEnv.print("$(_prev_spk_time)[$(_spk)["); if(batchSize > 1) { - os << "batchOffset + "; + popEnv.getStream() << "batchOffset + "; } - os << popSubs["id"] << "]] = " << popSubs["t"] << " - DT;" << std::endl; + popEnv.printLine("$(id)]] = $(t) - DT;"); } } if(ng.getArchetype().isPrevSpikeEventTimeRequired()) { // If there is a spike-like-event for this thread, set previous spike-like-event time to time of last timestep - os << "if(" << popSubs["id"] << " < group->spkCntEvnt[" << ((batchSize == 1) ? "0" : "batch") << "])"; + popEnv.print("if($(id) < $(_spk_cnt_evnt)[" + std::string{(batchSize == 1) ? "0" : "batch"} + "])"); { - CodeStream::Scope b(os); - os << "group->prevSET[group->spkEvnt["; + CodeStream::Scope b(popEnv.getStream()); + popEnv.print("$(_prev_spk_evnt_time)[$(_spk_evnt)["); if(batchSize > 1) { - os << "batchOffset + "; + popEnv.getStream() << "batchOffset + "; } - os << popSubs["id"] << "]] = " << popSubs["t"] << " - DT;" << std::endl; + popEnv.printLine("$(id)]] = $(t) - DT;"); } } } - os << std::endl; + popEnv.getStream() << std::endl; }); } //-------------------------------------------------------------------------- -void BackendSIMT::genNeuronSpikeQueueUpdateKernel(CodeStream &os, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); @@ -398,34 +394,34 @@ void BackendSIMT::genNeuronSpikeQueueUpdateKernel(CodeStream &os, const ModelSpe idStart = 0; for(const auto &n : modelMerged.getMergedNeuronSpikeQueueUpdateGroups()) { if(idStart == 0) { - os << "if(id < " << n.getGroups().size() << ")"; + env.getStream() << "if(id < " << n.getGroups().size() << ")"; } else { - os << "if(id >= " << idStart << " && id < " << idStart + n.getGroups().size() << ")"; + env.getStream() << "if(id >= " << idStart << " && id < " << idStart + n.getGroups().size() << ")"; } { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); // Use this to get reference to merged group structure - os << getPointerPrefix() << "struct MergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << " *group = &d_mergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << "[id - " << idStart << "]; " << std::endl; + env.getStream() << getPointerPrefix() << "struct MergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << " *group = &d_mergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << "[id - " << idStart << "]; " << std::endl; if(n.getArchetype().isDelayRequired()) { // with delay - os << "*group->spkQuePtr = (*group->spkQuePtr + 1) % " << n.getArchetype().getNumDelaySlots() << ";" << std::endl; + env.getStream() << "*" << env["_spk_que_ptr"] << " = (*" << env["_spk_que_ptr"] << " + 1) % " << n.getArchetype().getNumDelaySlots() << ";" << std::endl; } if(batchSize > 1) { - os << "for(unsigned int batch = 0; batch < " << batchSize << "; batch++)" << CodeStream::OB(1); + env.getStream() << "for(unsigned int batch = 0; batch < " << batchSize << "; batch++)" << CodeStream::OB(1); } - n.genMergedGroupSpikeCountReset(os, batchSize); + n.genMergedGroupSpikeCountReset(env, batchSize); if(batchSize > 1) { - os << CodeStream::CB(1); + env.getStream() << CodeStream::CB(1); } } idStart += n.getGroups().size(); } } //-------------------------------------------------------------------------- -void BackendSIMT::genNeuronUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); @@ -433,79 +429,80 @@ void BackendSIMT::genNeuronUpdateKernel(CodeStream &os, const Substitutions &ker if(std::any_of(modelMerged.getMergedNeuronUpdateGroups().cbegin(), modelMerged.getMergedNeuronUpdateGroups().cend(), [](const NeuronUpdateGroupMerged &n) { return n.getArchetype().isSpikeEventRequired(); })) { - os << getSharedPrefix() << "unsigned int shSpkEvnt[" << getKernelBlockSize(KernelNeuronUpdate) << "];" << std::endl; - os << getSharedPrefix() << "unsigned int shPosSpkEvnt;" << std::endl; - os << getSharedPrefix() << "unsigned int shSpkEvntCount;" << std::endl; - os << std::endl; - os << "if (" << getThreadID() << " == 1)"; + env.getStream() << getSharedPrefix() << "unsigned int shSpkEvnt[" << getKernelBlockSize(KernelNeuronUpdate) << "];" << std::endl; + env.getStream() << getSharedPrefix() << "unsigned int shPosSpkEvnt;" << std::endl; + env.getStream() << getSharedPrefix() << "unsigned int shSpkEvntCount;" << std::endl; + env.getStream() << std::endl; + env.getStream() << "if (" << getThreadID() << " == 1)"; { - CodeStream::Scope b(os); - os << "shSpkEvntCount = 0;" << std::endl; + CodeStream::Scope b(env.getStream()); + env.getStream() << "shSpkEvntCount = 0;" << std::endl; } - os << std::endl; + env.getStream() << std::endl; } // If any neuron groups emit true spikes if(std::any_of(modelMerged.getMergedNeuronUpdateGroups().cbegin(), modelMerged.getMergedNeuronUpdateGroups().cend(), [](const NeuronUpdateGroupMerged &n) { return !n.getArchetype().getNeuronModel()->getThresholdConditionCode().empty(); })) { - os << getSharedPrefix() << "unsigned int shSpk[" << getKernelBlockSize(KernelNeuronUpdate) << "];" << std::endl; - os << getSharedPrefix() << "unsigned int shPosSpk;" << std::endl; - os << getSharedPrefix() << "unsigned int shSpkCount;" << std::endl; - os << "if (" << getThreadID() << " == 0)"; + env.getStream() << getSharedPrefix() << "unsigned int shSpk[" << getKernelBlockSize(KernelNeuronUpdate) << "];" << std::endl; + env.getStream() << getSharedPrefix() << "unsigned int shPosSpk;" << std::endl; + env.getStream() << getSharedPrefix() << "unsigned int shSpkCount;" << std::endl; + env.getStream() << "if (" << getThreadID() << " == 0)"; { - CodeStream::Scope b(os); - os << "shSpkCount = 0;" << std::endl; + CodeStream::Scope b(env.getStream()); + env.getStream() << "shSpkCount = 0;" << std::endl; } - os << std::endl; + env.getStream() << std::endl; } // If any neuron groups record spikes if(std::any_of(modelMerged.getMergedNeuronUpdateGroups().cbegin(), modelMerged.getMergedNeuronUpdateGroups().cend(), [](const NeuronUpdateGroupMerged &n) { return n.getArchetype().isSpikeRecordingEnabled(); })) { - genRecordingSharedMemInit(os, ""); + genRecordingSharedMemInit(env.getStream(), ""); } // If any neuron groups record spike-like events if(std::any_of(modelMerged.getMergedNeuronUpdateGroups().cbegin(), modelMerged.getMergedNeuronUpdateGroups().cend(), [](const NeuronUpdateGroupMerged &n) { return n.getArchetype().isSpikeEventRecordingEnabled(); })) { - genRecordingSharedMemInit(os, "Evnt"); + genRecordingSharedMemInit(env.getStream(), "Evnt"); } - genSharedMemBarrier(os); + genSharedMemBarrier(env.getStream()); // Parallelise over neuron groups idStart = 0; genParallelGroup( - os, kernelSubs, modelMerged.getMergedNeuronUpdateGroups(), idStart, + env, modelMerged.getMergedNeuronUpdateGroups(), idStart, [this](const NeuronGroupInternal &ng) { return padKernelSize(ng.getNumNeurons(), KernelNeuronUpdate); }, - [batchSize, &modelMerged, this](CodeStream &os, const NeuronUpdateGroupMerged &ng, Substitutions &popSubs) + [batchSize, &modelMerged, this](EnvironmentExternalBase &popEnv, NeuronUpdateGroupMerged &ng) { - genNeuronIndexCalculation(os, ng, batchSize); - os << std::endl; + EnvironmentGroupMergedField neuronEnv(popEnv, ng); + genNeuronIndexCalculation(neuronEnv, batchSize); + neuronEnv.getStream() << std::endl; // Call handler to generate generic neuron code - os << "if(" << popSubs["id"] << " < group->numNeurons)"; + neuronEnv.print("if($(id) < $(num_neurons))"); { - CodeStream::Scope b(os); + CodeStream::Scope b(neuronEnv.getStream()); // Copy global RNG stream to local and use pointer to this for rng if(ng.getArchetype().isSimRNGRequired()) { genPopulationRNGPreamble(os, popSubs, "group->rng[" + ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) + "]"); } - ng.generateNeuronUpdate(*this, os, modelMerged, popSubs, + ng.generateNeuronUpdate(*this, neuronEnv, modelMerged, // Emit true spikes - [&modelMerged, this](CodeStream &neuronUpdateKernelsBody, const NeuronUpdateGroupMerged &ng, Substitutions &subs) + [&modelMerged, this](EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng) { - genEmitSpike(modelMerged, neuronUpdateKernelsBody, subs, "", ng.getArchetype().isSpikeRecordingEnabled()); + genEmitSpike(env, modelMerged, "", ng.getArchetype().isSpikeRecordingEnabled()); }, // Emit spike-like events - [&modelMerged, this](CodeStream &neuronUpdateKernelsBody, const NeuronUpdateGroupMerged &ng, Substitutions &subs) + [&modelMerged, this](EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng) { - genEmitSpike(modelMerged, neuronUpdateKernelsBody, subs, "Evnt", ng.getArchetype().isSpikeEventRecordingEnabled()); + genEmitSpike(env, modelMerged, "Evnt", ng.getArchetype().isSpikeEventRecordingEnabled()); }); // Copy local stream back to local @@ -514,64 +511,64 @@ void BackendSIMT::genNeuronUpdateKernel(CodeStream &os, const Substitutions &ker } } - genSharedMemBarrier(os); + genSharedMemBarrier(neuronEnv.getStream()); if(ng.getArchetype().isSpikeEventRequired()) { - os << "if (" << getThreadID() << " == 1)"; + neuronEnv.getStream() << "if (" << getThreadID() << " == 1)"; { - CodeStream::Scope b(os); - os << "if (shSpkEvntCount > 0)"; + CodeStream::Scope b(neuronEnv.getStream()); + neuronEnv.getStream() << "if (shSpkEvntCount > 0)"; { - CodeStream::Scope b(os); - os << "shPosSpkEvnt = " << getAtomic(Type::Uint32) << "(&group->spkCntEvnt"; + CodeStream::Scope b(neuronEnv.getStream()); + neuronEnv.getStream() << "shPosSpkEvnt = " << getAtomic(Type::Uint32) << "(&group->spkCntEvnt"; if(ng.getArchetype().isDelayRequired()) { - os << "[*group->spkQuePtr"; + neuronEnv.getStream() << "[*" << neuronEnv["_spk_que_ptr"]; if(batchSize > 1) { - os << " + (batch * " << ng.getArchetype().getNumDelaySlots() << ")"; + neuronEnv.getStream() << " + (batch * " << ng.getArchetype().getNumDelaySlots() << ")"; } - os << "], shSpkEvntCount);" << std::endl; + neuronEnv.getStream() << "], shSpkEvntCount);" << std::endl; } else { - os << "[" << ((batchSize > 1) ? "batch" : "0") << "], shSpkEvntCount);" << std::endl; + neuronEnv.getStream() << "[" << ((batchSize > 1) ? "batch" : "0") << "], shSpkEvntCount);" << std::endl; } } } - genSharedMemBarrier(os); + genSharedMemBarrier(neuronEnv.getStream()); } if(!ng.getArchetype().getNeuronModel()->getThresholdConditionCode().empty()) { - os << "if(" << getThreadID() << " == 0)"; + neuronEnv.getStream() << "if(" << getThreadID() << " == 0)"; { - CodeStream::Scope b(os); - os << "if (shSpkCount > 0)"; + CodeStream::Scope b(neuronEnv.getStream()); + neuronEnv.getStream() << "if (shSpkCount > 0)"; { - CodeStream::Scope b(os); - os << "shPosSpk = " << getAtomic(Type::Uint32) << "(&group->spkCnt"; + CodeStream::Scope b(neuronEnv.getStream()); + neuronEnv.getStream() << "shPosSpk = " << getAtomic(Type::Uint32) << "(&group->spkCnt"; if(ng.getArchetype().isDelayRequired() && ng.getArchetype().isTrueSpikeRequired()) { - os << "[*group->spkQuePtr"; + neuronEnv.getStream() << "[*" << neuronEnv["_spk_que_ptr"]; if(batchSize > 1) { - os << " + (batch * " << ng.getArchetype().getNumDelaySlots() << ")"; + neuronEnv.getStream() << " + (batch * " << ng.getArchetype().getNumDelaySlots() << ")"; } - os << "], shSpkCount);" << std::endl; + neuronEnv.getStream() << "], shSpkCount);" << std::endl; } else { - os << "[" << ((batchSize > 1) ? "batch" : "0") << "], shSpkCount);" << std::endl; + neuronEnv.getStream() << "[" << ((batchSize > 1) ? "batch" : "0") << "], shSpkCount);" << std::endl; } } } - genSharedMemBarrier(os); + genSharedMemBarrier(neuronEnv.getStream()); } const std::string queueOffset = ng.getWriteVarIndex(ng.getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, ""); if(ng.getArchetype().isSpikeEventRequired()) { - os << "if(" << getThreadID() << " < shSpkEvntCount)"; + neuronEnv.getStream() << "if(" << getThreadID() << " < shSpkEvntCount)"; { - CodeStream::Scope b(os); - os << "const unsigned int n = shSpkEvnt[" << getThreadID() << "];" << std::endl; + CodeStream::Scope b(neuronEnv.getStream()); + neuronEnv.getStream() << "const unsigned int n = shSpkEvnt[" << getThreadID() << "];" << std::endl; - os << "group->spkEvnt[" << queueOffset << "shPosSpkEvnt + " << getThreadID() << "] = n;" << std::endl; + neuronEnv.printLine("$(_spk_evnt)[" + queueOffset + "shPosSpkEvnt + " + getThreadID() + "] = n;"); if(ng.getArchetype().isSpikeEventTimeRequired()) { - os << "group->seT[" << queueOffset << "n] = t;" << std::endl; + neuronEnv.printLine("$(_spk_evnt_time)[" + queueOffset + "n] = t;"); } } } @@ -579,34 +576,33 @@ void BackendSIMT::genNeuronUpdateKernel(CodeStream &os, const Substitutions &ker if(!ng.getArchetype().getNeuronModel()->getThresholdConditionCode().empty()) { const std::string queueOffsetTrueSpk = ng.getWriteVarIndex(ng.getArchetype().isTrueSpikeRequired() && ng.getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, ""); - - os << "if(" << getThreadID() << " < shSpkCount)"; + neuronEnv.getStream() << "if(" << getThreadID() << " < shSpkCount)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(neuronEnv.getStream()); - os << "const unsigned int n = shSpk[" << getThreadID() << "];" << std::endl; + neuronEnv.getStream() << "const unsigned int n = shSpk[" << getThreadID() << "];" << std::endl; // Create new substition stack and explicitly replace id with 'n' and perform WU var update - Substitutions wuSubs(&popSubs); - wuSubs.addVarSubstitution("id", "n", true); - ng.generateWUVarUpdate(*this, os, modelMerged, wuSubs); + EnvironmentExternal wuEnv(neuronEnv); + wuEnv.add(Type::Uint32.addConst(), "id", "n"); + ng.generateWUVarUpdate(*this, wuEnv, modelMerged); - os << "group->spk[" << queueOffsetTrueSpk << "shPosSpk + " << getThreadID() << "] = n;" << std::endl; + neuronEnv.printLine("$(_spk)[" + queueOffsetTrueSpk + "shPosSpk + " + getThreadID() + "] = n;"); if(ng.getArchetype().isSpikeTimeRequired()) { - os << "group->sT[" << queueOffset << "n] = t;" << std::endl; + neuronEnv.printLine("$(_spk_time)[" + queueOffset + "n] = t;"); } } } // If we're recording spikes or spike-like events, use enough threads to copy this block's recording words if(ng.getArchetype().isSpikeRecordingEnabled() || ng.getArchetype().isSpikeEventRecordingEnabled()) { - os << "if(" << getThreadID() << " < " << m_KernelBlockSizes[KernelNeuronUpdate] / 32 << ")"; + neuronEnv.getStream() << "if(" << getThreadID() << " < " << m_KernelBlockSizes[KernelNeuronUpdate] / 32 << ")"; { - CodeStream::Scope b(os); + CodeStream::Scope b(neuronEnv.getStream()); // Calculate number of words which will be used to record this population's spikes in each batch - os << "const unsigned int numRecordingWords = (group->numNeurons + 31) / 32;" << std::endl; - os << "const unsigned int popWordIdx = (" << popSubs["id"] << " / 32) + " << getThreadID() << ";" << std::endl; + neuronEnv.printLine("const unsigned int numRecordingWords = ($(num_neurons) + 31) / 32;"); + neuronEnv.printLine("const unsigned int popWordIdx = ($(id) / 32) + " + getThreadID() + << ";"); // Build global index std::string globalIndex = "(recordingTimestep * numRecordingWords * " + std::to_string(batchSize) + ") + popWordIdx"; @@ -614,25 +610,25 @@ void BackendSIMT::genNeuronUpdateKernel(CodeStream &os, const Substitutions &ker globalIndex += " + (batch * numRecordingWords)"; } - os << "if(popWordIdx < numRecordingWords)"; + neuronEnv.getStream() << "if(popWordIdx < numRecordingWords)"; { - CodeStream::Scope c(os); + CodeStream::Scope c(neuronEnv.getStream()); // If we are recording spikes, copy word to correct location in global memory if(ng.getArchetype().isSpikeRecordingEnabled()) { - os << "group->recordSpk[" << globalIndex << "] = shSpkRecord"; + neuronEnv.getStream() << neuronEnv["_record_spk"] << "[" << globalIndex << "] = shSpkRecord"; if(m_KernelBlockSizes[KernelNeuronUpdate] != 32) { - os << "[" << getThreadID() << "]"; + neuronEnv.getStream() << "[" << getThreadID() << "]"; } - os << ";" << std::endl; + neuronEnv.getStream() << ";" << std::endl; } // If we are recording spike-like events, copy word to correct location in global memory if(ng.getArchetype().isSpikeEventRecordingEnabled()) { - os << "group->recordSpkEvent[" << globalIndex << "] = shSpkEvntRecord"; + neuronEnv.getStream() << neuronEnv["_record_spk_evnt"] << "[" << globalIndex << "] = shSpkEvntRecord"; if(m_KernelBlockSizes[KernelNeuronUpdate] != 32) { - os << "[" << getThreadID() << "]"; + neuronEnv.getStream() << "[" << getThreadID() << "]"; } - os << ";" << std::endl; + neuronEnv.getStream() << ";" << std::endl; } } } @@ -640,33 +636,35 @@ void BackendSIMT::genNeuronUpdateKernel(CodeStream &os, const Substitutions &ker }); } //-------------------------------------------------------------------------- -void BackendSIMT::genSynapseDendriticDelayUpdateKernel(CodeStream &os, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const { // Loop through merged synapse groups idStart = 0; for(const auto &n : modelMerged.getMergedSynapseDendriticDelayUpdateGroups()) { - os << "// merged" << n.getIndex() << std::endl; + env.getStream() << "// merged" << n.getIndex() << std::endl; if(idStart == 0) { - os << "if(id < " << n.getGroups().size() << ")"; + env.getStream() << "if(id < " << n.getGroups().size() << ")"; } else { - os << "if(id >= " << idStart << " && id < " << idStart + n.getGroups().size() << ")"; + env.getStream() << "if(id >= " << idStart << " && id < " << idStart + n.getGroups().size() << ")"; } { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); // Use this to get reference to merged group structure - os << getPointerPrefix() << "struct MergedSynapseDendriticDelayUpdateGroup" << n.getIndex() << " *group = &d_mergedSynapseDendriticDelayUpdateGroup" << n.getIndex() << "[id - " << idStart << "]; " << std::endl; + env.getStream() << getPointerPrefix() << "struct MergedSynapseDendriticDelayUpdateGroup" << n.getIndex() << " *group = &d_mergedSynapseDendriticDelayUpdateGroup" << n.getIndex() << "[id - " << idStart << "]; " << std::endl; - os << "*group->denDelayPtr = (*group->denDelayPtr + 1) % " << n.getArchetype().getMaxDendriticDelayTimesteps() << ";" << std::endl; + env.printLine("*$(_den_delay_ptr) = (*$(_den_delay_ptr) + 1) % " + std::to_string(n.getArchetype().getMaxDendriticDelayTimesteps()) + ";"); } idStart += n.getGroups().size(); } - os << std::endl; + env.getStream() << std::endl; } //-------------------------------------------------------------------------- -void BackendSIMT::genPresynapticUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const { + EnvironmentExternal kernelEnv(env); + // We need shLg if any synapse groups accumulate into shared memory // Determine the maximum shared memory outputs size_t maxSharedMemPerThread = 0; @@ -677,48 +675,32 @@ void BackendSIMT::genPresynapticUpdateKernel(CodeStream &os, const Substitutions // If any shared memory is required, declare array if(maxSharedMemPerThread > 0) { - os << getSharedPrefix() <<" scalar shLg[" << maxSharedMemPerThread * getKernelBlockSize(KernelPresynapticUpdate) << "];" << std::endl; - } - - // If any of these synapse groups also have sparse connectivity, allocate shared memory for row length - if(std::any_of(modelMerged.getMergedPresynapticUpdateGroups().cbegin(), modelMerged.getMergedPresynapticUpdateGroups().cend(), - [](const PresynapticUpdateGroupMerged &sg) - { - return (sg.getArchetype().getSpanType() == SynapseGroup::SpanType::POSTSYNAPTIC - && (sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE)); - })) - { - os << getSharedPrefix() << "unsigned int shRowLength[" << getKernelBlockSize(KernelPresynapticUpdate) << "];" << std::endl; + kernelEnv.getStream() << getSharedPrefix() <<" scalar shLg[" << maxSharedMemPerThread * getKernelBlockSize(KernelPresynapticUpdate) << "];" << std::endl; } - if(std::any_of(modelMerged.getMergedPresynapticUpdateGroups().cbegin(), modelMerged.getMergedPresynapticUpdateGroups().cend(), - [](const PresynapticUpdateGroupMerged &sg) - { - return (sg.getArchetype().isTrueSpikeRequired() || !sg.getArchetype().getWUModel()->getLearnPostCode().empty()); - })) - { - os << getSharedPrefix() << "unsigned int shSpk[" << getKernelBlockSize(KernelPresynapticUpdate) << "];" << std::endl; - } + // Shared memory for row length + kernelEnv.add(Type::Uint32.createPointer(), "_sh_row_length", "shRowLength", + {kernelEnv.addInitialiser(getSharedPrefix() + "unsigned int shRowLength[" + std::to_string(getKernelBlockSize(KernelPresynapticUpdate)) + "];")}); - if(std::any_of(modelMerged.getMergedPresynapticUpdateGroups().cbegin(), modelMerged.getMergedPresynapticUpdateGroups().cend(), - [](const PresynapticUpdateGroupMerged &sg) { return (sg.getArchetype().isSpikeEventRequired()); })) - { - os << getSharedPrefix() << "unsigned int shSpkEvnt[" << getKernelBlockSize(KernelPresynapticUpdate) << "];" << std::endl; - } + // Shared memory for spikes and spike events + kernelEnv.add(Type::Uint32.createPointer(), "_sh_spk", "shSpk", + {kernelEnv.addInitialiser(getSharedPrefix() + "unsigned int shSpk[" + std::to_string(getKernelBlockSize(KernelPresynapticUpdate)) + "];")}); + kernelEnv.add(Type::Uint32.createPointer(), "_sh_spk_evnt", "shSpkEvnt", + {kernelEnv.addInitialiser(getSharedPrefix() + "unsigned int shSpkEvnt[" + std::to_string(getKernelBlockSize(KernelPresynapticUpdate)) + "];")}); // Parallelise over synapse groups idStart = 0; genParallelGroup( - os, kernelSubs, modelMerged.getMergedPresynapticUpdateGroups(), idStart, + kernelEnv, modelMerged.getMergedPresynapticUpdateGroups(), idStart, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumPresynapticUpdateThreads(sg, getPreferences()), KernelPresynapticUpdate); }, - [&modelMerged, this](CodeStream &os, const PresynapticUpdateGroupMerged &sg, const Substitutions &popSubs) + [&modelMerged, this](EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg) { // Get presynaptic update strategy to use for this synapse group const auto *presynapticUpdateStrategy = getPresynapticUpdateStrategy(sg.getArchetype()); LOGD_BACKEND << "Using '" << typeid(*presynapticUpdateStrategy).name() << "' presynaptic update strategy for merged synapse group '" << sg.getIndex() << "'"; // Generate index calculation code - genSynapseIndexCalculation(os, sg, modelMerged.getModel().getBatchSize()); + genSynapseIndexCalculation(env, modelMerged.getModel().getBatchSize()); // Generate preamble presynapticUpdateStrategy->genPreamble(os, modelMerged, sg, popSubs, *this); @@ -742,85 +724,87 @@ void BackendSIMT::genPresynapticUpdateKernel(CodeStream &os, const Substitutions }); } //-------------------------------------------------------------------------- -void BackendSIMT::genPostsynapticUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const { - os << getSharedPrefix() << "unsigned int shSpk[" << getKernelBlockSize(KernelPostsynapticUpdate) << "];" << std::endl; - if(std::any_of(modelMerged.getModel().getSynapseGroups().cbegin(), modelMerged.getModel().getSynapseGroups().cend(), - [](const ModelSpec::SynapseGroupValueType &s) - { - return ((s.second.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && !s.second.getWUModel()->getLearnPostCode().empty()); - })) - { - os << getSharedPrefix() << "unsigned int shColLength[" << getKernelBlockSize(KernelPostsynapticUpdate) << "];" << std::endl; - } + EnvironmentExternal kernelEnv(env); + + // Shared memory for column length and spikes + kernelEnv.add(Type::Uint32.createPointer(), "_sh_colw_length", "shColLength", + {kernelEnv.addInitialiser(getSharedPrefix() + "unsigned int shColLength[" + std::to_string(getKernelBlockSize(KernelPostsynapticUpdate)) + "];")}); + kernelEnv.add(Type::Uint32.createPointer(), "_sh_spk", "shSpk", + {kernelEnv.addInitialiser(getSharedPrefix() + "unsigned int shSpk[" + std::to_string(getKernelBlockSize(KernelPostsynapticUpdate)) + "];")}); // Parallelise over postsynaptic update groups idStart = 0; - genParallelGroup(os, kernelSubs, modelMerged.getMergedPostsynapticUpdateGroups(), idStart, + genParallelGroup(kernelEnv, modelMerged.getMergedPostsynapticUpdateGroups(), idStart, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumPostsynapticUpdateThreads(sg), KernelPostsynapticUpdate); }, - [&modelMerged, this](CodeStream &os, const PostsynapticUpdateGroupMerged &sg, Substitutions &popSubs) + [&modelMerged, this](EnvironmentExternalBase &env, PostsynapticUpdateGroupMerged &sg) { + EnvironmentGroupMergedField groupEnv(env); + // Generate index calculation code const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - genSynapseIndexCalculation(os, sg, batchSize); + genSynapseIndexCalculation(groupEnv, batchSize); - os << "const unsigned int numSpikes = group->trgSpkCnt[" << sg.getPostSlot(batchSize) << "];" << std::endl; + groupEnv.printLine("const unsigned int numSpikes = $(_trg_spk_cnt)[" + sg.getPostSlot(batchSize) + "];"); - - os << "const unsigned int numSpikeBlocks = (numSpikes + " << getKernelBlockSize(KernelPostsynapticUpdate) - 1 << ") / " << getKernelBlockSize(KernelPostsynapticUpdate) << ";" << std::endl; - os << "for (unsigned int r = 0; r < numSpikeBlocks; r++)"; + groupEnv.getStream() << "const unsigned int numSpikeBlocks = (numSpikes + " << getKernelBlockSize(KernelPostsynapticUpdate) - 1 << ") / " << getKernelBlockSize(KernelPostsynapticUpdate) << ";" << std::endl; + groupEnv.getStream() << "for (unsigned int r = 0; r < numSpikeBlocks; r++)"; { - CodeStream::Scope b(os); - os << "const unsigned int numSpikesInBlock = (r == numSpikeBlocks - 1) ? ((numSpikes - 1) % " << getKernelBlockSize(KernelPostsynapticUpdate) << ") + 1 : " << getKernelBlockSize(KernelPostsynapticUpdate) << ";" << std::endl; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "const unsigned int numSpikesInBlock = (r == numSpikeBlocks - 1) ? ((numSpikes - 1) % " << getKernelBlockSize(KernelPostsynapticUpdate) << ") + 1 : " << getKernelBlockSize(KernelPostsynapticUpdate) << ";" << std::endl; - os << "if (" << getThreadID() << " < numSpikesInBlock)"; + groupEnv.getStream() << "if (" << getThreadID() << " < numSpikesInBlock)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(groupEnv.getStream()); const std::string index = "(r * " + std::to_string(getKernelBlockSize(KernelPostsynapticUpdate)) + ") + " + getThreadID(); - os << "const unsigned int spk = group->trgSpk[" << sg.getPostVarIndex(batchSize, VarAccessDuplication::DUPLICATE, index) << "];" << std::endl; - os << "shSpk[" << getThreadID() << "] = spk;" << std::endl; + groupEnv.printLine("const unsigned int spk = $(_trg_spk)[" + sg.getPostVarIndex(batchSize, VarAccessDuplication::DUPLICATE, index) + "];"); + groupEnv.getStream() << "shSpk[" << getThreadID() << "] = spk;" << std::endl; if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << "shColLength[" << getThreadID() << "] = group->colLength[spk];" << std::endl; + groupEnv.getStream() << "shColLength[" << getThreadID() << "] = group->colLength[spk];" << std::endl; } } - genSharedMemBarrier(os); - os << "// only work on existing neurons" << std::endl; - os << "if (" << popSubs["id"] << " < group->colStride)"; + genSharedMemBarrier(groupEnv.getStream()); + groupEnv.getStream() << "// only work on existing neurons" << std::endl; + groupEnv.print("if ($(id) < $(_col_stride))"); { - CodeStream::Scope b(os); - os << "// loop through all incoming spikes for learning" << std::endl; - os << "for (unsigned int j = 0; j < numSpikesInBlock; j++)"; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "// loop through all incoming spikes for learning" << std::endl; + groupEnv.getStream() << "for (unsigned int j = 0; j < numSpikesInBlock; j++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(groupEnv.getStream()); - Substitutions synSubs(&popSubs); if (sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << "if (" << synSubs["id"] << " < shColLength[j])" << CodeStream::OB(1540); - os << "const unsigned int synAddress = group->remap[(shSpk[j] * group->colStride) + " << popSubs["id"] << "];" << std::endl; + groupEnv.print("if ($(id) < $(_sh_col_length)[j])"); + groupEnv.getStream() << CodeStream::OB(1540); + } + + EnvironmentGroupMergedField synEnv(groupEnv, sg); + if (sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { + synEnv.add(Type::Uint32.addConst(), "id_syn", "synAddress", + {synEnv.addInitialiser("const unsigned int synAddress = $(_remap)[($(_sh_spk)[j] * $(_col_stride)) + $(id)];")}); // **OPTIMIZE** we can do a fast constant divide optimization here - os << "const unsigned int ipre = synAddress / group->rowStride;" << std::endl; - synSubs.addVarSubstitution("id_pre", "ipre"); + synEnv.add(Type::Uint32.addConst(), "id_pre", "idPre", + {synEnv.addInitialiser("const unsigned int idPre = $(synEnv) / $(_row_stride);"}); } else { - os << "const unsigned int synAddress = (" << synSubs["id"] << " * group->numTrgNeurons) + shSpk[j];" << std::endl; - synSubs.addVarSubstitution("id_pre", synSubs["id"]); + synEnv.add(Type::Uint32.addConst(), "id_syn", "synAddress", + {synEnv.addInitialiser("const unsigned int synAddress = ($(id) * $(num_post)) + $(_sh_spk)[j];")}); + + synEnv.add(Type::Uint32.addConst(), "id_pre", "$(id)"); } - synSubs.addVarSubstitution("id_post", "shSpk[j]"); - synSubs.addVarSubstitution("id_syn", "synAddress"); + synEnv.add(Type::Uint32.addConst(), "id_post", "$(_sh_spk)[j]"); - if(sg.getArchetype().isPresynapticOutputRequired()) { - synSubs.addFuncSubstitution("addToPre", 1, - getAtomic(modelMerged.getModel().getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); - } - - sg.generateSynapseUpdate(*this, os, modelMerged, synSubs); + synEnv.add(Type::AddToPre, "addToPre", getAtomic(modelMerged.getModel().getPrecision()) + "(&$(_out_pre)[" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], $(0))"); + + sg.generateSynapseUpdate(*this, synEnv, modelMerged); if (sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << CodeStream::CB(1540); + synEnv.getStream() << CodeStream::CB(1540); } } } @@ -829,68 +813,63 @@ void BackendSIMT::genPostsynapticUpdateKernel(CodeStream &os, const Substitution ); } //-------------------------------------------------------------------------- -void BackendSIMT::genSynapseDynamicsKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genSynapseDynamicsKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const { // Parallelise over synapse groups whose weight update models have code for synapse dynamics idStart = 0; genParallelGroup( - os, kernelSubs, modelMerged.getMergedSynapseDynamicsGroups(), idStart, + env, modelMerged.getMergedSynapseDynamicsGroups(), idStart, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumSynapseDynamicsThreads(sg), KernelSynapseDynamicsUpdate); }, - [&modelMerged, this](CodeStream &os, const SynapseDynamicsGroupMerged &sg, Substitutions &popSubs) + [&modelMerged, this](EnvironmentExternalBase &env, SynapseDynamicsGroupMerged &sg) { + EnvironmentGroupMergedField groupEnv(env); + // Generate index calculation code const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - genSynapseIndexCalculation(os, sg, batchSize); - - Substitutions synSubs(&popSubs); + genSynapseIndexCalculation(groupEnv, batchSize); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << "if (" << popSubs["id"] << " < (group->numSrcNeurons * group->rowStride))"; + groupEnv.print("if ($(id) < ($(num_pre) * $(_row_stride)))"); } else { - os << "if (" << popSubs["id"] << " < (group->numSrcNeurons * group->numTrgNeurons))"; + groupEnv.print("if ($(id( < ($(num_pre) * $(num_post)))"); } { - CodeStream::Scope b(os); + CodeStream::Scope b(groupEnv.getStream()); + EnvironmentGroupMergedField synEnv(groupEnv, sg); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // **OPTIMIZE * *we can do a fast constant divide optimization here and use the result to calculate the remainder - os << "const unsigned int row = " << popSubs["id"] << " / group->rowStride;" << std::endl; - os << "const unsigned int col = " << popSubs["id"] << " % group->rowStride;" << std::endl; + synEnv.printLine("const unsigned int row = $(id) / $(_row_stride);"); + synEnv.printLine("const unsigned int col = $(id) % $(_row_stride);"); - synSubs.addVarSubstitution("id_pre", "row"); - synSubs.addVarSubstitution("id_post", "group->ind[" + popSubs["id"] + "]"); - synSubs.addVarSubstitution("id_syn", popSubs["id"]); + synEnv.add(Type::Uint32.addConst(), "id_pre", "row"); + synEnv.add(Type::Uint32.addConst(), "id_post", "$(_ind)[$(id)]"); - os << "if(col < group->rowLength[row])"; - os << CodeStream::OB(1); + synEnv.getStream() << "if(col < " << synEnv["_row_length"] << "[row])"; + synEnv.getStream() << CodeStream::OB(1); } else { // **OPTIMIZE** we can do a fast constant divide optimization here and use the result to calculate the remainder - synSubs.addVarSubstitution("id_pre", "(" + popSubs["id"] + " / group->rowStride)"); - synSubs.addVarSubstitution("id_post", "(" + popSubs["id"] + " % group->rowStride)"); - synSubs.addVarSubstitution("id_syn", popSubs["id"]); + synEnv.add(Type::Uint32.addConst(), "id_pre", "idPre", + {synEnv.addInitialiser("const unsigned int idPre = ($(id) / $(_row_stride))")}); + synEnv.add(Type::Uint32.addConst(), "id_post", "idPost", + {synEnv.addInitialiser("const unsigned int idPost = ($(id) % $(_row_stride)")}); } - // If dendritic delay is required, always use atomic operation to update dendritic delay buffer - // **TODO** once synapse dynamics gets refactored into update strategy classes, move the index building code elsewhere - if(sg.getArchetype().isDendriticDelayRequired()) { - synSubs.addFuncSubstitution("addToInSynDelay", 2, getAtomic(modelMerged.getModel().getPrecision()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, synSubs["id_post"], "$(1)") + "], $(0))"); - } - // Otherwise - else { - synSubs.addFuncSubstitution("addToInSyn", 1, getAtomic(modelMerged.getModel().getPrecision()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, synSubs["id_post"]) + "], $(0))"); - } + synEnv.add(Type::Uint32.addConst(), "id_syn", "$(id)"]); + + synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", + getAtomic(modelMerged.getModel().getPrecision()) + "(&$(_den_delay)[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); + synEnv.add(Type::AddToPost, "addToPost", + getAtomic(modelMerged.getModel().getPrecision()) + "(&$(_out_post)[" + sg.getPostISynIndex(batchSize, "$(id_post)") + "], $(0))"); + synEnv.add(Type::AddToPre, "addToPre", + getAtomic(modelMerged.getModel().getPrecision()) + "(&$(_out_pre)[" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], $(0))"); - if(sg.getArchetype().isPresynapticOutputRequired()) { - synSubs.addFuncSubstitution("addToPre", 1, - getAtomic(modelMerged.getModel().getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); - } - - sg.generateSynapseUpdate(*this, os, modelMerged, synSubs); + sg.generateSynapseUpdate(*this, synEnv, modelMerged); if (sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << CodeStream::CB(1); + synEnv.getStream() << CodeStream::CB(1); } } }); @@ -906,81 +885,81 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpe return getPaddedNumCustomUpdateThreads(cu, modelMerged.getModel().getBatchSize()); }, [&updateGroup](const CustomUpdateGroupMerged &cg) { return (cg.getArchetype().getUpdateGroupName() == updateGroup); }, - [&modelMerged, this](EnvironmentExternal &env, const CustomUpdateGroupMerged &cg) + [&modelMerged, this](EnvironmentExternalBase &env, CustomUpdateGroupMerged &cg) { const size_t blockSize = getKernelBlockSize(KernelCustomUpdate); const unsigned int batchSize = modelMerged.getModel().getBatchSize(); // If update is a batch reduction - EnvironmentSubstitute cuEnv(env); if(cg.getArchetype().isBatchReduction()) { - cuEnv.getStream() << "// only do this for existing neurons" << std::endl; - cuEnv.getStream() << "if(" << cuEnv.getName("id") << " < group->size)"; + env.getStream() << "// only do this for existing neurons" << std::endl; + env.getStream() << "if(" << env["id"] << " < group->size)"; { - CodeStream::Scope b(cuEnv.getStream()); + CodeStream::Scope b(env.getStream()); + EnvironmentGroupMergedField groupEnv(env); // Initialise reduction targets - const auto reductionTargets = genInitReductionTargets(cuEnv.getStream(), cg, cuEnv.getName("id")); + const auto reductionTargets = genInitReductionTargets(groupEnv.getStream(), cg, groupEnv["id"]); // Loop through batches // **TODO** this naive approach is good for reduction when there are lots of neurons/synapses but, // if this isn't the case (TF uses a threshold of 4096), we should do something smarter - cuEnv.getStream() << "for(unsigned int batch = 0; batch < " << batchSize << "; batch++)"; + groupEnv.getStream() << "for(unsigned int batch = 0; batch < " << batchSize << "; batch++)"; { - CodeStream::Scope b(cuEnv.getStream()); - cuEnv.addSubstitution("batch", "batch"); + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.add(Type::Uint32.addConst(), "batch", "batch"); - genCustomUpdateIndexCalculation(cuEnv.getStream(), cg); + genCustomUpdateIndexCalculation(groupEnv); // **THINK** it would be great to 'lift' reads of SHARED variables out of this loop - cg.generateCustomUpdate(*this, cuEnv); + cg.generateCustomUpdate(*this, groupEnv); // Loop through reduction targets and generate reduction for(const auto &r : reductionTargets) { - cuEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; + groupEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; } } // Loop through reduction targets and write reduced value back to memory for(const auto &r : reductionTargets) { - cuEnv.getStream() << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; + groupEnv.getStream() << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; } } } // Otherwise, if this is a neuron reduction else if (cg.getArchetype().isNeuronReduction()) { - cuEnv.getStream() << "// only do this for existing neurons" << std::endl; - cuEnv.getStream() << "if(" << cuEnv.getName("id") << " < " << (32 * modelMerged.getModel().getBatchSize()) << ")"; + env.getStream() << "// only do this for existing neurons" << std::endl; + env.getStream() << "if(" << env["id"] << " < " << (32 * modelMerged.getModel().getBatchSize()) << ")"; { - CodeStream::Scope b(cuEnv.getStream()); + CodeStream::Scope b(env.getStream()); + EnvironmentGroupMergedField groupEnv(env); // Split ID into lane and batch - cuEnv.getStream() << "const unsigned int lane = " << cuEnv.getName("id") << " % 32;" << std::endl; - cuEnv.getStream() << "const unsigned int batch = " << cuEnv.getName("id") << " / 32;" << std::endl; - cuEnv.addSubstitution("batch", "batch"); + groupEnv.getStream() << "const unsigned int lane = " << env["id"] << " % 32;" << std::endl; + groupEnv.getStream() << "const unsigned int batch = " << env["id"] << " / 32;" << std::endl; + groupEnv.add(Type::Uint32.addConst(), "batch", "batch"); - genCustomUpdateIndexCalculation(cuEnv.getStream(), cg); + genCustomUpdateIndexCalculation(groupEnv); // Initialise reduction targets - const auto reductionTargets = genInitReductionTargets(cuEnv.getStream(), cg); + const auto reductionTargets = genInitReductionTargets(groupEnv.getStream(), cg); // Loop through warps of data // **TODO** this approach is good for reductions where there are small numbers of neurons but large batches sizes but, // if this isn't the case (TF uses a threshold of 1024), we should do something smarter - cuEnv.getStream() << "for(unsigned int idx = lane; idx < group->size; idx += 32)"; + groupEnv.getStream() << "for(unsigned int idx = lane; idx < " << groupEnv["size"] << "; idx += 32)"; { - CodeStream::Scope b(cuEnv.getStream()); + CodeStream::Scope b(groupEnv.getStream()); // Re-substitute id with loop index - EnvironmentSubstitute reductionEnv(cuEnv); - reductionEnv.addSubstitution("id", "idx"); + groupEnv.add(Type::Uint32.addConst(), "id", "idx"); // **THINK** it would be great to 'lift' reads of NEURON_SHARED variables out of this loop - cg.generateCustomUpdate(*this, reductionEnv); + cg.generateCustomUpdate(*this, groupEnv); // Loop through reduction targets and generate reduction for (const auto &r : reductionTargets) { - reductionEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; + groupEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; } } @@ -988,50 +967,51 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpe // **YUCK** CUDA-specific for (unsigned int i = 16; i > 0; i /= 2) { for (const auto &r : reductionTargets) { - cuEnv.getStream() << getReductionOperation("lr" + r.name, "__shfl_down_sync(0xFFFFFFFF, lr" + r.name + ", " + std::to_string(i) + ")", - r.access, r.type) << ";" << std::endl; + groupEnv.getStream() << getReductionOperation("lr" + r.name, "__shfl_down_sync(0xFFFFFFFF, lr" + r.name + ", " + std::to_string(i) + ")", + r.access, r.type) << ";" << std::endl; } } // In first lane, loop through reduction targets and write reduced value back to memory - cuEnv.getStream() << "if(lane == 0)"; + groupEnv.getStream() << "if(lane == 0)"; { - CodeStream::Scope b(cuEnv.getStream()); + CodeStream::Scope b(groupEnv.getStream()); for (const auto &r : reductionTargets) { - cuEnv.getStream() << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; + groupEnv.getStream() << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; } } } } // Otherwise else { + EnvironmentGroupMergedField groupEnv(env); + if(cg.getArchetype().isBatched()) { // Split ID into intra-batch ID and batch // **TODO** fast-divide style optimisations here - cuEnv.getStream() << "const unsigned int paddedSize = " << blockSize << " * ((group->size + " << blockSize << " - 1) / " << blockSize << ");" << std::endl; - cuEnv.getStream() << "const unsigned int bid = " << cuEnv.getName("id") << " % paddedSize;" << std::endl; - cuEnv.getStream() << "const unsigned int batch = " << cuEnv.getName("id") << " / paddedSize;" << std::endl; - + const std::string blockSizeStr = std::to_string(blockSize); + const size_t paddedSizeInit = groupEnv.addInitialiser("const unsigned int paddedSize = " + blockSizeStr + " * (($(size) + " + blockSizeStr + " - 1) / " + blockSizeStr + ");" << std::endl; + // Replace id in substitution with intra-batch ID and add batch - cuEnv.addSubstitution("id", "bid"); - cuEnv.addSubstitution("batch", "batch"); + groupEnv.add(Type::Uint32.addConst(), "id", "bid", + {paddedSizeInit, groupEnv.addInitialiser("const unsigned int bid = $(id) % paddedSize;")}); + groupEnv.add(Type::Uint32.addConst(), "batch", "batch", + {paddedSizeInit, groupEnv.addInitialiser("const unsigned int batch = $(id) / paddedSize;")}); } // Otherwise, just substitute "batch" for 0 else { - cuEnv.addSubstitution("batch", "0"); + groupEnv.add(Type::Uint32.addConst(), "batch", "0"); } - cuEnv.getStream() << "// only do this for existing neurons" << std::endl; - cuEnv.getStream() << "if(" << cuEnv.getName("id") << " < group->size)"; + groupEnv.getStream() << "// only do this for existing neurons" << std::endl; + groupEnv.print("if($(id) < $(size))"); { - CodeStream::Scope b(cuEnv.getStream()); + CodeStream::Scope b(groupEnv.getStream()); - genCustomUpdateIndexCalculation(cuEnv.getStream(), cg); - cg.generateCustomUpdate(*this, cuEnv); + genCustomUpdateIndexCalculation(groupEnv); + cg.generateCustomUpdate(*this, groupEnv); } } - - }); } //-------------------------------------------------------------------------- @@ -1045,7 +1025,7 @@ void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, const ModelS return getPaddedNumCustomUpdateWUThreads(cg, modelMerged.getModel().getBatchSize()); }, [&updateGroup](const CustomUpdateWUGroupMerged &cg) { return (cg.getArchetype().getUpdateGroupName() == updateGroup); }, - [&modelMerged, this](EnvironmentExternal &env, const CustomUpdateWUGroupMerged &cg) + [&modelMerged, this](EnvironmentExternalBase &env, CustomUpdateWUGroupMerged &cg) { const SynapseGroupInternal *sg = cg.getArchetype().getSynapseGroup(); const size_t blockSize = getKernelBlockSize(KernelCustomUpdate); @@ -1056,7 +1036,7 @@ void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, const ModelS // Loop through kernel dimensions and multiply together env.getStream() << "const unsigned int size = "; for (size_t i = 0; i < sg->getKernelSize().size(); i++) { - env.getStream() << cg.getKernelSize(i); + env.print(getKernelSize(cg, i)); if (i != (sg->getKernelSize().size() - 1)) { env.getStream() << " * "; } @@ -1064,102 +1044,99 @@ void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, const ModelS env.getStream() << ";" << std::endl; } else { - env.getStream() << "const unsigned int size = group->numSrcNeurons * group->rowStride;" << std::endl; + env.printLine("const unsigned int size = $(num_pre) * $(_row_stride);"); } // If update isn't a batch reduction - EnvironmentSubstitute cuEnv(env); + EnvironmentGroupMergedField groupEnv(env, cg); if(!cg.getArchetype().isBatchReduction()) { // If it's batched if(cg.getArchetype().isBatched()) { - cuEnv.getStream() << "const unsigned int paddedSize = " << blockSize << " * ((size + " << blockSize << " - 1) / " << blockSize << ");" << std::endl; - - // Split ID into intra-batch ID and batch + // Split ID into intra-batch ID and batch // **TODO** fast-divide style optimisations here - cuEnv.getStream() << "const unsigned int bid = " << cuEnv.getName("id") << " % paddedSize;" << std::endl; - cuEnv.getStream() << "const unsigned int batch = " << cuEnv.getName("id") << " / paddedSize;" << std::endl; - + const std::string blockSizeStr = std::to_string(blockSize); + const size_t paddedSizeInit = groupEnv.addInitialiser("const unsigned int paddedSize = " + blockSizeStr + " * ((size + " + blockSizeStr + " - 1) / " + blockSizeStr + ");" << std::endl; + // Replace id in substitution with intra-batch ID and add batch - cuEnv.addSubstitution("id", "bid"); - cuEnv.addSubstitution("batch", "batch"); - - // Calculate batch offset - cuEnv.getStream() << "const unsigned int batchOffset = size * batch;" << std::endl; + groupEnv.add(Type::Uint32.addConst(), "id", "bid", + {paddedSizeInit, groupEnv.addInitialiser("const unsigned int bid = $(id) % paddedSize;")}); + groupEnv.add(Type::Uint32.addConst(), "batch", "batch", + {paddedSizeInit, groupEnv.addInitialiser("const unsigned int batch = $(id) / paddedSize;")}); + groupEnv.add(Type::Uint32.addConst(), "_batch_offset", "batchOffset", + {groupEnv.addInitialiser("const unsigned int batchOffset = size * $(batch);")}); } // Otherwise, just substitute "batch" for 0 else { - cuEnv.addSubstitution("batch", "0"); + groupEnv.add(Type::Uint32.addConst(), "batch", "0"); } } // if this isn't a padding thread - cuEnv.getStream() << "if (" << cuEnv.getName("id") << " < size)"; + groupEnv.getStream() << "if (" << groupEnv["id"] << " < size)"; { - CodeStream::Scope b(cuEnv.getStream()); + CodeStream::Scope b(groupEnv.getStream()); + EnvironmentGroupMergedField synEnv(groupEnv, cg); if (sg->getMatrixType() & SynapseMatrixWeight::KERNEL) { - cuEnv.addSubstitution("id_syn", cuEnv.getName("id")); - cuEnv.addSubstitution("id_kernel", cuEnv.getName("id")); + synEnv.add(Type::Uint32.addConst(), "id_syn", "$(id)"); + synEnv.add(Type::Uint32.addConst(), "id_kernel", "$(id)"); } else { if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // **OPTIMIZE * *we can do a fast constant divide optimization here and use the result to calculate the remainder - cuEnv.getStream() << "const unsigned int row = " << cuEnv.getName("id") << " / group->rowStride;" << std::endl; - cuEnv.getStream() << "const unsigned int col = " << cuEnv.getName("id") << " % group->rowStride;" << std::endl; - - cuEnv.addSubstitution("id_pre", "row"); - cuEnv.addSubstitution("id_post", "group->ind[" + cuEnv.getName("id") + "]"); - cuEnv.addSubstitution("id_syn", cuEnv.getName("id")); + synEnv.printLine("const unsigned int row = $(id) / $(_row_stride);"); + synEnv.printLine("const unsigned int col = $(id) % $(_row_stride);"); - cuEnv.getStream() << "if(col < group->rowLength[row])"; - cuEnv.getStream() << CodeStream::OB(2); + synEnv.add(Type::Uint32.addConst(), "id_pre", "row"); + synEnv.add(Type::Uint32.addConst(), "id_post", "$(_ind)[$(id)]"); + + synEnv.print("if(col < $(_row_length)[row])"); + synEnv.getStream() << CodeStream::OB(2); } else { // **OPTIMIZE** we can do a fast constant divide optimization here and use the result to calculate the remainder - cuEnv.addSubstitution("id_pre", "(" + cuEnv.getName("id") + " / group->rowStride)"); - cuEnv.addSubstitution("id_post", "(" +cuEnv.getName("id") + " % group->rowStride)"); - cuEnv.addSubstitution("id_syn", cuEnv.getName("id")); + synEnv.add(Type::Uint32.addConst(), "id_pre", "idPre", + {synEnv.addInitialiser("const unsigned int idPre = $(id) / $(_row_stride)")}); + synEnv.add(Type::Uint32.addConst(), "id_post", "idPost", + {synEnv.addInitialiser("const unsigned int idPost = $(id) % $(_row_stride)")}); } } + synEnv.add(Type::Uint32.addConst(), "id_syn", "$(id)"); + // Initialise reduction targets - const auto reductionTargets = genInitReductionTargets(cuEnv.getStream(), cg, cuEnv.getName("id_syn")); + const auto reductionTargets = genInitReductionTargets(synEnv.getStream(), cg, synEnv["id_syn"])); // If this is a reduction if(cg.getArchetype().isBatchReduction()) { // Loop through batches // **TODO** this naive approach is good for reduction when there are lots of neurons/synapses but, // if this isn't the case (TF uses a threshold of 4096), we should do something smarter - cuEnv.getStream() << "for(unsigned int batch = 0; batch < " << batchSize << "; batch++)"; - cuEnv.getStream() << CodeStream::OB(1); - cuEnv.addSubstitution("batch", "batch"); + synEnv.getStream() << "for(unsigned int batch = 0; batch < " << batchSize << "; batch++)"; + synEnv.getStream() << CodeStream::OB(1); + synEnv.add(Type::Uint32.addConst(), "batch", "batch"); } - // Calculate batch offset if required - if(cg.getArchetype().isBatched()) { - cuEnv.getStream() << "const unsigned int batchOffset = size * batch;" << std::endl; - } - - cg.generateCustomUpdate(*this, cuEnv); + cg.generateCustomUpdate(*this, synEnv); // If this is a reduction if(cg.getArchetype().isBatchReduction()) { // Loop through reduction targets and generate reduction for(const auto &r : reductionTargets) { - cuEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; + synEnv.getStream() << getReductionOperation("lr" + r.name, "l" + r.name, r.access, r.type) << ";" << std::endl; } // End for loop through batches - cuEnv.getStream() << CodeStream::CB(1); + synEnv.getStream() << CodeStream::CB(1); // Loop through reduction targets and write reduced value back to memory for(const auto &r : reductionTargets) { - cuEnv.getStream() << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; + synEnv.getStream() << "group->" << r.name << "[" << r.index << "] = lr" << r.name << ";" << std::endl; } } if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - cuEnv.getStream() << CodeStream::CB(2); + synEnv.getStream() << CodeStream::CB(2); } } }); @@ -1179,8 +1156,10 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, con return getPaddedNumCustomUpdateTransposeWUThreads(cg, modelMerged.getModel().getBatchSize()); }, [&updateGroup](const CustomUpdateTransposeWUGroupMerged &cg) { return (cg.getArchetype().getUpdateGroupName() == updateGroup); }, - [&modelMerged, this, blockSize](EnvironmentExternal &env, const CustomUpdateTransposeWUGroupMerged &cg) + [&modelMerged, this, blockSize](EnvironmentExternalBase &env, CustomUpdateTransposeWUGroupMerged &cg) { + EnvironmentGroupMergedField groupEnv(env, cg); + // Get index of variable being transposed const size_t transposeVarIdx = std::distance(cg.getArchetype().getVarReferences().cbegin(), std::find_if(cg.getArchetype().getVarReferences().cbegin(), cg.getArchetype().getVarReferences().cend(), @@ -1188,64 +1167,63 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, con const std::string transposeVarName = cg.getArchetype().getCustomUpdateModel()->getVarRefs().at(transposeVarIdx).name; // To allow these kernels to be batched, we turn 2D grid into wide 1D grid of 2D block so calculate size - env.getStream() << "const unsigned int numXBlocks = (group->numTrgNeurons + " << (blockSize - 1) << ") / " << blockSize << ";" << std::endl; + groupEnv.getStream() << "const unsigned int numXBlocks = (" << groupEnv["num_post"] << " + " << (blockSize - 1) << ") / " << blockSize << ";" << std::endl; // Calculate what block this kernel starts at (because of kernel merging, it may not start at block 0) - env.getStream() << "const unsigned int blockStart = " << env.getName("group_start_id") << " / " << blockSize << ";" << std::endl; + groupEnv.getStream() << "const unsigned int blockStart = " << groupEnv["_group_start_id"] << " / " << blockSize << ";" << std::endl; - EnvironmentSubstitute synEnv(env); if(cg.getArchetype().isBatched()) { // If there's multiple batches we also need to know how many Y blocks and hence total blocks there are - synEnv.getStream() << "const unsigned int numYBlocks = (group->numSrcNeurons + " << (blockSize - 1) << ") / " << blockSize << ";" << std::endl; - synEnv.getStream() << "const unsigned int numBlocks = numXBlocks * numYBlocks;" << std::endl; + groupEnv.getStream() << "const unsigned int numYBlocks = (" << groupEnv["num_pre"] << " + " << (blockSize - 1) << ") / " << blockSize << ";" << std::endl; + groupEnv.getStream() << "const unsigned int numBlocks = numXBlocks * numYBlocks;" << std::endl; // Therefore determine block and batch - synEnv.getStream() << "const unsigned int batchBlock = " << getBlockID(0) << " - blockStart;" << std::endl; - synEnv.getStream() << "const unsigned int block = batchBlock % numBlocks;" << std::endl; - synEnv.getStream() << "const unsigned int batch = batchBlock / numBlocks;" << std::endl; + groupEnv.getStream() << "const unsigned int batchBlock = " << getBlockID(0) << " - blockStart;" << std::endl; + groupEnv.getStream() << "const unsigned int block = batchBlock % numBlocks;" << std::endl; + groupEnv.getStream() << "const unsigned int batch = batchBlock / numBlocks;" << std::endl; // Finally, calculate batch offset into arrays etc - synEnv.getStream() << "const unsigned int batchOffset = batch * group->numSrcNeurons * group->numTrgNeurons;" << std::endl; + groupEnv.printLine("const unsigned int batchOffset = batch * $(num_pre) * $(num_post);"); // Add batch to substitutions - synEnv.addSubstitution("batch", "batch"); + groupEnv.add(Type::Uint32.addConst(), "batch", "batch"); } // Otherwise, just substitute "batch" for 0 else { - synEnv.getStream() << "const unsigned int block = " << getBlockID(0) << " - blockStart;" << std::endl; - synEnv.addSubstitution("batch", "0"); + groupEnv.getStream() << "const unsigned int block = " << getBlockID(0) << " - blockStart;" << std::endl; + groupEnv.add(Type::Uint32.addConst(), "batch", "0"); } // Divide block index into x and y // **TODO** fast-divide style optimisations here - synEnv.getStream() << "const unsigned int blockX = (block % numXBlocks);" << std::endl; - synEnv.getStream() << "const unsigned int blockY = (block / numXBlocks);" << std::endl; + groupEnv.getStream() << "const unsigned int blockX = (block % numXBlocks);" << std::endl; + groupEnv.getStream() << "const unsigned int blockY = (block / numXBlocks);" << std::endl; { - CodeStream::Scope b(synEnv.getStream()); - synEnv.getStream() << "// Calculate coordinate of thread in input matrix" << std::endl; - synEnv.getStream() << "const unsigned int x = (blockX * " << blockSize << ") + " << getThreadID(0) << ";" << std::endl; - synEnv.getStream() << "const unsigned int y = (blockY * " << blockSize << ") + " << getThreadID(1) << ";" << std::endl; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "// Calculate coordinate of thread in input matrix" << std::endl; + groupEnv.getStream() << "const unsigned int x = (blockX * " << blockSize << ") + " << getThreadID(0) << ";" << std::endl; + groupEnv.getStream() << "const unsigned int y = (blockY * " << blockSize << ") + " << getThreadID(1) << ";" << std::endl; - synEnv.getStream() << "// If thread isn't off the 'right' edge of the input matrix" << std::endl; - synEnv.getStream() << "if(x < group->numTrgNeurons)"; + groupEnv.getStream() << "// If thread isn't off the 'right' edge of the input matrix" << std::endl; + groupEnv.getStream() << "if(x < " << groupEnv["num_post"] << ")"; { - CodeStream::Scope b(synEnv.getStream()); - synEnv.getStream() << "// Loop through input rows " << std::endl; - synEnv.getStream() << "for (unsigned int j = 0; j < " << blockSize << "; j += 8)"; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "// Loop through input rows " << std::endl; + groupEnv.getStream() << "for (unsigned int j = 0; j < " << blockSize << "; j += 8)"; { - CodeStream::Scope b(synEnv.getStream()); - synEnv.getStream() << "// If thread isn't off the 'bottom' edge of the input matrix" << std::endl; - synEnv.getStream() << "if((y + j) < group->numSrcNeurons)"; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "// If thread isn't off the 'bottom' edge of the input matrix" << std::endl; + groupEnv.getStream() << "if((y + j) < " << groupEnv["num_pre"] << ")"; { - CodeStream::Scope b(synEnv.getStream()); - synEnv.getStream() << "// Read forward weight from global memory" << std::endl; - synEnv.getStream() << "const unsigned int idx = ((y + j) * group->numTrgNeurons) + x;" << std::endl; + CodeStream::Scope b(groupEnv.getStream()); + EnvironmentGroupMergedField synEnv(groupEnv, cg); - synEnv.addSubstitution("id_pre", "y"); - synEnv.addSubstitution("id_post", "x"); - synEnv.addSubstitution("id_syn", "idx"); - cg.generateCustomUpdate(*this, env); + synEnv.add(Type::Uint32.addConst(), "id_pre", "y"); + synEnv.add(Type::Uint32.addConst(), "id_post", "x"); + synEnv.add(Type::Uint32.addConst(), "id_syn", "idx", + {synEnv.addInitialiser("const unsigned int idx = ((y + j) * $(num_post)) + x;")}); + cg.generateCustomUpdate(*this, synEnv); // Write forward weight to shared memory synEnv.getStream() << "shTile[" << getThreadID(1) << " + j][" << getThreadID(0) << "] = l" << transposeVarName << ";" << std::endl; @@ -1255,28 +1233,28 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, con } genSharedMemBarrier(env.getStream()); { - CodeStream::Scope b(synEnv.getStream()); - synEnv.getStream() << "// Calculate (transposed) coordinate of thread in output matrix" << std::endl; - synEnv.getStream() << "const unsigned int x = (blockY * " << blockSize << ") + " << getThreadID(0) << ";" << std::endl; - synEnv.getStream() << "const unsigned int y = (blockX * " << blockSize << ") + " << getThreadID(1) << ";" << std::endl; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "// Calculate (transposed) coordinate of thread in output matrix" << std::endl; + groupEnv.getStream() << "const unsigned int x = (blockY * " << blockSize << ") + " << getThreadID(0) << ";" << std::endl; + groupEnv.getStream() << "const unsigned int y = (blockX * " << blockSize << ") + " << getThreadID(1) << ";" << std::endl; - synEnv.getStream() << "// If thread isn't off the 'right' edge of the output matrix" << std::endl; - synEnv.getStream() << "if(x < group->numSrcNeurons)"; + groupEnv.getStream() << "// If thread isn't off the 'bottom' edge of the output matrix" << std::endl; + groupEnv.getStream() << "if(x < " << groupEnv["num_pre"] << ")"; { CodeStream::Scope b(synEnv.getStream()); synEnv.getStream() << "// Loop through output rows" << std::endl; synEnv.getStream() << "for(unsigned int j = 0; j < " << blockSize << "; j += 8)"; { CodeStream::Scope b(synEnv.getStream()); - synEnv.getStream() << "// If thread isn't off the 'bottom' edge of the output matrix" << std::endl; - synEnv.getStream() << "if((y + j) < group->numTrgNeurons)"; + synEnv.getStream() << "// If thread isn't off the 'right' edge of the output matrix" << std::endl; + synEnv.getStream() << "if((y + j) < group" << groupEnv["num_post"] << ")"; { CodeStream::Scope b(synEnv.getStream()); synEnv.getStream() << "group->" << transposeVarName << "Transpose["; if(cg.getArchetype().isBatched()) { synEnv.getStream() << "batchOffset + "; } - synEnv.getStream() << "((y + j) * group->numSrcNeurons) + x] = shTile[" << getThreadID(0) << "][" << getThreadID(1) << " + j];" << std::endl; + synEnv.getStream() << "((y + j) * " << groupEnv["num_pre"] << ") + x] = shTile[" << getThreadID(0) << "][" << getThreadID(1) << " + j];" << std::endl; } } } @@ -1284,35 +1262,37 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, con }); } //-------------------------------------------------------------------------- -void BackendSIMT::genCustomConnectivityUpdateKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, +void BackendSIMT::genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const { // Parallelise across presynaptic neurons genParallelGroup( - os, kernelSubs, modelMerged.getMergedCustomConnectivityUpdateGroups(), idStart, + env, modelMerged.getMergedCustomConnectivityUpdateGroups(), idStart, [this](const CustomConnectivityUpdateInternal &cg) { return padSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelCustomUpdate); }, [&updateGroup](const CustomConnectivityUpdateGroupMerged &cg) { return (cg.getArchetype().getUpdateGroupName() == updateGroup); }, - [&modelMerged, this](CodeStream &os, const CustomConnectivityUpdateGroupMerged &cg, Substitutions &popSubs) + [&modelMerged, this](EnvironmentExternalBase &env, const CustomConnectivityUpdateGroupMerged &cg) { - os << "// only do this for existing presynaptic neurons" << std::endl; - os << "if(" << popSubs["id"] << " < group->numSrcNeurons)"; - { - CodeStream::Scope b(os); + EnvironmentGroupMergedField groupEnv(env, cg); + + genCustomConnectivityUpdateIndexCalculation(groupEnv); - genCustomConnectivityUpdateIndexCalculation(os, cg); + groupEnv.getStream() << "// only do this for existing presynaptic neurons" << std::endl; + groupEnv.print("if($(id) < $(num_pre))"); + { + CodeStream::Scope b(groupEnv.getStream()); // Configure substitutions - popSubs.addVarSubstitution("id_pre", popSubs["id"]); + groupEnv.add(Type::Uint32.addConst(), "id_pre", "$(id)"); // Copy global RNG stream to local and use pointer to this for rng if(cg.getArchetype().isRowSimRNGRequired()) { genPopulationRNGPreamble(os, popSubs, "group->rng[" + popSubs["id"] + "]"); } - cg.generateUpdate(*this, os, modelMerged.getModel().getBatchSize(), popSubs); + cg.generateUpdate(*this, groupEnv, modelMerged); // Copy local stream back to local if(cg.getArchetype().isRowSimRNGRequired()) { @@ -1322,20 +1302,20 @@ void BackendSIMT::genCustomConnectivityUpdateKernel(CodeStream &os, const Substi }); } //-------------------------------------------------------------------------- -void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const { os << "// ------------------------------------------------------------------------" << std::endl; os << "// Local neuron groups" << std::endl; idStart = 0; genParallelGroup( - os, kernelSubs, modelMerged.getMergedNeuronInitGroups(), idStart, + env, modelMerged.getMergedNeuronInitGroups(), idStart, [this](const NeuronGroupInternal &ng) { return padKernelSize(ng.getNumNeurons(), KernelInitialize); }, - [&modelMerged, this](CodeStream &os, const NeuronInitGroupMerged &ng, Substitutions &popSubs) + [&modelMerged, this](EnvironmentExternalBase &env, NeuronInitGroupMerged &ng) { - os << "// only do this for existing neurons" << std::endl; - os << "if(" << popSubs["id"] << " < group->numNeurons)"; + env.getStream() << "// only do this for existing neurons" << std::endl; + env.print("if($(id) < $(num_neurons))"); { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); // If population RNGs are initialised on device and this neuron is going to require one, if(isPopulationRNGInitialisedOnDevice() && ng.getArchetype().isSimRNGRequired()) { @@ -1345,9 +1325,9 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne } // Otherwise, loop through batches and initialise independent RNGs using GLOBAL thread id as basis of sequence else { - os << "for(unsigned int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; + env.getStream() << "for(unsigned int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); genPopulationRNGInit(os, "group->rng[(b * group->numNeurons) + " + popSubs["id"] + "]", "deviceRNGSeed", "(b * " + std::to_string(getNumInitialisationRNGStreams(modelMerged)) + ") + id"); } @@ -1362,7 +1342,7 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne genGlobalRNGSkipAhead(os, popSubs, "id"); } - ng.generateInit(*this, os, modelMerged, popSubs); + ng.generateInit(*this, env, modelMerged); } }); os << std::endl; @@ -1370,26 +1350,28 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne os << "// ------------------------------------------------------------------------" << std::endl; os << "// Synapse groups" << std::endl; genParallelGroup( - os, kernelSubs, modelMerged.getMergedSynapseInitGroups(), idStart, + env, modelMerged.getMergedSynapseInitGroups(), idStart, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumInitThreads(sg), KernelInitialize); }, - [&modelMerged, this](CodeStream &os, const SynapseInitGroupMerged &sg, Substitutions &popSubs) + [&modelMerged, this](EnvironmentExternalBase &env, SynapseInitGroupMerged &sg) { - genSynapseVarInit(os, modelMerged, sg, popSubs, sg.getArchetype().isWUInitRNGRequired(), - (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL), sg.getArchetype().getKernelSize().size()); + EnvironmentGroupMergedField groupEnv(env, sg); + genSynapseVarInit(groupEnv, modelMerged, sg.getArchetype().isWUInitRNGRequired(), + (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL), + sg.getArchetype().getKernelSize().size()); }); os << std::endl; os << "// ------------------------------------------------------------------------" << std::endl; os << "// Custom update groups" << std::endl; genParallelGroup( - os, kernelSubs, modelMerged.getMergedCustomUpdateInitGroups(), idStart, + env, modelMerged.getMergedCustomUpdateInitGroups(), idStart, [this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelInitialize); }, - [&modelMerged, this](CodeStream &os, const CustomUpdateInitGroupMerged &cg, Substitutions &popSubs) + [&modelMerged, this](EnvironmentExternalBase &env, CustomUpdateInitGroupMerged &cg) { - os << "// only do this for existing variables" << std::endl; - os << "if(" << popSubs["id"] << " < group->size)"; + env.getStream() << "// only do this for existing variables" << std::endl; + env.print("if($(id) < $(size))"); { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id @@ -1398,7 +1380,7 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne genGlobalRNGSkipAhead(os, popSubs, "id"); } - cg.generateInit(*this, os, modelMerged, popSubs); + cg.generateInit(*this, env, modelMerged); } }); os << std::endl; @@ -1406,12 +1388,13 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne os << "// ------------------------------------------------------------------------" << std::endl; os << "// Custom WU update groups" << std::endl; genParallelGroup( - os, kernelSubs, modelMerged.getMergedCustomWUUpdateInitGroups(), idStart, + env, modelMerged.getMergedCustomWUUpdateInitGroups(), idStart, [this](const CustomUpdateWUInternal &cg) { return padKernelSize(getNumInitThreads(cg), KernelInitialize); }, - [&modelMerged, this](CodeStream &os, const CustomWUUpdateInitGroupMerged &cg, Substitutions &popSubs) + [&modelMerged, this](EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &cg) { const SynapseGroup *sg = cg.getArchetype().getSynapseGroup(); - genSynapseVarInit(os, modelMerged, cg, popSubs, cg.getArchetype().isInitRNGRequired(), + EnvironmentGroupMergedField groupEnv(env, cg); + genSynapseVarInit(groupEnv, modelMerged, cg.getArchetype().isInitRNGRequired(), (sg->getMatrixType() & SynapseMatrixWeight::KERNEL), sg->getKernelSize().size()); }); os << std::endl; @@ -1419,14 +1402,14 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne os << "// ------------------------------------------------------------------------" << std::endl; os << "// Custom connectivity presynaptic update groups" << std::endl; genParallelGroup( - os, kernelSubs, modelMerged.getMergedCustomConnectivityUpdatePreInitGroups(), idStart, + env, modelMerged.getMergedCustomConnectivityUpdatePreInitGroups(), idStart, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelInitialize); }, - [&modelMerged, this](CodeStream &os, const CustomConnectivityUpdatePreInitGroupMerged &cg, Substitutions &popSubs) + [&modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdatePreInitGroupMerged &cg) { - os << "// only do this for existing variables" << std::endl; - os << "if(" << popSubs["id"] << " < group->size)"; + env.getStream() << "// only do this for existing variables" << std::endl; + env.print("if($(id) < $(size))"); { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); // If population RNGs are initialised on device and this custom connectivity update // required one, initialise single RNG using GLOBAL thread id for sequence @@ -1441,7 +1424,7 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne genGlobalRNGSkipAhead(os, popSubs, "id"); } - cg.generateInit(*this, os, modelMerged, popSubs); + cg.generateInit(*this, env, modelMerged); } }); os << std::endl; @@ -1449,14 +1432,14 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne os << "// ------------------------------------------------------------------------" << std::endl; os << "// Custom connectivity postsynaptic update groups" << std::endl; genParallelGroup( - os, kernelSubs, modelMerged.getMergedCustomConnectivityUpdatePostInitGroups(), idStart, + env, modelMerged.getMergedCustomConnectivityUpdatePostInitGroups(), idStart, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons(), KernelInitialize); }, - [&modelMerged, this](CodeStream &os, const CustomConnectivityUpdatePostInitGroupMerged &cg, Substitutions &popSubs) + [&modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdatePostInitGroupMerged &cg) { - os << "// only do this for existing variables" << std::endl; - os << "if(" << popSubs["id"] << " < group->size)"; + env.getStream() << "// only do this for existing variables" << std::endl; + env.print("if($(id) < $(size))"); { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id @@ -1465,7 +1448,7 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne genGlobalRNGSkipAhead(os, popSubs, "id"); } - cg.generateInit(*this, os, modelMerged, popSubs); + cg.generateInit(*this, env, modelMerged); } }); os << std::endl; @@ -1473,41 +1456,39 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne os << "// ------------------------------------------------------------------------" << std::endl; os << "// Synapse groups with sparse connectivity" << std::endl; genParallelGroup( - os, kernelSubs, modelMerged.getMergedSynapseConnectivityInitGroups(), idStart, + env, modelMerged.getMergedSynapseConnectivityInitGroups(), idStart, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumConnectivityInitThreads(sg), KernelInitialize); }, - [&modelMerged, this](CodeStream &os, const SynapseConnectivityInitGroupMerged &sg, Substitutions &popSubs) + [&modelMerged, this](EnvironmentExternalBase &env, SynapseConnectivityInitGroupMerged &sg) { + EnvironmentGroupMergedField groupEnv(env, sg); + // If there is row-building code in this snippet const auto *snippet = sg.getArchetype().getConnectivityInitialiser().getSnippet(); if(!snippet->getRowBuildCode().empty()) { - os << "// only do this for existing presynaptic neurons" << std::endl; - os << "if(" << popSubs["id"] << " < group->numSrcNeurons)"; + groupEnv.getStream() << "// only do this for existing presynaptic neurons" << std::endl; + groupEnv.print("if($(id) < $(num_pre))"); // Configure substitutions - popSubs.addVarSubstitution("id_pre", popSubs["id"]); - popSubs.addVarSubstitution("id_post_begin", "0"); - popSubs.addVarSubstitution("id_thread", "0"); - popSubs.addVarSubstitution("num_threads", "1"); - popSubs.addVarSubstitution("num_pre", "group->numSrcNeurons"); - popSubs.addVarSubstitution("num_post", "group->numTrgNeurons"); + groupEnv.add(Type::Uint32.addConst(), "id_pre", "$(id)"); + groupEnv.add(Type::Uint32.addConst(), "id_post_begin", "0"); + groupEnv.add(Type::Uint32.addConst(), "id_thread", "0"); + groupEnv.add(Type::Uint32.addConst(), "num_threads", "1"); } // Otherwise else { assert(!snippet->getColBuildCode().empty()); - os << "// only do this for existing postsynaptic neurons" << std::endl; - os << "if(" << popSubs["id"] << " < group->numTrgNeurons)"; + groupEnv.getStream() << "// only do this for existing postsynaptic neurons" << std::endl; + groupEnv.print("if($(id) < $(num_post))"); // Configure substitutions - popSubs.addVarSubstitution("id_post", popSubs["id"]); - popSubs.addVarSubstitution("id_pre_begin", "0"); - popSubs.addVarSubstitution("id_thread", "0"); - popSubs.addVarSubstitution("num_threads", "1"); - popSubs.addVarSubstitution("num_pre", "group->numSrcNeurons"); - popSubs.addVarSubstitution("num_post", "group->numTrgNeurons"); + groupEnv.add(Type::Uint32.addConst(), "id_post", "$(id)"); + groupEnv.add(Type::Uint32.addConst(), "id_pre_begin", "0"); + groupEnv.add(Type::Uint32.addConst(), "id_thread", "0"); + groupEnv.add(Type::Uint32.addConst(), "num_threads", "1"); } { - CodeStream::Scope b(os); + CodeStream::Scope b(groupEnv.getStream()); // Create new stream to generate addSynapse function which initializes all kernel variables std::ostringstream kernelInitStream; @@ -1521,10 +1502,10 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne // Calculate index in data structure of this synapse if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { if(!snippet->getRowBuildCode().empty()) { - kernelInit << "const unsigned int idx = " << "(" << popSubs["id_pre"] << " * group->rowStride) + group->rowLength[" << popSubs["id"] << "];" << std::endl; + kernelInit << "const unsigned int idx = " << "($(id_pre) * $(_row_stride)) + $(_row_length)[$(id)];" << std::endl; } else { - kernelInit << "const unsigned int idx = " << "(($(0)) * group->rowStride) + group->rowLength[$(0)];" << std::endl; + kernelInit << "const unsigned int idx = " << "(($(0)) * $(_row_stride))) + $(_row_length)[$(0)];" << std::endl; } } @@ -1623,7 +1604,7 @@ void BackendSIMT::genInitializeKernel(CodeStream &os, const Substitutions &kerne os << std::endl; } //-------------------------------------------------------------------------- -void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions &kernelSubs, const ModelSpecMerged &modelMerged, +void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t numInitializeThreads, size_t &idStart) const { // Shared memory array so row lengths don't have to be read by EVERY postsynaptic thread @@ -1631,9 +1612,9 @@ void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions os << getSharedPrefix() << "unsigned int shRowLength[" << getKernelBlockSize(KernelInitializeSparse) << "];" << std::endl; // Initialise weight update variables for synapse groups with sparse connectivity - genParallelGroup(os, kernelSubs, modelMerged.getMergedSynapseSparseInitGroups(), idStart, + genParallelGroup(env, modelMerged.getMergedSynapseSparseInitGroups(), idStart, [this](const SynapseGroupInternal &sg) { return padKernelSize(sg.getMaxConnections(), KernelInitializeSparse); }, - [numInitializeThreads, &modelMerged, this](CodeStream &os, const SynapseSparseInitGroupMerged &sg, Substitutions &popSubs) + [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, SynapseSparseInitGroupMerged &sg) { // If this post synapse requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id @@ -1644,33 +1625,33 @@ void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions // Generate sparse synapse variable initialisation code genSparseSynapseVarInit( - os, modelMerged, sg, popSubs, sg.getArchetype().isWUVarInitRequired(), - [this](CodeStream &os, const SynapseSparseInitGroupMerged &sg, Substitutions&) + env, modelMerged, sg, sg.getArchetype().isWUVarInitRequired(), + [this](EnvironmentExternalBase &env, SynapseSparseInitGroupMerged &sg) { // If postsynaptic learning is required if(!sg.getArchetype().getWUModel()->getLearnPostCode().empty()) { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); // Extract index of synapse's postsynaptic target - os << "const unsigned int postIndex = group->ind[idx];" << std::endl; + env.getStream() << "const unsigned int postIndex = " << env["_ind"] << "[idx];" << std::endl; // Atomically increment length of column of connectivity associated with this target // **NOTE** this returns previous length i.e. where to insert new entry - os << "const unsigned int colLocation = " << getAtomic(Type::Uint32) << "(&group->colLength[postIndex], 1);" << std::endl; + env.getStream() << "const unsigned int colLocation = " << getAtomic(Type::Uint32) << "(&" << env["_col_length"] << "[postIndex], 1);" << std::endl; // From this calculate index into column-major matrix - os << "const unsigned int colMajorIndex = (postIndex * group->colStride) + colLocation;" << std::endl; + env.getStream() << "const unsigned int colMajorIndex = (postIndex * " << env["_col_stride"] << ") + colLocation;" << std::endl; // Add remapping entry at this location poining back to row-major index - os << "group->remap[colMajorIndex] = idx;" << std::endl; + env.getStream() << "group->remap[colMajorIndex] = idx;" << std::endl; } }); }); // Initialise weight update variables for synapse groups with sparse connectivity - genParallelGroup(os, kernelSubs, modelMerged.getMergedCustomWUUpdateSparseInitGroups(), idStart, + genParallelGroup(env, modelMerged.getMergedCustomWUUpdateSparseInitGroups(), idStart, [this](const CustomUpdateWUInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, - [numInitializeThreads, &modelMerged, this](CodeStream &os, const CustomWUUpdateSparseInitGroupMerged &cg, Substitutions &popSubs) + [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, CustomWUUpdateSparseInitGroupMerged &cg) { // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id @@ -1681,14 +1662,14 @@ void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions // Generate sparse synapse variable initialisation code genSparseSynapseVarInit( - os, modelMerged, cg, popSubs, true, - [](CodeStream&, const CustomWUUpdateSparseInitGroupMerged&, Substitutions&){}); + env, modelMerged, cg, true, + [](EnvironmentExternalBase &env, CustomWUUpdateSparseInitGroupMerged&){}); }); // Initialise weight update variables for synapse groups with sparse connectivity - genParallelGroup(os, kernelSubs, modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups(), idStart, + genParallelGroup(env, modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups(), idStart, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, - [numInitializeThreads, &modelMerged, this](CodeStream &os, const CustomConnectivityUpdateSparseInitGroupMerged &cg, Substitutions &popSubs) + [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdateSparseInitGroupMerged &cg) { // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id @@ -1699,8 +1680,8 @@ void BackendSIMT::genInitializeSparseKernel(CodeStream &os, const Substitutions // Generate sparse synapse variable initialisation code genSparseSynapseVarInit( - os, modelMerged, cg, popSubs, true, - [](CodeStream&, const CustomConnectivityUpdateSparseInitGroupMerged&, Substitutions&){}); + env, modelMerged, cg, true, + [](EnvironmentExternalBase&, CustomConnectivityUpdateSparseInitGroupMerged&){}); }); } //-------------------------------------------------------------------------- @@ -1709,18 +1690,18 @@ size_t BackendSIMT::padKernelSize(size_t size, Kernel kernel) const return padSize(size, getKernelBlockSize(kernel)); } //-------------------------------------------------------------------------- -void BackendSIMT::genEmitSpike(const ModelSpecMerged &modelMerged, CodeStream &os, const Substitutions &subs, const std::string &suffix, bool recordingEnabled) const +void BackendSIMT::genEmitSpike(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, const std::string &suffix, bool recordingEnabled) const { - os << "const unsigned int spk" << suffix << "Idx = " << getAtomic(Type::Uint32, AtomicOperation::ADD, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Count, 1);" << std::endl; - os << "shSpk" << suffix << "[spk" << suffix << "Idx] = " << subs["id"] << ";" << std::endl; + env.getStream() << "const unsigned int spk" << suffix << "Idx = " << getAtomic(Type::Uint32, AtomicOperation::ADD, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Count, 1);" << std::endl; + env.getStream() << "shSpk" << suffix << "[spk" << suffix << "Idx] = " << env["id"] << ";" << std::endl; // If recording is enabled, set bit in recording word if(recordingEnabled) { if(m_KernelBlockSizes[KernelNeuronUpdate] == 32) { - os << getAtomic(Type::Uint32, AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record, 1 << " << getThreadID() << ");" << std::endl; + env.getStream() << getAtomic(Type::Uint32, AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record, 1 << " << getThreadID() << ");" << std::endl; } else { - os << getAtomic(Type::Uint32, AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record[" << getThreadID() << " / 32], 1 << (" << getThreadID() << " % 32));" << std::endl; + env.getStream() << getAtomic(Type::Uint32, AtomicOperation::OR, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Record[" << getThreadID() << " / 32], 1 << (" << getThreadID() << " % 32));" << std::endl; } } } diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index f9aad04da2..f3f55da3f9 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -548,6 +548,7 @@ void prettyPrintStatements(const std::string &code, const Type::TypeContext &typ // Pretty print Transpiler::PrettyPrinter::print(std::get<0>(statementTypes), env, typeContext, std::get<1>(statementTypes), forEachSynapsePrettyPrintHandler); } +//-------------------------------------------------------------------------- std::string printSubs(const std::string &format, EnvironmentExternalBase &env) { // Create regex iterator to iterate over $(XXX) style varibles in format string diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 6a9fb77582..57c44d6a9a 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -385,7 +385,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back [&backend, &modelMerged, &removeSynapseStream, this](auto &env, auto generateBody) { EnvironmentGroupMergedField bodyEnv(env, *this); - bodyEnv.getStream() << printSubs("for(int j = 0; j < $(_row_length)[$(id_pre)]; j++)", bodyEnv); + bodyEnv.print("for(int j = 0; j < $(_row_length)[$(id_pre)]; j++)"); { CodeStream::Scope b(bodyEnv.getStream()); diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index f5231bd211..1e68aa7dad 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -30,6 +30,16 @@ void EnvironmentExternalBase::define(const Token&, const Type::ResolvedType&, Er throw std::runtime_error("Cannot declare variable in external environment"); } //---------------------------------------------------------------------------- +void EnvironmentExternalBase::print(const std::string &format) +{ + getStream() << printSubs(format, *this); +} +//---------------------------------------------------------------------------- +void EnvironmentExternalBase::printLine(const std::string &format) +{ + getStream() << printSubs(format, *this) << std::endl; +} +//---------------------------------------------------------------------------- CodeStream &EnvironmentExternalBase::getContextStream() const { return std::visit( diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 95551bda2c..864792ad06 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -28,14 +28,14 @@ void genVariableFill(EnvironmentExternalBase &env, const std::string &target, co // If there's only one, don't generate a loop if(numValues == 1) { - env.getStream() << env[target] << "[" << env[idx] << "] = " << value << ";" << std::endl; + env.printLine("$(" + target + ")[$(" + idx + ")] = " + value + ";"); } // Otherwise else { env.getStream() << "for(unsigned int d = 0; d < " << numValues << "; d++)"; { CodeStream::Scope b(env.getStream()); - env.getStream() << env[target] << "[(d * " << printSubs(stride, env) << ") + " << env[idx] << "] = " << value << ";" << std::endl; + env.printLine("$(" + target + ")[(d * " + stride + ") + $(" + idx + ")] = " + value + ";"); } } } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 47b6112946..177899e4b8 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -87,10 +87,9 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); // Read into local variable + const std::string idx = ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); psmEnv.getStream() << "// postsynaptic model " << getIndex() << std::endl; - psmEnv.getStream() << "scalar linSyn = " << psmEnv["_out_post"] << "["; - psmEnv.getStream() << printSubs(ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"), psmEnv); - psmEnv.getStream() << "];" << std::endl; + psmEnv.printLine("scalar linSyn = $(_out_post)[" + idx + "];"); // If dendritic delay is required if (getArchetype().isDendriticDelayRequired()) { @@ -101,10 +100,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix();}); // Get reference to dendritic delay buffer input for this timestep - psmEnv.getStream() << backend.getPointerPrefix() << "scalar *denDelayFront = "; - psmEnv.getStream() << "&" << psmEnv["_den_delay"] << "[(*" << psmEnv["_den_delay_ptr"] << " * " << psmEnv["num_neurons"] << ") + "; - psmEnv.getStream() << printSubs(ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"), psmEnv); - psmEnv.getStream() << "];" << std::endl; + psmEnv.printLine(backend.getPointerPrefix() + "scalar *denDelayFront = &$(_den_delay)[(*$(_den_delay_ptr) * $(num_neurons)) + " + idx + "];"); // Add delayed input from buffer into inSyn psmEnv.getStream() << "linSyn += *denDelayFront;" << std::endl; @@ -140,9 +136,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env prettyPrintStatements(psm->getDecayCode(), getTypeContext(), varEnv, decayErrorHandler); // Write back linSyn - varEnv.getStream() << psmEnv["_out_post"] << "["; - varEnv.getStream() << printSubs(ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"), psmEnv); - varEnv.getStream() << "] = linSyn;" << std::endl; + varEnv.printLine("$(_out_post)[" + ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id") + "] = linSyn;"); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &hash) const @@ -178,15 +172,11 @@ void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backe [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); // Add reverse insyn variable to - outSynEnv.getStream() << getArchetype().getPreTargetVar() << " += "; - outSynEnv.getStream() << outSynEnv["_out_pre"] << "["; - outSynEnv.getStream() << printSubs(ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"), env); - outSynEnv.getStream() << "];" << std::endl; + const std::string idx = ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); + outSynEnv.printLine(getArchetype().getPreTargetVar() + " += $(_out_pre)[" + idx + "];"); // Zero it again - outSynEnv.getStream() << outSynEnv["_out_pre"] << "["; - outSynEnv.getStream() << printSubs(ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"), env); - outSynEnv.getStream() << "] = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; + outSynEnv.printLine("$(_out_pre)[" + idx + "] = " + modelMerged.scalarExpr(0.0) + ";"); } //---------------------------------------------------------------------------- @@ -252,13 +242,9 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots for(const auto &v : getArchetype().getWUModel()->getPostVars()) { if(v.access & VarAccessMode::READ_WRITE) { - env.getStream() << env[v.name] << "["; - env.getStream() << printSubs(ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), "id"), env); - env.getStream() << "] = "; - - env.getStream() << env[v.name] << "["; - env.getStream() << printSubs(ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), "id"), env); - env.getStream() << "];" << std::endl; + const unsigned int batchSize = modelMerged.getModel().getBatchSize(); + env.print("$(" + v.name + ")[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(v.access), "id") + "] = "); + env.printLine("$(" + v.name + ")[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(v.access), "id") + "];"); } } } @@ -344,13 +330,9 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots for(const auto &v : getArchetype().getWUModel()->getPreVars()) { if(v.access & VarAccessMode::READ_WRITE) { - env.getStream() << env[v.name] << "["; - env.getStream() << printSubs(ng.getWriteVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), "id"), env); - env.getStream() << "] = "; - - env.getStream() << env[v.name] << "["; - env.getStream() << printSubs(ng.getReadVarIndex(true, modelMerged.getModel().getBatchSize(), getVarAccessDuplication(v.access), "id"), env); - env.getStream() << "];" << std::endl; + const unsigned int batchSize = modelMerged.getModel().getBatchSize(); + env.print("$(" + v.name + ")[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(v.access), "id") + "] = "); + env.printLine("$(" + v.name + ")[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(v.access), "id") + "];"); } } } diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 447d5fe168..c841175481 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -75,52 +75,52 @@ size_t PreSpan::getSharedMemoryPerThread(const PresynapticUpdateGroupMerged&, co return 0; } //---------------------------------------------------------------------------- -void PreSpan::genPreamble(CodeStream &, const ModelSpecMerged&, const PresynapticUpdateGroupMerged&, - const Substitutions&, const BackendSIMT&) const +void PreSpan::genPreamble(EnvironmentExternalBase&, const ModelSpecMerged&, + PresynapticUpdateGroupMerged&, const BackendSIMT&) const { } //---------------------------------------------------------------------------- -void PreSpan::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, - const Substitutions &popSubs, const BackendSIMT &backend, bool trueSpike) const +void PreSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, + PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const { // Get suffix based on type of events const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); - const std::string eventSuffix = trueSpike ? "" : "Evnt"; + const std::string eventSuffix = trueSpike ? "" : "_evnt"; const auto *wu = sg.getArchetype().getWUModel(); const size_t numThreadsPerSpike = sg.getArchetype().getNumThreadsPerSpike(); if(numThreadsPerSpike > 1) { - os << "const unsigned int spike = " << popSubs["id"] << " / " << numThreadsPerSpike << ";" << std::endl; - os << "const unsigned int thread = " << popSubs["id"] << " % " << numThreadsPerSpike << ";" << std::endl; + env.getStream() << "const unsigned int spike = " << env["id"] << " / " << numThreadsPerSpike << ";" << std::endl; + env.getStream() << "const unsigned int thread = " << env["id"] << " % " << numThreadsPerSpike << ";" << std::endl; } else { - os << "const unsigned int spike = " << popSubs["id"] << ";" << std::endl; + env.getStream() << "const unsigned int spike = " << env["id"] << ";" << std::endl; } if(sg.getArchetype().isPresynapticOutputRequired()) { - os << "scalar lrevInSyn= 0.0;" << std::endl; + env.getStream() << "scalar lrevInSyn= 0.0;" << std::endl; } - os << "if (spike < group->srcSpkCnt" << eventSuffix << "[" << sg.getPreSlot(batchSize) << "])"; + env.print("if (spike < $(_src_spk_cnt" + eventSuffix + ")[" + sg.getPreSlot(batchSize) + "])"); { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); - if(backend.supportsNamespace() && !wu->getSimSupportCode().empty()) { + /*if(backend.supportsNamespace() && !wu->getSimSupportCode().empty()) { os << "using namespace " << modelMerged.getPresynapticUpdateSupportCodeNamespace(wu->getSimSupportCode()) << ";" << std::endl; - } + }*/ - os << "const unsigned int preInd = group->srcSpk" << eventSuffix << "[" << sg.getPreVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "spike") << "];" << std::endl; + env.printLine("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "spike") + "];"); if(numThreadsPerSpike > 1) { - os << "unsigned int synAddress = (preInd * group->rowStride) + thread;" << std::endl; + env.printLine("unsigned int synAddress = (preInd * $(_row_stride)) + thread;"); } else { - os << "unsigned int synAddress = preInd * group->rowStride;" << std::endl; + env.printLine("unsigned int synAddress = preInd * $(_row_stride);"); } - os << "const unsigned int npost = group->rowLength[preInd];" << std::endl; + env.printLine("const unsigned int npost = $(_row_length)[preInd];"); - if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { + /*if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { os << "if("; Substitutions threshSubs(&popSubs); @@ -133,65 +133,56 @@ void PreSpan::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, cons os << ")"; os << CodeStream::OB(130); - } + }*/ if(numThreadsPerSpike > 1) { - os << "for(unsigned int i = thread; i < npost; i += " << numThreadsPerSpike << ", synAddress += " << numThreadsPerSpike << ")"; + env.getStream() << "for(unsigned int i = thread; i < npost; i += " << numThreadsPerSpike << ", synAddress += " << numThreadsPerSpike << ")"; } else { - os << "for(unsigned int i = 0; i < npost; i++, synAddress++)"; + env.getStream() << "for(unsigned int i = 0; i < npost; i++, synAddress++)"; } { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); // **TODO** pretty sure __ldg will boost performance here - basically will bring whole row into cache - os << "const unsigned int ipost = group->ind[synAddress];" << std::endl; + env.printLine("const unsigned int ipost = $(_ind)[synAddress];"); // Create substitution stack for presynaptic simulation code - Substitutions synSubs(&popSubs); - synSubs.addVarSubstitution("id_pre", "preInd"); - synSubs.addVarSubstitution("id_post", "ipost"); - synSubs.addVarSubstitution("id_syn", "synAddress"); - - // If dendritic delay is required, use atomic operation to update dendritic delay buffer - if(sg.getArchetype().isDendriticDelayRequired()) { - synSubs.addFuncSubstitution("addToInSynDelay", 2, - backend.getAtomic(model.getPrecision()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, "ipost", "$(1)") + "], $(0))"); - } - // Otherwise, substitute global memory array for $(inSyn) - else { - synSubs.addFuncSubstitution("addToInSyn", 1, - backend.getAtomic(model.getPrecision()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, "ipost") + "], $(0))"); - } - - if(sg.getArchetype().isPresynapticOutputRequired()) { - synSubs.addFuncSubstitution("addToPre", 1, "lrevInSyn += $(0)"); - } + EnvironmentGroupMergedField synEnv(env, sg); + synEnv.add(Type::Uint32.addConst(), "id_pre", "preInd"); + synEnv.add(Type::Uint32.addConst(), "id_post", "ipost"); + synEnv.add(Type::Uint32.addConst(), "id_syn", "synAddress"); + + synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", + backend.getAtomic(model.getPrecision()) + "(&$(_den_delay)[" + sg.getPostDenDelayIndex(batchSize, "ipost", "$(1)") + "], $(0))"); + synEnv.add(Type::AddToPost, "addToPost", + backend.getAtomic(model.getPrecision()) + "(&$(_out_post)[" + sg.getPostISynIndex(batchSize, "ipost") + "], $(0))"); + synEnv.add(Type::AddToPre, "addToPre", "lrevInSyn += $(0)"); if(trueSpike) { - sg.generateSpikeUpdate(backend, os, modelMerged, synSubs); + sg.generateSpikeUpdate(backend, synEnv, modelMerged); } else { - sg.generateSpikeEventUpdate(backend, os, modelMerged, synSubs); + sg.generateSpikeEventUpdate(backend, synEnv, modelMerged); } } - if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { + /*if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { os << CodeStream::CB(130); - } + }*/ // Should this be in the Postamble? if(sg.getArchetype().isPresynapticOutputRequired()) { // write lrevInSyn to global memory if not 0 - os << "if(lrevInSyn != 0.0) " << backend.getAtomic(model.getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, "preInd") + "], lrevInSyn);" << std::endl; + env.getStream() << "if(lrevInSyn != 0.0) " << backend.getAtomic(model.getPrecision()) + "(&" + env["_out_pre"] + "[" + sg.getPreISynIndex(batchSize, "preInd") + "], lrevInSyn);" << std::endl; } } } //---------------------------------------------------------------------------- -void PreSpan::genPostamble(CodeStream&, const ModelSpecMerged&, const PresynapticUpdateGroupMerged&, - const Substitutions&, const BackendSIMT&) const +void PreSpan::genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged&, + PresynapticUpdateGroupMerged&, const BackendSIMT&) const { } @@ -227,20 +218,20 @@ bool PostSpan::isCompatible(const SynapseGroupInternal &sg, const PreferencesBas && !(sg.getMatrixType() & SynapseMatrixConnectivity::TOEPLITZ)); } //---------------------------------------------------------------------------- -void PostSpan::genPreamble(CodeStream &os, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, - const Substitutions &, const BackendSIMT &backend) const +void PostSpan::genPreamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, + PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const { // If data structure is dense, we can accumulate output directly into register if(shouldAccumulateInRegister(sg)) { - os << "scalar linSyn = 0;" << std::endl; + env.getStream() << "scalar linSyn = 0;" << std::endl; } else if(isSmallSharedMemoryPop(sg, backend)) { - os << "if(" << backend.getThreadID() << " < group->numTrgNeurons)"; + env.getStream() << "if(" << backend.getThreadID() << " < group->numTrgNeurons)"; { - CodeGenerator::CodeStream::Scope b(os); - os << "shLg[" << backend.getThreadID() << "] = 0;" << std::endl; + CodeGenerator::CodeStream::Scope b(env.getStream()); + env.getStream() << "shLg[" << backend.getThreadID() << "] = 0;" << std::endl; } - backend.genSharedMemBarrier(os); + backend.genSharedMemBarrier(env.getStream()); } } //---------------------------------------------------------------------------- @@ -250,62 +241,62 @@ size_t PostSpan::getSharedMemoryPerThread(const PresynapticUpdateGroupMerged &sg return isSmallSharedMemoryPop(sg, backend) ? 1 : 0; } //---------------------------------------------------------------------------- -void PostSpan::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, - const Substitutions &popSubs, const BackendSIMT &backend, bool trueSpike) const +void PostSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, + PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const { // Get suffix based on type of events const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); - const std::string eventSuffix = trueSpike ? "" : "Evnt"; + const std::string eventSuffix = trueSpike ? "" : "_evnt"; - os << "const unsigned int numSpikes = group->srcSpkCnt" << eventSuffix << "[" << sg.getPreSlot(batchSize) << "];" << std::endl; - os << "const unsigned int numSpikeBlocks = (numSpikes + " << backend.getKernelBlockSize(KernelPresynapticUpdate) << " - 1) / " << backend.getKernelBlockSize(KernelPresynapticUpdate) << ";" << std::endl; + env.printLine("const unsigned int numSpikes = $(_src_spk_cnt" + eventSuffix + ")[" + sg.getPreSlot(batchSize) + "];"); + env.getStream() << "const unsigned int numSpikeBlocks = (numSpikes + " << backend.getKernelBlockSize(KernelPresynapticUpdate) << " - 1) / " << backend.getKernelBlockSize(KernelPresynapticUpdate) << ";" << std::endl; const auto *wu = sg.getArchetype().getWUModel(); - os << "for (unsigned int r = 0; r < numSpikeBlocks; r++)"; + env.getStream() << "for (unsigned int r = 0; r < numSpikeBlocks; r++)"; { - CodeStream::Scope b(os); - os << "const unsigned int numSpikesInBlock = (r == numSpikeBlocks - 1) ? ((numSpikes - 1) % " << backend.getKernelBlockSize(KernelPresynapticUpdate) << ") + 1 : " << backend.getKernelBlockSize(KernelPresynapticUpdate) << ";" << std::endl; + CodeStream::Scope b(env.getStream()); + env.getStream() << "const unsigned int numSpikesInBlock = (r == numSpikeBlocks - 1) ? ((numSpikes - 1) % " << backend.getKernelBlockSize(KernelPresynapticUpdate) << ") + 1 : " << backend.getKernelBlockSize(KernelPresynapticUpdate) << ";" << std::endl; - backend.genSharedMemBarrier(os); - os << "if (" << backend.getThreadID() << " < numSpikesInBlock)"; + backend.genSharedMemBarrier(env.getStream()); + env.getStream() << "if (" << backend.getThreadID() << " < numSpikesInBlock)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); const std::string index = "(r * " + std::to_string(backend.getKernelBlockSize(KernelPresynapticUpdate)) + ") + " + backend.getThreadID(); - os << "const unsigned int spk = group->srcSpk" << eventSuffix << "[" << sg.getPreVarIndex(batchSize, VarAccessDuplication::DUPLICATE, index) << "];" << std::endl; - os << "shSpk" << eventSuffix << "[" << backend.getThreadID() << "] = spk;" << std::endl; + env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDuplication::DUPLICATE, index) + "];"); + env.printLine("$(_sh_spk" + eventSuffix + ")[" + backend.getThreadID() + "] = spk;"); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << "shRowLength[" << backend.getThreadID() << "] = group->rowLength[spk];" << std::endl; + env.printLine("$(_sh_row_length)[" + backend.getThreadID() + "] = $(_row_length)[spk];"); } } - backend.genSharedMemBarrier(os); + backend.genSharedMemBarrier(env.getStream()); - os << "// loop through all incoming spikes" << std::endl; - os << "for (unsigned int j = 0; j < numSpikesInBlock; j++)"; + env.getStream() << "// loop through all incoming spikes" << std::endl; + env.getStream() << "for (unsigned int j = 0; j < numSpikesInBlock; j++)"; { - CodeStream::Scope b(os); - os << "// only work on existing neurons" << std::endl; - os << "if (" << popSubs["id"] << " < group->rowStride)"; + CodeStream::Scope b(env.getStream()); + env.getStream() << "// only work on existing neurons" << std::endl; + env.print("if ($(id) < $(_row_stride))"); { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { // If this can only be represented using a 64-bit number if(backend.areSixtyFourBitSynapseIndicesRequired(sg)) { - os << "const uint64_t gid = (shSpk" << eventSuffix << "[j] * (uint64_t)group->rowStride) + " << popSubs["id"] << ";" << std::endl; + env.printLine("const uint64_t gid = ($(_sh_spk" + eventSuffix + ")[j] * (uint64_t)$(_row_stride)) + $(id);"); } else { - os << "const unsigned int gid = (shSpk" << eventSuffix << "[j] * group->rowStride) + " << popSubs["id"] << ";" << std::endl; + env.printLine("const unsigned int gid = ($(_sh_spk" + eventSuffix + ")[j] * $(_row_stride)) + $(id);"); } } - if(backend.supportsNamespace() && !wu->getSimSupportCode().empty()) { + /*if(backend.supportsNamespace() && !wu->getSimSupportCode().empty()) { os << "using namespace " << modelMerged.getPresynapticUpdateSupportCodeNamespace(wu->getSimSupportCode()) << ";" << std::endl; - } - if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { - os << "if("; + }*/ + /*if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { + env.getStream() << "if("; if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { // Note: we will just access global mem. For compute >= 1.2 simultaneous access to same global mem in the (half-)warp will be coalesced - no worries - os << "(B(group->gp[gid / 32], gid & 31)) && "; + env.getStream() << "(B(group->gp[gid / 32], gid & 31)) && "; } Substitutions threshSubs(&popSubs); @@ -318,106 +309,110 @@ void PostSpan::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, con os << ")"; os << CodeStream::OB(130); } - else if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { - os << "if (B(group->gp[gid / 32], gid & 31))" << CodeStream::OB(135); + else */if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { + env.getStream() << "if (B(" << env["_gp"] << "[gid / 32], gid & 31))" << CodeStream::OB(135); } - os << "const unsigned int synAddress = (shSpk" << eventSuffix << "[j] * group->rowStride) + " + popSubs["id"] + ";" << std::endl; + EnvironmentGroupMergedField synEnv(env, sg); - Substitutions synSubs(&popSubs); - synSubs.addVarSubstitution("id_pre", "shSpk" + eventSuffix + "[j]"); - synSubs.addVarSubstitution("id_syn", "synAddress"); + synEnv.add(Type::Uint32.addConst(), "id_pre", "$(_sh_spk" + eventSuffix + ")[j]"); + synEnv.add(Type::Uint32.addConst(), "id_syn", "synAddress", + {synEnv.addInitialiser( "const unsigned int synAddress = ($(_sh_spk" + eventSuffix + ")[j] * $(_row_stride)) + $(id);")}); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << "const unsigned int npost = shRowLength[j];" << std::endl; + synEnv.getStream() << "const unsigned int npost = " << synEnv["_sh_row_length"] << "[j];" << std::endl; - os << "if (" << popSubs["id"] << " < npost)" << CodeStream::OB(140); - os << "const unsigned int ipost = group->ind[synAddress];" << std::endl; + synEnv.getStream() << "if (" << synEnv["id"] << " < npost)" << CodeStream::OB(140); + synEnv.getStream() << "const unsigned int ipost = " << synEnv["_ind"] << "[synAddress];" << std::endl; - synSubs.addVarSubstitution("id_post", "ipost"); + synEnv.add(Type::Uint32.addConst(), "id_post", "ipost"); } else { // DENSE - synSubs.addVarSubstitution("id_post", popSubs["id"]); + synEnv.add(Type::Uint32.addConst(), "id_post", "$(id)"); } + /*synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", + backend.getAtomic(model.getPrecision()) + "(&$(_den_delay)[" + sg.getPostDenDelayIndex(batchSize, "ipost", "$(1)") + "], $(0))"); + + synEnv.add(Type::AddToPost, "addToPost", + backend.getAtomic(model.getPrecision()) + "(&$(_out_post)[" + sg.getPostISynIndex(batchSize, "ipost") + "], $(0))"); + + synEnv.add(Type::AddToPre, "addToPre", "lrevInSyn += $(0)"); + */ // If dendritic delay is required, always use atomic operation to update dendritic delay buffer - if(sg.getArchetype().isDendriticDelayRequired()) { - synSubs.addFuncSubstitution("addToInSynDelay", 2, - backend.getAtomic(model.getPrecision()) + "(&group->denDelay[" + sg.getPostDenDelayIndex(batchSize, synSubs["id_post"], "$(1)") + "], $(0))"); + synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", + backend.getAtomic(model.getPrecision()) + "(&$(_den_delay)[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); + + // If we should accumulate in register, add parameter to register + if(shouldAccumulateInRegister(sg)) { + synEnv.add(Type::AddToPost, "addToPost", "linSyn += $(0)"); } - // Otherwise + // Otherwise, if we should use shared memory, add to shared memory + // **THINK** this is only correct if there are no multapses i.e. there is only one synapse between any pair of pre and postsynaptic neurons + else if(isSmallSharedMemoryPop(sg, backend)) { + synEnv.add(Type::AddToPost, "addToPost", "shLg[$(id_post)] += $(0)"); + } + // Otherwise, use global memory atomic else { - // If we should accumulate in register, add parameter to register - if(shouldAccumulateInRegister(sg)) { - synSubs.addFuncSubstitution("addToInSyn", 1, "linSyn += $(0)"); - } - // Otherwise, if we should use shared memory, add to shared memory - // **THINK** this is only correct if there are no multapses i.e. there is only one synapse between any pair of pre and postsynaptic neurons - else if(isSmallSharedMemoryPop(sg, backend)) { - synSubs.addFuncSubstitution("addToInSyn", 1, "shLg[" + synSubs["id_post"] + "] += $(0)"); - } - // Otherwise, use global memory atomic - else { - synSubs.addFuncSubstitution("addToInSyn", 1, - backend.getAtomic(model.getPrecision()) + "(&group->inSyn[" + sg.getPostISynIndex(batchSize, synSubs["id_post"]) + "], $(0))"); - } + synEnv.add(Type::AddToPost, "addToPost", + backend.getAtomic(model.getPrecision()) + "(&$(_out_post)[" + sg.getPostISynIndex(batchSize, "$(id_post)") + "], $(0))"); } if(sg.getArchetype().isPresynapticOutputRequired()) { - synSubs.addFuncSubstitution("addToPre", 1, - backend.getAtomic(model.getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); + synEnv.add(Type::AddToPre, "addToPre", + backend.getAtomic(model.getPrecision()) + "(&$(_out_pre)([" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], $(0))"); } if(trueSpike) { - sg.generateSpikeUpdate(backend, os, modelMerged, synSubs); + sg.generateSpikeUpdate(backend, synEnv, modelMerged); } else { - sg.generateSpikeEventUpdate(backend, os, modelMerged, synSubs); + sg.generateSpikeEventUpdate(backend, synEnv, modelMerged); } if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << CodeStream::CB(140); // end if (id < npost) + synEnv.getStream() << CodeStream::CB(140); // end if (id < npost) } - if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { + /*if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { os << CodeStream::CB(130); // end if (eCode) } - else if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { - os << CodeStream::CB(135); // end if (B(dd_gp" << sg.getName() << "[gid / 32], gid + else */if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { + env.getStream() << CodeStream::CB(135); // end if (B(dd_gp" << sg.getName() << "[gid / 32], gid } } } } } //---------------------------------------------------------------------------- -void PostSpan::genPostamble(CodeStream &os, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, - const Substitutions &popSubs, const BackendSIMT &backend) const +void PostSpan::genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, + PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const { // If we should accumulate output directly into register const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); if(shouldAccumulateInRegister(sg)) { - os << "// only do this for existing neurons" << std::endl; - os << "if (" << popSubs["id"] << " < group->numTrgNeurons)"; + env.getStream() << "// only do this for existing neurons" << std::endl; + env.print("if ($(id) < $(num_post))"); { - CodeStream::Scope b(os); - const std::string inSyn = "group->inSyn[" + sg.getPostISynIndex(batchSize, popSubs["id"]) + "]"; + CodeStream::Scope b(env.getStream()); + const std::string inSyn = printSubs("$(_out_post)[" + sg.getPostISynIndex(batchSize, "$(id)") + "]", env); if(sg.getArchetype().isPSModelFused()) { - os << backend.getAtomic(model.getPrecision()) << "(&" << inSyn << ", linSyn);" << std::endl; + env.getStream() << backend.getAtomic(model.getPrecision()) << "(&" << inSyn << ", linSyn);" << std::endl; } else { - os << inSyn << " += linSyn;" << std::endl; + env.getStream() << inSyn << " += linSyn;" << std::endl; } } } // Otherwise, if we should accumulate into shared memory else if(isSmallSharedMemoryPop(sg, backend)) { - backend.genSharedMemBarrier(os); - os << "if(" << backend.getThreadID() << " < group->numTrgNeurons)"; + backend.genSharedMemBarrier(env.getStream()); + env.getStream() << "if(" << backend.getThreadID() << " < " << env["num_post"] << ")"; { - CodeGenerator::CodeStream::Scope b(os); - os << backend.getAtomic(model.getPrecision()) << "(&group->inSyn[" << sg.getPostISynIndex(batchSize, backend.getThreadID()) << "], "; - os << "shLg[" << backend.getThreadID() << "]); " << std::endl; + CodeGenerator::CodeStream::Scope b(env.getStream()); + const std::string inSyn = printSubs("$(_out_post)[" + sg.getPostISynIndex(batchSize, backend.getThreadID()) + "]", env); + env.getStream() << backend.getAtomic(model.getPrecision()) << "(&" << inSyn << "], shLg[" << backend.getThreadID() << "]); " << std::endl; } } } @@ -459,57 +454,56 @@ size_t PreSpanProcedural::getSharedMemoryPerThread(const PresynapticUpdateGroupM return 0; } //---------------------------------------------------------------------------- -void PreSpanProcedural::genPreamble(CodeStream&, const ModelSpecMerged&, const PresynapticUpdateGroupMerged&, - const Substitutions&, const BackendSIMT&) const +void PreSpanProcedural::genPreamble(EnvironmentExternalBase&, const ModelSpecMerged&, + PresynapticUpdateGroupMerged&, const BackendSIMT&) const { } //---------------------------------------------------------------------------- -void PreSpanProcedural::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, - const Substitutions &popSubs, const BackendSIMT &backend, bool trueSpike) const +void PreSpanProcedural::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, + PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const { // Get suffix based on type of events const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); - const std::string eventSuffix = trueSpike ? "" : "Evnt"; + const std::string eventSuffix = trueSpike ? "" : "_evnt"; const auto *wu = sg.getArchetype().getWUModel(); const size_t numThreadsPerSpike = sg.getArchetype().getNumThreadsPerSpike(); if(numThreadsPerSpike > 1) { - os << "const unsigned int spike = " << popSubs["id"] << " / " << numThreadsPerSpike << ";" << std::endl; - os << "const unsigned int thread = " << popSubs["id"] << " % " << numThreadsPerSpike << ";" << std::endl; - os << "const unsigned int numPostPerThread = (group->numTrgNeurons + " << numThreadsPerSpike << " - 1) / " << numThreadsPerSpike << ";" << std::endl; + const std::string numThreadsPerSpikeStr = std::to_string(numThreadsPerSpike); + env.printLine("const unsigned int spike = $(id) / " + numThreadsPerSpikeStr + ";"); + env.printLine("const unsigned int thread = $(id) % " + numThreadsPerSpikeStr + ";"); + env.printLine("const unsigned int numPostPerThread = ($(num_post) + " + numThreadsPerSpikeStr + " - 1) / " + numThreadsPerSpikeStr + ";"); // Calculate the starting position and length of the sub-row to process on this thread // **TODO** fast-divide style optimisations here - os << "const unsigned int idPostStart = thread * numPostPerThread;" << std::endl; - os << "const unsigned int postRemainder = group->numTrgNeurons % numPostPerThread;" << std::endl; - os << "const unsigned int numPost = (postRemainder == 0 || thread < " << (numThreadsPerSpike - 1) << ") ? numPostPerThread : postRemainder;" << std::endl; + env.getStream() << "const unsigned int idPostStart = thread * numPostPerThread;" << std::endl; + env.getStream() << "const unsigned int postRemainder = " << env["num_post"] << " % numPostPerThread;" << std::endl; + env.getStream() << "const unsigned int numPost = (postRemainder == 0 || thread < " << (numThreadsPerSpike - 1) << ") ? numPostPerThread : postRemainder;" << std::endl; } else { - os << "const unsigned int spike = " << popSubs["id"] << ";" << std::endl; + env.printLine("const unsigned int spike = $(id);"); } if(sg.getArchetype().isPresynapticOutputRequired()) { - os << "scalar lrevInSyn= 0.0;" << std::endl; + env.getStream() << "scalar lrevInSyn = 0.0;" << std::endl; } // If there is a spike for this thread to process - os << "if (spike < group->srcSpkCnt" << eventSuffix << "[" << sg.getPreSlot(batchSize) << "])"; + env.print("if (spike < $(_src_spk_cnt" + eventSuffix + ")[" + sg.getPreSlot(batchSize) + "])"); { - CodeStream::Scope b(os); - - // Determine the index of the presynaptic neuron this thread is responsible for - os << "const unsigned int preInd = group->srcSpk" << eventSuffix << "[" << sg.getPreVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "spike") << "];" << std::endl; + CodeStream::Scope b(env.getStream()); - // Create substitution stack and add presynaptic index - Substitutions synSubs(&popSubs); - synSubs.addVarSubstitution("id_pre", "preInd"); + // Create environment and add presynaptic index + EnvironmentGroupMergedField synEnv(env, sg); + synEnv.add(Type::Uint32.addConst(), "id_pre", "preInd", + {synEnv.addInitialiser("const unsigned int preInd = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "spike") + "];")}); - if(backend.supportsNamespace() && !wu->getSimSupportCode().empty()) { + /*if(backend.supportsNamespace() && !wu->getSimSupportCode().empty()) { os << "using namespace " << modelMerged.getPresynapticUpdateSupportCodeNamespace(wu->getSimSupportCode()) << ";" << std::endl; - } + }*/ - if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { + /*if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { os << "if("; // Generate weight update threshold condition @@ -520,11 +514,12 @@ void PreSpanProcedural::genUpdate(CodeStream &os, const ModelSpecMerged &modelMe os << ")"; os << CodeStream::OB(130); - } + }*/ // Create substitution stack for generating procedural connectivity code - Substitutions connSubs(&synSubs); - connSubs.addVarSubstitution("num_threads", std::to_string(numThreadsPerSpike)); + assert(false); + /*Substitutions connSubs(&synSubs); + synEnv.add("num_threads", std::to_string(numThreadsPerSpike)); // If this connectivity requires an RNG for initialisation, // make copy of connect Phillox RNG and skip ahead to id that would have been used to initialize any variables associated with it @@ -605,21 +600,21 @@ void PreSpanProcedural::genUpdate(CodeStream &os, const ModelSpecMerged &modelMe // Generate procedural connectivity code sg.generateProceduralConnectivity(backend, os, connSubs); - if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { - os << CodeStream::CB(130); - } + //if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { + // os << CodeStream::CB(130); + //} // Should this be in the Postamble? if(sg.getArchetype().isPresynapticOutputRequired()) { // write lrevInSyn to global memory if not 0 os << "if(lrevInSyn != 0.0) " << backend.getAtomic(model.getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, "preInd") + "], lrevInSyn);" << std::endl; - } + }*/ } } //---------------------------------------------------------------------------- -void PreSpanProcedural::genPostamble(CodeStream&, const ModelSpecMerged&, const PresynapticUpdateGroupMerged&, - const Substitutions&, const BackendSIMT&) const +void PreSpanProcedural::genPostamble(EnvironmentExternalBase&, const ModelSpecMerged&, + PresynapticUpdateGroupMerged&, const BackendSIMT&) const { } @@ -647,17 +642,17 @@ bool PostSpanBitmask::isCompatible(const SynapseGroupInternal &sg, const Prefere && !sg.isDendriticDelayRequired()); } //---------------------------------------------------------------------------- -void PostSpanBitmask::genPreamble(CodeStream &os, const ModelSpecMerged &, const PresynapticUpdateGroupMerged &, - const Substitutions &, const BackendSIMT &backend) const +void PostSpanBitmask::genPreamble(EnvironmentExternalBase &env, const ModelSpecMerged &, + PresynapticUpdateGroupMerged &, const BackendSIMT &backend) const { // Loop through bits written by this thread for(size_t i = 0; i < 32; i++) { // Zero entries in this thread's shared memory array // **NOTE** this is ordered to prevent bank conflicts const std::string index = std::to_string(i * backend.getKernelBlockSize(KernelPresynapticUpdate)) + " + " + backend.getThreadID(); - os << "shLg[" << index << "] = 0;" << std::endl; + env.getStream() << "shLg[" << index << "] = 0;" << std::endl; } - backend.genSharedMemBarrier(os); + backend.genSharedMemBarrier(env.getStream()); } //---------------------------------------------------------------------------- size_t PostSpanBitmask::getSharedMemoryPerThread(const PresynapticUpdateGroupMerged&, const BackendSIMT&) const @@ -666,47 +661,47 @@ size_t PostSpanBitmask::getSharedMemoryPerThread(const PresynapticUpdateGroupMer return 32; } //---------------------------------------------------------------------------- -void PostSpanBitmask::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, - const Substitutions &popSubs, const BackendSIMT &backend, bool trueSpike) const +void PostSpanBitmask::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, + PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const { // Get suffix based on type of events const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - const std::string eventSuffix = trueSpike ? "" : "Evnt"; + const std::string eventSuffix = trueSpike ? "" : "_evnt"; // Get blocksize const size_t blockSize = backend.getKernelBlockSize(KernelPresynapticUpdate); - os << "const unsigned int numSpikes = group->srcSpkCnt" << eventSuffix << "[" << sg.getPreSlot(batchSize) << "];" << std::endl; - os << "const unsigned int numSpikeBlocks = (numSpikes + " << blockSize << " - 1) / " << blockSize << ";" << std::endl; + env.printLine("const unsigned int numSpikes = $(_src_spk_cnt" + eventSuffix + ")[" + sg.getPreSlot(batchSize) + "];"); + env.getStream() << "const unsigned int numSpikeBlocks = (numSpikes + " << blockSize << " - 1) / " << blockSize << ";" << std::endl; const auto *wu = sg.getArchetype().getWUModel(); - os << "const unsigned int rowWords = (group->numTrgNeurons + 32 - 1) / 32;" << std::endl; - os << "for (unsigned int r = 0; r < numSpikeBlocks; r++)"; + env.printLine("const unsigned int rowWords = ($(num_post) + 32 - 1) / 32;"); + env.getStream() << "for (unsigned int r = 0; r < numSpikeBlocks; r++)"; { - CodeStream::Scope b(os); - os << "const unsigned int numSpikesInBlock = (r == numSpikeBlocks - 1) ? ((numSpikes - 1) % " << blockSize << ") + 1 : " << blockSize << ";" << std::endl; + CodeStream::Scope b(env.getStream()); + env.getStream() << "const unsigned int numSpikesInBlock = (r == numSpikeBlocks - 1) ? ((numSpikes - 1) % " << blockSize << ") + 1 : " << blockSize << ";" << std::endl; - backend.genSharedMemBarrier(os); - os << "if (" << backend.getThreadID() << " < numSpikesInBlock)"; + backend.genSharedMemBarrier(env.getStream()); + env.getStream() << "if (" << backend.getThreadID() << " < numSpikesInBlock)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); const std::string index = "(r * " + std::to_string(backend.getKernelBlockSize(KernelPresynapticUpdate)) + ") + " + backend.getThreadID(); - os << "const unsigned int spk = group->srcSpk" << eventSuffix << "[" << sg.getPreVarIndex(batchSize, VarAccessDuplication::DUPLICATE, index) << "];" << std::endl; - os << "shSpk" << eventSuffix << "[" << backend.getThreadID() << "] = spk;" << std::endl; + env.printLine("const unsigned int spk = $(_src_spk" + eventSuffix + ")[" + sg.getPreVarIndex(batchSize, VarAccessDuplication::DUPLICATE, index) + "];"); + env.printLine("$(_sh_spk" + eventSuffix + ")[" + backend.getThreadID() + "] = spk;"); } - backend.genSharedMemBarrier(os); + backend.genSharedMemBarrier(env.getStream()); - os << "// loop through all incoming spikes" << std::endl; - os << "for (unsigned int j = 0; j < numSpikesInBlock; j++)"; + env.getStream() << "// loop through all incoming spikes" << std::endl; + env.getStream() << "for (unsigned int j = 0; j < numSpikesInBlock; j++)"; { - CodeStream::Scope b(os); - os << "// only work on existing neurons" << std::endl; - os << "if (" << popSubs["id"] << " < rowWords)"; + CodeStream::Scope b(env.getStream()); + env.getStream() << "// only work on existing neurons" << std::endl; + env.print("if ($(id) < rowWords)"); { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); - if(backend.supportsNamespace() && !wu->getSimSupportCode().empty()) { + /*if(backend.supportsNamespace() && !wu->getSimSupportCode().empty()) { os << "using namespace " << modelMerged.getPresynapticUpdateSupportCodeNamespace(wu->getSimSupportCode()) << ";" << std::endl; } if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { @@ -720,83 +715,82 @@ void PostSpanBitmask::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerg os << ")"; os << CodeStream::OB(130); - } + }*/ // Read row word - os << "uint32_t connectivityWord = group->gp[(shSpk" << eventSuffix << "[j] * rowWords) + " << popSubs["id"] << "];" << std::endl; + env.printLine("uint32_t connectivityWord = $(_gp)[($(_sh_spk" + eventSuffix + ")[j] * rowWords) + $(id)];"); // While there any bits left - os << "unsigned int ibit = 0;" << std::endl; - os << "while(connectivityWord != 0)"; + env.getStream() << "unsigned int ibit = 0;" << std::endl; + env.getStream() << "while(connectivityWord != 0)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(env.getStream()); + EnvironmentGroupMergedField synEnv(env, sg); // Cound leading zeros (as bits are indexed backwards this is index of next synapse) - os << "const int numLZ = " << backend.getCLZ() << "(connectivityWord);" << std::endl; + synEnv.getStream() << "const int numLZ = " << backend.getCLZ() << "(connectivityWord);" << std::endl; // Shift off zeros and the one just discovered // **NOTE** if numLZ == 31, undefined behaviour results in C++, BUT in CUDA this PRESUMABLY emits // In a 'shl' PTX instruction where "Shift amounts greater than the register width N are clamped to N." - os << "connectivityWord <<= (numLZ + 1);" << std::endl; + synEnv.getStream() << "connectivityWord <<= (numLZ + 1);" << std::endl; // Add to bit index - os << "ibit += numLZ;" << std::endl; + synEnv.getStream() << "ibit += numLZ;" << std::endl; // Calculate postsynaptic index - os << "const unsigned int ipost = ibit + (" << popSubs["id"] << " * 32);" << std::endl; + synEnv.printLine("const unsigned int ipost = ibit + ($(id) * 32);"); - Substitutions synSubs(&popSubs); - synSubs.addVarSubstitution("id_pre", "shSpk" + eventSuffix + "[j]"); - synSubs.addVarSubstitution("id_syn", "synAddress"); - synSubs.addVarSubstitution("id_post", "ipost"); - synSubs.addFuncSubstitution("addToInSyn", 1, "shLg[(ibit * " + std::to_string(blockSize) + ") + " + backend.getThreadID() + "] += $(0)"); + synEnv.add(Type::Uint32.addConst(), "id_pre", "$(_sh_spk" + eventSuffix + ")[j]"); + synEnv.add(Type::Uint32.addConst(), "id_post", "ipost"); + + + synEnv.add(Type::AddToPost, "addToPost", + "shLg[(ibit * " + std::to_string(blockSize) + ") + " + backend.getThreadID() + "] += $(0)"); + synEnv.add(Type::AddToPre, "addToPre", + backend.getAtomic(modelMerged.getModel().getPrecision()) + "(&$(_out_pre)[" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], $(0))"); - if(sg.getArchetype().isPresynapticOutputRequired()) { - synSubs.addFuncSubstitution("addToPre", 1, - backend.getAtomic(modelMerged.getModel().getPrecision()) + "(&group->revInSyn[" + sg.getPreISynIndex(batchSize, synSubs["id_pre"]) + "], $(0))"); - } - if(trueSpike) { - sg.generateSpikeUpdate(backend, os, modelMerged, synSubs); + sg.generateSpikeUpdate(backend, synEnv, modelMerged); } else { - sg.generateSpikeEventUpdate(backend, os, modelMerged, synSubs); + sg.generateSpikeEventUpdate(backend, synEnv, modelMerged); } - os << "ibit++;" << std::endl; + synEnv.getStream() << "ibit++;" << std::endl; } - if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { + /*if(!trueSpike && sg.getArchetype().isEventThresholdReTestRequired()) { os << CodeStream::CB(130); // end if (eCode) - } + }*/ } } } } //---------------------------------------------------------------------------- -void PostSpanBitmask::genPostamble(CodeStream &os, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, - const Substitutions &popSubs, const BackendSIMT &backend) const +void PostSpanBitmask::genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, + PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const { - backend.genSharedMemBarrier(os); + backend.genSharedMemBarrier(env.getStream()); const size_t blockSize = backend.getKernelBlockSize(KernelPresynapticUpdate); // Use first 32 threads in each block to write shared memory back to global memory - os << "if (" << backend.getThreadID() << " < 32)"; + env.getStream() << "if (" << backend.getThreadID() << " < 32)"; { - CodeStream::Scope b(os); - os << "unsigned int glbIdx = ((" << backend.getBlockID() << " - (" << popSubs["group_start_id"] << " / " << blockSize << ")) * " << 32 * blockSize << ") + " << backend.getThreadID() << ";" << std::endl; - os << "unsigned int shIdx = " << backend.getThreadID() << " * " << blockSize << ";" << std::endl; - os << "const unsigned int endShIdx = shIdx + 32;" << std::endl; - os << "for(;shIdx < endShIdx && glbIdx < group->numTrgNeurons; shIdx++, glbIdx += 32)"; + CodeStream::Scope b(env.getStream()); + env.printLine("unsigned int glbIdx = ((" + backend.getBlockID() + " - ($(_group_start_id) / " + std::to_string(blockSize) + ")) * " + std::to_string(32 * blockSize) + ") + " + backend.getThreadID() + ";"); + env.getStream() << "unsigned int shIdx = " << backend.getThreadID() << " * " << blockSize << ";" << std::endl; + env.getStream() << "const unsigned int endShIdx = shIdx + 32;" << std::endl; + env.print("for(;shIdx < endShIdx && glbIdx < $(num_post); shIdx++, glbIdx += 32)"); { - CodeStream::Scope b(os); - const std::string inSyn = "group->inSyn[" + sg.getPostISynIndex(modelMerged.getModel().getBatchSize(), "glbIdx") +"]"; + CodeStream::Scope b(env.getStream()); + const std::string inSyn = "$(_out_post)[" + sg.getPostISynIndex(modelMerged.getModel().getBatchSize(), "glbIdx") +"]"; if(sg.getArchetype().isPSModelFused()) { - os << backend.getAtomic(modelMerged.getModel().getPrecision()) << "(&" << inSyn << ", shLg[shIdx]);" << std::endl; + env.printLine(backend.getAtomic(modelMerged.getModel().getPrecision()) + "(&" + inSyn + ", shLg[shIdx]);"); } else { - os << inSyn << " += shLg[shIdx];" << std::endl; + env.printLine(inSyn + " += shLg[shIdx];"); } } } @@ -820,16 +814,16 @@ bool PostSpanToeplitz::isCompatible(const SynapseGroupInternal &sg, const Prefer return (sg.getMatrixType() & SynapseMatrixConnectivity::TOEPLITZ); } //---------------------------------------------------------------------------- -void PostSpanToeplitz::genPreamble(CodeStream &os, const ModelSpecMerged &, const PresynapticUpdateGroupMerged &sg, - const Substitutions &, const BackendSIMT &backend) const +void PostSpanToeplitz::genPreamble(EnvironmentExternalBase &env, const ModelSpecMerged &, + PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const { if(isSmallSharedMemoryPop(sg, backend)) { - os << "if(" << backend.getThreadID() << " < group->numTrgNeurons)"; + env.print("if(" + backend.getThreadID() + " < $(num_post))"); { - CodeGenerator::CodeStream::Scope b(os); - os << "shLg[" << backend.getThreadID() << "] = 0;" << std::endl; + CodeGenerator::CodeStream::Scope b(env.getStream()); + env.getStream() << "shLg[" << backend.getThreadID() << "] = 0;" << std::endl; } - backend.genSharedMemBarrier(os); + backend.genSharedMemBarrier(env.getStream()); } } //---------------------------------------------------------------------------- @@ -839,16 +833,17 @@ size_t PostSpanToeplitz::getSharedMemoryPerThread(const PresynapticUpdateGroupMe return isSmallSharedMemoryPop(sg, backend) ? 1 : 0; } //---------------------------------------------------------------------------- -void PostSpanToeplitz::genUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, - const Substitutions &popSubs, const BackendSIMT &backend, bool trueSpike) const +void PostSpanToeplitz::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, + PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const { const auto &connectInit = sg.getArchetype().getToeplitzConnectivityInitialiser(); // Get suffix based on type of events const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); - const std::string eventSuffix = trueSpike ? "" : "Evnt"; - + const std::string eventSuffix = trueSpike ? "" : "_evnt"; + assert(false); + /* // Create substitution stack for generating Toeplitz connectivity code Substitutions connSubs(&popSubs); connSubs.addVarSubstitution("id_diag", connSubs["id"]); @@ -977,21 +972,20 @@ void PostSpanToeplitz::genUpdate(CodeStream &os, const ModelSpecMerged &modelMer } } } - } + }*/ } //---------------------------------------------------------------------------- -void PostSpanToeplitz::genPostamble(CodeStream &os, const ModelSpecMerged &modelMerged, const PresynapticUpdateGroupMerged &sg, - const Substitutions&, const BackendSIMT &backend) const +void PostSpanToeplitz::genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, + PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const { // If we should accumulate into shared memory if(isSmallSharedMemoryPop(sg, backend)) { - backend.genSharedMemBarrier(os); - os << "if(" << backend.getThreadID() << " < group->numTrgNeurons)"; + backend.genSharedMemBarrier(env.getStream()); + env.print("if(" + backend.getThreadID() + " < $(num_post))"); { - CodeGenerator::CodeStream::Scope b(os); - os << backend.getAtomic(modelMerged.getModel().getPrecision()); - os << "(&group->inSyn[" << sg.getPostISynIndex(modelMerged.getModel().getBatchSize(), backend.getThreadID()) << "], "; - os << "shLg[" << backend.getThreadID() << "]); " << std::endl; + CodeGenerator::CodeStream::Scope b(env.getStream()); + const std::string idx = sg.getPostISynIndex(modelMerged.getModel().getBatchSize(), backend.getThreadID()); + env.printLine(backend.getAtomic(modelMerged.getModel().getPrecision()) + "(&$(_out_post)[" + idx + "], shLg[" + backend.getThreadID() + "]);"); } } } From 8c09fbe445007ef75382bde8c5ad96360b24f5d0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 28 Jun 2023 18:07:40 +0100 Subject: [PATCH 267/725] Fixed up backend SIMT RNG methods and closing in on compiling --- include/genn/backends/cuda/backend.h | 6 +- include/genn/backends/opencl/backend.h | 4 +- .../genn/genn/code_generator/backendSIMT.h | 55 ++-- src/genn/backends/cuda/backend.cc | 10 +- src/genn/backends/opencl/backend.cc | 8 +- src/genn/genn/code_generator/backendSIMT.cc | 258 ++++++++++-------- 6 files changed, 180 insertions(+), 161 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index 5161d60a4e..819b1fe989 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -156,17 +156,17 @@ class BACKEND_EXPORT Backend : public BackendSIMT virtual void genSharedMemBarrier(CodeStream &os) const override; //! For SIMT backends which initialize RNGs on device, initialize population RNG with specified seed and sequence - virtual void genPopulationRNGInit(CodeStream &os, const std::string &globalRNG, const std::string &seed, const std::string &sequence) const override; + virtual void genPopulationRNGInit(CodeStream &os, const std::string &globalRNG, const std::string &seed, const std::string &sequence) const final; //! Generate a preamble to add substitution name for population RNG - virtual void genPopulationRNGPreamble(CodeStream &os, Substitutions &subs, const std::string &globalRNG, const std::string &name = "rng") const override; + virtual std::string genPopulationRNGPreamble(CodeStream &os, const std::string &globalRNG) const final; //! If required, generate a postamble for population RNG /*! For example, in OpenCL, this is used to write local RNG state back to global memory*/ virtual void genPopulationRNGPostamble(CodeStream &os, const std::string &globalRNG) const override; //! Generate code to skip ahead local copy of global RNG - virtual void genGlobalRNGSkipAhead(CodeStream &os, Substitutions &subs, const std::string &sequence, const std::string &name = "rng") const override; + virtual std::string genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const final; //-------------------------------------------------------------------------- // CodeGenerator::BackendBase virtuals diff --git a/include/genn/backends/opencl/backend.h b/include/genn/backends/opencl/backend.h index 43f8b9e7be..c5d88eab2c 100644 --- a/include/genn/backends/opencl/backend.h +++ b/include/genn/backends/opencl/backend.h @@ -122,14 +122,14 @@ class BACKEND_EXPORT Backend : public BackendSIMT virtual void genPopulationRNGInit(CodeStream &os, const std::string &globalRNG, const std::string &seed, const std::string &sequence) const override; //! Generate a preamble to add substitution name for population RNG - virtual void genPopulationRNGPreamble(CodeStream &os, Substitutions &subs, const std::string &globalRNG, const std::string &name = "rng") const override; + virtual std::string genPopulationRNGPreamble(CodeStream &os, const std::string &globalRNG) const override; //! If required, generate a postamble for population RNG /*! For example, in OpenCL, this is used to write local RNG state back to global memory*/ virtual void genPopulationRNGPostamble(CodeStream &os, const std::string &globalRNG) const override; //! Generate code to skip ahead local copy of global RNG - virtual void genGlobalRNGSkipAhead(CodeStream &os, Substitutions &subs, const std::string &sequence, const std::string &name = "rng") const override; + virtual std::string genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const override; //-------------------------------------------------------------------------- // CodeGenerator::BackendBase:: virtuals diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 5cc9d4609d..750900e18f 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -103,14 +103,14 @@ class GENN_EXPORT BackendSIMT : public BackendBase virtual void genPopulationRNGInit(CodeStream &os, const std::string &globalRNG, const std::string &seed, const std::string &sequence) const = 0; //! Generate a preamble to add substitution name for population RNG - virtual void genPopulationRNGPreamble(CodeStream &os, Substitutions &subs, const std::string &globalRNG, const std::string &name = "rng") const = 0; + virtual std::string genPopulationRNGPreamble(CodeStream &os, const std::string &globalRNG) const = 0; //! If required, generate a postamble for population RNG /*! For example, in OpenCL, this is used to write local RNG state back to global memory*/ virtual void genPopulationRNGPostamble(CodeStream &os, const std::string &globalRNG) const = 0; //! Generate code to skip ahead local copy of global RNG - virtual void genGlobalRNGSkipAhead(CodeStream &os, Substitutions &subs, const std::string &sequence, const std::string &name = "rng") const = 0; + virtual std::string genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const = 0; //------------------------------------------------------------------------ // BackendBase virtuals @@ -227,11 +227,11 @@ class GENN_EXPORT BackendSIMT : public BackendBase // Private methods //-------------------------------------------------------------------------- template - void genParallelGroup(EnvironmentExternalBase &env, const std::vector &groups, size_t &idStart, + void genParallelGroup(EnvironmentExternalBase &env, std::vector &groups, size_t &idStart, S getPaddedSizeFunc, F filter, GroupHandlerEnv handler) const { // Loop through groups - for(const auto &gMerge : groups) { + for(auto &gMerge : groups) { if(filter(gMerge)) { // Sum padded sizes of each group within merged group const size_t paddedSize = std::accumulate( @@ -252,51 +252,52 @@ class GENN_EXPORT BackendSIMT : public BackendBase } { CodeStream::Scope b(env.getStream()); + EnvironmentExternal groupEnv(env); if(gMerge.getGroups().size() == 1) { - env.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; - env.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[0]; " << std::endl; - env.getStream() << "const unsigned int lid = id - " << idStart << ";" << std::endl; + groupEnv.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; + groupEnv.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[0]; " << std::endl; + groupEnv.getStream() << "const unsigned int lid = id - " << idStart << ";" << std::endl; // Use the starting thread ID of the whole merged group as group_start_id - env.add(Type::Uint32.addConst(), "group_start_id", std::to_string(idStart)); + groupEnv.add(Type::Uint32.addConst(), "group_start_id", std::to_string(idStart)); } else { // Perform bisect operation to get index of merged struct - env.getStream() << "unsigned int lo = 0;" << std::endl; - env.getStream() << "unsigned int hi = " << gMerge.getGroups().size() << ";" << std::endl; - env.getStream() << "while(lo < hi)" << std::endl; + groupEnv.getStream() << "unsigned int lo = 0;" << std::endl; + groupEnv.getStream() << "unsigned int hi = " << gMerge.getGroups().size() << ";" << std::endl; + groupEnv.getStream() << "while(lo < hi)" << std::endl; { - CodeStream::Scope b(env.getStream()); - env.getStream() << "const unsigned int mid = (lo + hi) / 2;" << std::endl; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "const unsigned int mid = (lo + hi) / 2;" << std::endl; - env.getStream() << "if(id < d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[mid])"; + groupEnv.getStream() << "if(id < d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[mid])"; { - CodeStream::Scope b(env.getStream()); - env.getStream() << "hi = mid;" << std::endl; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "hi = mid;" << std::endl; } - env.getStream() << "else"; + groupEnv.getStream() << "else"; { - CodeStream::Scope b(env.getStream()); - env.getStream() << "lo = mid + 1;" << std::endl; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "lo = mid + 1;" << std::endl; } } // Use this to get reference to merged group structure - env.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; - env.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[lo - 1]; " << std::endl; + groupEnv.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; + groupEnv.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[lo - 1]; " << std::endl; // Get group start thread ID and use as group_start_id - env.getStream() << "const unsigned int groupStartID = d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[lo - 1];" << std::endl; - env.add(Type::Uint32.addConst(), "_group_start_id", "groupStartID"); + groupEnv.getStream() << "const unsigned int groupStartID = d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[lo - 1];" << std::endl; + groupEnv.add(Type::Uint32.addConst(), "_group_start_id", "groupStartID"); // Use this to calculate local id within group - env.getStream() << "const unsigned int lid = id - groupStartID;" << std::endl; + groupEnv.getStream() << "const unsigned int lid = id - groupStartID;" << std::endl; } - env.add(Type::Uint32.addConst(), "id", "lid"); + groupEnv.add(Type::Uint32.addConst(), "id", "lid"); - handler(env, gMerge); + handler(groupEnv, gMerge); idStart += paddedSize; } @@ -306,7 +307,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase template - void genParallelGroup(EnvironmentExternalBase &env, const std::vector &groups, size_t &idStart, + void genParallelGroup(EnvironmentExternalBase &env, std::vector &groups, size_t &idStart, S getPaddedSizeFunc, GroupHandlerEnv handler) const { genParallelGroup(env, groups, idStart, getPaddedSizeFunc, diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 031e9a5ac2..878ab3f25b 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -395,23 +395,21 @@ void Backend::genPopulationRNGInit(CodeStream &os, const std::string &globalRNG, os << "curand_init(" << seed << ", " << sequence << ", 0, &" << globalRNG << ");" << std::endl; } //-------------------------------------------------------------------------- -void Backend::genPopulationRNGPreamble(CodeStream &, Substitutions &subs, const std::string &globalRNG, const std::string &name) const +std::string Backend::genPopulationRNGPreamble(CodeStream &, const std::string &globalRNG) const { - subs.addVarSubstitution(name, "&" + globalRNG); + return "&" + globalRNG; } //-------------------------------------------------------------------------- void Backend::genPopulationRNGPostamble(CodeStream&, const std::string&) const { } //-------------------------------------------------------------------------- -void Backend::genGlobalRNGSkipAhead(CodeStream &os, Substitutions &subs, const std::string &sequence, const std::string &name) const +std::string Backend::genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const { // Skipahead RNG os << "curandStatePhilox4_32_10_t localRNG = d_rng;" << std::endl; os << "skipahead_sequence((unsigned long long)" << sequence << ", &localRNG);" << std::endl; - - // Add substitution for RNG - subs.addVarSubstitution(name, "&localRNG"); + return "&localRNG"; } //-------------------------------------------------------------------------- void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index 5b6f18c0d8..734ba3b418 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -277,11 +277,11 @@ void Backend::genPopulationRNGInit(CodeStream&, const std::string&, const std::s assert(false); } //-------------------------------------------------------------------------- -void Backend::genPopulationRNGPreamble(CodeStream &os, Substitutions &subs, const std::string &globalRNG, const std::string &name) const +std::string Backend::genPopulationRNGPreamble(CodeStream &os, const std::string &globalRNG) const { os << "clrngLfsr113Stream localStream;" << std::endl; os << "clrngLfsr113CopyOverStreamsFromGlobal(1, &localStream, &" << globalRNG << ");" << std::endl; - subs.addVarSubstitution(name, "&localStream"); + return "&localStream"; } //-------------------------------------------------------------------------- void Backend::genPopulationRNGPostamble(CodeStream &os, const std::string &globalRNG) const @@ -289,7 +289,7 @@ void Backend::genPopulationRNGPostamble(CodeStream &os, const std::string &globa os << "clrngLfsr113CopyOverStreamsToGlobal(1, &" << globalRNG << ", &localStream);" << std::endl; } //-------------------------------------------------------------------------- -void Backend::genGlobalRNGSkipAhead(CodeStream &os, Substitutions &subs, const std::string &sequence, const std::string &name) const +std::string Backend::genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const { // Make local copy of host stream os << "clrngPhilox432Stream localStream;" << std::endl; @@ -300,7 +300,7 @@ void Backend::genGlobalRNGSkipAhead(CodeStream &os, Substitutions &subs, const s os << "const clrngPhilox432Counter steps = {{0, " << sequence << "}, {0, 0}};" << std::endl; os << "localStream.current.ctr = clrngPhilox432Add(localStream.current.ctr, steps);" << std::endl; os << "localStream.current.deckIndex = 0;" << std::endl; - subs.addVarSubstitution(name, "&localStream"); + return "&localStream"; } //-------------------------------------------------------------------------- void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index f07e60c25f..9d3394f761 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -338,7 +338,7 @@ void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &en popEnv.print("if($(id) < $(_spk_cnt)[lastTimestepDelaySlot])"); { CodeStream::Scope b(popEnv.getStream()); - popEnv.printLine("$(_prev_spk_time)[lastTimestepDelayOffset + $(_spk)[lastTimestepDelayOffset + $(id)]] = $(t) - DT;") + popEnv.printLine("$(_prev_spk_time)[lastTimestepDelayOffset + $(_spk)[lastTimestepDelayOffset + $(id)]] = $(t) - DT;"); } } if(ng.getArchetype().isPrevSpikeEventTimeRequired()) { @@ -489,8 +489,9 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, const Mode CodeStream::Scope b(neuronEnv.getStream()); // Copy global RNG stream to local and use pointer to this for rng + const std::string rng = printSubs("$(_rng)[" + ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)") + "]", neuronEnv); if(ng.getArchetype().isSimRNGRequired()) { - genPopulationRNGPreamble(os, popSubs, "group->rng[" + ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) + "]"); + neuronEnv.add(Type::Void, "rng", genPopulationRNGPreamble(neuronEnv.getStream(), rng)); } ng.generateNeuronUpdate(*this, neuronEnv, modelMerged, @@ -507,7 +508,7 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, const Mode // Copy local stream back to local if(ng.getArchetype().isSimRNGRequired()) { - genPopulationRNGPostamble(os, "group->rng[" + ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, popSubs["id"]) + "]"); + genPopulationRNGPostamble(neuronEnv.getStream(), rng); } } @@ -602,7 +603,7 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, const Mode // Calculate number of words which will be used to record this population's spikes in each batch neuronEnv.printLine("const unsigned int numRecordingWords = ($(num_neurons) + 31) / 32;"); - neuronEnv.printLine("const unsigned int popWordIdx = ($(id) / 32) + " + getThreadID() + << ";"); + neuronEnv.printLine("const unsigned int popWordIdx = ($(id) / 32) + " + getThreadID() + ";"); // Build global index std::string globalIndex = "(recordingTimestep * numRecordingWords * " + std::to_string(batchSize) + ") + popWordIdx"; @@ -679,13 +680,13 @@ void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, const } // Shared memory for row length - kernelEnv.add(Type::Uint32.createPointer(), "_sh_row_length", "shRowLength", + kernelEnv.add(Type::Void, "_sh_row_length", "shRowLength", {kernelEnv.addInitialiser(getSharedPrefix() + "unsigned int shRowLength[" + std::to_string(getKernelBlockSize(KernelPresynapticUpdate)) + "];")}); // Shared memory for spikes and spike events - kernelEnv.add(Type::Uint32.createPointer(), "_sh_spk", "shSpk", + kernelEnv.add(Type::Void, "_sh_spk", "shSpk", {kernelEnv.addInitialiser(getSharedPrefix() + "unsigned int shSpk[" + std::to_string(getKernelBlockSize(KernelPresynapticUpdate)) + "];")}); - kernelEnv.add(Type::Uint32.createPointer(), "_sh_spk_evnt", "shSpkEvnt", + kernelEnv.add(Type::Void, "_sh_spk_evnt", "shSpkEvnt", {kernelEnv.addInitialiser(getSharedPrefix() + "unsigned int shSpkEvnt[" + std::to_string(getKernelBlockSize(KernelPresynapticUpdate)) + "];")}); // Parallelise over synapse groups @@ -695,32 +696,34 @@ void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, const [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumPresynapticUpdateThreads(sg, getPreferences()), KernelPresynapticUpdate); }, [&modelMerged, this](EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg) { + EnvironmentGroupMergedField groupEnv(env, sg); + // Get presynaptic update strategy to use for this synapse group const auto *presynapticUpdateStrategy = getPresynapticUpdateStrategy(sg.getArchetype()); LOGD_BACKEND << "Using '" << typeid(*presynapticUpdateStrategy).name() << "' presynaptic update strategy for merged synapse group '" << sg.getIndex() << "'"; // Generate index calculation code - genSynapseIndexCalculation(env, modelMerged.getModel().getBatchSize()); + genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); // Generate preamble - presynapticUpdateStrategy->genPreamble(os, modelMerged, sg, popSubs, *this); + presynapticUpdateStrategy->genPreamble(groupEnv, modelMerged, sg, *this); // If spike events should be processed if(sg.getArchetype().isSpikeEventRequired()) { - CodeStream::Scope b(os); - presynapticUpdateStrategy->genUpdate(os, modelMerged, sg, popSubs, *this, false); + CodeStream::Scope b(groupEnv.getStream()); + presynapticUpdateStrategy->genUpdate(groupEnv, modelMerged, sg, *this, false); } // If true spikes should be processed if(sg.getArchetype().isTrueSpikeRequired()) { - CodeStream::Scope b(os); - presynapticUpdateStrategy->genUpdate(os, modelMerged, sg, popSubs, *this, true); + CodeStream::Scope b(groupEnv.getStream()); + presynapticUpdateStrategy->genUpdate(groupEnv, modelMerged, sg, *this, true); } - os << std::endl; + groupEnv.getStream() << std::endl; // Generate pre-amble - presynapticUpdateStrategy->genPostamble(os, modelMerged, sg, popSubs, *this); + presynapticUpdateStrategy->genPostamble(groupEnv, modelMerged, sg, *this); }); } //-------------------------------------------------------------------------- @@ -729,9 +732,9 @@ void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, cons EnvironmentExternal kernelEnv(env); // Shared memory for column length and spikes - kernelEnv.add(Type::Uint32.createPointer(), "_sh_colw_length", "shColLength", + kernelEnv.add(Type::Void, "_sh_col_length", "shColLength", {kernelEnv.addInitialiser(getSharedPrefix() + "unsigned int shColLength[" + std::to_string(getKernelBlockSize(KernelPostsynapticUpdate)) + "];")}); - kernelEnv.add(Type::Uint32.createPointer(), "_sh_spk", "shSpk", + kernelEnv.add(Type::Void, "_sh_spk", "shSpk", {kernelEnv.addInitialiser(getSharedPrefix() + "unsigned int shSpk[" + std::to_string(getKernelBlockSize(KernelPostsynapticUpdate)) + "];")}); // Parallelise over postsynaptic update groups @@ -740,7 +743,7 @@ void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, cons [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumPostsynapticUpdateThreads(sg), KernelPostsynapticUpdate); }, [&modelMerged, this](EnvironmentExternalBase &env, PostsynapticUpdateGroupMerged &sg) { - EnvironmentGroupMergedField groupEnv(env); + EnvironmentGroupMergedField groupEnv(env, sg); // Generate index calculation code const unsigned int batchSize = modelMerged.getModel().getBatchSize(); @@ -788,7 +791,7 @@ void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, cons // **OPTIMIZE** we can do a fast constant divide optimization here synEnv.add(Type::Uint32.addConst(), "id_pre", "idPre", - {synEnv.addInitialiser("const unsigned int idPre = $(synEnv) / $(_row_stride);"}); + {synEnv.addInitialiser("const unsigned int idPre = $(synEnv) / $(_row_stride);")}); } else { synEnv.add(Type::Uint32.addConst(), "id_syn", "synAddress", @@ -857,7 +860,7 @@ void BackendSIMT::genSynapseDynamicsKernel(EnvironmentExternalBase &env, const M {synEnv.addInitialiser("const unsigned int idPost = ($(id) % $(_row_stride)")}); } - synEnv.add(Type::Uint32.addConst(), "id_syn", "$(id)"]); + synEnv.add(Type::Uint32.addConst(), "id_syn", "$(id)"); synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", getAtomic(modelMerged.getModel().getPrecision()) + "(&$(_den_delay)[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); @@ -990,7 +993,7 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpe // Split ID into intra-batch ID and batch // **TODO** fast-divide style optimisations here const std::string blockSizeStr = std::to_string(blockSize); - const size_t paddedSizeInit = groupEnv.addInitialiser("const unsigned int paddedSize = " + blockSizeStr + " * (($(size) + " + blockSizeStr + " - 1) / " + blockSizeStr + ");" << std::endl; + const size_t paddedSizeInit = groupEnv.addInitialiser("const unsigned int paddedSize = " + blockSizeStr + " * (($(size) + " + blockSizeStr + " - 1) / " + blockSizeStr + ");"); // Replace id in substitution with intra-batch ID and add batch groupEnv.add(Type::Uint32.addConst(), "id", "bid", @@ -1055,7 +1058,7 @@ void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, const ModelS // Split ID into intra-batch ID and batch // **TODO** fast-divide style optimisations here const std::string blockSizeStr = std::to_string(blockSize); - const size_t paddedSizeInit = groupEnv.addInitialiser("const unsigned int paddedSize = " + blockSizeStr + " * ((size + " + blockSizeStr + " - 1) / " + blockSizeStr + ");" << std::endl; + const size_t paddedSizeInit = groupEnv.addInitialiser("const unsigned int paddedSize = " + blockSizeStr + " * ((size + " + blockSizeStr + " - 1) / " + blockSizeStr + ");"); // Replace id in substitution with intra-batch ID and add batch groupEnv.add(Type::Uint32.addConst(), "id", "bid", @@ -1105,7 +1108,7 @@ void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, const ModelS synEnv.add(Type::Uint32.addConst(), "id_syn", "$(id)"); // Initialise reduction targets - const auto reductionTargets = genInitReductionTargets(synEnv.getStream(), cg, synEnv["id_syn"])); + const auto reductionTargets = genInitReductionTargets(synEnv.getStream(), cg, synEnv["id_syn"]); // If this is a reduction if(cg.getArchetype().isBatchReduction()) { @@ -1241,20 +1244,20 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, con groupEnv.getStream() << "// If thread isn't off the 'bottom' edge of the output matrix" << std::endl; groupEnv.getStream() << "if(x < " << groupEnv["num_pre"] << ")"; { - CodeStream::Scope b(synEnv.getStream()); - synEnv.getStream() << "// Loop through output rows" << std::endl; - synEnv.getStream() << "for(unsigned int j = 0; j < " << blockSize << "; j += 8)"; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "// Loop through output rows" << std::endl; + groupEnv.getStream() << "for(unsigned int j = 0; j < " << blockSize << "; j += 8)"; { - CodeStream::Scope b(synEnv.getStream()); - synEnv.getStream() << "// If thread isn't off the 'right' edge of the output matrix" << std::endl; - synEnv.getStream() << "if((y + j) < group" << groupEnv["num_post"] << ")"; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "// If thread isn't off the 'right' edge of the output matrix" << std::endl; + groupEnv.getStream() << "if((y + j) < group" << groupEnv["num_post"] << ")"; { - CodeStream::Scope b(synEnv.getStream()); - synEnv.getStream() << "group->" << transposeVarName << "Transpose["; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "group->" << transposeVarName << "Transpose["; if(cg.getArchetype().isBatched()) { - synEnv.getStream() << "batchOffset + "; + groupEnv.getStream() << "batchOffset + "; } - synEnv.getStream() << "((y + j) * " << groupEnv["num_pre"] << ") + x] = shTile[" << getThreadID(0) << "][" << getThreadID(1) << " + j];" << std::endl; + groupEnv.getStream() << "((y + j) * " << groupEnv["num_pre"] << ") + x] = shTile[" << getThreadID(0) << "][" << getThreadID(1) << " + j];" << std::endl; } } } @@ -1273,7 +1276,7 @@ void BackendSIMT::genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env return padSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelCustomUpdate); }, [&updateGroup](const CustomConnectivityUpdateGroupMerged &cg) { return (cg.getArchetype().getUpdateGroupName() == updateGroup); }, - [&modelMerged, this](EnvironmentExternalBase &env, const CustomConnectivityUpdateGroupMerged &cg) + [&modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdateGroupMerged &cg) { EnvironmentGroupMergedField groupEnv(env, cg); @@ -1288,15 +1291,16 @@ void BackendSIMT::genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env groupEnv.add(Type::Uint32.addConst(), "id_pre", "$(id)"); // Copy global RNG stream to local and use pointer to this for rng + const std::string rng = printSubs("$(_rng)[$(id)]", groupEnv); if(cg.getArchetype().isRowSimRNGRequired()) { - genPopulationRNGPreamble(os, popSubs, "group->rng[" + popSubs["id"] + "]"); + groupEnv.add(Type::Void, "rng", genPopulationRNGPreamble(groupEnv.getStream(), rng)); } cg.generateUpdate(*this, groupEnv, modelMerged); // Copy local stream back to local if(cg.getArchetype().isRowSimRNGRequired()) { - genPopulationRNGPostamble(os, "group->rng[" + popSubs["id"] + "]"); + genPopulationRNGPostamble(groupEnv.getStream(), rng); } } }); @@ -1304,8 +1308,8 @@ void BackendSIMT::genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env //-------------------------------------------------------------------------- void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const { - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Local neuron groups" << std::endl; + env.getStream() << "// ------------------------------------------------------------------------" << std::endl; + env.getStream() << "// Local neuron groups" << std::endl; idStart = 0; genParallelGroup( env, modelMerged.getMergedNeuronInitGroups(), idStart, @@ -1316,20 +1320,22 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS env.print("if($(id) < $(num_neurons))"); { CodeStream::Scope b(env.getStream()); + EnvironmentGroupMergedField groupEnv(env, ng); // If population RNGs are initialised on device and this neuron is going to require one, if(isPopulationRNGInitialisedOnDevice() && ng.getArchetype().isSimRNGRequired()) { // If batch size is 1, initialise single RNG using GLOBAL thread id for sequence if(modelMerged.getModel().getBatchSize() == 1) { - genPopulationRNGInit(os, "group->rng[" + popSubs["id"] + "]", "deviceRNGSeed", "id"); + genPopulationRNGInit(groupEnv.getStream(), printSubs("$(_rng)[$(id)]", groupEnv), + "deviceRNGSeed", "id"); } // Otherwise, loop through batches and initialise independent RNGs using GLOBAL thread id as basis of sequence else { env.getStream() << "for(unsigned int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { - CodeStream::Scope b(env.getStream()); - genPopulationRNGInit(os, "group->rng[(b * group->numNeurons) + " + popSubs["id"] + "]", "deviceRNGSeed", - "(b * " + std::to_string(getNumInitialisationRNGStreams(modelMerged)) + ") + id"); + CodeStream::Scope b(groupEnv.getStream()); + genPopulationRNGInit(groupEnv.getStream(), printSubs("$(_rng)[(b * $(num_neurons)) + $(id)]", groupEnv), + "deviceRNGSeed", "(b * " + std::to_string(getNumInitialisationRNGStreams(modelMerged)) + ") + id"); } } @@ -1339,16 +1345,16 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(ng.getArchetype().isInitRNGRequired()) { - genGlobalRNGSkipAhead(os, popSubs, "id"); + groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } - ng.generateInit(*this, env, modelMerged); + ng.generateInit(*this, groupEnv, modelMerged); } }); - os << std::endl; + env.getStream() << std::endl; - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Synapse groups" << std::endl; + env.getStream() << "// ------------------------------------------------------------------------" << std::endl; + env.getStream() << "// Synapse groups" << std::endl; genParallelGroup( env, modelMerged.getMergedSynapseInitGroups(), idStart, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumInitThreads(sg), KernelInitialize); }, @@ -1359,10 +1365,10 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL), sg.getArchetype().getKernelSize().size()); }); - os << std::endl; + env.getStream() << std::endl; - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom update groups" << std::endl; + env.getStream() << "// ------------------------------------------------------------------------" << std::endl; + env.getStream() << "// Custom update groups" << std::endl; genParallelGroup( env, modelMerged.getMergedCustomUpdateInitGroups(), idStart, [this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelInitialize); }, @@ -1372,21 +1378,22 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS env.print("if($(id) < $(size))"); { CodeStream::Scope b(env.getStream()); + EnvironmentGroupMergedField groupEnv(env, cg); // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(cg.getArchetype().isInitRNGRequired()) { - genGlobalRNGSkipAhead(os, popSubs, "id"); + groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } - cg.generateInit(*this, env, modelMerged); + cg.generateInit(*this, groupEnv, modelMerged); } }); - os << std::endl; + env.getStream() << std::endl; - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom WU update groups" << std::endl; + env.getStream() << "// ------------------------------------------------------------------------" << std::endl; + env.getStream() << "// Custom WU update groups" << std::endl; genParallelGroup( env, modelMerged.getMergedCustomWUUpdateInitGroups(), idStart, [this](const CustomUpdateWUInternal &cg) { return padKernelSize(getNumInitThreads(cg), KernelInitialize); }, @@ -1397,10 +1404,10 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS genSynapseVarInit(groupEnv, modelMerged, cg.getArchetype().isInitRNGRequired(), (sg->getMatrixType() & SynapseMatrixWeight::KERNEL), sg->getKernelSize().size()); }); - os << std::endl; + env.getStream() << std::endl; - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom connectivity presynaptic update groups" << std::endl; + env.getStream() << "// ------------------------------------------------------------------------" << std::endl; + env.getStream() << "// Custom connectivity presynaptic update groups" << std::endl; genParallelGroup( env, modelMerged.getMergedCustomConnectivityUpdatePreInitGroups(), idStart, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelInitialize); }, @@ -1410,27 +1417,29 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS env.print("if($(id) < $(size))"); { CodeStream::Scope b(env.getStream()); + EnvironmentGroupMergedField groupEnv(env, cg); // If population RNGs are initialised on device and this custom connectivity update // required one, initialise single RNG using GLOBAL thread id for sequence if(isPopulationRNGInitialisedOnDevice() && cg.getArchetype().isRowSimRNGRequired()) { - genPopulationRNGInit(os, "group->rng[" + popSubs["id"] + "]", "deviceRNGSeed", "id"); + genPopulationRNGInit(groupEnv.getStream(), printSubs("$(_rng)[$(id)]", groupEnv), + "deviceRNGSeed", "id"); } // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(cg.getArchetype().isPreVarInitRNGRequired()) { - genGlobalRNGSkipAhead(os, popSubs, "id"); + groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } - cg.generateInit(*this, env, modelMerged); + cg.generateInit(*this, groupEnv, modelMerged); } }); - os << std::endl; + env.getStream() << std::endl; - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom connectivity postsynaptic update groups" << std::endl; + env.getStream() << "// ------------------------------------------------------------------------" << std::endl; + env.getStream() << "// Custom connectivity postsynaptic update groups" << std::endl; genParallelGroup( env, modelMerged.getMergedCustomConnectivityUpdatePostInitGroups(), idStart, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons(), KernelInitialize); }, @@ -1440,21 +1449,29 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS env.print("if($(id) < $(size))"); { CodeStream::Scope b(env.getStream()); + EnvironmentGroupMergedField groupEnv(env, cg); + + // If population RNGs are initialised on device and this custom connectivity update + // required one, initialise single RNG using GLOBAL thread id for sequence + if(isPopulationRNGInitialisedOnDevice() && cg.getArchetype().isRowSimRNGRequired()) { + genPopulationRNGInit(groupEnv.getStream(), printSubs("$(_rng)[$(id)]", groupEnv), + "deviceRNGSeed", "id"); + } // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(cg.getArchetype().isPostVarInitRNGRequired()) { - genGlobalRNGSkipAhead(os, popSubs, "id"); + groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } - cg.generateInit(*this, env, modelMerged); + cg.generateInit(*this, groupEnv, modelMerged); } }); - os << std::endl; + env.getStream() << std::endl; - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Synapse groups with sparse connectivity" << std::endl; + env.getStream() << "// ------------------------------------------------------------------------" << std::endl; + env.getStream() << "// Synapse groups with sparse connectivity" << std::endl; genParallelGroup( env, modelMerged.getMergedSynapseConnectivityInitGroups(), idStart, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumConnectivityInitThreads(sg), KernelInitialize); }, @@ -1502,16 +1519,17 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS // Calculate index in data structure of this synapse if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { if(!snippet->getRowBuildCode().empty()) { - kernelInit << "const unsigned int idx = " << "($(id_pre) * $(_row_stride)) + $(_row_length)[$(id)];" << std::endl; + kernelInit << "const unsigned int idx = ($(id_pre) * $(_row_stride)) + $(_row_length)[$(id)];" << std::endl; } else { - kernelInit << "const unsigned int idx = " << "(($(0)) * $(_row_stride))) + $(_row_length)[$(0)];" << std::endl; + kernelInit << "const unsigned int idx = (($(0)) * $(_row_stride))) + $(_row_length)[$(0)];" << std::endl; } } // If there is a kernel if(!sg.getArchetype().getKernelSize().empty()) { - Substitutions kernelInitSubs(&popSubs); + assert(false); + /*Substitutions kernelInitSubs(&popSubs); // Replace $(id_post) with first 'function' parameter as simulation code is // going to be, in turn, substituted into procedural connectivity generation code @@ -1533,19 +1551,19 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS } // Call handler to initialize variables - sg.generateKernelInit(*this, kernelInit, modelMerged, kernelInitSubs); + sg.generateKernelInit(*this, kernelInit, modelMerged, kernelInitSubs);*/ } // If matrix is sparse if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // If there is row-building code in this snippet if(!snippet->getRowBuildCode().empty()) { - kernelInit << "group->ind[idx] = $(0);" << std::endl; - kernelInit << "group->rowLength[" << popSubs["id"] << "]++;" << std::endl; + kernelInit << "$(_ind)[idx] = $(0);" << std::endl; + kernelInit << "$(_row_length)[$(id)]++;" << std::endl; } // Otherwise else { - kernelInit << "group->ind[(($(0)) * group->rowStride) + " << getAtomic(Type::Uint32) << +"(&group->rowLength[$(0)], 1)] = " << popSubs["id_post"] << ";"; + kernelInit << "$(_ind)[(($(0)) * $(_row_stride)) + " << getAtomic(Type::Uint32) << +"(&$(_row_length)[$(0)], 1)] = $(id_post);"; } } // Otherwise, if it's bitmask @@ -1555,77 +1573,73 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS // If there is row-building code in this snippet if(!snippet->getRowBuildCode().empty()) { - kernelInit << "const " << indexType << " rowStartGID = " << popSubs["id"] << " * (" << indexType << ")group->rowStride;" << std::endl; - kernelInit << getAtomic(Type::Uint32, AtomicOperation::OR) << "(&group->gp[(rowStartGID + ($(0))) / 32], 0x80000000 >> ((rowStartGID + ($(0))) & 31));" << std::endl; + kernelInit << "const " << indexType << " rowStartGID = $(id) * (" << indexType << ")($_row_stride);" << std::endl; + kernelInit << getAtomic(Type::Uint32, AtomicOperation::OR) << "(&$(_gp)[(rowStartGID + ($(0))) / 32], 0x80000000 >> ((rowStartGID + ($(0))) & 31));" << std::endl; } // Otherwise else { - kernelInit << "const " << indexType << " colStartGID = " << popSubs["id"] << ";" << std::endl; - kernelInit << getAtomic(Type::Uint32, AtomicOperation::OR) << "(&group->gp[(colStartGID + (($(0)) * group->rowStride)) / 32], 0x80000000 >> ((colStartGID + (($(0)) * group->rowStride)) & 31));" << std::endl; + kernelInit << "const " << indexType << " colStartGID = $(id);" << std::endl; + kernelInit << getAtomic(Type::Uint32, AtomicOperation::OR) << "(&$(_gp)[(colStartGID + (($(0)) * $(_row_stride))) / 32], 0x80000000 >> ((colStartGID + (($(0)) * $(_row_stride))) & 31));" << std::endl; } } } kernelInit << "while(false)"; - popSubs.addFuncSubstitution("addSynapse", 1 + (unsigned int)sg.getArchetype().getKernelSize().size(), - kernelInitStream.str()); + groupEnv.add(Type::ResolvedType::createFunction(Type::Void, {Type::Uint32}), "addSynapse", //1 + (unsigned int)sg.getArchetype().getKernelSize().size(), + kernelInitStream.str()); + + // If this connectivity requires an RNG for initialisation, + // make copy of global phillox RNG and skip ahead by thread id + // **NOTE** not LOCAL id + if(Utils::isRNGRequired(snippet->getRowBuildCode())) { + groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); + } - // If there is row - building code in this snippet + // If there is row-building code in this snippet if(!snippet->getRowBuildCode().empty()) { // If this is a sparse matrix, zero row length if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - os << "group->rowLength[" + popSubs["id"] + "] = 0;" << std::endl; - } - - // If this connectivity requires an RNG for initialisation, - // make copy of global phillox RNG and skip ahead by thread id - // **NOTE** not LOCAL id - if(Utils::isRNGRequired(snippet->getRowBuildCode())) { - genGlobalRNGSkipAhead(os, popSubs, "id"); + groupEnv.printLine("$(_row_length)[$(id)] = 0;"); } // Call row-based connectivity handler - sg.generateSparseRowInit(*this, os, modelMerged, popSubs); + sg.generateSparseRowInit(*this, groupEnv); } - // Otherwise + // Otherwise, call column-based connectivity handler + // **NOTE** in this case, row length gets zeroed by a memset call in backend else { - // If this connectivity requires an RNG for initialisation, - // make copy of global phillox RNG and skip ahead by thread id - // **NOTE** not LOCAL id - if(Utils::isRNGRequired(snippet->getColBuildCode())) { - genGlobalRNGSkipAhead(os, popSubs, "id"); - } - - // Call column-based connectivity handler - sg.generateSparseColumnInit(*this, os, modelMerged, popSubs); + sg.generateSparseColumnInit(*this, groupEnv); } } }); - os << std::endl; + env.getStream() << std::endl; } //-------------------------------------------------------------------------- void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t numInitializeThreads, size_t &idStart) const { - // Shared memory array so row lengths don't have to be read by EVERY postsynaptic thread - // **TODO** check actually required - os << getSharedPrefix() << "unsigned int shRowLength[" << getKernelBlockSize(KernelInitializeSparse) << "];" << std::endl; + EnvironmentExternal envKernel(env); + envKernel.add(Type::Void, "_sh_row_length", "shRowLength", + {envKernel.addInitialiser(getSharedPrefix() + "unsigned int shRowLength[" + std::to_string(getKernelBlockSize(KernelInitializeSparse)) + "];")}); // Initialise weight update variables for synapse groups with sparse connectivity - genParallelGroup(env, modelMerged.getMergedSynapseSparseInitGroups(), idStart, + genParallelGroup(envKernel, modelMerged.getMergedSynapseSparseInitGroups(), idStart, [this](const SynapseGroupInternal &sg) { return padKernelSize(sg.getMaxConnections(), KernelInitializeSparse); }, [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, SynapseSparseInitGroupMerged &sg) { + EnvironmentGroupMergedField groupEnv(env, sg); + // If this post synapse requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(sg.getArchetype().isWUInitRNGRequired()) { - genGlobalRNGSkipAhead(os, popSubs, std::to_string(numInitializeThreads) + " + id"); + groupEnv.add(Type::Void, "rng", + genGlobalRNGSkipAhead(groupEnv.getStream(), std::to_string(numInitializeThreads) + " + id")); } // Generate sparse synapse variable initialisation code genSparseSynapseVarInit( - env, modelMerged, sg, sg.getArchetype().isWUVarInitRequired(), + groupEnv, modelMerged, sg, sg.getArchetype().isWUVarInitRequired(), [this](EnvironmentExternalBase &env, SynapseSparseInitGroupMerged &sg) { // If postsynaptic learning is required @@ -1633,54 +1647,60 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, const CodeStream::Scope b(env.getStream()); // Extract index of synapse's postsynaptic target - env.getStream() << "const unsigned int postIndex = " << env["_ind"] << "[idx];" << std::endl; + env.printLine("const unsigned int postIndex = $(_ind)[idx];"); // Atomically increment length of column of connectivity associated with this target // **NOTE** this returns previous length i.e. where to insert new entry - env.getStream() << "const unsigned int colLocation = " << getAtomic(Type::Uint32) << "(&" << env["_col_length"] << "[postIndex], 1);" << std::endl; + env.printLine("const unsigned int colLocation = " + getAtomic(Type::Uint32) + "(&$(_col_length)[postIndex], 1);"); // From this calculate index into column-major matrix - env.getStream() << "const unsigned int colMajorIndex = (postIndex * " << env["_col_stride"] << ") + colLocation;" << std::endl; + env.printLine("const unsigned int colMajorIndex = (postIndex * $(_col_stride)) + colLocation;"); // Add remapping entry at this location poining back to row-major index - env.getStream() << "group->remap[colMajorIndex] = idx;" << std::endl; + env.printLine("$(_remap)[colMajorIndex] = idx;"); } }); }); // Initialise weight update variables for synapse groups with sparse connectivity - genParallelGroup(env, modelMerged.getMergedCustomWUUpdateSparseInitGroups(), idStart, + genParallelGroup(envKernel, modelMerged.getMergedCustomWUUpdateSparseInitGroups(), idStart, [this](const CustomUpdateWUInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, CustomWUUpdateSparseInitGroupMerged &cg) { + EnvironmentGroupMergedField groupEnv(env, cg); + // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(cg.getArchetype().isInitRNGRequired()) { - genGlobalRNGSkipAhead(os, popSubs, std::to_string(numInitializeThreads) + " + id"); + groupEnv.add(Type::Void, "rng", + genGlobalRNGSkipAhead(groupEnv.getStream(), std::to_string(numInitializeThreads) + " + id")); } // Generate sparse synapse variable initialisation code genSparseSynapseVarInit( - env, modelMerged, cg, true, + groupEnv, modelMerged, cg, true, [](EnvironmentExternalBase &env, CustomWUUpdateSparseInitGroupMerged&){}); }); // Initialise weight update variables for synapse groups with sparse connectivity - genParallelGroup(env, modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups(), idStart, + genParallelGroup(envKernel, modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups(), idStart, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdateSparseInitGroupMerged &cg) { + EnvironmentGroupMergedField groupEnv(env, cg); + // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(cg.getArchetype().isVarInitRNGRequired()) { - genGlobalRNGSkipAhead(os, popSubs, std::to_string(numInitializeThreads) + " + id"); + groupEnv.add(Type::Void, "rng", + genGlobalRNGSkipAhead(groupEnv.getStream(), std::to_string(numInitializeThreads) + " + id")); } // Generate sparse synapse variable initialisation code genSparseSynapseVarInit( - env, modelMerged, cg, true, + groupEnv, modelMerged, cg, true, [](EnvironmentExternalBase&, CustomConnectivityUpdateSparseInitGroupMerged&){}); }); } From a4a04fc13b380f439ae27a9745dc4336eaea1339 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 28 Jun 2023 18:14:25 +0100 Subject: [PATCH 268/725] removed substitutions --- .../genn/genn/code_generator/backendSIMT.h | 1 - .../genn/genn/code_generator/codeGenUtils.h | 6 +- .../genn/genn/code_generator/substitutions.h | 200 ------------------ .../backends/single_threaded_cpu/backend.cc | 3 +- .../genn/code_generator/generateRunner.cc | 1 - src/genn/genn/code_generator/groupMerged.cc | 2 +- src/genn/genn/code_generator/substitutions.cc | 110 ---------- src/genn/genn/genn.vcxproj | 2 - src/genn/genn/modelSpec.cc | 6 +- 9 files changed, 8 insertions(+), 323 deletions(-) delete mode 100644 include/genn/genn/code_generator/substitutions.h delete mode 100644 src/genn/genn/code_generator/substitutions.cc diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 750900e18f..8a5a6938f1 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -14,7 +14,6 @@ #include "code_generator/codeStream.h" #include "code_generator/environment.h" #include "code_generator/presynapticUpdateStrategySIMT.h" -#include "code_generator/substitutions.h" //-------------------------------------------------------------------------- // GeNN::CodeGenerator::Kernel diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 976d9e3b22..d7538bde18 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -19,7 +19,6 @@ #include "backendBase.h" #include "codeStream.h" #include "lazyString.h" -#include "substitutions.h" #include "teeStream.h" // GeNN transpiler includes @@ -136,12 +135,13 @@ GENN_EXPORT void prettyPrintStatements(const std::string &code, const Type::Type Transpiler::PrettyPrinter::StatementHandler forEachSynapsePrettyPrintHandler = nullptr); GENN_EXPORT std::string printSubs(const std::string &format, EnvironmentExternalBase &env); + //------------------------------------------------------------------------- /*! \brief Function for performing the code and value substitutions necessary to insert neuron related variables, parameters, and extraGlobal parameters into synaptic code. */ //------------------------------------------------------------------------- -template +/*template void neuronSubstitutionsInSynapticCode(CodeGenerator::Substitutions &substitutions, const NeuronGroupInternal *archetypeNG, const std::string &delayOffset, const std::string &sourceSuffix, const std::string &destSuffix, const std::string &varPrefix, const std::string &varSuffix, bool useLocalNeuronVars, @@ -186,7 +186,7 @@ void neuronSubstitutionsInSynapticCode(CodeGenerator::Substitutions &substitutio // Substitute extra global parameters from neuron model substitutions.addVarNameSubstitution(nm->getExtraGlobalParams(), sourceSuffix, "group->", destSuffix); -} +}*/ template bool isKernelSizeHeterogeneous(const G &group, size_t dimensionIndex) diff --git a/include/genn/genn/code_generator/substitutions.h b/include/genn/genn/code_generator/substitutions.h deleted file mode 100644 index b119c743f8..0000000000 --- a/include/genn/genn/code_generator/substitutions.h +++ /dev/null @@ -1,200 +0,0 @@ -#pragma once - -// Standard C++ includes -#include -#include -#include - -// Standard C includes -#include - -// GeNN includes -#include "gennExport.h" -#include "gennUtils.h" -#include "logging.h" - -//-------------------------------------------------------------------------- -// GeNN::CodeGenerator::Substitutions -//-------------------------------------------------------------------------- -namespace GeNN::CodeGenerator -{ -class GENN_EXPORT Substitutions -{ -public: - //! Immutable structure for specifying how to implement - //! a generic function e.g. gennrand_uniform - /*! **NOTE** for the sake of easy initialisation first two parameters of GenericFunction are repeated (C++17 fixes) */ - struct FunctionTemplate - { - // **HACK** while GCC and CLang automatically generate this fine/don't require it, VS2013 seems to need it - FunctionTemplate operator = (const FunctionTemplate &o) - { - return FunctionTemplate{o.genericName, o.numArguments, o.funcTemplate}; - } - - //! Generic name used to refer to function in user code - const std::string genericName; - - //! Number of function arguments - const unsigned int numArguments; - - //! The function template (for use with ::functionSubstitute) used when model uses double precision - const std::string funcTemplate; - }; - - Substitutions(const Substitutions *parent = nullptr) : m_Parent(parent) - { - assert(m_Parent != this); - } - - Substitutions(const std::vector &functions) : m_Parent(nullptr) - { - // Loop through functions and add as substitutions - for(const auto &f: functions) { - addFuncSubstitution(f.genericName, f.numArguments, f.funcTemplate); - } - } - - //-------------------------------------------------------------------------- - // Public API - //-------------------------------------------------------------------------- - template - void addVarNameSubstitution(const std::vector &variables, const std::string &sourceSuffix = "", - const std::string &destPrefix = "", const std::string &destSuffix = "") - { - for(const auto &v : variables) { - addVarSubstitution(v.name + sourceSuffix, - destPrefix + v.name + destSuffix); - } - } - - template - void addVarNameSubstitution(const std::vector &variables, const std::string &sourceSuffix, - const std::string &destPrefix, S getDestSuffixFn, F filterFn) - { - for(const auto &v : variables) { - if (filterFn(v.access, v.name)) { - addVarSubstitution(v.name + sourceSuffix, - destPrefix + v.name + getDestSuffixFn(v.access, v.name)); - } - } - } - - template - void addVarNameSubstitution(const std::vector &variables, const std::string &sourceSuffix, - const std::string &destPrefix, S getDestSuffixFn) - { - typedef decltype(T::access) AccessType; - addVarNameSubstitution(variables, sourceSuffix, destPrefix, - getDestSuffixFn, [](AccessType, const std::string&) { return true; }); - } - - template - void addVarValueSubstitution(const std::vector &variables, const std::unordered_map &values, - const std::string &sourceSuffix = "") - { - if(variables.size() != values.size()) { - throw std::runtime_error("Number of variables does not match number of values"); - } - - for(const auto &v : variables) { - addVarSubstitution(v.name + sourceSuffix, - "(" + Utils::writePreciseString(values.at(v.name)) + ")"); - } - } - - void addParamValueSubstitution(const std::vector ¶mNames, const std::unordered_map &values, - const std::string &sourceSuffix = ""); - - template - void addParamValueSubstitution(const std::vector ¶mNames, const std::unordered_map &values, G isHeterogeneousFn, - const std::string &sourceSuffix = "", const std::string &destPrefix = "", const std::string &destSuffix = "") - { - if(paramNames.size() != values.size()) { - throw std::runtime_error("Number of parameters does not match number of values"); - } - - for(const auto &p : paramNames) { - if(isHeterogeneousFn(p)) { - addVarSubstitution(p + sourceSuffix, - destPrefix + p + destSuffix); - } - else { - addVarSubstitution(p + sourceSuffix, - "(" + Utils::writePreciseString(values.at(p)) + ")"); - } - } - } - - template - void addVarValueSubstitution(const std::vector &variables, const std::vector &values, G isHeterogeneousFn, - const std::string &sourceSuffix = "", const std::string &destPrefix = "", const std::string &destSuffix = "") - { - if(variables.size() != values.size()) { - throw std::runtime_error("Number of variables does not match number of values"); - } - - for(size_t i = 0; i < variables.size(); i++) { - if(isHeterogeneousFn(i)) { - addVarSubstitution(variables[i].name + sourceSuffix, - destPrefix + variables[i].name + destSuffix); - } - else { - addVarSubstitution(variables[i].name + sourceSuffix, - "(" + Utils::writePreciseString(values[i]) + ")"); - } - } - } - - template - void addVarValueSubstitution(const std::vector &variables, const std::unordered_map &values, G isHeterogeneousFn, - const std::string &sourceSuffix = "", const std::string &destPrefix = "", const std::string &destSuffix = "") - { - if(variables.size() != values.size()) { - throw std::runtime_error("Number of variables does not match number of values"); - } - - for(const auto &v : variables) { - if(isHeterogeneousFn(v.name)) { - addVarSubstitution(v.name + sourceSuffix, - destPrefix + v.name + destSuffix); - } - else { - addVarSubstitution(v.name + sourceSuffix, - "(" + Utils::writePreciseString(values.at(v.name)) + ")"); - } - } - } - - void addVarSubstitution(const std::string &source, const std::string &destionation, bool allowOverride = false); - void addFuncSubstitution(const std::string &source, unsigned int numArguments, const std::string &funcTemplate, bool allowOverride = false); - bool hasVarSubstitution(const std::string &source) const; - - const std::string &getVarSubstitution(const std::string &source) const; - - void apply(std::string &code) const; - void applyCheckUnreplaced(std::string &code, const std::string &context) const; - - //-------------------------------------------------------------------------- - // Public API - //-------------------------------------------------------------------------- - const std::string operator[] (const std::string &source) const - { - return getVarSubstitution(source); - } - -private: - //-------------------------------------------------------------------------- - // Private API - //-------------------------------------------------------------------------- - void applyFuncs(std::string &code) const; - void applyVars(std::string &code) const; - - //-------------------------------------------------------------------------- - // Members - //-------------------------------------------------------------------------- - std::map m_VarSubstitutions; - std::map> m_FuncSubstitutions; - const Substitutions *m_Parent; -}; -} // namespace GeNN::CodeGenerator diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 3691543c31..6f42e317fc 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -9,7 +9,6 @@ #include "code_generator/environment.h" #include "code_generator/modelSpecMerged.h" #include "code_generator/standardLibrary.h" -#include "code_generator/substitutions.h" using namespace GeNN; using namespace GeNN::CodeGenerator; @@ -189,7 +188,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host } if(n.getArchetype().isPrevSpikeEventTimeRequired()) { // Loop through neurons which spiked last timestep and set their spike time to time of previous timestep - groupEnv.print("for(unsigned int i = 0; i < $(_spk_cnt_evnt)[0]; i++)"; + groupEnv.print("for(unsigned int i = 0; i < $(_spk_cnt_evnt)[0]; i++)"); { CodeStream::Scope b(groupEnv.getStream()); groupEnv.printLine("$(_prev_spk_evnt_time)[$(_spk_evnt)[i]] = t - DT;"); diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index a4019b27cd..67ff8ec143 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -15,7 +15,6 @@ #include "code_generator/codeGenUtils.h" #include "code_generator/codeStream.h" #include "code_generator/groupMerged.h" -#include "code_generator/substitutions.h" #include "code_generator/teeStream.h" #include "code_generator/backendBase.h" #include "code_generator/modelSpecMerged.h" diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index 9f9f921599..b1a8db22cd 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -203,7 +203,7 @@ std::string SynapseGroupMergedBase::getPostPrevSpikeTimeIndex(bool delay, unsign const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); if(delay) { - return (singleBatch ? "$(_post_prev_spike_time_delay_offset) + " : "$(_post_prev_spike_time_batch_delay_offset) + ") + "$(" + index + ")";; + return (singleBatch ? "$(_post_prev_spike_time_delay_offset) + " : "$(_post_prev_spike_time_batch_delay_offset) + ") + std::string{"$(" + index + ")"}; } else { return (singleBatch ? "" : "$(_post_batch_offset) + ") + std::string{"$(" + index + ")"}; diff --git a/src/genn/genn/code_generator/substitutions.cc b/src/genn/genn/code_generator/substitutions.cc deleted file mode 100644 index b51904401c..0000000000 --- a/src/genn/genn/code_generator/substitutions.cc +++ /dev/null @@ -1,110 +0,0 @@ -#include "code_generator/substitutions.h" - -// GeNN code generator includes -#include "code_generator/codeGenUtils.h" - -//-------------------------------------------------------------------------- -// GeNN::CodeGenerator::Substitutions -//-------------------------------------------------------------------------- -namespace GeNN::CodeGenerator -{ -void Substitutions::addParamValueSubstitution(const std::vector ¶mNames, const std::unordered_map &values, - const std::string &sourceSuffix) -{ - if(paramNames.size() != values.size()) { - throw std::runtime_error("Number of parameters does not match number of values"); - } - - for(const auto &p : paramNames) { - addVarSubstitution(p + sourceSuffix, - "(" + Utils::writePreciseString(values.at(p)) + ")"); - } -} -//-------------------------------------------------------------------------- -void Substitutions::addVarSubstitution(const std::string &source, const std::string &destionation, bool allowOverride) -{ - auto res = m_VarSubstitutions.emplace(source, destionation); - if(!allowOverride && !res.second) { - throw std::runtime_error("'" + source + "' already has a variable substitution"); - } -} -//-------------------------------------------------------------------------- -void Substitutions::addFuncSubstitution(const std::string &source, unsigned int numArguments, - const std::string &funcTemplate, bool allowOverride) -{ - auto res = m_FuncSubstitutions.emplace(std::piecewise_construct, - std::forward_as_tuple(source), - std::forward_as_tuple(numArguments, funcTemplate)); - if(!allowOverride && !res.second) { - throw std::runtime_error("'" + source + "' already has a function substitution"); - } -} -//-------------------------------------------------------------------------- -bool Substitutions::hasVarSubstitution(const std::string &source) const -{ - if (m_VarSubstitutions.find(source) != m_VarSubstitutions.end()) { - return true; - } - else if (m_Parent) { - return m_Parent->hasVarSubstitution(source); - } - else { - return false; - } -} -//-------------------------------------------------------------------------- -const std::string &Substitutions::getVarSubstitution(const std::string &source) const -{ - auto var = m_VarSubstitutions.find(source); - if(var != m_VarSubstitutions.end()) { - return var->second; - } - else if(m_Parent) { - return m_Parent->getVarSubstitution(source); - } - else { - throw std::runtime_error("Nothing to substitute for '" + source + "'"); - } -} -//-------------------------------------------------------------------------- -void Substitutions::apply(std::string &code) const -{ - // Apply function and variable substitutions - // **NOTE** functions may contain variables so evaluate ALL functions first - applyFuncs(code); - applyVars(code); -} -//-------------------------------------------------------------------------- -void Substitutions::applyCheckUnreplaced(std::string &code, const std::string &context) const -{ - apply(code); - checkUnreplacedVariables(code, context); -} -//-------------------------------------------------------------------------- -void Substitutions::applyFuncs(std::string &code) const -{ - // Apply function substitutions - for(const auto &f : m_FuncSubstitutions) { - functionSubstitute(code, f.first, f.second.first, f.second.second); - } - - // If we have a parent, apply their function substitutions too - if(m_Parent) { - m_Parent->applyFuncs(code); - } -} -//-------------------------------------------------------------------------- -void Substitutions::applyVars(std::string &code) const -{ - // Apply variable substitutions - for(const auto &v : m_VarSubstitutions) { - LOGD_CODE_GEN << "Substituting '$(" << v.first << ")' for '" << v.second << "'"; - substitute(code, "$(" + v.first + ")", v.second); - } - - // If we have a parent, apply their variable substitutions too - if(m_Parent) { - m_Parent->applyVars(code); - } -} -} // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index f5f836a3ec..5eaf2a5734 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -38,7 +38,6 @@ - @@ -93,7 +92,6 @@ - diff --git a/src/genn/genn/modelSpec.cc b/src/genn/genn/modelSpec.cc index fd81fe5b33..a054b8fbfa 100644 --- a/src/genn/genn/modelSpec.cc +++ b/src/genn/genn/modelSpec.cc @@ -28,7 +28,6 @@ // GeNN code generator includes #include "code_generator/codeGenUtils.h" -#include "code_generator/substitutions.h" // --------------------------------------------------------------------------- // GeNN::ModelSpec @@ -287,7 +286,8 @@ void ModelSpec::finalize() const auto *wu = sg->getWUModel(); if(!wu->getEventThresholdConditionCode().empty()) { - using namespace CodeGenerator; + assert(false); + /*using namespace CodeGenerator; // do an early replacement of weight update model parameters and derived parameters // **NOTE** this is really gross but I can't really see an alternative - merging decisions are based on the spike event conditions set @@ -301,7 +301,7 @@ void ModelSpec::finalize() thresholdSubs.apply(eCode); // Add code and name of support code namespace to set - n.second.addSpkEventCondition(eCode, sg); + n.second.addSpkEventCondition(eCode, sg);*/ } } if (n.second.getSpikeEventCondition().size() > 1) { From 52bbddc92ba1f64de1efb465e186c9d46886d0af Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 29 Jun 2023 13:48:53 +0100 Subject: [PATCH 269/725] de-templated genXXXMergedGroup methods of ModelSpecMerged * Allows required type inference in BackendSIMT::genParallelGrouo * Reduces header-bloat --- .../genn/code_generator/modelSpecMerged.h | 295 +++--------------- .../genn/code_generator/modelSpecMerged.cc | 231 ++++++++++++++ 2 files changed, 271 insertions(+), 255 deletions(-) diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 3259caa554..43cd191330 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -78,7 +78,7 @@ class GENN_EXPORT ModelSpecMerged typedef std::map MergedEGPMap; template - using GenerateMergedGroupFn = std::function; + using GenMergedGroupFn = std::function; //-------------------------------------------------------------------------- // Public API @@ -164,254 +164,39 @@ class GENN_EXPORT ModelSpecMerged //! Get merged custom connectivity update groups where host processing needs to be performed const std::vector &getMergedCustomConnectivityHostUpdateGroups() const { return m_MergedCustomConnectivityHostUpdateGroups; } - template - void genMergedNeuronUpdateGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronUpdateGroups, - [](const NeuronGroupInternal &){ return true; }, - &NeuronGroupInternal::getHashDigest, generateGroup); - } - - template - void genMergedPresynapticUpdateGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedPresynapticUpdateGroups, - [](const SynapseGroupInternal &sg) { return (sg.isSpikeEventRequired() || sg.isTrueSpikeRequired()); }, - &SynapseGroupInternal::getWUHashDigest, generateGroup); - } - - template - void genMergedPostsynapticUpdateGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedPostsynapticUpdateGroups, - [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getLearnPostCode().empty(); }, - &SynapseGroupInternal::getWUHashDigest, generateGroup); - } - - template - void genMergedSynapseDynamicsGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseDynamicsGroups, - [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getSynapseDynamicsCode().empty(); }, - &SynapseGroupInternal::getWUHashDigest, generateGroup); - } - - template - void genMergedCustomUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, G generateGroup) - { - createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateGroups, - [&updateGroupName](const CustomUpdateInternal &cg) { return cg.getUpdateGroupName() == updateGroupName; }, - &CustomUpdateInternal::getHashDigest, generateGroup); - } - - template - void genMergedCustomUpdateWUGroups(const BackendBase &backend, const std::string &updateGroupName, G generateGroup) - { - createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, - [&updateGroupName](const CustomUpdateWUInternal &cg) - { - return (!cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName); - }, - &CustomUpdateWUInternal::getHashDigest, generateGroup); - } + void genMergedNeuronUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedPresynapticUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedPostsynapticUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedSynapseDynamicsGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedCustomUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup); + void genMergedCustomUpdateWUGroups(const BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup); + void genMergedCustomUpdateTransposeWUGroups(const BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup); + void genMergedCustomUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup); + void genMergedCustomWUUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup); + void genMergedCustomConnectivityUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup); + void genMergedCustomConnectivityHostUpdateGroups(BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup); + void genMergedNeuronSpikeQueueUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedNeuronPrevSpikeTimeUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedSynapseDendriticDelayUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedNeuronInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedCustomUpdateInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedCustomWUUpdateInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedSynapseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedSynapseConnectivityInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedSynapseSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedCustomWUUpdateSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedCustomConnectivityUpdatePreInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedCustomConnectivityUpdatePostInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedCustomConnectivityUpdateSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedSynapseConnectivityHostInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - template - void genMergedCustomUpdateTransposeWUGroups(const BackendBase &backend, const std::string &updateGroupName, G generateGroup) - { - createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, - [&updateGroupName](const CustomUpdateWUInternal &cg) - { - return (cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName); - }, - &CustomUpdateWUInternal::getHashDigest, generateGroup); - } - - template - void genMergedCustomUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, G generateGroup) - { - createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateHostReductionGroups, - [&updateGroupName](const CustomUpdateInternal &cg) - { - return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName); - }, - &CustomUpdateInternal::getHashDigest, generateGroup, true); - } - - template - void genMergedCustomWUUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, G generateGroup) - { - createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateHostReductionGroups, - [&updateGroupName](const CustomUpdateWUInternal &cg) - { - return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName); - }, - &CustomUpdateWUInternal::getHashDigest, generateGroup, true); - } - - template - void genMergedCustomConnectivityUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, G generateGroup) - { - createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateGroups, - [&updateGroupName](const CustomConnectivityUpdateInternal &cg) - { - return (!cg.getCustomConnectivityUpdateModel()->getRowUpdateCode().empty() && cg.getUpdateGroupName() == updateGroupName); - }, - &CustomConnectivityUpdateInternal::getHashDigest, generateGroup); - } - - template - void genMergedCustomConnectivityHostUpdateGroups(BackendBase &backend, const std::string &updateGroupName, G generateGroup) - { - createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityHostUpdateGroups, - [&updateGroupName](const CustomConnectivityUpdateInternal &cg) - { - return (!cg.getCustomConnectivityUpdateModel()->getHostUpdateCode().empty() && cg.getUpdateGroupName() == updateGroupName); - }, - &CustomConnectivityUpdateInternal::getHashDigest, generateGroup, true); - } - - template - void genMergedNeuronSpikeQueueUpdateGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronSpikeQueueUpdateGroups, - [](const NeuronGroupInternal &){ return true; }, - &NeuronGroupInternal::getSpikeQueueUpdateHashDigest, generateGroup); - } - - template - void genMergedNeuronPrevSpikeTimeUpdateGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronPrevSpikeTimeUpdateGroups, - [](const NeuronGroupInternal &ng){ return (ng.isPrevSpikeTimeRequired() || ng.isPrevSpikeEventTimeRequired()); }, - &NeuronGroupInternal::getPrevSpikeTimeUpdateHashDigest, generateGroup); - } - - template - void genMergedSynapseDendriticDelayUpdateGroups(const BackendBase &backend, G generateGroup) - { - std::vector> synapseGroupsWithDendriticDelay; - for(const auto &n : getModel().getNeuronGroups()) { - for(const auto *sg : n.second.getFusedPSMInSyn()) { - if(sg->isDendriticDelayRequired()) { - synapseGroupsWithDendriticDelay.push_back(std::cref(*sg)); - } - } - } - createMergedGroups(backend, synapseGroupsWithDendriticDelay, m_MergedSynapseDendriticDelayUpdateGroups, - &SynapseGroupInternal::getDendriticDelayUpdateHashDigest, generateGroup); - } - - template - void genMergedNeuronInitGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronInitGroups, - [](const NeuronGroupInternal &){ return true; }, - &NeuronGroupInternal::getInitHashDigest, generateGroup); - } - - template - void genMergedCustomUpdateInitGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateInitGroups, - [](const CustomUpdateInternal &cg) { return cg.isVarInitRequired(); }, - &CustomUpdateInternal::getInitHashDigest, generateGroup); - } - - template - void genMergedCustomWUUpdateInitGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateInitGroups, - [](const CustomUpdateWUInternal &cg) - { - return (((cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::DENSE) - || (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL)) - && cg.isVarInitRequired()); - }, - &CustomUpdateWUInternal::getInitHashDigest, generateGroup); - } - - template - void genMergedSynapseInitGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseInitGroups, - [](const SynapseGroupInternal &sg) - { - return (((sg.getMatrixType() & SynapseMatrixConnectivity::DENSE) - || (sg.getMatrixType() & SynapseMatrixWeight::KERNEL)) - && sg.isWUVarInitRequired()); - }, - &SynapseGroupInternal::getWUInitHashDigest, generateGroup); - } - - template - void genMergedSynapseConnectivityInitGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseConnectivityInitGroups, - [](const SynapseGroupInternal &sg){ return sg.isSparseConnectivityInitRequired(); }, - &SynapseGroupInternal::getConnectivityInitHashDigest, generateGroup); - } - - template - void genMergedSynapseSparseInitGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseSparseInitGroups, - [&backend](const SynapseGroupInternal &sg) - { - return ((sg.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && - (sg.isWUVarInitRequired() - || (backend.isPostsynapticRemapRequired() && !sg.getWUModel()->getLearnPostCode().empty()))); - }, - &SynapseGroupInternal::getWUInitHashDigest, generateGroup); - } - - template - void genMergedCustomWUUpdateSparseInitGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateSparseInitGroups, - [](const CustomUpdateWUInternal &cg) - { - return (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) && cg.isVarInitRequired(); - }, - &CustomUpdateWUInternal::getInitHashDigest, generateGroup); - } - - template - void genMergedCustomConnectivityUpdatePreInitGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePreInitGroups, - [&backend](const CustomConnectivityUpdateInternal &cg) - { - return (cg.isPreVarInitRequired() || (backend.isPopulationRNGInitialisedOnDevice() && cg.isRowSimRNGRequired())); - }, - &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); - } - - template - void genMergedCustomConnectivityUpdatePostInitGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePostInitGroups, - [](const CustomConnectivityUpdateInternal &cg) { return cg.isPostVarInitRequired(); }, - &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); - } - - template - void genMergedCustomConnectivityUpdateSparseInitGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateSparseInitGroups, - [](const CustomConnectivityUpdateInternal &cg) { return cg.isVarInitRequired(); }, - &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); - } - - template - void genMergedSynapseConnectivityHostInitGroups(const BackendBase &backend, G generateGroup) - { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseConnectivityHostInitGroups, - [](const SynapseGroupInternal &sg) - { - return !sg.getConnectivityInitialiser().getSnippet()->getHostInitCode().empty(); - }, - &SynapseGroupInternal::getConnectivityHostInitHashDigest, generateGroup, true); - } void genMergedNeuronUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedNeuronUpdateGroups); } void genMergedPresynapticUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedPresynapticUpdateGroups); } @@ -542,14 +327,14 @@ class GENN_EXPORT ModelSpecMerged } } - template + template void createMergedGroups(const BackendBase &backend, - const std::vector> &unmergedGroups, - std::vector &mergedGroups, D getHashDigest, G generateGroup, bool host = false) + const std::vector> &unmergedGroups, + std::vector &mergedGroups, D getHashDigest, GenMergedGroupFn generateGroup, bool host = false) { // Create a hash map to group together groups with the same SHA1 digest std::unordered_map>, + std::vector>, Utils::SHA1Hash> protoMergedGroups; // Add unmerged groups to correct vector @@ -589,13 +374,13 @@ class GENN_EXPORT ModelSpecMerged } } - template + template void createMergedGroups(const BackendBase &backend, - const std::map &groups, std::vector &mergedGroups, + const std::map &groups, std::vector &mergedGroups, F filter, D getHashDigest, G generateGroup, bool host = false) { // Build temporary vector of references to groups that pass filter - std::vector> unmergedGroups; + std::vector> unmergedGroups; for(const auto &g : groups) { if(filter(g.second)) { unmergedGroups.emplace_back(std::cref(g.second)); diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index 6d1e34bf0d..ce4165673a 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -97,6 +97,237 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa assignGroups(backend, m_MergedCustomConnectivityUpdateSparseInitGroups, memorySpaces); } //---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedNeuronUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronUpdateGroups, + [](const NeuronGroupInternal &){ return true; }, + &NeuronGroupInternal::getHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedPresynapticUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedPresynapticUpdateGroups, + [](const SynapseGroupInternal &sg) { return (sg.isSpikeEventRequired() || sg.isTrueSpikeRequired()); }, + &SynapseGroupInternal::getWUHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedPostsynapticUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedPostsynapticUpdateGroups, + [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getLearnPostCode().empty(); }, + &SynapseGroupInternal::getWUHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedSynapseDynamicsGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseDynamicsGroups, + [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getSynapseDynamicsCode().empty(); }, + &SynapseGroupInternal::getWUHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedCustomUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateGroups, + [&updateGroupName](const CustomUpdateInternal &cg) { return cg.getUpdateGroupName() == updateGroupName; }, + &CustomUpdateInternal::getHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedCustomUpdateWUGroups(const BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, + [&updateGroupName](const CustomUpdateWUInternal &cg) + { + return (!cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName); + }, + &CustomUpdateWUInternal::getHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedCustomUpdateTransposeWUGroups(const BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, + [&updateGroupName](const CustomUpdateWUInternal &cg) + { + return (cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName); + }, + &CustomUpdateWUInternal::getHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedCustomUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateHostReductionGroups, + [&updateGroupName](const CustomUpdateInternal &cg) + { + return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName); + }, + &CustomUpdateInternal::getHashDigest, generateGroup, true); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedCustomWUUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateHostReductionGroups, + [&updateGroupName](const CustomUpdateWUInternal &cg) + { + return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName); + }, + &CustomUpdateWUInternal::getHashDigest, generateGroup, true); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedCustomConnectivityUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateGroups, + [&updateGroupName](const CustomConnectivityUpdateInternal &cg) + { + return (!cg.getCustomConnectivityUpdateModel()->getRowUpdateCode().empty() && cg.getUpdateGroupName() == updateGroupName); + }, + &CustomConnectivityUpdateInternal::getHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedCustomConnectivityHostUpdateGroups(BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityHostUpdateGroups, + [&updateGroupName](const CustomConnectivityUpdateInternal &cg) + { + return (!cg.getCustomConnectivityUpdateModel()->getHostUpdateCode().empty() && cg.getUpdateGroupName() == updateGroupName); + }, + &CustomConnectivityUpdateInternal::getHashDigest, generateGroup, true); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedNeuronSpikeQueueUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronSpikeQueueUpdateGroups, + [](const NeuronGroupInternal &){ return true; }, + &NeuronGroupInternal::getSpikeQueueUpdateHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedNeuronPrevSpikeTimeUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronPrevSpikeTimeUpdateGroups, + [](const NeuronGroupInternal &ng){ return (ng.isPrevSpikeTimeRequired() || ng.isPrevSpikeEventTimeRequired()); }, + &NeuronGroupInternal::getPrevSpikeTimeUpdateHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedSynapseDendriticDelayUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + std::vector> synapseGroupsWithDendriticDelay; + for(const auto &n : getModel().getNeuronGroups()) { + for(const auto *sg : n.second.getFusedPSMInSyn()) { + if(sg->isDendriticDelayRequired()) { + synapseGroupsWithDendriticDelay.push_back(std::cref(*sg)); + } + } + } + createMergedGroups(backend, synapseGroupsWithDendriticDelay, m_MergedSynapseDendriticDelayUpdateGroups, + &SynapseGroupInternal::getDendriticDelayUpdateHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedNeuronInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronInitGroups, + [](const NeuronGroupInternal &){ return true; }, + &NeuronGroupInternal::getInitHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedCustomUpdateInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateInitGroups, + [](const CustomUpdateInternal &cg) { return cg.isVarInitRequired(); }, + &CustomUpdateInternal::getInitHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedCustomWUUpdateInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateInitGroups, + [](const CustomUpdateWUInternal &cg) + { + return (((cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::DENSE) + || (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL)) + && cg.isVarInitRequired()); + }, + &CustomUpdateWUInternal::getInitHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedSynapseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseInitGroups, + [](const SynapseGroupInternal &sg) + { + return (((sg.getMatrixType() & SynapseMatrixConnectivity::DENSE) + || (sg.getMatrixType() & SynapseMatrixWeight::KERNEL)) + && sg.isWUVarInitRequired()); + }, + &SynapseGroupInternal::getWUInitHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedSynapseConnectivityInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseConnectivityInitGroups, + [](const SynapseGroupInternal &sg){ return sg.isSparseConnectivityInitRequired(); }, + &SynapseGroupInternal::getConnectivityInitHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedSynapseSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseSparseInitGroups, + [&backend](const SynapseGroupInternal &sg) + { + return ((sg.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && + (sg.isWUVarInitRequired() + || (backend.isPostsynapticRemapRequired() && !sg.getWUModel()->getLearnPostCode().empty()))); + }, + &SynapseGroupInternal::getWUInitHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedCustomWUUpdateSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateSparseInitGroups, + [](const CustomUpdateWUInternal &cg) + { + return (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) && cg.isVarInitRequired(); + }, + &CustomUpdateWUInternal::getInitHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedCustomConnectivityUpdatePreInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePreInitGroups, + [&backend](const CustomConnectivityUpdateInternal &cg) + { + return (cg.isPreVarInitRequired() || (backend.isPopulationRNGInitialisedOnDevice() && cg.isRowSimRNGRequired())); + }, + &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedCustomConnectivityUpdatePostInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePostInitGroups, + [](const CustomConnectivityUpdateInternal &cg) { return cg.isPostVarInitRequired(); }, + &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedCustomConnectivityUpdateSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateSparseInitGroups, + [](const CustomConnectivityUpdateInternal &cg) { return cg.isVarInitRequired(); }, + &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); +} +//---------------------------------------------------------------------------- +void ModelSpecMerged::genMergedSynapseConnectivityHostInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +{ + createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseConnectivityHostInitGroups, + [](const SynapseGroupInternal &sg) + { + return !sg.getConnectivityInitialiser().getSnippet()->getHostInitCode().empty(); + }, + &SynapseGroupInternal::getConnectivityHostInitHashDigest, generateGroup, true); +} +//---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type ModelSpecMerged::getHashDigest(const BackendBase &backend) const { boost::uuids::detail::sha1 hash; From 446d85e47bb55c1ff4d5756262bfae537b57dc0b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 29 Jun 2023 13:49:10 +0100 Subject: [PATCH 270/725] backend SIMT and single-threaded CPU backend now compile --- .../backends/single_threaded_cpu/backend.h | 4 +- .../genn/genn/code_generator/backendSIMT.h | 207 ++++++++++-------- .../backends/single_threaded_cpu/backend.cc | 10 +- src/genn/genn/code_generator/backendSIMT.cc | 129 +++++------ 4 files changed, 180 insertions(+), 170 deletions(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index aa1478e2ef..6d9298b9dc 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -125,8 +125,8 @@ class BACKEND_EXPORT Backend : public BackendBase virtual void genVariableInit(EnvironmentExternalBase &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const final; virtual void genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const final; virtual void genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const final; - virtual void genKernelSynapseVariableInit(EnvironmentExternalBase &env, const SynapseInitGroupMerged &sg, HandlerEnv handler) const final; - virtual void genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, const CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const final; + virtual void genKernelSynapseVariableInit(EnvironmentExternalBase &env, SynapseInitGroupMerged &sg, HandlerEnv handler) const final; + virtual void genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const final; virtual void genGlobalDeviceRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, MemAlloc &memAlloc) const final; diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 8a5a6938f1..5b2971e19d 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -182,31 +182,31 @@ class GENN_EXPORT BackendSIMT : public BackendBase //------------------------------------------------------------------------ // Protected API //------------------------------------------------------------------------ - void genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; - void genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; + void genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; + void genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; - void genNeuronUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; + void genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; - void genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; - void genPresynapticUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; - void genPostsynapticUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; - void genSynapseDynamicsKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; + void genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; + void genPresynapticUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; + void genPostsynapticUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; + void genSynapseDynamicsKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; - void genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, + void genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const; - void genCustomUpdateWUKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, + void genCustomUpdateWUKernel(EnvironmentExternal &env, ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const; - void genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, + void genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const; - void genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, + void genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const; - void genInitializeKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const; + void genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; - void genInitializeSparseKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, + void genInitializeSparseKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t numInitializeThreads, size_t &idStart) const; //! Helper wrapper around padSize to pad size to a kernel size @@ -219,103 +219,126 @@ class GENN_EXPORT BackendSIMT : public BackendBase //-------------------------------------------------------------------------- // Type definitions //-------------------------------------------------------------------------- - template - using GetPaddedGroupSizeFunc = std::function; + template + using GenMergedGroupsFn = void (ModelSpecMerged::*)(const BackendBase&, std::function); + template + using GenMergedCustomUpdateGroupsFn = void (ModelSpecMerged::*)(const BackendBase&, const std::string &, std::function); + //-------------------------------------------------------------------------- // Private methods //-------------------------------------------------------------------------- - template - void genParallelGroup(EnvironmentExternalBase &env, std::vector &groups, size_t &idStart, - S getPaddedSizeFunc, F filter, GroupHandlerEnv handler) const + template + void genGroup(EnvironmentExternalBase &env, T &gMerge, size_t &idStart, + S getPaddedSizeFn, GroupHandlerEnv handler) const { - // Loop through groups - for(auto &gMerge : groups) { - if(filter(gMerge)) { - // Sum padded sizes of each group within merged group - const size_t paddedSize = std::accumulate( - gMerge.getGroups().cbegin(), gMerge.getGroups().cend(), size_t{0}, - [getPaddedSizeFunc](size_t acc, std::reference_wrapper g) - { - return (acc + getPaddedSizeFunc(g.get())); - }); + // Sum padded sizes of each group within merged group + const size_t paddedSize = std::accumulate( + gMerge.getGroups().cbegin(), gMerge.getGroups().cend(), size_t{0}, + [getPaddedSizeFn](size_t acc, std::reference_wrapper g) + { + return (acc + getPaddedSizeFn(g.get())); + }); - env.getStream() << "// merged" << gMerge.getIndex() << std::endl; + env.getStream() << "// merged" << gMerge.getIndex() << std::endl; - // If this is the first group - if(idStart == 0) { - env.getStream() << "if(id < " << paddedSize << ")"; - } - else { - env.getStream() << "if(id >= " << idStart << " && id < " << idStart + paddedSize << ")"; - } - { - CodeStream::Scope b(env.getStream()); - EnvironmentExternal groupEnv(env); + // If this is the first group + if(idStart == 0) { + env.getStream() << "if(id < " << paddedSize << ")"; + } + else { + env.getStream() << "if(id >= " << idStart << " && id < " << idStart + paddedSize << ")"; + } + { + CodeStream::Scope b(env.getStream()); + EnvironmentExternal groupEnv(env); - if(gMerge.getGroups().size() == 1) { - groupEnv.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; - groupEnv.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[0]; " << std::endl; - groupEnv.getStream() << "const unsigned int lid = id - " << idStart << ";" << std::endl; + if(gMerge.getGroups().size() == 1) { + groupEnv.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; + groupEnv.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[0]; " << std::endl; + groupEnv.getStream() << "const unsigned int lid = id - " << idStart << ";" << std::endl; - // Use the starting thread ID of the whole merged group as group_start_id - groupEnv.add(Type::Uint32.addConst(), "group_start_id", std::to_string(idStart)); + // Use the starting thread ID of the whole merged group as group_start_id + groupEnv.add(Type::Uint32.addConst(), "group_start_id", std::to_string(idStart)); + } + else { + // Perform bisect operation to get index of merged struct + groupEnv.getStream() << "unsigned int lo = 0;" << std::endl; + groupEnv.getStream() << "unsigned int hi = " << gMerge.getGroups().size() << ";" << std::endl; + groupEnv.getStream() << "while(lo < hi)" << std::endl; + { + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "const unsigned int mid = (lo + hi) / 2;" << std::endl; + + groupEnv.getStream() << "if(id < d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[mid])"; + { + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "hi = mid;" << std::endl; } - else { - // Perform bisect operation to get index of merged struct - groupEnv.getStream() << "unsigned int lo = 0;" << std::endl; - groupEnv.getStream() << "unsigned int hi = " << gMerge.getGroups().size() << ";" << std::endl; - groupEnv.getStream() << "while(lo < hi)" << std::endl; - { - CodeStream::Scope b(groupEnv.getStream()); - groupEnv.getStream() << "const unsigned int mid = (lo + hi) / 2;" << std::endl; - - groupEnv.getStream() << "if(id < d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[mid])"; - { - CodeStream::Scope b(groupEnv.getStream()); - groupEnv.getStream() << "hi = mid;" << std::endl; - } - groupEnv.getStream() << "else"; - { - CodeStream::Scope b(groupEnv.getStream()); - groupEnv.getStream() << "lo = mid + 1;" << std::endl; - } - } + groupEnv.getStream() << "else"; + { + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "lo = mid + 1;" << std::endl; + } + } - // Use this to get reference to merged group structure - groupEnv.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; - groupEnv.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[lo - 1]; " << std::endl; + // Use this to get reference to merged group structure + groupEnv.getStream() << getPointerPrefix() << "struct Merged" << T::name << "Group" << gMerge.getIndex() << " *group"; + groupEnv.getStream() << " = &d_merged" << T::name << "Group" << gMerge.getIndex() << "[lo - 1]; " << std::endl; - // Get group start thread ID and use as group_start_id - groupEnv.getStream() << "const unsigned int groupStartID = d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[lo - 1];" << std::endl; - groupEnv.add(Type::Uint32.addConst(), "_group_start_id", "groupStartID"); + // Get group start thread ID and use as group_start_id + groupEnv.getStream() << "const unsigned int groupStartID = d_merged" << T::name << "GroupStartID" << gMerge.getIndex() << "[lo - 1];" << std::endl; + groupEnv.add(Type::Uint32.addConst(), "_group_start_id", "groupStartID"); - // Use this to calculate local id within group - groupEnv.getStream() << "const unsigned int lid = id - groupStartID;" << std::endl; - } - groupEnv.add(Type::Uint32.addConst(), "id", "lid"); + // Use this to calculate local id within group + groupEnv.getStream() << "const unsigned int lid = id - groupStartID;" << std::endl; + } + groupEnv.add(Type::Uint32.addConst(), "id", "lid"); - handler(groupEnv, gMerge); + handler(groupEnv, gMerge); - idStart += paddedSize; - } - } + idStart += paddedSize; } } - + template + void genParallelGroup(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart, + GenMergedGroupsFn generateGroupFn, S getPaddedSizeFunc, GroupHandlerEnv handler) const + { + std::invoke(generateGroupFn, modelMerged, *this, + [this, getPaddedSizeFunc, handler, &env, &idStart](T &g) + { + genGroup(env, g, idStart, getPaddedSizeFunc, handler); + }); + } + + template + void genParallelGroup(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, const std::string &updateGroupName, size_t &idStart, + GenMergedCustomUpdateGroupsFn generateGroupFn, S getPaddedSizeFunc, GroupHandlerEnv handler) const + { + std::invoke(generateGroupFn, modelMerged, *this, updateGroupName, + [this, getPaddedSizeFunc, handler, &env, &idStart](T &g) + { + genGroup(env, g, idStart, getPaddedSizeFunc, handler); + }); + } + + + + + /*template void genParallelGroup(EnvironmentExternalBase &env, std::vector &groups, size_t &idStart, S getPaddedSizeFunc, GroupHandlerEnv handler) const { genParallelGroup(env, groups, idStart, getPaddedSizeFunc, [](const T &) { return true; }, handler); - } + }*/ // Helper function to generate kernel code to initialise variables associated with synapse group or custom WU update with dense/kernel connectivity template - void genSynapseVarInit(EnvironmentGroupMergedField &env, const ModelSpecMerged &modelMerged, + void genSynapseVarInit(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, G &g, bool initRNGRequired, bool kernel, size_t kernelDimensions) const { env.getStream() << "if(" << env["id"] << " < "; @@ -339,13 +362,14 @@ class GENN_EXPORT BackendSIMT : public BackendBase env.getStream() << ")"; { CodeStream::Scope b(env.getStream()); - EnvironmentGroupMergedField initEnv(env, env.getGroup()); + EnvironmentGroupMergedField initEnv(env, g); // If an RNG is required for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(initRNGRequired) { - genGlobalRNGSkipAhead(os, popSubs, "id"); + initEnv.add(Type::Void, "rng", + genGlobalRNGSkipAhead(initEnv.getStream(), "id")); } // If synapse group has kernel weights @@ -360,7 +384,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase // Loop backwards through other kernel and generate code to divide by product of subsequent dimensions kernelIDInit << " / ("; for (size_t j = (kernelDimensions - 1); j > i; j--) { - kernelIDInit << getKernelSize(env.getGroup(), j); + kernelIDInit << getKernelSize(g, j); if (j != (i + 1)) { kernelIDInit << " * "; @@ -372,7 +396,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase // If this isn't the first dimension, take modulus of kernel size if (i > 0) { - kernelIDInit << " % " << getKernelSize(env.getGroup(), i); + kernelIDInit << " % " << getKernelSize(g, i); } kernelIDInit << ";" << std::endl; @@ -394,7 +418,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase // Helper function to generate kernel code to initialise variables associated with synapse group or custom WU update with sparse connectivity template - void genSparseSynapseVarInit(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, const G &g, + void genSparseSynapseVarInit(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, G &g, bool varInitRequired, GroupHandlerEnv handler) const { // Calculate how many blocks rows need to be processed in (in order to store row lengths in shared memory) @@ -431,12 +455,13 @@ class GENN_EXPORT BackendSIMT : public BackendBase env.getStream() << "if(" << env["id"] << " < shRowLength[i])"; { CodeStream::Scope b(env.getStream()); - + // Generate initialisation code if(varInitRequired) { - env.add(Type::Uint32.addConst(), "id_pre", "((r * " + std::to_string(blockSize) + ") + i)"); - env.add(Type::Uint32.addConst(), "id_post", "$(_ind)[idx]"); - g.generateInit(*this, env, modelMerged); + EnvironmentExternal initEnv(env); + initEnv.add(Type::Uint32.addConst(), "id_pre", "((r * " + std::to_string(blockSize) + ") + i)"); + initEnv.add(Type::Uint32.addConst(), "id_post", "$(_ind)[idx]"); + g.generateInit(*this, initEnv, modelMerged); } // Call handler diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 6f42e317fc..6a80cb7ba3 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -81,7 +81,7 @@ void genKernelIteration(EnvironmentExternalBase &env, G &g, size_t numKernelDims { // Loop through this kernel dimensions const std::string idxVar = "k" + std::to_string(depth); - env.getStream() << "for(unsigned int " << idxVar << " = 0; " << idxVar << " < " << printSubs(getKernelSize(g, depth), env) << "; " << idxVar << "++)"; + env.print("for(unsigned int " + idxVar + " = 0; " + idxVar + " < " + getKernelSize(g, depth) + "; " + idxVar + "++)"); { CodeStream::Scope b(env.getStream()); EnvironmentGroupMergedField loopEnv(env, g); @@ -468,7 +468,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // **TODO** prod types const std::string offsetTrueSpkPost = (s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired() && s.getArchetype().getTrgNeuronGroup()->isDelayRequired()) ? "$(_post_delay_offset) + " : ""; - groupEnv.printLine("const unsigned int spike = $(_trg_spk)[" + offsetTrueSpkPost + "j];", groupEnv); + groupEnv.printLine("const unsigned int spike = $(_trg_spk)[" + offsetTrueSpkPost + "j];"); // Loop through column of presynaptic neurons if (s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { @@ -737,7 +737,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, c); - genCustomConnectivityUpdateIndexCalculation(funcEnv.getStream(), c); + genCustomConnectivityUpdateIndexCalculation(groupEnv); // Loop through presynaptic neurons funcEnv.getStream() << "for(unsigned int i = 0; i < " << funcEnv["num_pre"] << "; i++)"; @@ -1515,12 +1515,12 @@ void Backend::genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, Handl } } //-------------------------------------------------------------------------- -void Backend::genKernelSynapseVariableInit(EnvironmentExternalBase &env, const SynapseInitGroupMerged &sg, HandlerEnv handler) const +void Backend::genKernelSynapseVariableInit(EnvironmentExternalBase &env, SynapseInitGroupMerged &sg, HandlerEnv handler) const { genKernelIteration(env, sg, sg.getArchetype().getKernelSize().size(), handler); } //-------------------------------------------------------------------------- -void Backend::genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, const CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const +void Backend::genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const { genKernelIteration(env, cu, cu.getArchetype().getSynapseGroup()->getKernelSize().size(), handler); } diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 9d3394f761..eb78d217e7 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -309,14 +309,14 @@ void BackendSIMT::addPresynapticUpdateStrategy(PresynapticUpdateStrategySIMT::Ba s_PresynapticUpdateStrategies.push_back(strategy); } //-------------------------------------------------------------------------- -void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); // Parallelise over neuron groups idStart = 0; genParallelGroup( - env, modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups(), idStart, + env, modelMerged, idStart, &ModelSpecMerged::genMergedNeuronPrevSpikeTimeUpdateGroups, [this](const NeuronGroupInternal &ng) { return padKernelSize(ng.getNumNeurons(), KernelNeuronUpdate); }, [batchSize, this](EnvironmentExternalBase &popEnv, NeuronPrevSpikeTimeUpdateGroupMerged &ng) { @@ -386,7 +386,7 @@ void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &en } //-------------------------------------------------------------------------- -void BackendSIMT::genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); @@ -421,7 +421,7 @@ void BackendSIMT::genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, } } //-------------------------------------------------------------------------- -void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); @@ -475,7 +475,7 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, const Mode // Parallelise over neuron groups idStart = 0; genParallelGroup( - env, modelMerged.getMergedNeuronUpdateGroups(), idStart, + env, modelMerged, idStart, &ModelSpecMerged::genMergedNeuronUpdateGroups, [this](const NeuronGroupInternal &ng) { return padKernelSize(ng.getNumNeurons(), KernelNeuronUpdate); }, [batchSize, &modelMerged, this](EnvironmentExternalBase &popEnv, NeuronUpdateGroupMerged &ng) { @@ -637,7 +637,7 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, const Mode }); } //-------------------------------------------------------------------------- -void BackendSIMT::genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const { // Loop through merged synapse groups idStart = 0; @@ -662,7 +662,7 @@ void BackendSIMT::genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase & env.getStream() << std::endl; } //-------------------------------------------------------------------------- -void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const { EnvironmentExternal kernelEnv(env); @@ -692,11 +692,11 @@ void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, const // Parallelise over synapse groups idStart = 0; genParallelGroup( - kernelEnv, modelMerged.getMergedPresynapticUpdateGroups(), idStart, + env, modelMerged, idStart, &ModelSpecMerged::genMergedPresynapticUpdateGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumPresynapticUpdateThreads(sg, getPreferences()), KernelPresynapticUpdate); }, [&modelMerged, this](EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg) { - EnvironmentGroupMergedField groupEnv(env, sg); + EnvironmentGroupMergedField groupEnv(env, sg); // Get presynaptic update strategy to use for this synapse group const auto *presynapticUpdateStrategy = getPresynapticUpdateStrategy(sg.getArchetype()); @@ -727,7 +727,7 @@ void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, const }); } //-------------------------------------------------------------------------- -void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const { EnvironmentExternal kernelEnv(env); @@ -739,7 +739,8 @@ void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, cons // Parallelise over postsynaptic update groups idStart = 0; - genParallelGroup(kernelEnv, modelMerged.getMergedPostsynapticUpdateGroups(), idStart, + genParallelGroup( + env, modelMerged, idStart, &ModelSpecMerged::genMergedPostsynapticUpdateGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumPostsynapticUpdateThreads(sg), KernelPostsynapticUpdate); }, [&modelMerged, this](EnvironmentExternalBase &env, PostsynapticUpdateGroupMerged &sg) { @@ -816,16 +817,16 @@ void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, cons ); } //-------------------------------------------------------------------------- -void BackendSIMT::genSynapseDynamicsKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genSynapseDynamicsKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const { // Parallelise over synapse groups whose weight update models have code for synapse dynamics idStart = 0; genParallelGroup( - env, modelMerged.getMergedSynapseDynamicsGroups(), idStart, + env, modelMerged, idStart, &ModelSpecMerged::genMergedSynapseDynamicsGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumSynapseDynamicsThreads(sg), KernelSynapseDynamicsUpdate); }, [&modelMerged, this](EnvironmentExternalBase &env, SynapseDynamicsGroupMerged &sg) { - EnvironmentGroupMergedField groupEnv(env); + EnvironmentGroupMergedField groupEnv(env, sg); // Generate index calculation code const unsigned int batchSize = modelMerged.getModel().getBatchSize(); @@ -839,7 +840,7 @@ void BackendSIMT::genSynapseDynamicsKernel(EnvironmentExternalBase &env, const M } { CodeStream::Scope b(groupEnv.getStream()); - EnvironmentGroupMergedField synEnv(groupEnv, sg); + EnvironmentGroupMergedField synEnv(groupEnv, sg); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // **OPTIMIZE * *we can do a fast constant divide optimization here and use the result to calculate the remainder @@ -878,28 +879,24 @@ void BackendSIMT::genSynapseDynamicsKernel(EnvironmentExternalBase &env, const M }); } //-------------------------------------------------------------------------- -void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, +void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const { + const unsigned int batchSize = modelMerged.getModel().getBatchSize(); genParallelGroup( - env, modelMerged.getMergedCustomUpdateGroups(), idStart, - [&modelMerged, this](const CustomUpdateInternal &cu) - { - return getPaddedNumCustomUpdateThreads(cu, modelMerged.getModel().getBatchSize()); - }, - [&updateGroup](const CustomUpdateGroupMerged &cg) { return (cg.getArchetype().getUpdateGroupName() == updateGroup); }, - [&modelMerged, this](EnvironmentExternalBase &env, CustomUpdateGroupMerged &cg) + env, modelMerged, updateGroup, idStart, &ModelSpecMerged::genMergedCustomUpdateGroups, + [batchSize, this](const CustomUpdateInternal &cu) { return getPaddedNumCustomUpdateThreads(cu, batchSize); }, + [batchSize, this](EnvironmentExternalBase &env, CustomUpdateGroupMerged &cg) { const size_t blockSize = getKernelBlockSize(KernelCustomUpdate); - const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - + // If update is a batch reduction if(cg.getArchetype().isBatchReduction()) { env.getStream() << "// only do this for existing neurons" << std::endl; env.getStream() << "if(" << env["id"] << " < group->size)"; { CodeStream::Scope b(env.getStream()); - EnvironmentGroupMergedField groupEnv(env); + EnvironmentGroupMergedField groupEnv(env, cg); // Initialise reduction targets const auto reductionTargets = genInitReductionTargets(groupEnv.getStream(), cg, groupEnv["id"]); @@ -932,10 +929,10 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpe // Otherwise, if this is a neuron reduction else if (cg.getArchetype().isNeuronReduction()) { env.getStream() << "// only do this for existing neurons" << std::endl; - env.getStream() << "if(" << env["id"] << " < " << (32 * modelMerged.getModel().getBatchSize()) << ")"; + env.getStream() << "if(" << env["id"] << " < " << (32 * batchSize) << ")"; { CodeStream::Scope b(env.getStream()); - EnvironmentGroupMergedField groupEnv(env); + EnvironmentGroupMergedField groupEnv(env, cg); // Split ID into lane and batch groupEnv.getStream() << "const unsigned int lane = " << env["id"] << " % 32;" << std::endl; @@ -987,7 +984,7 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpe } // Otherwise else { - EnvironmentGroupMergedField groupEnv(env); + EnvironmentGroupMergedField groupEnv(env, cg); if(cg.getArchetype().isBatched()) { // Split ID into intra-batch ID and batch @@ -1018,21 +1015,17 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, const ModelSpe }); } //-------------------------------------------------------------------------- -void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, +void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const { + const unsigned int batchSize = modelMerged.getModel().getBatchSize(); genParallelGroup( - env, modelMerged.getMergedCustomUpdateWUGroups(), idStart, - [&modelMerged, this](const CustomUpdateWUInternal &cg) - { - return getPaddedNumCustomUpdateWUThreads(cg, modelMerged.getModel().getBatchSize()); - }, - [&updateGroup](const CustomUpdateWUGroupMerged &cg) { return (cg.getArchetype().getUpdateGroupName() == updateGroup); }, - [&modelMerged, this](EnvironmentExternalBase &env, CustomUpdateWUGroupMerged &cg) + env, modelMerged, updateGroup, idStart, &ModelSpecMerged::genMergedCustomUpdateWUGroups, + [batchSize, this](const CustomUpdateWUInternal &cu) { return getPaddedNumCustomUpdateWUThreads(cu, batchSize); }, + [batchSize, this](EnvironmentExternalBase &env, CustomUpdateWUGroupMerged &cg) { const SynapseGroupInternal *sg = cg.getArchetype().getSynapseGroup(); const size_t blockSize = getKernelBlockSize(KernelCustomUpdate); - const unsigned int batchSize = modelMerged.getModel().getBatchSize(); // Calculate size of each batch to update if (sg->getMatrixType() & SynapseMatrixWeight::KERNEL) { @@ -1145,21 +1138,16 @@ void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, const ModelS }); } //-------------------------------------------------------------------------- -void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, const ModelSpecMerged &modelMerged, +void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const { // Generate 2D array const size_t blockSize = getKernelBlockSize(KernelCustomTransposeUpdate); env.getStream() << getSharedPrefix() << " float shTile[" << blockSize << "][" << (blockSize + 1) << "];" << std::endl; - genParallelGroup( - env, modelMerged.getMergedCustomUpdateTransposeWUGroups(), idStart, - [&modelMerged, this](const CustomUpdateWUInternal &cg) - { - return getPaddedNumCustomUpdateTransposeWUThreads(cg, modelMerged.getModel().getBatchSize()); - }, - [&updateGroup](const CustomUpdateTransposeWUGroupMerged &cg) { return (cg.getArchetype().getUpdateGroupName() == updateGroup); }, - [&modelMerged, this, blockSize](EnvironmentExternalBase &env, CustomUpdateTransposeWUGroupMerged &cg) + env, modelMerged, updateGroup, idStart, &ModelSpecMerged::genMergedCustomUpdateTransposeWUGroups, + [&modelMerged, this](const CustomUpdateWUInternal &cu) { return getPaddedNumCustomUpdateTransposeWUThreads(cu, modelMerged.getModel().getBatchSize()); }, + [blockSize, this](EnvironmentExternalBase &env, CustomUpdateTransposeWUGroupMerged &cg) { EnvironmentGroupMergedField groupEnv(env, cg); @@ -1265,17 +1253,13 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, con }); } //-------------------------------------------------------------------------- -void BackendSIMT::genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, +void BackendSIMT::genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, const std::string &updateGroup, size_t &idStart) const { // Parallelise across presynaptic neurons genParallelGroup( - env, modelMerged.getMergedCustomConnectivityUpdateGroups(), idStart, - [this](const CustomConnectivityUpdateInternal &cg) - { - return padSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelCustomUpdate); - }, - [&updateGroup](const CustomConnectivityUpdateGroupMerged &cg) { return (cg.getArchetype().getUpdateGroupName() == updateGroup); }, + env, modelMerged, updateGroup, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdateGroups, + [this](const CustomConnectivityUpdateInternal &cg) { return padSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelCustomUpdate); }, [&modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdateGroupMerged &cg) { EnvironmentGroupMergedField groupEnv(env, cg); @@ -1306,13 +1290,13 @@ void BackendSIMT::genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env }); } //-------------------------------------------------------------------------- -void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const { env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Local neuron groups" << std::endl; idStart = 0; genParallelGroup( - env, modelMerged.getMergedNeuronInitGroups(), idStart, + env, modelMerged, idStart, &ModelSpecMerged::genMergedNeuronInitGroups, [this](const NeuronGroupInternal &ng) { return padKernelSize(ng.getNumNeurons(), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, NeuronInitGroupMerged &ng) { @@ -1356,12 +1340,11 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Synapse groups" << std::endl; genParallelGroup( - env, modelMerged.getMergedSynapseInitGroups(), idStart, + env, modelMerged, idStart, &ModelSpecMerged::genMergedSynapseInitGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumInitThreads(sg), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, SynapseInitGroupMerged &sg) { - EnvironmentGroupMergedField groupEnv(env, sg); - genSynapseVarInit(groupEnv, modelMerged, sg.getArchetype().isWUInitRNGRequired(), + genSynapseVarInit(env, modelMerged, sg, sg.getArchetype().isWUInitRNGRequired(), (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL), sg.getArchetype().getKernelSize().size()); }); @@ -1370,7 +1353,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Custom update groups" << std::endl; genParallelGroup( - env, modelMerged.getMergedCustomUpdateInitGroups(), idStart, + env, modelMerged, idStart, &ModelSpecMerged::genMergedCustomUpdateInitGroups, [this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, CustomUpdateInitGroupMerged &cg) { @@ -1395,13 +1378,12 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Custom WU update groups" << std::endl; genParallelGroup( - env, modelMerged.getMergedCustomWUUpdateInitGroups(), idStart, + env, modelMerged, idStart, &ModelSpecMerged::genMergedCustomWUUpdateInitGroups, [this](const CustomUpdateWUInternal &cg) { return padKernelSize(getNumInitThreads(cg), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &cg) { const SynapseGroup *sg = cg.getArchetype().getSynapseGroup(); - EnvironmentGroupMergedField groupEnv(env, cg); - genSynapseVarInit(groupEnv, modelMerged, cg.getArchetype().isInitRNGRequired(), + genSynapseVarInit(env, modelMerged, cg, cg.getArchetype().isInitRNGRequired(), (sg->getMatrixType() & SynapseMatrixWeight::KERNEL), sg->getKernelSize().size()); }); env.getStream() << std::endl; @@ -1409,7 +1391,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Custom connectivity presynaptic update groups" << std::endl; genParallelGroup( - env, modelMerged.getMergedCustomConnectivityUpdatePreInitGroups(), idStart, + env, modelMerged, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdatePreInitGroups, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdatePreInitGroupMerged &cg) { @@ -1441,7 +1423,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Custom connectivity postsynaptic update groups" << std::endl; genParallelGroup( - env, modelMerged.getMergedCustomConnectivityUpdatePostInitGroups(), idStart, + env, modelMerged, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdatePostInitGroups, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons(), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdatePostInitGroupMerged &cg) { @@ -1473,7 +1455,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Synapse groups with sparse connectivity" << std::endl; genParallelGroup( - env, modelMerged.getMergedSynapseConnectivityInitGroups(), idStart, + env, modelMerged, idStart, &ModelSpecMerged::genMergedSynapseConnectivityInitGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumConnectivityInitThreads(sg), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, SynapseConnectivityInitGroupMerged &sg) { @@ -1615,7 +1597,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, const ModelS env.getStream() << std::endl; } //-------------------------------------------------------------------------- -void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, +void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t numInitializeThreads, size_t &idStart) const { EnvironmentExternal envKernel(env); @@ -1623,9 +1605,10 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, const {envKernel.addInitialiser(getSharedPrefix() + "unsigned int shRowLength[" + std::to_string(getKernelBlockSize(KernelInitializeSparse)) + "];")}); // Initialise weight update variables for synapse groups with sparse connectivity - genParallelGroup(envKernel, modelMerged.getMergedSynapseSparseInitGroups(), idStart, - [this](const SynapseGroupInternal &sg) { return padKernelSize(sg.getMaxConnections(), KernelInitializeSparse); }, - [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, SynapseSparseInitGroupMerged &sg) + genParallelGroup( + env, modelMerged, idStart, &ModelSpecMerged::genMergedSynapseSparseInitGroups, + [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumConnectivityInitThreads(sg), KernelInitializeSparse); }, + [&modelMerged, numInitializeThreads, this](EnvironmentExternalBase &env, SynapseSparseInitGroupMerged &sg) { EnvironmentGroupMergedField groupEnv(env, sg); @@ -1663,7 +1646,8 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, const }); // Initialise weight update variables for synapse groups with sparse connectivity - genParallelGroup(envKernel, modelMerged.getMergedCustomWUUpdateSparseInitGroups(), idStart, + genParallelGroup( + env, modelMerged, idStart, &ModelSpecMerged::genMergedCustomWUUpdateSparseInitGroups, [this](const CustomUpdateWUInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, CustomWUUpdateSparseInitGroupMerged &cg) { @@ -1684,7 +1668,8 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, const }); // Initialise weight update variables for synapse groups with sparse connectivity - genParallelGroup(envKernel, modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups(), idStart, + genParallelGroup( + env, modelMerged, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdateSparseInitGroups, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdateSparseInitGroupMerged &cg) { From 0907e7b3c5da51d95437c653a016d5d68eaa7d1e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 29 Jun 2023 14:20:34 +0100 Subject: [PATCH 271/725] WIP --- .../genn/genn/code_generator/groupMerged.h | 133 +----------- .../genn/code_generator/initGroupMerged.h | 20 +- .../code_generator/neuronUpdateGroupMerged.h | 55 +++++ .../genn/code_generator/generateRunner.cc | 13 +- src/genn/genn/code_generator/groupMerged.cc | 195 ++---------------- .../genn/code_generator/initGroupMerged.cc | 44 ++++ .../code_generator/neuronUpdateGroupMerged.cc | 39 ++++ 7 files changed, 172 insertions(+), 327 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 471b32b384..81b9b1f745 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -127,21 +127,6 @@ class ChildGroupMerged }); } - //! Helper to test whether parameter values are heterogeneous within merged group - /*template - bool isParamValueHeterogeneous(size_t index, P getParamValuesFn) const - { - // Get value of parameter in archetype group - const double archetypeValue = getParamValuesFn(getArchetype()).at(index); - - // Return true if any parameter values differ from the archetype value - return std::any_of(getGroups().cbegin(), getGroups().cend(), - [archetypeValue, index, getParamValuesFn](const GroupInternal &g) - { - return (getParamValuesFn(g).at(index) != archetypeValue); - }); - }*/ - //! Helper to update hash with the hash of calling getHashableFn on each group template void updateHash(H getHashableFn, boost::uuids::detail::sha1 &hash) const @@ -450,60 +435,6 @@ class GroupMerged : public ChildGroupMerged std::vector m_Fields; }; -//---------------------------------------------------------------------------- -// GeNN::CodeGenerator::NeuronSpikeQueueUpdateGroupMerged -//---------------------------------------------------------------------------- -class GENN_EXPORT NeuronSpikeQueueUpdateGroupMerged : public GroupMerged -{ -public: - using GroupMerged::GroupMerged; - - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - void generateRunner(const BackendBase &backend, - CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, - CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, - CodeStream &runnerMergedStructAlloc) const - { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, - runnerVarDecl, runnerMergedStructAlloc, name); - } - - void genMergedGroupSpikeCountReset(EnvironmentExternalBase &env, unsigned int batchSize) const; - - //---------------------------------------------------------------------------- - // Static constants - //---------------------------------------------------------------------------- - static const std::string name; -}; - -//---------------------------------------------------------------------------- -// GeNN::CodeGenerator::NeuronPrevSpikeTimeUpdateGroupMerged -//---------------------------------------------------------------------------- -class GENN_EXPORT NeuronPrevSpikeTimeUpdateGroupMerged : public GroupMerged -{ -public: - using GroupMerged::GroupMerged; - - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - void generateRunner(const BackendBase &backend, - CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, - CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, - CodeStream &runnerMergedStructAlloc) const - { - generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, - runnerVarDecl, runnerMergedStructAlloc, name); - } - - //---------------------------------------------------------------------------- - // Static constants - //---------------------------------------------------------------------------- - static const std::string name; -}; - //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronGroupMergedBase //---------------------------------------------------------------------------- @@ -561,9 +492,6 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged> &groups) - : GroupMerged(index, typeContext, groups), m_ArchetypeCode(archetypeCode) - {} - + using GroupMerged::GroupMerged; + //---------------------------------------------------------------------------- // Protected methods //---------------------------------------------------------------------------- @@ -689,31 +586,5 @@ class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged { public: - SynapseConnectivityInitGroupMerged(size_t index, const Type::TypeContext &typeContext, - const std::vector> &groups) - : SynapseGroupMergedBase(index, typeContext, SynapseGroupMergedBase::Role::ConnectivityInit, "", groups) - {} - - boost::uuids::detail::sha1::digest_type getHashDigest() const - { - return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::ConnectivityInit); - } + using GroupMerged::GroupMerged; + + boost::uuids::detail::sha1::digest_type getHashDigest() const; void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, @@ -320,6 +314,12 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public SynapseGroupMerged static const std::string name; private: + //! Should the sparse connectivity initialization parameter be implemented heterogeneously? + bool isSparseConnectivityInitParamHeterogeneous(const std::string ¶mName) const; + + //! Should the sparse connectivity initialization parameter be implemented heterogeneously? + bool isSparseConnectivityInitDerivedParamHeterogeneous(const std::string ¶mName) const; + //---------------------------------------------------------------------------- // Private methods //---------------------------------------------------------------------------- diff --git a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h index 7a4e5dd410..928a22fa5f 100644 --- a/include/genn/genn/code_generator/neuronUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/neuronUpdateGroupMerged.h @@ -216,4 +216,59 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase std::vector m_MergedInSynWUMPostCodeGroups; std::vector m_MergedOutSynWUMPreCodeGroups; }; + + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronSpikeQueueUpdateGroupMerged +//---------------------------------------------------------------------------- +class GENN_EXPORT NeuronSpikeQueueUpdateGroupMerged : public GroupMerged +{ +public: + using GroupMerged::GroupMerged; + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + void generateRunner(const BackendBase &backend, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const + { + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + runnerVarDecl, runnerMergedStructAlloc, name); + } + + void genMergedGroupSpikeCountReset(EnvironmentExternalBase &env, unsigned int batchSize) const; + + //---------------------------------------------------------------------------- + // Static constants + //---------------------------------------------------------------------------- + static const std::string name; +}; + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronPrevSpikeTimeUpdateGroupMerged +//---------------------------------------------------------------------------- +class GENN_EXPORT NeuronPrevSpikeTimeUpdateGroupMerged : public GroupMerged +{ +public: + using GroupMerged::GroupMerged; + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + void generateRunner(const BackendBase &backend, + CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, + CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, + CodeStream &runnerMergedStructAlloc) const + { + generateRunnerBase(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + runnerVarDecl, runnerMergedStructAlloc, name); + } + + //---------------------------------------------------------------------------- + // Static constants + //---------------------------------------------------------------------------- + static const std::string name; +}; } // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 67ff8ec143..d7dd3f929e 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -441,14 +441,14 @@ void genRunnerFusedVars(const ModelSpecMerged &modelMerged, const BackendBase &b for(const auto &var : varAdaptor.getDefs()) { const auto resolvedType = var.type.resolve(modelMerged.getTypeContext()); backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - resolvedType, var.name + varAdaptor.getFusedSuffix(), varAdaptor.getLoc(var.name), + resolvedType, var.name + varAdaptor.getNameSuffix(), varAdaptor.getLoc(var.name), getSizeFn(group, var), mem); // Loop through EGPs required to initialize variable for(const auto &egp : varAdaptor.getInitialisers().at(var.name).getSnippet()->getExtraGlobalParams()) { genExtraGlobalParam(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerExtraGlobalParamFunc, - egp.type, egp.name + var.name + varAdaptor.getFusedSuffix(), + egp.type, egp.name + var.name + varAdaptor.getNameSuffix(), true, VarLocation::HOST_DEVICE); } } @@ -734,14 +734,17 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Loop through merged synapse connectivity host initialisation groups for(const auto &m : modelMerged.getMergedSynapseConnectivityHostInitGroups()) { - m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, - runnerVarDecl, runnerMergedStructAlloc); + assert(false); + //m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + // runnerVarDecl, runnerMergedStructAlloc); } // Loop through merged synapse connectivity host init groups and generate host init code // **NOTE** this is done here so valid pointers get copied straight into subsequent structures and merged EGP system isn't required for(const auto &sg : modelMerged.getMergedSynapseConnectivityHostInitGroups()) { - sg.generateInit(backend, runnerMergedStructAlloc, modelMerged); + assert(false); + //EnvironmentExternal env(runnerMergedStructAlloc); + //sg.generateInit(backend, runnerMergedStructAlloc, modelMerged); } // Generate merged neuron initialisation groups diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc index b1a8db22cd..0b0614da3a 100644 --- a/src/genn/genn/code_generator/groupMerged.cc +++ b/src/genn/genn/code_generator/groupMerged.cc @@ -15,131 +15,47 @@ using namespace GeNN; using namespace GeNN::CodeGenerator; -//---------------------------------------------------------------------------- -// GeNN::CodeGenerator::NeuronSpikeQueueUpdateGroupMerged -//---------------------------------------------------------------------------- -const std::string NeuronSpikeQueueUpdateGroupMerged::name = "NeuronSpikeQueueUpdate"; -//---------------------------------------------------------------------------- -void NeuronSpikeQueueUpdateGroupMerged::genMergedGroupSpikeCountReset(EnvironmentExternalBase &env, unsigned int batchSize) const -{ - if(getArchetype().isSpikeEventRequired()) { - if(getArchetype().isDelayRequired()) { - env.getStream() << env["_spk_cnt_evnt"] << "[*" << env["_spk_que_ptr"]; - if(batchSize > 1) { - env.getStream() << " + (batch * " << getArchetype().getNumDelaySlots() << ")"; - } - env.getStream() << "] = 0; " << std::endl; - } - else { - env.getStream() << env["_spk_cnt_evnt"] << "[" << ((batchSize > 1) ? "batch" : "0") << "] = 0;" << std::endl; - } - } - - if(getArchetype().isTrueSpikeRequired() && getArchetype().isDelayRequired()) { - env.getStream() << env["_spk_cnt"] << "[*" << env["_spk_que_ptr"]; - if(batchSize > 1) { - env.getStream() << " + (batch * " << getArchetype().getNumDelaySlots() << ")"; - } - env.getStream() << "] = 0; " << std::endl; - } - else { - env.getStream() << env["_spk_cnt"] << "[" << ((batchSize > 1) ? "batch" : "0") << "] = 0;" << std::endl; - } -} - -//---------------------------------------------------------------------------- -// GeNN::CodeGenerator::NeuronPrevSpikeTimeUpdateGroupMerged -//---------------------------------------------------------------------------- -const std::string NeuronPrevSpikeTimeUpdateGroupMerged::name = "NeuronPrevSpikeTimeUpdate"; - -//---------------------------------------------------------------------------- -// GeNN::CodeGenerator::NeuronGroupMergedBase -//---------------------------------------------------------------------------- -bool NeuronGroupMergedBase::isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const -{ - const auto *varInitSnippet = getArchetype().getVarInitialisers().at(varName).getSnippet(); - return isParamReferenced({varInitSnippet->getCode()}, paramName); -} - //---------------------------------------------------------------------------- // GeNN::CodeGenerator::SynapseGroupMergedBase //---------------------------------------------------------------------------- bool SynapseGroupMergedBase::isWUParamHeterogeneous(const std::string ¶mName) const { - return (isWUParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUParams(); })); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUParams(); }); } //---------------------------------------------------------------------------- bool SynapseGroupMergedBase::isWUDerivedParamHeterogeneous(const std::string ¶mName) const { - return (isWUParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUDerivedParams(); })); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isWUGlobalVarHeterogeneous(const std::string &varName) const -{ - return (isWUGlobalVarReferenced(varName) && - isParamValueHeterogeneous(varName, [](const SynapseGroupInternal &sg) { return sg.getWUConstInitVals(); })); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUDerivedParams(); }); } //---------------------------------------------------------------------------- bool SynapseGroupMergedBase::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { - return (isWUVarInitParamReferenced(varName, paramName) && - isParamValueHeterogeneous(paramName, [varName](const SynapseGroupInternal &sg){ return sg.getWUVarInitialisers().at(varName).getParams(); })); + return isParamValueHeterogeneous(paramName, [varName](const SynapseGroupInternal &sg){ return sg.getWUVarInitialisers().at(varName).getParams(); }); } //---------------------------------------------------------------------------- bool SynapseGroupMergedBase::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { - return (isWUVarInitParamReferenced(varName, paramName) && - isParamValueHeterogeneous(paramName, [varName](const SynapseGroupInternal &sg) { return sg.getWUVarInitialisers().at(varName).getDerivedParams(); })); + return isParamValueHeterogeneous(paramName, [varName](const SynapseGroupInternal &sg) { return sg.getWUVarInitialisers().at(varName).getDerivedParams(); }); } //---------------------------------------------------------------------------- bool SynapseGroupMergedBase::isSparseConnectivityInitParamHeterogeneous(const std::string ¶mName) const { - return (isSparseConnectivityInitParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getParams(); })); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getParams(); }); } //---------------------------------------------------------------------------- bool SynapseGroupMergedBase::isSparseConnectivityInitDerivedParamHeterogeneous(const std::string ¶mName) const { - return (isSparseConnectivityInitParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); })); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); }); } //---------------------------------------------------------------------------- bool SynapseGroupMergedBase::isToeplitzConnectivityInitParamHeterogeneous(const std::string ¶mName) const { - return (isToeplitzConnectivityInitParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getParams(); })); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getParams(); }); } //---------------------------------------------------------------------------- bool SynapseGroupMergedBase::isToeplitzConnectivityInitDerivedParamHeterogeneous(const std::string ¶mName) const { - return (isToeplitzConnectivityInitParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getDerivedParams(); })); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isSrcNeuronParamHeterogeneous(const std::string ¶mName) const -{ - return (isSrcNeuronParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getSrcNeuronGroup()->getParams(); })); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isSrcNeuronDerivedParamHeterogeneous(const std::string ¶mName) const -{ - return (isSrcNeuronParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getSrcNeuronGroup()->getDerivedParams(); })); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isTrgNeuronParamHeterogeneous(const std::string ¶mName) const -{ - return (isTrgNeuronParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getTrgNeuronGroup()->getParams(); })); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isTrgNeuronDerivedParamHeterogeneous(const std::string ¶mName) const -{ - return (isTrgNeuronParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getTrgNeuronGroup()->getDerivedParams(); })); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getDerivedParams(); }); } //---------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getPreSlot(unsigned int batchSize) const @@ -224,27 +140,7 @@ std::string SynapseGroupMergedBase::getKernelVarIndex(unsigned int batchSize, Va //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type SynapseGroupMergedBase::getHashDigest(Role role) const { - const bool updateRole = ((role == Role::PresynapticUpdate) - || (role == Role::PostsynapticUpdate) - || (role == Role::SynapseDynamics)); - - // Update hash with archetype's hash - boost::uuids::detail::sha1 hash; - if(updateRole) { - Utils::updateHash(getArchetype().getWUHashDigest(), hash); - } - else if (role == Role::ConnectivityInit) { - Utils::updateHash(getArchetype().getConnectivityInitHashDigest(), hash); - } - else { - Utils::updateHash(getArchetype().getWUInitHashDigest(), hash); - } - - // Update hash with number of neurons in pre and postsynaptic population - updateHash([](const SynapseGroupInternal &g) { return g.getSrcNeuronGroup()->getNumNeurons(); }, hash); - updateHash([](const SynapseGroupInternal &g) { return g.getTrgNeuronGroup()->getNumNeurons(); }, hash); - updateHash([](const SynapseGroupInternal &g) { return g.getMaxConnections(); }, hash); - updateHash([](const SynapseGroupInternal &g) { return g.getMaxSourceConnections(); }, hash); + teHash([](const SynapseGroupInternal &g) { return g.getMaxSourceConnections(); }, hash); if(updateRole) { // Update hash with weight update model parameters and derived parameters @@ -340,13 +236,13 @@ std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSi { if (delay) { if (varDuplication == VarAccessDuplication::SHARED_NEURON) { - return prefix + ((batchSize == 1) ? "DelaySlot" : "BatchDelaySlot"); + return prefix + ((batchSize == 1) ? "$(_delay_slot)" : "$(_batch_delay_slot)"); } else if (varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return prefix + "DelayOffset + " + index; + return prefix + "$(_delay_offset) + " + index; } else { - return prefix + "BatchDelayOffset + " + index; + return prefix + "$(_batch_delay_offset) + " + index; } } else { @@ -357,70 +253,7 @@ std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSi return index; } else { - return prefix + "BatchOffset + " + index; + return prefix + "$(_batch_offset) + " + index; } } -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isWUParamReferenced(const std::string ¶mName) const -{ - return isParamReferenced({getArchetypeCode()}, paramName); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isWUGlobalVarReferenced(const std::string &varName) const -{ - // If synapse group has global WU variables - if(getArchetype().getMatrixType() & SynapseMatrixWeight::GLOBAL) { - return isParamReferenced({getArchetypeCode()}, varName); - } - // Otherwise, return false - else { - return false; - } -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isWUVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const -{ - // If parameter isn't referenced in code, there's no point implementing it hetereogeneously! - const auto *varInitSnippet = getArchetype().getWUVarInitialisers().at(varName).getSnippet(); - return isParamReferenced({varInitSnippet->getCode()}, paramName); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isSparseConnectivityInitParamReferenced(const std::string ¶mName) const -{ - const auto *snippet = getArchetype().getConnectivityInitialiser().getSnippet(); - const auto rowBuildStateVars = snippet->getRowBuildStateVars(); - const auto colBuildStateVars = snippet->getColBuildStateVars(); - - // Build list of code strings containing row build code and any row build state variable values - std::vector codeStrings{snippet->getRowBuildCode(), snippet->getColBuildCode()}; - std::transform(rowBuildStateVars.cbegin(), rowBuildStateVars.cend(), std::back_inserter(codeStrings), - [](const Snippet::Base::ParamVal &p) { return p.value; }); - std::transform(colBuildStateVars.cbegin(), colBuildStateVars.cend(), std::back_inserter(codeStrings), - [](const Snippet::Base::ParamVal &p) { return p.value; }); - - return isParamReferenced(codeStrings, paramName); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isToeplitzConnectivityInitParamReferenced(const std::string ¶mName) const -{ - const auto *snippet = getArchetype().getToeplitzConnectivityInitialiser().getSnippet(); - const auto diagonalBuildStateVars = snippet->getDiagonalBuildStateVars(); - - // Build list of code strings containing diagonal build code and any diagonal build state variable values - std::vector codeStrings{snippet->getDiagonalBuildCode()}; - std::transform(diagonalBuildStateVars.cbegin(), diagonalBuildStateVars.cend(), std::back_inserter(codeStrings), - [](const Snippet::Base::ParamVal &p) { return p.value; }); - - return isParamReferenced(codeStrings, paramName); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isSrcNeuronParamReferenced(const std::string ¶mName) const -{ - return isParamReferenced({getArchetypeCode()}, paramName + "_pre"); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isTrgNeuronParamReferenced(const std::string ¶mName) const -{ - return isParamReferenced({getArchetypeCode()}, paramName + "_post"); -} +} \ No newline at end of file diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 864792ad06..59c8046a78 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -623,6 +623,40 @@ void SynapseSparseInitGroupMerged::generateInit(const BackendBase &backend, Envi //---------------------------------------------------------------------------- const std::string SynapseConnectivityInitGroupMerged::name = "SynapseConnectivityInit"; //---------------------------------------------------------------------------- +boost::uuids::detail::sha1::digest_type SynapseConnectivityInitGroupMerged::getHashDigest() const +{ + boost::uuids::detail::sha1 hash; + + // Update hash with archetype connectivity initialisation hash + Utils::updateHash(getArchetype().getConnectivityInitHashDigest(), hash); + + // Update hash with number of neurons in pre and postsynaptic population + updateHash([](const SynapseGroupInternal &g) { return g.getSrcNeuronGroup()->getNumNeurons(); }, hash); + updateHash([](const SynapseGroupInternal &g) { return g.getTrgNeuronGroup()->getNumNeurons(); }, hash); + updateHash([](const SynapseGroupInternal &g) { return g.getMaxConnections(); }, hash); + updateHash([](const SynapseGroupInternal &g) { return g.getMaxSourceConnections(); }, hash); + + // Update hash with connectivity parameters and derived parameters + updateParamHash( + &SynapseGroupMergedBase::isConnectivityInitParamReferenced, + [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getParams(); }, hash); + + updateParamHash( + &SynapseGroupMergedBase::isConnectivityInitParamReferenced, + [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); }, hash); + + if(!getArchetype().getKernelSize().empty()) { + updateHash([](const SynapseGroupInternal &g) { return g.getKernelSize(); }, hash); + + // Update hash with each group's variable initialisation parameters and derived parameters + updateVarInitParamHash( + &SynapseGroupMergedBase::isWUVarInitParamReferenced, hash); + updateVarInitDerivedParamHash( + &SynapseGroupMergedBase::isWUVarInitParamReferenced, hash); + } + return hash.get_digest(); +} +//---------------------------------------------------------------------------- void SynapseConnectivityInitGroupMerged::generateSparseRowInit(const BackendBase &backend, EnvironmentExternalBase &env) { genInitConnectivity(backend, env, true); @@ -651,6 +685,16 @@ void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase &b }); } //---------------------------------------------------------------------------- +bool SynapseConnectivityInitGroupMerged::isSparseConnectivityInitParamHeterogeneous(const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getParams(); }); +} +//---------------------------------------------------------------------------- +bool SynapseConnectivityInitGroupMerged::isSparseConnectivityInitDerivedParamHeterogeneous(const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); }); +} +//---------------------------------------------------------------------------- void SynapseConnectivityInitGroupMerged::genInitConnectivity(const BackendBase &backend, EnvironmentExternalBase &env, bool rowNotColumns) { const auto &connectInit = getArchetype().getConnectivityInitialiser(); diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 177899e4b8..a291ee5bad 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -826,3 +826,42 @@ bool NeuronUpdateGroupMerged::isDerivedParamHeterogeneous(const std::string &par { return isParamValueHeterogeneous(paramName, [](const NeuronGroupInternal &ng) { return ng.getDerivedParams(); }); } + + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronSpikeQueueUpdateGroupMerged +//---------------------------------------------------------------------------- +const std::string NeuronSpikeQueueUpdateGroupMerged::name = "NeuronSpikeQueueUpdate"; +//---------------------------------------------------------------------------- +void NeuronSpikeQueueUpdateGroupMerged::genMergedGroupSpikeCountReset(EnvironmentExternalBase &env, unsigned int batchSize) const +{ + if(getArchetype().isSpikeEventRequired()) { + if(getArchetype().isDelayRequired()) { + env.getStream() << env["_spk_cnt_evnt"] << "[*" << env["_spk_que_ptr"]; + if(batchSize > 1) { + env.getStream() << " + (batch * " << getArchetype().getNumDelaySlots() << ")"; + } + env.getStream() << "] = 0; " << std::endl; + } + else { + env.getStream() << env["_spk_cnt_evnt"] << "[" << ((batchSize > 1) ? "batch" : "0") << "] = 0;" << std::endl; + } + } + + if(getArchetype().isTrueSpikeRequired() && getArchetype().isDelayRequired()) { + env.getStream() << env["_spk_cnt"] << "[*" << env["_spk_que_ptr"]; + if(batchSize > 1) { + env.getStream() << " + (batch * " << getArchetype().getNumDelaySlots() << ")"; + } + env.getStream() << "] = 0; " << std::endl; + } + else { + env.getStream() << env["_spk_cnt"] << "[" << ((batchSize > 1) ? "batch" : "0") << "] = 0;" << std::endl; + } +} + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::NeuronPrevSpikeTimeUpdateGroupMerged +//---------------------------------------------------------------------------- +const std::string NeuronPrevSpikeTimeUpdateGroupMerged::name = "NeuronPrevSpikeTimeUpdate"; + From 022fc118a27d1327d8b43a5e3962490e1f8bca8d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 29 Jun 2023 18:51:16 +0100 Subject: [PATCH 272/725] WIP moving around of old bits of group merged into saner structure --- .../genn/genn/code_generator/backendBase.h | 6 +- .../genn/genn/code_generator/groupMerged.h | 157 ++--------- .../genn/code_generator/initGroupMerged.h | 42 +-- .../code_generator/synapseUpdateGroupMerged.h | 122 +++++++-- src/genn/genn/code_generator/backendBase.cc | 2 +- src/genn/genn/code_generator/groupMerged.cc | 259 ------------------ .../genn/code_generator/initGroupMerged.cc | 28 +- .../code_generator/neuronUpdateGroupMerged.cc | 18 +- .../synapseUpdateGroupMerged.cc | 193 +++++++++++++ src/genn/genn/genn.vcxproj | 1 - 10 files changed, 336 insertions(+), 492 deletions(-) delete mode 100644 src/genn/genn/code_generator/groupMerged.cc diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index dc2a2ea440..19e6396883 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -39,9 +39,9 @@ template class EnvironmentGroupMergedField; class EnvironmentExternalBase; class ModelSpecMerged; +template +class GroupMerged; class NeuronUpdateGroupMerged; -class Substitutions; -class SynapseGroupMergedBase; class PresynapticUpdateGroupMerged; class PostsynapticUpdateGroupMerged; class SynapseDynamicsGroupMerged; @@ -453,7 +453,7 @@ class GENN_EXPORT BackendBase return isDeviceScalarRequired() ? getDeviceVarPrefix() : ("&" + getDeviceVarPrefix()); } - bool areSixtyFourBitSynapseIndicesRequired(const SynapseGroupMergedBase &sg) const; + bool areSixtyFourBitSynapseIndicesRequired(const GroupMerged &sg) const; //! Get backend-specific pointer size in bytes size_t getPointerBytes() const{ return m_PointerBytes; } diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 81b9b1f745..8e6502d018 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -102,16 +102,6 @@ class ChildGroupMerged //------------------------------------------------------------------------ // Protected API //------------------------------------------------------------------------ - //! Helper to test whether parameter is referenced in vector of codestrings - bool isParamReferenced(const std::vector &codeStrings, const std::string ¶mName) const - { - return std::any_of(codeStrings.begin(), codeStrings.end(), - [¶mName](const std::string &c) - { - return (c.find("$(" + paramName + ")") != std::string::npos); - }); - } - //! Helper to test whether parameter values are heterogeneous within merged group template bool isParamValueHeterogeneous(const std::string &name, P getParamValuesFn) const @@ -136,62 +126,53 @@ class ChildGroupMerged } } - template - void updateParamHash(R isParamReferencedFn, V getValueFn, boost::uuids::detail::sha1 &hash) const + template + void updateParamHash(V getValueFn, boost::uuids::detail::sha1 &hash) const { // Loop through parameters const auto &archetypeParams = getValueFn(getArchetype()); for(const auto &p : archetypeParams) { - // If any of the code strings reference the parameter - if((static_cast(this)->*isParamReferencedFn)(p.first)) { - // Loop through groups - for(const auto &g : getGroups()) { - // Update hash with parameter value - Utils::updateHash(getValueFn(g.get()).at(p.first), hash); - } + // Loop through groups + for(const auto &g : getGroups()) { + // Update hash with parameter value + Utils::updateHash(getValueFn(g.get()).at(p.first), hash); } } } - template - void updateVarInitParamHash(R isParamReferencedFn, boost::uuids::detail::sha1 &hash) const + template + void updateVarInitParamHash(boost::uuids::detail::sha1 &hash) const { // Loop through variables const auto &archetypeVarInitialisers = A(getArchetype()).getInitialisers(); for(const auto &varInit : archetypeVarInitialisers) { // Loop through parameters for(const auto &p : varInit.second.getParams()) { - // If any of the code strings reference the parameter - if((static_cast(this)->*isParamReferencedFn)(varInit.first, p.first)) { - // Loop through groups - for(const auto &g : getGroups()) { - const auto &values = A(g.get()).getInitialisers().at(varInit.first).getParams(); - - // Update hash with parameter value - Utils::updateHash(values.at(p.first), hash); - } + // Loop through groups + for(const auto &g : getGroups()) { + const auto &values = A(g.get()).getInitialisers().at(varInit.first).getParams(); + + // Update hash with parameter value + Utils::updateHash(values.at(p.first), hash); } } } } - template - void updateVarInitDerivedParamHash(R isDerivedParamReferencedFn, boost::uuids::detail::sha1 &hash) const + template + void updateVarInitDerivedParamHash(boost::uuids::detail::sha1 &hash) const { // Loop through variables const auto &archetypeVarInitialisers = A(getArchetype()).getInitialisers(); for(const auto &varInit : archetypeVarInitialisers) { // Loop through parameters for(const auto &d : varInit.second.getDerivedParams()) { - // If any of the code strings reference the parameter - if((static_cast(this)->*isDerivedParamReferencedFn)(varInit.first, d.first)) { - // Loop through groups - for(const auto &g : getGroups()) { - const auto &values = A(g.get()).getInitialisers().at(varInit.first).getDerivedParams(); - - // Update hash with parameter value - Utils::updateHash(values.at(d.first), hash); - } + // Loop through groups + for(const auto &g : getGroups()) { + const auto &values = A(g.get()).getInitialisers().at(varInit.first).getDerivedParams(); + + // Update hash with parameter value + Utils::updateHash(values.at(d.first), hash); } } } @@ -493,98 +474,4 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged -{ -public: - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - //! Should the weight update model parameter be implemented heterogeneously? - bool isWUParamHeterogeneous(const std::string ¶mName) const; - - //! Should the weight update model derived parameter be implemented heterogeneously? - bool isWUDerivedParamHeterogeneous(const std::string ¶mName) const; - - //! Should the weight update model variable initialization parameter be implemented heterogeneously? - bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - //! Should the weight update model variable initialization derived parameter be implemented heterogeneously? - bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; - - //! Should the sparse connectivity initialization parameter be implemented heterogeneously? - bool isSparseConnectivityInitParamHeterogeneous(const std::string ¶mName) const; - - //! Should the sparse connectivity initialization parameter be implemented heterogeneously? - bool isSparseConnectivityInitDerivedParamHeterogeneous(const std::string ¶mName) const; - - //! Should the Toeplitz connectivity initialization parameter be implemented heterogeneously? - bool isToeplitzConnectivityInitParamHeterogeneous(const std::string ¶mName) const; - - //! Should the Toeplitz connectivity initialization parameter be implemented heterogeneously? - bool isToeplitzConnectivityInitDerivedParamHeterogeneous(const std::string ¶mName) const; - - std::string getPreSlot(unsigned int batchSize) const; - std::string getPostSlot(unsigned int batchSize) const; - - std::string getPreVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const - { - return getPreVarIndex(getArchetype().getSrcNeuronGroup()->isDelayRequired(), batchSize, varDuplication, index); - } - - std::string getPostVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const - { - return getPostVarIndex(getArchetype().getTrgNeuronGroup()->isDelayRequired(), batchSize, varDuplication, index); - } - - std::string getPreWUVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const - { - return getPreVarIndex(getArchetype().getDelaySteps() != 0, batchSize, varDuplication, index); - } - - std::string getPostWUVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const - { - return getPostVarIndex(getArchetype().getBackPropDelaySteps() != 0, batchSize, varDuplication, index); - } - - std::string getPostDenDelayIndex(unsigned int batchSize, const std::string &index, const std::string &offset) const; - - std::string getPreVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; - std::string getPostVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; - - std::string getPrePrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; - std::string getPostPrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; - - std::string getPostISynIndex(unsigned int batchSize, const std::string &index) const - { - return ((batchSize == 1) ? "" : "$(_post_batch_offset) + ") + index; - } - - std::string getPreISynIndex(unsigned int batchSize, const std::string &index) const - { - return ((batchSize == 1) ? "" : "$(pre_batch_offset) + ") + index; - } - - std::string getSynVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; - std::string getKernelVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; - -protected: - using GroupMerged::GroupMerged; - - //---------------------------------------------------------------------------- - // Protected methods - //---------------------------------------------------------------------------- - boost::uuids::detail::sha1::digest_type getHashDigest(Role role) const; - - -private: - //------------------------------------------------------------------------ - // Private methods - //------------------------------------------------------------------------ - std::string getVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, - const std::string &index, const std::string &prefix) const; -}; } // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 051505d865..c27b973963 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -20,23 +20,21 @@ class GENN_EXPORT InitGroupMergedBase : public B //! Should the var init parameter be implemented heterogeneously? bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { - return (isVarInitParamReferenced(varName, paramName) && - this->isParamValueHeterogeneous(paramName, - [&varName](const auto &g) - { - return A(g).getInitialisers().at(varName).getParams(); - })); + return this->isParamValueHeterogeneous(paramName, + [&varName](const auto &g) + { + return A(g).getInitialisers().at(varName).getParams(); + }); } //! Should the var init derived parameter be implemented heterogeneously? bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const { - return (isVarInitParamReferenced(varName, paramName) && - this->isParamValueHeterogeneous(paramName, - [&varName](const auto &g) - { - return A(g).getInitialisers().at(varName).getDerivedParams(); - })); + return this->isParamValueHeterogeneous(paramName, + [&varName](const auto &g) + { + return A(g).getInitialisers().at(varName).getDerivedParams(); + }); } protected: //---------------------------------------------------------------------------- @@ -45,22 +43,9 @@ class GENN_EXPORT InitGroupMergedBase : public B void updateBaseHash(boost::uuids::detail::sha1 &hash) const { // Update hash with each group's variable initialisation parameters and derived parameters - this->template updateVarInitParamHash, A>( - &InitGroupMergedBase::isVarInitParamHeterogeneous, hash); + this->template updateVarInitParamHash(hash); - this->template updateVarInitDerivedParamHash, A>( - &InitGroupMergedBase::isVarInitDerivedParamHeterogeneous, hash); - } - -private: - //---------------------------------------------------------------------------- - // Private methods - //---------------------------------------------------------------------------- - //! Is the var init parameter referenced? - bool isVarInitParamReferenced(const std::string &varName, const std::string ¶mName) const - { - const auto *varInitSnippet = A(this->getArchetype()).getInitialisers().at(varName).getSnippet(); - return this->isParamReferenced({varInitSnippet->getCode()}, paramName); + this->template updateVarInitDerivedParamHash(hash); } }; @@ -364,9 +349,6 @@ class GENN_EXPORT SynapseConnectivityHostInitGroupMerged : public GroupMerged { public: - PresynapticUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, - const std::vector> &groups) - : SynapseGroupMergedBase(index, typeContext, SynapseGroupMergedBase::Role::PresynapticUpdate, - groups.front().get().getWUModel()->getSimCode() + groups.front().get().getWUModel()->getEventCode() + groups.front().get().getWUModel()->getEventThresholdConditionCode(), groups) - {} + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + //! Should the weight update model parameter be implemented heterogeneously? + bool isWUParamHeterogeneous(const std::string ¶mName) const; + + //! Should the weight update model derived parameter be implemented heterogeneously? + bool isWUDerivedParamHeterogeneous(const std::string ¶mName) const; + + //! Should the weight update model variable initialization parameter be implemented heterogeneously? + bool isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + + //! Should the weight update model variable initialization derived parameter be implemented heterogeneously? + bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const; + + //! Should the sparse connectivity initialization parameter be implemented heterogeneously? + bool isSparseConnectivityInitParamHeterogeneous(const std::string ¶mName) const; + + //! Should the sparse connectivity initialization parameter be implemented heterogeneously? + bool isSparseConnectivityInitDerivedParamHeterogeneous(const std::string ¶mName) const; + + //! Should the Toeplitz connectivity initialization parameter be implemented heterogeneously? + bool isToeplitzConnectivityInitParamHeterogeneous(const std::string ¶mName) const; - boost::uuids::detail::sha1::digest_type getHashDigest() const + //! Should the Toeplitz connectivity initialization parameter be implemented heterogeneously? + bool isToeplitzConnectivityInitDerivedParamHeterogeneous(const std::string ¶mName) const; + + std::string getPreSlot(unsigned int batchSize) const; + std::string getPostSlot(unsigned int batchSize) const; + + std::string getPreVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const { - return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::PresynapticUpdate); + return getPreVarIndex(getArchetype().getSrcNeuronGroup()->isDelayRequired(), batchSize, varDuplication, index); + } + + std::string getPostVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const + { + return getPostVarIndex(getArchetype().getTrgNeuronGroup()->isDelayRequired(), batchSize, varDuplication, index); } + std::string getPreWUVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const + { + return getPreVarIndex(getArchetype().getDelaySteps() != 0, batchSize, varDuplication, index); + } + + std::string getPostWUVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const + { + return getPostVarIndex(getArchetype().getBackPropDelaySteps() != 0, batchSize, varDuplication, index); + } + + std::string getPostDenDelayIndex(unsigned int batchSize, const std::string &index, const std::string &offset) const; + + std::string getPreVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; + std::string getPostVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; + + std::string getPrePrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; + std::string getPostPrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; + + std::string getPostISynIndex(unsigned int batchSize, const std::string &index) const + { + return ((batchSize == 1) ? "" : "$(_post_batch_offset) + ") + index; + } + + std::string getPreISynIndex(unsigned int batchSize, const std::string &index) const + { + return ((batchSize == 1) ? "" : "$(pre_batch_offset) + ") + index; + } + + std::string getSynVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; + std::string getKernelVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; + + boost::uuids::detail::sha1::digest_type getHashDigest() const; + +protected: + using GroupMerged::GroupMerged; + + +private: + //------------------------------------------------------------------------ + // Private methods + //------------------------------------------------------------------------ + std::string getVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, + const std::string &index, const std::string &prefix) const; +}; + +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::PresynapticUpdateGroupMerged +//---------------------------------------------------------------------------- +class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase +{ +public: + using SynapseGroupMergedBase::SynapseGroupMergedBase; + void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, @@ -49,16 +131,7 @@ class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase { public: - PostsynapticUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext, - const std::vector> &groups) - : SynapseGroupMergedBase(index, typeContext, SynapseGroupMergedBase::Role::PostsynapticUpdate, - groups.front().get().getWUModel()->getLearnPostCode(), groups) - {} - - boost::uuids::detail::sha1::digest_type getHashDigest() const - { - return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::PostsynapticUpdate); - } + using SynapseGroupMergedBase::SynapseGroupMergedBase; void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, @@ -83,16 +156,7 @@ class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase { public: - SynapseDynamicsGroupMerged(size_t index, const Type::TypeContext &typeContext, - const std::vector> &groups) - : SynapseGroupMergedBase(index, typeContext, SynapseGroupMergedBase::Role::SynapseDynamics, - groups.front().get().getWUModel()->getSynapseDynamicsCode(), groups) - {} - - boost::uuids::detail::sha1::digest_type getHashDigest() const - { - return SynapseGroupMergedBase::getHashDigest(SynapseGroupMergedBase::Role::SynapseDynamics); - } + using SynapseGroupMergedBase::SynapseGroupMergedBase; void generateRunner(const BackendBase &backend, CodeStream &definitionsInternal, CodeStream &definitionsInternalFunc, diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 3fb2fcabc3..92edc89b7b 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -20,7 +20,7 @@ BackendBase::BackendBase(const PreferencesBase &preferences) { } //-------------------------------------------------------------------------- -bool BackendBase::areSixtyFourBitSynapseIndicesRequired(const SynapseGroupMergedBase &sg) const +bool BackendBase::areSixtyFourBitSynapseIndicesRequired(const GroupMerged &sg) const { // Loop through merged groups and calculate maximum number of synapses size_t maxSynapses = 0; diff --git a/src/genn/genn/code_generator/groupMerged.cc b/src/genn/genn/code_generator/groupMerged.cc deleted file mode 100644 index 0b0614da3a..0000000000 --- a/src/genn/genn/code_generator/groupMerged.cc +++ /dev/null @@ -1,259 +0,0 @@ -#include "code_generator/groupMerged.h" - -// PLOG includes -#include - -// GeNN includes -#include "modelSpecInternal.h" - -// GeNN code generator includes -#include "code_generator/backendBase.h" -#include "code_generator/codeGenUtils.h" -#include "code_generator/codeStream.h" -#include "code_generator/environment.h" - -using namespace GeNN; -using namespace GeNN::CodeGenerator; - -//---------------------------------------------------------------------------- -// GeNN::CodeGenerator::SynapseGroupMergedBase -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isWUParamHeterogeneous(const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUParams(); }); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isWUDerivedParamHeterogeneous(const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUDerivedParams(); }); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, [varName](const SynapseGroupInternal &sg){ return sg.getWUVarInitialisers().at(varName).getParams(); }); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, [varName](const SynapseGroupInternal &sg) { return sg.getWUVarInitialisers().at(varName).getDerivedParams(); }); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isSparseConnectivityInitParamHeterogeneous(const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getParams(); }); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isSparseConnectivityInitDerivedParamHeterogeneous(const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); }); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isToeplitzConnectivityInitParamHeterogeneous(const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getParams(); }); -} -//---------------------------------------------------------------------------- -bool SynapseGroupMergedBase::isToeplitzConnectivityInitDerivedParamHeterogeneous(const std::string ¶mName) const -{ - return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getDerivedParams(); }); -} -//---------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPreSlot(unsigned int batchSize) const -{ - if(getArchetype().getSrcNeuronGroup()->isDelayRequired()) { - return (batchSize == 1) ? "$(_pre_delay_slot)" : "$(_pre_batch_delay_slot)"; - } - else { - return (batchSize == 1) ? "0" : "$(batch)"; - } -} -//---------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPostSlot(unsigned int batchSize) const -{ - if(getArchetype().getTrgNeuronGroup()->isDelayRequired()) { - return (batchSize == 1) ? "$(_post_delay_slot)" : "$(_post_batch_delay_slot)"; - } - else { - return (batchSize == 1) ? "0" : "$(batch)"; - } -} -//---------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPostDenDelayIndex(unsigned int batchSize, const std::string &index, const std::string &offset) const -{ - assert(getArchetype().isDendriticDelayRequired()); - - const std::string batchID = ((batchSize == 1) ? "" : "$(_post_batch_offset) + ") + std::string{"$(" + index + ")"}; - - if(offset.empty()) { - return "(*$(_den_delay_ptr) * $(num_post) + " + batchID; - } - else { - return "(((*(_den_delay_ptr) + " + offset + ") % " + std::to_string(getArchetype().getMaxDendriticDelayTimesteps()) + ") * $(num_post)) + " + batchID; - } -} -//---------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPreVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const -{ - return getVarIndex(delay, batchSize, varDuplication, index, "pre"); -} -//-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPostVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const -{ - return getVarIndex(delay, batchSize, varDuplication, index, "post"); -} -//-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPrePrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const -{ - const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); - - if(delay) { - return (singleBatch ? "$(_pre_prev_spike_time_delay_offset) + " : "$(_pre_prev_spike_time_batch_delay_offset) + ") + index; - } - else { - return (singleBatch ? "" : "$(_pre_batch_offset) + ") + std::string{"$(" + index + ")"}; - } -} -//-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getPostPrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const -{ - const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); - - if(delay) { - return (singleBatch ? "$(_post_prev_spike_time_delay_offset) + " : "$(_post_prev_spike_time_batch_delay_offset) + ") + std::string{"$(" + index + ")"}; - } - else { - return (singleBatch ? "" : "$(_post_batch_offset) + ") + std::string{"$(" + index + ")"}; - } -} -//-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getSynVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const -{ - const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); - return (singleBatch ? "" : "$(_syn_batch_offset)") + std::string{"$(" + index + ")"}; -} -//-------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getKernelVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const -{ - const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); - return (singleBatch ? "" : "$(_kern_batch_offset)") + std::string{"$(" + index + ")"}; -} -//---------------------------------------------------------------------------- -boost::uuids::detail::sha1::digest_type SynapseGroupMergedBase::getHashDigest(Role role) const -{ - teHash([](const SynapseGroupInternal &g) { return g.getMaxSourceConnections(); }, hash); - - if(updateRole) { - // Update hash with weight update model parameters and derived parameters - updateHash([](const SynapseGroupInternal &g) { return g.getWUParams(); }, hash); - updateHash([](const SynapseGroupInternal &g) { return g.getWUDerivedParams(); }, hash); - - // Update hash with presynaptic neuron population parameters and derived parameters - updateParamHash( - &SynapseGroupMergedBase::isSrcNeuronParamReferenced, - [](const SynapseGroupInternal &g) { return g.getSrcNeuronGroup()->getParams(); }, hash); - - updateParamHash( - &SynapseGroupMergedBase::isSrcNeuronParamReferenced, - [](const SynapseGroupInternal &g) { return g.getSrcNeuronGroup()->getDerivedParams(); }, hash); - - // Update hash with postsynaptic neuron population parameters and derived parameters - updateParamHash( - &SynapseGroupMergedBase::isTrgNeuronParamReferenced, - [](const SynapseGroupInternal &g) { return g.getTrgNeuronGroup()->getParams(); }, hash); - - updateParamHash( - &SynapseGroupMergedBase::isTrgNeuronParamReferenced, - [](const SynapseGroupInternal &g) { return g.getTrgNeuronGroup()->getDerivedParams(); }, hash); - } - - - // If we're updating a hash for a group with procedural connectivity or initialising connectivity - if((getArchetype().getMatrixType() & SynapseMatrixConnectivity::PROCEDURAL) || (role == Role::ConnectivityInit)) { - // Update hash with connectivity parameters and derived parameters - updateParamHash( - &SynapseGroupMergedBase::isSparseConnectivityInitParamReferenced, - [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getParams(); }, hash); - - updateParamHash( - &SynapseGroupMergedBase::isSparseConnectivityInitParamReferenced, - [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); }, hash); - } - - // If we're updating a hash for a group with Toeplitz connectivity - if((getArchetype().getMatrixType() & SynapseMatrixConnectivity::TOEPLITZ) && updateRole) { - // Update hash with connectivity parameters and derived parameters - updateParamHash( - &SynapseGroupMergedBase::isToeplitzConnectivityInitParamReferenced, - [](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getParams(); }, hash); - - updateParamHash( - &SynapseGroupMergedBase::isToeplitzConnectivityInitParamReferenced, - [](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getDerivedParams(); }, hash); - } - - if(getArchetype().getMatrixType() & SynapseMatrixWeight::GLOBAL) { - // If this is an update role - // **NOTE **global variable values aren't useful during initialization - if(updateRole) { - updateParamHash( - &SynapseGroupMergedBase::isWUGlobalVarReferenced, - [](const SynapseGroupInternal &sg) { return sg.getWUConstInitVals(); }, hash); - } - } - // Otherwise (weights are individual or procedural) - else { - const bool connectInitRole = (role == Role::ConnectivityInit); - const bool varInitRole = (role == Role::Init || role == Role::SparseInit); - const bool proceduralWeights = (getArchetype().getMatrixType() & SynapseMatrixWeight::PROCEDURAL); - const bool individualWeights = (getArchetype().getMatrixType() & SynapseMatrixWeight::INDIVIDUAL); - const bool kernelWeights = (getArchetype().getMatrixType() & SynapseMatrixWeight::INDIVIDUAL); - - // If synapse group has a kernel and we're either updating with procedural - // weights or initialising individual weights, update hash with kernel size - if(!getArchetype().getKernelSize().empty() && - ((proceduralWeights && updateRole) || (connectInitRole && individualWeights) || (kernelWeights && !updateRole))) - { - updateHash([](const SynapseGroupInternal &g) { return g.getKernelSize(); }, hash); - } - - // If weights are procedural, we're initializing individual variables or we're initialising variables in a kernel - // **NOTE** some of these won't actually be required - could do this per-variable in loop over vars - if((proceduralWeights && updateRole) || (connectInitRole && !getArchetype().getKernelSize().empty()) - || (varInitRole && individualWeights) || (varInitRole && kernelWeights)) - { - // Update hash with each group's variable initialisation parameters and derived parameters - updateVarInitParamHash( - &SynapseGroupMergedBase::isWUVarInitParamReferenced, hash); - updateVarInitDerivedParamHash( - &SynapseGroupMergedBase::isWUVarInitParamReferenced, hash); - } - } - return hash.get_digest(); -} -//---------------------------------------------------------------------------- -std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, - const std::string &index, const std::string &prefix) const -{ - if (delay) { - if (varDuplication == VarAccessDuplication::SHARED_NEURON) { - return prefix + ((batchSize == 1) ? "$(_delay_slot)" : "$(_batch_delay_slot)"); - } - else if (varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return prefix + "$(_delay_offset) + " + index; - } - else { - return prefix + "$(_batch_delay_offset) + " + index; - } - } - else { - if (varDuplication == VarAccessDuplication::SHARED_NEURON) { - return (batchSize == 1) ? "0" : "batch"; - } - else if (varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return index; - } - else { - return prefix + "$(_batch_offset) + " + index; - } - } -} \ No newline at end of file diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 59c8046a78..f3a6d45a0f 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -637,22 +637,15 @@ boost::uuids::detail::sha1::digest_type SynapseConnectivityInitGroupMerged::getH updateHash([](const SynapseGroupInternal &g) { return g.getMaxSourceConnections(); }, hash); // Update hash with connectivity parameters and derived parameters - updateParamHash( - &SynapseGroupMergedBase::isConnectivityInitParamReferenced, - [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getParams(); }, hash); - - updateParamHash( - &SynapseGroupMergedBase::isConnectivityInitParamReferenced, - [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); }, hash); + updateParamHash([](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getParams(); }, hash); + updateParamHash([](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); }, hash); if(!getArchetype().getKernelSize().empty()) { updateHash([](const SynapseGroupInternal &g) { return g.getKernelSize(); }, hash); // Update hash with each group's variable initialisation parameters and derived parameters - updateVarInitParamHash( - &SynapseGroupMergedBase::isWUVarInitParamReferenced, hash); - updateVarInitDerivedParamHash( - &SynapseGroupMergedBase::isWUVarInitParamReferenced, hash); + updateVarInitParamHash(hash); + updateVarInitDerivedParamHash(hash); } return hash.get_digest(); } @@ -844,21 +837,12 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac //---------------------------------------------------------------------------- bool SynapseConnectivityHostInitGroupMerged::isConnectivityInitParamHeterogeneous(const std::string ¶mName) const { - return (isSparseConnectivityInitParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg){ return sg.getConnectivityInitialiser().getParams(); })); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg){ return sg.getConnectivityInitialiser().getParams(); }); } //---------------------------------------------------------------------------- bool SynapseConnectivityHostInitGroupMerged::isConnectivityInitDerivedParamHeterogeneous(const std::string ¶mName) const { - return (isSparseConnectivityInitParamReferenced(paramName) && - isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); })); -} -//---------------------------------------------------------------------------- -bool SynapseConnectivityHostInitGroupMerged::isSparseConnectivityInitParamReferenced(const std::string ¶mName) const -{ - // If parameter isn't referenced in code, there's no point implementing it hetereogeneously! - const auto *connectInitSnippet = getArchetype().getConnectivityInitialiser().getSnippet(); - return isParamReferenced({connectInitSnippet->getHostInitCode()}, paramName); + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); }); } // ---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index a291ee5bad..7d0b6aba52 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -53,10 +53,8 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::CurrentSource::updateHash(boost::uuids::detail::sha1 &hash) const { - updateParamHash(&CurrentSource::isParamReferenced, - [](const CurrentSourceInternal &g) { return g.getParams(); }, hash); - updateParamHash(&CurrentSource::isParamReferenced, - [](const CurrentSourceInternal &g) { return g.getDerivedParams(); }, hash); + updateParamHash([](const CurrentSourceInternal &g) { return g.getParams(); }, hash); + updateParamHash([](const CurrentSourceInternal &g) { return g.getDerivedParams(); }, hash); } //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::CurrentSource::isParamHeterogeneous(const std::string ¶mName) const @@ -141,10 +139,8 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &hash) const { - updateParamHash(&InSynPSM::isParamReferenced, - [](const SynapseGroupInternal &g) { return g.getPSParams(); }, hash); - updateParamHash(&InSynPSM::isParamReferenced, - [](const SynapseGroupInternal &g) { return g.getPSDerivedParams(); }, hash); + updateParamHash([](const SynapseGroupInternal &g) { return g.getPSParams(); }, hash); + updateParamHash([](const SynapseGroupInternal &g) { return g.getPSDerivedParams(); }, hash); } //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::InSynPSM::isParamHeterogeneous(const std::string ¶mName) const @@ -252,10 +248,8 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentEx //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynWUMPostCode::updateHash(boost::uuids::detail::sha1 &hash) const { - updateParamHash(&InSynWUMPostCode::isParamReferenced, - [](const SynapseGroupInternal &g) { return g.getWUParams(); }, hash); - updateParamHash(&InSynWUMPostCode::isParamReferenced, - [](const SynapseGroupInternal &g) { return g.getWUDerivedParams(); }, hash); + updateParamHash([](const SynapseGroupInternal &g) { return g.getWUParams(); }, hash); + updateParamHash([](const SynapseGroupInternal &g) { return g.getWUDerivedParams(); }, hash); } //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::InSynWUMPostCode::isParamHeterogeneous(const std::string ¶mName) const diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 2174078de6..fabf7b61c4 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -165,6 +165,199 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa } } // Anonymous namespace +//---------------------------------------------------------------------------- +// GeNN::CodeGenerator::SynapseGroupMergedBase +//---------------------------------------------------------------------------- +bool SynapseGroupMergedBase::isWUParamHeterogeneous(const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUParams(); }); +} +//---------------------------------------------------------------------------- +bool SynapseGroupMergedBase::isWUDerivedParamHeterogeneous(const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getWUDerivedParams(); }); +} +//---------------------------------------------------------------------------- +bool SynapseGroupMergedBase::isVarInitParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [varName](const SynapseGroupInternal &sg){ return sg.getWUVarInitialisers().at(varName).getParams(); }); +} +//---------------------------------------------------------------------------- +bool SynapseGroupMergedBase::isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [varName](const SynapseGroupInternal &sg) { return sg.getWUVarInitialisers().at(varName).getDerivedParams(); }); +} +//---------------------------------------------------------------------------- +bool SynapseGroupMergedBase::isSparseConnectivityInitParamHeterogeneous(const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getParams(); }); +} +//---------------------------------------------------------------------------- +bool SynapseGroupMergedBase::isSparseConnectivityInitDerivedParamHeterogeneous(const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); }); +} +//---------------------------------------------------------------------------- +bool SynapseGroupMergedBase::isToeplitzConnectivityInitParamHeterogeneous(const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getParams(); }); +} +//---------------------------------------------------------------------------- +bool SynapseGroupMergedBase::isToeplitzConnectivityInitDerivedParamHeterogeneous(const std::string ¶mName) const +{ + return isParamValueHeterogeneous(paramName, [](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getDerivedParams(); }); +} +//---------------------------------------------------------------------------- +std::string SynapseGroupMergedBase::getPreSlot(unsigned int batchSize) const +{ + if(getArchetype().getSrcNeuronGroup()->isDelayRequired()) { + return (batchSize == 1) ? "$(_pre_delay_slot)" : "$(_pre_batch_delay_slot)"; + } + else { + return (batchSize == 1) ? "0" : "$(batch)"; + } +} +//---------------------------------------------------------------------------- +std::string SynapseGroupMergedBase::getPostSlot(unsigned int batchSize) const +{ + if(getArchetype().getTrgNeuronGroup()->isDelayRequired()) { + return (batchSize == 1) ? "$(_post_delay_slot)" : "$(_post_batch_delay_slot)"; + } + else { + return (batchSize == 1) ? "0" : "$(batch)"; + } +} +//---------------------------------------------------------------------------- +std::string SynapseGroupMergedBase::getPostDenDelayIndex(unsigned int batchSize, const std::string &index, const std::string &offset) const +{ + assert(getArchetype().isDendriticDelayRequired()); + + const std::string batchID = ((batchSize == 1) ? "" : "$(_post_batch_offset) + ") + std::string{"$(" + index + ")"}; + + if(offset.empty()) { + return "(*$(_den_delay_ptr) * $(num_post) + " + batchID; + } + else { + return "(((*(_den_delay_ptr) + " + offset + ") % " + std::to_string(getArchetype().getMaxDendriticDelayTimesteps()) + ") * $(num_post)) + " + batchID; + } +} +//---------------------------------------------------------------------------- +std::string SynapseGroupMergedBase::getPreVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +{ + return getVarIndex(delay, batchSize, varDuplication, index, "pre"); +} +//-------------------------------------------------------------------------- +std::string SynapseGroupMergedBase::getPostVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +{ + return getVarIndex(delay, batchSize, varDuplication, index, "post"); +} +//-------------------------------------------------------------------------- +std::string SynapseGroupMergedBase::getPrePrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +{ + const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); + + if(delay) { + return (singleBatch ? "$(_pre_prev_spike_time_delay_offset) + " : "$(_pre_prev_spike_time_batch_delay_offset) + ") + index; + } + else { + return (singleBatch ? "" : "$(_pre_batch_offset) + ") + std::string{"$(" + index + ")"}; + } +} +//-------------------------------------------------------------------------- +std::string SynapseGroupMergedBase::getPostPrevSpikeTimeIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +{ + const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); + + if(delay) { + return (singleBatch ? "$(_post_prev_spike_time_delay_offset) + " : "$(_post_prev_spike_time_batch_delay_offset) + ") + std::string{"$(" + index + ")"}; + } + else { + return (singleBatch ? "" : "$(_post_batch_offset) + ") + std::string{"$(" + index + ")"}; + } +} +//-------------------------------------------------------------------------- +std::string SynapseGroupMergedBase::getSynVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +{ + const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); + return (singleBatch ? "" : "$(_syn_batch_offset)") + std::string{"$(" + index + ")"}; +} +//-------------------------------------------------------------------------- +std::string SynapseGroupMergedBase::getKernelVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const +{ + const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); + return (singleBatch ? "" : "$(_kern_batch_offset)") + std::string{"$(" + index + ")"}; +} +//---------------------------------------------------------------------------- +std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, + const std::string &index, const std::string &prefix) const +{ + if (delay) { + if (varDuplication == VarAccessDuplication::SHARED_NEURON) { + return prefix + ((batchSize == 1) ? "$(_delay_slot)" : "$(_batch_delay_slot)"); + } + else if (varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { + return prefix + "$(_delay_offset) + " + index; + } + else { + return prefix + "$(_batch_delay_offset) + " + index; + } + } + else { + if (varDuplication == VarAccessDuplication::SHARED_NEURON) { + return (batchSize == 1) ? "0" : "batch"; + } + else if (varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { + return index; + } + else { + return prefix + "$(_batch_offset) + " + index; + } + } +} +//---------------------------------------------------------------------------- +boost::uuids::detail::sha1::digest_type SynapseGroupMergedBase::getHashDigest() const +{ + boost::uuids::detail::sha1 hash; + + // Update hash with number of neurons in pre and postsynaptic population + updateHash([](const SynapseGroupInternal &g) { return g.getSrcNeuronGroup()->getNumNeurons(); }, hash); + updateHash([](const SynapseGroupInternal &g) { return g.getTrgNeuronGroup()->getNumNeurons(); }, hash); + updateHash([](const SynapseGroupInternal &g) { return g.getMaxConnections(); }, hash); + updateHash([](const SynapseGroupInternal &g) { return g.getMaxSourceConnections(); }, hash); + + // Update hash with weight update model parameters and derived parameters + updateHash([](const SynapseGroupInternal &g) { return g.getWUParams(); }, hash); + updateHash([](const SynapseGroupInternal &g) { return g.getWUDerivedParams(); }, hash); + + // If we're updating a hash for a group with procedural connectivity or initialising connectivity + if(getArchetype().getMatrixType() & SynapseMatrixConnectivity::PROCEDURAL) { + updateParamHash([](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getParams(); }, hash); + updateParamHash([](const SynapseGroupInternal &sg) { return sg.getConnectivityInitialiser().getDerivedParams(); }, hash); + } + + // If we're updating a hash for a group with Toeplitz connectivity + if(getArchetype().getMatrixType() & SynapseMatrixConnectivity::TOEPLITZ) { + // Update hash with connectivity parameters and derived parameters + updateParamHash([](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getParams(); }, hash); + + updateParamHash([](const SynapseGroupInternal &sg) { return sg.getToeplitzConnectivityInitialiser().getDerivedParams(); }, hash); + } + + // If weights are procedural + if(getArchetype().getMatrixType() & SynapseMatrixWeight::PROCEDURAL) { + // If synapse group has a kernel, update hash with kernel size + if(!getArchetype().getKernelSize().empty()) { + updateHash([](const SynapseGroupInternal &g) { return g.getKernelSize(); }, hash); + } + + // Update hash with each group's variable initialisation parameters and derived parameters + updateVarInitParamHash(hash); + updateVarInitDerivedParamHash(hash); + } + + return hash.get_digest(); +} + //---------------------------------------------------------------------------- // GeNN::CodeGenerator::PresynapticUpdateGroupMerged //---------------------------------------------------------------------------- diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index 5eaf2a5734..d0cb8d900e 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -31,7 +31,6 @@ - From d20e565d8723bc8a992abbf73e8ff7e57185ac62 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 30 Jun 2023 10:33:48 +0100 Subject: [PATCH 273/725] std::monostate needs to be hashable --- include/genn/genn/gennUtils.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/include/genn/genn/gennUtils.h b/include/genn/genn/gennUtils.h index 60f95e9b69..d2b6cd0763 100644 --- a/include/genn/genn/gennUtils.h +++ b/include/genn/genn/gennUtils.h @@ -138,6 +138,11 @@ inline void updateHash(const T& value, boost::uuids::detail::sha1& hash) hash.process_bytes(&value, sizeof(T)); } +//! Hash monostate +inline void updateHash(std::monostate, boost::uuids::detail::sha1&) +{ +} + //! Hash strings inline void updateHash(const std::string &string, boost::uuids::detail::sha1 &hash) { From 109b06ba9e93d291740040856935d20cf9cc88e1 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 30 Jun 2023 10:49:08 +0100 Subject: [PATCH 274/725] fixed final compile errors --- include/genn/genn/code_generator/backendBase.h | 4 ++-- .../genn/genn/code_generator/initGroupMerged.h | 9 +++++++++ .../genn/code_generator/initGroupMerged.cc | 18 ++++++++++++++++++ .../code_generator/neuronUpdateGroupMerged.cc | 6 ++---- .../presynapticUpdateStrategySIMT.cc | 1 - 5 files changed, 31 insertions(+), 7 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 19e6396883..b99cb5d875 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -315,8 +315,8 @@ class GENN_EXPORT BackendBase virtual void genVariableInit(EnvironmentExternalBase &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const = 0; virtual void genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const = 0; virtual void genDenseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const = 0; - virtual void genKernelSynapseVariableInit(EnvironmentExternalBase &env, const SynapseInitGroupMerged &sg, HandlerEnv handler) const = 0; - virtual void genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, const CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const = 0; + virtual void genKernelSynapseVariableInit(EnvironmentExternalBase &env, SynapseInitGroupMerged &sg, HandlerEnv handler) const = 0; + virtual void genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const = 0; //! Generate a single RNG instance /*! On single-threaded platforms this can be a standard RNG like M.T. but, on parallel platforms, it is likely to be a counter-based RNG */ diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index c27b973963..0fbf554743 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -280,6 +280,9 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public GroupMerged(&OutSynWUMPreCode::isParamReferenced, - [](const SynapseGroupInternal &g) { return g.getWUParams(); }, hash); - updateParamHash(&OutSynWUMPreCode::isParamReferenced, - [](const SynapseGroupInternal &g) { return g.getWUDerivedParams(); }, hash); + updateParamHash([](const SynapseGroupInternal &g) { return g.getWUParams(); }, hash); + updateParamHash([](const SynapseGroupInternal &g) { return g.getWUDerivedParams(); }, hash); } //---------------------------------------------------------------------------- bool NeuronUpdateGroupMerged::OutSynWUMPreCode::isParamHeterogeneous(const std::string ¶mName) const diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index c841175481..10a20d219c 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -12,7 +12,6 @@ #include "code_generator/codeStream.h" #include "code_generator/groupMerged.h" #include "code_generator/modelSpecMerged.h" -#include "code_generator/substitutions.h" //---------------------------------------------------------------------------- // Anonymous namespace From 612180b91b1af855c873d64ef50e09b7ed272117 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 30 Jun 2023 11:08:55 +0100 Subject: [PATCH 275/725] fixed linker error --- include/genn/genn/code_generator/generateModules.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/genn/genn/code_generator/generateModules.h b/include/genn/genn/code_generator/generateModules.h index c9ec3bfb61..0498fecab8 100644 --- a/include/genn/genn/code_generator/generateModules.h +++ b/include/genn/genn/code_generator/generateModules.h @@ -30,15 +30,15 @@ GENN_EXPORT std::pair, MemAlloc> generateAll(const Mode const filesystem::path &sharePath, const filesystem::path &outputPath, bool forceRebuild = false); -GENN_EXPORT void generateNeuronUpdate(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, +GENN_EXPORT void generateNeuronUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, const std::string &suffix = ""); -GENN_EXPORT void generateCustomUpdate(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, +GENN_EXPORT void generateCustomUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, const std::string &suffix = ""); -GENN_EXPORT void generateSynapseUpdate(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, +GENN_EXPORT void generateSynapseUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, const std::string &suffix = ""); -GENN_EXPORT void generateInit(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, +GENN_EXPORT void generateInit(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, const std::string &suffix = ""); } From 674fb19bfadbf98fbc8bc15b0b9d9de82ab33304 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 30 Jun 2023 11:09:03 +0100 Subject: [PATCH 276/725] removed GLOBALG --- include/genn/genn/synapseMatrixType.h | 12 ++++-------- .../code_generator/presynapticUpdateStrategySIMT.cc | 4 ++-- src/genn/genn/synapseGroup.cc | 13 +++++++++---- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/include/genn/genn/synapseMatrixType.h b/include/genn/genn/synapseMatrixType.h index 720000d70e..c6bd0d5a42 100644 --- a/include/genn/genn/synapseMatrixType.h +++ b/include/genn/genn/synapseMatrixType.h @@ -18,7 +18,6 @@ enum class SynapseMatrixConnectivity : unsigned int //! Flags defining different types of synaptic matrix connectivity enum class SynapseMatrixWeight : unsigned int { - GLOBAL = (1 << 5), INDIVIDUAL = (1 << 6), PROCEDURAL = (1 << 7), KERNEL = (1 << 8) @@ -27,16 +26,13 @@ enum class SynapseMatrixWeight : unsigned int //! Supported combinations of SynapticMatrixConnectivity and SynapticMatrixWeight enum class SynapseMatrixType : unsigned int { - DENSE_GLOBALG = static_cast(SynapseMatrixConnectivity::DENSE) | static_cast(SynapseMatrixWeight::GLOBAL), - DENSE_INDIVIDUALG = static_cast(SynapseMatrixConnectivity::DENSE) | static_cast(SynapseMatrixWeight::INDIVIDUAL), + DENSE = static_cast(SynapseMatrixConnectivity::DENSE) | static_cast(SynapseMatrixWeight::INDIVIDUAL), DENSE_PROCEDURALG = static_cast(SynapseMatrixConnectivity::DENSE) | static_cast(SynapseMatrixWeight::PROCEDURAL), - BITMASK_GLOBALG = static_cast(SynapseMatrixConnectivity::BITMASK) | static_cast(SynapseMatrixWeight::GLOBAL), - SPARSE_GLOBALG = static_cast(SynapseMatrixConnectivity::SPARSE) | static_cast(SynapseMatrixWeight::GLOBAL), - SPARSE_INDIVIDUALG = static_cast(SynapseMatrixConnectivity::SPARSE) | static_cast(SynapseMatrixWeight::INDIVIDUAL), - PROCEDURAL_GLOBALG = static_cast(SynapseMatrixConnectivity::PROCEDURAL) | static_cast(SynapseMatrixWeight::GLOBAL), + BITMASK = static_cast(SynapseMatrixConnectivity::BITMASK) | static_cast(SynapseMatrixWeight::INDIVIDUAL), + SPARSE = static_cast(SynapseMatrixConnectivity::SPARSE) | static_cast(SynapseMatrixWeight::INDIVIDUAL), PROCEDURAL_PROCEDURALG = static_cast(SynapseMatrixConnectivity::PROCEDURAL) | static_cast(SynapseMatrixWeight::PROCEDURAL), PROCEDURAL_KERNELG = static_cast(SynapseMatrixConnectivity::PROCEDURAL) | static_cast(SynapseMatrixWeight::KERNEL), - TOEPLITZ_KERNELG = static_cast(SynapseMatrixConnectivity::TOEPLITZ) | static_cast(SynapseMatrixWeight::KERNEL), + TOEPLITZ = static_cast(SynapseMatrixConnectivity::TOEPLITZ) | static_cast(SynapseMatrixWeight::KERNEL), }; //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 10a20d219c..4e2c615498 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -441,10 +441,10 @@ size_t PreSpanProcedural::getSynapticMatrixRowStride(const SynapseGroupInternal bool PreSpanProcedural::isCompatible(const SynapseGroupInternal &sg, const PreferencesBase &) const { // Presynaptic procedural parallelism can be used when synapse groups have - // procedural connectivity and weights are either GLOBAL, PROCEDURAL or KERNEL + // procedural connectivity and there are either no variables or variables are PROCEDURAL or KERNEL const auto matrixType = sg.getMatrixType(); return ((matrixType & SynapseMatrixConnectivity::PROCEDURAL) - && ((matrixType & SynapseMatrixWeight::GLOBAL) || (matrixType & SynapseMatrixWeight::PROCEDURAL) + && (sg.getWUModel()->getVars().empty() || (matrixType & SynapseMatrixWeight::PROCEDURAL) || (matrixType & SynapseMatrixWeight::KERNEL))); } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 960c1a6ea9..ef93d63d15 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -563,10 +563,15 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType } // If connectivity initialisation snippet defines a kernel and matrix type doesn't support it, give error - if(!m_KernelSize.empty() && (m_MatrixType != SynapseMatrixType::PROCEDURAL_PROCEDURALG) && (m_MatrixType != SynapseMatrixType::TOEPLITZ_KERNELG) - && (m_MatrixType != SynapseMatrixType::SPARSE_INDIVIDUALG) && (m_MatrixType != SynapseMatrixType::PROCEDURAL_KERNELG)) + if(!m_KernelSize.empty() && (m_MatrixType != SynapseMatrixType::PROCEDURAL_PROCEDURALG) && (m_MatrixType != SynapseMatrixType::TOEPLITZ) + && (m_MatrixType != SynapseMatrixType::SPARSE) && (m_MatrixType != SynapseMatrixType::PROCEDURAL_KERNELG)) { - throw std::runtime_error("Connectivity initialisation snippet which use a kernel can only be used with PROCEDURAL_PROCEDURALG, PROCEDURAL_KERNELG, TOEPLITZ_KERNELG or SPARSE_INDIVIDUALG connectivity."); + throw std::runtime_error("Connectivity initialisation snippet which use a kernel can only be used with PROCEDURAL_PROCEDURALG, PROCEDURAL_KERNELG, TOEPLITZ or SPARSE connectivity."); + } + + // Check BITMASK connectivity isn't used with models with variables + if((m_MatrixType & SynapseMatrixConnectivity::BITMASK) && !m_WUModel->getVars().empty()) { + throw std::runtime_error("BITMASK connectivity can only be used with weight update models without variables like StaticPulseConstantWeight."); } // If connectivity is dense and there is connectivity initialiser code, give error @@ -578,7 +583,7 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType // If synapse group uses sparse or procedural connectivity but no kernel size is provided, // check that no variable's initialisation snippets require a kernel - if(((m_MatrixType == SynapseMatrixType::SPARSE_INDIVIDUALG) || (m_MatrixType == SynapseMatrixType::PROCEDURAL_PROCEDURALG)) && + if(((m_MatrixType == SynapseMatrixType::SPARSE) || (m_MatrixType == SynapseMatrixType::PROCEDURAL_PROCEDURALG)) && m_KernelSize.empty() && std::any_of(getWUVarInitialisers().cbegin(), getWUVarInitialisers().cend(), [](const auto &v) { return v.second.getSnippet()->requiresKernel(); })) { From a214327621eebf185f771cc3c0ae79208e1f65a5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 30 Jun 2023 11:18:16 +0100 Subject: [PATCH 277/725] missing CodeStream constructor --- include/genn/genn/code_generator/environment.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index dc554b308d..d901a0d020 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -233,17 +233,17 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P public: template EnvironmentExternalDynamicBase(EnvironmentExternalBase &enclosing, PolicyArgs&&... policyArgs) - : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...) + : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...), m_Contents(m_ContentsStream) {} template EnvironmentExternalDynamicBase(Transpiler::PrettyPrinter::EnvironmentBase &enclosing, PolicyArgs&&... policyArgs) - : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...) + : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...), m_Contents(m_ContentsStream) {} template EnvironmentExternalDynamicBase(CodeStream &os, PolicyArgs&&... policyArgs) - : EnvironmentExternalBase(os), P(std::forward(policyArgs)...) + : EnvironmentExternalBase(os), P(std::forward(policyArgs)...), m_Contents(m_ContentsStream) {} ~EnvironmentExternalDynamicBase() @@ -276,7 +276,7 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P } } - virtual CodeStream &getStream() final { return m_Contents; } + virtual CodeStream &getStream() final { return m_Contents; } //------------------------------------------------------------------------ // TypeChecker::EnvironmentBase virtuals From 06568d15dd4059b6c3e12fdee2aa66f1e12ef437 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 30 Jun 2023 11:20:43 +0100 Subject: [PATCH 278/725] removed assert now evaluation is lazy --- src/genn/genn/code_generator/synapseUpdateGroupMerged.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index fabf7b61c4..854f9badaa 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -230,8 +230,6 @@ std::string SynapseGroupMergedBase::getPostSlot(unsigned int batchSize) const //---------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getPostDenDelayIndex(unsigned int batchSize, const std::string &index, const std::string &offset) const { - assert(getArchetype().isDendriticDelayRequired()); - const std::string batchID = ((batchSize == 1) ? "" : "$(_post_batch_offset) + ") + std::string{"$(" + index + ")"}; if(offset.empty()) { From cdbf04b394e3387cd4809d3d3a6a7cccb6803503 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 30 Jun 2023 11:24:18 +0100 Subject: [PATCH 279/725] started updating WUMs --- include/genn/genn/weightUpdateModels.h | 47 ++++++++++++++++++++------ src/genn/genn/weightUpdateModels.cc | 1 + 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/include/genn/genn/weightUpdateModels.h b/include/genn/genn/weightUpdateModels.h index 41b8e53bc6..b394472608 100644 --- a/include/genn/genn/weightUpdateModels.h +++ b/include/genn/genn/weightUpdateModels.h @@ -176,7 +176,32 @@ class StaticPulse : public Base SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}}); - SET_SIM_CODE("$(addToInSyn, $(g));\n"); + SET_SIM_CODE("addToPost(g);\n"); +}; + +//---------------------------------------------------------------------------- +// GeNN::WeightUpdateModels::StaticPulseConstantWeight +//---------------------------------------------------------------------------- +//! Pulse-coupled, static synapse. +/*! No learning rule is applied to the synapse and for each pre-synaptic spikes, + the synaptic conductances are simply added to the postsynaptic input variable. + The model has 1 parameter: + - g - conductance + and no other variables. + + \c sim code is: + + \code + "addToPost(g);" + \endcode*/ +class StaticPulseConstantWeight : public Base +{ +public: + DECLARE_SNIPPET(StaticPulseConstantWeight); + + SET_PARAM_NAMES({"g"}); + + SET_SIM_CODE("addToPost(g);\n"); }; //---------------------------------------------------------------------------- @@ -202,7 +227,7 @@ class StaticPulseDendriticDelay : public Base SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}, {"d", "uint8_t", VarAccess::READ_ONLY}}); - SET_SIM_CODE("$(addToInSynDelay, $(g), $(d));\n"); + SET_SIM_CODE("addToPostDelay(g, d);\n"); }; //---------------------------------------------------------------------------- @@ -239,9 +264,9 @@ class StaticGraded : public Base SET_PARAM_NAMES({"Epre", "Vslope"}); SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}}); - SET_EVENT_CODE("$(addToInSyn, fmax(0.0, $(g) * tanh(($(V_pre) - $(Epre)) / $(Vslope))* DT));\n"); + SET_EVENT_CODE("addToPost(fmax(0.0, g * tanh((V_pre - Epre) / Vslope) * DT));\n"); - SET_EVENT_THRESHOLD_CONDITION_CODE("$(V_pre) > $(Epre)"); + SET_EVENT_THRESHOLD_CONDITION_CODE("V_pre > Epre"); }; //---------------------------------------------------------------------------- @@ -311,15 +336,15 @@ class PiecewiseSTDP : public Base SET_VARS({{"g", "scalar"}, {"gRaw", "scalar"}}); SET_SIM_CODE( - "$(addToInSyn, $(g));\n" - "scalar dt = $(sT_post) - $(t) - ($(tauShift)); \n" + "addToPost(g);\n" + "scalar dt = sT_post - t - tauShift; \n" "scalar dg = 0;\n" - "if (dt > $(lim0)) \n" - " dg = -($(off0)) ; \n" + "if (dt > lim0) \n" + " dg = -off0 ; \n" "else if (dt > 0) \n" - " dg = $(slope0) * dt + ($(off1)); \n" - "else if (dt > $(lim1)) \n" - " dg = $(slope1) * dt + ($(off1)); \n" + " dg = slope0 * dt + off1; \n" + "else if (dt > lim1) \n" + " dg = slope1 * dt + ($(off1)); \n" "else dg = - ($(off2)) ; \n" "$(gRaw) += dg; \n" "$(g)=$(gMax)/2 *(tanh($(gSlope)*($(gRaw) - ($(gMid))))+1); \n"); diff --git a/src/genn/genn/weightUpdateModels.cc b/src/genn/genn/weightUpdateModels.cc index 619fec2a67..7280e232bf 100644 --- a/src/genn/genn/weightUpdateModels.cc +++ b/src/genn/genn/weightUpdateModels.cc @@ -5,6 +5,7 @@ using namespace GeNN; namespace GeNN::WeightUpdateModels { IMPLEMENT_SNIPPET(StaticPulse); +IMPLEMENT_SNIPPET(StaticPulseConstantWeight); IMPLEMENT_SNIPPET(StaticPulseDendriticDelay); IMPLEMENT_SNIPPET(StaticGraded); IMPLEMENT_SNIPPET(PiecewiseSTDP); From cb82fa2dc99d23e8fcbe7cbfec7521e398ce5f6f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 30 Jun 2023 11:46:29 +0100 Subject: [PATCH 280/725] fixed typo --- src/genn/backends/single_threaded_cpu/backend.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 6a80cb7ba3..15f0542591 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1860,7 +1860,7 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda // Add correct functions for apply synaptic input groupEnv.add(Type::AddToPostDenDelay, "addToPostDelay", "$(_den_delay)[" + sg.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)"); groupEnv.add(Type::AddToPost, "addToPost", "$(_out_post)[" + sg.getPostISynIndex(1, "j") + "] += $(0)"); - groupEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + sg.getPreISynIndex(1, env["id_pre"]) + "] += $(0)"); + groupEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + sg.getPreISynIndex(1, "$(id_pre)") + "] += $(0)"); // If connectivity is sparse if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { From f20ccf1e46305b416c876ed2c20c0112024ad0ff Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 30 Jun 2023 11:46:44 +0100 Subject: [PATCH 281/725] added connectivity to genSynapseIndexCalculation --- include/genn/genn/code_generator/backendBase.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index b99cb5d875..b5c4050b87 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -600,6 +600,22 @@ class GENN_EXPORT BackendBase [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getTrgNeuronGroup()->getName(); }); env.addField(Type::Uint32.createPointer(), "_trg_spk", "trgSpk", [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getTrgNeuronGroup()->getName(); }); + + // Connectivity fields + if(env.getGroup().getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { + env.addField(Type::Uint32.createPointer(), "_gp", "gp", + [this](const auto &sg, size_t) { return getDeviceVarPrefix() + "gp" + sg.getName(); }); + } + else if(env.getGroup().getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { + env.addField(Type::Uint32.createPointer(), "_row_length", "rowLength", + [this](const auto &sg, size_t) { return getDeviceVarPrefix() + "rowLength" + sg.getName(); }); + env.addField(env.getGroup().getArchetype().getSparseIndType().createPointer(), "_ind", "ind", + [this](const auto &sg, size_t) { return getDeviceVarPrefix() + "ind" + sg.getName(); }); + env.addField(Type::Uint32.createPointer(), "_col_length", "colLength", + [this](const auto &sg, size_t) { return getDeviceVarPrefix() + "colLength" + sg.getName(); }); + env.addField(Type::Uint32.createPointer(), "_remap", "remap", + [this](const auto &sg, size_t) { return getDeviceVarPrefix() + "remap" + sg.getName(); }); + } // If batching is enabled if(batchSize > 1) { From b82fb6c91d92bee5b51bd925473278f8fac57bea Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 30 Jun 2023 11:59:31 +0100 Subject: [PATCH 282/725] lazy string doesn't treat $(0) style parameter placeholders as things to be lazily evaluated - pretty printer handles these --- src/genn/genn/code_generator/lazyString.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/genn/genn/code_generator/lazyString.cc b/src/genn/genn/code_generator/lazyString.cc index 1ed631aad5..d40ff44727 100644 --- a/src/genn/genn/code_generator/lazyString.cc +++ b/src/genn/genn/code_generator/lazyString.cc @@ -18,7 +18,8 @@ using namespace GeNN::CodeGenerator; LazyString::LazyString(const std::string &format, EnvironmentExternalBase &env) { // Create regex iterator to iterate over $(XXX) style varibles in format string - std::regex regex("\\$\\(([\\w]+)\\)"); + // **NOTE** this doesn't match function argument $(0) + std::regex regex("\\$\\(([a-zA-Z_][\\w]+)\\)"); std::sregex_iterator matchesBegin(format.cbegin(), format.cend(), regex); std::sregex_iterator matchesEnd; From 105796ca2595269f1506c5d589be11c4fe0431a2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 30 Jun 2023 11:59:49 +0100 Subject: [PATCH 283/725] inSyn and Isyn shouldn't be const in postsynaptic model code --- src/genn/genn/code_generator/neuronUpdateGroupMerged.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 93de42417b..422b8dd913 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -113,10 +113,10 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env psmEnv.addExtraGlobalParams(psm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", fieldSuffix); // **TODO** naming convention - psmEnv.add(modelMerged.getModel().getPrecision().addConst(), "inSyn", "linSyn"); + psmEnv.add(modelMerged.getModel().getPrecision(), "inSyn", "linSyn"); // Allow synapse group's PS output var to override what Isyn points to - psmEnv.add(modelMerged.getModel().getPrecision().addConst(), "Isyn", getArchetype().getPSTargetVar()); + psmEnv.add(modelMerged.getModel().getPrecision(), "Isyn", getArchetype().getPSTargetVar()); // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( From d5ce598568810c967a1b515848152111b889181c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 30 Jun 2023 12:00:06 +0100 Subject: [PATCH 284/725] call genNeuronIndexCalculation in correct place --- src/genn/backends/single_threaded_cpu/backend.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 15f0542591..83fd220455 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -233,6 +233,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Get reference to group funcEnv.getStream() << "const auto *group = &mergedNeuronUpdateGroup" << n.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, n); + genNeuronIndexCalculation(groupEnv, 1); // If spike or spike-like event recording is in use if(n.getArchetype().isSpikeRecordingEnabled() || n.getArchetype().isSpikeEventRecordingEnabled()) { @@ -250,7 +251,6 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host } } - genNeuronIndexCalculation(groupEnv, 1); groupEnv.getStream() << std::endl; groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["num_neurons"] << "; i++)"; From 21299639591d13d6aaf2483d008f577c14c0ea18 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 30 Jun 2023 12:13:43 +0100 Subject: [PATCH 285/725] WIP more fixes --- include/genn/genn/code_generator/codeGenUtils.h | 8 ++++++++ src/genn/backends/single_threaded_cpu/backend.cc | 9 ++++++++- src/genn/genn/code_generator/codeGenUtils.cc | 4 +++- src/genn/genn/code_generator/generateRunner.cc | 1 - src/genn/genn/code_generator/neuronUpdateGroupMerged.cc | 2 +- 5 files changed, 20 insertions(+), 4 deletions(-) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index d7538bde18..2fe3b0126b 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -136,6 +136,14 @@ GENN_EXPORT void prettyPrintStatements(const std::string &code, const Type::Type GENN_EXPORT std::string printSubs(const std::string &format, EnvironmentExternalBase &env); + +template +inline std::string writePreciseLiteral(T value, const Type::ResolvedType &type) +{ + const auto &numeric = type.getNumeric(); + return writePreciseString(value, numeric.maxDigits10) + numeric.literalSuffix; +} + //------------------------------------------------------------------------- /*! \brief Function for performing the code and value substitutions necessary to insert neuron related variables, parameters, and extraGlobal parameters into synaptic code. diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 83fd220455..beb671356a 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -140,6 +140,8 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host EnvironmentExternal funcEnv(neuronUpdateEnv); funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); funcEnv.add(Type::Uint32.addConst(), "batch", "0"); + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "dt", + writePreciseLiteral(modelMerged.getModel().getDT(), modelMerged.getModel().getTimePrecision())); Timer t(funcEnv.getStream(), "neuronUpdate", modelMerged.getModel().isTimingEnabled()); modelMerged.genMergedNeuronPrevSpikeTimeUpdateGroups( @@ -320,6 +322,8 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos EnvironmentExternal funcEnv(synapseUpdateEnv); funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); funcEnv.add(Type::Uint32.addConst(), "batch", "0"); + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "dt", + writePreciseLiteral(modelMerged.getModel().getDT(), modelMerged.getModel().getTimePrecision())); // Synapse dynamics { @@ -564,7 +568,8 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host EnvironmentExternal funcEnv(customUpdateEnv); funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); funcEnv.add(Type::Uint32.addConst(), "batch", "0"); - + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "dt", + writePreciseLiteral(modelMerged.getModel().getDT(), modelMerged.getModel().getTimePrecision())); // Loop through host update groups and generate code for those in this custom update group for (const auto &cg : modelMerged.getMergedCustomConnectivityHostUpdateGroups()) { @@ -850,6 +855,8 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler { CodeStream::Scope b(initEnv.getStream()); EnvironmentExternal funcEnv(initEnv); + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "dt", + writePreciseLiteral(modelMerged.getModel().getDT(), modelMerged.getModel().getTimePrecision())); Timer t(funcEnv.getStream(), "init", model.isTimingEnabled()); diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index f3f55da3f9..0546ea4c95 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -483,7 +483,9 @@ std::string upgradeCodeString(const std::string &codeString) // **TODO** old style function call to standard C (these are ambiguous so need to be applied to existing genn functions) std::regex variable(R"(\$\(([_a-zA-Z][a-zA-Z0-9]*)\))"); - return std::regex_replace(codeString, variable, "$1"); + std::string upgraded = std::regex_replace(codeString, variable, "$1"); + + return upgraded; } //---------------------------------------------------------------------------- std::tuple scanParseAndTypeCheckStatements( diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index d7dd3f929e..b190594be4 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -560,7 +560,6 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // write DT macro const ModelSpecInternal &model = modelMerged.getModel(); - definitions << "#define DT " << Utils::writePreciseString(model.getDT()) << model.getTimePrecision().getNumeric().literalSuffix << std::endl; // Write ranges of scalar and time types genTypeRange(definitions, model.getPrecision(), "SCALAR"); diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 422b8dd913..a3f72b43d5 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -574,7 +574,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E neuronVarEnv.getStream() << "// calculate membrane potential" << std::endl; Transpiler::ErrorHandler errorHandler("Neuron sim code " + std::to_string(getIndex())); - prettyPrintExpression(nm->getSimCode(), getTypeContext(), neuronVarEnv, errorHandler); + prettyPrintStatements(nm->getSimCode(), getTypeContext(), neuronVarEnv, errorHandler); // Generate var update for outgoing synaptic populations with presynaptic update code for (auto &sg : m_MergedOutSynWUMPreCodeGroups) { From f9985993850620c58f5f4dae9256357394b8106e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 11:49:44 +0100 Subject: [PATCH 286/725] changed DT to dt --- include/genn/genn/neuronModels.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/genn/genn/neuronModels.h b/include/genn/genn/neuronModels.h index 9c44f4a810..e2d969fef4 100644 --- a/include/genn/genn/neuronModels.h +++ b/include/genn/genn/neuronModels.h @@ -222,7 +222,7 @@ class LIF : public Base " $(V) = alpha - ($(ExpTC) * (alpha - $(V)));\n" "}\n" "else {\n" - " $(RefracTime) -= DT;\n" + " $(RefracTime) -= dt;\n" "}\n" ); @@ -446,7 +446,7 @@ class TraubMiles : public Base SET_SIM_CODE( "scalar Imem;\n" "unsigned int mt;\n" - "scalar mdt= DT/25.0;\n" + "scalar mdt= dt/25.0;\n" "for (mt=0; mt < 25; mt++) {\n" " Imem= -($(m)*$(m)*$(m)*$(h)*$(gNa)*($(V)-($(ENa)))+\n" " $(n)*$(n)*$(n)*$(n)*$(gK)*($(V)-($(EK)))+\n" @@ -501,7 +501,7 @@ class TraubMilesFast : public TraubMiles SET_SIM_CODE( "scalar Imem;\n" "unsigned int mt;\n" - "scalar mdt= DT/25.0;\n" + "scalar mdt= dt/25.0;\n" "for (mt=0; mt < 25; mt++) {\n" " Imem= -($(m)*$(m)*$(m)*$(h)*$(gNa)*($(V)-($(ENa)))+\n" " $(n)*$(n)*$(n)*$(n)*$(gK)*($(V)-($(EK)))+\n" @@ -534,7 +534,7 @@ class TraubMilesAlt : public TraubMiles SET_SIM_CODE( "scalar Imem;\n" "unsigned int mt;\n" - "scalar mdt= DT/25.0;\n" + "scalar mdt= dt/25.0;\n" "for (mt=0; mt < 25; mt++) {\n" " Imem= -($(m)*$(m)*$(m)*$(h)*$(gNa)*($(V)-($(ENa)))+\n" " $(n)*$(n)*$(n)*$(n)*$(gK)*($(V)-($(EK)))+\n" From dc602e84589cbd82df5b304eb1f4b0de85fed82a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 11:50:13 +0100 Subject: [PATCH 287/725] use dt correctly in generated code --- src/genn/backends/single_threaded_cpu/backend.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index beb671356a..383ded3aa9 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -167,7 +167,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host groupEnv.print("for(unsigned int i = 0; i < $(_spk_cnt)[$(_read_delay_slot)]; i++)"); { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.printLine("$(_prev_spk_time)[$(_read_delay_offset) + $(_spk)[$(_read_delay_offset) + i]] = t - DT;"); + groupEnv.printLine("$(_prev_spk_time)[$(_read_delay_offset) + $(_spk)[$(_read_delay_offset) + i]] = t - $(dt);"); } } if(n.getArchetype().isPrevSpikeEventTimeRequired()) { @@ -175,7 +175,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host groupEnv.print("for(unsigned int i = 0; i < $(_spk_cnt_envt)[$(_read_delay_slot)]; i++)"); { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.printLine("$(_prev_spk_evnt_time)[$(_read_delay_offset) + $(_spk_evnt)[$(_read_delay_offset) + i]] = t - DT;"); + groupEnv.printLine("$(_prev_spk_evnt_time)[$(_read_delay_offset) + $(_spk_evnt)[$(_read_delay_offset) + i]] = t - $(dt);"); } } } @@ -185,7 +185,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host groupEnv.print("for(unsigned int i = 0; i < $(_spk_cnt)[0]; i++)"); { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.printLine("$(_prev_spk_time)[$(_spk)[i]] = t - DT;"); + groupEnv.printLine("$(_prev_spk_time)[$(_spk)[i]] = t - $(dt);"); } } if(n.getArchetype().isPrevSpikeEventTimeRequired()) { @@ -193,7 +193,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host groupEnv.print("for(unsigned int i = 0; i < $(_spk_cnt_evnt)[0]; i++)"); { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.printLine("$(_prev_spk_evnt_time)[$(_spk_evnt)[i]] = t - DT;"); + groupEnv.printLine("$(_prev_spk_evnt_time)[$(_spk_evnt)[i]] = t - $(dt);"); } } } From 98e11a636ea72c297ffd582a6e3bc335cb0137c1 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 11:50:28 +0100 Subject: [PATCH 288/725] EnvironmentLibrary::getTypes wasn't searching context --- src/genn/genn/code_generator/environment.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index 1e68aa7dad..b9ae2d3828 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -96,8 +96,7 @@ std::vector EnvironmentLibrary::getTypes(const Transpiler::T { const auto [typeBegin, typeEnd] = m_Library.get().equal_range(name.lexeme); if (typeBegin == typeEnd) { - errorHandler.error(name, "Undefined identifier"); - throw TypeChecker::TypeCheckError(); + return getContextTypes(name, errorHandler); } else { std::vector types; From 027b7fd63d397ffba0e3e40b62f549c3b53a75b6 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 12:56:47 +0100 Subject: [PATCH 289/725] fixed lots of small bugs --- .../genn/genn/code_generator/backendBase.h | 58 +++++++++---------- .../genn/genn/code_generator/codeGenUtils.h | 2 +- .../backends/single_threaded_cpu/backend.cc | 31 +++++----- src/genn/genn/code_generator/codeGenUtils.cc | 29 ++++++++-- .../code_generator/customUpdateGroupMerged.cc | 2 +- .../genn/code_generator/initGroupMerged.cc | 36 +++++++----- .../code_generator/neuronUpdateGroupMerged.cc | 20 +++---- 7 files changed, 103 insertions(+), 75 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index b5c4050b87..d57235ae39 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -455,35 +455,6 @@ class GENN_EXPORT BackendBase bool areSixtyFourBitSynapseIndicesRequired(const GroupMerged &sg) const; - //! Get backend-specific pointer size in bytes - size_t getPointerBytes() const{ return m_PointerBytes; } - - const PreferencesBase &getPreferences() const { return m_Preferences; } - - template - const T &getPreferences() const { return static_cast(m_Preferences); } - -protected: - //-------------------------------------------------------------------------- - // ReductionTarget - //-------------------------------------------------------------------------- - //! Simple struct to hold reduction targets - struct ReductionTarget - { - std::string name; - Type::ResolvedType type; - VarAccessMode access; - std::string index; - }; - - //-------------------------------------------------------------------------- - // Protected API - //-------------------------------------------------------------------------- - void setPointerBytes(size_t pointerBytes) - { - m_PointerBytes = pointerBytes; - } - template void genNeuronIndexCalculation(EnvironmentGroupMergedField &env, unsigned int batchSize) const { @@ -737,6 +708,35 @@ class GENN_EXPORT BackendBase void genCustomConnectivityUpdateIndexCalculation(EnvironmentGroupMergedField &env) const; + //! Get backend-specific pointer size in bytes + size_t getPointerBytes() const{ return m_PointerBytes; } + + const PreferencesBase &getPreferences() const { return m_Preferences; } + + template + const T &getPreferences() const { return static_cast(m_Preferences); } + +protected: + //-------------------------------------------------------------------------- + // ReductionTarget + //-------------------------------------------------------------------------- + //! Simple struct to hold reduction targets + struct ReductionTarget + { + std::string name; + Type::ResolvedType type; + VarAccessMode access; + std::string index; + }; + + //-------------------------------------------------------------------------- + // Protected API + //-------------------------------------------------------------------------- + void setPointerBytes(size_t pointerBytes) + { + m_PointerBytes = pointerBytes; + } + //! Get the initial value to start reduction operations from std::string getReductionInitialValue(VarAccessMode access, const Type::ResolvedType &type) const; diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 2fe3b0126b..b47f502f78 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -141,7 +141,7 @@ template inline std::string writePreciseLiteral(T value, const Type::ResolvedType &type) { const auto &numeric = type.getNumeric(); - return writePreciseString(value, numeric.maxDigits10) + numeric.literalSuffix; + return Utils::writePreciseString(value, numeric.maxDigits10) + numeric.literalSuffix; } //------------------------------------------------------------------------- diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 383ded3aa9..876f16ea16 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -23,18 +23,18 @@ const EnvironmentLibrary::Library cpuSinglePrecisionFunctions = { {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Float, {}), "standardUniformDistribution(hostRNG)"}}, {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "standardNormalDistribution(hostRNG)"}}, {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Float, {}), "standardExponentialDistribution(hostRNG)"}}, - {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "std::lognormal_distribution($(0), $(1))(hostRNG)"}}, - {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Float, {}), "std::gamma_distribution($(0), 1.0f)(hostRNG)"}}, - {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Float, {}), "std::binomial_distribution($(0), $(1))(hostRNG)"}}, + {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Float, {Type::Float, Type::Float}), "std::lognormal_distribution($(0), $(1))(hostRNG)"}}, + {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Float, {Type::Float}), "std::gamma_distribution($(0), 1.0f)(hostRNG)"}}, + {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Float}), "std::binomial_distribution($(0), $(1))(hostRNG)"}}, }; const EnvironmentLibrary::Library cpuDoublePrecisionFunctions = { - {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Float, {}), "standardUniformDistribution(hostRNG)"}}, - {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "standardNormalDistribution(hostRNG)"}}, - {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Float, {}), "standardExponentialDistribution(hostRNG)"}}, - {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "std::lognormal_distribution($(0), $(1))(hostRNG)"}}, - {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Float, {}), "std::gamma_distribution($(0), 1.0)(hostRNG)"}}, - {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Float, {}), "std::binomial_distribution($(0), $(1))(hostRNG)"}}, + {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Double, {}), "standardUniformDistribution(hostRNG)"}}, + {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Double, {}), "standardNormalDistribution(hostRNG)"}}, + {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Double, {}), "standardExponentialDistribution(hostRNG)"}}, + {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Double, {Type::Double, Type::Double}), "std::lognormal_distribution($(0), $(1))(hostRNG)"}}, + {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Double, {Type::Double}), "std::gamma_distribution($(0), 1.0)(hostRNG)"}}, + {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Double}), "std::binomial_distribution($(0), $(1))(hostRNG)"}}, }; //-------------------------------------------------------------------------- @@ -848,8 +848,10 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler std::ostringstream initStream; CodeStream init(initStream); - // Begin environment with standard library - EnvironmentLibrary initEnv(init, StandardLibrary::getFunctions()); + // Begin environment with RNG library and standard library + EnvironmentLibrary rngEnv(init, (modelMerged.getModel().getPrecision() == Type::Float) ? cpuSinglePrecisionFunctions : cpuDoublePrecisionFunctions); + EnvironmentLibrary initEnv(rngEnv, StandardLibrary::getFunctions()); + initEnv.getStream() << "void initialize()"; { @@ -983,15 +985,16 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Get reference to group funcEnv.getStream() << "const auto *group = &mergedSynapseConnectivityInitGroup" << s.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, s); + genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); // If matrix connectivity is ragged if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // Zero row lengths - funcEnv.printLine("std::fill_n($(_row_length), $(num_pre), 0);"); + groupEnv.printLine("std::fill_n($(_row_length), $(num_pre), 0);"); } else if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { - funcEnv.printLine("const size_t gpSize = ((((size_t)$(num_pre) * (size_t)$(_row_stride)) + 32 - 1) / 32);"); - funcEnv.printLine("std::fill($(_gp), gpSize, 0);"); + groupEnv.printLine("const size_t gpSize = ((((size_t)$(num_pre) * (size_t)$(_row_stride)) + 32 - 1) / 32);"); + groupEnv.printLine("std::fill($(_gp), gpSize, 0);"); } else { throw std::runtime_error("Only BITMASK and SPARSE format connectivity can be generated using a connectivity initialiser"); diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index 0546ea4c95..0ba457ddb3 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -479,13 +479,32 @@ std::string disambiguateNamespaceFunction(const std::string supportCode, const s //---------------------------------------------------------------------------- std::string upgradeCodeString(const std::string &codeString) { + + // Build vector of regular expressions to replace old style function calls + const std::vector> functionReplacements{ + {std::regex(R"(\$\(gennrand_uniform\))"), "gennrand_uniform()"}, + {std::regex(R"(\$\(gennrand_normal\))"), "gennrand_normal()"}, + {std::regex(R"(\$\(gennrand_exponential\))"), "gennrand_exponential()"}, + {std::regex(R"(\$\(gennrand_log_normal,(.*)\))"), "gennrand_log_normal($1)"}, + {std::regex(R"(\$\(gennrand_gamma,(.*)\))"), "gennrand_gamma($1)"}, + {std::regex(R"(\$\(gennrand_binomial,(.*)\))"), "gennrand_binomial($1)"}, + {std::regex(R"(\$\(addSynapse,(.*)\))"), "addSynapse($1)"}, + {std::regex(R"(\$\(endRow\))"), "endRow()"}, + {std::regex(R"(\$\(endCol\))"), "endCol()"}}; + + // Apply sustitutions to upgraded code string + std::string upgradedCodeString = codeString; + for(const auto &f : functionReplacements) { + upgradedCodeString = std::regex_replace(upgradedCodeString, f.first, f.second); + } + // **TODO** snake-case -> camel case known built in variables e.g id_pre -> idPre - // **TODO** old style function call to standard C (these are ambiguous so need to be applied to existing genn functions) - std::regex variable(R"(\$\(([_a-zA-Z][a-zA-Z0-9]*)\))"); - - std::string upgraded = std::regex_replace(codeString, variable, "$1"); - return upgraded; + // Replace old style $(XX) variables with plain XX + // **NOTE** this is done after functions as single-parameter function calls and variables were indistinguishable with old syntax + const std::regex variable(R"(\$\(([_a-zA-Z][_a-zA-Z0-9]*)\))"); + upgradedCodeString = std::regex_replace(upgradedCodeString, variable, "$1"); + return upgradedCodeString; } //---------------------------------------------------------------------------- std::tuple scanParseAndTypeCheckStatements( diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 2084d69ded..3cba9602aa 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -259,7 +259,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdateBase(const BackendBase & // Create an environment which caches variable references in local variables if they are accessed EnvironmentLocalVarRefCache varRefEnv( - *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", + *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", [this, &varEnv](const std::string&, const Models::WUVarReference &v) { return getVarRefIndex(getVarAccessDuplication(v.getVar().access), diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 37b76bd290..b05e2001b7 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -219,7 +219,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir // Add field for InSyn and zero groupEnv.addField(getScalarType().createPointer(), "_out_post", "outPost", [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); - backend.genVariableInit(env, "num_neurons", "id", + backend.genVariableInit(groupEnv, "num_neurons", "id", [&modelMerged] (EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_out_post", modelMerged.scalarExpr(0.0), @@ -233,7 +233,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir // Add field for dendritic delay buffer and zero groupEnv.addField(getScalarType().createPointer(), "_den_delay", "denDelay", [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); - backend.genVariableInit(env, "num_neurons", "id", + backend.genVariableInit(groupEnv, "num_neurons", "id", [&modelMerged, this](EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_den_delay", modelMerged.scalarExpr(0.0), @@ -245,7 +245,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir // Add field for dendritic delay pointer and zero groupEnv.addField(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr", [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); - backend.genPopVariableInit(env, + backend.genPopVariableInit(groupEnv, [](EnvironmentExternalBase &varEnv) { varEnv.getStream() << "*" << varEnv["_den_delay_ptr"] << " = 0;" << std::endl; @@ -372,6 +372,7 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, Environment // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); + backend.genNeuronIndexCalculation(groupEnv, model.getBatchSize()); // Initialise spike counts genInitSpikeCount(backend, groupEnv, false, model.getBatchSize()); @@ -406,31 +407,31 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, Environment // Add spike queue pointer field and zero groupEnv.addField(Type::Uint32.createPointer(), "_spk_que_ptr", "spkQuePtr", [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + g.getName(); }); - backend.genPopVariableInit(env, + backend.genPopVariableInit(groupEnv, [](EnvironmentExternalBase &varEnv) { - varEnv.getStream() << "*" << varEnv["_spk_que_ptr"] << " = 0;" << std::endl; + varEnv.printLine("*$(_spk_que_ptr) = 0;"); }); } // Initialise neuron variables - genInitNeuronVarCode(backend, env, *this, "", "num_neurons", 0, modelMerged.getModel().getBatchSize()); + genInitNeuronVarCode(backend, groupEnv, *this, "", "num_neurons", 0, modelMerged.getModel().getBatchSize()); // Generate initialisation code for child groups for (auto &cs : m_MergedCurrentSourceGroups) { - cs.generate(backend, env, *this, modelMerged); + cs.generate(backend, groupEnv, *this, modelMerged); } for(auto &sg : m_MergedInSynPSMGroups) { - sg.generate(backend, env, *this, modelMerged); + sg.generate(backend, groupEnv, *this, modelMerged); } for (auto &sg : m_MergedOutSynPreOutputGroups) { - sg.generate(backend, env, *this, modelMerged); + sg.generate(backend, groupEnv, *this, modelMerged); } for (auto &sg : m_MergedOutSynWUMPreVarGroups) { - sg.generate(backend, env, *this, modelMerged); + sg.generate(backend, groupEnv, *this, modelMerged); } for (auto &sg : m_MergedInSynWUMPostVarGroups) { - sg.generate(backend, env, *this, modelMerged); + sg.generate(backend, groupEnv, *this, modelMerged); } } //-------------------------------------------------------------------------- @@ -538,6 +539,7 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, Environmen { // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); + backend.genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); // If model is batched and has kernel weights const bool kernel = (getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL); @@ -561,7 +563,7 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, Environmen // If we're using non-kernel weights, generate loop over source neurons if (!kernel) { - groupEnv.getStream() << "for(unsigned int i = 0; i < group->numSrcNeurons; i++)"; + groupEnv.print("for(unsigned int i = 0; i < $(num_pre); i++)"); groupEnv.getStream() << CodeStream::OB(1); groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); } @@ -611,7 +613,10 @@ boost::uuids::detail::sha1::digest_type SynapseSparseInitGroupMerged::getHashDig //---------------------------------------------------------------------------- void SynapseSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { - genInitWUVarCode(backend, env, *this, "$(num_pre) * $(_row_stride)", modelMerged.getModel().getBatchSize(), + // Create environment for group + EnvironmentGroupMergedField groupEnv(env, *this); + backend.genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); + genInitWUVarCode(backend, groupEnv, *this, "$(num_pre) * $(_row_stride)", modelMerged.getModel().getBatchSize(), [&backend](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { backend.genSparseSynapseVariableRowInit(varInitEnv, handler); @@ -731,13 +736,14 @@ void SynapseConnectivityInitGroupMerged::genInitConnectivity(const BackendBase & const auto stateVars = rowNotColumns ? snippet->getRowBuildStateVars() : snippet->getColBuildStateVars(); const std::string context = rowNotColumns ? "row" : "column"; for(const auto &a : stateVars) { - - groupEnv.getStream() << a.type.resolve(getTypeContext()).getName() << " " << a.name << " = "; + const auto resolvedType = a.type.resolve(getTypeContext()); + groupEnv.getStream() << resolvedType.getName() << " _" << a.name << " = "; Transpiler::ErrorHandler errorHandler("Connectivity init " + context + " build state var" + std::to_string(getIndex())); prettyPrintExpression(a.value, getTypeContext(), groupEnv, errorHandler); groupEnv.getStream() << ";" << std::endl; + groupEnv.add(resolvedType, a.name, "_" + a.name); } groupEnv.getStream() << "while(true)"; { diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index a3f72b43d5..43ad86016f 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -40,7 +40,7 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( - *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, + *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [&modelMerged, &ng](const std::string&, VarAccessDuplication d) { return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "id"); @@ -120,7 +120,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( - *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, + *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [&modelMerged, &ng](const std::string&, VarAccessDuplication d) { return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "id"); @@ -202,7 +202,7 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back // Create an environment which caches variables in local variables if they are accessed const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( - *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, + *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { return ng.getReadVarIndex(delayed, batchSize, d, "id"); @@ -288,7 +288,7 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back // Create an environment which caches variables in local variables if they are accessed const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( - *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), "l", fieldSuffix, + *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [batchSize, delayed, &ng](const std::string&, VarAccessDuplication d) { return ng.getReadVarIndex(delayed, batchSize, d, "id"); @@ -515,34 +515,34 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Create an environment which caches variables in local variables if they are accessed // **NOTE** we do this right at the top so that local copies can be used by child groups EnvironmentLocalVarCache neuronVarEnv( - *this, *this, getTypeContext(), neuronEnv, backend.getDeviceVarPrefix(), "l", "", + *this, *this, getTypeContext(), neuronEnv, backend.getDeviceVarPrefix(), "", "l", [batchSize, &neuronEnv, this](const std::string &varName, VarAccessDuplication d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getReadVarIndex(delayed, batchSize, d, neuronEnv["id"]) ; + return getReadVarIndex(delayed, batchSize, d, "id") ; }, [batchSize, &neuronEnv, this](const std::string &varName, VarAccessDuplication d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getWriteVarIndex(delayed, batchSize, d, neuronEnv["id"]) ; + return getWriteVarIndex(delayed, batchSize, d, "id") ; }); // Loop through incoming synapse groups for(auto &sg : m_MergedInSynPSMGroups) { - CodeStream::Scope b(env.getStream()); + CodeStream::Scope b(neuronVarEnv.getStream()); sg.generate(backend, neuronVarEnv, *this, modelMerged); } // Loop through outgoing synapse groups with presynaptic output for (auto &sg : m_MergedOutSynPreOutputGroups) { - CodeStream::Scope b(env.getStream()); + CodeStream::Scope b(neuronVarEnv.getStream()); sg.generate(backend, neuronVarEnv, *this, modelMerged); } // Loop through all of neuron group's current sources for (auto &cs : m_MergedCurrentSourceGroups) { - CodeStream::Scope b(env.getStream()); + CodeStream::Scope b(neuronVarEnv.getStream()); cs.generate(backend, neuronVarEnv, *this, modelMerged); } From e93888f3f58a8aa28b50fc5f1b2bf8ba4af08c32 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 17:45:57 +0100 Subject: [PATCH 290/725] for some reason had to turn on /bigobj --- src/genn/genn/genn.vcxproj | 1 + 1 file changed, 1 insertion(+) diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index d0cb8d900e..66b9642e74 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -173,6 +173,7 @@ 4251 stdcpp17 true + /bigobj %(AdditionalOptions) true From b1328754bd5715fcf259ee0ce4bba6c475b6a1d1 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 17:46:08 +0100 Subject: [PATCH 291/725] fixed several subtle lambda capturing bugs --- .../genn/genn/code_generator/environment.h | 71 ++++++++++--------- 1 file changed, 36 insertions(+), 35 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index d901a0d020..2acdc380f2 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -205,12 +205,14 @@ class EnvironmentFieldPolicy if (std::get<2>(payload) && !std::get<0>(payload)) { // Extract field from payload const auto &field = std::get<2>(payload).value(); + const auto &group = getGroup(); - // Add to field group using lambda function to potentially map from group to field + // Add to field group using lambda function to potentially map from group to field + // **NOTE** this will have been destroyed by the point this is called so need careful capturing! m_FieldGroup.get().addField(std::get<0>(field), std::get<1>(field), - [this, &field](const typename F::GroupInternal &, size_t i) + [field, &group](const typename F::GroupInternal &, size_t i) { - return std::get<2>(field)(getGroup().getGroups().at(i), i); + return std::get<2>(field)(group.getGroups().at(i), i); }, std::get<3>(field)); @@ -364,7 +366,6 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &(GroupInternal::*)(void) const; using GetVarIndexFn = std::function; @@ -411,11 +412,13 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase Date: Sat, 1 Jul 2023 17:59:59 +0100 Subject: [PATCH 292/725] ensure that pretty printing marks variables as being required in EnvironmentExternalDynamicBase --- include/genn/genn/code_generator/environment.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 2acdc380f2..a866baba8a 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -274,6 +274,14 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P } // Otherwise, get name from payload else { + // If this identifier relies on any initialiser statements, mark these initialisers as required + for(size_t i : std::get<1>(env->second)) { + m_Initialisers.at(i).first = true; + } + + // Perform any type-specific logic to mark this identifier as required + setRequired(std::get<2>(env->second)); + return getNameInternal(std::get<2>(env->second)); } } From 38521c34dc9824af8ff0beabaa7c0a596928b339 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 18:00:12 +0100 Subject: [PATCH 293/725] updated some code generation --- src/genn/genn/code_generator/neuronUpdateGroupMerged.cc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 43ad86016f..f2917b8545 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -718,16 +718,12 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // If spike times are required, copy times from register if(getArchetype().isSpikeTimeRequired()) { - neuronVarEnv.getStream() << "group->sT["; - neuronVarEnv.getStream() << getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, neuronVarEnv["id"]); - neuronVarEnv.getStream() << "] = " << neuronVarEnv["sT"] << ";" << std::endl; + neuronVarEnv.printLine("$(_spk_time)[" + getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, "id") + "] = $(sT);"); } // If previous spike times are required, copy times from register if(getArchetype().isPrevSpikeTimeRequired()) { - neuronVarEnv.getStream() << "group->prevST["; - neuronVarEnv.getStream() << getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, neuronVarEnv["id"]); - neuronVarEnv.getStream() << "] = " << neuronVarEnv["prev_sT"] << ";" << std::endl; + neuronVarEnv.printLine("$(_prev_spk_time)[" + getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, "id") + "] = $(prev_sT);"); } // Loop through outgoing synapse groups with some sort of presynaptic code From 9bac45fc41e9788819802b6a43eec4a975136b3d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 18:01:55 +0100 Subject: [PATCH 294/725] fixed another lambda capture bug --- src/genn/genn/code_generator/initGroupMerged.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index b05e2001b7..7413f9a1de 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -473,7 +473,7 @@ void NeuronInitGroupMerged::genInitSpikes(const BackendBase &backend, Environmen const std::string suffix = spikeEvent ? "Evnt" : ""; EnvironmentGroupMergedField spikeEnv(env, *this); spikeEnv.addField(Type::Uint32.createPointer(), "_spk", "spk" + suffix, - [&backend, &suffix](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpk" + suffix + g.getName(); }); + [&backend, suffix](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpk" + suffix + g.getName(); }); // Generate variable initialisation code From 6eec9347094b4582e506bf5baf62d7818bd8e5c4 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 19:14:42 +0100 Subject: [PATCH 295/725] missing suffixes --- src/genn/genn/code_generator/initGroupMerged.cc | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 7413f9a1de..e9dcb228e9 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -217,7 +217,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir EnvironmentGroupMergedField groupEnv(env, *this, ng); // Add field for InSyn and zero - groupEnv.addField(getScalarType().createPointer(), "_out_post", "outPost", + groupEnv.addField(getScalarType().createPointer(), "_out_post", "outPost" + fieldSuffix, [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); backend.genVariableInit(groupEnv, "num_neurons", "id", [&modelMerged] (EnvironmentExternalBase &varEnv) @@ -231,7 +231,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir // If dendritic delays are required if(getArchetype().isDendriticDelayRequired()) { // Add field for dendritic delay buffer and zero - groupEnv.addField(getScalarType().createPointer(), "_den_delay", "denDelay", + groupEnv.addField(getScalarType().createPointer(), "_den_delay", "denDelay" + fieldSuffix, [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); backend.genVariableInit(groupEnv, "num_neurons", "id", [&modelMerged, this](EnvironmentExternalBase &varEnv) @@ -243,7 +243,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir }); // Add field for dendritic delay pointer and zero - groupEnv.addField(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr", + groupEnv.addField(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr" + fieldSuffix, [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); backend.genPopVariableInit(groupEnv, [](EnvironmentExternalBase &varEnv) @@ -262,13 +262,11 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, EnvironmentExternalBase &env, NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) { - const std::string suffix = "OutSyn" + std::to_string(getIndex()); - // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this, ng); // Add - groupEnv.addField(getScalarType().createPointer(), "_out_pre", "outPre", + groupEnv.addField(getScalarType().createPointer(), "_out_pre", "outPreOutSyn" + std::to_string(getIndex()), [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); backend.genVariableInit(env, "num_neurons", "id", [&modelMerged] (EnvironmentExternalBase &varEnv) From a99bec05d80f56db5f07c9e845b01140a71a12c9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 19:14:48 +0100 Subject: [PATCH 296/725] new names for insyn variables and use resolved type to get timepoint type --- src/genn/backends/cuda/backend.cc | 18 +++++++++--------- .../backends/single_threaded_cpu/backend.cc | 4 ++-- src/genn/genn/code_generator/generateRunner.cc | 12 ++++++------ .../code_generator/neuronUpdateGroupMerged.cc | 9 +++++---- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 878ab3f25b..6f1e67e775 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -445,7 +445,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // If any neuron groups require their previous spike times updating size_t idNeuronPrevSpikeTimeUpdate = 0; if(!modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronPrevSpikeTimeUpdate] << "(timepoint t)"; + os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronPrevSpikeTimeUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)"; { CodeStream::Scope b(os); @@ -474,7 +474,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged os << std::endl; size_t idStart = 0; - os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronUpdate] << "(timepoint t"; + os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t"; if(model.isRecordingInUse()) { os << ", unsigned int recordingTimestep"; } @@ -496,7 +496,7 @@ void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged genNeuronUpdateKernel(os, kernelSubs, modelMerged, idStart); } - os << "void updateNeurons(timepoint t"; + os << "void updateNeurons(" << modelMerged.getModel().getTimePrecision().getName() << " t"; if(model.isRecordingInUse()) { os << ", unsigned int recordingTimestep"; } @@ -579,7 +579,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge // If there are any presynaptic update groups size_t idPresynapticStart = 0; if(!modelMerged.getMergedPresynapticUpdateGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelPresynapticUpdate] << "(timepoint t)" << std::endl; // end of synapse kernel header + os << "extern \"C\" __global__ void " << KernelNames[KernelPresynapticUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; // end of synapse kernel header { CodeStream::Scope b(os); @@ -601,7 +601,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge // If any synapse groups require postsynaptic learning size_t idPostsynapticStart = 0; if(!modelMerged.getMergedPostsynapticUpdateGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelPostsynapticUpdate] << "(timepoint t)" << std::endl; + os << "extern \"C\" __global__ void " << KernelNames[KernelPostsynapticUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; { CodeStream::Scope b(os); @@ -622,7 +622,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge size_t idSynapseDynamicsStart = 0; if(!modelMerged.getMergedSynapseDynamicsGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDynamicsUpdate] << "(timepoint t)" << std::endl; // end of synapse kernel header + os << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDynamicsUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; // end of synapse kernel header { CodeStream::Scope b(os); @@ -641,7 +641,7 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge } } - os << "void updateSynapses(timepoint t)"; + os << "void updateSynapses(" << modelMerged.getModel().getTimePrecision().getName() << " t)"; { CodeStream::Scope b(os); @@ -749,7 +749,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged [this](const CustomConnectivityUpdateInternal &cg){ return padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelCustomUpdate); }, [g](const CustomConnectivityUpdateGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; }); - os << "extern \"C\" __global__ void " << KernelNames[KernelCustomUpdate] << g << "(timepoint t)" << std::endl; + os << "extern \"C\" __global__ void " << KernelNames[KernelCustomUpdate] << g << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; { CodeStream::Scope b(os); @@ -780,7 +780,7 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged [&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); }, [g](const CustomUpdateTransposeWUGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; }); - os << "extern \"C\" __global__ void " << KernelNames[KernelCustomTransposeUpdate] << g << "(timepoint t)" << std::endl; + os << "extern \"C\" __global__ void " << KernelNames[KernelCustomTransposeUpdate] << g << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; { CodeStream::Scope b(os); diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 876f16ea16..14d2d814f3 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -129,7 +129,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Begin environment with standard library EnvironmentLibrary neuronUpdateEnv(neuronUpdate, StandardLibrary::getFunctions()); - neuronUpdateEnv.getStream() << "void updateNeurons(timepoint t"; + neuronUpdateEnv.getStream() << "void updateNeurons(" << modelMerged.getModel().getTimePrecision().getName() << " t"; if(modelMerged.getModel().isRecordingInUse()) { neuronUpdateEnv.getStream() << ", unsigned int recordingTimestep"; } @@ -315,7 +315,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Begin environment with standard library EnvironmentLibrary synapseUpdateEnv(synapseUpdate, StandardLibrary::getFunctions()); - synapseUpdateEnv.getStream() << "void updateSynapses(timepoint t)"; + synapseUpdateEnv.getStream() << "void updateSynapses(" << modelMerged.getModel().getTimePrecision().getName() << " t)"; { CodeStream::Scope b(synapseUpdateEnv.getStream()); diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index b190594be4..b1e4560a58 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -617,9 +617,9 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Define and declare time variables definitionsVar << "EXPORT_VAR unsigned long long iT;" << std::endl; - definitionsVar << "EXPORT_VAR timepoint t;" << std::endl; + definitionsVar << "EXPORT_VAR " << model.getTimePrecision().getName() << " t;" << std::endl; runnerVarDecl << "unsigned long long iT;" << std::endl; - runnerVarDecl << "timepoint t;" << std::endl; + runnerVarDecl << model.getTimePrecision().getName() << " t;" << std::endl; if(model.isRecordingInUse()) { runnerVarDecl << "unsigned long long numRecordingTimesteps = 0;" << std::endl; @@ -1245,7 +1245,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Loop through merged postsynaptic models of incoming synaptic populations for(const auto *sg : n.second.getFusedPSMInSyn()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getPrecision(), "inSyn" + sg->getFusedPSVarSuffix(), + model.getPrecision(), "outPost" + sg->getFusedPSVarSuffix(), sg->getInSynLocation(), sg->getTrgNeuronGroup()->getNumNeurons() * batchSize, mem); if (sg->isDendriticDelayRequired()) { @@ -1266,7 +1266,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // Loop through fused outgoing synapse populations with weightupdate models that have presynaptic output for(const auto *sg : n.second.getFusedPreOutputOutSyn()) { backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, - model.getPrecision(), "revInSyn" + sg->getFusedPreOutputSuffix(), + model.getPrecision(), "outPre" + sg->getFusedPreOutputSuffix(), sg->getInSynLocation(), sg->getSrcNeuronGroup()->getNumNeurons() * batchSize, mem); } @@ -1790,12 +1790,12 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, definitions << "EXPORT_FUNC void stepTime();" << std::endl; definitions << std::endl; definitions << "// Functions generated by backend" << std::endl; - definitions << "EXPORT_FUNC void updateNeurons(timepoint t"; + definitions << "EXPORT_FUNC void updateNeurons(" << modelMerged.getModel().getTimePrecision().getName() << " t"; if(model.isRecordingInUse()) { definitions << ", unsigned int recordingTimestep"; } definitions << "); " << std::endl; - definitions << "EXPORT_FUNC void updateSynapses(timepoint t);" << std::endl; + definitions << "EXPORT_FUNC void updateSynapses(" << modelMerged.getModel().getTimePrecision().getName() << " t);" << std::endl; definitions << "EXPORT_FUNC void initialize();" << std::endl; definitions << "EXPORT_FUNC void initializeSparse();" << std::endl; diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index f2917b8545..c1bc1e0efd 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -502,15 +502,16 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E neuronEnv.addExtraGlobalParams(nm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); // Substitute spike times + const std::string timePrecision = modelMerged.getModel().getTimePrecision().getName(); const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "id"); neuronEnv.add(getTimeType().addConst(), "sT", "lsT", - {neuronEnv.addInitialiser("const timepoint lsT = $(_spk_time)[" + spikeTimeReadIndex + "];")}); + {neuronEnv.addInitialiser("const " + timePrecision + " lsT = $(_spk_time)[" + spikeTimeReadIndex + "];")}); neuronEnv.add(getTimeType().addConst(), "prev_sT", "lprevST", - {neuronEnv.addInitialiser("const timepoint lprevST = $(_prev_spk_time)[" + spikeTimeReadIndex + "];")}); + {neuronEnv.addInitialiser("const " + timePrecision + " lprevST = $(_prev_spk_time)[" + spikeTimeReadIndex + "];")}); neuronEnv.add(getTimeType().addConst(), "seT", "lseT", - {neuronEnv.addInitialiser("const timepoint lseT = $(_spk_evnt_time)[" + spikeTimeReadIndex+ "];")}); + {neuronEnv.addInitialiser("const " + timePrecision + " lseT = $(_spk_evnt_time)[" + spikeTimeReadIndex+ "];")}); neuronEnv.add(getTimeType().addConst(), "prev_seT", "lprevSET", - {neuronEnv.addInitialiser("const timepoint lprevSET = $(_prev_spk_evnt_time)[" + spikeTimeReadIndex + "];")}); + {neuronEnv.addInitialiser("const " + timePrecision + " lprevSET = $(_prev_spk_evnt_time)[" + spikeTimeReadIndex + "];")}); // Create an environment which caches variables in local variables if they are accessed // **NOTE** we do this right at the top so that local copies can be used by child groups From 047a1d91733d689fa7853406fc07c713c053771f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 20:17:35 +0100 Subject: [PATCH 297/725] updated built in init var snippet syntax --- include/genn/genn/initVarSnippet.h | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/include/genn/genn/initVarSnippet.h b/include/genn/genn/initVarSnippet.h index c56e2b6749..2056c1e22d 100644 --- a/include/genn/genn/initVarSnippet.h +++ b/include/genn/genn/initVarSnippet.h @@ -60,7 +60,7 @@ class Constant : public Base public: DECLARE_SNIPPET(InitVarSnippet::Constant); - SET_CODE("$(value) = $(constant);"); + SET_CODE("value = constant;"); SET_PARAM_NAMES({"constant"}); }; @@ -73,7 +73,7 @@ class Kernel : public Base { DECLARE_SNIPPET(InitVarSnippet::Kernel); - SET_CODE("$(value) = $(kernel)[$(id_kernel)];"); + SET_CODE("value = kerne)[id_kerneL];"); SET_EXTRA_GLOBAL_PARAMS({{"kernel", "scalar*"}}); }; @@ -92,8 +92,8 @@ class Uniform : public Base DECLARE_SNIPPET(InitVarSnippet::Uniform); SET_CODE( - "const scalar scale = $(max) - $(min);\n" - "$(value) = $(min) + ($(gennrand_uniform) * scale);"); + "const scalar scale = max - min;\n" + "value = min + (gennrand_uniform() * scale);"); SET_PARAM_NAMES({"min", "max"}); }; @@ -111,7 +111,7 @@ class Normal : public Base public: DECLARE_SNIPPET(InitVarSnippet::Normal); - SET_CODE("$(value) = $(mean) + ($(gennrand_normal) * $(sd));"); + SET_CODE("value = mean + (gennrand_normal() * sd);"); SET_PARAM_NAMES({"mean", "sd"}); }; @@ -136,9 +136,9 @@ class NormalClipped : public Base "scalar normal;\n" "do\n" "{\n" - " normal = $(mean) + ($(gennrand_normal) * $(sd));\n" - "} while (normal > $(max) || normal < $(min));\n" - "$(value) = normal;\n"); + " normal = mean + (gennrand_normal() * sd);\n" + "} while (normal > max || normal < min);\n" + "value = normal;\n"); SET_PARAM_NAMES({"mean", "sd", "min", "max"}); }; @@ -165,9 +165,9 @@ class NormalClippedDelay : public Base "scalar normal;\n" "do\n" "{\n" - " normal = $(meanTimestep) + ($(gennrand_normal) * $(sdTimestep));\n" - "} while (normal > $(maxTimestep) || normal < $(minTimestep));\n" - "$(value) = rint(normal);\n"); + " normal = meanTimestep + (gennrand_normal() * sdTimestep);\n" + "} while (normal > maxTimestep || normal < minTimestep);\n" + "value = rint(normal);\n"); SET_PARAM_NAMES({"mean", "sd", "min", "max"}); SET_DERIVED_PARAMS({ @@ -189,7 +189,7 @@ class Exponential : public Base public: DECLARE_SNIPPET(InitVarSnippet::Exponential); - SET_CODE("$(value) = $(lambda) * $(gennrand_exponential);"); + SET_CODE("value = lambda * gennrand_exponential();"); SET_PARAM_NAMES({"lambda"}); }; @@ -207,7 +207,7 @@ class Gamma : public Base public: DECLARE_SNIPPET(InitVarSnippet::Gamma); - SET_CODE("$(value) = $(b) * $(gennrand_gamma, $(a));"); + SET_CODE("value = b * gennrand_gamma(a);"); SET_PARAM_NAMES({"a", "b"}); }; @@ -225,7 +225,7 @@ class Binomial : public Base public: DECLARE_SNIPPET(InitVarSnippet::Binomial); - SET_CODE("$(value) = $(gennrand_binomial, (unsigned int)$(n), $(p));"); + SET_CODE("value = gennrand_binomial((unsigned int)n, p);"); SET_PARAM_NAMES({"n", "p"}); }; From f2e98844ae821a9a3b326aa18fcf5697d63a5730 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 20:21:00 +0100 Subject: [PATCH 298/725] removed trailing comma when pretty-printing var declarations --- src/genn/genn/transpiler/prettyPrinter.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index a9fe80ae7d..dc41ec1dea 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -457,13 +457,17 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { m_Environment.get().getStream() << varDeclaration.getType().getName() << " "; - for(const auto &var : varDeclaration.getInitDeclaratorList()) { + const size_t numDeclarators = varDeclaration.getInitDeclaratorList().size(); + for(size_t i = 0; i < numDeclarators; i++) { + const auto &var = varDeclaration.getInitDeclaratorList()[i]; m_Environment.get().getStream() << m_Environment.get().define(std::get<0>(var).lexeme); if(std::get<1>(var)) { m_Environment.get().getStream() << " = "; std::get<1>(var)->accept(*this); } - m_Environment.get().getStream() << ", "; + if(i != (numDeclarators - 1)) { + m_Environment.get().getStream() << ", "; + } } m_Environment.get().getStream() << ";"; } From 9e9cb8fd384433409bc913dda806352de998593d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 20:21:20 +0100 Subject: [PATCH 299/725] fixed some synapse group generation issues --- .../backends/single_threaded_cpu/backend.cc | 30 ++++++++++++------- .../synapseUpdateGroupMerged.cc | 8 ++--- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 14d2d814f3..00c73d76d5 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -389,9 +389,9 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos } // Add correct functions for apply synaptic input - synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", "$(_den_delay)[" + s.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)"); - synEnv.add(Type::AddToPost, "addToPost", "$(_out_post)[" + s.getPostISynIndex(1, "j") + "] += $(0)"); - synEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + s.getPreISynIndex(1, "id_pre") + "] += $(0)"); + synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", "$(_den_delay)[" + s.getPostDenDelayIndex(1, "$(id_post)", "$(1)") + "] += $(0)"); + synEnv.add(Type::AddToPost, "addToPost", "$(_out_post)[" + s.getPostISynIndex(1, "$(id_post)") + "] += $(0)"); + synEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + s.getPreISynIndex(1, "$(id_pre)") + "] += $(0)"); // Call synapse dynamics handler s.generateSynapseUpdate(*this, synEnv, modelMerged); @@ -505,7 +505,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos } synEnv.add(Type::Uint32.addConst(), "id_post", "spike"); - synEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + s.getPreISynIndex(1, "id_pre") + "] += $(0)"); + synEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + s.getPreISynIndex(1, "$(id_pre)") + "] += $(0)"); s.generateSynapseUpdate(*this, synEnv, modelMerged); } @@ -1867,14 +1867,9 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda groupEnv.getStream() << CodeStream::OB(10); } - // Add correct functions for apply synaptic input - groupEnv.add(Type::AddToPostDenDelay, "addToPostDelay", "$(_den_delay)[" + sg.getPostDenDelayIndex(1, "j", "$(1)") + "] += $(0)"); - groupEnv.add(Type::AddToPost, "addToPost", "$(_out_post)[" + sg.getPostISynIndex(1, "j") + "] += $(0)"); - groupEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + sg.getPreISynIndex(1, "$(id_pre)") + "] += $(0)"); - // If connectivity is sparse if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - groupEnv.printLine("const unsigned int npost = $(_row_length)[ipre];"); + groupEnv.printLine("const unsigned int npost = $(_row_length)[$(id_pre)];"); groupEnv.getStream() << "for (unsigned int j = 0; j < npost; j++)"; { CodeStream::Scope b(groupEnv.getStream()); @@ -1886,6 +1881,11 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda synEnv.add(Type::Uint32, "id_post", "idPost", {synEnv.addInitialiser("const unsigned int idPost = $(_ind)[$(id_syn)];")}); + // Add correct functions for apply synaptic input + synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", "$(_den_delay)[" + sg.getPostDenDelayIndex(1, "$(id_post)", "$(1)") + "] += $(0)"); + synEnv.add(Type::AddToPost, "addToPost", "$(_out_post)[" + sg.getPostISynIndex(1, "$(id_post)") + "] += $(0)"); + synEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + sg.getPreISynIndex(1, "$(id_pre)") + "] += $(0)"); + if(trueSpike) { sg.generateSpikeUpdate(*this, synEnv, modelMerged); } @@ -1910,6 +1910,11 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda // Set ipost to first synapse in connectivity word groupEnv.getStream() << "unsigned int ipost = w * 32;" << std::endl; groupEnv.add(Type::Uint32, "id_post", "ipost"); + + // Add correct functions for apply synaptic input + groupEnv.add(Type::AddToPostDenDelay, "addToPostDelay", "$(_den_delay)[" + sg.getPostDenDelayIndex(1, "$(id_post)", "$(1)") + "] += $(0)"); + groupEnv.add(Type::AddToPost, "addToPost", "$(_out_post)[" + sg.getPostISynIndex(1, "$(id_post)") + "] += $(0)"); + groupEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + sg.getPreISynIndex(1, "$(id_pre)") + "] += $(0)"); // While there any bits left groupEnv.getStream() << "while(connectivityWord != 0)"; @@ -1951,6 +1956,11 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda CodeStream::Scope b(groupEnv.getStream()); EnvironmentGroupMergedField synEnv(groupEnv, sg); synEnv.add(Type::Uint32, "id_post", "ipost"); + + // Add correct functions for apply synaptic input + synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", "$(_den_delay)[" + sg.getPostDenDelayIndex(1, "$(id_post)", "$(1)") + "] += $(0)"); + synEnv.add(Type::AddToPost, "addToPost", "$(_out_post)[" + sg.getPostISynIndex(1, "$(id_post)") + "] += $(0)"); + synEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + sg.getPreISynIndex(1, "$(id_pre)") + "] += $(0)"); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { // **TODO** 64-bit index diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 854f9badaa..94b1d960d9 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -33,12 +33,12 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa synEnv.addVars(backend.getDeviceVarPrefix(), [&sg, batchSize](VarAccess a, const std::string&) { - return sg.getPreWUVarIndex(batchSize, getVarAccessDuplication(a), "id_pre"); + return sg.getPreWUVarIndex(batchSize, getVarAccessDuplication(a), "$(id_pre)"); }); synEnv.addVars(backend.getDeviceVarPrefix(), [&sg, batchSize](VarAccess a, const std::string&) { - return sg.getPostWUVarIndex(batchSize, getVarAccessDuplication(a), "id_post"); + return sg.getPostWUVarIndex(batchSize, getVarAccessDuplication(a), "$(id_post)"); }); @@ -54,7 +54,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa synEnv.addVars(backend.getDeviceVarPrefix(), [&sg, batchSize](VarAccess a, const std::string&) { - return sg.getSynVarIndex(batchSize, getVarAccessDuplication(a), "id_syn"); + return sg.getSynVarIndex(batchSize, getVarAccessDuplication(a), "$(id_syn)"); }); } // Otherwise, if weights are procedual @@ -104,7 +104,7 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa synEnv.addVars(backend.getDeviceVarPrefix(), [&sg, batchSize](VarAccess a, const std::string&) { - return sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), "id_kernel"); + return sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), "$(id_kernel)"); }); } // Otherwise, substitute variables for constant values From 257746e8cecd2da113d70726c3dc9533ff7c8ce0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 20:27:40 +0100 Subject: [PATCH 300/725] fixed some more use of "scalar" --- .../genn/genn/code_generator/modelSpecMerged.h | 3 --- src/genn/backends/single_threaded_cpu/backend.cc | 6 +++--- src/genn/genn/code_generator/initGroupMerged.cc | 10 +++++----- src/genn/genn/code_generator/modelSpecMerged.cc | 6 ------ .../code_generator/neuronUpdateGroupMerged.cc | 15 +++++++-------- 5 files changed, 15 insertions(+), 25 deletions(-) diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 43cd191330..2c47375bf6 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -251,9 +251,6 @@ class GENN_EXPORT ModelSpecMerged //! Get hash digest of init module boost::uuids::detail::sha1::digest_type getInitArchetypeHashDigest() const; - //! Get the string literal that should be used to represent a value in scalar type - std::string scalarExpr(double value) const; - //! Does model have any EGPs? bool anyPointerEGPs() const; diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 00c73d76d5..5797ca2176 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1362,9 +1362,9 @@ void Backend::genRunnerPreamble(CodeStream &os, const ModelSpecMerged &modelMerg // If a global RNG is required, implement standard host distributions as recreating them each call is slow if(isGlobalHostRNGRequired(modelMerged)) { - os << "std::uniform_real_distribution<" << model.getPrecision().getName() << "> standardUniformDistribution(" << modelMerged.scalarExpr(0.0) << ", " << modelMerged.scalarExpr(1.0) << ");" << std::endl; - os << "std::normal_distribution<" << model.getPrecision().getName() << "> standardNormalDistribution(" << modelMerged.scalarExpr(0.0) << ", " << modelMerged.scalarExpr(1.0) << ");" << std::endl; - os << "std::exponential_distribution<" << model.getPrecision().getName() << "> standardExponentialDistribution(" << modelMerged.scalarExpr(1.0) << ");" << std::endl; + os << "std::uniform_real_distribution<" << model.getPrecision().getName() << "> standardUniformDistribution(" << writePreciseLiteral(0.0, model.getPrecision()) << ", " << writePreciseLiteral(1.0, model.getPrecision()) << ");" << std::endl; + os << "std::normal_distribution<" << model.getPrecision().getName() << "> standardNormalDistribution(" << writePreciseLiteral(0.0, model.getPrecision()) << ", " << writePreciseLiteral(1.0, model.getPrecision()) << ");" << std::endl; + os << "std::exponential_distribution<" << model.getPrecision().getName() << "> standardExponentialDistribution(" << writePreciseLiteral(1.0, model.getPrecision()) << ");" << std::endl; os << std::endl; } os << std::endl; diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index e9dcb228e9..cd1c9d90ea 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -220,9 +220,9 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir groupEnv.addField(getScalarType().createPointer(), "_out_post", "outPost" + fieldSuffix, [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); backend.genVariableInit(groupEnv, "num_neurons", "id", - [&modelMerged] (EnvironmentExternalBase &varEnv) + [&modelMerged, this] (EnvironmentExternalBase &varEnv) { - genVariableFill(varEnv, "_out_post", modelMerged.scalarExpr(0.0), + genVariableFill(varEnv, "_out_post", writePreciseLiteral(0.0, getScalarType()), "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, modelMerged.getModel().getBatchSize()); @@ -236,7 +236,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir backend.genVariableInit(groupEnv, "num_neurons", "id", [&modelMerged, this](EnvironmentExternalBase &varEnv) { - genVariableFill(varEnv, "_den_delay", modelMerged.scalarExpr(0.0), + genVariableFill(varEnv, "_den_delay", writePreciseLiteral(0.0, getScalarType()), "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, modelMerged.getModel().getBatchSize(), true, getArchetype().getMaxDendriticDelayTimesteps()); @@ -269,9 +269,9 @@ void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend groupEnv.addField(getScalarType().createPointer(), "_out_pre", "outPreOutSyn" + std::to_string(getIndex()), [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); backend.genVariableInit(env, "num_neurons", "id", - [&modelMerged] (EnvironmentExternalBase &varEnv) + [&modelMerged, this] (EnvironmentExternalBase &varEnv) { - genVariableFill(varEnv, "_out_pre", modelMerged.scalarExpr(0.0), + genVariableFill(varEnv, "_out_pre", writePreciseLiteral(0.0, getScalarType()), "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, modelMerged.getModel().getBatchSize()); }); diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index ce4165673a..3a22421ad3 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -636,12 +636,6 @@ boost::uuids::detail::sha1::digest_type ModelSpecMerged::getInitArchetypeHashDig return hash.get_digest(); } //---------------------------------------------------------------------------- -std::string ModelSpecMerged::scalarExpr(double value) const -{ - const auto scalarNumeric = m_TypeContext.at("scalar").getNumeric(); - return Utils::writePreciseString(value, scalarNumeric.maxDigits10) + scalarNumeric.literalSuffix; -} -//---------------------------------------------------------------------------- bool ModelSpecMerged::anyPointerEGPs() const { // Loop through grouped merged EGPs diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index c1bc1e0efd..f9da2c2480 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -80,31 +80,30 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env EnvironmentGroupMergedField psmEnv(env, *this, ng); // Add inSyn - const auto scalarType = modelMerged.getModel().getPrecision(); - psmEnv.addField(scalarType.createPointer(), "_out_post", "outPost" + fieldSuffix, + psmEnv.addField(getScalarType().createPointer(), "_out_post", "outPost" + fieldSuffix, [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); // Read into local variable const std::string idx = ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); psmEnv.getStream() << "// postsynaptic model " << getIndex() << std::endl; - psmEnv.printLine("scalar linSyn = $(_out_post)[" + idx + "];"); + psmEnv.printLine(getScalarType().getName() + " linSyn = $(_out_post)[" + idx + "];"); // If dendritic delay is required if (getArchetype().isDendriticDelayRequired()) { // Add dendritic delay buffer and pointer into it - psmEnv.addField(scalarType.createPointer(), "_den_delay", "denDelay" + fieldSuffix, + psmEnv.addField(getScalarType().createPointer(), "_den_delay", "denDelay" + fieldSuffix, [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix();}); psmEnv.addField(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr" + fieldSuffix, [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix();}); // Get reference to dendritic delay buffer input for this timestep - psmEnv.printLine(backend.getPointerPrefix() + "scalar *denDelayFront = &$(_den_delay)[(*$(_den_delay_ptr) * $(num_neurons)) + " + idx + "];"); + psmEnv.printLine(backend.getPointerPrefix() + getScalarType().getName() + " *denDelayFront = &$(_den_delay)[(*$(_den_delay_ptr) * $(num_neurons)) + " + idx + "];"); // Add delayed input from buffer into inSyn psmEnv.getStream() << "linSyn += *denDelayFront;" << std::endl; // Zero delay buffer slot - psmEnv.getStream() << "*denDelayFront = " << modelMerged.scalarExpr(0.0) << ";" << std::endl; + psmEnv.getStream() << "*denDelayFront = " << writePreciseLiteral(0.0, getScalarType()) << ";" << std::endl; } // Add parameters, derived parameters and extra global parameters to environment @@ -172,7 +171,7 @@ void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backe outSynEnv.printLine(getArchetype().getPreTargetVar() + " += $(_out_pre)[" + idx + "];"); // Zero it again - outSynEnv.printLine("$(_out_pre)[" + idx + "] = " + modelMerged.scalarExpr(0.0) + ";"); + outSynEnv.printLine("$(_out_pre)[" + idx + "] = " + writePreciseLiteral(0.0, getScalarType()) + ";"); } //---------------------------------------------------------------------------- @@ -487,7 +486,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Add default input variable neuronEnv.add(modelMerged.getModel().getPrecision(), "Isyn", "Isyn", - {neuronEnv.addInitialiser("scalar Isyn = 0;")}); + {neuronEnv.addInitialiser(getScalarType().getName() + " Isyn = 0;")}); // **NOTE** arbitrary code in param value to be deprecated for (const auto &v : nm->getAdditionalInputVars()) { From 4d7c9f8f0de404759c425436d70311aaf737561d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 20:45:01 +0100 Subject: [PATCH 301/725] iterative initialiser evaluation --- .../genn/genn/code_generator/environment.h | 33 ++++++++++++++++--- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index a866baba8a..8e8315bd61 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -250,11 +250,34 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P ~EnvironmentExternalDynamicBase() { - // Loop through initialiser - for(const auto &i : m_Initialisers) { - // If variable requiring initialiser has been referenced, write out initialiser - if (i.first) { - getContextStream() << i.second.str() << std::endl; + // Loop through initialisers + std::vector initialiserCode(m_Initialisers.size()); + + // Because initialisers may refer to other initialisers, + // keep evaluating initialisers until no new ones are founf + bool anyReferences; + do { + // Loop through initialiser + anyReferences = false; + for(size_t i = 0; i < m_Initialisers.size(); i++) { + // If initialiser has been referenced + auto &initialiser = m_Initialisers[i]; + if (initialiser.first) { + // Evaluate lazy string into vector + initialiserCode[i] = initialiser.second.str(); + + // Clear referenced flag and set flag to ensure another iteration occurs + initialiser.first = false; + anyReferences = true; + } + } + } while(anyReferences); + + // Write out generated initialiser code + // **NOTE** in order + for(const auto &i : initialiserCode) { + if(!i.empty()) { + getContextStream() << i << std::endl; } } From 150973528ed382164256bb3ccc5f8fa277908d80 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sat, 1 Jul 2023 20:55:35 +0100 Subject: [PATCH 302/725] fixed minor issues - VA benchmark builds --- src/genn/genn/code_generator/generateRunner.cc | 2 +- src/genn/genn/code_generator/initGroupMerged.cc | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index b1e4560a58..a1aad3fa75 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -1756,7 +1756,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, } // Advance time runner << "iT++;" << std::endl; - runner << "t = iT*DT;" << std::endl; + runner << "t = iT * " << writePreciseLiteral(model.getDT(), model.getTimePrecision()) << ";" << std::endl; // Write step time finalize logic to runner runner << runnerStepTimeFinaliseStream.str(); diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index cd1c9d90ea..a98568e48a 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -48,14 +48,14 @@ void genScalarFill(EnvironmentExternalBase &env, const std::string &target, cons // If there's only one, don't generate a loop if(numValues == 1) { - env.getStream() << env[target] << "[0] = " << value << ";" << std::endl; + env.printLine("$(" + target + ")[0] = " + value + ";"); } // Otherwise else { env.getStream() << "for(unsigned int d = 0; d < " << numValues << "; d++)"; { CodeStream::Scope b(env.getStream()); - env.getStream() << env[target] << "[d] = " << value << ";" << std::endl; + env.printLine("$(" + target + ")[d] = " + value + ";"); } } } @@ -104,7 +104,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all delay slots and batches - genScalarFill(varInitEnv, "value", "initVal", getVarAccessDuplication(var.access), + genScalarFill(varInitEnv, "_value", "$(value)", getVarAccessDuplication(var.access), batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } @@ -125,7 +125,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all delay slots and batches - genVariableFill(varInitEnv, "value", "initVal", "id", "$(" + count + ")", + genVariableFill(varInitEnv, "_value", "$(value)", "id", "$(" + count + ")", getVarAccessDuplication(var.access), batchSize, adaptor.isVarDelayed(var.name), numDelaySlots); }); } @@ -185,7 +185,7 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all batches - genVariableFill(varInitEnv, "value", "initVal", "id_syn", stride, + genVariableFill(varInitEnv, "_value", "$(value)", "id_syn", stride, getVarAccessDuplication(var.access), batchSize); }); } From b917b1be75e9d7b4bb8296d10594ed261b709280 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sun, 2 Jul 2023 11:42:51 +0100 Subject: [PATCH 303/725] strip out unsused code gen utils --- .../genn/genn/code_generator/codeGenUtils.h | 41 -- src/genn/genn/code_generator/codeGenUtils.cc | 383 ------------------ 2 files changed, 424 deletions(-) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index b47f502f78..99e6ea6a60 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -37,35 +37,6 @@ class EnvironmentExternalBase; //-------------------------------------------------------------------------- namespace GeNN::CodeGenerator { -//-------------------------------------------------------------------------- -//! \brief Tool for substituting strings in the neuron code strings or other templates -//-------------------------------------------------------------------------- -GENN_EXPORT void substitute(std::string &s, const std::string &trg, const std::string &rep); - -//-------------------------------------------------------------------------- -//! \brief Tool for substituting variable names in the neuron code strings or other templates using regular expressions -//-------------------------------------------------------------------------- -GENN_EXPORT bool regexVarSubstitute(std::string &s, const std::string &trg, const std::string &rep); - -//-------------------------------------------------------------------------- -//! \brief Tool for substituting function names in the neuron code strings or other templates using regular expressions -//-------------------------------------------------------------------------- -GENN_EXPORT bool regexFuncSubstitute(std::string &s, const std::string &trg, const std::string &rep); - -//-------------------------------------------------------------------------- -/*! \brief This function substitutes function calls in the form: - * - * $(functionName, parameter1, param2Function(0.12, "string")) - * - * with replacement templates in the form: - * - * actualFunction(CONSTANT, $(0), $(1)) - * - */ -//-------------------------------------------------------------------------- -GENN_EXPORT void functionSubstitute(std::string &code, const std::string &funcName, - unsigned int numParams, const std::string &replaceFuncTemplate); - //! Divide two integers, rounding up i.e. effectively taking ceil inline size_t ceilDivide(size_t numerator, size_t denominator) { @@ -80,18 +51,6 @@ inline size_t padSize(size_t size, size_t blockSize) GENN_EXPORT void genTypeRange(CodeStream &os, const Type::ResolvedType &type, const std::string &prefix); -//-------------------------------------------------------------------------- -/*! \brief This function implements a parser that converts any floating point constant in a code snippet to a floating point constant with an explicit precision (by appending "f" or removing it). - */ -//-------------------------------------------------------------------------- -GENN_EXPORT std::string ensureFtype(const std::string &oldcode, const std::string &type); - -//-------------------------------------------------------------------------- -/*! \brief This function checks for unknown variable definitions and returns a gennError if any are found - */ -//-------------------------------------------------------------------------- -GENN_EXPORT void checkUnreplacedVariables(const std::string &code, const std::string &codeName); - //-------------------------------------------------------------------------- /*! \brief This function substitutes function names in a code with namespace as prefix of the function name for backends that do not support namespaces by checking that the function indeed exists in the support code and returns the substituted code. */ diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index 0ba457ddb3..a10e4fc76d 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -38,149 +38,6 @@ //-------------------------------------------------------------------------- namespace { -const std::string digits="0123456789"; -const std::string op= std::string("+-*/(<>= ,;")+std::string("\n")+std::string("\t"); - -enum MathsFunc -{ - MathsFuncCPP, - MathsFuncSingle, - MathsFuncMax, -}; - -const char *mathsFuncs[][MathsFuncMax] = { - {"cos", "cosf"}, - {"sin", "sinf"}, - {"tan", "tanf"}, - {"acos", "acosf"}, - {"asin", "asinf"}, - {"atan", "atanf"}, - {"atan2", "atan2f"}, - {"cosh", "coshf"}, - {"sinh", "sinhf"}, - {"tanh", "tanhf"}, - {"acosh", "acoshf"}, - {"asinh", "asinhf"}, - {"atanh", "atanhf"}, - {"exp", "expf"}, - {"frexp", "frexpf"}, - {"ldexp", "ldexpf"}, - {"log", "logf"}, - {"log10", "log10f"}, - {"modf", "modff"}, - {"exp2", "exp2f"}, - {"expm1", "expm1f"}, - {"ilogb", "ilogbf"}, - {"log1p", "log1pf"}, - {"log2", "log2f"}, - {"logb", "logbf"}, - {"scalbn", "scalbnf"}, - {"scalbln", "scalblnf"}, - {"pow", "powf"}, - {"sqrt", "sqrtf"}, - {"cbrt", "cbrtf"}, - {"hypot", "hypotf"}, - {"erf", "erff"}, - {"erfc", "erfcf"}, - {"tgamma", "tgammaf"}, - {"lgamma", "lgammaf"}, - {"ceil", "ceilf"}, - {"floor", "floorf"}, - {"fmod", "fmodf"}, - {"trunc", "truncf"}, - {"round", "roundf"}, - {"lround", "lroundf"}, - {"llround", "llroundf"}, - {"rint", "rintf"}, - {"lrint", "lrintf"}, - {"nearbyint", "nearbyintf"}, - {"remainder", "remainderf"}, - {"remquo", "remquof"}, - {"copysign", "copysignf"}, - {"nan", "nanf"}, - {"nextafter", "nextafterf"}, - {"nexttoward", "nexttowardf"}, - {"fdim", "fdimf"}, - {"fmax", "fmaxf"}, - {"fmin", "fminf"}, - {"fabs", "fabsf"}, - {"fma", "fmaf"} -}; -//-------------------------------------------------------------------------- -void ensureMathFunctionFtype(std::string &code) -{ - // Replace any outstanding explicit single-precision maths functions - // with C++ versions where overloads should work the same - for(const auto &m : mathsFuncs) { - GeNN::CodeGenerator::regexFuncSubstitute(code, m[MathsFuncSingle], m[MathsFuncCPP]); - } -} -//-------------------------------------------------------------------------- -void doFinal(std::string &code, unsigned int i, const std::string &type, unsigned int &state) -{ - if (code[i] == 'f') { - if (type == "double") { - code.erase(i,1); - } - } - else { - if (type == "float") { - code.insert(i,1,'f'); - } - } - if (i < code.size()-1) { - if (op.find(code[i]) == std::string::npos) { - state= 0; - } - else { - state= 1; - } - } -} -//-------------------------------------------------------------------------- -bool regexSubstitute(std::string &s, const std::regex ®ex, const std::string &format) -{ - // **NOTE** the following code performs the same function as std::regex_replace - // but has a return value indicating whether any replacements are made - // see http://en.cppreference.com/w/cpp/regex/regex_replace - - // Create regex iterator to iterate over matches found in code - std::sregex_iterator matchesBegin(s.cbegin(), s.cend(), regex); - std::sregex_iterator matchesEnd; - - // If there are no matches, leave s unmodified and return false - if(matchesBegin == matchesEnd) { - return false; - } - // Otherwise - else { - // Loop through matches - std::string output; - for(std::sregex_iterator m = matchesBegin;;) { - // Copy the non-matched subsequence (m->prefix()) onto output - std::copy(m->prefix().first, m->prefix().second, std::back_inserter(output)); - - // Then replaces the matched subsequence with the formatted replacement string - m->format(std::back_inserter(output), format); - - // If there are no subsequent matches - if(std::next(m) == matchesEnd) { - // Copy the remaining non-matched characters onto output - std::copy(m->suffix().first, m->suffix().second, std::back_inserter(output)); - break; - } - // Otherwise go onto next match - else { - m++; - } - } - - // Set reference to newly processed version and return true - s = output; - return true; - } -} - std::string trimWhitespace(const std::string& str) { const std::string whitespace = " \t\r\n"; @@ -203,133 +60,6 @@ std::string trimWhitespace(const std::string& str) //---------------------------------------------------------------------------- namespace GeNN::CodeGenerator { -void substitute(std::string &s, const std::string &trg, const std::string &rep) -{ - size_t found= s.find(trg); - while (found != std::string::npos) { - s.replace(found,trg.length(),rep); - found= s.find(trg); - } -} -//---------------------------------------------------------------------------- -bool regexVarSubstitute(std::string &s, const std::string &trg, const std::string &rep) -{ - // Build a regex to match variable name with at least one - // character that can't be in a variable name on either side (or an end/beginning of string) - // **NOTE** the suffix is non-capturing so two instances of variables separated by a single character are matched e.g. a*a - std::regex regex("(^|[^0-9a-zA-Z_])" + trg + "(?=$|[^a-zA-Z0-9_])"); - - // Create format string to replace in text - // **NOTE** preceding character is captured as C++ regex doesn't support lookbehind so this needs to be replaced in - const std::string format = "$1" + rep; - - return regexSubstitute(s, regex, format); -} - -//---------------------------------------------------------------------------- -bool regexFuncSubstitute(std::string &s, const std::string &trg, const std::string &rep) -{ - // Build a regex to match function name with at least one - // character that can't be part of the function name on the left and a bracket on the right (with optional whitespace) - // **NOTE** the suffix is non-capturing so two instances of functions separated by a single character are matched e.g. sin(cos(x)); - std::regex regex("(^|[^0-9a-zA-Z_])" + trg + "(?=\\s*\\()"); - - // Create format string to replace in text - // **NOTE** preceding character is captured as C++ regex doesn't support lookbehind so this needs to be replaced in - const std::string format = "$1" + rep; - - return regexSubstitute(s, regex, format); -} -//---------------------------------------------------------------------------- -void functionSubstitute(std::string &code, const std::string &funcName, - unsigned int numParams, const std::string &replaceFuncTemplate) -{ - // If there are no parameters, just replace the function name (wrapped in '$()') - // with the template (which will, inherantly, not have any parameters) - if(numParams == 0) { - substitute(code, "$(" + funcName + ")", replaceFuncTemplate); - } - // Otherwise - else { - // Reserve vector to hold parameters - std::vector params; - params.reserve(numParams); - - // String to hold parameter currently being parsed - std::string currentParam = ""; - - // Function will start with opening GeNN wrapper, name and comma before first argument - // **NOTE** need to match up to comma so longer function names with same prefix aren't matched - const std::string funcStart = "$(" + funcName + ","; - - // Find first occurance of start of function - size_t idx = code.find(funcStart); - - // While functions are found - while (idx != std::string::npos) { - // Loop through characters following funcStart - unsigned int bracketDepth = 0; - const size_t startIdx = idx; - for(idx = idx + funcStart.length(); idx < code.size(); idx++) { - // If this character is a comma at function bracket depth - if(code[idx] == ',' && bracketDepth == 0) { - currentParam = trimWhitespace(currentParam); - assert(!currentParam.empty()); - - // Add parameter to array - params.push_back(currentParam); - currentParam = ""; - } - // Otherwise - else { - // If this is an open bracket, increase bracket depth - if(code[idx] == '(') { - bracketDepth++; - } - // Otherwise, it's a close bracket - else if(code[idx] == ')') { - // If we are at a deeper bracket depth than function, decrease bracket depth - if(bracketDepth > 0) { - bracketDepth--; - } - // Otherwise - else { - currentParam = trimWhitespace(currentParam); - assert(!currentParam.empty()); - - // Add parameter to array - params.push_back(currentParam); - currentParam = ""; - - // Check parameters match - assert(params.size() == numParams); - - // Substitute parsed parameters into function template - std::string replaceFunc = replaceFuncTemplate; - for(unsigned int p = 0; p < numParams; p++) { - substitute(replaceFunc, "$(" + std::to_string(p) + ")", params[p]); - } - - // Clear parameters now they have been substituted - // into the final string to replace in to code - params.clear(); - - // Replace this into code - code.replace(startIdx, idx - startIdx + 1, replaceFunc); - break; - } - } - - // Add character to parameter string - currentParam += code[idx]; - } - } - - // Find start of next function to replace - idx = code.find(funcStart, idx); - } - } -} //---------------------------------------------------------------------------- void genTypeRange(CodeStream &os, const Type::ResolvedType &type, const std::string &prefix) { @@ -339,119 +69,6 @@ void genTypeRange(CodeStream &os, const Type::ResolvedType &type, const std::str os << "#define " << prefix << "_MAX " << Utils::writePreciseString(numeric.max, numeric.maxDigits10) << numeric.literalSuffix << std::endl; } //---------------------------------------------------------------------------- -std::string ensureFtype(const std::string &oldcode, const std::string &type) -{ -// cerr << "entering ensure" << endl; -// cerr << oldcode << endl; - std::string code= oldcode; - unsigned int i= 0; - unsigned int state= 1; // allowed to start with a number straight away. - while (i < code.size()) { - switch (state) - { - case 0: // looking for a valid lead-in - if (op.find(code[i]) != std::string::npos) { - state= 1; - break; - } - break; - case 1: // looking for start of number - if (digits.find(code[i]) != std::string::npos) { - state= 2; // found the beginning of a number starting with a digit - break; - } - if (code[i] == '.') { - state= 3; // number starting with a dot - break; - } - if (op.find(code[i]) == std::string::npos) { - state= 0; - break; - } - break; - case 2: // in a number, looking for more digits, '.', 'e', 'E', or end of number - if (code[i] == '.') { - state= 3; // number now also contained a dot - break; - } - if ((code[i] == 'e') || (code[i] == 'E')) { - state= 4; - break; - } - if (digits.find(code[i]) == std::string::npos) {// the number looks like an integer ... - if (op.find(code[i]) != std::string::npos) state= 1; - else state= 0; - break; - } - break; - case 3: // we have had '.' now looking for digits or 'e', 'E' - if ((code[i] == 'e') || (code[i] == 'E')) { - state= 4; - break; - } - if (digits.find(code[i]) == std::string::npos) { - doFinal(code, i, type, state); - break; - } - break; - case 4: // we have had '.' and 'e', 'E', digits only now - if (digits.find(code[i]) != std::string::npos) { - state= 6; - break; - } - if ((code[i] != '+') && (code[i] != '-')) { - if (op.find(code[i]) != std::string::npos) state= 1; - else state= 0; - break; - } - else { - state= 5; - break; - } - case 5: // now one or more digits or else ... - if (digits.find(code[i]) != std::string::npos) { - state= 6; - break; - } - else { - if (op.find(code[i]) != std::string::npos) state= 1; - else state= 0; - break; - } - case 6: // any non-digit character will trigger action - if (digits.find(code[i]) == std::string::npos) { - doFinal(code, i, type, state); - break; - } - break; - } - i++; - } - if ((state == 3) || (state == 6)) { - if (type == "float") { - code= code+"f"; - } - } - ensureMathFunctionFtype(code); - return code; -} -//---------------------------------------------------------------------------- -void checkUnreplacedVariables(const std::string &code, const std::string &codeName) -{ - std::regex rgx("\\$\\([\\w]+\\)"); - std::string vars= ""; - for (std::sregex_iterator it(code.begin(), code.end(), rgx), end; it != end; it++) { - vars+= it->str().substr(2,it->str().size()-3) + ", "; - } - if (vars.size() > 0) { - vars= vars.substr(0, vars.size()-2); - - vars = (vars.find(",") != std::string::npos) ? "variables " + vars + " were " : "variable " + vars + " was "; - - throw std::runtime_error("The "+vars+"undefined in code "+codeName+"."); - } -} -//---------------------------------------------------------------------------- std::string disambiguateNamespaceFunction(const std::string supportCode, const std::string code, std::string namespaceName) { // Regex for function call - looks for words with succeeding parentheses with or without any data inside the parentheses (arguments) std::regex funcCallRegex(R"(\w+(?=\(.*\)))"); From f5fdd9aaa0d6a483d936c2407a0d79b980932358 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sun, 2 Jul 2023 11:44:14 +0100 Subject: [PATCH 304/725] update current source model syntax --- include/genn/genn/currentSourceModels.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/include/genn/genn/currentSourceModels.h b/include/genn/genn/currentSourceModels.h index 1974b46b9f..ac7196d225 100644 --- a/include/genn/genn/currentSourceModels.h +++ b/include/genn/genn/currentSourceModels.h @@ -55,7 +55,7 @@ class DC : public Base { DECLARE_SNIPPET(DC); - SET_INJECTION_CODE("$(injectCurrent, $(amp));\n"); + SET_INJECTION_CODE("injectCurrent(amp);\n"); SET_PARAM_NAMES({"amp"}); }; @@ -72,7 +72,7 @@ class GaussianNoise : public Base { DECLARE_SNIPPET(GaussianNoise); - SET_INJECTION_CODE("$(injectCurrent, $(mean) + $(gennrand_normal) * $(sd));\n"); + SET_INJECTION_CODE("injectCurrent(mean + (gennrand_normal() * sd));\n"); SET_PARAM_NAMES({"mean", "sd"} ); }; @@ -92,16 +92,16 @@ class PoissonExp : public Base DECLARE_SNIPPET(PoissonExp); SET_INJECTION_CODE( - "scalar p = 1.0f;\n" + "scalar p = 1.0;\n" "unsigned int numSpikes = 0;\n" "do\n" "{\n" " numSpikes++;\n" - " p *= $(gennrand_uniform);\n" + " p *= gennrand_uniform();\n" "} while (p > $(ExpMinusLambda));\n" - "$(current) += $(Init) * (scalar)(numSpikes - 1);\n" - "$(injectCurrent, $(current));\n" - "$(current) *= $(ExpDecay);\n"); + "current += Init * (scalar)(numSpikes - 1);\n" + "injectCurrent(current);\n" + "current *= ExpDecay;\n"); SET_PARAM_NAMES({"weight", "tauSyn", "rate"}); SET_VARS({{"current", "scalar"}}); From 34608577bb21bd99ad64a851f42878050ad6f175 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sun, 2 Jul 2023 11:44:30 +0100 Subject: [PATCH 305/725] fixed num_neurons typo --- src/genn/genn/code_generator/initGroupMerged.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index a98568e48a..32db799a6c 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -201,7 +201,7 @@ void NeuronInitGroupMerged::CurrentSource::generate(const BackendBase &backend, { genInitNeuronVarCode( backend, env, *this, ng, "CS" + std::to_string(getIndex()), - "$(num_neurons)", 0, modelMerged.getModel().getBatchSize()); + "num_neurons", 0, modelMerged.getModel().getBatchSize()); } @@ -253,7 +253,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir } genInitNeuronVarCode( - backend, groupEnv, *this, ng, fieldSuffix, "$(num_neurons)", 0, modelMerged.getModel().getBatchSize()); + backend, groupEnv, *this, ng, fieldSuffix, "num_neurons", 0, modelMerged.getModel().getBatchSize()); } //---------------------------------------------------------------------------- From 1d0293b3f9919a47b88a36373874e89ba02874fe Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sun, 2 Jul 2023 16:54:31 +0100 Subject: [PATCH 306/725] condition of Conditional expression needs visiting during typechecking even though its type isn't super-relevant --- src/genn/genn/transpiler/typeChecker.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 75a54f88bb..8297a85ed9 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -420,6 +420,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Conditional &conditional) final { + evaluateType(conditional.getCondition()); const auto trueType = evaluateType(conditional.getTrue()); const auto falseType = evaluateType(conditional.getFalse()); if (trueType.isNumeric() && falseType.isNumeric()) { From b7fa3dce2061a1972e66b178c68ec55711bc7725 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sun, 2 Jul 2023 16:54:43 +0100 Subject: [PATCH 307/725] pow takes two arguments! --- src/genn/genn/code_generator/standardLibrary.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/code_generator/standardLibrary.cc b/src/genn/genn/code_generator/standardLibrary.cc index a09db367b7..aed9c0f758 100644 --- a/src/genn/genn/code_generator/standardLibrary.cc +++ b/src/genn/genn/code_generator/standardLibrary.cc @@ -63,7 +63,7 @@ const auto libraryTypes = initLibraryTypes( ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(exp), ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(expm1), ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(exp2), - ADD_ONE_ARG_FLOAT_DOUBLE_FUNC(pow), + ADD_TWO_ARG_FLOAT_DOUBLE_FUNC(pow), std::make_pair("scalbn", std::make_pair(Type::ResolvedType::createFunction(Type::Float, {Type::Float, Type::Int32}), "scalbn($(0), $(1))")), std::make_pair("scalbn", std::make_pair(Type::ResolvedType::createFunction(Type::Double, {Type::Double, Type::Int32}), "scalbn($(0), $(1))")), From 2c188be5e9ab9818789e03b428981f428319fc00 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sun, 2 Jul 2023 16:55:06 +0100 Subject: [PATCH 308/725] SynapseSparseInitGroups needs to be provided with indices earlier on --- src/genn/backends/single_threaded_cpu/backend.cc | 1 + src/genn/genn/code_generator/initGroupMerged.cc | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 5797ca2176..d58ef8e339 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1144,6 +1144,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Get reference to group funcEnv.getStream() << "const auto *group = &mergedSynapseSparseInitGroup" << s.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, s); + genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); // If postsynaptic learning is required, initially zero column lengths if (!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 32db799a6c..e814c90353 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -612,9 +612,7 @@ boost::uuids::detail::sha1::digest_type SynapseSparseInitGroupMerged::getHashDig void SynapseSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { // Create environment for group - EnvironmentGroupMergedField groupEnv(env, *this); - backend.genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); - genInitWUVarCode(backend, groupEnv, *this, "$(num_pre) * $(_row_stride)", modelMerged.getModel().getBatchSize(), + genInitWUVarCode(backend, env, *this, "$(num_pre) * $(_row_stride)", modelMerged.getModel().getBatchSize(), [&backend](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { backend.genSparseSynapseVariableRowInit(varInitEnv, handler); From 2fdf9c6e7ded164acb814fd475f141e0fc5371e5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sun, 2 Jul 2023 19:44:12 +0100 Subject: [PATCH 309/725] fixed variable initialization parameters --- .../genn/genn/code_generator/environment.h | 70 +++++++++---------- .../genn/code_generator/initGroupMerged.cc | 8 +-- 2 files changed, 36 insertions(+), 42 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 8e8315bd61..2087a77a5d 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -564,51 +564,45 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase - void addVarInitParams(IsVarInitHeterogeneousFn isHeterogeneous, const std::string &fieldSuffix = "") + void addVarInitParams(IsVarInitHeterogeneousFn isHeterogeneous, + const std::string &varName, const std::string &fieldSuffix = "") { - // Loop through weight update model variables - const A archetypeAdaptor(getGroup().getArchetype()); - for(const auto &v : archetypeAdaptor.getDefs()) { - // Loop through parameters - for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getParams()) { - // If parameter is heterogeneous, add scalar field - if(std::invoke(isHeterogeneous, getGroup(), v.name, p.first)) { - addScalar(p.first, v.name + fieldSuffix, - [p, v](const auto &g, size_t) - { - return A(g).getInitialisers().at(v.name).getParams().at(p.first); - }); - } - // Otherwise, just add a const-qualified scalar to the type environment with archetype value - else { - add(getGroup().getScalarType().addConst(), p.first, - writePreciseLiteral(p.second, getGroup().getScalarType())); - } + // Loop through parameters + for(const auto &p : A(getGroup().getArchetype()).getInitialisers().at(varName).getParams()) { + // If parameter is heterogeneous, add scalar field + if(std::invoke(isHeterogeneous, getGroup(), varName, p.first)) { + addScalar(p.first, varName + fieldSuffix, + [p, varName](const auto &g, size_t) + { + return A(g).getInitialisers().at(varName).getParams().at(p.first); + }); + } + // Otherwise, just add a const-qualified scalar to the type environment with archetype value + else { + add(getGroup().getScalarType().addConst(), p.first, + writePreciseLiteral(p.second, getGroup().getScalarType())); } } } template - void addVarInitDerivedParams(IsVarInitHeterogeneousFn isHeterogeneous, const std::string &fieldSuffix = "") + void addVarInitDerivedParams(IsVarInitHeterogeneousFn isHeterogeneous, + const std::string &varName, const std::string &fieldSuffix = "") { - // Loop through weight update model variables - const A archetypeAdaptor(getGroup().getArchetype()); - for(const auto &v : archetypeAdaptor.getDefs()) { - // Loop through parameters - for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getDerivedParams()) { - // If derived parameter is heterogeneous, add scalar field - if(std::invoke(isHeterogeneous, getGroup(), v.name, p.first)) { - addScalar(p.first, v.name + fieldSuffix, - [p, v](const auto &g, size_t) - { - return A(g).getInitialisers().at(v.name).getDerivedParams().at(p.first); - }); - } - // Otherwise, just add a const-qualified scalar to the type environment with archetype value - else { - add(getGroup().getScalarType().addConst(), p.first, - writePreciseLiteral(p.second, getGroup().getScalarType())); - } + // Loop through derived parameters + for(const auto &p : A(getGroup().getArchetype()).getInitialisers().at(varName).getDerivedParams()) { + // If derived parameter is heterogeneous, add scalar field + if(std::invoke(isHeterogeneous, getGroup(), varName, p.first)) { + addScalar(p.first, varName + fieldSuffix, + [p, varName](const auto &g, size_t) + { + return A(g).getInitialisers().at(varName).getDerivedParams().at(p.first); + }); + } + // Otherwise, just add a const-qualified scalar to the type environment with archetype value + else { + add(getGroup().getScalarType().addConst(), p.first, + writePreciseLiteral(p.second, getGroup().getScalarType())); } } } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index e814c90353..762175629f 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -76,8 +76,8 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e // Substitute in parameters and derived parameters for initialising variables EnvironmentGroupMergedField varEnv(env, group, fieldGroup); - varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous, fieldSuffix); - varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, fieldSuffix); + varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous, var.name, fieldSuffix); + varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, var.name, fieldSuffix); varEnv.addExtraGlobalParams(snippet->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name, fieldSuffix); // Add field for variable itself @@ -159,8 +159,8 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, // Substitute in parameters and derived parameters for initialising variables EnvironmentGroupMergedField varEnv(env, group); - varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous); - varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous); + varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous, var.name); + varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, var.name); varEnv.addExtraGlobalParams(snippet->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name); // Add field for variable itself From 7eabfc85bbd7b430f1881539070e2461e10ba094 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sun, 2 Jul 2023 19:49:20 +0100 Subject: [PATCH 310/725] missing $ --- src/genn/genn/code_generator/synapseUpdateGroupMerged.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 94b1d960d9..7df932b4c5 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -236,7 +236,7 @@ std::string SynapseGroupMergedBase::getPostDenDelayIndex(unsigned int batchSize, return "(*$(_den_delay_ptr) * $(num_post) + " + batchID; } else { - return "(((*(_den_delay_ptr) + " + offset + ") % " + std::to_string(getArchetype().getMaxDendriticDelayTimesteps()) + ") * $(num_post)) + " + batchID; + return "(((*$(_den_delay_ptr) + " + offset + ") % " + std::to_string(getArchetype().getMaxDendriticDelayTimesteps()) + ") * $(num_post)) + " + batchID; } } //---------------------------------------------------------------------------- From 6b7e807f85379667c46200becadfa94641fccb99 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 3 Jul 2023 17:22:25 +0100 Subject: [PATCH 311/725] removed seperate scanning+parsing+typechecking functions --- .../genn/genn/code_generator/codeGenUtils.h | 16 ------ src/genn/genn/code_generator/codeGenUtils.cc | 51 ++++++------------- 2 files changed, 15 insertions(+), 52 deletions(-) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 99e6ea6a60..bc2f7e056d 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -63,22 +63,6 @@ GENN_EXPORT std::string disambiguateNamespaceFunction(const std::string supportC //-------------------------------------------------------------------------- GENN_EXPORT std::string upgradeCodeString(const std::string &codeString); -//-------------------------------------------------------------------------- -/*! \brief This function uses the transpiler to scan, parse and type check statements contained in a code string - */ - //-------------------------------------------------------------------------- -GENN_EXPORT std::tuple scanParseAndTypeCheckStatements( - const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, - Transpiler::ErrorHandlerBase &errorHandler, Transpiler::TypeChecker::StatementHandler forEachSynapseHandler = nullptr); - -//-------------------------------------------------------------------------- -/*! \brief This function uses the transpiler to scan, parse and type check expression contained in a code string - */ - //-------------------------------------------------------------------------- -GENN_EXPORT std::tuple scanParseAndTypeCheckExpression( - const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, Transpiler::ErrorHandlerBase &errorHandler); - - //-------------------------------------------------------------------------- /*! \brief This function uses the transpiler to scan, parse, type check and pretty print expression contained in a code string */ diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index a10e4fc76d..fdb4bd5a05 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -124,9 +124,7 @@ std::string upgradeCodeString(const std::string &codeString) return upgradedCodeString; } //---------------------------------------------------------------------------- -std::tuple scanParseAndTypeCheckStatements( - const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, - Transpiler::ErrorHandlerBase &errorHandler, Transpiler::TypeChecker::StatementHandler forEachSynapseHandler) +void prettyPrintExpression(const std::string &code, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, Transpiler::ErrorHandlerBase &errorHandler) { using namespace Transpiler; @@ -136,55 +134,36 @@ std::tuple scanParseAndTypeCheckExpression( - const std::string &code, const Type::TypeContext &typeContext, Transpiler::TypeChecker::EnvironmentBase &environment, Transpiler::ErrorHandlerBase &errorHandler) + //-------------------------------------------------------------------------- +void prettyPrintStatements(const std::string &code, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, + Transpiler::ErrorHandlerBase &errorHandler, Transpiler::TypeChecker::StatementHandler forEachSynapseTypeCheckHandler, + Transpiler::PrettyPrinter::StatementHandler forEachSynapsePrettyPrintHandler) { using namespace Transpiler; - + // Upgrade code string const std::string upgradedCode = upgradeCodeString(code); // Scan code string to convert to tokens const auto tokens = Scanner::scanSource(upgradedCode, typeContext, errorHandler); - // Parse tokens as expression - auto expression = Parser::parseExpression(tokens, typeContext, errorHandler); + // Parse tokens as block item list (function body) + auto updateStatements = Parser::parseBlockItemList(tokens, typeContext, errorHandler); // Resolve types - auto resolvedTypes= TypeChecker::typeCheck(expression.get(), environment, errorHandler); - - // Move into tuple and eturn - return std::make_tuple(std::move(expression), std::move(resolvedTypes)); -} -//---------------------------------------------------------------------------- -void prettyPrintExpression(const std::string &code, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, Transpiler::ErrorHandlerBase &errorHandler) -{ - // Scan, parse and type check expression - auto expressionTypes = scanParseAndTypeCheckExpression(code, typeContext, env, errorHandler); - - // Pretty print - Transpiler::PrettyPrinter::print(std::get<0>(expressionTypes), env, typeContext, std::get<1>(expressionTypes)); -} - //-------------------------------------------------------------------------- -void prettyPrintStatements(const std::string &code, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, - Transpiler::ErrorHandlerBase &errorHandler, Transpiler::TypeChecker::StatementHandler forEachSynapseTypeCheckHandler, - Transpiler::PrettyPrinter::StatementHandler forEachSynapsePrettyPrintHandler) -{ - // Scan, parse and type check statements - auto statementTypes = scanParseAndTypeCheckStatements(code, typeContext, env, errorHandler, forEachSynapseTypeCheckHandler); + auto resolvedTypes= TypeChecker::typeCheck(updateStatements, env, errorHandler, forEachSynapseTypeCheckHandler); // Pretty print - Transpiler::PrettyPrinter::print(std::get<0>(statementTypes), env, typeContext, std::get<1>(statementTypes), forEachSynapsePrettyPrintHandler); + PrettyPrinter::print(updateStatements, env, typeContext, resolvedTypes, forEachSynapsePrettyPrintHandler); } //-------------------------------------------------------------------------- std::string printSubs(const std::string &format, EnvironmentExternalBase &env) From 0cc709fa28de3556256b6173d54f72d707a29cf4 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 4 Jul 2023 13:02:24 +0100 Subject: [PATCH 312/725] started fixing up GCC compilation --- .../genn/genn/code_generator/backendBase.h | 1 + .../genn/genn/code_generator/groupMerged.h | 41 ++++++++++--------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index d57235ae39..d614f026a7 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -18,6 +18,7 @@ // GeNN includes #include "gennExport.h" #include "gennUtils.h" +#include "synapseMatrixType.h" #include "type.h" #include "varAccess.h" #include "variableMode.h" diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 8e6502d018..3a9a61c3e4 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -195,7 +195,7 @@ template class GroupMerged : public ChildGroupMerged { public: - GroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> groups) + GroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> groups) : ChildGroupMerged(index, typeContext, std::move(groups)) {} @@ -209,17 +209,17 @@ class GroupMerged : public ChildGroupMerged const std::string &getMemorySpace() const { return m_MemorySpace; } //! Get group fields - const std::vector &getFields() const{ return m_Fields; } + const std::vector::Field> &getFields() const{ return m_Fields; } //! Get group fields, sorted into order they will appear in struct - std::vector getSortedFields(const BackendBase &backend) const + std::vector::Field> getSortedFields(const BackendBase &backend) const { // Make a copy of fields and sort so largest come first. This should mean that due // to structure packing rules, significant memory is saved and estimate is more precise auto sortedFields = m_Fields; const size_t pointerBytes = backend.getPointerBytes(); std::sort(sortedFields.begin(), sortedFields.end(), - [pointerBytes](const Field &a, const Field &b) + [pointerBytes](const auto &a, const auto &b) { return (std::get<0>(a).getSize(pointerBytes) > std::get<0>(b).getSize(pointerBytes)); }); @@ -230,7 +230,7 @@ class GroupMerged : public ChildGroupMerged //! Generate declaration of struct to hold this merged group void generateStruct(CodeStream &os, const BackendBase &backend, const std::string &name, bool host = false) const { - os << "struct Merged" << name << "Group" << getIndex() << std::endl; + os << "struct Merged" << name << "Group" << this->getIndex() << std::endl; { // Loop through fields and write to structure CodeStream::Scope b(os); @@ -291,7 +291,7 @@ class GroupMerged : public ChildGroupMerged // Add total size of array of merged structures to merged struct data // **NOTE** to match standard struct packing rules we pad to a multiple of the largest field size - return padSize(structSize, largestFieldSize) * getGroups().size(); + return padSize(structSize, largestFieldSize) * this->getGroups().size(); } //! Assign memory spaces to group @@ -320,9 +320,10 @@ class GroupMerged : public ChildGroupMerged } } - void addField(const Type::ResolvedType &type, const std::string &name, GetFieldValueFunc getFieldValue, GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) + void addField(const Type::ResolvedType &type, const std::string &name, typename ChildGroupMerged::GetFieldValueFunc getFieldValue, + GroupMergedFieldType fieldType = GroupMergedFieldType::STANDARD) { - // Add field to data structure + // Add field to data structurChildGroupMergede m_Fields.emplace_back(type, name, getFieldValue, fieldType); } @@ -341,7 +342,7 @@ class GroupMerged : public ChildGroupMerged // If this isn't a host merged structure, generate definition for function to push group if(!host) { - definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << "Group" << getIndex() << "ToDevice(unsigned int idx, "; + definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << "Group" << this->getIndex() << "ToDevice(unsigned int idx, "; generateStructFieldArgumentDefinitions(definitionsInternalFunc, backend); definitionsInternalFunc << ");" << std::endl; } @@ -350,7 +351,7 @@ class GroupMerged : public ChildGroupMerged for(const auto &f : sortedFields) { // If this field is a dynamic pointer if((std::get<3>(f) & GroupMergedFieldType::DYNAMIC) && std::get<0>(f).isPointer()) { - definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << getIndex() << std::get<1>(f) << "ToDevice(unsigned int idx, "; + definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << this->getIndex() << std::get<1>(f) << "ToDevice(unsigned int idx, "; definitionsInternalFunc << backend.getMergedGroupFieldHostTypeName(std::get<0>(f)) << " value);" << std::endl; } @@ -365,23 +366,23 @@ class GroupMerged : public ChildGroupMerged generateStruct(definitionsInternal, backend, name, true); // Declare array of these structs containing individual neuron group pointers etc - runnerVarDecl << "Merged" << name << "Group" << getIndex() << " merged" << name << "Group" << getIndex() << "[" << getGroups().size() << "];" << std::endl; + runnerVarDecl << "Merged" << name << "Group" << this->getIndex() << " merged" << name << "Group" << this->getIndex() << "[" << this->getGroups().size() << "];" << std::endl; // Export it - definitionsInternalVar << "EXPORT_VAR Merged" << name << "Group" << getIndex() << " merged" << name << "Group" << getIndex() << "[" << getGroups().size() << "]; " << std::endl; + definitionsInternalVar << "EXPORT_VAR Merged" << name << "Group" << this->getIndex() << " merged" << name << "Group" << this->getIndex() << "[" << this->getGroups().size() << "]; " << std::endl; } // Loop through groups - for(size_t groupIndex = 0; groupIndex < getGroups().size(); groupIndex++) { + for(size_t groupIndex = 0; groupIndex < this->getGroups().size(); groupIndex++) { // If this is a merged group used on the host, directly set array entry if(host) { - runnerMergedStructAlloc << "merged" << name << "Group" << getIndex() << "[" << groupIndex << "] = {"; + runnerMergedStructAlloc << "merged" << name << "Group" << this->getIndex() << "[" << groupIndex << "] = {"; generateStructFieldArguments(runnerMergedStructAlloc, groupIndex, sortedFields); runnerMergedStructAlloc << "};" << std::endl; } // Otherwise, call function to push to device else { - runnerMergedStructAlloc << "pushMerged" << name << "Group" << getIndex() << "ToDevice(" << groupIndex << ", "; + runnerMergedStructAlloc << "pushMerged" << name << "Group" << this->getIndex() << "ToDevice(" << groupIndex << ", "; generateStructFieldArguments(runnerMergedStructAlloc, groupIndex, sortedFields); runnerMergedStructAlloc << ");" << std::endl; } @@ -393,10 +394,10 @@ class GroupMerged : public ChildGroupMerged // Private methods //------------------------------------------------------------------------ void generateStructFieldArguments(CodeStream &os, size_t groupIndex, - const std::vector &sortedFields) const + const std::vector::Field> &sortedFields) const { // Get group by index - const auto &g = getGroups()[groupIndex]; + const auto &g = this->getGroups()[groupIndex]; // Loop through fields for(size_t fieldIndex = 0; fieldIndex < sortedFields.size(); fieldIndex++) { @@ -413,7 +414,7 @@ class GroupMerged : public ChildGroupMerged // Members //------------------------------------------------------------------------ std::string m_MemorySpace; - std::vector m_Fields; + std::vector::Field> m_Fields; }; //---------------------------------------------------------------------------- @@ -431,14 +432,14 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged void orderNeuronGroupChildren(std::vector &childGroups, const Type::TypeContext &typeContext, G getVectorFunc, H getHashDigestFunc) const { - const std::vector &archetypeChildren = (getArchetype().*getVectorFunc)(); + const std::vector &archetypeChildren = std::invoke(getVectorFunc, getArchetype()); // Resize vector of vectors to hold children for all neuron groups, sorted in a consistent manner std::vector>> sortedGroupChildren; sortedGroupChildren.resize(archetypeChildren.size()); // Create temporary vector of children and their digests - std::vector> childDigests; + std::vector> childDigests; childDigests.reserve(archetypeChildren.size()); // Loop through groups From fa72bbd678fff82eabfe36760ef795eb6a256d65 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 4 Jul 2023 14:10:21 +0100 Subject: [PATCH 313/725] updated implementation of various methods which search for function calls - this is a hack, should be driven by parser --- src/genn/genn/gennUtils.cc | 13 +++---------- src/genn/genn/synapseGroup.cc | 25 ++++++++++++++++--------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/genn/genn/gennUtils.cc b/src/genn/genn/gennUtils.cc index e944c97369..83e8796593 100644 --- a/src/genn/genn/gennUtils.cc +++ b/src/genn/genn/gennUtils.cc @@ -44,19 +44,12 @@ namespace GeNN::Utils bool isRNGRequired(const std::string &code) { // Loop through random functions + // **TODO** regex followed by optional whitespace and ( would b better for(const auto &r : randomFuncs) { - // If this function takes no arguments, return true if - // generic function name enclosed in $() markers is found - if(r.numArguments == 0) { - if(code.find("$(" + r.genericName + ")") != std::string::npos) { - return true; - } - } - // Otherwise, return true if generic function name - // prefixed by $( and suffixed with comma is found - else if(code.find("$(" + r.genericName + ",") != std::string::npos) { + if(code.find(r.genericName) != std::string::npos) { return true; } + } return false; diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index ef93d63d15..da9835d9d1 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -308,17 +308,20 @@ VarLocation SynapseGroup::getSparseConnectivityExtraGlobalParamLocation(const st bool SynapseGroup::isDendriticDelayRequired() const { // If addToInSynDelay function is used in sim code, return true - if(getWUModel()->getSimCode().find("$(addToInSynDelay") != std::string::npos) { + // **TODO** regex followed by optional whitespace and ( would b better + if(getWUModel()->getSimCode().find("addToPostDelay") != std::string::npos) { return true; } // If addToInSynDelay function is used in event code, return true - if(getWUModel()->getEventCode().find("$(addToInSynDelay") != std::string::npos) { + // **TODO** regex followed by optional whitespace and ( would b better + if(getWUModel()->getEventCode().find("addToPostDelay") != std::string::npos) { return true; } - // If addToInSynDelay function is used in synapse dynamics, return true - if(getWUModel()->getSynapseDynamicsCode().find("$(addToInSynDelay") != std::string::npos) { + // If addToInSynDelay function is used in synapse dynamics, return tru + // **TODO** regex followed by optional whitespace and ( would b bettere + if(getWUModel()->getSynapseDynamicsCode().find("addToPostDelay") != std::string::npos) { return true; } @@ -328,22 +331,26 @@ bool SynapseGroup::isDendriticDelayRequired() const bool SynapseGroup::isPresynapticOutputRequired() const { // If addToPre function is used in sim_code, return true - if(getWUModel()->getSimCode().find("$(addToPre") != std::string::npos) { + // **TODO** regex followed by optional whitespace and ( would b better + if(getWUModel()->getSimCode().find("addToPre") != std::string::npos) { return true; } // If addToPre function is used in learn_post_code, return true - if(getWUModel()->getLearnPostCode().find("$(addToPre") != std::string::npos) { + // **TODO** regex followed by optional whitespace and ( would b better + if(getWUModel()->getLearnPostCode().find("addToPre") != std::string::npos) { return true; } // If addToPre function is used in event_code, return true - if(getWUModel()->getEventCode().find("$(addToPre") != std::string::npos) { + // **TODO** regex followed by optional whitespace and ( would b better + if(getWUModel()->getEventCode().find("addToPre") != std::string::npos) { return true; } // If addToPre function is used in synapse_dynamics, return true - if(getWUModel()->getSynapseDynamicsCode().find("$(addToPre") != std::string::npos) { + // **TODO** regex followed by optional whitespace and ( would b better + if(getWUModel()->getSynapseDynamicsCode().find("addToPre") != std::string::npos) { return true; } @@ -395,7 +402,7 @@ bool SynapseGroup::isWUPostInitRNGRequired() const //---------------------------------------------------------------------------- bool SynapseGroup::isHostInitRNGRequired() const { - return (m_SparseConnectivityInitialiser.getSnippet()->getHostInitCode().find("$(rng)") != std::string::npos); + return Utils::isRNGRequired(m_SparseConnectivityInitialiser.getSnippet()->getHostInitCode()); } //---------------------------------------------------------------------------- bool SynapseGroup::isWUVarInitRequired() const From dff4824eaeb25f436e61d843714bed38a64ace9e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 4 Jul 2023 16:18:06 +0100 Subject: [PATCH 314/725] token-based functions for determining if identifiers are referenced and whether any of the RNG functions are identified --- include/genn/genn/gennUtils.h | 12 +++++-- src/genn/genn/gennUtils.cc | 65 +++++++++++++++-------------------- 2 files changed, 37 insertions(+), 40 deletions(-) diff --git a/include/genn/genn/gennUtils.h b/include/genn/genn/gennUtils.h index d2b6cd0763..8cdae6fcbe 100644 --- a/include/genn/genn/gennUtils.h +++ b/include/genn/genn/gennUtils.h @@ -21,6 +21,9 @@ // GeNN includes #include "gennExport.h" +// GeNN code generator includes +#include "transpiler/token.h" + // Forward declarations namespace GeNN::Models { @@ -35,12 +38,17 @@ namespace GeNN::Utils //-------------------------------------------------------------------------- //! \brief Does the code string contain any functions requiring random number generator //-------------------------------------------------------------------------- -GENN_EXPORT bool isRNGRequired(const std::string &code); +GENN_EXPORT bool isIdentifierReferenced(const std::string &identifierName, const std::vector &tokens); + +//-------------------------------------------------------------------------- +//! \brief Does the code string contain any functions requiring random number generator +//-------------------------------------------------------------------------- +GENN_EXPORT bool isRNGRequired(const std::vector &tokens); //-------------------------------------------------------------------------- //! \brief Does the model with the vectors of variable initialisers and modes require an RNG for the specified init location i.e. host or device //-------------------------------------------------------------------------- -GENN_EXPORT bool isRNGRequired(const std::unordered_map &varInitialisers); +GENN_EXPORT bool isRNGRequired(const std::unordered_map> &varInitialisers); //-------------------------------------------------------------------------- //! \brief Is the variable name valid? GeNN variable names must obey C variable naming rules diff --git a/src/genn/genn/gennUtils.cc b/src/genn/genn/gennUtils.cc index 83e8796593..66bacf72ee 100644 --- a/src/genn/genn/gennUtils.cc +++ b/src/genn/genn/gennUtils.cc @@ -11,29 +11,13 @@ namespace { -//-------------------------------------------------------------------------- -// GenericFunction -//-------------------------------------------------------------------------- -//! Immutable structure for specifying the name and number of -//! arguments of a generic funcion e.g. gennrand_uniform -struct GenericFunction -{ - //! Generic name used to refer to function in user code - const std::string genericName; - - //! Number of function arguments - const unsigned int numArguments; -}; - - -GenericFunction randomFuncs[] = { - {"gennrand_uniform", 0}, - {"gennrand_normal", 0}, - {"gennrand_exponential", 0}, - {"gennrand_log_normal", 2}, - {"gennrand_gamma", 1}, - {"gennrand_binomial", 2} -}; +const std::unordered_set randomFuncs{ + "gennrand_uniform" + "gennrand_normal", + "gennrand_exponential", + "gennrand_log_normal", + "gennrand_gamma", + "gennrand_binomial"}; } //-------------------------------------------------------------------------- @@ -41,28 +25,33 @@ GenericFunction randomFuncs[] = { //-------------------------------------------------------------------------- namespace GeNN::Utils { -bool isRNGRequired(const std::string &code) +bool isIdentifierReferenced(const std::string &identifierName, const std::vector &tokens) { - // Loop through random functions - // **TODO** regex followed by optional whitespace and ( would b better - for(const auto &r : randomFuncs) { - if(code.find(r.genericName) != std::string::npos) { - return true; - } - - } - return false; + // Return true if any identifier's lexems match identifier name + return std::any_of(tokens.cbegin(), tokens.cend(), + [&identifierName](const auto &t) + { + return (t.type == Transpiler::Token::Type::IDENTIFIER && t.lexeme == identifierName); + }); + +} +//-------------------------------------------------------------------------- +bool isRNGRequired(const std::vector &tokens) +{ + // Return true if any identifier's lexems are in set of random functions + return std::any_of(tokens.cbegin(), tokens.cend(), + [](const auto &t) + { + return (t.type == Transpiler::Token::Type::IDENTIFIER && randomFuncs.find(t.lexeme) != randomFuncs.cend()); + }); } //-------------------------------------------------------------------------- -bool isRNGRequired(const std::unordered_map &varInitialisers) +bool isRNGRequired(const std::unordered_map> &varInitialisers) { // Return true if any of these variable initialisers require an RNG return std::any_of(varInitialisers.cbegin(), varInitialisers.cend(), - [](const auto &varInit) - { - return isRNGRequired(varInit.second.getSnippet()->getCode()); - }); + [](const auto &varInit) { return isRNGRequired(varInit.second); }); } //-------------------------------------------------------------------------- void validateVarName(const std::string &name, const std::string &description) From 966ecb17ea6d0fc88aedf8feb9f08106b78c3d95 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 4 Jul 2023 17:05:26 +0100 Subject: [PATCH 315/725] started hooking up scanning of tokens in group finalise functions --- include/genn/genn/gennUtils.h | 4 ++ include/genn/genn/modelSpec.h | 2 + include/genn/genn/models.h | 11 ++++ include/genn/genn/neuronGroup.h | 38 +++++++++---- include/genn/genn/neuronGroupInternal.h | 2 +- include/genn/genn/snippet.h | 15 ++--- src/genn/genn/customConnectivityUpdate.cc | 2 +- src/genn/genn/gennUtils.cc | 49 ++++++++++++++++ src/genn/genn/initVarSnippet.cc | 3 +- src/genn/genn/modelSpec.cc | 11 +++- src/genn/genn/models.cc | 7 +++ src/genn/genn/neuronGroup.cc | 68 +++++++++++++---------- 12 files changed, 160 insertions(+), 52 deletions(-) diff --git a/include/genn/genn/gennUtils.h b/include/genn/genn/gennUtils.h index 8cdae6fcbe..2e76df47de 100644 --- a/include/genn/genn/gennUtils.h +++ b/include/genn/genn/gennUtils.h @@ -20,6 +20,7 @@ // GeNN includes #include "gennExport.h" +#include "type.h" // GeNN code generator includes #include "transpiler/token.h" @@ -35,6 +36,9 @@ class VarInit; //-------------------------------------------------------------------------- namespace GeNN::Utils { +GENN_EXPORT std::vector scanCode(const std::string &code, const Type::TypeContext &typeContext, + const std::string &errorContext); + //-------------------------------------------------------------------------- //! \brief Does the code string contain any functions requiring random number generator //-------------------------------------------------------------------------- diff --git a/include/genn/genn/modelSpec.h b/include/genn/genn/modelSpec.h index 5c7e502af3..e886257537 100644 --- a/include/genn/genn/modelSpec.h +++ b/include/genn/genn/modelSpec.h @@ -675,6 +675,8 @@ class GENN_EXPORT ModelSpec //! Get hash digest used for detecting changes boost::uuids::detail::sha1::digest_type getHashDigest() const; + Type::TypeContext getTypeContext() const; + //! Get std::map containing local named NeuronGroup objects in model const std::map &getNeuronGroups() const{ return m_LocalNeuronGroups; } diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 1df29d2c24..927447f260 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -136,6 +136,17 @@ class VarInit : public Snippet::Init : Snippet::Init(InitVarSnippet::Constant::getInstance(), {{"constant", constant}}) { } + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + void finalise(double dt, const Type::TypeContext &context, const std::string &errorContext); + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + std::vector m_CodeTokens; }; //---------------------------------------------------------------------------- diff --git a/include/genn/genn/neuronGroup.h b/include/genn/genn/neuronGroup.h index 8425f65077..edd7417954 100644 --- a/include/genn/genn/neuronGroup.h +++ b/include/genn/genn/neuronGroup.h @@ -176,15 +176,6 @@ class GENN_EXPORT NeuronGroup //! Is spike event recording enabled for this population? bool isSpikeEventRecordingEnabled() const { return m_SpikeEventRecordingEnabled; } - //! Does this neuron group require an RNG to simulate? - bool isSimRNGRequired() const; - - //! Does this neuron group require an RNG for it's init code? - bool isInitRNGRequired() const; - - //! Does this neuron group require any sort of recording? - bool isRecordingEnabled() const; - protected: NeuronGroup(const std::string &name, int numNeurons, const NeuronModels::Base *neuronModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, @@ -207,7 +198,7 @@ class GENN_EXPORT NeuronGroup void addInSyn(SynapseGroupInternal *synapseGroup){ m_InSyn.push_back(synapseGroup); } void addOutSyn(SynapseGroupInternal *synapseGroup){ m_OutSyn.push_back(synapseGroup); } - void initDerivedParams(double dt); + void finalise(double dt, const Type::TypeContext &context); //! Fuse incoming postsynaptic models void fusePrePostSynapses(bool fusePSM, bool fusePrePostWUM); @@ -228,6 +219,15 @@ class GENN_EXPORT NeuronGroup const std::vector &getFusedWUPreOutSyn() const { return m_FusedWUPreOutSyn; } const std::vector &getFusedPreOutputOutSyn() const { return m_FusedPreOutputOutSyn; } + //! Does this neuron group require an RNG to simulate? + bool isSimRNGRequired() const; + + //! Does this neuron group require an RNG for it's init code? + bool isInitRNGRequired() const; + + //! Does this neuron group require any sort of recording? + bool isRecordingEnabled() const; + //! Gets pointers to all current sources which provide input to this neuron group const std::vector &getCurrentSources() const { return m_MergedCurrentSourceGroups; } @@ -247,6 +247,15 @@ class GENN_EXPORT NeuronGroup //! Helper to get vector of outgoing synapse groups which have presynaptic variables std::vector getFusedOutSynWithPreVars() const; + //! Tokens produced by scanner from simc ode + const std::vector &getSimCodeTokens() const { return m_SimCodeTokens; } + + //! Tokens produced by scanner from threshold condition code + const std::vector &getThresholdConditionCodeTokens() const { return m_ThresholdConditionCodeTokens; } + + //! Tokens produced by scanner from reset code + const std::vector &getResetCodeTokens() const { return m_ResetCodeTokens; } + bool isVarQueueRequired(const std::string &var) const; bool isVarQueueRequired(size_t index) const{ return m_VarQueueRequired[index]; } @@ -319,6 +328,15 @@ class GENN_EXPORT NeuronGroup //! Location of extra global parameters std::vector m_ExtraGlobalParamLocation; + //! Tokens produced by scanner from simc ode + std::vector m_SimCodeTokens; + + //! Tokens produced by scanner from threshold condition code + std::vector m_ThresholdConditionCodeTokens; + + //! Tokens produced by scanner from reset code + std::vector m_ResetCodeTokens; + //! Is spike recording enabled for this population? bool m_SpikeRecordingEnabled; diff --git a/include/genn/genn/neuronGroupInternal.h b/include/genn/genn/neuronGroupInternal.h index 22f8ace253..fd75b851a1 100644 --- a/include/genn/genn/neuronGroupInternal.h +++ b/include/genn/genn/neuronGroupInternal.h @@ -25,7 +25,7 @@ class NeuronGroupInternal : public NeuronGroup using NeuronGroup::addSpkEventCondition; using NeuronGroup::addInSyn; using NeuronGroup::addOutSyn; - using NeuronGroup::initDerivedParams; + using NeuronGroup::finalise; using NeuronGroup::fusePrePostSynapses; using NeuronGroup::injectCurrent; using NeuronGroup::getFusedPSMInSyn; diff --git a/include/genn/genn/snippet.h b/include/genn/genn/snippet.h index e9a6e05a7b..854e224058 100644 --- a/include/genn/genn/snippet.h +++ b/include/genn/genn/snippet.h @@ -191,7 +191,14 @@ class Init const std::unordered_map &getParams() const{ return m_Params; } const std::unordered_map &getDerivedParams() const{ return m_DerivedParams; } - void initDerivedParams(double dt) + + boost::uuids::detail::sha1::digest_type getHashDigest() const + { + return getSnippet()->getHashDigest(); + } + +protected: + void finalise(double dt) { auto derivedParams = m_Snippet->getDerivedParams(); @@ -200,12 +207,6 @@ class Init m_DerivedParams.emplace(d.name, d.func(m_Params, dt)); } } - - boost::uuids::detail::sha1::digest_type getHashDigest() const - { - return getSnippet()->getHashDigest(); - } - private: //---------------------------------------------------------------------------- // Members diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index e45292ed3d..b233674e13 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -97,7 +97,7 @@ bool CustomConnectivityUpdate::isRowSimRNGRequired() const //------------------------------------------------------------------------ bool CustomConnectivityUpdate::isHostRNGRequired() const { - return (getCustomConnectivityUpdateModel()->getHostUpdateCode().find("$(rng)") != std::string::npos); + return Utils::isRNGRequired(getCustomConnectivityUpdateModel()->getHostUpdateCode()); } //------------------------------------------------------------------------ CustomConnectivityUpdate::CustomConnectivityUpdate(const std::string &name, const std::string &updateGroupName, SynapseGroupInternal *synapseGroup, diff --git a/src/genn/genn/gennUtils.cc b/src/genn/genn/gennUtils.cc index 66bacf72ee..683ba88db6 100644 --- a/src/genn/genn/gennUtils.cc +++ b/src/genn/genn/gennUtils.cc @@ -2,6 +2,7 @@ // Standard C++ includes #include +#include // Standard C includes #include @@ -9,6 +10,13 @@ // GeNN includes #include "models.h" +// GeNN transpiler includes +#include "transpiler/errorHandler.h" +#include "transpiler/scanner.h" + +//-------------------------------------------------------------------------- +// Anonymous namespace +//-------------------------------------------------------------------------- namespace { const std::unordered_set randomFuncs{ @@ -18,13 +26,54 @@ const std::unordered_set randomFuncs{ "gennrand_log_normal", "gennrand_gamma", "gennrand_binomial"}; + +std::string upgradeCodeString(const std::string &codeString) +{ + // Build vector of regular expressions to replace old style function calls + // **TODO** build from set of random functions + const std::vector> functionReplacements{ + {std::regex(R"(\$\(gennrand_uniform\))"), "gennrand_uniform()"}, + {std::regex(R"(\$\(gennrand_normal\))"), "gennrand_normal()"}, + {std::regex(R"(\$\(gennrand_exponential\))"), "gennrand_exponential()"}, + {std::regex(R"(\$\(gennrand_log_normal,(.*)\))"), "gennrand_log_normal($1)"}, + {std::regex(R"(\$\(gennrand_gamma,(.*)\))"), "gennrand_gamma($1)"}, + {std::regex(R"(\$\(gennrand_binomial,(.*)\))"), "gennrand_binomial($1)"}, + {std::regex(R"(\$\(addSynapse,(.*)\))"), "addSynapse($1)"}, + {std::regex(R"(\$\(endRow\))"), "endRow()"}, + {std::regex(R"(\$\(endCol\))"), "endCol()"}}; + + // Apply sustitutions to upgraded code string + std::string upgradedCodeString = codeString; + for(const auto &f : functionReplacements) { + upgradedCodeString = std::regex_replace(upgradedCodeString, f.first, f.second); + } + + // **TODO** snake-case -> camel case known built in variables e.g id_pre -> idPre + + // Replace old style $(XX) variables with plain XX + // **NOTE** this is done after functions as single-parameter function calls and variables were indistinguishable with old syntax + const std::regex variable(R"(\$\(([_a-zA-Z][_a-zA-Z0-9]*)\))"); + upgradedCodeString = std::regex_replace(upgradedCodeString, variable, "$1"); + return upgradedCodeString; } +} // Anonymous namespace //-------------------------------------------------------------------------- // GeNN::Utils //-------------------------------------------------------------------------- namespace GeNN::Utils { +std::vector scanCode(const std::string &code, const Type::TypeContext &typeContext, + const std::string &errorContext) +{ + // Upgrade code string + const std::string upgradedCode = upgradeCodeString(code); + + // Scan code string and return tokens + Transpiler::ErrorHandler errorHandler(errorContext); + return Transpiler::Scanner::scanSource(upgradedCode, typeContext, errorHandler); +} +//-------------------------------------------------------------------------- bool isIdentifierReferenced(const std::string &identifierName, const std::vector &tokens) { // Return true if any identifier's lexems match identifier name diff --git a/src/genn/genn/initVarSnippet.cc b/src/genn/genn/initVarSnippet.cc index 3db718a5a7..3538fd7fe1 100644 --- a/src/genn/genn/initVarSnippet.cc +++ b/src/genn/genn/initVarSnippet.cc @@ -37,6 +37,7 @@ void Base::validate(const std::unordered_map ¶mValues) //---------------------------------------------------------------------------- bool Base::requiresKernel() const { - return (getCode().find("$(id_kernel)") != std::string::npos); + // **TODO** regex followed by optional whitespace and ( would b better + return (getCode().find("id_kernel") != std::string::npos); } } // namespace GeNN::InitVarSnippet \ No newline at end of file diff --git a/src/genn/genn/modelSpec.cc b/src/genn/genn/modelSpec.cc index a054b8fbfa..79289004b6 100644 --- a/src/genn/genn/modelSpec.cc +++ b/src/genn/genn/modelSpec.cc @@ -213,10 +213,10 @@ CustomUpdateWU *ModelSpec::addCustomUpdate(const std::string &name, const std::s // --------------------------------------------------------------------------- void ModelSpec::finalize() { - // NEURON GROUPS + // Finalise neuron groups + const auto typeContext = getTypeContext(); for(auto &n : m_LocalNeuronGroups) { - // Initialize derived parameters - n.second.initDerivedParams(m_DT); + n.second.finalise(m_DT, typeContext); } // SYNAPSE groups @@ -385,6 +385,11 @@ boost::uuids::detail::sha1::digest_type ModelSpec::getHashDigest() const return hash.get_digest(); } // --------------------------------------------------------------------------- +Type::TypeContext ModelSpec::getTypeContext() const +{ + return Type::TypeContext{{"scalar", getPrecision()}, {"timepoint", getTimePrecision()}}; +} +// --------------------------------------------------------------------------- NeuronGroupInternal *ModelSpec::findNeuronGroupInternal(const std::string &name) { // If a matching local neuron group is found, return it diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index ad9b50f7b5..5f27ca3429 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -34,6 +34,13 @@ void Base::validate(const std::unordered_map ¶mValues, Utils::validateInitialisers(vars, varValues, "variable", description); } +void VarInit::finalise(double dt, const Type::TypeContext &context, const std::string &errorContext) +{ + Snippet::Init::finalise(dt); + + m_CodeTokens = Utils::scanCode(getSnippet()->getCode(), context, errorContext); +} + //---------------------------------------------------------------------------- // VarReference //---------------------------------------------------------------------------- diff --git a/src/genn/genn/neuronGroup.cc b/src/genn/genn/neuronGroup.cc index 51b22baa6e..a1a86202e5 100644 --- a/src/genn/genn/neuronGroup.cc +++ b/src/genn/genn/neuronGroup.cc @@ -220,13 +220,35 @@ bool NeuronGroup::isZeroCopyEnabled() const return false; } + +//---------------------------------------------------------------------------- +bool NeuronGroup::isRecordingEnabled() const + + // Return true if spike recording is enabled + if(m_SpikeRecordingEnabled) { + return true; + } + + // Return true if spike event recording is enabled + if(m_SpikeEventRecordingEnabled) { + return true; + } + else { + return false; + } +} +//---------------------------------------------------------------------------- +void NeuronGroup::injectCurrent(CurrentSourceInternal *src) +{ + m_MergedCurrentSourceGroups.push_back(src); +} //---------------------------------------------------------------------------- bool NeuronGroup::isSimRNGRequired() const { // Returns true if any parts of the neuron code require an RNG - if(Utils::isRNGRequired(getNeuronModel()->getSimCode()) - || Utils::isRNGRequired(getNeuronModel()->getThresholdConditionCode()) - || Utils::isRNGRequired(getNeuronModel()->getResetCode())) + if(Utils::isRNGRequired(getSimCodeTokens()) + || Utils::isRNGRequired(getThresholdConditionCodeTokens()) + || Utils::isRNGRequired(getResetCodeTokens())) { return true; } @@ -282,27 +304,6 @@ bool NeuronGroup::isInitRNGRequired() const [](const SynapseGroupInternal *sg){ return sg->isPSInitRNGRequired(); }); } //---------------------------------------------------------------------------- -bool NeuronGroup::isRecordingEnabled() const -{ - // Return true if spike recording is enabled - if(m_SpikeRecordingEnabled) { - return true; - } - - // Return true if spike event recording is enabled - if(m_SpikeEventRecordingEnabled) { - return true; - } - else { - return false; - } -} -//---------------------------------------------------------------------------- -void NeuronGroup::injectCurrent(CurrentSourceInternal *src) -{ - m_MergedCurrentSourceGroups.push_back(src); -} -//---------------------------------------------------------------------------- NeuronGroup::NeuronGroup(const std::string &name, int numNeurons, const NeuronModels::Base *neuronModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) @@ -334,7 +335,7 @@ void NeuronGroup::updatePostVarQueues(const std::string &code) updateVarQueues(code, "_post"); } //---------------------------------------------------------------------------- -void NeuronGroup::initDerivedParams(double dt) +void NeuronGroup::finalise(double dt, const Type::TypeContext &context) { auto derivedParams = getNeuronModel()->getDerivedParams(); @@ -343,10 +344,18 @@ void NeuronGroup::initDerivedParams(double dt) m_DerivedParams.emplace(d.name, d.func(m_Params, dt)); } - // Initialise derived parameters for variable initialisers + // Finalise variable initialisers for(auto &v : m_VarInitialisers) { - v.second.initDerivedParams(dt); + v.second.finalise(dt, context, "Variable '" + v.first + "' initialisation code"); } + + // Scan neuron model code strings + m_SimCodeTokens = Utils::scanCode(getNeuronModel()->getSimCode(), context, + "Neuron group '" + getName() + "' sim code"); + m_ThresholdConditionCodeTokens = Utils::scanCode(getNeuronModel()->getThresholdConditionCode(), context, + "Neuron group '" + getName() + "' threshold condition code"); + m_SimCodeTokens = Utils::scanCode(getNeuronModel()->getResetCode(), context, + "Neuron group '" + getName() + "' reset code"); } //---------------------------------------------------------------------------- void NeuronGroup::fusePrePostSynapses(bool fusePSM, bool fusePrePostWUM) @@ -452,7 +461,8 @@ std::vector NeuronGroup::getFusedOutSynWithPreVars() cons //---------------------------------------------------------------------------- void NeuronGroup::addSpkEventCondition(const std::string &code, SynapseGroupInternal *synapseGroup) { - const auto *wu = synapseGroup->getWUModel(); + assert(false); + /*const auto *wu = synapseGroup->getWUModel(); // Determine if any EGPs are required by threshold code const auto wuEGPs = wu->getExtraGlobalParams(); @@ -471,7 +481,7 @@ void NeuronGroup::addSpkEventCondition(const std::string &code, SynapseGroupInte }); // Add threshold, support code, synapse group and whether egps are required to set - m_SpikeEventCondition.emplace(code, wu->getSimSupportCode(), egpInThresholdCode || preVarInThresholdCode, synapseGroup); + m_SpikeEventCondition.emplace(code, wu->getSimSupportCode(), egpInThresholdCode || preVarInThresholdCode, synapseGroup);*/ } //---------------------------------------------------------------------------- bool NeuronGroup::isVarQueueRequired(const std::string &var) const From 6436f442ac636fccc52b3eaffe7611e316e86a2b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 4 Jul 2023 18:58:55 +0100 Subject: [PATCH 316/725] * scanning in finalise throughout * new helpers to check whether tokens are empty (EOF actually very nice verification we're not calling pre-finalise) --- include/genn/genn/gennUtils.h | 4 +- .../genn/genn/initSparseConnectivitySnippet.h | 26 +- .../genn/initToeplitzConnectivitySnippet.h | 21 +- include/genn/genn/initVarSnippet.h | 3 - include/genn/genn/modelSpec.h | 2 +- include/genn/genn/modelSpecInternal.h | 2 +- include/genn/genn/models.h | 16 +- include/genn/genn/neuronGroup.h | 6 +- include/genn/genn/synapseGroup.h | 92 ++-- include/genn/genn/synapseGroupInternal.h | 3 +- src/genn/genn/gennUtils.cc | 22 +- .../genn/initSparseConnectivitySnippet.cc | 24 + .../genn/initToeplitzConnectivitySnippet.cc | 17 + src/genn/genn/initVarSnippet.cc | 6 - src/genn/genn/modelSpec.cc | 28 +- src/genn/genn/models.cc | 17 +- src/genn/genn/neuronGroup.cc | 26 +- src/genn/genn/synapseGroup.cc | 438 ++++++++++-------- 18 files changed, 447 insertions(+), 306 deletions(-) diff --git a/include/genn/genn/gennUtils.h b/include/genn/genn/gennUtils.h index 2e76df47de..9a9518e2d8 100644 --- a/include/genn/genn/gennUtils.h +++ b/include/genn/genn/gennUtils.h @@ -39,6 +39,8 @@ namespace GeNN::Utils GENN_EXPORT std::vector scanCode(const std::string &code, const Type::TypeContext &typeContext, const std::string &errorContext); +GENN_EXPORT bool areTokensEmpty(const std::vector &tokens); + //-------------------------------------------------------------------------- //! \brief Does the code string contain any functions requiring random number generator //-------------------------------------------------------------------------- @@ -52,7 +54,7 @@ GENN_EXPORT bool isRNGRequired(const std::vector &tokens); //-------------------------------------------------------------------------- //! \brief Does the model with the vectors of variable initialisers and modes require an RNG for the specified init location i.e. host or device //-------------------------------------------------------------------------- -GENN_EXPORT bool isRNGRequired(const std::unordered_map> &varInitialisers); +GENN_EXPORT bool isRNGRequired(const std::unordered_map &varInitialisers); //-------------------------------------------------------------------------- //! \brief Is the variable name valid? GeNN variable names must obey C variable naming rules diff --git a/include/genn/genn/initSparseConnectivitySnippet.h b/include/genn/genn/initSparseConnectivitySnippet.h index a84e9fce07..610116515a 100644 --- a/include/genn/genn/initSparseConnectivitySnippet.h +++ b/include/genn/genn/initSparseConnectivitySnippet.h @@ -72,7 +72,31 @@ class GENN_EXPORT Base : public Snippet::Base //---------------------------------------------------------------------------- // Init //---------------------------------------------------------------------------- -using Init = Snippet::Init; +class Init : public Snippet::Init +{ +public: + using Snippet::Init::Init; + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + void finalise(double dt, const Type::TypeContext &context, const std::string &errorContext); + + bool isRNGRequired() const; + bool isHostRNGRequired() const; + + const std::vector &getRowBuildCodeTokens() const{ return m_RowBuildCodeTokens; } + const std::vector &getColBuildCodeTokens() const{ return m_ColBuildCodeTokens; } + const std::vector &getHostInitCodeTokens() const{ return m_HostInitCodeTokens; } + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + std::vector m_RowBuildCodeTokens; + std::vector m_ColBuildCodeTokens; + std::vector m_HostInitCodeTokens; +}; //---------------------------------------------------------------------------- // InitSparseConnectivitySnippet::Uninitialised diff --git a/include/genn/genn/initToeplitzConnectivitySnippet.h b/include/genn/genn/initToeplitzConnectivitySnippet.h index e3d6638444..e32e18f42c 100644 --- a/include/genn/genn/initToeplitzConnectivitySnippet.h +++ b/include/genn/genn/initToeplitzConnectivitySnippet.h @@ -58,7 +58,26 @@ class GENN_EXPORT Base : public Snippet::Base //---------------------------------------------------------------------------- // Init //---------------------------------------------------------------------------- -using Init = Snippet::Init; +class Init : public Snippet::Init +{ +public: + using Snippet::Init::Init; + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + void finalise(double dt, const Type::TypeContext &context, const std::string &errorContext); + + bool isRNGRequired() const; + + const std::vector &getDiagonalBuildCodeTokens() const{ return m_DiagonalBuildCodeTokens; } +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + std::vector m_DiagonalBuildCodeTokens; + +}; //---------------------------------------------------------------------------- // GeNN::InitToeplitzConnectivitySnippet::Uninitialised diff --git a/include/genn/genn/initVarSnippet.h b/include/genn/genn/initVarSnippet.h index 2056c1e22d..015976dd98 100644 --- a/include/genn/genn/initVarSnippet.h +++ b/include/genn/genn/initVarSnippet.h @@ -30,9 +30,6 @@ class GENN_EXPORT Base : public Snippet::Base //! Validate names of parameters etc void validate(const std::unordered_map ¶mValues) const; - - //! Does this var init snippet require kernel-based connectivity - bool requiresKernel() const; }; //---------------------------------------------------------------------------- diff --git a/include/genn/genn/modelSpec.h b/include/genn/genn/modelSpec.h index e886257537..1b8608ba72 100644 --- a/include/genn/genn/modelSpec.h +++ b/include/genn/genn/modelSpec.h @@ -661,7 +661,7 @@ class GENN_EXPORT ModelSpec // Protected methods //-------------------------------------------------------------------------- //! Finalise model - void finalize(); + void finalise(); //-------------------------------------------------------------------------- // Protected const methods diff --git a/include/genn/genn/modelSpecInternal.h b/include/genn/genn/modelSpecInternal.h index 4420f3873c..3ae2aa8832 100644 --- a/include/genn/genn/modelSpecInternal.h +++ b/include/genn/genn/modelSpecInternal.h @@ -22,7 +22,7 @@ class ModelSpecInternal : public ModelSpec using ModelSpec::getCustomWUUpdates; using ModelSpec::getCustomConnectivityUpdates; - using ModelSpec::finalize; + using ModelSpec::finalise; using ModelSpec::zeroCopyInUse; using ModelSpec::isRecordingInUse; diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 927447f260..109fa2873b 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -127,21 +127,19 @@ class GENN_EXPORT Base : public Snippet::Base class VarInit : public Snippet::Init { public: - VarInit(const InitVarSnippet::Base *snippet, const std::unordered_map ¶ms) - : Snippet::Init(snippet, params) - { - } - - VarInit(double constant) - : Snippet::Init(InitVarSnippet::Constant::getInstance(), {{"constant", constant}}) - { - } + using Snippet::Init::Init; //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ void finalise(double dt, const Type::TypeContext &context, const std::string &errorContext); + bool isRNGRequired() const; + + bool isKernelRequired() const; + + const std::vector &getCodeTokens() const{ return m_CodeTokens; } + private: //------------------------------------------------------------------------ // Members diff --git a/include/genn/genn/neuronGroup.h b/include/genn/genn/neuronGroup.h index edd7417954..431cabfddf 100644 --- a/include/genn/genn/neuronGroup.h +++ b/include/genn/genn/neuronGroup.h @@ -188,10 +188,10 @@ class GENN_EXPORT NeuronGroup void checkNumDelaySlots(unsigned int requiredDelay); //! Update which presynaptic variables require queues based on piece of code - void updatePreVarQueues(const std::string &code); + void updatePreVarQueues(const std::vector &tokens); //! Update which postsynaptic variables require queues based on piece of code - void updatePostVarQueues(const std::string &code); + void updatePostVarQueues(const std::vector &tokens); void addSpkEventCondition(const std::string &code, SynapseGroupInternal *synapseGroup); @@ -278,7 +278,7 @@ class GENN_EXPORT NeuronGroup // Private methods //------------------------------------------------------------------------ //! Update which variables require queues based on piece of code - void updateVarQueues(const std::string &code, const std::string &suffix); + void updateVarQueues(const std::vector &tokens, const std::string &suffix); //------------------------------------------------------------------------ // Members diff --git a/include/genn/genn/synapseGroup.h b/include/genn/genn/synapseGroup.h index 58dd7a39f4..c1439ad848 100644 --- a/include/genn/genn/synapseGroup.h +++ b/include/genn/genn/synapseGroup.h @@ -201,36 +201,6 @@ class GENN_EXPORT SynapseGroup /*! This is only used by extra global parameters which are pointers*/ VarLocation getSparseConnectivityExtraGlobalParamLocation(const std::string ¶mName) const; - //! Does this synapse group require dendritic delay? - bool isDendriticDelayRequired() const; - - //! Does this synapse group define presynaptic output? - bool isPresynapticOutputRequired() const; - - //! Does this synapse group require an RNG to generate procedural connectivity? - bool isProceduralConnectivityRNGRequired() const; - - //! Does this synapse group require an RNG for it's postsynaptic init code? - bool isPSInitRNGRequired() const; - - //! Does this synapse group require an RNG for it's weight update init code? - bool isWUInitRNGRequired() const; - - //! Does this synapse group require an RNG for it's weight update presynaptic variable init code? - bool isWUPreInitRNGRequired() const; - - //! Does this synapse group require an RNG for it's weight update postsynaptic variable init code? - bool isWUPostInitRNGRequired() const; - - //! Does this synapse group require a RNG for any sort of initialization - bool isHostInitRNGRequired() const; - - //! Is var init code required for any variables in this synapse group's weight update model? - bool isWUVarInitRequired() const; - - //! Is sparse connectivity initialisation code required for this synapse group? - bool isSparseConnectivityInitRequired() const; - protected: SynapseGroup(const std::string &name, SynapseMatrixType matrixType, unsigned int delaySteps, const WeightUpdateModels::Base *wu, const std::unordered_map &wuParams, const std::unordered_map &wuVarInitialisers, const std::unordered_map &wuPreVarInitialisers, const std::unordered_map &wuPostVarInitialisers, @@ -254,7 +224,7 @@ class GENN_EXPORT SynapseGroup void setFusedWUPostVarSuffix(const std::string &suffix){ m_FusedWUPostVarSuffix = suffix; } void setFusedPreOutputSuffix(const std::string &suffix){ m_FusedPreOutputSuffix = suffix; } - void initDerivedParams(double dt); + void finalise(double dt, const Type::TypeContext &context); //! Add reference to custom connectivity update, referencing this synapse group void addCustomUpdateReference(CustomConnectivityUpdateInternal *cu){ m_CustomConnectivityUpdateReferences.push_back(cu); } @@ -271,7 +241,17 @@ class GENN_EXPORT SynapseGroup const std::unordered_map &getWUDerivedParams() const{ return m_WUDerivedParams; } const std::unordered_map &getPSDerivedParams() const{ return m_PSDerivedParams; } - const SynapseGroupInternal *getWeightSharingMaster() const { return m_WeightSharingMaster; } + const std::vector &getWUSimCodeTokens() const{ return m_WUSimCodeTokens; } + const std::vector &getWUEventCodeTokens() const{ return m_WUEventCodeTokens; } + const std::vector &getWUPostLearnCodeTokens() const{ return m_WUPostLearnCodeTokens; } + const std::vector &getWUSynapseDynamicsCodeTokens() const{ return m_WUSynapseDynamicsCodeTokens; } + const std::vector &getWUEventThresholdCodeTokens() const{ return m_WUEventThresholdCodeTokens; } + const std::vector &getWUPreSpikeCodeTokens() const{ return m_WUPreSpikeCodeTokens; } + const std::vector &getWUPostSpikeCodeTokens() const{ return m_WUPostSpikeCodeTokens; } + const std::vector &getWUPreDynamicsCodeTokens() const{ return m_WUPreDynamicsCodeTokens; } + const std::vector &getWUPostDynamicsCodeTokens() const{ return m_WUPostDynamicsCodeTokens; } + const std::vector &getPSApplyInputCodeTokens() const{ return m_PSApplyInputCodeTokens; } + const std::vector &getPSDecayCodeTokens() const{ return m_PSDecayCodeTokens; } //!< Does the event threshold needs to be retested in the synapse kernel? /*! This is required when the pre-synaptic neuron population's outgoing synapse groups require different event threshold */ @@ -313,6 +293,24 @@ class GENN_EXPORT SynapseGroup //! model been fused with those from other synapse groups? bool isWUPostModelFused() const { return m_FusedWUPostVarSuffix != getName(); } + //! Does this synapse group require dendritic delay? + bool isDendriticDelayRequired() const; + + //! Does this synapse group define presynaptic output? + bool isPresynapticOutputRequired() const; + + //! Does this synapse group require an RNG to generate procedural connectivity? + bool isProceduralConnectivityRNGRequired() const; + + //! Does this synapse group require an RNG for it's weight update init code? + bool isWUInitRNGRequired() const; + + //! Is var init code required for any variables in this synapse group's weight update model? + bool isWUVarInitRequired() const; + + //! Is sparse connectivity initialisation code required for this synapse group? + bool isSparseConnectivityInitRequired() const; + //! Get the type to use for sparse connectivity indices for synapse group const Type::ResolvedType &getSparseIndType() const; @@ -422,9 +420,6 @@ class GENN_EXPORT SynapseGroup //! Pointer to postsynaptic neuron group NeuronGroupInternal * const m_TrgNeuronGroup; - //! Pointer to 'master' weight sharing group if this is a slave - const SynapseGroupInternal *m_WeightSharingMaster; - //! Does the event threshold needs to be retested in the synapse kernel? /*! This is required when the pre-synaptic neuron population's outgoing synapse groups require different event threshold */ bool m_EventThresholdReTestRequired; @@ -529,5 +524,32 @@ class GENN_EXPORT SynapseGroup //! Custom updates which reference this synapse group /*! Because, if connectivity is sparse, all groups share connectivity this is required if connectivity changes. */ std::vector m_CustomUpdateReferences; + + //! Tokens produced by scanner from threshold condition code + std::vector m_WUSimCodeTokens; + + std::vector m_WUEventCodeTokens; + + std::vector m_WUPostLearnCodeTokens; + + std::vector m_WUSynapseDynamicsCodeTokens; + + std::vector m_WUEventThresholdCodeTokens; + + std::vector m_WUPreSpikeCodeTokens; + + std::vector m_WUPostSpikeCodeTokens; + + std::vector m_WUPreDynamicsCodeTokens; + + std::vector m_WUPostDynamicsCodeTokens; + + std::vector m_PSApplyInputCodeTokens; + + std::vector m_PSDecayCodeTokens; + + //! Tokens produced by scanner from reset code + std::vector m_ResetCodeTokens; + }; } // namespace GeNN diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h index 4d0e2b410b..6662656601 100644 --- a/include/genn/genn/synapseGroupInternal.h +++ b/include/genn/genn/synapseGroupInternal.h @@ -32,7 +32,6 @@ class SynapseGroupInternal : public SynapseGroup using SynapseGroup::getSrcNeuronGroup; using SynapseGroup::getTrgNeuronGroup; - using SynapseGroup::getWeightSharingMaster; using SynapseGroup::getWUDerivedParams; using SynapseGroup::getPSDerivedParams; using SynapseGroup::setEventThresholdReTestRequired; @@ -40,7 +39,7 @@ class SynapseGroupInternal : public SynapseGroup using SynapseGroup::setFusedPreOutputSuffix; using SynapseGroup::setFusedWUPreVarSuffix; using SynapseGroup::setFusedWUPostVarSuffix; - using SynapseGroup::initDerivedParams; + using SynapseGroup::finalise; using SynapseGroup::addCustomUpdateReference; using SynapseGroup::isEventThresholdReTestRequired; using SynapseGroup::getFusedPSVarSuffix; diff --git a/src/genn/genn/gennUtils.cc b/src/genn/genn/gennUtils.cc index 683ba88db6..810461551b 100644 --- a/src/genn/genn/gennUtils.cc +++ b/src/genn/genn/gennUtils.cc @@ -74,8 +74,26 @@ std::vector scanCode(const std::string &code, const Type::Typ return Transpiler::Scanner::scanSource(upgradedCode, typeContext, errorHandler); } //-------------------------------------------------------------------------- +bool areTokensEmpty(const std::vector &tokens) +{ + // For easy parsing, there should always be at least one token + assert(tokens.size() >= 1); + + // If there's only one token, assert it is actually an EOF and return true + if(tokens.size() == 1) { + assert(tokens.front().type == Transpiler::Token::Type::END_OF_FILE); + return true; + } + // Otherwise, return false + else { + return false; + } +} +//-------------------------------------------------------------------------- bool isIdentifierReferenced(const std::string &identifierName, const std::vector &tokens) { + assert(!tokens.empty()); + // Return true if any identifier's lexems match identifier name return std::any_of(tokens.cbegin(), tokens.cend(), [&identifierName](const auto &t) @@ -87,6 +105,8 @@ bool isIdentifierReferenced(const std::string &identifierName, const std::vector //-------------------------------------------------------------------------- bool isRNGRequired(const std::vector &tokens) { + assert(!tokens.empty()); + // Return true if any identifier's lexems are in set of random functions return std::any_of(tokens.cbegin(), tokens.cend(), [](const auto &t) @@ -96,7 +116,7 @@ bool isRNGRequired(const std::vector &tokens) } //-------------------------------------------------------------------------- -bool isRNGRequired(const std::unordered_map> &varInitialisers) +bool isRNGRequired(const std::unordered_map &varInitialisers) { // Return true if any of these variable initialisers require an RNG return std::any_of(varInitialisers.cbegin(), varInitialisers.cend(), diff --git a/src/genn/genn/initSparseConnectivitySnippet.cc b/src/genn/genn/initSparseConnectivitySnippet.cc index 9a83167ad6..ad3813cde1 100644 --- a/src/genn/genn/initSparseConnectivitySnippet.cc +++ b/src/genn/genn/initSparseConnectivitySnippet.cc @@ -38,4 +38,28 @@ void Base::validate(const std::unordered_map ¶mValues) Utils::validateVecNames(getRowBuildStateVars(), "Row building state variable"); Utils::validateVecNames(getColBuildStateVars(), "Column building state variable"); } + +//---------------------------------------------------------------------------- +// GeNN::InitSparseConnectivitySnippet::Init +//---------------------------------------------------------------------------- +void Init::finalise(double dt, const Type::TypeContext &context, const std::string &errorContext) +{ + // Superclass + Snippet::Init::finalise(dt); + + // Scan code tokens + m_RowBuildCodeTokens = Utils::scanCode(getSnippet()->getRowBuildCode(), context, errorContext + "row build code"); + m_ColBuildCodeTokens = Utils::scanCode(getSnippet()->getColBuildCode(), context, errorContext + "col build code"); + m_HostInitCodeTokens = Utils::scanCode(getSnippet()->getHostInitCode(), context, errorContext + "host init code"); +} +//---------------------------------------------------------------------------- +bool Init::isRNGRequired() const +{ + return (Utils::isRNGRequired(m_RowBuildCodeTokens) || Utils::isRNGRequired(m_ColBuildCodeTokens)); +} +//---------------------------------------------------------------------------- +bool Init::isHostRNGRequired() const +{ + return Utils::isRNGRequired(m_HostInitTokens); +} } // namespace GeNN::InitSparseConnectivitySnippet diff --git a/src/genn/genn/initToeplitzConnectivitySnippet.cc b/src/genn/genn/initToeplitzConnectivitySnippet.cc index c07f73766f..7d9283f82c 100644 --- a/src/genn/genn/initToeplitzConnectivitySnippet.cc +++ b/src/genn/genn/initToeplitzConnectivitySnippet.cc @@ -29,4 +29,21 @@ void Base::validate(const std::unordered_map ¶mValues) Snippet::Base::validate(paramValues, "Toeplitz connectivity initialiser "); Utils::validateVecNames(getDiagonalBuildStateVars(), "Row building state variable"); } + +//---------------------------------------------------------------------------- +// GeNN::InitToeplitzConnectivitySnippet::Init +//---------------------------------------------------------------------------- +void Init::finalise(double dt, const Type::TypeContext &context, const std::string &errorContext) +{ + // Superclass + Snippet::Init::finalise(dt); + + // Scan code tokens + m_DiagonalBuildCodeTokens = Utils::scanCode(getSnippet()->getDiagonalBuildCode(), context, errorContext + "diagonal build code"); +} +//---------------------------------------------------------------------------- +bool Init::isRNGRequired() const +{ + return Utils::isRNGRequired(m_DiagonalBuildCodeTokens); +} } // namespace GeNN::InitToeplitzConnectivitySnippet \ No newline at end of file diff --git a/src/genn/genn/initVarSnippet.cc b/src/genn/genn/initVarSnippet.cc index 3538fd7fe1..2992cd4eeb 100644 --- a/src/genn/genn/initVarSnippet.cc +++ b/src/genn/genn/initVarSnippet.cc @@ -34,10 +34,4 @@ void Base::validate(const std::unordered_map ¶mValues) // Superclass Snippet::Base::validate(paramValues, "Variable initialiser "); } -//---------------------------------------------------------------------------- -bool Base::requiresKernel() const -{ - // **TODO** regex followed by optional whitespace and ( would b better - return (getCode().find("id_kernel") != std::string::npos); -} } // namespace GeNN::InitVarSnippet \ No newline at end of file diff --git a/src/genn/genn/modelSpec.cc b/src/genn/genn/modelSpec.cc index 79289004b6..4989c728c0 100644 --- a/src/genn/genn/modelSpec.cc +++ b/src/genn/genn/modelSpec.cc @@ -211,7 +211,7 @@ CustomUpdateWU *ModelSpec::addCustomUpdate(const std::string &name, const std::s } } // --------------------------------------------------------------------------- -void ModelSpec::finalize() +void ModelSpec::finalise() { // Finalise neuron groups const auto typeContext = getTypeContext(); @@ -224,31 +224,9 @@ void ModelSpec::finalize() const auto *wu = s.second.getWUModel(); // Initialize derived parameters - s.second.initDerivedParams(m_DT); + s.second.finalise(m_DT); - // Mark any pre or postsyaptic neuron variables referenced in sim code as requiring queues - if (!wu->getSimCode().empty()) { - s.second.getSrcNeuronGroup()->updatePreVarQueues(wu->getSimCode()); - s.second.getTrgNeuronGroup()->updatePostVarQueues(wu->getSimCode()); - } - - // Mark any pre or postsyaptic neuron variables referenced in event code as requiring queues - if (!wu->getEventCode().empty()) { - s.second.getSrcNeuronGroup()->updatePreVarQueues(wu->getEventCode()); - s.second.getTrgNeuronGroup()->updatePostVarQueues(wu->getEventCode()); - } - - // Mark any pre or postsyaptic neuron variables referenced in postsynaptic update code as requiring queues - if (!wu->getLearnPostCode().empty()) { - s.second.getSrcNeuronGroup()->updatePreVarQueues(wu->getLearnPostCode()); - s.second.getTrgNeuronGroup()->updatePostVarQueues(wu->getLearnPostCode()); - } - - // Mark any pre or postsyaptic neuron variables referenced in synapse dynamics code as requiring queues - if (!wu->getSynapseDynamicsCode().empty()) { - s.second.getSrcNeuronGroup()->updatePreVarQueues(wu->getSynapseDynamicsCode()); - s.second.getTrgNeuronGroup()->updatePostVarQueues(wu->getSynapseDynamicsCode()); - } + } // CURRENT SOURCES diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 5f27ca3429..12f3925282 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -34,11 +34,26 @@ void Base::validate(const std::unordered_map ¶mValues, Utils::validateInitialisers(vars, varValues, "variable", description); } +//---------------------------------------------------------------------------- +// VarInit +//---------------------------------------------------------------------------- void VarInit::finalise(double dt, const Type::TypeContext &context, const std::string &errorContext) { + // Superclass Snippet::Init::finalise(dt); - m_CodeTokens = Utils::scanCode(getSnippet()->getCode(), context, errorContext); + // Scan code tokens + m_CodeTokens = Utils::scanCode(getSnippet()->getCode(), context, errorContext + "initialisation code"); +} +//---------------------------------------------------------------------------- +bool VarInit::isRNGRequired() const +{ + return Utils::isRNGRequired(m_CodeTokens); +} +//---------------------------------------------------------------------------- +bool VarInit::isKernelRequired() const +{ + return Utils::isIdentifierReferenced("id_kernel", m_CodeTokens); } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/neuronGroup.cc b/src/genn/genn/neuronGroup.cc index a1a86202e5..bbf5873135 100644 --- a/src/genn/genn/neuronGroup.cc +++ b/src/genn/genn/neuronGroup.cc @@ -223,7 +223,7 @@ bool NeuronGroup::isZeroCopyEnabled() const //---------------------------------------------------------------------------- bool NeuronGroup::isRecordingEnabled() const - +{ // Return true if spike recording is enabled if(m_SpikeRecordingEnabled) { return true; @@ -265,8 +265,8 @@ bool NeuronGroup::isSimRNGRequired() const return std::any_of(getInSyn().cbegin(), getInSyn().cend(), [](const SynapseGroupInternal *sg) { - return (Utils::isRNGRequired(sg->getPSModel()->getApplyInputCode()) || - Utils::isRNGRequired(sg->getPSModel()->getDecayCode())); + return (Utils::isRNGRequired(sg->getPSApplyInputCodeTokens()) || + Utils::isRNGRequired(sg->getPSDecayCodeTokens())); }); } //---------------------------------------------------------------------------- @@ -286,14 +286,14 @@ bool NeuronGroup::isInitRNGRequired() const // Return true if any incoming synapse groups require and RNG to initialize their postsynaptic variables if(std::any_of(getInSyn().cbegin(), getInSyn().cend(), - [](const SynapseGroupInternal *sg) { return sg->isWUPostInitRNGRequired(); })) + [](const SynapseGroupInternal *sg) { return Utils::isRNGRequired(sg->getWUPostVarInitialisers()); })) { return true; } // Return true if any outgoing synapse groups require and RNG to initialize their presynaptic variables if(std::any_of(getOutSyn().cbegin(), getOutSyn().cend(), - [](const SynapseGroupInternal *sg) { return sg->isWUPreInitRNGRequired(); })) + [](const SynapseGroupInternal *sg) { return Utils::isRNGRequired(sg->getWUPreVarInitialisers()); })) { return true; } @@ -301,7 +301,7 @@ bool NeuronGroup::isInitRNGRequired() const // Return true if any of the incoming synapse groups have state variables which require an RNG to initialise // **NOTE** these are included here as they are initialised in neuron initialisation threads return std::any_of(getInSyn().cbegin(), getInSyn().cend(), - [](const SynapseGroupInternal *sg){ return sg->isPSInitRNGRequired(); }); + [](const SynapseGroupInternal *sg){ return Utils::isRNGRequired(sg->getPSVarInitialisers()); }); } //---------------------------------------------------------------------------- NeuronGroup::NeuronGroup(const std::string &name, int numNeurons, const NeuronModels::Base *neuronModel, @@ -325,14 +325,14 @@ void NeuronGroup::checkNumDelaySlots(unsigned int requiredDelay) } } //---------------------------------------------------------------------------- -void NeuronGroup::updatePreVarQueues(const std::string &code) +void NeuronGroup::updatePreVarQueues(const std::vector &tokens) { - updateVarQueues(code, "_pre"); + updateVarQueues(tokens, "_pre"); } //---------------------------------------------------------------------------- -void NeuronGroup::updatePostVarQueues(const std::string &code) +void NeuronGroup::updatePostVarQueues(const std::vector &tokens) { - updateVarQueues(code, "_post"); + updateVarQueues(tokens, "_post"); } //---------------------------------------------------------------------------- void NeuronGroup::finalise(double dt, const Type::TypeContext &context) @@ -346,7 +346,7 @@ void NeuronGroup::finalise(double dt, const Type::TypeContext &context) // Finalise variable initialisers for(auto &v : m_VarInitialisers) { - v.second.finalise(dt, context, "Variable '" + v.first + "' initialisation code"); + v.second.finalise(dt, context, "Variable '" + v.first + "' "); } // Scan neuron model code strings @@ -589,13 +589,13 @@ boost::uuids::detail::sha1::digest_type NeuronGroup::getVarLocationHashDigest() return hash.get_digest(); } //---------------------------------------------------------------------------- -void NeuronGroup::updateVarQueues(const std::string &code, const std::string &suffix) +void NeuronGroup::updateVarQueues(const std::vector &tokens, const std::string &suffix) { // Loop through variables const auto vars = getNeuronModel()->getVars(); for(size_t i = 0; i < vars.size(); i++) { // If the code contains a reference to this variable, set corresponding flag - if (code.find(vars[i].name + suffix) != std::string::npos) { + if(Utils::isIdentifierReferenced(vars[i].name + suffix, tokens)) { m_VarQueueRequired[i] = true; } } diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index da9835d9d1..4db333974d 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -305,131 +305,6 @@ VarLocation SynapseGroup::getSparseConnectivityExtraGlobalParamLocation(const st return m_ConnectivityExtraGlobalParamLocation[m_SparseConnectivityInitialiser.getSnippet()->getExtraGlobalParamIndex(paramName)]; } //---------------------------------------------------------------------------- -bool SynapseGroup::isDendriticDelayRequired() const -{ - // If addToInSynDelay function is used in sim code, return true - // **TODO** regex followed by optional whitespace and ( would b better - if(getWUModel()->getSimCode().find("addToPostDelay") != std::string::npos) { - return true; - } - - // If addToInSynDelay function is used in event code, return true - // **TODO** regex followed by optional whitespace and ( would b better - if(getWUModel()->getEventCode().find("addToPostDelay") != std::string::npos) { - return true; - } - - // If addToInSynDelay function is used in synapse dynamics, return tru - // **TODO** regex followed by optional whitespace and ( would b bettere - if(getWUModel()->getSynapseDynamicsCode().find("addToPostDelay") != std::string::npos) { - return true; - } - - return false; -} -//---------------------------------------------------------------------------- -bool SynapseGroup::isPresynapticOutputRequired() const -{ - // If addToPre function is used in sim_code, return true - // **TODO** regex followed by optional whitespace and ( would b better - if(getWUModel()->getSimCode().find("addToPre") != std::string::npos) { - return true; - } - - // If addToPre function is used in learn_post_code, return true - // **TODO** regex followed by optional whitespace and ( would b better - if(getWUModel()->getLearnPostCode().find("addToPre") != std::string::npos) { - return true; - } - - // If addToPre function is used in event_code, return true - // **TODO** regex followed by optional whitespace and ( would b better - if(getWUModel()->getEventCode().find("addToPre") != std::string::npos) { - return true; - } - - // If addToPre function is used in synapse_dynamics, return true - // **TODO** regex followed by optional whitespace and ( would b better - if(getWUModel()->getSynapseDynamicsCode().find("addToPre") != std::string::npos) { - return true; - } - - return false; -} -//---------------------------------------------------------------------------- -bool SynapseGroup::isProceduralConnectivityRNGRequired() const -{ - if(m_MatrixType & SynapseMatrixConnectivity::PROCEDURAL) { - return (Utils::isRNGRequired(m_SparseConnectivityInitialiser.getSnippet()->getRowBuildCode()) - || Utils::isRNGRequired(m_SparseConnectivityInitialiser.getSnippet()->getColBuildCode())); - } - else if(m_MatrixType & SynapseMatrixConnectivity::TOEPLITZ) { - return (Utils::isRNGRequired(m_ToeplitzConnectivityInitialiser.getSnippet()->getDiagonalBuildCode())); - } - else { - return false; - } -} -//---------------------------------------------------------------------------- -bool SynapseGroup::isPSInitRNGRequired() const -{ - // If initialising the postsynaptic variables require an RNG, return true - return Utils::isRNGRequired(m_PSVarInitialisers); -} -//---------------------------------------------------------------------------- -bool SynapseGroup::isWUInitRNGRequired() const -{ - // If initialising the weight update variables require an RNG, return true - if(Utils::isRNGRequired(m_WUVarInitialisers)) { - return true; - } - - // Return true if matrix has sparse or bitmask connectivity and an RNG is required to initialise connectivity - const auto *snippet = m_SparseConnectivityInitialiser.getSnippet(); - return (((m_MatrixType & SynapseMatrixConnectivity::SPARSE) || (m_MatrixType & SynapseMatrixConnectivity::BITMASK)) - && (Utils::isRNGRequired(snippet->getRowBuildCode()) || Utils::isRNGRequired(snippet->getColBuildCode()))); -} -//---------------------------------------------------------------------------- -bool SynapseGroup::isWUPreInitRNGRequired() const -{ - return Utils::isRNGRequired(m_WUPreVarInitialisers); -} -//---------------------------------------------------------------------------- -bool SynapseGroup::isWUPostInitRNGRequired() const -{ - return Utils::isRNGRequired(m_WUPostVarInitialisers); -} -//---------------------------------------------------------------------------- -bool SynapseGroup::isHostInitRNGRequired() const -{ - return Utils::isRNGRequired(m_SparseConnectivityInitialiser.getSnippet()->getHostInitCode()); -} -//---------------------------------------------------------------------------- -bool SynapseGroup::isWUVarInitRequired() const -{ - // If this synapse group has per-synapse or kernel state variables, - // return true if any of them have initialisation code which doesn't require a kernel - if ((getMatrixType() & SynapseMatrixWeight::INDIVIDUAL) || (getMatrixType() & SynapseMatrixWeight::KERNEL)) { - return std::any_of(m_WUVarInitialisers.cbegin(), m_WUVarInitialisers.cend(), - [](const auto &init) - { - return !init.second.getSnippet()->getCode().empty() && !init.second.getSnippet()->requiresKernel(); - }); - } - else { - return false; - } -} -//---------------------------------------------------------------------------- -bool SynapseGroup::isSparseConnectivityInitRequired() const -{ - // Return true if the matrix type is sparse or bitmask - // and there is code to initialise sparse connectivity - const auto *snippet = getConnectivityInitialiser().getSnippet(); - return (((m_MatrixType & SynapseMatrixConnectivity::SPARSE) || (m_MatrixType & SynapseMatrixConnectivity::BITMASK)) - && (!snippet->getRowBuildCode().empty() || !snippet->getColBuildCode().empty())); -} -//---------------------------------------------------------------------------- SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType, unsigned int delaySteps, const WeightUpdateModels::Base *wu, const std::unordered_map &wuParams, const std::unordered_map &wuVarInitialisers, const std::unordered_map &wuPreVarInitialisers, const std::unordered_map &wuPostVarInitialisers, const PostsynapticModels::Base *ps, const std::unordered_map &psParams, const std::unordered_map &psVarInitialisers, @@ -457,63 +332,9 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType "Synapse group " + getName() + " weight update model "); getPSModel()->validate(getPSParams(), getPSVarInitialisers(), "Synapse group " + getName() + " postsynaptic model "); - // If connectivity is procedural - if(m_MatrixType & SynapseMatrixConnectivity::PROCEDURAL) { - // If there's a toeplitz initialiser, give an error - if(!m_ToeplitzConnectivityInitialiser.getSnippet()->getDiagonalBuildCode().empty()) { - throw std::runtime_error("Cannot use procedural connectivity with toeplitz initialisation snippet"); - } - - // If there's no row build code, give an error - if(m_SparseConnectivityInitialiser.getSnippet()->getRowBuildCode().empty()) { - throw std::runtime_error("Cannot use procedural connectivity without specifying a connectivity initialisation snippet with row building code"); - } - - // If there's column build code, give an error - if(!m_SparseConnectivityInitialiser.getSnippet()->getColBuildCode().empty()) { - throw std::runtime_error("Cannot use procedural connectivity with connectivity initialisation snippets with column building code"); - } - - // If the weight update model has code for postsynaptic-spike triggered updating, give an error - if(!m_WUModel->getLearnPostCode().empty()) { - throw std::runtime_error("Procedural connectivity cannot be used for synapse groups with postsynaptic spike-triggered learning"); - } - - // If weight update model has code for continuous synapse dynamics, give error - // **THINK** this would actually be pretty trivial to implement - if (!m_WUModel->getSynapseDynamicsCode().empty()) { - throw std::runtime_error("Procedural connectivity cannot be used for synapse groups with continuous synapse dynamics"); - } - } - // Otherwise, if WEIGHTS are procedural e.g. in the case of DENSE_PROCEDURALG, give error if RNG is required for weights - else if(m_MatrixType & SynapseMatrixWeight::PROCEDURAL) { - if(Utils::isRNGRequired(m_WUVarInitialisers)) { - throw std::runtime_error("Procedural weights used without procedural connectivity cannot currently access RNG."); - } - } // If synapse group has Toeplitz connectivity if(m_MatrixType & SynapseMatrixConnectivity::TOEPLITZ) { - // Give an error if there is sparse connectivity initialiser code - if(!m_SparseConnectivityInitialiser.getSnippet()->getRowBuildCode().empty() || !m_SparseConnectivityInitialiser.getSnippet()->getColBuildCode().empty()) { - throw std::runtime_error("Cannot use TOEPLITZ connectivity with sparse connectivity initialisation snippet."); - } - - // Give an error if there isn't toeplitz connectivity initialiser code - if(m_ToeplitzConnectivityInitialiser.getSnippet()->getDiagonalBuildCode().empty()) { - throw std::runtime_error("TOEPLITZ connectivity requires toeplitz connectivity initialisation snippet."); - } - - // Give an error if connectivity initialisation snippet uses RNG - if(Utils::isRNGRequired(m_ToeplitzConnectivityInitialiser.getSnippet()->getDiagonalBuildCode())) { - throw std::runtime_error("TOEPLITZ connectivity cannot currently access RNG."); - } - - // If the weight update model has code for postsynaptic-spike triggered updating, give an error - if(!m_WUModel->getLearnPostCode().empty()) { - throw std::runtime_error("TOEPLITZ connectivity cannot be used for synapse groups with postsynaptic spike-triggered learning"); - } - // If toeplitz initialisation snippet provides a function to calculate kernel size, call it auto calcKernelSizeFunc = m_ToeplitzConnectivityInitialiser.getSnippet()->getCalcKernelSizeFunc(); if(calcKernelSizeFunc) { @@ -569,7 +390,7 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType } } - // If connectivity initialisation snippet defines a kernel and matrix type doesn't support it, give error + // If connectivity initialisation snippet defines a kernel and matrix type doesn't support it, give error if(!m_KernelSize.empty() && (m_MatrixType != SynapseMatrixType::PROCEDURAL_PROCEDURALG) && (m_MatrixType != SynapseMatrixType::TOEPLITZ) && (m_MatrixType != SynapseMatrixType::SPARSE) && (m_MatrixType != SynapseMatrixType::PROCEDURAL_KERNELG)) { @@ -581,27 +402,11 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType throw std::runtime_error("BITMASK connectivity can only be used with weight update models without variables like StaticPulseConstantWeight."); } - // If connectivity is dense and there is connectivity initialiser code, give error - if((m_MatrixType & SynapseMatrixConnectivity::DENSE) - && (!m_SparseConnectivityInitialiser.getSnippet()->getRowBuildCode().empty() || !m_SparseConnectivityInitialiser.getSnippet()->getColBuildCode().empty())) - { - throw std::runtime_error("Cannot use DENSE connectivity with connectivity initialisation snippet."); - } - - // If synapse group uses sparse or procedural connectivity but no kernel size is provided, - // check that no variable's initialisation snippets require a kernel - if(((m_MatrixType == SynapseMatrixType::SPARSE) || (m_MatrixType == SynapseMatrixType::PROCEDURAL_PROCEDURALG)) && - m_KernelSize.empty() && std::any_of(getWUVarInitialisers().cbegin(), getWUVarInitialisers().cend(), - [](const auto &v) { return v.second.getSnippet()->requiresKernel(); })) - { - throw std::runtime_error("Variable initialisation snippets which use $(id_kernel) must be used with a connectivity initialisation snippet which specifies how kernel size is calculated."); - } - // Check that the source neuron group supports the desired number of delay steps srcNeuronGroup->checkNumDelaySlots(delaySteps); } //---------------------------------------------------------------------------- -void SynapseGroup::initDerivedParams(double dt) +void SynapseGroup::finalise(double dt, const Type::TypeContext &context) { auto wuDerivedParams = getWUModel()->getDerivedParams(); auto psDerivedParams = getPSModel()->getDerivedParams(); @@ -618,27 +423,159 @@ void SynapseGroup::initDerivedParams(double dt) // Initialise derived parameters for WU variable initialisers for(auto &v : m_WUVarInitialisers) { - v.second.initDerivedParams(dt); + v.second.finalise(dt, context, + "Synapse group '" + getName() + "', weight update model variable '" + v.first + "' "); } // Initialise derived parameters for PSM variable initialisers for(auto &v : m_PSVarInitialisers) { - v.second.initDerivedParams(dt); + v.second.finalise(dt, context, + "Synapse group '" + getName() + "', postsynaptic update model variable '" + v.first + "' "); } // Initialise derived parameters for WU presynaptic variable initialisers for(auto &v : m_WUPreVarInitialisers) { - v.second.initDerivedParams(dt); + v.second.finalise(dt, context, + "Synapse group '" + getName() + "' weight update model presynaptic variable '" + v.first + "' "); } // Initialise derived parameters for WU postsynaptic variable initialisers for(auto &v : m_WUPostVarInitialisers) { - v.second.initDerivedParams(dt); + v.second.finalise(dt, context, + "Synapse group '" + getName() + "' weight update model postsynaptic variable '" + v.first + "' "); } // Initialise any derived connectivity initialiser parameters - m_SparseConnectivityInitialiser.initDerivedParams(dt); - m_ToeplitzConnectivityInitialiser.initDerivedParams(dt); + m_SparseConnectivityInitialiser.finalise(dt, context, "Synapse group '" + getName() + "'"); + m_ToeplitzConnectivityInitialiser.finalise(dt, context, "Synapse group '" + getName() + "'"); + + // Scan weight update model code strings + m_WUSimCodeTokens = Utils::scanCode(getWUModel()->getSimCode(), context, + "Synapse group '" + getName() + "' weight update model sim code"); + m_WUEventCodeTokens = Utils::scanCode(getWUModel()->getEventCode(), context, + " Synapse group '" + getName() + "' weight update model event code"); + m_WUPostLearnCodeTokens = Utils::scanCode(getWUModel()->getLearnPostCode(), context, + "Synapse group '" + getName() + "' weight update model learn post code"); + m_WUSynapseDynamicsCodeTokens = Utils::scanCode(getWUModel()->getSynapseDynamicsCode(), context, + "Synapse group '" + getName() + "' weight update model synapse dynamics code"); + m_WUEventThresholdCodeTokens = Utils::scanCode(getWUModel()->getEventThresholdConditionCode(), context, + "Synapse group '" + getName() + "' weight update model event threshold code"); + m_WUPreSpikeCodeTokens = Utils::scanCode(getWUModel()->getPreSpikeCode(), context, + "Synapse group '" + getName() + "' weight update model pre spike code"); + m_WUPostSpikeCodeTokens = Utils::scanCode(getWUModel()->getPostSpikeCode(), context, + "Synapse group '" + getName() + "' weight update model post spike code"); + m_WUPreDynamicsCodeTokens = Utils::scanCode(getWUModel()->getPreDynamicsCode(), context, + "Synapse group '" + getName() + "' weight update model pre dynamics code"); + m_WUPostDynamicsCodeTokens = Utils::scanCode(getWUModel()->getPostDynamicsCode(), context, + "Synapse group '" + getName() + "' weight update model post dynamics code"); + + // Scan postsynaptic update model code strings + m_PSApplyInputCodeTokens = Utils::scanCode(getPSModel()->getApplyInputCode(), context, + "Synapse group '" + getName() + "' postsynaptic update model apply input code"); + m_PSDecayCodeTokens = Utils::scanCode(getPSModel()->getDecayCode(), context, + "Synapse group '" + getName() + "' postsynaptic update model decay code"); + + // If connectivity is procedural + if(m_MatrixType & SynapseMatrixConnectivity::PROCEDURAL) { + // If there's a toeplitz initialiser, give an error + if(!Utils::areTokensEmpty(m_ToeplitzConnectivityInitialiser.getDiagonalBuildCodeTokens())) { + throw std::runtime_error("Cannot use procedural connectivity with toeplitz initialisation snippet"); + } + + // If there's no row build code, give an error + if(Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getRowBuildCodeTokens())) { + throw std::runtime_error("Cannot use procedural connectivity without specifying a connectivity initialisation snippet with row building code"); + } + + // If there's column build code, give an error + if(!Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getColBuildCodeTokens())) { + throw std::runtime_error("Cannot use procedural connectivity with connectivity initialisation snippets with column building code"); + } + + // If the weight update model has code for postsynaptic-spike triggered updating, give an error + if(!Utils::areTokensEmpty(m_WUPostLearnCodeTokens)) { + throw std::runtime_error("Procedural connectivity cannot be used for synapse groups with postsynaptic spike-triggered learning"); + } + + // If weight update model has code for continuous synapse dynamics, give error + // **THINK** this would actually be pretty trivial to implement + if (!Utils::areTokensEmpty(m_WUSynapseDynamicsCodeTokens)) { + throw std::runtime_error("Procedural connectivity cannot be used for synapse groups with continuous synapse dynamics"); + } + } + // Otherwise, if WEIGHTS are procedural e.g. in the case of DENSE_PROCEDURALG, give error if RNG is required for weights + else if(m_MatrixType & SynapseMatrixWeight::PROCEDURAL) { + if(Utils::isRNGRequired(m_WUVarInitialisers)) { + throw std::runtime_error("Procedural weights used without procedural connectivity cannot currently access RNG."); + } + } + + // If synapse group has Toeplitz connectivity + if(m_MatrixType & SynapseMatrixConnectivity::TOEPLITZ) { + // Give an error if there is sparse connectivity initialiser code + if(!Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getRowBuildCodeTokens()) + || !Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getColBuildCodeTokens())) + { + throw std::runtime_error("Cannot use TOEPLITZ connectivity with sparse connectivity initialisation snippet."); + } + + // Give an error if there isn't toeplitz connectivity initialiser code + if(Utils::areTokensEmpty(m_ToeplitzConnectivityInitialiser.getDiagonalBuildCodeTokens())) { + throw std::runtime_error("TOEPLITZ connectivity requires toeplitz connectivity initialisation snippet."); + } + + // Give an error if connectivity initialisation snippet uses RNG + if(m_ToeplitzConnectivityInitialiser.isRNGRequired()) { + throw std::runtime_error("TOEPLITZ connectivity cannot currently access RNG."); + } + + // If the weight update model has code for postsynaptic-spike triggered updating, give an error + if(!Utils::areTokensEmpty(m_WUPostLearnCodeTokens)) { + throw std::runtime_error("TOEPLITZ connectivity cannot be used for synapse groups with postsynaptic spike-triggered learning"); + } + } + + // If connectivity is dense and there is connectivity initialiser code, give error + if((m_MatrixType & SynapseMatrixConnectivity::DENSE) + && (!Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getRowBuildCodeTokens()) + || !Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getColBuildCodeTokens()))) + { + throw std::runtime_error("Cannot use DENSE connectivity with connectivity initialisation snippet."); + } + + // If synapse group uses sparse or procedural connectivity but no kernel size is provided, + // check that no variable's initialisation snippets require a kernel + if(((m_MatrixType == SynapseMatrixType::SPARSE) || (m_MatrixType == SynapseMatrixType::PROCEDURAL_PROCEDURALG)) && + m_KernelSize.empty() && std::any_of(getWUVarInitialisers().cbegin(), getWUVarInitialisers().cend(), + [](const auto &v) { return v.second.isKernelRequired(); })) + { + throw std::runtime_error("Variable initialisation snippets which use id_kernel must be used with a " + "connectivity initialisation snippet which specifies how kernel size is calculated."); + } + + // Mark any pre or postsyaptic neuron variables referenced in sim code as requiring queues + if (!Utils::areTokensEmpty(m_WUSimCodeTokens)) { + getSrcNeuronGroup()->updatePreVarQueues(m_WUSimCodeTokens); + getTrgNeuronGroup()->updatePostVarQueues(m_WUSimCodeTokens); + } + + // Mark any pre or postsyaptic neuron variables referenced in event code as requiring queues + if (!Utils::areTokensEmpty(m_WUEventCodeTokens)) { + getSrcNeuronGroup()->updatePreVarQueues(m_WUEventCodeTokens); + getTrgNeuronGroup()->updatePostVarQueues(m_WUEventCodeTokens); + } + + // Mark any pre or postsyaptic neuron variables referenced in postsynaptic update code as requiring queues + if (!Utils::areTokensEmpty(m_WUPostLearnCodeTokens)) { + getSrcNeuronGroup()->updatePreVarQueues(m_WUPostLearnCodeTokens); + getTrgNeuronGroup()->updatePostVarQueues(m_WUPostLearnCodeTokens); + } + + // Mark any pre or postsyaptic neuron variables referenced in synapse dynamics code as requiring queues + if (!Utils::areTokensEmpty(m_WUSynapseDynamicsCodeTokens)) { + getSrcNeuronGroup()->updatePreVarQueues(m_WUSynapseDynamicsCodeTokens); + getTrgNeuronGroup()->updatePostVarQueues(m_WUSynapseDynamicsCodeTokens); + } } //---------------------------------------------------------------------------- bool SynapseGroup::canPSBeFused() const @@ -708,6 +645,101 @@ bool SynapseGroup::canWUMPostUpdateBeFused() const return true; } //---------------------------------------------------------------------------- +bool SynapseGroup::isDendriticDelayRequired() const +{ + // If addToInSynDelay function is used in sim code, return true + if(Utils::isIdentifierReferenced("addToPostDelay", getWUSimCodeTokens())) { + return true; + } + + // If addToInSynDelay function is used in event code, return true + if(Utils::isIdentifierReferenced("addToPostDelay", getWUEventCodeTokens())) { + return true; + } + + // If addToInSynDelay function is used in synapse dynamics, return tru + if(Utils::isIdentifierReferenced("addToPostDelay", getWUSynapseDynamicsCodeTokens())) { + return true; + } + + return false; +} +//---------------------------------------------------------------------------- +bool SynapseGroup::isPresynapticOutputRequired() const +{ + // If addToPre function is used in sim code, return true + if(Utils::isIdentifierReferenced("addToPre", getWUSimCodeTokens())) { + return true; + } + + // If addToPre function is used in event code, return true + if(Utils::isIdentifierReferenced("addToPre", getWUEventCodeTokens())) { + return true; + } + + // If addToPre function is used in learn post code, return true + if(Utils::isIdentifierReferenced("addToPre", getWUPostLearnCodeTokens())) { + return true; + } + + // If addToPre function is used in synapse dynamics, return tru + if(Utils::isIdentifierReferenced("addToPre", getWUSynapseDynamicsCodeTokens())) { + return true; + } + + return false; +} +//---------------------------------------------------------------------------- +bool SynapseGroup::isProceduralConnectivityRNGRequired() const +{ + if(m_MatrixType & SynapseMatrixConnectivity::PROCEDURAL) { + return m_SparseConnectivityInitialiser.isRNGRequired(); + } + else if(m_MatrixType & SynapseMatrixConnectivity::TOEPLITZ) { + return m_ToeplitzConnectivityInitialiser.isRNGRequired(); + } + else { + return false; + } +} +//---------------------------------------------------------------------------- +bool SynapseGroup::isWUInitRNGRequired() const +{ + // If initialising the weight update variables require an RNG, return true + if(Utils::isRNGRequired(m_WUVarInitialisers)) { + return true; + } + + // Return true if matrix has sparse or bitmask connectivity and an RNG is required to initialise connectivity + return (((m_MatrixType & SynapseMatrixConnectivity::SPARSE) || (m_MatrixType & SynapseMatrixConnectivity::BITMASK)) + && m_SparseConnectivityInitialiser.isRNGRequired()); +} +//---------------------------------------------------------------------------- +bool SynapseGroup::isWUVarInitRequired() const +{ + // If this synapse group has per-synapse or kernel state variables, + // return true if any of them have initialisation code which doesn't require a kernel + if ((getMatrixType() & SynapseMatrixWeight::INDIVIDUAL) || (getMatrixType() & SynapseMatrixWeight::KERNEL)) { + return std::any_of(m_WUVarInitialisers.cbegin(), m_WUVarInitialisers.cend(), + [](const auto &init) + { + return !Utils::areTokensEmpty(init.second.getCodeTokens()) && !init.second.isKernelRequired(); + }); + } + else { + return false; + } +} +//---------------------------------------------------------------------------- +bool SynapseGroup::isSparseConnectivityInitRequired() const +{ + // Return true if the matrix type is sparse or bitmask + // and there is code to initialise sparse connectivity + const auto *snippet = getConnectivityInitialiser().getSnippet(); + return (((m_MatrixType & SynapseMatrixConnectivity::SPARSE) || (m_MatrixType & SynapseMatrixConnectivity::BITMASK)) + && (!snippet->getRowBuildCode().empty() || !snippet->getColBuildCode().empty())); +} +//---------------------------------------------------------------------------- bool SynapseGroup::canPreOutputBeFused() const { // There are no variables or other non-constant objects, so these can presumably always be fused From a8267d6d15b8bddc59a1aa79a754ad71aa32b2d4 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 09:08:19 +0100 Subject: [PATCH 317/725] hook up new token-checking functions throughout --- include/genn/genn/neuronGroup.h | 2 +- include/genn/genn/neuronGroupInternal.h | 2 ++ include/genn/genn/synapseGroupInternal.h | 17 +++++++++++++++++ .../backends/single_threaded_cpu/backend.cc | 2 +- src/genn/genn/code_generator/backendSIMT.cc | 18 +++++++++--------- 5 files changed, 30 insertions(+), 11 deletions(-) diff --git a/include/genn/genn/neuronGroup.h b/include/genn/genn/neuronGroup.h index 431cabfddf..6db8869642 100644 --- a/include/genn/genn/neuronGroup.h +++ b/include/genn/genn/neuronGroup.h @@ -219,7 +219,7 @@ class GENN_EXPORT NeuronGroup const std::vector &getFusedWUPreOutSyn() const { return m_FusedWUPreOutSyn; } const std::vector &getFusedPreOutputOutSyn() const { return m_FusedPreOutputOutSyn; } - //! Does this neuron group require an RNG to simulate? + //! Does this neuron group require an RNG to simulate? bool isSimRNGRequired() const; //! Does this neuron group require an RNG for it's init code? diff --git a/include/genn/genn/neuronGroupInternal.h b/include/genn/genn/neuronGroupInternal.h index fd75b851a1..dfbf042b90 100644 --- a/include/genn/genn/neuronGroupInternal.h +++ b/include/genn/genn/neuronGroupInternal.h @@ -40,6 +40,8 @@ class NeuronGroupInternal : public NeuronGroup using NeuronGroup::getFusedOutSynWithPreCode; using NeuronGroup::getFusedInSynWithPostVars; using NeuronGroup::getFusedOutSynWithPreVars; + using NeuronGroup::isSimRNGRequired; + using NeuronGroup::isInitRNGRequired; using NeuronGroup::isVarQueueRequired; using NeuronGroup::getHashDigest; using NeuronGroup::getInitHashDigest; diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h index 6662656601..beaf77f382 100644 --- a/include/genn/genn/synapseGroupInternal.h +++ b/include/genn/genn/synapseGroupInternal.h @@ -42,6 +42,17 @@ class SynapseGroupInternal : public SynapseGroup using SynapseGroup::finalise; using SynapseGroup::addCustomUpdateReference; using SynapseGroup::isEventThresholdReTestRequired; + using SynapseGroup::getWUSimCodeTokens; + using SynapseGroup::getWUEventCodeTokens; + using SynapseGroup::getWUPostLearnCodeTokens; + using SynapseGroup::getWUSynapseDynamicsCodeTokens; + using SynapseGroup::getWUEventThresholdCodeTokens; + using SynapseGroup::getWUPreSpikeCodeTokens; + using SynapseGroup::getWUPostSpikeCodeTokens; + using SynapseGroup::getWUPreDynamicsCodeTokens; + using SynapseGroup::getWUPostDynamicsCodeTokens; + using SynapseGroup::getPSApplyInputCodeTokens; + using SynapseGroup::getPSDecayCodeTokens; using SynapseGroup::getFusedPSVarSuffix; using SynapseGroup::getFusedPreOutputSuffix; using SynapseGroup::getFusedWUPreVarSuffix; @@ -56,6 +67,12 @@ class SynapseGroupInternal : public SynapseGroup using SynapseGroup::isPSModelFused; using SynapseGroup::isWUPreModelFused; using SynapseGroup::isWUPostModelFused; + using SynapseGroup::isDendriticDelayRequired; + using SynapseGroup::isPresynapticOutputRequired; + using SynapseGroup::isProceduralConnectivityRNGRequired; + using SynapseGroup::isWUInitRNGRequired; + using SynapseGroup::isWUVarInitRequired; + using SynapseGroup::isSparseConnectivityInitRequired; using SynapseGroup::getWUHashDigest; using SynapseGroup::getWUPreHashDigest; using SynapseGroup::getWUPostHashDigest; diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index d58ef8e339..4cf58b7fa1 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1656,7 +1656,7 @@ bool Backend::isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) const if(std::any_of(model.getSynapseGroups().cbegin(), model.getSynapseGroups().cend(), [](const ModelSpec::SynapseGroupValueType &s) { - return (s.second.isWUInitRNGRequired() || s.second.isHostInitRNGRequired()); + return (s.second.isWUInitRNGRequired() || s.second.getConnectivityInitialiser().isHostRNGRequired()); })) { return true; diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index eb78d217e7..012bcb3d11 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -103,7 +103,7 @@ bool BackendSIMT::isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) co // Host RNG is required if any synapse groups or custom connectivity updates require a host RNG const ModelSpecInternal &model = modelMerged.getModel(); return (std::any_of(model.getSynapseGroups().cbegin(), model.getSynapseGroups().cend(), - [](const ModelSpec::SynapseGroupValueType &s){ return (s.second.isHostInitRNGRequired()); }) + [](const ModelSpec::SynapseGroupValueType &s){ return s.second.getConnectivityInitialiser().isHostRNGRequired(); }) || std::any_of(model.getCustomConnectivityUpdates().cbegin(), model.getCustomConnectivityUpdates().cend(), [](const ModelSpec::CustomConnectivityUpdateValueType &c){ return c.second.isHostRNGRequired(); })); } @@ -1462,8 +1462,8 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer EnvironmentGroupMergedField groupEnv(env, sg); // If there is row-building code in this snippet - const auto *snippet = sg.getArchetype().getConnectivityInitialiser().getSnippet(); - if(!snippet->getRowBuildCode().empty()) { + const auto &connectInit = sg.getArchetype().getConnectivityInitialiser(); + if(!Utils::areTokensEmpty(connectInit.getRowBuildCodeTokens())) { groupEnv.getStream() << "// only do this for existing presynaptic neurons" << std::endl; groupEnv.print("if($(id) < $(num_pre))"); @@ -1475,7 +1475,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer } // Otherwise else { - assert(!snippet->getColBuildCode().empty()); + assert(!Utils::areTokensEmpty(connectInit.getColBuildCodeTokens())); groupEnv.getStream() << "// only do this for existing postsynaptic neurons" << std::endl; groupEnv.print("if($(id) < $(num_post))"); @@ -1500,7 +1500,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // Calculate index in data structure of this synapse if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - if(!snippet->getRowBuildCode().empty()) { + if(!Utils::areTokensEmpty(connectInit.getRowBuildCodeTokens())) { kernelInit << "const unsigned int idx = ($(id_pre) * $(_row_stride)) + $(_row_length)[$(id)];" << std::endl; } else { @@ -1539,7 +1539,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // If matrix is sparse if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { // If there is row-building code in this snippet - if(!snippet->getRowBuildCode().empty()) { + if(!Utils::areTokensEmpty(connectInit.getRowBuildCodeTokens())) { kernelInit << "$(_ind)[idx] = $(0);" << std::endl; kernelInit << "$(_row_length)[$(id)]++;" << std::endl; } @@ -1554,7 +1554,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer const std::string indexType = areSixtyFourBitSynapseIndicesRequired(sg) ? "uint64_t" : "unsigned int"; // If there is row-building code in this snippet - if(!snippet->getRowBuildCode().empty()) { + if(!Utils::areTokensEmpty(connectInit.getRowBuildCodeTokens())) { kernelInit << "const " << indexType << " rowStartGID = $(id) * (" << indexType << ")($_row_stride);" << std::endl; kernelInit << getAtomic(Type::Uint32, AtomicOperation::OR) << "(&$(_gp)[(rowStartGID + ($(0))) / 32], 0x80000000 >> ((rowStartGID + ($(0))) & 31));" << std::endl; } @@ -1573,12 +1573,12 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // If this connectivity requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id - if(Utils::isRNGRequired(snippet->getRowBuildCode())) { + if(connectInit.isRNGRequired()) { groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } // If there is row-building code in this snippet - if(!snippet->getRowBuildCode().empty()) { + if(!Utils::areTokensEmpty(connectInit.getRowBuildCodeTokens())) { // If this is a sparse matrix, zero row length if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { groupEnv.printLine("$(_row_length)[$(id)] = 0;"); From d2e45d45a5b92954f6395ff717a2dbe4b1676454 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 09:26:41 +0100 Subject: [PATCH 318/725] * Defer scalar literal type resolution until type-checker * Move some parsing functionality from type into parser --- include/genn/genn/gennUtils.h | 3 +- include/genn/genn/transpiler/scanner.h | 2 +- include/genn/genn/transpiler/token.h | 2 +- include/genn/genn/transpiler/typeChecker.h | 5 +- include/genn/genn/type.h | 6 -- src/genn/genn/gennUtils.cc | 5 +- src/genn/genn/transpiler/parser.cc | 58 ++++++++++++++- src/genn/genn/transpiler/scanner.cc | 43 +++-------- src/genn/genn/transpiler/typeChecker.cc | 33 +++++---- src/genn/genn/type.cc | 83 ++++------------------ 10 files changed, 105 insertions(+), 135 deletions(-) diff --git a/include/genn/genn/gennUtils.h b/include/genn/genn/gennUtils.h index 9a9518e2d8..ac674ede7a 100644 --- a/include/genn/genn/gennUtils.h +++ b/include/genn/genn/gennUtils.h @@ -36,8 +36,7 @@ class VarInit; //-------------------------------------------------------------------------- namespace GeNN::Utils { -GENN_EXPORT std::vector scanCode(const std::string &code, const Type::TypeContext &typeContext, - const std::string &errorContext); +GENN_EXPORT std::vector scanCode(const std::string &code, const std::string &errorContext); GENN_EXPORT bool areTokensEmpty(const std::vector &tokens); diff --git a/include/genn/genn/transpiler/scanner.h b/include/genn/genn/transpiler/scanner.h index 4c2ba5f375..adfccef840 100644 --- a/include/genn/genn/transpiler/scanner.h +++ b/include/genn/genn/transpiler/scanner.h @@ -25,6 +25,6 @@ class ErrorHandlerBase; //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Scanner { -std::vector scanSource(const std::string_view &source, const Type::TypeContext &context, ErrorHandlerBase &errorHandler); +std::vector scanSource(const std::string_view &source, ErrorHandlerBase &errorHandler); } // namespace Scanner diff --git a/include/genn/genn/transpiler/token.h b/include/genn/genn/transpiler/token.h index e7d66cbf5f..d1d86afc74 100644 --- a/include/genn/genn/transpiler/token.h +++ b/include/genn/genn/transpiler/token.h @@ -37,7 +37,7 @@ struct Token SHIFT_LEFT_EQUAL, SHIFT_RIGHT_EQUAL, // Literals - IDENTIFIER, UINT32_NUMBER, INT32_NUMBER, FLOAT_NUMBER, DOUBLE_NUMBER, STRING, + IDENTIFIER, UINT32_NUMBER, INT32_NUMBER, FLOAT_NUMBER, DOUBLE_NUMBER, SCALAR_NUMBER, STRING, // Types TYPE_SPECIFIER, diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 3e7800b741..80d00c1b78 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -61,8 +61,9 @@ typedef std::function StatementHandle // Free functions //--------------------------------------------------------------------------- ResolvedTypeMap typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, - ErrorHandlerBase &errorHandler, StatementHandler forEachSynapseHandler = nullptr); + const Type::TypeContext &context, ErrorHandlerBase &errorHandler, + StatementHandler forEachSynapseHandler = nullptr); ResolvedTypeMap typeCheck(const Expression::Base *expression, EnvironmentBase &environment, - ErrorHandlerBase &errorHandler); + const Type::TypeContext &context, ErrorHandlerBase &errorHandler); } // namespace GeNN::Transpiler::TypeChecker diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 868fd341a3..da7b995123 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -353,12 +353,6 @@ inline static const ResolvedType AddToPre = ResolvedType::createFunction(Void, { inline static const ResolvedType AddToPost = ResolvedType::createFunction(Void, {Uint32}); inline static const ResolvedType AddToPostDenDelay = ResolvedType::createFunction(Void, {Uint32, Uint32}); -//! Parse a numeric type -GENN_EXPORT ResolvedType parseNumeric(const std::string &typeString, const TypeContext &context); - -//! Look up numeric type based on set of type specifiers -GENN_EXPORT ResolvedType getNumericType(const std::set &typeSpecifiers, const TypeContext &context); - //! Apply C type promotion rules to numeric type GENN_EXPORT ResolvedType getPromotedType(const ResolvedType &type); diff --git a/src/genn/genn/gennUtils.cc b/src/genn/genn/gennUtils.cc index 810461551b..919ea56d3f 100644 --- a/src/genn/genn/gennUtils.cc +++ b/src/genn/genn/gennUtils.cc @@ -63,15 +63,14 @@ std::string upgradeCodeString(const std::string &codeString) //-------------------------------------------------------------------------- namespace GeNN::Utils { -std::vector scanCode(const std::string &code, const Type::TypeContext &typeContext, - const std::string &errorContext) +std::vector scanCode(const std::string &code, const std::string &errorContext) { // Upgrade code string const std::string upgradedCode = upgradeCodeString(code); // Scan code string and return tokens Transpiler::ErrorHandler errorHandler(errorContext); - return Transpiler::Scanner::scanSource(upgradedCode, typeContext, errorHandler); + return Transpiler::Scanner::scanSource(upgradedCode, errorHandler); } //-------------------------------------------------------------------------- bool areTokensEmpty(const std::vector &tokens) diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index c43c60f694..1662dbc5e2 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -25,6 +25,36 @@ using namespace GeNN::Transpiler; //--------------------------------------------------------------------------- namespace { +const std::map, Type::ResolvedType> numericTypeSpecifiers{ + {{"char"}, Type::Int8}, + {{"int8_t"}, Type::Int8}, + + {{"unsigned", "char"}, Type::Uint8}, + {{"uint8_t"}, Type::Uint8}, + + {{"short"}, Type::Int16}, + {{"short", "int"}, Type::Int16}, + {{"signed", "short"}, Type::Int16}, + {{"signed", "short", "int"}, Type::Int16}, + {{"int16_t"}, Type::Int16}, + + {{"unsigned", "short"}, Type::Uint16}, + {{"unsigned", "short", "int"}, Type::Uint16}, + {{"uint16_t"}, Type::Uint8}, + + {{"int"}, Type::Int32}, + {{"signed"}, Type::Int32}, + {{"signed", "int"}, Type::Int32}, + {{"int32_t"}, Type::Int32}, + + {{"unsigned"}, Type::Uint32}, + {{"unsigned", "int"}, Type::Uint32}, + {{"uint32_t"}, Type::Uint32}, + + {{"float"}, Type::Float}, + {{"double"}, Type::Double}}; + + //--------------------------------------------------------------------------- // ParseError //--------------------------------------------------------------------------- @@ -142,6 +172,27 @@ class ParserState ErrorHandlerBase &m_ErrorHandler; }; +// **THINK** could leave unresolved +Type::ResolvedType getNumericType(const std::set &typeSpecifiers, const Type::TypeContext &context) +{ + // If type is numeric, return + const auto type = numericTypeSpecifiers.find(typeSpecifiers); + if (type != numericTypeSpecifiers.cend()) { + return type->second; + } + else { + // **YUCK** use sets everywhere + if (typeSpecifiers.size() == 1) { + const auto contextType = context.find(*typeSpecifiers.begin()); + if (contextType != context.cend()) { + return contextType->second; + } + } + + // **TODO** improve error + throw std::runtime_error("Unknown numeric type specifier"); + } +} void synchronise(ParserState &parserState) { @@ -241,8 +292,9 @@ Expression::ExpressionPtr parsePrimary(ParserState &parserState) // constant // "(" expression ")" if (parserState.match({Token::Type::FALSE, Token::Type::TRUE, Token::Type::STRING, - Token::Type::DOUBLE_NUMBER, Token::Type::FLOAT_NUMBER, - Token::Type::INT32_NUMBER, Token::Type::UINT32_NUMBER})) { + Token::Type::DOUBLE_NUMBER, Token::Type::FLOAT_NUMBER, + Token::Type::SCALAR_NUMBER, Token::Type::INT32_NUMBER, + Token::Type::UINT32_NUMBER})) { return std::make_unique(parserState.previous()); } else if(parserState.match(Token::Type::IDENTIFIER)) { @@ -870,6 +922,6 @@ const GeNN::Type::ResolvedType parseNumericType(const std::vector &tokens }; // Return numeric type - return GeNN::Type::getNumericType(typeSpecifiers, context); + return getNumericType(typeSpecifiers, context); } } diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index 7cfa95480e..a90dc17d8b 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -51,7 +51,8 @@ const std::unordered_map keywords{ {"int16_t", Token::Type::TYPE_SPECIFIER}, {"uint32_t", Token::Type::TYPE_SPECIFIER}, {"int32_t", Token::Type::TYPE_SPECIFIER}, - {"bool", Token::Type::TYPE_SPECIFIER}}; + {"bool", Token::Type::TYPE_SPECIFIER}, + {"scalar", Token::Type::TYPE_SPECIFIER}}; //--------------------------------------------------------------------------- // ScanState //--------------------------------------------------------------------------- @@ -59,8 +60,8 @@ const std::unordered_map keywords{ class ScanState { public: - ScanState(std::string_view source, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) - : m_Start(0), m_Current(0), m_Line(1), m_Source(source), m_Context(context), m_ErrorHandler(errorHandler) + ScanState(std::string_view source, ErrorHandlerBase &errorHandler) + : m_Start(0), m_Current(0), m_Line(1), m_Source(source), m_ErrorHandler(errorHandler) { } @@ -123,30 +124,7 @@ class ScanState { m_ErrorHandler.error(getLine(), message); } - - bool isTypedefIdentifier(std::string_view lexeme) { - return (m_Context.find(std::string{lexeme}) != m_Context.cend()); - } - Token::Type getScalarTokenType() const - { - const auto scalarType = m_Context.find("scalar"); - if (scalarType == m_Context.cend()) { - throw std::runtime_error("Cannot scan scalar literals without 'scalar' type being defined in type context"); - } - else { - if (scalarType->second == Type::Float) { - return Token::Type::FLOAT_NUMBER; - } - else if (scalarType->second == Type::Double) { - return Token::Type::DOUBLE_NUMBER; - } - else { - throw std::runtime_error("Unsupported scalar type '" + scalarType->first + "'"); - } - } - } - private: //--------------------------------------------------------------------------- // Members @@ -156,7 +134,6 @@ class ScanState size_t m_Line; std::string_view m_Source; - const Type::TypeContext &m_Context; ErrorHandlerBase &m_ErrorHandler; }; @@ -245,9 +222,9 @@ void scanNumber(char c, ScanState &scanState, std::vector &tokens) emplaceToken(tokens, Token::Type::DOUBLE_NUMBER, scanState); scanState.advance(); } - // Otherwise, emplace literal with whatever type is specified + // Otherwise, emplace scalar literal with type to be decoded later else { - emplaceToken(tokens, scanState.getScalarTokenType(), scanState); + emplaceToken(tokens, Token::Type::SCALAR_NUMBER, scanState); } } // Otherwise, emplace integer token @@ -287,10 +264,6 @@ void scanIdentifier(ScanState &scanState, std::vector &tokens) if(k != keywords.cend()) { emplaceToken(tokens, k->second, scanState); } - // Otherwise, if identifier is typedef, add type specifier token - else if (scanState.isTypedefIdentifier(scanState.getLexeme())) { - emplaceToken(tokens, Token::Type::TYPE_SPECIFIER, scanState); - } // Otherwise, add identifier token else { emplaceToken(tokens, Token::Type::IDENTIFIER, scanState); @@ -468,11 +441,11 @@ void scanToken(ScanState &scanState, std::vector &tokens) //--------------------------------------------------------------------------- namespace GeNN::Transpiler::Scanner { -std::vector scanSource(const std::string_view &source, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) +std::vector scanSource(const std::string_view &source, ErrorHandlerBase &errorHandler) { std::vector tokens; - ScanState scanState(source, context, errorHandler); + ScanState scanState(source, errorHandler); // Scan tokens while(!scanState.isAtEnd()) { diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 8297a85ed9..5a6d9890b2 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -176,8 +176,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor { public: Visitor(const Statement::StatementList &statements, EnvironmentInternal &environment, - ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler, StatementHandler forEachSynapseHandler) - : Visitor(environment, resolvedTypes, errorHandler, forEachSynapseHandler) + ResolvedTypeMap &resolvedTypes, const Type::TypeContext &context, + ErrorHandlerBase &errorHandler, StatementHandler forEachSynapseHandler) + : Visitor(environment, resolvedTypes, context, errorHandler, forEachSynapseHandler) { for (auto &s : statements) { s.get()->accept(*this); @@ -185,17 +186,20 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } Visitor(const Expression::Base *expression, EnvironmentInternal &environment, - ResolvedTypeMap &resolvedTypes, ErrorHandlerBase &errorHandler) - : Visitor(environment, resolvedTypes, errorHandler, nullptr) + ResolvedTypeMap &resolvedTypes, const Type::TypeContext &context, + ErrorHandlerBase &errorHandler) + : Visitor(environment, resolvedTypes, context, errorHandler, nullptr) { expression->accept(*this); } private: Visitor(EnvironmentInternal &environment, ResolvedTypeMap &resolvedTypes, - ErrorHandlerBase &errorHandler, StatementHandler forEachSynapseHandler) - : m_Environment(environment), m_ErrorHandler(errorHandler), m_ForEachSynapseHandler(forEachSynapseHandler), - m_ResolvedTypes(resolvedTypes), m_InLoop(false), m_InSwitch(false) + const Type::TypeContext &context, ErrorHandlerBase &errorHandler, + StatementHandler forEachSynapseHandler) + : m_Environment(environment), m_Context(context), m_ErrorHandler(errorHandler), + m_ForEachSynapseHandler(forEachSynapseHandler), m_ResolvedTypes(resolvedTypes), + m_InLoop(false), m_InSwitch(false) { } @@ -454,6 +458,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor else if (literal.getValue().type == Token::Type::FLOAT_NUMBER) { setExpressionType(&literal, Type::Float); } + else if (literal.getValue().type == Token::Type::SCALAR_NUMBER) { + setExpressionType(&literal, m_Context.at("scalar")); + } else if (literal.getValue().type == Token::Type::INT32_NUMBER) { setExpressionType(&literal, Type::Int32); } @@ -855,6 +862,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // Members //--------------------------------------------------------------------------- std::reference_wrapper m_Environment; + const Type::TypeContext &m_Context; ErrorHandlerBase &m_ErrorHandler; StatementHandler m_ForEachSynapseHandler; ResolvedTypeMap &m_ResolvedTypes; @@ -883,20 +891,21 @@ Type::ResolvedType EnvironmentBase::getType(const Token &name, ErrorHandlerBase // GeNN::Transpiler::TypeChecker //--------------------------------------------------------------------------- ResolvedTypeMap GeNN::Transpiler::TypeChecker::typeCheck(const Statement::StatementList &statements, EnvironmentBase &environment, - ErrorHandlerBase &errorHandler, StatementHandler forEachSynapseHandler) + const Type::TypeContext &context, ErrorHandlerBase &errorHandler, + StatementHandler forEachSynapseHandler) { ResolvedTypeMap expressionTypes; EnvironmentInternal internalEnvironment(environment); - Visitor visitor(statements, internalEnvironment, expressionTypes, errorHandler, - forEachSynapseHandler); + Visitor visitor(statements, internalEnvironment, expressionTypes, + context, errorHandler, forEachSynapseHandler); return expressionTypes; } //--------------------------------------------------------------------------- ResolvedTypeMap GeNN::Transpiler::TypeChecker::typeCheck(const Expression::Base *expression, EnvironmentBase &environment, - ErrorHandlerBase &errorHandler) + const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { ResolvedTypeMap expressionTypes; EnvironmentInternal internalEnvironment(environment); - Visitor visitor(expression, internalEnvironment, expressionTypes, errorHandler); + Visitor visitor(expression, internalEnvironment, expressionTypes, context, errorHandler); return expressionTypes; } diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 4e8265ee84..507e259c30 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -19,35 +19,6 @@ using namespace GeNN; // Anonymous namespace namespace { -const std::map, Type::ResolvedType> numericTypeSpecifiers{ - {{"char"}, Type::Int8}, - {{"int8_t"}, Type::Int8}, - - {{"unsigned", "char"}, Type::Uint8}, - {{"uint8_t"}, Type::Uint8}, - - {{"short"}, Type::Int16}, - {{"short", "int"}, Type::Int16}, - {{"signed", "short"}, Type::Int16}, - {{"signed", "short", "int"}, Type::Int16}, - {{"int16_t"}, Type::Int16}, - - {{"unsigned", "short"}, Type::Uint16}, - {{"unsigned", "short", "int"}, Type::Uint16}, - {{"uint16_t"}, Type::Uint8}, - - {{"int"}, Type::Int32}, - {{"signed"}, Type::Int32}, - {{"signed", "int"}, Type::Int32}, - {{"int32_t"}, Type::Int32}, - - {{"unsigned"}, Type::Uint32}, - {{"unsigned", "int"}, Type::Uint32}, - {{"uint32_t"}, Type::Uint32}, - - {{"float"}, Type::Float}, - {{"double"}, Type::Double}}; -//---------------------------------------------------------------------------- // Mapping of signed integer numericTypeSpecifiers to their unsigned equivalents const std::map unsignedType{ {Type::Int8, Type::Uint8}, @@ -123,51 +94,23 @@ ResolvedType UnresolvedType::resolve(const TypeContext &typeContext) const }, [&typeContext](const std::string &name) { - return parseNumeric(name, typeContext); - }}, - detail); -} -//---------------------------------------------------------------------------- -// Free functions -//---------------------------------------------------------------------------- -ResolvedType parseNumeric(const std::string &typeString, const TypeContext &context) -{ - using namespace Transpiler; - - // Scan type - SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(typeString, context, errorHandler); + using namespace Transpiler; - // Parse type numeric type - const auto type = Parser::parseNumericType(tokens, context, errorHandler); + // Scan type + SingleLineErrorHandler errorHandler; + const auto tokens = Scanner::scanSource(name, errorHandler); - // If an error was encountered while scanning or parsing, throw exception - if (errorHandler.hasError()) { - throw std::runtime_error("Error parsing type '" + std::string{typeString} + "'"); - } + // Parse type numeric type + const auto type = Parser::parseNumericType(tokens, typeContext, errorHandler); - return type; -} -//---------------------------------------------------------------------------- -ResolvedType getNumericType(const std::set &typeSpecifiers, const TypeContext &context) -{ - // If type is numeric, return - const auto type = numericTypeSpecifiers.find(typeSpecifiers); - if (type != numericTypeSpecifiers.cend()) { - return type->second; - } - else { - // **YUCK** use sets everywhere - if (typeSpecifiers.size() == 1) { - const auto contextType = context.find(*typeSpecifiers.begin()); - if (contextType != context.cend()) { - return contextType->second; - } - } + // If an error was encountered while scanning or parsing, throw exception + if (errorHandler.hasError()) { + throw std::runtime_error("Error parsing type '" + std::string{name} + "'"); + } - // **TODO** improve error - throw std::runtime_error("Unknown numeric type specifier"); - } + return type; + }}, + detail); } //---------------------------------------------------------------------------- ResolvedType getPromotedType(const ResolvedType &type) From 6400722468774401af08b36e098005b753aa7a0a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 10:31:02 +0100 Subject: [PATCH 319/725] moved scanning to constructors --- .../genn/genn/initSparseConnectivitySnippet.h | 6 +- .../genn/initToeplitzConnectivitySnippet.h | 4 +- include/genn/genn/models.h | 4 +- include/genn/genn/neuronGroup.h | 2 +- include/genn/genn/snippet.h | 1 - .../genn/initSparseConnectivitySnippet.cc | 12 +- .../genn/initToeplitzConnectivitySnippet.cc | 7 +- src/genn/genn/models.cc | 8 +- src/genn/genn/neuronGroup.cc | 20 +- src/genn/genn/synapseGroup.cc | 225 +++++++++--------- 10 files changed, 131 insertions(+), 158 deletions(-) diff --git a/include/genn/genn/initSparseConnectivitySnippet.h b/include/genn/genn/initSparseConnectivitySnippet.h index 610116515a..fc33ff9927 100644 --- a/include/genn/genn/initSparseConnectivitySnippet.h +++ b/include/genn/genn/initSparseConnectivitySnippet.h @@ -72,16 +72,14 @@ class GENN_EXPORT Base : public Snippet::Base //---------------------------------------------------------------------------- // Init //---------------------------------------------------------------------------- -class Init : public Snippet::Init +class Init : public Snippet::Init { public: - using Snippet::Init::Init; + Init(const Base *snippet, const std::unordered_map ¶ms); //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - void finalise(double dt, const Type::TypeContext &context, const std::string &errorContext); - bool isRNGRequired() const; bool isHostRNGRequired() const; diff --git a/include/genn/genn/initToeplitzConnectivitySnippet.h b/include/genn/genn/initToeplitzConnectivitySnippet.h index e32e18f42c..a41b3d63c9 100644 --- a/include/genn/genn/initToeplitzConnectivitySnippet.h +++ b/include/genn/genn/initToeplitzConnectivitySnippet.h @@ -61,13 +61,11 @@ class GENN_EXPORT Base : public Snippet::Base class Init : public Snippet::Init { public: - using Snippet::Init::Init; + Init(const Base *snippet, const std::unordered_map ¶ms); //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - void finalise(double dt, const Type::TypeContext &context, const std::string &errorContext); - bool isRNGRequired() const; const std::vector &getDiagonalBuildCodeTokens() const{ return m_DiagonalBuildCodeTokens; } diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 109fa2873b..ad37cd34ae 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -127,13 +127,11 @@ class GENN_EXPORT Base : public Snippet::Base class VarInit : public Snippet::Init { public: - using Snippet::Init::Init; + VarInit(const InitVarSnippet::Base *snippet, const std::unordered_map ¶ms); //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - void finalise(double dt, const Type::TypeContext &context, const std::string &errorContext); - bool isRNGRequired() const; bool isKernelRequired() const; diff --git a/include/genn/genn/neuronGroup.h b/include/genn/genn/neuronGroup.h index 6db8869642..8d01bdb9bf 100644 --- a/include/genn/genn/neuronGroup.h +++ b/include/genn/genn/neuronGroup.h @@ -198,7 +198,7 @@ class GENN_EXPORT NeuronGroup void addInSyn(SynapseGroupInternal *synapseGroup){ m_InSyn.push_back(synapseGroup); } void addOutSyn(SynapseGroupInternal *synapseGroup){ m_OutSyn.push_back(synapseGroup); } - void finalise(double dt, const Type::TypeContext &context); + void finalise(double dt); //! Fuse incoming postsynaptic models void fusePrePostSynapses(bool fusePSM, bool fusePrePostWUM); diff --git a/include/genn/genn/snippet.h b/include/genn/genn/snippet.h index 854e224058..9eff48738b 100644 --- a/include/genn/genn/snippet.h +++ b/include/genn/genn/snippet.h @@ -197,7 +197,6 @@ class Init return getSnippet()->getHashDigest(); } -protected: void finalise(double dt) { auto derivedParams = m_Snippet->getDerivedParams(); diff --git a/src/genn/genn/initSparseConnectivitySnippet.cc b/src/genn/genn/initSparseConnectivitySnippet.cc index ad3813cde1..faae9d87ee 100644 --- a/src/genn/genn/initSparseConnectivitySnippet.cc +++ b/src/genn/genn/initSparseConnectivitySnippet.cc @@ -42,15 +42,13 @@ void Base::validate(const std::unordered_map ¶mValues) //---------------------------------------------------------------------------- // GeNN::InitSparseConnectivitySnippet::Init //---------------------------------------------------------------------------- -void Init::finalise(double dt, const Type::TypeContext &context, const std::string &errorContext) +Init::Init(const Base *snippet, const std::unordered_map ¶ms) +: Snippet::Init(snippet, params) { - // Superclass - Snippet::Init::finalise(dt); - // Scan code tokens - m_RowBuildCodeTokens = Utils::scanCode(getSnippet()->getRowBuildCode(), context, errorContext + "row build code"); - m_ColBuildCodeTokens = Utils::scanCode(getSnippet()->getColBuildCode(), context, errorContext + "col build code"); - m_HostInitCodeTokens = Utils::scanCode(getSnippet()->getHostInitCode(), context, errorContext + "host init code"); + m_RowBuildCodeTokens = Utils::scanCode(getSnippet()->getRowBuildCode(), "Row build code"); + m_ColBuildCodeTokens = Utils::scanCode(getSnippet()->getColBuildCode(), "Col build code"); + m_HostInitCodeTokens = Utils::scanCode(getSnippet()->getHostInitCode(), "Host init code"); } //---------------------------------------------------------------------------- bool Init::isRNGRequired() const diff --git a/src/genn/genn/initToeplitzConnectivitySnippet.cc b/src/genn/genn/initToeplitzConnectivitySnippet.cc index 7d9283f82c..3228645791 100644 --- a/src/genn/genn/initToeplitzConnectivitySnippet.cc +++ b/src/genn/genn/initToeplitzConnectivitySnippet.cc @@ -33,15 +33,12 @@ void Base::validate(const std::unordered_map ¶mValues) //---------------------------------------------------------------------------- // GeNN::InitToeplitzConnectivitySnippet::Init //---------------------------------------------------------------------------- -void Init::finalise(double dt, const Type::TypeContext &context, const std::string &errorContext) +Init::Init(const Base *snippet, const std::unordered_map ¶ms) +: Snippet::Init(snippet, params) { - // Superclass - Snippet::Init::finalise(dt); - // Scan code tokens m_DiagonalBuildCodeTokens = Utils::scanCode(getSnippet()->getDiagonalBuildCode(), context, errorContext + "diagonal build code"); } -//---------------------------------------------------------------------------- bool Init::isRNGRequired() const { return Utils::isRNGRequired(m_DiagonalBuildCodeTokens); diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 12f3925282..736c426ad5 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -37,13 +37,11 @@ void Base::validate(const std::unordered_map ¶mValues, //---------------------------------------------------------------------------- // VarInit //---------------------------------------------------------------------------- -void VarInit::finalise(double dt, const Type::TypeContext &context, const std::string &errorContext) +VarInit::VarInit(const InitVarSnippet::Base *snippet, const std::unordered_map ¶ms) +: Snippet::Init(snippet, params) { - // Superclass - Snippet::Init::finalise(dt); - // Scan code tokens - m_CodeTokens = Utils::scanCode(getSnippet()->getCode(), context, errorContext + "initialisation code"); + m_CodeTokens = Utils::scanCode(getSnippet()->getCode(), "Variable initialisation code"); } //---------------------------------------------------------------------------- bool VarInit::isRNGRequired() const diff --git a/src/genn/genn/neuronGroup.cc b/src/genn/genn/neuronGroup.cc index bbf5873135..834d3d9cca 100644 --- a/src/genn/genn/neuronGroup.cc +++ b/src/genn/genn/neuronGroup.cc @@ -316,6 +316,14 @@ NeuronGroup::NeuronGroup(const std::string &name, int numNeurons, const NeuronMo // Validate names Utils::validatePopName(name, "Neuron group"); getNeuronModel()->validate(getParams(), getVarInitialisers(), "Neuron group " + getName()); + + // Scan neuron model code strings + m_SimCodeTokens = Utils::scanCode(getNeuronModel()->getSimCode(), + "Neuron group '" + getName() + "' sim code"); + m_ThresholdConditionCodeTokens = Utils::scanCode(getNeuronModel()->getThresholdConditionCode(), + "Neuron group '" + getName() + "' threshold condition code"); + m_SimCodeTokens = Utils::scanCode(getNeuronModel()->getResetCode(), + "Neuron group '" + getName() + "' reset code"); } //---------------------------------------------------------------------------- void NeuronGroup::checkNumDelaySlots(unsigned int requiredDelay) @@ -335,7 +343,7 @@ void NeuronGroup::updatePostVarQueues(const std::vector &toke updateVarQueues(tokens, "_post"); } //---------------------------------------------------------------------------- -void NeuronGroup::finalise(double dt, const Type::TypeContext &context) +void NeuronGroup::finalise(double dt) { auto derivedParams = getNeuronModel()->getDerivedParams(); @@ -346,16 +354,8 @@ void NeuronGroup::finalise(double dt, const Type::TypeContext &context) // Finalise variable initialisers for(auto &v : m_VarInitialisers) { - v.second.finalise(dt, context, "Variable '" + v.first + "' "); + v.second.finalise(dt); } - - // Scan neuron model code strings - m_SimCodeTokens = Utils::scanCode(getNeuronModel()->getSimCode(), context, - "Neuron group '" + getName() + "' sim code"); - m_ThresholdConditionCodeTokens = Utils::scanCode(getNeuronModel()->getThresholdConditionCode(), context, - "Neuron group '" + getName() + "' threshold condition code"); - m_SimCodeTokens = Utils::scanCode(getNeuronModel()->getResetCode(), context, - "Neuron group '" + getName() + "' reset code"); } //---------------------------------------------------------------------------- void NeuronGroup::fusePrePostSynapses(bool fusePSM, bool fusePrePostWUM) diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 4db333974d..d16c3652c9 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -332,9 +332,91 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType "Synapse group " + getName() + " weight update model "); getPSModel()->validate(getPSParams(), getPSVarInitialisers(), "Synapse group " + getName() + " postsynaptic model "); + // Scan weight update model code strings + m_WUSimCodeTokens = Utils::scanCode( + getWUModel()->getSimCode(), "Synapse group '" + getName() + "' weight update model sim code"); + m_WUEventCodeTokens = Utils::scanCode( + getWUModel()->getEventCode(), "Synapse group '" + getName() + "' weight update model event code"); + m_WUPostLearnCodeTokens = Utils::scanCode( + getWUModel()->getLearnPostCode(), "Synapse group '" + getName() + "' weight update model learn post code"); + m_WUSynapseDynamicsCodeTokens = Utils::scanCode( + getWUModel()->getSynapseDynamicsCode(), "Synapse group '" + getName() + "' weight update model synapse dynamics code"); + m_WUEventThresholdCodeTokens = Utils::scanCode( + getWUModel()->getEventThresholdConditionCode(), "Synapse group '" + getName() + "' weight update model event threshold code"); + m_WUPreSpikeCodeTokens = Utils::scanCode( + getWUModel()->getPreSpikeCode(), "Synapse group '" + getName() + "' weight update model pre spike code"); + m_WUPostSpikeCodeTokens = Utils::scanCode( + getWUModel()->getPostSpikeCode(), "Synapse group '" + getName() + "' weight update model post spike code"); + m_WUPreDynamicsCodeTokens = Utils::scanCode( + getWUModel()->getPreDynamicsCode(), "Synapse group '" + getName() + "' weight update model pre dynamics code"); + m_WUPostDynamicsCodeTokens = Utils::scanCode( + getWUModel()->getPostDynamicsCode(), "Synapse group '" + getName() + "' weight update model post dynamics code"); + + // Scan postsynaptic update model code strings + m_PSApplyInputCodeTokens = Utils::scanCode( + getPSModel()->getApplyInputCode(), "Synapse group '" + getName() + "' postsynaptic update model apply input code"); + m_PSDecayCodeTokens = Utils::scanCode( + getPSModel()->getDecayCode(), "Synapse group '" + getName() + "' postsynaptic update model decay code"); + + // If connectivity is procedural + if(m_MatrixType & SynapseMatrixConnectivity::PROCEDURAL) { + // If there's a toeplitz initialiser, give an error + if(!Utils::areTokensEmpty(m_ToeplitzConnectivityInitialiser.getDiagonalBuildCodeTokens())) { + throw std::runtime_error("Cannot use procedural connectivity with toeplitz initialisation snippet"); + } + + // If there's no row build code, give an error + if(Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getRowBuildCodeTokens())) { + throw std::runtime_error("Cannot use procedural connectivity without specifying a connectivity initialisation snippet with row building code"); + } + + // If there's column build code, give an error + if(!Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getColBuildCodeTokens())) { + throw std::runtime_error("Cannot use procedural connectivity with connectivity initialisation snippets with column building code"); + } + + // If the weight update model has code for postsynaptic-spike triggered updating, give an error + if(!Utils::areTokensEmpty(m_WUPostLearnCodeTokens)) { + throw std::runtime_error("Procedural connectivity cannot be used for synapse groups with postsynaptic spike-triggered learning"); + } + + // If weight update model has code for continuous synapse dynamics, give error + // **THINK** this would actually be pretty trivial to implement + if (!Utils::areTokensEmpty(m_WUSynapseDynamicsCodeTokens)) { + throw std::runtime_error("Procedural connectivity cannot be used for synapse groups with continuous synapse dynamics"); + } + } + // Otherwise, if WEIGHTS are procedural e.g. in the case of DENSE_PROCEDURALG, give error if RNG is required for weights + else if(m_MatrixType & SynapseMatrixWeight::PROCEDURAL) { + if(Utils::isRNGRequired(m_WUVarInitialisers)) { + throw std::runtime_error("Procedural weights used without procedural connectivity cannot currently access RNG."); + } + } // If synapse group has Toeplitz connectivity if(m_MatrixType & SynapseMatrixConnectivity::TOEPLITZ) { + // Give an error if there is sparse connectivity initialiser code + if(!Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getRowBuildCodeTokens()) + || !Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getColBuildCodeTokens())) + { + throw std::runtime_error("Cannot use TOEPLITZ connectivity with sparse connectivity initialisation snippet."); + } + + // Give an error if there isn't toeplitz connectivity initialiser code + if(Utils::areTokensEmpty(m_ToeplitzConnectivityInitialiser.getDiagonalBuildCodeTokens())) { + throw std::runtime_error("TOEPLITZ connectivity requires toeplitz connectivity initialisation snippet."); + } + + // Give an error if connectivity initialisation snippet uses RNG + if(Utils::isRNGRequired(m_ToeplitzConnectivityInitialiser.getDiagonalBuildCodeTokens())) { + throw std::runtime_error("TOEPLITZ connectivity cannot currently access RNG."); + } + + // If the weight update model has code for postsynaptic-spike triggered updating, give an error + if(!Utils::areTokensEmpty(m_WUPostLearnCodeTokens)) { + throw std::runtime_error("TOEPLITZ connectivity cannot be used for synapse groups with postsynaptic spike-triggered learning"); + } + // If toeplitz initialisation snippet provides a function to calculate kernel size, call it auto calcKernelSizeFunc = m_ToeplitzConnectivityInitialiser.getSnippet()->getCalcKernelSizeFunc(); if(calcKernelSizeFunc) { @@ -390,16 +472,28 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType } } - // If connectivity initialisation snippet defines a kernel and matrix type doesn't support it, give error + // If connectivity initialisation snippet defines a kernel and matrix type doesn't support it, give error if(!m_KernelSize.empty() && (m_MatrixType != SynapseMatrixType::PROCEDURAL_PROCEDURALG) && (m_MatrixType != SynapseMatrixType::TOEPLITZ) && (m_MatrixType != SynapseMatrixType::SPARSE) && (m_MatrixType != SynapseMatrixType::PROCEDURAL_KERNELG)) { - throw std::runtime_error("Connectivity initialisation snippet which use a kernel can only be used with PROCEDURAL_PROCEDURALG, PROCEDURAL_KERNELG, TOEPLITZ or SPARSE connectivity."); + throw std::runtime_error("BITMASK connectivity can only be used with weight update models without variables like StaticPulseConstantWeight."); } - // Check BITMASK connectivity isn't used with models with variables - if((m_MatrixType & SynapseMatrixConnectivity::BITMASK) && !m_WUModel->getVars().empty()) { - throw std::runtime_error("BITMASK connectivity can only be used with weight update models without variables like StaticPulseConstantWeight."); + // If connectivity is dense and there is connectivity initialiser code, give error + if((m_MatrixType & SynapseMatrixConnectivity::DENSE) + && (!Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getRowBuildCodeTokens()) + || !Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getColBuildCodeTokens()))) + { + throw std::runtime_error("Cannot use DENSE connectivity with connectivity initialisation snippet."); + } + + // If synapse group uses sparse or procedural connectivity but no kernel size is provided, + // check that no variable's initialisation snippets require a kernel + if(((m_MatrixType == SynapseMatrixType::SPARSE) || (m_MatrixType == SynapseMatrixType::PROCEDURAL_PROCEDURALG)) && + m_KernelSize.empty() && std::any_of(getWUVarInitialisers().cbegin(), getWUVarInitialisers().cend(), + [](const auto &v) { return v.second.isKernelRequired(); })) + { + throw std::runtime_error("Variable initialisation snippets which use $(id_kernel) must be used with a connectivity initialisation snippet which specifies how kernel size is calculated."); } // Check that the source neuron group supports the desired number of delay steps @@ -423,136 +517,29 @@ void SynapseGroup::finalise(double dt, const Type::TypeContext &context) // Initialise derived parameters for WU variable initialisers for(auto &v : m_WUVarInitialisers) { - v.second.finalise(dt, context, - "Synapse group '" + getName() + "', weight update model variable '" + v.first + "' "); + v.second.finalise(dt); } // Initialise derived parameters for PSM variable initialisers for(auto &v : m_PSVarInitialisers) { - v.second.finalise(dt, context, - "Synapse group '" + getName() + "', postsynaptic update model variable '" + v.first + "' "); + v.second.finalise(dt); } // Initialise derived parameters for WU presynaptic variable initialisers for(auto &v : m_WUPreVarInitialisers) { - v.second.finalise(dt, context, - "Synapse group '" + getName() + "' weight update model presynaptic variable '" + v.first + "' "); + v.second.finalise(dt); } // Initialise derived parameters for WU postsynaptic variable initialisers for(auto &v : m_WUPostVarInitialisers) { - v.second.finalise(dt, context, - "Synapse group '" + getName() + "' weight update model postsynaptic variable '" + v.first + "' "); + v.second.finalise(dt); } // Initialise any derived connectivity initialiser parameters - m_SparseConnectivityInitialiser.finalise(dt, context, "Synapse group '" + getName() + "'"); - m_ToeplitzConnectivityInitialiser.finalise(dt, context, "Synapse group '" + getName() + "'"); - - // Scan weight update model code strings - m_WUSimCodeTokens = Utils::scanCode(getWUModel()->getSimCode(), context, - "Synapse group '" + getName() + "' weight update model sim code"); - m_WUEventCodeTokens = Utils::scanCode(getWUModel()->getEventCode(), context, - " Synapse group '" + getName() + "' weight update model event code"); - m_WUPostLearnCodeTokens = Utils::scanCode(getWUModel()->getLearnPostCode(), context, - "Synapse group '" + getName() + "' weight update model learn post code"); - m_WUSynapseDynamicsCodeTokens = Utils::scanCode(getWUModel()->getSynapseDynamicsCode(), context, - "Synapse group '" + getName() + "' weight update model synapse dynamics code"); - m_WUEventThresholdCodeTokens = Utils::scanCode(getWUModel()->getEventThresholdConditionCode(), context, - "Synapse group '" + getName() + "' weight update model event threshold code"); - m_WUPreSpikeCodeTokens = Utils::scanCode(getWUModel()->getPreSpikeCode(), context, - "Synapse group '" + getName() + "' weight update model pre spike code"); - m_WUPostSpikeCodeTokens = Utils::scanCode(getWUModel()->getPostSpikeCode(), context, - "Synapse group '" + getName() + "' weight update model post spike code"); - m_WUPreDynamicsCodeTokens = Utils::scanCode(getWUModel()->getPreDynamicsCode(), context, - "Synapse group '" + getName() + "' weight update model pre dynamics code"); - m_WUPostDynamicsCodeTokens = Utils::scanCode(getWUModel()->getPostDynamicsCode(), context, - "Synapse group '" + getName() + "' weight update model post dynamics code"); - - // Scan postsynaptic update model code strings - m_PSApplyInputCodeTokens = Utils::scanCode(getPSModel()->getApplyInputCode(), context, - "Synapse group '" + getName() + "' postsynaptic update model apply input code"); - m_PSDecayCodeTokens = Utils::scanCode(getPSModel()->getDecayCode(), context, - "Synapse group '" + getName() + "' postsynaptic update model decay code"); - - // If connectivity is procedural - if(m_MatrixType & SynapseMatrixConnectivity::PROCEDURAL) { - // If there's a toeplitz initialiser, give an error - if(!Utils::areTokensEmpty(m_ToeplitzConnectivityInitialiser.getDiagonalBuildCodeTokens())) { - throw std::runtime_error("Cannot use procedural connectivity with toeplitz initialisation snippet"); - } - - // If there's no row build code, give an error - if(Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getRowBuildCodeTokens())) { - throw std::runtime_error("Cannot use procedural connectivity without specifying a connectivity initialisation snippet with row building code"); - } - - // If there's column build code, give an error - if(!Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getColBuildCodeTokens())) { - throw std::runtime_error("Cannot use procedural connectivity with connectivity initialisation snippets with column building code"); - } - - // If the weight update model has code for postsynaptic-spike triggered updating, give an error - if(!Utils::areTokensEmpty(m_WUPostLearnCodeTokens)) { - throw std::runtime_error("Procedural connectivity cannot be used for synapse groups with postsynaptic spike-triggered learning"); - } - - // If weight update model has code for continuous synapse dynamics, give error - // **THINK** this would actually be pretty trivial to implement - if (!Utils::areTokensEmpty(m_WUSynapseDynamicsCodeTokens)) { - throw std::runtime_error("Procedural connectivity cannot be used for synapse groups with continuous synapse dynamics"); - } - } - // Otherwise, if WEIGHTS are procedural e.g. in the case of DENSE_PROCEDURALG, give error if RNG is required for weights - else if(m_MatrixType & SynapseMatrixWeight::PROCEDURAL) { - if(Utils::isRNGRequired(m_WUVarInitialisers)) { - throw std::runtime_error("Procedural weights used without procedural connectivity cannot currently access RNG."); - } - } - - // If synapse group has Toeplitz connectivity - if(m_MatrixType & SynapseMatrixConnectivity::TOEPLITZ) { - // Give an error if there is sparse connectivity initialiser code - if(!Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getRowBuildCodeTokens()) - || !Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getColBuildCodeTokens())) - { - throw std::runtime_error("Cannot use TOEPLITZ connectivity with sparse connectivity initialisation snippet."); - } - - // Give an error if there isn't toeplitz connectivity initialiser code - if(Utils::areTokensEmpty(m_ToeplitzConnectivityInitialiser.getDiagonalBuildCodeTokens())) { - throw std::runtime_error("TOEPLITZ connectivity requires toeplitz connectivity initialisation snippet."); - } - - // Give an error if connectivity initialisation snippet uses RNG - if(m_ToeplitzConnectivityInitialiser.isRNGRequired()) { - throw std::runtime_error("TOEPLITZ connectivity cannot currently access RNG."); - } - - // If the weight update model has code for postsynaptic-spike triggered updating, give an error - if(!Utils::areTokensEmpty(m_WUPostLearnCodeTokens)) { - throw std::runtime_error("TOEPLITZ connectivity cannot be used for synapse groups with postsynaptic spike-triggered learning"); - } - } - - // If connectivity is dense and there is connectivity initialiser code, give error - if((m_MatrixType & SynapseMatrixConnectivity::DENSE) - && (!Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getRowBuildCodeTokens()) - || !Utils::areTokensEmpty(m_SparseConnectivityInitialiser.getColBuildCodeTokens()))) - { - throw std::runtime_error("Cannot use DENSE connectivity with connectivity initialisation snippet."); - } - - // If synapse group uses sparse or procedural connectivity but no kernel size is provided, - // check that no variable's initialisation snippets require a kernel - if(((m_MatrixType == SynapseMatrixType::SPARSE) || (m_MatrixType == SynapseMatrixType::PROCEDURAL_PROCEDURALG)) && - m_KernelSize.empty() && std::any_of(getWUVarInitialisers().cbegin(), getWUVarInitialisers().cend(), - [](const auto &v) { return v.second.isKernelRequired(); })) - { - throw std::runtime_error("Variable initialisation snippets which use id_kernel must be used with a " - "connectivity initialisation snippet which specifies how kernel size is calculated."); - } + m_SparseConnectivityInitialiser.finalise(dt); + m_ToeplitzConnectivityInitialiser.finalise(dt); + // Mark any pre or postsyaptic neuron variables referenced in sim code as requiring queues if (!Utils::areTokensEmpty(m_WUSimCodeTokens)) { getSrcNeuronGroup()->updatePreVarQueues(m_WUSimCodeTokens); From 11f3177f21f48acdfc64a71c64fdebb45a759e49 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 11:50:43 +0100 Subject: [PATCH 320/725] take tokens rather than strings into prettyPrint helpers --- .../genn/genn/code_generator/codeGenUtils.h | 8 +++---- src/genn/genn/code_generator/codeGenUtils.cc | 21 ++++--------------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index bc2f7e056d..406ece7faa 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -64,16 +64,16 @@ GENN_EXPORT std::string disambiguateNamespaceFunction(const std::string supportC GENN_EXPORT std::string upgradeCodeString(const std::string &codeString); //-------------------------------------------------------------------------- -/*! \brief This function uses the transpiler to scan, parse, type check and pretty print expression contained in a code string +/*! \brief This function uses the transpiler to parse, type check and pretty print previously scanned vector of tokens representing an expression */ //-------------------------------------------------------------------------- -GENN_EXPORT void prettyPrintExpression(const std::string &code, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, Transpiler::ErrorHandlerBase &errorHandler); +GENN_EXPORT void prettyPrintExpression(const std::vector &tokens, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, Transpiler::ErrorHandlerBase &errorHandler); //-------------------------------------------------------------------------- -/*! \brief This function uses the transpiler to scan, parse, type check and pretty print statametns contained in a code string +/*! \brief This function uses the transpiler to parse, type check and pretty print previously scanned vector of tokens representing a statemebt */ //-------------------------------------------------------------------------- -GENN_EXPORT void prettyPrintStatements(const std::string &code, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, +GENN_EXPORT void prettyPrintStatements(const std::vector &tokens, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, Transpiler::ErrorHandlerBase &errorHandler, Transpiler::TypeChecker::StatementHandler forEachSynapseTypeCheckHandler = nullptr, Transpiler::PrettyPrinter::StatementHandler forEachSynapsePrettyPrintHandler = nullptr); diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index fdb4bd5a05..a5a7d6b87e 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -31,7 +31,6 @@ // GeNN transpiler includes #include "transpiler/parser.h" #include "transpiler/prettyPrinter.h" -#include "transpiler/scanner.h" //-------------------------------------------------------------------------- // Anonymous namespace @@ -124,43 +123,31 @@ std::string upgradeCodeString(const std::string &codeString) return upgradedCodeString; } //---------------------------------------------------------------------------- -void prettyPrintExpression(const std::string &code, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, Transpiler::ErrorHandlerBase &errorHandler) +void prettyPrintExpression(const std::vector &tokens, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, Transpiler::ErrorHandlerBase &errorHandler) { using namespace Transpiler; - // Upgrade code string - const std::string upgradedCode = upgradeCodeString(code); - - // Scan code string to convert to tokens - const auto tokens = Scanner::scanSource(upgradedCode, typeContext, errorHandler); - // Parse tokens as expression auto expression = Parser::parseExpression(tokens, typeContext, errorHandler); // Resolve types - auto resolvedTypes = TypeChecker::typeCheck(expression.get(), env, errorHandler); + auto resolvedTypes = TypeChecker::typeCheck(expression.get(), env, typeContext, errorHandler); // Pretty print PrettyPrinter::print(expression, env, typeContext, resolvedTypes); } //-------------------------------------------------------------------------- -void prettyPrintStatements(const std::string &code, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, +void prettyPrintStatements(const std::vector &tokens, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, Transpiler::ErrorHandlerBase &errorHandler, Transpiler::TypeChecker::StatementHandler forEachSynapseTypeCheckHandler, Transpiler::PrettyPrinter::StatementHandler forEachSynapsePrettyPrintHandler) { using namespace Transpiler; - - // Upgrade code string - const std::string upgradedCode = upgradeCodeString(code); - - // Scan code string to convert to tokens - const auto tokens = Scanner::scanSource(upgradedCode, typeContext, errorHandler); // Parse tokens as block item list (function body) auto updateStatements = Parser::parseBlockItemList(tokens, typeContext, errorHandler); // Resolve types - auto resolvedTypes= TypeChecker::typeCheck(updateStatements, env, errorHandler, forEachSynapseTypeCheckHandler); + auto resolvedTypes= TypeChecker::typeCheck(updateStatements, env, typeContext, errorHandler, forEachSynapseTypeCheckHandler); // Pretty print PrettyPrinter::print(updateStatements, env, typeContext, resolvedTypes, forEachSynapsePrettyPrintHandler); From 14067ccb85c51272d6ec2cd9b495a1cc2d4d7195 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 12:02:19 +0100 Subject: [PATCH 321/725] BrEify finalize throughout and merged initDerivedParams with finalise --- include/genn/genn/currentSource.h | 2 +- include/genn/genn/currentSourceInternal.h | 2 +- include/genn/genn/customConnectivityUpdate.h | 4 +-- .../genn/customConnectivityUpdateInternal.h | 2 +- include/genn/genn/customUpdate.h | 6 ++--- include/genn/genn/customUpdateInternal.h | 5 ++-- .../genn/code_generator/generateRunner.cc | 2 +- src/genn/genn/currentSource.cc | 2 +- src/genn/genn/customConnectivityUpdate.cc | 21 +++++++-------- src/genn/genn/customUpdate.cc | 12 ++++++--- src/genn/genn/modelSpec.cc | 27 +++++++------------ 11 files changed, 38 insertions(+), 47 deletions(-) diff --git a/include/genn/genn/currentSource.h b/include/genn/genn/currentSource.h index bd664546bd..dd18cf09c0 100644 --- a/include/genn/genn/currentSource.h +++ b/include/genn/genn/currentSource.h @@ -74,7 +74,7 @@ class GENN_EXPORT CurrentSource //------------------------------------------------------------------------ // Protected methods //------------------------------------------------------------------------ - void initDerivedParams(double dt); + void finalise(double dt); //------------------------------------------------------------------------ // Protected const methods diff --git a/include/genn/genn/currentSourceInternal.h b/include/genn/genn/currentSourceInternal.h index 263d7c4548..550e70c3af 100644 --- a/include/genn/genn/currentSourceInternal.h +++ b/include/genn/genn/currentSourceInternal.h @@ -21,7 +21,7 @@ class CurrentSourceInternal : public CurrentSource } using CurrentSource::getTrgNeuronGroup; - using CurrentSource::initDerivedParams; + using CurrentSource::finalise; using CurrentSource::getDerivedParams; using CurrentSource::isSimRNGRequired; using CurrentSource::isInitRNGRequired; diff --git a/include/genn/genn/customConnectivityUpdate.h b/include/genn/genn/customConnectivityUpdate.h index 5a1e000370..0f2eed5e46 100644 --- a/include/genn/genn/customConnectivityUpdate.h +++ b/include/genn/genn/customConnectivityUpdate.h @@ -90,9 +90,7 @@ class GENN_EXPORT CustomConnectivityUpdate //------------------------------------------------------------------------ // Protected methods //------------------------------------------------------------------------ - void initDerivedParams(double dt); - - void finalize(unsigned int batchSize); + void finalise(double dt, unsigned int batchSize); //------------------------------------------------------------------------ // Protected const methods diff --git a/include/genn/genn/customConnectivityUpdateInternal.h b/include/genn/genn/customConnectivityUpdateInternal.h index 4c6a7f6ba5..622ba81f10 100644 --- a/include/genn/genn/customConnectivityUpdateInternal.h +++ b/include/genn/genn/customConnectivityUpdateInternal.h @@ -34,7 +34,7 @@ class CustomConnectivityUpdateInternal : public CustomConnectivityUpdate using CustomConnectivityUpdate::getVarLocationHashDigest; using CustomConnectivityUpdate::getSynapseGroup; using CustomConnectivityUpdate::getDependentVariables; - using CustomConnectivityUpdate::finalize; + using CustomConnectivityUpdate::finalise; using CustomConnectivityUpdate::getHashDigest; using CustomConnectivityUpdate::getInitHashDigest; using CustomConnectivityUpdate::getPreDelayNeuronGroup; diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 69a7b50fe6..a5b270ce06 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -67,7 +67,7 @@ class GENN_EXPORT CustomUpdateBase //------------------------------------------------------------------------ // Protected methods //------------------------------------------------------------------------ - void initDerivedParams(double dt); + void finalise(double dt); //------------------------------------------------------------------------ // Protected const methods @@ -240,7 +240,7 @@ class GENN_EXPORT CustomUpdate : public CustomUpdateBase //------------------------------------------------------------------------ // Protected methods //------------------------------------------------------------------------ - void finalize(unsigned int batchSize); + void finalise(double dt, unsigned int batchSize); //------------------------------------------------------------------------ // Protected const methods @@ -287,7 +287,7 @@ class GENN_EXPORT CustomUpdateWU : public CustomUpdateBase //------------------------------------------------------------------------ // Protected methods //------------------------------------------------------------------------ - void finalize(unsigned int batchSize); + void finalise(double dt, unsigned int batchSize); //------------------------------------------------------------------------ // Protected const methods diff --git a/include/genn/genn/customUpdateInternal.h b/include/genn/genn/customUpdateInternal.h index af3a001867..e3bb27c995 100644 --- a/include/genn/genn/customUpdateInternal.h +++ b/include/genn/genn/customUpdateInternal.h @@ -28,7 +28,7 @@ class CustomUpdateInternal : public CustomUpdate using CustomUpdateBase::isBatched; using CustomUpdateBase::getVarLocationHashDigest; - using CustomUpdate::finalize; + using CustomUpdate::finalise; using CustomUpdate::getHashDigest; using CustomUpdate::getInitHashDigest; using CustomUpdate::getDelayNeuronGroup; @@ -77,7 +77,6 @@ class CustomUpdateWUInternal : public CustomUpdateWU getSynapseGroup()->addCustomUpdateReference(this); } - using CustomUpdateBase::initDerivedParams; using CustomUpdateBase::getDerivedParams; using CustomUpdateBase::isInitRNGRequired; using CustomUpdateBase::isZeroCopyEnabled; @@ -85,7 +84,7 @@ class CustomUpdateWUInternal : public CustomUpdateWU using CustomUpdateBase::isReduction; using CustomUpdateBase::getVarLocationHashDigest; - using CustomUpdateWU::finalize; + using CustomUpdateWU::finalise; using CustomUpdateWU::getHashDigest; using CustomUpdateWU::getInitHashDigest; using CustomUpdateWU::getSynapseGroup; diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index a1aad3fa75..a5c3415bbe 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -1758,7 +1758,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runner << "iT++;" << std::endl; runner << "t = iT * " << writePreciseLiteral(model.getDT(), model.getTimePrecision()) << ";" << std::endl; - // Write step time finalize logic to runner + // Write step time finalise logic to runner runner << runnerStepTimeFinaliseStream.str(); } runner << std::endl; diff --git a/src/genn/genn/currentSource.cc b/src/genn/genn/currentSource.cc index bea478983f..e408245a94 100644 --- a/src/genn/genn/currentSource.cc +++ b/src/genn/genn/currentSource.cc @@ -45,7 +45,7 @@ CurrentSource::CurrentSource(const std::string &name, const CurrentSourceModels: getCurrentSourceModel()->validate(getParams(), getVarInitialisers(), "Current source " + getName()); } //---------------------------------------------------------------------------- -void CurrentSource::initDerivedParams(double dt) +void CurrentSource::finalise(double dt) { auto derivedParams = getCurrentSourceModel()->getDerivedParams(); diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index b233674e13..567c540756 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -153,7 +153,7 @@ CustomConnectivityUpdate::CustomConnectivityUpdate(const std::string &name, cons getPostVarReferences(), "Custom connectivity update " + getName()); } //------------------------------------------------------------------------ -void CustomConnectivityUpdate::initDerivedParams(double dt) +void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) { // Loop through derived parameters auto derivedParams = getCustomConnectivityUpdateModel()->getDerivedParams(); @@ -161,24 +161,21 @@ void CustomConnectivityUpdate::initDerivedParams(double dt) m_DerivedParams.emplace(d.name, d.func(getParams(), dt)); } - // Initialise derived parameters for synaptic variable initialisers + // Finalise derived parameters for synaptic variable initialisers for (auto &v : m_VarInitialisers) { - v.second.initDerivedParams(dt); + v.second.finalise(dt); } - // Initialise derived parameters for presynaptic variable initialisers + // Finalise derived parameters for presynaptic variable initialisers for (auto &v : m_PreVarInitialisers) { - v.second.initDerivedParams(dt); + v.second.finalise(dt); } - // Initialise derived parameters for postsynaptic variable initialisers + // Finalise derived parameters for postsynaptic variable initialisers for (auto &v : m_PostVarInitialisers) { - v.second.initDerivedParams(dt); + v.second.finalise(dt); } -} -//------------------------------------------------------------------------ -void CustomConnectivityUpdate::finalize(unsigned int batchSize) -{ + // If model is batched we need to check all variable references // are SHARED as, connectivity itself is always SHARED if (batchSize > 1) { @@ -388,7 +385,7 @@ NeuronGroup *CustomConnectivityUpdate::getVarRefDelayGroup(const std::unordered_ const std::string &errorContext) const { // If any variable references have delays - // **YUCK** copy and paste from CustomUpdate::finalize + // **YUCK** copy and paste from CustomUpdate::finalise auto delayRef = std::find_if(varRefs.cbegin(), varRefs.cend(), [](const auto &v) { return v.second.getDelayNeuronGroup() != nullptr; }); if(delayRef != varRefs.cend()) { diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 677dee5dbc..3eda1d3f37 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -32,7 +32,7 @@ bool CustomUpdateBase::isVarInitRequired() const [](const auto &init){ return !init.second.getSnippet()->getCode().empty(); }); } //---------------------------------------------------------------------------- -void CustomUpdateBase::initDerivedParams(double dt) +void CustomUpdateBase::finalise(double dt) { auto derivedParams = getCustomUpdateModel()->getDerivedParams(); @@ -120,8 +120,11 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro } } //---------------------------------------------------------------------------- -void CustomUpdate::finalize(unsigned int batchSize) +void CustomUpdate::finalise(double dt, unsigned int batchSize) { + // Superclass + CustomUpdateBase::finalise(dt); + // Check variable reference batching checkVarReferenceBatching(m_VarReferences, batchSize); @@ -247,8 +250,11 @@ CustomUpdateWU::CustomUpdateWU(const std::string &name, const std::string &updat } } //---------------------------------------------------------------------------- -void CustomUpdateWU::finalize(unsigned int batchSize) +void CustomUpdateWU::finalise(double dt, unsigned int batchSize) { + // Superclass + CustomUpdateBase::finalise(dt); + // Check variable reference types checkVarReferenceBatching(m_VarReferences, batchSize); } diff --git a/src/genn/genn/modelSpec.cc b/src/genn/genn/modelSpec.cc index 4989c728c0..9e22af0649 100644 --- a/src/genn/genn/modelSpec.cc +++ b/src/genn/genn/modelSpec.cc @@ -219,38 +219,29 @@ void ModelSpec::finalise() n.second.finalise(m_DT, typeContext); } - // SYNAPSE groups + // Finalise synapse groups for(auto &s : m_LocalSynapseGroups) { - const auto *wu = s.second.getWUModel(); - - // Initialize derived parameters s.second.finalise(m_DT); - - } - // CURRENT SOURCES + // Finalise current sources for(auto &cs : m_LocalCurrentSources) { - // Initialize derived parameters - cs.second.initDerivedParams(m_DT); + cs.second.finalise(m_DT); } - // Custom update groups + // Finalise custom update groups for(auto &c : m_CustomUpdates) { - c.second.finalize(m_BatchSize); - c.second.initDerivedParams(m_DT); + c.second.finalise(m_DT, m_BatchSize); } - // Custom WUM update groups + // Finalise custom WUM update groups for(auto &c : m_CustomWUUpdates) { - c.second.finalize(m_BatchSize); - c.second.initDerivedParams(m_DT); + c.second.finalise(m_DT, m_BatchSize); } - // Custom connectivity update groups + // Finalize custom connectivity update groups for (auto &c : m_CustomConnectivityUpdates) { - c.second.finalize(m_BatchSize); - c.second.initDerivedParams(m_DT); + c.second.finalise(m_BatchSize); } // Merge incoming postsynaptic models From eb2aacaa8020912935d5b0ebcb9a6ba9484d9372 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 13:50:34 +0100 Subject: [PATCH 322/725] compiling! --- .../genn/genn/code_generator/environment.h | 12 ++--- include/genn/genn/currentSource.h | 11 ++-- include/genn/genn/currentSourceInternal.h | 3 +- include/genn/genn/customConnectivityUpdate.h | 25 ++++----- .../genn/customConnectivityUpdateInternal.h | 6 +-- include/genn/genn/customUpdate.h | 16 +++--- include/genn/genn/customUpdateInternal.h | 3 +- include/genn/genn/models.h | 1 + include/genn/genn/neuronGroupInternal.h | 4 ++ include/genn/genn/synapseGroup.h | 2 +- include/genn/genn/synapseGroupInternal.h | 16 +++--- .../backends/single_threaded_cpu/backend.cc | 10 ++-- src/genn/generator/generator.cc | 2 +- src/genn/genn/code_generator/backendSIMT.cc | 16 +++--- .../customConnectivityUpdateGroupMerged.cc | 8 +-- .../code_generator/customUpdateGroupMerged.cc | 8 +-- .../genn/code_generator/generateRunner.cc | 2 +- .../genn/code_generator/initGroupMerged.cc | 44 +++++++--------- .../genn/code_generator/modelSpecMerged.cc | 2 +- .../code_generator/neuronUpdateGroupMerged.cc | 46 ++++++++-------- .../synapseUpdateGroupMerged.cc | 24 ++++----- src/genn/genn/currentSource.cc | 26 ++-------- src/genn/genn/customConnectivityUpdate.cc | 52 ++++++------------- src/genn/genn/customUpdate.cc | 19 ++++++- src/genn/genn/gennUtils.cc | 2 +- .../genn/initSparseConnectivitySnippet.cc | 8 +-- .../genn/initToeplitzConnectivitySnippet.cc | 3 +- src/genn/genn/modelSpec.cc | 4 +- src/genn/genn/models.cc | 5 ++ src/genn/genn/neuronGroup.cc | 4 +- src/genn/genn/synapseGroup.cc | 2 +- 31 files changed, 181 insertions(+), 205 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 2087a77a5d..ef2e9f9976 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -403,8 +403,8 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase using GetVarRefIndexFn = std::function; - template - using GetConnectivityFn = const Snippet::Init &(GroupInternal::*)(void) const; + template + using GetConnectivityFn = const I &(GroupInternal::*)(void) const; template using GetVarReferencesFn = const std::unordered_map &(GroupInternal::*)(void) const; @@ -515,8 +515,8 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase - void addConnectInitParams(const std::string &fieldSuffix, GetConnectivityFn getConnectivity, + template + void addConnectInitParams(const std::string &fieldSuffix, GetConnectivityFn getConnectivity, IsHeterogeneousFn isHeterogeneous) { // Loop through params @@ -539,8 +539,8 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase - void addConnectInitDerivedParams(const std::string &fieldSuffix, GetConnectivityFn getConnectivity, + template + void addConnectInitDerivedParams(const std::string &fieldSuffix, GetConnectivityFn getConnectivity, IsHeterogeneousFn isHeterogeneous) { // Loop through params diff --git a/include/genn/genn/currentSource.h b/include/genn/genn/currentSource.h index dd18cf09c0..f06e607028 100644 --- a/include/genn/genn/currentSource.h +++ b/include/genn/genn/currentSource.h @@ -83,12 +83,6 @@ class GENN_EXPORT CurrentSource const std::unordered_map &getDerivedParams() const{ return m_DerivedParams; } - //! Does this current source require an RNG to simulate - bool isSimRNGRequired() const; - - //! Does this current source group require an RNG for it's init code - bool isInitRNGRequired() const; - bool isZeroCopyEnabled() const; //! Updates hash with current source @@ -101,6 +95,8 @@ class GENN_EXPORT CurrentSource boost::uuids::detail::sha1::digest_type getVarLocationHashDigest() const; + const std::vector getInjectionCodeTokens() const{ return m_InjectionCodeTokens; } + private: //------------------------------------------------------------------------ // Members @@ -119,5 +115,8 @@ class GENN_EXPORT CurrentSource //! Location of extra global parameters std::vector m_ExtraGlobalParamLocation; + + //! Tokens produced by scanner from injection code + std::vector m_InjectionCodeTokens; }; } // namespace GeNN diff --git a/include/genn/genn/currentSourceInternal.h b/include/genn/genn/currentSourceInternal.h index 550e70c3af..a1a0697657 100644 --- a/include/genn/genn/currentSourceInternal.h +++ b/include/genn/genn/currentSourceInternal.h @@ -23,12 +23,11 @@ class CurrentSourceInternal : public CurrentSource using CurrentSource::getTrgNeuronGroup; using CurrentSource::finalise; using CurrentSource::getDerivedParams; - using CurrentSource::isSimRNGRequired; - using CurrentSource::isInitRNGRequired; using CurrentSource::isZeroCopyEnabled; using CurrentSource::getHashDigest; using CurrentSource::getInitHashDigest; using CurrentSource::getVarLocationHashDigest; + using CurrentSource::getInjectionCodeTokens; }; //---------------------------------------------------------------------------- diff --git a/include/genn/genn/customConnectivityUpdate.h b/include/genn/genn/customConnectivityUpdate.h index 0f2eed5e46..b19ca387cd 100644 --- a/include/genn/genn/customConnectivityUpdate.h +++ b/include/genn/genn/customConnectivityUpdate.h @@ -72,12 +72,6 @@ class GENN_EXPORT CustomConnectivityUpdate //! Is var init code required for any postsynaptic variables in this custom connectivity update group? bool isPostVarInitRequired() const; - //! Is a per-row RNG required for this custom connectivity update group - bool isRowSimRNGRequired() const; - - //! Is a host RNG required for this custom connectivity update group - bool isHostRNGRequired() const; - protected: CustomConnectivityUpdate(const std::string &name, const std::string &updateGroupName, SynapseGroupInternal *synapseGroup, const CustomConnectivityUpdateModels::Base *customConnectivityUpdateModel, @@ -97,15 +91,6 @@ class GENN_EXPORT CustomConnectivityUpdate //------------------------------------------------------------------------ const std::unordered_map &getDerivedParams() const { return m_DerivedParams; } - //! Does this current source group require an RNG for initialising its presynaptic variables - bool isPreVarInitRNGRequired() const; - - //! Does this current source group require an RNG for initialising its postsynaptic variables - bool isPostVarInitRNGRequired() const; - - //! Does this current source group require an RNG for initialising its synaptic variables - bool isVarInitRNGRequired() const; - bool isZeroCopyEnabled() const; SynapseGroupInternal *getSynapseGroup() const { return m_SynapseGroup; } @@ -125,6 +110,10 @@ class GENN_EXPORT CustomConnectivityUpdate boost::uuids::detail::sha1::digest_type getVarLocationHashDigest() const; + const std::vector getRowUpdateCodeTokens() const{ return m_RowUpdateCodeTokens; } + + const std::vector getHostUpdateCodeTokens() const{ return m_HostUpdateCodeTokens; } + const NeuronGroup *getPreDelayNeuronGroup() const { return m_PreDelayNeuronGroup; } const NeuronGroup *getPostDelayNeuronGroup() const { return m_PostDelayNeuronGroup; } @@ -164,5 +153,11 @@ class GENN_EXPORT CustomConnectivityUpdate const NeuronGroup *m_PreDelayNeuronGroup; const NeuronGroup *m_PostDelayNeuronGroup; + + //! Tokens produced by scanner from row update code + std::vector m_RowUpdateCodeTokens; + + //! Tokens produced by scanner from host update code + std::vector m_HostUpdateCodeTokens; }; } // namespace GeNN \ No newline at end of file diff --git a/include/genn/genn/customConnectivityUpdateInternal.h b/include/genn/genn/customConnectivityUpdateInternal.h index 622ba81f10..7ef824bb8e 100644 --- a/include/genn/genn/customConnectivityUpdateInternal.h +++ b/include/genn/genn/customConnectivityUpdateInternal.h @@ -25,13 +25,11 @@ class CustomConnectivityUpdateInternal : public CustomConnectivityUpdate getSynapseGroup()->addCustomUpdateReference(this); } - using CustomConnectivityUpdate::initDerivedParams; using CustomConnectivityUpdate::getDerivedParams; - using CustomConnectivityUpdate::isPreVarInitRNGRequired; - using CustomConnectivityUpdate::isPostVarInitRNGRequired; - using CustomConnectivityUpdate::isVarInitRNGRequired; using CustomConnectivityUpdate::isZeroCopyEnabled; using CustomConnectivityUpdate::getVarLocationHashDigest; + using CustomConnectivityUpdate::getRowUpdateCodeTokens; + using CustomConnectivityUpdate::getHostUpdateCodeTokens; using CustomConnectivityUpdate::getSynapseGroup; using CustomConnectivityUpdate::getDependentVariables; using CustomConnectivityUpdate::finalise; diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index a5b270ce06..f170355b7e 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -53,16 +53,7 @@ class GENN_EXPORT CustomUpdateBase protected: CustomUpdateBase(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, - VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) - : m_Name(name), m_UpdateGroupName(updateGroupName), m_CustomUpdateModel(customUpdateModel), m_Params(params), - m_VarInitialisers(varInitialisers), m_VarLocation(varInitialisers.size(), defaultVarLocation), - m_ExtraGlobalParamLocation(customUpdateModel->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation), - m_Batched(false) - { - // Validate names - Utils::validatePopName(name, "Custom update"); - Utils::validatePopName(updateGroupName, "Custom update group name"); - } + VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation); //------------------------------------------------------------------------ // Protected methods @@ -92,6 +83,8 @@ class GENN_EXPORT CustomUpdateBase boost::uuids::detail::sha1::digest_type getVarLocationHashDigest() const; + const std::vector getUpdateCodeTokens() const{ return m_UpdateCodeTokens; } + template bool isReduction(const std::unordered_map &varRefs, VarAccessDuplication duplication) const { @@ -163,6 +156,9 @@ class GENN_EXPORT CustomUpdateBase //! Location of extra global parameters std::vector m_ExtraGlobalParamLocation; + //! Tokens produced by scanner from update code + std::vector m_UpdateCodeTokens; + //! Is this custom update batched i.e. run in parallel across model batches bool m_Batched; }; diff --git a/include/genn/genn/customUpdateInternal.h b/include/genn/genn/customUpdateInternal.h index e3bb27c995..c4c78b5ca5 100644 --- a/include/genn/genn/customUpdateInternal.h +++ b/include/genn/genn/customUpdateInternal.h @@ -21,12 +21,12 @@ class CustomUpdateInternal : public CustomUpdate { } - using CustomUpdateBase::initDerivedParams; using CustomUpdateBase::getDerivedParams; using CustomUpdateBase::isInitRNGRequired; using CustomUpdateBase::isZeroCopyEnabled; using CustomUpdateBase::isBatched; using CustomUpdateBase::getVarLocationHashDigest; + using CustomUpdateBase::getUpdateCodeTokens; using CustomUpdate::finalise; using CustomUpdate::getHashDigest; @@ -83,6 +83,7 @@ class CustomUpdateWUInternal : public CustomUpdateWU using CustomUpdateBase::isBatched; using CustomUpdateBase::isReduction; using CustomUpdateBase::getVarLocationHashDigest; + using CustomUpdateBase::getUpdateCodeTokens; using CustomUpdateWU::finalise; using CustomUpdateWU::getHashDigest; diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index ad37cd34ae..4e1a35fab8 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -128,6 +128,7 @@ class VarInit : public Snippet::Init { public: VarInit(const InitVarSnippet::Base *snippet, const std::unordered_map ¶ms); + VarInit(double constant); //------------------------------------------------------------------------ // Public API diff --git a/include/genn/genn/neuronGroupInternal.h b/include/genn/genn/neuronGroupInternal.h index dfbf042b90..02625354c7 100644 --- a/include/genn/genn/neuronGroupInternal.h +++ b/include/genn/genn/neuronGroupInternal.h @@ -40,8 +40,12 @@ class NeuronGroupInternal : public NeuronGroup using NeuronGroup::getFusedOutSynWithPreCode; using NeuronGroup::getFusedInSynWithPostVars; using NeuronGroup::getFusedOutSynWithPreVars; + using NeuronGroup::getSimCodeTokens; + using NeuronGroup::getThresholdConditionCodeTokens; + using NeuronGroup::getResetCodeTokens; using NeuronGroup::isSimRNGRequired; using NeuronGroup::isInitRNGRequired; + using NeuronGroup::isRecordingEnabled; using NeuronGroup::isVarQueueRequired; using NeuronGroup::getHashDigest; using NeuronGroup::getInitHashDigest; diff --git a/include/genn/genn/synapseGroup.h b/include/genn/genn/synapseGroup.h index c1439ad848..c54565b63f 100644 --- a/include/genn/genn/synapseGroup.h +++ b/include/genn/genn/synapseGroup.h @@ -224,7 +224,7 @@ class GENN_EXPORT SynapseGroup void setFusedWUPostVarSuffix(const std::string &suffix){ m_FusedWUPostVarSuffix = suffix; } void setFusedPreOutputSuffix(const std::string &suffix){ m_FusedPreOutputSuffix = suffix; } - void finalise(double dt, const Type::TypeContext &context); + void finalise(double dt); //! Add reference to custom connectivity update, referencing this synapse group void addCustomUpdateReference(CustomConnectivityUpdateInternal *cu){ m_CustomConnectivityUpdateReferences.push_back(cu); } diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h index beaf77f382..3c72c2a81f 100644 --- a/include/genn/genn/synapseGroupInternal.h +++ b/include/genn/genn/synapseGroupInternal.h @@ -34,14 +34,6 @@ class SynapseGroupInternal : public SynapseGroup using SynapseGroup::getTrgNeuronGroup; using SynapseGroup::getWUDerivedParams; using SynapseGroup::getPSDerivedParams; - using SynapseGroup::setEventThresholdReTestRequired; - using SynapseGroup::setFusedPSVarSuffix; - using SynapseGroup::setFusedPreOutputSuffix; - using SynapseGroup::setFusedWUPreVarSuffix; - using SynapseGroup::setFusedWUPostVarSuffix; - using SynapseGroup::finalise; - using SynapseGroup::addCustomUpdateReference; - using SynapseGroup::isEventThresholdReTestRequired; using SynapseGroup::getWUSimCodeTokens; using SynapseGroup::getWUEventCodeTokens; using SynapseGroup::getWUPostLearnCodeTokens; @@ -53,6 +45,14 @@ class SynapseGroupInternal : public SynapseGroup using SynapseGroup::getWUPostDynamicsCodeTokens; using SynapseGroup::getPSApplyInputCodeTokens; using SynapseGroup::getPSDecayCodeTokens; + using SynapseGroup::setEventThresholdReTestRequired; + using SynapseGroup::setFusedPSVarSuffix; + using SynapseGroup::setFusedPreOutputSuffix; + using SynapseGroup::setFusedWUPreVarSuffix; + using SynapseGroup::setFusedWUPostVarSuffix; + using SynapseGroup::finalise; + using SynapseGroup::addCustomUpdateReference; + using SynapseGroup::isEventThresholdReTestRequired; using SynapseGroup::getFusedPSVarSuffix; using SynapseGroup::getFusedPreOutputSuffix; using SynapseGroup::getFusedWUPreVarSuffix; diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 4cf58b7fa1..c95c99b610 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1686,11 +1686,11 @@ bool Backend::isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) const if(std::any_of(model.getCustomConnectivityUpdates().cbegin(), model.getCustomConnectivityUpdates().cend(), [](const ModelSpec::CustomConnectivityUpdateValueType &c) { - return (c.second.isVarInitRNGRequired() - || c.second.isPreVarInitRNGRequired() - || c.second.isPostVarInitRNGRequired() - || c.second.isRowSimRNGRequired() - || c.second.isHostRNGRequired()); + return (Utils::isRNGRequired(c.second.getVarInitialisers()) + || Utils::isRNGRequired(c.second.getPreVarInitialisers()) + || Utils::isRNGRequired(c.second.getPostVarInitialisers()) + || Utils::isRNGRequired(c.second.getRowUpdateCodeTokens()) + || Utils::isRNGRequired(c.second.getHostUpdateCodeTokens())); })) { return true; diff --git a/src/genn/generator/generator.cc b/src/genn/generator/generator.cc index 1eebeb20d2..daaa20794a 100644 --- a/src/genn/generator/generator.cc +++ b/src/genn/generator/generator.cc @@ -56,7 +56,7 @@ int main(int argc, //!< number of arguments; expected to be 3 &consoleAppender, &consoleAppender, &consoleAppender); // Finalize model - model.finalize(); + model.finalise(); // Determine code generation path const filesystem::path outputPath = targetPath / (model.getName() + "_CODE"); diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 012bcb3d11..4d441a6b5e 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -105,7 +105,7 @@ bool BackendSIMT::isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) co return (std::any_of(model.getSynapseGroups().cbegin(), model.getSynapseGroups().cend(), [](const ModelSpec::SynapseGroupValueType &s){ return s.second.getConnectivityInitialiser().isHostRNGRequired(); }) || std::any_of(model.getCustomConnectivityUpdates().cbegin(), model.getCustomConnectivityUpdates().cend(), - [](const ModelSpec::CustomConnectivityUpdateValueType &c){ return c.second.isHostRNGRequired(); })); + [](const ModelSpec::CustomConnectivityUpdateValueType &c){ return Utils::isRNGRequired(c.second.getHostUpdateCodeTokens()); })); } //-------------------------------------------------------------------------- bool BackendSIMT::isGlobalDeviceRNGRequired(const ModelSpecMerged &modelMerged) const @@ -1276,14 +1276,14 @@ void BackendSIMT::genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env // Copy global RNG stream to local and use pointer to this for rng const std::string rng = printSubs("$(_rng)[$(id)]", groupEnv); - if(cg.getArchetype().isRowSimRNGRequired()) { + if(Utils::isRNGRequired(cg.getArchetype().getRowUpdateCodeTokens())) { groupEnv.add(Type::Void, "rng", genPopulationRNGPreamble(groupEnv.getStream(), rng)); } cg.generateUpdate(*this, groupEnv, modelMerged); // Copy local stream back to local - if(cg.getArchetype().isRowSimRNGRequired()) { + if(Utils::isRNGRequired(cg.getArchetype().getRowUpdateCodeTokens())) { genPopulationRNGPostamble(groupEnv.getStream(), rng); } } @@ -1403,7 +1403,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // If population RNGs are initialised on device and this custom connectivity update // required one, initialise single RNG using GLOBAL thread id for sequence - if(isPopulationRNGInitialisedOnDevice() && cg.getArchetype().isRowSimRNGRequired()) { + if(isPopulationRNGInitialisedOnDevice() && Utils::isRNGRequired(cg.getArchetype().getRowUpdateCodeTokens())) { genPopulationRNGInit(groupEnv.getStream(), printSubs("$(_rng)[$(id)]", groupEnv), "deviceRNGSeed", "id"); } @@ -1411,7 +1411,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id - if(cg.getArchetype().isPreVarInitRNGRequired()) { + if(Utils::isRNGRequired(cg.getArchetype().getPreVarInitialisers())) { groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } @@ -1435,7 +1435,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // If population RNGs are initialised on device and this custom connectivity update // required one, initialise single RNG using GLOBAL thread id for sequence - if(isPopulationRNGInitialisedOnDevice() && cg.getArchetype().isRowSimRNGRequired()) { + if(isPopulationRNGInitialisedOnDevice() && Utils::isRNGRequired(cg.getArchetype().getRowUpdateCodeTokens())) { genPopulationRNGInit(groupEnv.getStream(), printSubs("$(_rng)[$(id)]", groupEnv), "deviceRNGSeed", "id"); } @@ -1443,7 +1443,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id - if(cg.getArchetype().isPostVarInitRNGRequired()) { + if(Utils::isRNGRequired(cg.getArchetype().getPostVarInitialisers())) { groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } @@ -1678,7 +1678,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id - if(cg.getArchetype().isVarInitRNGRequired()) { + if(Utils::isRNGRequired(cg.getArchetype().getVarInitialisers())) { groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), std::to_string(numInitializeThreads) + " + id")); } diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 57c44d6a9a..438b364db5 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -363,8 +363,8 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back } // Pretty print code back to environment - Transpiler::ErrorHandler errorHandler("Custom connectivity update" + std::to_string(getIndex())); - prettyPrintStatements(cm->getRowUpdateCode(), getTypeContext(), updateEnv, errorHandler, + Transpiler::ErrorHandler errorHandler("Custom connectivity update '" + getArchetype().getName() + "' row update code"); + prettyPrintStatements(getArchetype().getRowUpdateCodeTokens(), getTypeContext(), updateEnv, errorHandler, // Within for_each_synapse loops, define the following types [this](auto &env, auto &errorHandler) { @@ -520,8 +520,8 @@ void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase & addVars(groupEnv, "$(num_post)", backend); // Pretty print code back to environment - Transpiler::ErrorHandler errorHandler("Custom connectivity host update" + std::to_string(getIndex())); - prettyPrintStatements(cm->getHostUpdateCode(), getTypeContext(), groupEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Custom connectivity '" + getArchetype().getName() + "' host update code"); + prettyPrintStatements(getArchetype().getHostUpdateCodeTokens(), getTypeContext(), groupEnv, errorHandler); } } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 3cba9602aa..1daa1c0cbd 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -83,8 +83,8 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E varEnv["id"]); }); - Transpiler::ErrorHandler errorHandler("Custom update code " + std::to_string(getIndex())); - prettyPrintExpression(cm->getUpdateCode(), getTypeContext(), varRefEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Custom update '" + getArchetype().getName() + "' update code"); + prettyPrintExpression(getArchetype().getUpdateCodeTokens(), getTypeContext(), varRefEnv, errorHandler); } //---------------------------------------------------------------------------- std::string CustomUpdateGroupMerged::getVarIndex(VarAccessDuplication varDuplication, const std::string &index) const @@ -266,8 +266,8 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdateBase(const BackendBase & varEnv["id_syn"]); }); - Transpiler::ErrorHandler errorHandler("Custom WU update code " + std::to_string(getIndex())); - prettyPrintExpression(cm->getUpdateCode(), getTypeContext(), varRefEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Custom update '" + getArchetype().getName() + "' update code"); + prettyPrintExpression(getArchetype().getUpdateCodeTokens(), getTypeContext(), varRefEnv, errorHandler); } // ---------------------------------------------------------------------------- // GeNN::CodeGenerator::CustomUpdateWUGroupMerged diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index a5c3415bbe..5ac3ebabb2 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -1223,7 +1223,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerVarDecl, runnerExtraGlobalParamFunc, c.second); // If custom connectivity update group needs per-row RNGs - if(c.second.isRowSimRNGRequired()) { + if(Utils::isRNGRequired(c.second.getRowUpdateCodeTokens())) { backend.genPopulationRNG(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, "rowRNG" + c.first, c.second.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), mem); } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 762175629f..9ae8168f11 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -70,15 +70,14 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e // If there is any initialisation code const auto resolvedType = var.type.resolve(group.getTypeContext()); const auto &varInit = adaptor.getInitialisers().at(var.name); - const auto *snippet = varInit.getSnippet(); - if (!snippet->getCode().empty()) { + if (!Utils::areTokensEmpty(varInit.getCodeTokens())) { CodeStream::Scope b(env.getStream()); // Substitute in parameters and derived parameters for initialising variables EnvironmentGroupMergedField varEnv(env, group, fieldGroup); varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous, var.name, fieldSuffix); varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, var.name, fieldSuffix); - varEnv.addExtraGlobalParams(snippet->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name, fieldSuffix); + varEnv.addExtraGlobalParams(varInit.getSnippet()->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name, fieldSuffix); // Add field for variable itself varEnv.addField(resolvedType.createPointer(), "_value", var.name + fieldSuffix, @@ -91,7 +90,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e if (getVarAccessDuplication(var.access) == VarAccessDuplication::SHARED_NEURON) { backend.genPopVariableInit( varEnv, - [&adaptor, &fieldGroup, &fieldSuffix, &group, &resolvedType, &var, batchSize, numDelaySlots, snippet] + [&adaptor, &fieldGroup, &fieldSuffix, &group, &resolvedType, &var, &varInit, batchSize, numDelaySlots] (EnvironmentExternalBase &env) { // Generate initial value into temporary variable @@ -100,8 +99,8 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e varInitEnv.add(resolvedType, "value", "initVal"); // Pretty print variable initialisation code - Transpiler::ErrorHandler errorHandler("Variable '" + var.name + "' init code" + std::to_string(group.getIndex())); - prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Group '" + group.getArchetype().getName() + "' variable '" + var.name + "' init code"); + prettyPrintStatements(varInit.getCodeTokens(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all delay slots and batches genScalarFill(varInitEnv, "_value", "$(value)", getVarAccessDuplication(var.access), @@ -112,7 +111,7 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e else { backend.genVariableInit( varEnv, count, "id", - [&adaptor, &fieldGroup, &fieldSuffix, &group, &var, &resolvedType, batchSize, count, numDelaySlots, snippet] + [&adaptor, &fieldGroup, &fieldSuffix, &group, &var, &resolvedType, &varInit, batchSize, count, numDelaySlots] (EnvironmentExternalBase &env) { // Generate initial value into temporary variable @@ -121,8 +120,8 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e varInitEnv.add(resolvedType, "value", "initVal"); // Pretty print variable initialisation code - Transpiler::ErrorHandler errorHandler("Variable '" + var.name + "' init code" + std::to_string(group.getIndex())); - prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Group '" + group.getArchetype().getName() + "' variable '" + var.name + "' init code"); + prettyPrintStatements(varInit.getCodeTokens(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all delay slots and batches genVariableFill(varInitEnv, "_value", "$(value)", "id", "$(" + count + ")", @@ -150,18 +149,17 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, { A adaptor(group.getArchetype()); for (const auto &var : adaptor.getDefs()) { - // If this variable has any initialisation code and doesn't require a kernel + // If this variable has any initialisation code and doesn't require a kernel (in this case it will be initialised elsewhere) const auto resolvedType = var.type.resolve(group.getTypeContext()); const auto &varInit = adaptor.getInitialisers().at(var.name); - const auto *snippet = varInit.getSnippet(); - if(!snippet->getCode().empty() && !snippet->requiresKernel()) { + if(!Utils::areTokensEmpty(varInit.getCodeTokens()) && !varInit.isKernelRequired()) { CodeStream::Scope b(env.getStream()); // Substitute in parameters and derived parameters for initialising variables EnvironmentGroupMergedField varEnv(env, group); varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous, var.name); varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, var.name); - varEnv.addExtraGlobalParams(snippet->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name); + varEnv.addExtraGlobalParams(varInit.getSnippet()->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name); // Add field for variable itself varEnv.addField(resolvedType.createPointer(), "_value", var.name, @@ -172,7 +170,7 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, // Generate target-specific code to initialise variable genSynapseVariableRowInitFn(varEnv, - [&group, &resolvedType, &stride, &var, batchSize, snippet] + [&group, &resolvedType, &stride, &var, &varInit, batchSize] (EnvironmentExternalBase &env) { // Generate initial value into temporary variable @@ -182,7 +180,7 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, // Pretty print variable initialisation code Transpiler::ErrorHandler errorHandler("Variable '" + var.name + "' init code" + std::to_string(group.getIndex())); - prettyPrintStatements(snippet->getCode(), group.getTypeContext(), varInitEnv, errorHandler); + prettyPrintStatements(varInit.getCodeTokens(), group.getTypeContext(), varInitEnv, errorHandler); // Fill value across all batches genVariableFill(varInitEnv, "_value", "$(value)", "id_syn", stride, @@ -733,20 +731,16 @@ void SynapseConnectivityInitGroupMerged::genInitConnectivity(const BackendBase & const std::string context = rowNotColumns ? "row" : "column"; for(const auto &a : stateVars) { const auto resolvedType = a.type.resolve(getTypeContext()); - groupEnv.getStream() << resolvedType.getName() << " _" << a.name << " = "; - - Transpiler::ErrorHandler errorHandler("Connectivity init " + context + " build state var" + std::to_string(getIndex())); - prettyPrintExpression(a.value, getTypeContext(), groupEnv, errorHandler); - - groupEnv.getStream() << ";" << std::endl; + groupEnv.getStream() << resolvedType.getName() << " _" << a.name << " = " << a.value << ";" << std::endl; groupEnv.add(resolvedType, a.name, "_" + a.name); } groupEnv.getStream() << "while(true)"; { CodeStream::Scope b(groupEnv.getStream()); - Transpiler::ErrorHandler errorHandler("Connectivity init " + context + " build" + std::to_string(getIndex())); - prettyPrintStatements(rowNotColumns ? snippet->getRowBuildCode() : snippet->getColBuildCode(), getTypeContext(), groupEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Synapse group sparse connectivity '" + getArchetype().getName() + "' " + context + " build code"); + prettyPrintStatements(rowNotColumns ? connectInit.getRowBuildCodeTokens() : connectInit.getColBuildCodeTokens(), + getTypeContext(), groupEnv, errorHandler); } } @@ -850,8 +844,8 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac groupEnv.add(Type::ResolvedType::createFunction(Type::Void, {Type::Uint32}), "push" + egp.name, pushStream.str()); } } - Transpiler::ErrorHandler errorHandler("Connectivity host init" + std::to_string(getIndex())); - prettyPrintStatements(connectInit.getSnippet()->getHostInitCode(), getTypeContext(), groupEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Synapse group '" + getArchetype().getName() + "' sparse connectivity host init code"); + prettyPrintStatements(connectInit.getHostInitCodeTokens(), getTypeContext(), groupEnv, errorHandler); } } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index 3a22421ad3..c7b349d4f6 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -299,7 +299,7 @@ void ModelSpecMerged::genMergedCustomConnectivityUpdatePreInitGroups(const Backe createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePreInitGroups, [&backend](const CustomConnectivityUpdateInternal &cg) { - return (cg.isPreVarInitRequired() || (backend.isPopulationRNGInitialisedOnDevice() && cg.isRowSimRNGRequired())); + return (cg.isPreVarInitRequired() || (backend.isPopulationRNGInitialisedOnDevice() && Utils::isRNGRequired(cg.getRowUpdateCodeTokens()))); }, &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index f9da2c2480..a964da82d9 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -47,8 +47,8 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend }); // Pretty print code back to environment - Transpiler::ErrorHandler errorHandler("Current source injection" + std::to_string(getIndex())); - prettyPrintStatements(cm->getInjectionCode(), getTypeContext(), varEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Current source '" + getArchetype().getName() + "' injection code"); + prettyPrintStatements(getArchetype().getInjectionCodeTokens(), getTypeContext(), varEnv, errorHandler); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::CurrentSource::updateHash(boost::uuids::detail::sha1 &hash) const @@ -126,11 +126,11 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env }); // Pretty print code back to environment - Transpiler::ErrorHandler applyInputErrorHandler("Postsynaptic model apply input" + std::to_string(getIndex())); - prettyPrintStatements(psm->getApplyInputCode(), getTypeContext(), varEnv, applyInputErrorHandler); + Transpiler::ErrorHandler applyInputErrorHandler("Synapse group '" + getArchetype().getName() + "' postsynaptic model apply input code"); + prettyPrintStatements(getArchetype().getPSApplyInputCodeTokens(), getTypeContext(), varEnv, applyInputErrorHandler); - Transpiler::ErrorHandler decayErrorHandler("Postsynaptic model decay" + std::to_string(getIndex())); - prettyPrintStatements(psm->getDecayCode(), getTypeContext(), varEnv, decayErrorHandler); + Transpiler::ErrorHandler decayErrorHandler("Synapse group '" + getArchetype().getName() + "' postsynaptic model decay code"); + prettyPrintStatements(getArchetype().getPSDecayCodeTokens(), getTypeContext(), varEnv, decayErrorHandler); // Write back linSyn varEnv.printLine("$(_out_post)[" + ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id") + "] = linSyn;"); @@ -186,8 +186,8 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back const unsigned int batchSize = modelMerged.getModel().getBatchSize(); // If there are any statements to execute here - const std::string code = dynamicsNotSpike ? wum->getPostDynamicsCode() : wum->getPostSpikeCode(); - if(!code.empty()) { + const auto &tokens = dynamicsNotSpike ? getArchetype().getWUPostDynamicsCodeTokens() : getArchetype().getWUPostSpikeCodeTokens(); + if(!Utils::areTokensEmpty(tokens)) { // Create new environment to add out syn fields to neuron update group EnvironmentGroupMergedField synEnv(env, *this, ng); @@ -223,8 +223,9 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back return ng.getReadVarIndex(delay, batchSize, varDuplication, subs["id"]); });*/ - Transpiler::ErrorHandler errorHandler("Postsynaptic weight update model " + std::to_string(getIndex())); - prettyPrintStatements(code, getTypeContext(), varEnv, errorHandler); + const std::string context = dynamicsNotSpike ? "dynamics" : "spike"; + Transpiler::ErrorHandler errorHandler("Synapse group '" + getArchetype().getName() + "' weight update model postsynaptic " + context + " code"); + prettyPrintStatements(tokens, getTypeContext(), varEnv, errorHandler); } } //---------------------------------------------------------------------------- @@ -272,8 +273,8 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back const unsigned int batchSize = modelMerged.getModel().getBatchSize(); // If there are any statements to execute here - const std::string code = dynamicsNotSpike ? wum->getPreDynamicsCode() : wum->getPreSpikeCode(); - if(!code.empty()) { + const auto &tokens = dynamicsNotSpike ? getArchetype().getWUPreDynamicsCodeTokens() : getArchetype().getWUPreSpikeCodeTokens(); + if(!Utils::areTokensEmpty(tokens)) { // Create new environment to add out syn fields to neuron update group EnvironmentGroupMergedField synEnv(env, *this, ng); @@ -309,8 +310,9 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back return ng.getReadVarIndex(delay, batchSize, varDuplication, subs["id"]); });*/ - Transpiler::ErrorHandler errorHandler("Presynaptic weight update model " + std::to_string(getIndex())); - prettyPrintStatements(code, getTypeContext(), varEnv, errorHandler); + const std::string context = dynamicsNotSpike ? "dynamics" : "spike"; + Transpiler::ErrorHandler errorHandler("Synapse group '" + getArchetype().getName() + "' weight update model presynaptic " + context + " code"); + prettyPrintStatements(tokens, getTypeContext(), varEnv, errorHandler); } } //---------------------------------------------------------------------------- @@ -558,8 +560,8 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E if (nm->isAutoRefractoryRequired()) { neuronVarEnv.getStream() << "const bool oldSpike = ("; - Transpiler::ErrorHandler errorHandler("Neuron threshold condition " + std::to_string(getIndex())); - prettyPrintExpression(nm->getThresholdConditionCode(), getTypeContext(), neuronVarEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Neuron group '" + getArchetype().getName() + "' threshold condition code"); + prettyPrintExpression(getArchetype().getThresholdConditionCodeTokens(), getTypeContext(), neuronVarEnv, errorHandler); neuronVarEnv.getStream() << ");" << std::endl; } @@ -573,8 +575,8 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E neuronVarEnv.getStream() << "// calculate membrane potential" << std::endl; - Transpiler::ErrorHandler errorHandler("Neuron sim code " + std::to_string(getIndex())); - prettyPrintStatements(nm->getSimCode(), getTypeContext(), neuronVarEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Neuron group '" + getArchetype().getName() + "' sim code"); + prettyPrintStatements(getArchetype().getSimCodeTokens(), getTypeContext(), neuronVarEnv, errorHandler); // Generate var update for outgoing synaptic populations with presynaptic update code for (auto &sg : m_MergedOutSynWUMPreCodeGroups) { @@ -665,8 +667,8 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E neuronVarEnv.getStream() << "// test for and register a true spike" << std::endl; neuronVarEnv.getStream() << "if (("; - Transpiler::ErrorHandler errorHandler("Neuron threshold condition " + std::to_string(getIndex())); - prettyPrintExpression(nm->getThresholdConditionCode(), getTypeContext(), neuronVarEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Neuron group '" + getArchetype().getName() + "' threshold condition code"); + prettyPrintExpression(getArchetype().getThresholdConditionCodeTokens(), getTypeContext(), neuronVarEnv, errorHandler); neuronVarEnv.getStream() << ")"; if (nm->isAutoRefractoryRequired()) { @@ -681,8 +683,8 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E if (!nm->getResetCode().empty()) { neuronVarEnv.getStream() << "// spike reset code" << std::endl; - Transpiler::ErrorHandler errorHandler("Neuron reset code " + std::to_string(getIndex())); - prettyPrintStatements(nm->getResetCode(), getTypeContext(), neuronVarEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Neuron group '" + getArchetype().getName() + "' reset code"); + prettyPrintStatements(getArchetype().getResetCodeTokens(), getTypeContext(), neuronVarEnv, errorHandler); } } diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 7df932b4c5..3a39811a74 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -15,7 +15,7 @@ using namespace GeNN::CodeGenerator; namespace { template -void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBase &env, std::string code, const std::string &errorContext, +void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBase &env, const std::vector &tokens, const std::string &errorContext, G &sg, const ModelSpecMerged &modelMerged, bool backendSupportsNamespace) { const ModelSpecInternal &model = modelMerged.getModel(); @@ -160,8 +160,8 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa }*/ // Pretty print code back to environment - Transpiler::ErrorHandler errorHandler(errorContext + std::to_string(sg.getIndex())); - prettyPrintStatements(code, sg.getTypeContext(), synEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Synapse group '" + sg.getArchetype().getName() + "' weight update model " + errorContext); + prettyPrintStatements(tokens, sg.getTypeContext(), synEnv, errorHandler); } } // Anonymous namespace @@ -386,19 +386,19 @@ void PresynapticUpdateGroupMerged::generateSpikeEventThreshold(const BackendBase });*/ // Pretty print code back to environment - Transpiler::ErrorHandler errorHandler("eventThresholdConditionCode" + std::to_string(getIndex())); - prettyPrintStatements(wum->getEventThresholdConditionCode(), getTypeContext(), synEnv, errorHandler); + Transpiler::ErrorHandler errorHandler("Synapse group '" + getArchetype().getName() + "' weight update model event threshold code"); + prettyPrintStatements(getArchetype().getWUEventThresholdCodeTokens(), getTypeContext(), synEnv, errorHandler); } //---------------------------------------------------------------------------- void PresynapticUpdateGroupMerged::generateSpikeEventUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { - applySynapseSubstitutions(backend, env, getArchetype().getWUModel()->getEventCode(), "eventCode", + applySynapseSubstitutions(backend, env, getArchetype().getWUEventCodeTokens(), "event code", *this, modelMerged, backend.supportsNamespace()); } //---------------------------------------------------------------------------- void PresynapticUpdateGroupMerged::generateSpikeUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { - applySynapseSubstitutions(backend, env, getArchetype().getWUModel()->getSimCode(), "simCode", + applySynapseSubstitutions(backend, env, getArchetype().getWUSimCodeTokens(), "sim code", *this, modelMerged, backend.supportsNamespace()); } //---------------------------------------------------------------------------- @@ -435,8 +435,8 @@ void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendB void PresynapticUpdateGroupMerged::generateToeplitzConnectivity(const BackendBase&, EnvironmentExternalBase &env) { // Pretty print code back to environment - Transpiler::ErrorHandler errorHandler("toeplitzSparseConnectivity" + std::to_string(getIndex())); - prettyPrintStatements(getArchetype().getToeplitzConnectivityInitialiser().getSnippet()->getDiagonalBuildCode(), + Transpiler::ErrorHandler errorHandler("Synapse group '" + getArchetype().getName() + "' Toeplitz connectivity diagonal build code"); + prettyPrintStatements(getArchetype().getToeplitzConnectivityInitialiser().getDiagonalBuildCodeTokens(), getTypeContext(), env, errorHandler); } @@ -447,12 +447,11 @@ const std::string PostsynapticUpdateGroupMerged::name = "PostsynapticUpdate"; //---------------------------------------------------------------------------- void PostsynapticUpdateGroupMerged::generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { - const auto *wum = getArchetype().getWUModel(); /*if (!wum->getLearnPostSupportCode().empty() && backend.supportsNamespace()) { os << "using namespace " << modelMerged.getPostsynapticUpdateSupportCodeNamespace(wum->getLearnPostSupportCode()) << ";" << std::endl; }*/ - applySynapseSubstitutions(backend, env, wum->getLearnPostCode(), "synapselearnPostCodeDynamics", + applySynapseSubstitutions(backend, env, getArchetype().getWUPostLearnCodeTokens(), "learn post code", *this, modelMerged, backend.supportsNamespace()); } @@ -463,12 +462,11 @@ const std::string SynapseDynamicsGroupMerged::name = "SynapseDynamics"; //---------------------------------------------------------------------------- void SynapseDynamicsGroupMerged::generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { - const auto *wum = getArchetype().getWUModel(); /*if (!wum->getSynapseDynamicsSuppportCode().empty() && backend.supportsNamespace()) { os << "using namespace " << modelMerged.getSynapseDynamicsSupportCodeNamespace(wum->getSynapseDynamicsSuppportCode()) << ";" << std::endl; }*/ - applySynapseSubstitutions(backend, env, wum->getSynapseDynamicsCode(), "synapseDynamics", + applySynapseSubstitutions(backend, env, getArchetype().getWUSynapseDynamicsCodeTokens(), "synapse dynamics", *this, modelMerged, backend.supportsNamespace()); } diff --git a/src/genn/genn/currentSource.cc b/src/genn/genn/currentSource.cc index e408245a94..e6578bbf82 100644 --- a/src/genn/genn/currentSource.cc +++ b/src/genn/genn/currentSource.cc @@ -43,6 +43,10 @@ CurrentSource::CurrentSource(const std::string &name, const CurrentSourceModels: // Validate names Utils::validatePopName(name, "Current source"); getCurrentSourceModel()->validate(getParams(), getVarInitialisers(), "Current source " + getName()); + + // Scan current source model code string + m_InjectionCodeTokens = Utils::scanCode(getCurrentSourceModel()->getInjectionCode(), + "Current source '" + getName() + "' injection code"); } //---------------------------------------------------------------------------- void CurrentSource::finalise(double dt) @@ -56,28 +60,8 @@ void CurrentSource::finalise(double dt) // Initialise derived parameters for variable initialisers for(auto &v : m_VarInitialisers) { - v.second.initDerivedParams(dt); - } -} -//---------------------------------------------------------------------------- -bool CurrentSource::isSimRNGRequired() const -{ - // Returns true if any parts of the current source code require an RNG - if(Utils::isRNGRequired(getCurrentSourceModel()->getInjectionCode())) { - return true; + v.second.finalise(dt); } - - return false; -} -//---------------------------------------------------------------------------- -bool CurrentSource::isInitRNGRequired() const -{ - // If initialising the neuron variables require an RNG, return true - if(Utils::isRNGRequired(getVarInitialisers())) { - return true; - } - - return false; } //---------------------------------------------------------------------------- bool CurrentSource::isZeroCopyEnabled() const diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index 567c540756..c2258ce045 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -75,29 +75,19 @@ VarLocation CustomConnectivityUpdate::getPostVarLocation(const std::string &varN bool CustomConnectivityUpdate::isVarInitRequired() const { return std::any_of(m_VarInitialisers.cbegin(), m_VarInitialisers.cend(), - [](const auto &init){ return !init.second.getSnippet()->getCode().empty(); }); + [](const auto &v){ return !Utils::areTokensEmpty(v.second.getCodeTokens()); }); } //------------------------------------------------------------------------ bool CustomConnectivityUpdate::isPreVarInitRequired() const { return std::any_of(m_PreVarInitialisers.cbegin(), m_PreVarInitialisers.cend(), - [](const auto &init){ return !init.second.getSnippet()->getCode().empty(); }); + [](const auto &v){ return !Utils::areTokensEmpty(v.second.getCodeTokens()); }); } //------------------------------------------------------------------------ bool CustomConnectivityUpdate::isPostVarInitRequired() const { return std::any_of(m_PostVarInitialisers.cbegin(), m_PostVarInitialisers.cend(), - [](const auto &init){ return !init.second.getSnippet()->getCode().empty(); }); -} -//------------------------------------------------------------------------ -bool CustomConnectivityUpdate::isRowSimRNGRequired() const -{ - return Utils::isRNGRequired(getCustomConnectivityUpdateModel()->getRowUpdateCode()); -} -//------------------------------------------------------------------------ -bool CustomConnectivityUpdate::isHostRNGRequired() const -{ - return Utils::isRNGRequired(getCustomConnectivityUpdateModel()->getHostUpdateCode()); + [](const auto &v){ return !Utils::areTokensEmpty(v.second.getCodeTokens()); }); } //------------------------------------------------------------------------ CustomConnectivityUpdate::CustomConnectivityUpdate(const std::string &name, const std::string &updateGroupName, SynapseGroupInternal *synapseGroup, @@ -114,6 +104,20 @@ CustomConnectivityUpdate::CustomConnectivityUpdate(const std::string &name, cons m_VarReferences(varReferences), m_PreVarReferences(preVarReferences), m_PostVarReferences(postVarReferences), m_PreDelayNeuronGroup(nullptr), m_PostDelayNeuronGroup(nullptr) { + + // Validate names + Utils::validatePopName(name, "Custom connectivity update"); + Utils::validatePopName(updateGroupName, "Custom connectivity update group name"); + getCustomConnectivityUpdateModel()->validate(getParams(), getVarInitialisers(), getPreVarInitialisers(), + getPostVarInitialisers(), getVarReferences(), getPreVarReferences(), + getPostVarReferences(), "Custom connectivity update " + getName()); + + // Scan custom connectivity update model code strings + m_RowUpdateCodeTokens = Utils::scanCode(getCustomConnectivityUpdateModel()->getRowUpdateCode(), + "Custom connectivity update '" + getName() + "' row update code"); + m_HostUpdateCodeTokens = Utils::scanCode(getCustomConnectivityUpdateModel()->getHostUpdateCode(), + "Custom connectivity update '" + getName() + "' host update code"); + // Give error if synapse group has unsupported connectivity type if (!(getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE)) { throw std::runtime_error("Custom connectivity updates can only be attached to synapse groups with SPARSE connectivity."); @@ -144,13 +148,6 @@ CustomConnectivityUpdate::CustomConnectivityUpdate(const std::string &name, cons { throw std::runtime_error("All referenced postsynaptic variables must have the same size as postsynaptic population."); } - - // Validate names - Utils::validatePopName(name, "Custom connectivity update"); - Utils::validatePopName(updateGroupName, "Custom connectivity update group name"); - getCustomConnectivityUpdateModel()->validate(getParams(), getVarInitialisers(), getPreVarInitialisers(), - getPostVarInitialisers(), getVarReferences(), getPreVarReferences(), - getPostVarReferences(), "Custom connectivity update " + getName()); } //------------------------------------------------------------------------ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) @@ -199,21 +196,6 @@ void CustomConnectivityUpdate::finalise(double dt, unsigned int batchSize) m_PostDelayNeuronGroup = getVarRefDelayGroup(getPostVarReferences(), "postsynaptic"); } //------------------------------------------------------------------------ -bool CustomConnectivityUpdate::isPreVarInitRNGRequired() const -{ - return Utils::isRNGRequired(getPreVarInitialisers()); -} -//------------------------------------------------------------------------ -bool CustomConnectivityUpdate::isPostVarInitRNGRequired() const -{ - return Utils::isRNGRequired(getPostVarInitialisers()); -} -//------------------------------------------------------------------------ -bool CustomConnectivityUpdate::isVarInitRNGRequired() const -{ - return Utils::isRNGRequired(getVarInitialisers()); -} -//------------------------------------------------------------------------ bool CustomConnectivityUpdate::isZeroCopyEnabled() const { // If there are any synaptic variables implemented in zero-copy mode return true diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 3eda1d3f37..da8c594247 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -32,6 +32,23 @@ bool CustomUpdateBase::isVarInitRequired() const [](const auto &init){ return !init.second.getSnippet()->getCode().empty(); }); } //---------------------------------------------------------------------------- +CustomUpdateBase::CustomUpdateBase(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, + const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, + VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) +: m_Name(name), m_UpdateGroupName(updateGroupName), m_CustomUpdateModel(customUpdateModel), m_Params(params), + m_VarInitialisers(varInitialisers), m_VarLocation(varInitialisers.size(), defaultVarLocation), + m_ExtraGlobalParamLocation(customUpdateModel->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation), + m_Batched(false) +{ + // Validate names + Utils::validatePopName(name, "Custom update"); + Utils::validatePopName(updateGroupName, "Custom update group name"); + + // Scan custom update model code string + m_UpdateCodeTokens = Utils::scanCode(getCustomUpdateModel()->getUpdateCode(), + "Custom update '" + getName() + "' update code"); +} +//---------------------------------------------------------------------------- void CustomUpdateBase::finalise(double dt) { auto derivedParams = getCustomUpdateModel()->getDerivedParams(); @@ -43,7 +60,7 @@ void CustomUpdateBase::finalise(double dt) // Initialise derived parameters for variable initialisers for(auto &v : m_VarInitialisers) { - v.second.initDerivedParams(dt); + v.second.finalise(dt); } } //---------------------------------------------------------------------------- diff --git a/src/genn/genn/gennUtils.cc b/src/genn/genn/gennUtils.cc index 919ea56d3f..8d8834a858 100644 --- a/src/genn/genn/gennUtils.cc +++ b/src/genn/genn/gennUtils.cc @@ -119,7 +119,7 @@ bool isRNGRequired(const std::unordered_map &varIn { // Return true if any of these variable initialisers require an RNG return std::any_of(varInitialisers.cbegin(), varInitialisers.cend(), - [](const auto &varInit) { return isRNGRequired(varInit.second); }); + [](const auto &varInit) { return isRNGRequired(varInit.second.getCodeTokens()); }); } //-------------------------------------------------------------------------- void validateVarName(const std::string &name, const std::string &description) diff --git a/src/genn/genn/initSparseConnectivitySnippet.cc b/src/genn/genn/initSparseConnectivitySnippet.cc index faae9d87ee..ebce9c1f93 100644 --- a/src/genn/genn/initSparseConnectivitySnippet.cc +++ b/src/genn/genn/initSparseConnectivitySnippet.cc @@ -46,9 +46,9 @@ Init::Init(const Base *snippet, const std::unordered_map &p : Snippet::Init(snippet, params) { // Scan code tokens - m_RowBuildCodeTokens = Utils::scanCode(getSnippet()->getRowBuildCode(), "Row build code"); - m_ColBuildCodeTokens = Utils::scanCode(getSnippet()->getColBuildCode(), "Col build code"); - m_HostInitCodeTokens = Utils::scanCode(getSnippet()->getHostInitCode(), "Host init code"); + m_RowBuildCodeTokens = Utils::scanCode(getSnippet()->getRowBuildCode(), "Sparse connectivity row build code"); + m_ColBuildCodeTokens = Utils::scanCode(getSnippet()->getColBuildCode(), "Sparse connectivity col build code"); + m_HostInitCodeTokens = Utils::scanCode(getSnippet()->getHostInitCode(), "Sparse connectivity host init code"); } //---------------------------------------------------------------------------- bool Init::isRNGRequired() const @@ -58,6 +58,6 @@ bool Init::isRNGRequired() const //---------------------------------------------------------------------------- bool Init::isHostRNGRequired() const { - return Utils::isRNGRequired(m_HostInitTokens); + return Utils::isRNGRequired(m_HostInitCodeTokens); } } // namespace GeNN::InitSparseConnectivitySnippet diff --git a/src/genn/genn/initToeplitzConnectivitySnippet.cc b/src/genn/genn/initToeplitzConnectivitySnippet.cc index 3228645791..bfc6a3bcd9 100644 --- a/src/genn/genn/initToeplitzConnectivitySnippet.cc +++ b/src/genn/genn/initToeplitzConnectivitySnippet.cc @@ -37,8 +37,9 @@ Init::Init(const Base *snippet, const std::unordered_map &p : Snippet::Init(snippet, params) { // Scan code tokens - m_DiagonalBuildCodeTokens = Utils::scanCode(getSnippet()->getDiagonalBuildCode(), context, errorContext + "diagonal build code"); + m_DiagonalBuildCodeTokens = Utils::scanCode(getSnippet()->getDiagonalBuildCode(), "Toeplitz connectivity diagonal build code"); } +//---------------------------------------------------------------------------- bool Init::isRNGRequired() const { return Utils::isRNGRequired(m_DiagonalBuildCodeTokens); diff --git a/src/genn/genn/modelSpec.cc b/src/genn/genn/modelSpec.cc index 9e22af0649..3f304ce6b9 100644 --- a/src/genn/genn/modelSpec.cc +++ b/src/genn/genn/modelSpec.cc @@ -216,7 +216,7 @@ void ModelSpec::finalise() // Finalise neuron groups const auto typeContext = getTypeContext(); for(auto &n : m_LocalNeuronGroups) { - n.second.finalise(m_DT, typeContext); + n.second.finalise(m_DT); } // Finalise synapse groups @@ -241,7 +241,7 @@ void ModelSpec::finalise() // Finalize custom connectivity update groups for (auto &c : m_CustomConnectivityUpdates) { - c.second.finalise(m_BatchSize); + c.second.finalise(m_DT, m_BatchSize); } // Merge incoming postsynaptic models diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 736c426ad5..11aee1446a 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -44,6 +44,11 @@ VarInit::VarInit(const InitVarSnippet::Base *snippet, const std::unordered_mapgetCode(), "Variable initialisation code"); } //---------------------------------------------------------------------------- +VarInit::VarInit(double constant) +: Snippet::Init(InitVarSnippet::Constant::getInstance(), {{"constant", constant}}) +{ +} +//---------------------------------------------------------------------------- bool VarInit::isRNGRequired() const { return Utils::isRNGRequired(m_CodeTokens); diff --git a/src/genn/genn/neuronGroup.cc b/src/genn/genn/neuronGroup.cc index 834d3d9cca..ba9a0a9ea5 100644 --- a/src/genn/genn/neuronGroup.cc +++ b/src/genn/genn/neuronGroup.cc @@ -255,7 +255,7 @@ bool NeuronGroup::isSimRNGRequired() const // Return true if any current sources require an RNG for simulation if(std::any_of(m_MergedCurrentSourceGroups.cbegin(), m_MergedCurrentSourceGroups.cend(), - [](const CurrentSourceInternal *cs){ return cs->isSimRNGRequired(); })) + [](const CurrentSourceInternal *cs){ return Utils::isRNGRequired(cs->getInjectionCodeTokens()); })) { return true; } @@ -279,7 +279,7 @@ bool NeuronGroup::isInitRNGRequired() const // Return true if any current sources require an RNG for initialisation if(std::any_of(m_MergedCurrentSourceGroups.cbegin(), m_MergedCurrentSourceGroups.cend(), - [](const CurrentSourceInternal *cs){ return cs->isInitRNGRequired(); })) + [](const CurrentSourceInternal *cs){ return Utils::isRNGRequired(cs->getVarInitialisers()); })) { return true; } diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index d16c3652c9..f92b5c7957 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -500,7 +500,7 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType srcNeuronGroup->checkNumDelaySlots(delaySteps); } //---------------------------------------------------------------------------- -void SynapseGroup::finalise(double dt, const Type::TypeContext &context) +void SynapseGroup::finalise(double dt) { auto wuDerivedParams = getWUModel()->getDerivedParams(); auto psDerivedParams = getPSModel()->getDerivedParams(); From 7c2860c90dc4bec2fcbf3ea947c6e677759ae6ad Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 14:00:41 +0100 Subject: [PATCH 323/725] fixed small bugs --- src/genn/genn/models.cc | 2 ++ src/genn/genn/neuronGroup.cc | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 11aee1446a..ba4d018cd4 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -47,6 +47,8 @@ VarInit::VarInit(const InitVarSnippet::Base *snippet, const std::unordered_map(InitVarSnippet::Constant::getInstance(), {{"constant", constant}}) { + // Scan code tokens + m_CodeTokens = Utils::scanCode(getSnippet()->getCode(), "Variable initialisation code"); } //---------------------------------------------------------------------------- bool VarInit::isRNGRequired() const diff --git a/src/genn/genn/neuronGroup.cc b/src/genn/genn/neuronGroup.cc index ba9a0a9ea5..8a954fdf5b 100644 --- a/src/genn/genn/neuronGroup.cc +++ b/src/genn/genn/neuronGroup.cc @@ -322,8 +322,8 @@ NeuronGroup::NeuronGroup(const std::string &name, int numNeurons, const NeuronMo "Neuron group '" + getName() + "' sim code"); m_ThresholdConditionCodeTokens = Utils::scanCode(getNeuronModel()->getThresholdConditionCode(), "Neuron group '" + getName() + "' threshold condition code"); - m_SimCodeTokens = Utils::scanCode(getNeuronModel()->getResetCode(), - "Neuron group '" + getName() + "' reset code"); + m_ResetCodeTokens = Utils::scanCode(getNeuronModel()->getResetCode(), + "Neuron group '" + getName() + "' reset code"); } //---------------------------------------------------------------------------- void NeuronGroup::checkNumDelaySlots(unsigned int requiredDelay) From f85dbed6deeaba8f9c3dce5e4c3acc1b3d42cbae Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 14:45:01 +0100 Subject: [PATCH 324/725] fixed another bug --- src/genn/genn/gennUtils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/gennUtils.cc b/src/genn/genn/gennUtils.cc index 8d8834a858..61353c405f 100644 --- a/src/genn/genn/gennUtils.cc +++ b/src/genn/genn/gennUtils.cc @@ -20,7 +20,7 @@ namespace { const std::unordered_set randomFuncs{ - "gennrand_uniform" + "gennrand_uniform", "gennrand_normal", "gennrand_exponential", "gennrand_log_normal", From 3bace56618282e0c0610960c71fdbcabd766b69f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 16:34:57 +0100 Subject: [PATCH 325/725] removed stupid row and column build vars from sparse connectivity init --- .../genn/genn/initSparseConnectivitySnippet.h | 145 ++++++++---------- .../genn/code_generator/initGroupMerged.cc | 24 +-- .../genn/initSparseConnectivitySnippet.cc | 4 - 3 files changed, 69 insertions(+), 104 deletions(-) diff --git a/include/genn/genn/initSparseConnectivitySnippet.h b/include/genn/genn/initSparseConnectivitySnippet.h index fc33ff9927..0f1ba57153 100644 --- a/include/genn/genn/initSparseConnectivitySnippet.h +++ b/include/genn/genn/initSparseConnectivitySnippet.h @@ -16,10 +16,7 @@ // Macros //---------------------------------------------------------------------------- #define SET_ROW_BUILD_CODE(CODE) virtual std::string getRowBuildCode() const override{ return CODE; } -#define SET_ROW_BUILD_STATE_VARS(...) virtual ParamValVec getRowBuildStateVars() const override{ return __VA_ARGS__; } - #define SET_COL_BUILD_CODE(CODE) virtual std::string getColBuildCode() const override{ return CODE; } -#define SET_COL_BUILD_STATE_VARS(...) virtual ParamValVec getColBuildStateVars() const override{ return __VA_ARGS__; } #define SET_HOST_INIT_CODE(CODE) virtual std::string getHostInitCode() const override{ return CODE; } @@ -43,11 +40,7 @@ class GENN_EXPORT Base : public Snippet::Base // Declared virtuals //---------------------------------------------------------------------------- virtual std::string getRowBuildCode() const{ return ""; } - virtual ParamValVec getRowBuildStateVars() const{ return {}; } - virtual std::string getColBuildCode() const { return ""; } - virtual ParamValVec getColBuildStateVars() const { return {}; } - virtual std::string getHostInitCode() const{ return ""; } //! Get function to calculate the maximum row length of this connector based on the parameters and the size of the pre and postsynaptic population @@ -133,8 +126,6 @@ class FixedProbabilityBase : public Base public: virtual std::string getRowBuildCode() const override = 0; - SET_ROW_BUILD_STATE_VARS({{"prevJ", "int", -1}}); - SET_PARAM_NAMES({"prob"}); SET_DERIVED_PARAMS({{"probLogRecip", [](const std::unordered_map &pars, double){ return 1.0 / log(1.0 - pars.at("prob")); }}}); @@ -176,13 +167,16 @@ class FixedProbability : public FixedProbabilityBase DECLARE_SNIPPET(InitSparseConnectivitySnippet::FixedProbability); SET_ROW_BUILD_CODE( - "const scalar u = $(gennrand_uniform);\n" - "prevJ += (1 + (int)(log(u) * $(probLogRecip)));\n" - "if(prevJ < $(num_post)) {\n" - " $(addSynapse, prevJ + $(id_post_begin));\n" - "}\n" - "else {\n" - " $(endRow);\n" + "int prevJ = -1;\n" + "while(true) {\n" + " const scalar u = gennrand_uniform();\n" + " prevJ += (1 + (int)(log(u) * probLogRecip));\n" + " if(prevJ < num_post) {\n" + " addSynapse(prevJ + id_post_begin);\n" + " }\n" + " else {\n" + " break;\n" + " }\n" "}\n"); }; @@ -208,17 +202,20 @@ class FixedProbabilityNoAutapse : public FixedProbabilityBase DECLARE_SNIPPET(InitSparseConnectivitySnippet::FixedProbabilityNoAutapse); SET_ROW_BUILD_CODE( - "int nextJ;\n" - "do {\n" - " const scalar u = $(gennrand_uniform);\n" - " nextJ = prevJ + (1 + (int)(log(u) * $(probLogRecip)));\n" - "} while(nextJ == $(id_pre));\n" - "prevJ = nextJ;\n" - "if(prevJ < $(num_post)) {\n" - " $(addSynapse, prevJ + $(id_post_begin));\n" - "}\n" - "else {\n" - " $(endRow);\n" + "int prevJ = -1;\n" + "while(true) {\n" + " int nextJ;\n" + " do {\n" + " const scalar u = gennrand_uniform();\n" + " nextJ = prevJ + (1 + (int)(log(u) * probLogRecip));\n" + " } while(nextJ == id_pre);\n" + " prevJ = nextJ;\n" + " if(prevJ < num_post) {\n" + " addSynapse(prevJ + id_post_begin);\n" + " }\n" + " else {\n" + " break;\n" + " }\n" "}\n"); }; @@ -236,16 +233,14 @@ class FixedNumberPostWithReplacement : public Base DECLARE_SNIPPET(InitSparseConnectivitySnippet::FixedNumberPostWithReplacement); SET_ROW_BUILD_CODE( - "if(c == 0) {\n" - " $(endRow);\n" - "}\n" - "const scalar u = $(gennrand_uniform);\n" - "x += (1.0 - x) * (1.0 - pow(u, 1.0 / (scalar)c));\n" - "unsigned int postIdx = (unsigned int)(x * $(num_post));\n" - "postIdx = (postIdx < $(num_post)) ? postIdx : ($(num_post) - 1);\n" - "$(addSynapse, postIdx + $(id_post_begin));\n" - "c--;\n"); - SET_ROW_BUILD_STATE_VARS({{"x", "scalar", 0.0},{"c", "unsigned int", "$(rowLength)"}}); + "scalar x = 0.0;\n" + "for(unsigned int c = rowLength; c != 0; c--) {\n" + " const scalar u = gennrand_uniform();\n" + " x += (1.0 - x) * (1.0 - pow(u, 1.0 / (scalar)c));\n" + " unsigned int postIdx = (unsigned int)(x * num_post);\n" + " postIdx = (postIdx < num_post) ? postIdx : (num_post - 1);\n" + " addSynapse(postIdx + id_post_begin);\n" + "}\n"); SET_PARAM_NAMES({"rowLength"}); @@ -285,16 +280,14 @@ class FixedNumberTotalWithReplacement : public Base DECLARE_SNIPPET(InitSparseConnectivitySnippet::FixedNumberTotalWithReplacement); SET_ROW_BUILD_CODE( - "if(c == 0) {\n" - " $(endRow);\n" - "}\n" - "const scalar u = $(gennrand_uniform);\n" - "x += (1.0 - x) * (1.0 - pow(u, 1.0 / (scalar)c));\n" - "unsigned int postIdx = (unsigned int)(x * $(num_post));\n" - "postIdx = (postIdx < $(num_post)) ? postIdx : ($(num_post) - 1);\n" - "$(addSynapse, postIdx + $(id_post_begin));\n" - "c--;\n"); - SET_ROW_BUILD_STATE_VARS({{"x", "scalar", 0.0},{"c", "unsigned int", "$(preCalcRowLength)[($(id_pre) * $(num_threads)) + $(id_thread)]"}}); + "scalar x = 0.0;\n" + "for(unsigned int c = preCalcRowLength[(id_pre * num_threads) + id_thread]; c != 0; c--) {\n" + " const scalar u = gennrand_uniform();\n" + " x += (1.0 - x) * (1.0 - pow(u, 1.0 / (scalar)c));\n" + " unsigned int postIdx = (unsigned int)(x * num_post);\n" + " postIdx = (postIdx < num_post) ? postIdx : (num_post - 1);\n" + " addSynapse(postIdx + id_post_begin);\n" + "}\n"); SET_PARAM_NAMES({"total"}); SET_EXTRA_GLOBAL_PARAMS({{"preCalcRowLength", "uint16_t*"}}) @@ -374,14 +367,11 @@ class FixedNumberPreWithReplacement : public Base DECLARE_SNIPPET(InitSparseConnectivitySnippet::FixedNumberPreWithReplacement); SET_COL_BUILD_CODE( - "if(c == 0) {\n" - " $(endCol);\n" - "}\n" - "const unsigned int idPre = (unsigned int)ceil($(gennrand_uniform) * $(num_pre)) - 1;\n" - "$(addSynapse, idPre + $(id_pre_begin));\n" - "c--;\n"); - SET_COL_BUILD_STATE_VARS({{"c", "unsigned int", "$(colLength)"}}); - + "for(unsigned int c = colLength; c != 0; c--) {\n" + " const unsigned int idPre = (unsigned int)ceil(gennrand_uniform() * num_pre) - 1;\n" + " addSynapse(idPre + id_pre_begin);\n" + "}\n"); + SET_PARAM_NAMES({"colLength"}); SET_CALC_MAX_ROW_LENGTH_FUNC( @@ -421,31 +411,28 @@ class Conv2D : public Base "conv_ih", "conv_iw", "conv_ic", "conv_oh", "conv_ow", "conv_oc"}); - SET_ROW_BUILD_STATE_VARS({{"inRow", "int", "($(id_pre) / (int)$(conv_ic)) / (int)$(conv_iw)"}, - {"inCol", "int", "($(id_pre) / (int)$(conv_ic)) % (int)$(conv_iw)"}, - {"inChan", "int", "$(id_pre) % (int)$(conv_ic)"}, - {"outRow", "int", "min((int)$(conv_oh), max(0, 1 + (int)floor((inRow + $(conv_padh) - $(conv_kh)) / $(conv_sh))))"}, - {"maxOutRow", "int", "min((int)$(conv_oh), max(0, 1 + ((inRow + (int)$(conv_padh)) / (int)$(conv_sh))))"}, - {"minOutCol", "int", "min((int)$(conv_ow), max(0, 1 + (int)floor((inCol + $(conv_padw) - $(conv_kw)) / $(conv_sw))))"}, - {"maxOutCol", "int", "min((int)$(conv_ow), max(0, 1 + ((inCol + (int)$(conv_padw)) / (int)$(conv_sw))))"}}); - SET_ROW_BUILD_CODE( - "if($(outRow) == $(maxOutRow)) {\n" - " $(endRow);\n" - "}\n" - "const int strideRow = ($(outRow) * (int)$(conv_sh)) - (int)$(conv_padh);\n" - "const int kernRow = $(inRow) - strideRow;\n" - "for(int outCol = $(minOutCol); outCol < $(maxOutCol); outCol++) {\n" - " const int strideCol = (outCol * (int)$(conv_sw)) - (int)$(conv_padw);\n" - " const int kernCol = $(inCol) - strideCol;\n" - " for(unsigned int outChan = 0; outChan < (unsigned int)$(conv_oc); outChan++) {\n" - " const int idPost = (($(outRow) * (int)$(conv_ow) * (int)$(conv_oc)) +\n" - " (outCol * (int)$(conv_oc)) +\n" - " outChan);\n" - " $(addSynapse, idPost, kernRow, kernCol, $(inChan), outChan);\n" - " }\n" - "}\n" - "$(outRow)++;\n"); + "const int inRow = (id_pre / (int)conv_ic) / (int)conv_iw\n" + "const int inCol = (id_pre / (int)conv_ic) % (int)conv_iw\n" + "const int inChan = id_pre % (int)conv_ic\n" + "const int maxOutRow = min((int)conv_oh, max(0, 1 + ((inRow + (int)conv_padh) / (int)conv_sh)))\n" + "const int minOutCol = min((int)conv_ow, max(0, 1 + (int)floor((inCol + conv_padw - conv_kw) / conv_sw)))\n" + "const int maxOutCol = min((int)conv_ow, max(0, 1 + ((inCol + (int)conv_padw) / (int)conv_sw)))\n" + "int outRow = min((int)conv_oh, max(0, 1 + (int)floor((inRow + conv_padh - conv_kh) / conv_sh)))\n" + "for(;outRow < maxOutRow; outRow++) {\n" + " const int strideRow = (outRow * (int)conv_sh) - (int)conv_padh;\n" + " const int kernRow = inRow - strideRow;\n" + " for(int outCol = minOutCol; outCol < maxOutCol; outCol++) {\n" + " const int strideCol = (outCol * (int)conv_sw) - (int)conv_padw;\n" + " const int kernCol = inCol - strideCol;\n" + " for(unsigned int outChan = 0; outChan < (unsigned int)conv_oc; outChan++) {\n" + " const int idPost = ((outRow * (int)conv_ow * (int)conv_oc) +\n" + " (outCol * (int)conv_oc) +\n" + " outChan);\n" + " addSynapse(idPost, kernRow, kernCol, inChan, outChan);\n" + " }\n" + " }\n" + "}\n"); SET_CALC_MAX_ROW_LENGTH_FUNC( [](unsigned int, unsigned int, const std::unordered_map &pars) diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 9ae8168f11..3318a25cea 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -713,10 +713,6 @@ void SynapseConnectivityInitGroupMerged::genInitConnectivity(const BackendBase & // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); - // Add substitution for end function - // **TODO** remove - groupEnv.add(Type::ResolvedType::createFunction(Type::Void, {}), rowNotColumns ? "endRow" : "endCol", "break;"); - // Substitute in parameters and derived parameters for initialising variables groupEnv.addConnectInitParams("", &SynapseGroupInternal::getConnectivityInitialiser, &SynapseConnectivityInitGroupMerged::isSparseConnectivityInitParamHeterogeneous); @@ -724,24 +720,10 @@ void SynapseConnectivityInitGroupMerged::genInitConnectivity(const BackendBase & &SynapseConnectivityInitGroupMerged::isSparseConnectivityInitDerivedParamHeterogeneous); groupEnv.addExtraGlobalParams(snippet->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", ""); - - // Initialise state variables and loop on generated code to initialise sparse connectivity - groupEnv.getStream() << "// Build sparse connectivity" << std::endl; - const auto stateVars = rowNotColumns ? snippet->getRowBuildStateVars() : snippet->getColBuildStateVars(); const std::string context = rowNotColumns ? "row" : "column"; - for(const auto &a : stateVars) { - const auto resolvedType = a.type.resolve(getTypeContext()); - groupEnv.getStream() << resolvedType.getName() << " _" << a.name << " = " << a.value << ";" << std::endl; - groupEnv.add(resolvedType, a.name, "_" + a.name); - } - groupEnv.getStream() << "while(true)"; - { - CodeStream::Scope b(groupEnv.getStream()); - - Transpiler::ErrorHandler errorHandler("Synapse group sparse connectivity '" + getArchetype().getName() + "' " + context + " build code"); - prettyPrintStatements(rowNotColumns ? connectInit.getRowBuildCodeTokens() : connectInit.getColBuildCodeTokens(), - getTypeContext(), groupEnv, errorHandler); - } + Transpiler::ErrorHandler errorHandler("Synapse group sparse connectivity '" + getArchetype().getName() + "' " + context + " build code"); + prettyPrintStatements(rowNotColumns ? connectInit.getRowBuildCodeTokens() : connectInit.getColBuildCodeTokens(), + getTypeContext(), groupEnv, errorHandler); } diff --git a/src/genn/genn/initSparseConnectivitySnippet.cc b/src/genn/genn/initSparseConnectivitySnippet.cc index ebce9c1f93..da50be7d2a 100644 --- a/src/genn/genn/initSparseConnectivitySnippet.cc +++ b/src/genn/genn/initSparseConnectivitySnippet.cc @@ -24,9 +24,7 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const Snippet::Base::updateHash(hash); Utils::updateHash(getRowBuildCode(), hash); - Utils::updateHash(getRowBuildStateVars(), hash); Utils::updateHash(getColBuildCode(), hash); - Utils::updateHash(getColBuildStateVars(), hash); Utils::updateHash(getHostInitCode(), hash); return hash.get_digest(); } @@ -35,8 +33,6 @@ void Base::validate(const std::unordered_map ¶mValues) { // Superclass Snippet::Base::validate(paramValues, "Sparse connectivity initialiser "); - Utils::validateVecNames(getRowBuildStateVars(), "Row building state variable"); - Utils::validateVecNames(getColBuildStateVars(), "Column building state variable"); } //---------------------------------------------------------------------------- From 86503350614ba896a0f51de807794e710610e7f8 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 16:35:14 +0100 Subject: [PATCH 326/725] no need for macro trick around addSynapse --- src/genn/backends/single_threaded_cpu/backend.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index c95c99b610..52e9af853f 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1036,8 +1036,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler std::ostringstream addSynapseStream; CodeStream addSynapse(addSynapseStream); - // Use classic macro trick to turn block of initialization code into statement and 'eat' semicolon - addSynapse << "do"; + // Create block of code to add synapse { CodeStream::Scope b(addSynapse); @@ -1105,7 +1104,6 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler } } } - addSynapse << "while(false)"; const auto addSynapseType = Type::ResolvedType::createFunction(Type::Void, std::vector{1ull + s.getArchetype().getKernelSize().size(), Type::Uint32}); groupEnv.add(addSynapseType, "addSynapse", addSynapseStream.str()); From 90e875f9517f60c53e8c114b3703bdd6eb2aa216 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 16:35:37 +0100 Subject: [PATCH 327/725] true and false should be treated as literals not keywords --- include/genn/genn/transpiler/token.h | 4 ++-- src/genn/genn/transpiler/parser.cc | 2 +- src/genn/genn/transpiler/scanner.cc | 4 ++-- src/genn/genn/transpiler/typeChecker.cc | 5 ++++- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/include/genn/genn/transpiler/token.h b/include/genn/genn/transpiler/token.h index d1d86afc74..9d9e6c5c6d 100644 --- a/include/genn/genn/transpiler/token.h +++ b/include/genn/genn/transpiler/token.h @@ -37,14 +37,14 @@ struct Token SHIFT_LEFT_EQUAL, SHIFT_RIGHT_EQUAL, // Literals - IDENTIFIER, UINT32_NUMBER, INT32_NUMBER, FLOAT_NUMBER, DOUBLE_NUMBER, SCALAR_NUMBER, STRING, + IDENTIFIER, UINT32_NUMBER, INT32_NUMBER, FLOAT_NUMBER, DOUBLE_NUMBER, SCALAR_NUMBER, BOOLEAN, STRING, // Types TYPE_SPECIFIER, TYPE_QUALIFIER, // Keywords - DO, ELSE, FALSE, FOR, FOR_EACH_SYNAPSE, IF, TRUE, WHILE, SWITCH, CONTINUE, BREAK, CASE, DEFAULT, + DO, ELSE, FOR, FOR_EACH_SYNAPSE, IF, WHILE, SWITCH, CONTINUE, BREAK, CASE, DEFAULT, END_OF_FILE, }; diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 1662dbc5e2..cda427e795 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -291,7 +291,7 @@ Expression::ExpressionPtr parsePrimary(ParserState &parserState) // identifier // constant // "(" expression ")" - if (parserState.match({Token::Type::FALSE, Token::Type::TRUE, Token::Type::STRING, + if (parserState.match({Token::Type::BOOLEAN, Token::Type::STRING, Token::Type::DOUBLE_NUMBER, Token::Type::FLOAT_NUMBER, Token::Type::SCALAR_NUMBER, Token::Type::INT32_NUMBER, Token::Type::UINT32_NUMBER})) { diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index a90dc17d8b..bf8f97f696 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -26,11 +26,11 @@ const std::unordered_map keywords{ {"const", Token::Type::TYPE_QUALIFIER}, {"do", Token::Type::DO}, {"else", Token::Type::ELSE}, - {"false", Token::Type::FALSE}, + {"false", Token::Type::BOOLEAN}, {"for", Token::Type::FOR}, {"for_each_synapse", Token::Type::FOR_EACH_SYNAPSE}, {"if", Token::Type::IF}, - {"true", Token::Type::TRUE}, + {"true", Token::Type::BOOLEAN}, {"while", Token::Type::WHILE}, {"switch", Token::Type::SWITCH}, {"break", Token::Type::BREAK}, diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index 5a6d9890b2..c4fbf4bffd 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -451,7 +451,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Expression::Literal &literal) final { - // Convert number token type to type + // Convert literal token type to type if (literal.getValue().type == Token::Type::DOUBLE_NUMBER) { setExpressionType(&literal, Type::Double); } @@ -467,6 +467,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor else if (literal.getValue().type == Token::Type::UINT32_NUMBER) { setExpressionType(&literal, Type::Uint32); } + else if(literal.getValue().type == Token::Type::BOOLEAN) { + setExpressionType(&literal, Type::Bool); + } else if(literal.getValue().type == Token::Type::STRING) { setExpressionType(&literal, Type::Int8.createPointer(Type::Qualifier::CONSTANT)); } From fc5641ce403017848e38c8a953131240601bd9b0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 17:47:55 +0100 Subject: [PATCH 328/725] fixed bug in logging --- include/genn/genn/logging.h | 48 ++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/include/genn/genn/logging.h b/include/genn/genn/logging.h index 4849d7c982..0af7a87e02 100644 --- a/include/genn/genn/logging.h +++ b/include/genn/genn/logging.h @@ -17,36 +17,36 @@ class IAppender; // Macros //---------------------------------------------------------------------------- // Shorthand macros for logging to 'GeNN' channel -#define LOGV_GENN LOGV_(Logging::CHANNEL_GENN) -#define LOGD_GENN LOGD_(Logging::CHANNEL_GENN) -#define LOGI_GENN LOGI_(Logging::CHANNEL_GENN) -#define LOGW_GENN LOGW_(Logging::CHANNEL_GENN) -#define LOGE_GENN LOGE_(Logging::CHANNEL_GENN) -#define LOGF_GENN LOGF_(Logging::CHANNEL_GENN) +#define LOGV_GENN LOGV_(GeNN::Logging::CHANNEL_GENN) +#define LOGD_GENN LOGD_(GeNN::Logging::CHANNEL_GENN) +#define LOGI_GENN LOGI_(GeNN::Logging::CHANNEL_GENN) +#define LOGW_GENN LOGW_(GeNN::Logging::CHANNEL_GENN) +#define LOGE_GENN LOGE_(GeNN::Logging::CHANNEL_GENN) +#define LOGF_GENN LOGF_(GeNN::Logging::CHANNEL_GENN) // Shorthand macros for logging to 'code generator' channel -#define LOGV_CODE_GEN LOGV_(Logging::CHANNEL_CODE_GEN) -#define LOGD_CODE_GEN LOGD_(Logging::CHANNEL_CODE_GEN) -#define LOGI_CODE_GEN LOGI_(Logging::CHANNEL_CODE_GEN) -#define LOGW_CODE_GEN LOGW_(Logging::CHANNEL_CODE_GEN) -#define LOGE_CODE_GEN LOGE_(Logging::CHANNEL_CODE_GEN) -#define LOGF_CODE_GEN LOGF_(Logging::CHANNEL_CODE_GEN) +#define LOGV_CODE_GEN LOGV_(GeNN::Logging::CHANNEL_CODE_GEN) +#define LOGD_CODE_GEN LOGD_(GeNN::Logging::CHANNEL_CODE_GEN) +#define LOGI_CODE_GEN LOGI_(GeNN::Logging::CHANNEL_CODE_GEN) +#define LOGW_CODE_GEN LOGW_(GeNN::Logging::CHANNEL_CODE_GEN) +#define LOGE_CODE_GEN LOGE_(GeNN::Logging::CHANNEL_CODE_GEN) +#define LOGF_CODE_GEN LOGF_(GeNN::Logging::CHANNEL_CODE_GEN) // Shorthand macros for logging to 'transpiler' channel -#define LOGV_TRANSPILER LOGV_(Logging::CHANNEL_TRANSPILER) -#define LOGD_TRANSPILER LOGD_(Logging::CHANNEL_TRANSPILER) -#define LOGI_TRANSPILER LOGI_(Logging::CHANNEL_TRANSPILER) -#define LOGW_TRANSPILER LOGW_(Logging::CHANNEL_TRANSPILER) -#define LOGE_TRANSPILER LOGE_(Logging::CHANNEL_TRANSPILER) -#define LOGF_TRANSPILER LOGF_(Logging::CHANNEL_TRANSPILER) +#define LOGV_TRANSPILER LOGV_(GeNN::Logging::CHANNEL_TRANSPILER) +#define LOGD_TRANSPILER LOGD_(GeNN::Logging::CHANNEL_TRANSPILER) +#define LOGI_TRANSPILER LOGI_(GeNN::Logging::CHANNEL_TRANSPILER) +#define LOGW_TRANSPILER LOGW_(GeNN::Logging::CHANNEL_TRANSPILER) +#define LOGE_TRANSPILER LOGE_(GeNN::Logging::CHANNEL_TRANSPILER) +#define LOGF_TRANSPILER LOGF_(GeNN::Logging::CHANNEL_TRANSPILER) // Shorthand macros for logging to 'backend' channel -#define LOGV_BACKEND LOGV_(Logging::CHANNEL_BACKEND) -#define LOGD_BACKEND LOGD_(Logging::CHANNEL_BACKEND) -#define LOGI_BACKEND LOGI_(Logging::CHANNEL_BACKEND) -#define LOGW_BACKEND LOGW_(Logging::CHANNEL_BACKEND) -#define LOGE_BACKEND LOGE_(Logging::CHANNEL_BACKEND) -#define LOGF_BACKEND LOGF_(Logging::CHANNEL_BACKEND) +#define LOGV_BACKEND LOGV_(GeNN::Logging::CHANNEL_BACKEND) +#define LOGD_BACKEND LOGD_(GeNN::Logging::CHANNEL_BACKEND) +#define LOGI_BACKEND LOGI_(GeNN::Logging::CHANNEL_BACKEND) +#define LOGW_BACKEND LOGW_(GeNN::Logging::CHANNEL_BACKEND) +#define LOGE_BACKEND LOGE_(GeNN::Logging::CHANNEL_BACKEND) +#define LOGF_BACKEND LOGF_(GeNN::Logging::CHANNEL_BACKEND) //---------------------------------------------------------------------------- From 90f0ce62243f4d1dbf74c0044a060fc9eb6380ef Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 18:22:03 +0100 Subject: [PATCH 329/725] fixed issue with type checker not handling nested loops --- src/genn/genn/transpiler/typeChecker.cc | 39 ++++++++++++++----------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index c4fbf4bffd..aca67b16ba 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -198,8 +198,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const Type::TypeContext &context, ErrorHandlerBase &errorHandler, StatementHandler forEachSynapseHandler) : m_Environment(environment), m_Context(context), m_ErrorHandler(errorHandler), - m_ForEachSynapseHandler(forEachSynapseHandler), m_ResolvedTypes(resolvedTypes), - m_InLoop(false), m_InSwitch(false) + m_ForEachSynapseHandler(forEachSynapseHandler), m_ResolvedTypes(resolvedTypes) { } @@ -677,7 +676,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor //--------------------------------------------------------------------------- virtual void visit(const Statement::Break &breakStatement) final { - if (!m_InLoop && !m_InSwitch) { + if (m_ActiveLoopStatements.empty() && m_ActiveSwitchStatements.empty()) { m_ErrorHandler.error(breakStatement.getToken(), "Statement not within loop"); throw TypeCheckError(); } @@ -702,7 +701,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::Continue &continueStatement) final { - if (!m_InLoop) { + if (m_ActiveLoopStatements.empty()) { m_ErrorHandler.error(continueStatement.getToken(), "Statement not within loop"); throw TypeCheckError(); } @@ -710,9 +709,10 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::Do &doStatement) final { - m_InLoop = true; + m_ActiveLoopStatements.emplace(&doStatement); doStatement.getBody()->accept(*this); - m_InLoop = false; + assert(m_ActiveLoopStatements.top() == &doStatement); + m_ActiveLoopStatements.pop(); doStatement.getCondition()->accept(*this); } @@ -743,9 +743,10 @@ class Visitor : public Expression::Visitor, public Statement::Visitor forStatement.getIncrement()->accept(*this); } - m_InLoop = true; + m_ActiveLoopStatements.emplace(&forStatement); forStatement.getBody()->accept(*this); - m_InLoop = false; + assert(m_ActiveLoopStatements.top() == &forStatement); + m_ActiveLoopStatements.pop(); // Restore old environment m_Environment = oldEnvironment; @@ -768,9 +769,10 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // Call handler to define anything required in environment m_ForEachSynapseHandler(m_Environment, m_ErrorHandler); - m_InLoop = true; + m_ActiveLoopStatements.emplace(&forEachSynapseStatement); forEachSynapseStatement.getBody()->accept(*this); - m_InLoop = false; + assert(m_ActiveLoopStatements.top() == &forEachSynapseStatement); + m_ActiveLoopStatements.pop(); // Restore old environment m_Environment = oldEnvironment; @@ -787,7 +789,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::Labelled &labelled) final { - if (!m_InSwitch) { + if (m_ActiveSwitchStatements.empty()) { m_ErrorHandler.error(labelled.getKeyword(), "Statement not within switch statement"); throw TypeCheckError(); } @@ -813,9 +815,10 @@ class Visitor : public Expression::Visitor, public Statement::Visitor throw TypeCheckError(); } - m_InSwitch = true; + m_ActiveSwitchStatements.emplace(&switchStatement); switchStatement.getBody()->accept(*this); - m_InSwitch = false; + assert(m_ActiveSwitchStatements.top() == &switchStatement); + m_ActiveSwitchStatements.pop(); } virtual void visit(const Statement::VarDeclaration &varDeclaration) final @@ -839,9 +842,11 @@ class Visitor : public Expression::Visitor, public Statement::Visitor virtual void visit(const Statement::While &whileStatement) final { whileStatement.getCondition()->accept(*this); - m_InLoop = true; + + m_ActiveLoopStatements.emplace(&whileStatement); whileStatement.getBody()->accept(*this); - m_InLoop = false; + assert(m_ActiveLoopStatements.top() == &whileStatement); + m_ActiveLoopStatements.pop(); } private: @@ -870,8 +875,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor StatementHandler m_ForEachSynapseHandler; ResolvedTypeMap &m_ResolvedTypes; std::stack> m_CallArguments; - bool m_InLoop; - bool m_InSwitch; + std::stack m_ActiveLoopStatements; + std::stack m_ActiveSwitchStatements; }; } // Anonymous namespace From d1b9362d7f958f5698fff7d66989c513a103be96 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 18:22:17 +0100 Subject: [PATCH 330/725] fixed indexing of various classes of synapse variable --- .../genn/code_generator/synapseUpdateGroupMerged.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 3a39811a74..69d7e8d690 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -230,7 +230,7 @@ std::string SynapseGroupMergedBase::getPostSlot(unsigned int batchSize) const //---------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getPostDenDelayIndex(unsigned int batchSize, const std::string &index, const std::string &offset) const { - const std::string batchID = ((batchSize == 1) ? "" : "$(_post_batch_offset) + ") + std::string{"$(" + index + ")"}; + const std::string batchID = ((batchSize == 1) ? "" : "$(_post_batch_offset) + ") + index; if(offset.empty()) { return "(*$(_den_delay_ptr) * $(num_post) + " + batchID; @@ -258,7 +258,7 @@ std::string SynapseGroupMergedBase::getPrePrevSpikeTimeIndex(bool delay, unsigne return (singleBatch ? "$(_pre_prev_spike_time_delay_offset) + " : "$(_pre_prev_spike_time_batch_delay_offset) + ") + index; } else { - return (singleBatch ? "" : "$(_pre_batch_offset) + ") + std::string{"$(" + index + ")"}; + return (singleBatch ? "" : "$(_pre_batch_offset) + ") + index; } } //-------------------------------------------------------------------------- @@ -267,23 +267,23 @@ std::string SynapseGroupMergedBase::getPostPrevSpikeTimeIndex(bool delay, unsign const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); if(delay) { - return (singleBatch ? "$(_post_prev_spike_time_delay_offset) + " : "$(_post_prev_spike_time_batch_delay_offset) + ") + std::string{"$(" + index + ")"}; + return (singleBatch ? "$(_post_prev_spike_time_delay_offset) + " : "$(_post_prev_spike_time_batch_delay_offset) + ") + index; } else { - return (singleBatch ? "" : "$(_post_batch_offset) + ") + std::string{"$(" + index + ")"}; + return (singleBatch ? "" : "$(_post_batch_offset) + ") + index; } } //-------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getSynVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const { const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); - return (singleBatch ? "" : "$(_syn_batch_offset)") + std::string{"$(" + index + ")"}; + return (singleBatch ? "" : "$(_syn_batch_offset)") + index; } //-------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getKernelVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const { const bool singleBatch = (varDuplication == VarAccessDuplication::SHARED || batchSize == 1); - return (singleBatch ? "" : "$(_kern_batch_offset)") + std::string{"$(" + index + ")"}; + return (singleBatch ? "" : "$(_kern_batch_offset)") + index; } //---------------------------------------------------------------------------- std::string SynapseGroupMergedBase::getVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, From 07d9a858dc55f77df7e059ef2f9ee9b06fe6835c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 18:32:05 +0100 Subject: [PATCH 331/725] wrong prefix for denDelayPtr --- include/genn/genn/code_generator/backendBase.h | 2 +- src/genn/genn/code_generator/initGroupMerged.cc | 2 +- src/genn/genn/code_generator/neuronUpdateGroupMerged.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index d614f026a7..d6efc790af 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -547,7 +547,7 @@ class GENN_EXPORT BackendBase env.addField(env.getGroup().getScalarType().createPointer(), "_den_delay", "denDelay", [this](const auto &g, size_t) { return getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); env.addField(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); + [this](const auto &g, size_t) { return getScalarAddressPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); // Presynaptic output fields env.addField(env.getGroup().getScalarType().createPointer(), "_out_pre", "outPre", diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 3318a25cea..7472487c81 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -242,7 +242,7 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir // Add field for dendritic delay pointer and zero groupEnv.addField(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr" + fieldSuffix, - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); + [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); backend.genPopVariableInit(groupEnv, [](EnvironmentExternalBase &varEnv) { diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index a964da82d9..553a6fc306 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -94,7 +94,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env psmEnv.addField(getScalarType().createPointer(), "_den_delay", "denDelay" + fieldSuffix, [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix();}); psmEnv.addField(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr" + fieldSuffix, - [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix();}); + [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix();}); // Get reference to dendritic delay buffer input for this timestep psmEnv.printLine(backend.getPointerPrefix() + getScalarType().getName() + " *denDelayFront = &$(_den_delay)[(*$(_den_delay_ptr) * $(num_neurons)) + " + idx + "];"); From 39386e7f121e5e189314f8fc7abd0bc882495af8 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 5 Jul 2023 18:43:25 +0100 Subject: [PATCH 332/725] very rough fix up of synapse connectivity host init code --- .../genn/genn/code_generator/generateRunner.h | 2 +- .../genn/code_generator/generateRunner.cc | 30 ++++++++++++------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/include/genn/genn/code_generator/generateRunner.h b/include/genn/genn/code_generator/generateRunner.h index 549ac67f0a..1650d79197 100644 --- a/include/genn/genn/code_generator/generateRunner.h +++ b/include/genn/genn/code_generator/generateRunner.h @@ -25,6 +25,6 @@ class path; //-------------------------------------------------------------------------- namespace GeNN::CodeGenerator { -GENN_EXPORT MemAlloc generateRunner(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, +GENN_EXPORT MemAlloc generateRunner(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, const std::string &suffix = ""); } diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 5ac3ebabb2..ac8c5c82d8 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -522,7 +522,7 @@ void genCustomUpdate(const ModelSpecMerged &modelMerged, const BackendBase &back //-------------------------------------------------------------------------- // GeNN::CodeGenerator //-------------------------------------------------------------------------- -MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const ModelSpecMerged &modelMerged, +MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, const std::string &suffix) { // Create output streams to write to file and wrap in CodeStreams @@ -731,20 +731,28 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, definitionsInternalFunc << "// copying merged group structures to device" << std::endl; definitionsInternalFunc << "// ------------------------------------------------------------------------" << std::endl; + // Generate merged synapse connectivity host init code + // **NOTE** this needs to be done before generating the runner because this configures the required fields BUT + // needs to be done into a seperate stream because it actually needs to be RUN afterwards so valid pointers + // get copied straight into subsequent structures and merged EGP system isn't required + std::ostringstream synapseConnectivityHostInitStream; + CodeStream synapseConnectivityHostInit(synapseConnectivityHostInitStream); + modelMerged.genMergedSynapseConnectivityHostInitGroups( + backend, + [&backend, &modelMerged, &synapseConnectivityHostInit](auto &sg) + { + EnvironmentExternal env(synapseConnectivityHostInit); + sg.generateInit(backend, env, modelMerged); + }); + // Loop through merged synapse connectivity host initialisation groups for(const auto &m : modelMerged.getMergedSynapseConnectivityHostInitGroups()) { - assert(false); - //m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, - // runnerVarDecl, runnerMergedStructAlloc); + m.generateRunner(backend, definitionsInternal, definitionsInternalFunc, definitionsInternalVar, + runnerVarDecl, runnerMergedStructAlloc); } - // Loop through merged synapse connectivity host init groups and generate host init code - // **NOTE** this is done here so valid pointers get copied straight into subsequent structures and merged EGP system isn't required - for(const auto &sg : modelMerged.getMergedSynapseConnectivityHostInitGroups()) { - assert(false); - //EnvironmentExternal env(runnerMergedStructAlloc); - //sg.generateInit(backend, runnerMergedStructAlloc, modelMerged); - } + // Now insert host initialisation code + runnerMergedStructAlloc << synapseConnectivityHostInitStream.str(); // Generate merged neuron initialisation groups for(const auto &m : modelMerged.getMergedNeuronInitGroups()) { From 6663969589f45ddb7f0295bbc51ee966394120a9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 09:53:06 +0100 Subject: [PATCH 333/725] support for 64-bit integer types --- include/genn/genn/type.h | 12 ++++----- src/genn/genn/transpiler/parser.cc | 40 ++++++++++++++++++++++------- src/genn/genn/transpiler/scanner.cc | 3 +++ 3 files changed, 39 insertions(+), 16 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index da7b995123..a89f058234 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -6,13 +6,9 @@ // Standard C++ includes #include -#include #include -#include #include -#include #include -#include #include #include @@ -266,7 +262,7 @@ struct ResolvedType template static ResolvedType createNumeric(const std::string &name, int rank, const std::string &literalSuffix = "", Qualifier qualifiers = Qualifier{0}) { - return ResolvedType{Value{name, sizeof(T), Numeric{rank, std::numeric_limits::min(), std::numeric_limits::max(), + return ResolvedType{Value{name, sizeof(T), Numeric{rank, std::numeric_limits::min(), static_cast(std::numeric_limits::max()), std::numeric_limits::lowest(), std::numeric_limits::max_digits10, std::is_signed::value, std::is_integral::value, literalSuffix}}, qualifiers}; @@ -335,11 +331,13 @@ inline static const ResolvedType Bool = CREATE_NUMERIC(bool, 0, ""); inline static const ResolvedType Int8 = CREATE_NUMERIC(int8_t, 10, ""); inline static const ResolvedType Int16 = CREATE_NUMERIC(int16_t, 20, ""); inline static const ResolvedType Int32 = CREATE_NUMERIC(int32_t, 30, ""); -//DECLARE_NUMERIC_TYPE(Int64, int64_t, 40); +inline static const ResolvedType Int64 = CREATE_NUMERIC(int64_t, 40, ""); + inline static const ResolvedType Uint8 = CREATE_NUMERIC(uint8_t, 10, "u"); inline static const ResolvedType Uint16 = CREATE_NUMERIC(uint16_t, 20, "u"); inline static const ResolvedType Uint32 = CREATE_NUMERIC(uint32_t, 30, "u"); -//DECLARE_NUMERIC_TYPE(Uint64, uint64_t, 40); +inline static const ResolvedType Uint64 = CREATE_NUMERIC(uint64_t, 40, "u"); + inline static const ResolvedType Float = CREATE_NUMERIC(float, 50, "f"); inline static const ResolvedType Double = CREATE_NUMERIC(double, 60, ""); diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index cda427e795..6f4e6f2aec 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -25,7 +25,7 @@ using namespace GeNN::Transpiler; //--------------------------------------------------------------------------- namespace { -const std::map, Type::ResolvedType> numericTypeSpecifiers{ +const std::map, Type::ResolvedType> numericTypeSpecifiers{ {{"char"}, Type::Int8}, {{"int8_t"}, Type::Int8}, @@ -51,6 +51,25 @@ const std::map, Type::ResolvedType> numericTypeSpecifiers{ {{"unsigned", "int"}, Type::Uint32}, {{"uint32_t"}, Type::Uint32}, + // **NOTE** GeNN uses LP64 data model where longs are 64-bit (unlike Windows) + {{"long"}, Type::Int64}, + {{"long", "int"}, Type::Int64}, + {{"signed", "long"}, Type::Int64}, + {{"signed", "long", "int"}, Type::Int64}, + {{"long", "long"}, Type::Int64}, + {{"long", "long", "int"}, Type::Int64}, + {{"signed", "long", "long"}, Type::Int64}, + {{"signed", "long", "long", "int"}, Type::Int64}, + {{"int64_t"}, Type::Int64}, + + // **NOTE** GeNN uses LP64 data model where longs are 64-bit (unlike Windows) + {{"unsigned", "long"}, Type::Uint64}, + {{"unsigned", "long", "int"}, Type::Uint64}, + {{"unsigned", "long"}, Type::Uint64}, + {{"unsigned", "long", "long", "int"}, Type::Uint64}, + {{"uint64_t"}, Type::Uint64}, + {{"size_t"}, Type::Uint64}, + {{"float"}, Type::Float}, {{"double"}, Type::Double}}; @@ -173,7 +192,7 @@ class ParserState }; // **THINK** could leave unresolved -Type::ResolvedType getNumericType(const std::set &typeSpecifiers, const Type::TypeContext &context) +Type::ResolvedType getNumericType(const std::multiset &typeSpecifiers, const Type::TypeContext &context) { // If type is numeric, return const auto type = numericTypeSpecifiers.find(typeSpecifiers); @@ -241,7 +260,7 @@ GeNN::Type::ResolvedType parseDeclarationSpecifiers(ParserState &parserState) { using namespace GeNN::Type; - std::set typeSpecifiers; + std::multiset typeSpecifiers; std::set typeQualifiers; std::vector> pointerTypeQualifiers; @@ -262,12 +281,17 @@ GeNN::Type::ResolvedType parseDeclarationSpecifiers(ParserState &parserState) if(!pointerTypeQualifiers.empty()) { parserState.error(parserState.previous(), "invalid type specifier"); } - else if(!typeSpecifiers.insert(parserState.previous().lexeme).second) { - parserState.error(parserState.previous(), "duplicate type specifier"); + else { + typeSpecifiers.insert(parserState.previous().lexeme); } } } while(parserState.match({Token::Type::TYPE_QUALIFIER, Token::Type::TYPE_SPECIFIER, Token::Type::STAR})); + // If no type specifiers are found + if(typeSpecifiers.empty()) { + parserState.error(parserState.peek(), "missing type specifier"); + throw ParseError(); + } // Lookup numeric type Type::ResolvedType type = getNumericType(typeSpecifiers, parserState.getContext()); @@ -914,11 +938,9 @@ Statement::StatementList parseBlockItemList(const std::vector &tokens, co const GeNN::Type::ResolvedType parseNumericType(const std::vector &tokens, const Type::TypeContext &context, ErrorHandlerBase &errorHandler) { ParserState parserState(tokens, context, errorHandler); - std::set typeSpecifiers; + std::multiset typeSpecifiers; while(parserState.match(Token::Type::TYPE_SPECIFIER)) { - if(!typeSpecifiers.insert(parserState.previous().lexeme).second) { - parserState.error(parserState.previous(), "duplicate type specifier"); - } + typeSpecifiers.insert(parserState.previous().lexeme); }; // Return numeric type diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index bf8f97f696..6039600993 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -50,7 +50,10 @@ const std::unordered_map keywords{ {"uint16_t", Token::Type::TYPE_SPECIFIER}, {"int16_t", Token::Type::TYPE_SPECIFIER}, {"uint32_t", Token::Type::TYPE_SPECIFIER}, + {"uint64_t", Token::Type::TYPE_SPECIFIER}, {"int32_t", Token::Type::TYPE_SPECIFIER}, + {"int64_t", Token::Type::TYPE_SPECIFIER}, + {"size_t", Token::Type::TYPE_SPECIFIER}, {"bool", Token::Type::TYPE_SPECIFIER}, {"scalar", Token::Type::TYPE_SPECIFIER}}; //--------------------------------------------------------------------------- From f4c4424ead6dee21fee45bf25ec34adf31fa245b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 10:01:58 +0100 Subject: [PATCH 334/725] improved error message in getNumericType --- src/genn/genn/transpiler/parser.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 6f4e6f2aec..b54bda2fe4 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -26,6 +26,8 @@ using namespace GeNN::Transpiler; namespace { const std::map, Type::ResolvedType> numericTypeSpecifiers{ + {{"bool"}, Type::Bool}, + {{"char"}, Type::Int8}, {{"int8_t"}, Type::Int8}, @@ -208,8 +210,11 @@ Type::ResolvedType getNumericType(const std::multiset &typeSpecifie } } - // **TODO** improve error - throw std::runtime_error("Unknown numeric type specifier"); + // Generate string representation of type specifier and give error + std::ostringstream typeSpecifiersString; + std::copy(typeSpecifiers.cbegin(), typeSpecifiers.cend(), + std::ostream_iterator(typeSpecifiersString, " ")); + throw std::runtime_error("Unknown numeric type specifier '" + typeSpecifiersString.str() + "'"); } } From c0e01419e3ff275faa5d15c6427822684bf466b1 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 10:12:03 +0100 Subject: [PATCH 335/725] corrected casting logic - can cast const off value type --- src/genn/genn/transpiler/typeChecker.cc | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index aca67b16ba..d84eb24ee4 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -376,12 +376,6 @@ class Visitor : public Expression::Visitor, public Statement::Visitor // Evaluate type of expression we're casting const auto rightType = evaluateType(cast.getExpression()); - // If const is being removed - if (!checkForConstRemoval(rightType, cast.getType())) { - m_ErrorHandler.error(cast.getClosingParen(), "Invalid operand types '" + cast.getType().getName() + "' and '" + rightType.getName()); - throw TypeCheckError(); - } - const auto resultType = std::visit( Utils::Overload{ // If types are numeric, any cast goes @@ -395,15 +389,19 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } }, // Otherwise, if we're trying to cast pointer to pointer - [&cast](const Type::ResolvedType::Pointer &rightPointer, const Type::ResolvedType::Pointer &castPointer) -> std::optional + [&cast, &rightType](const Type::ResolvedType::Pointer &rightPointer, const Type::ResolvedType::Pointer &castPointer) -> std::optional { - // Check that value type at the end matches - if (checkPointerTypeAssignement(*rightPointer.valueType, *castPointer.valueType)) { - return cast.getType(); + // Check that value type at the end matches + if (!checkPointerTypeAssignement(*rightPointer.valueType, *castPointer.valueType)) { + return std::nullopt; } - else { + // Check we're not trying to maketype less const + else if(!checkForConstRemoval(rightType, cast.getType())) { return std::nullopt; } + else { + return cast.getType(); + } }, // Otherwise, pointers can't be cast to non-pointers and vice versa [](auto, auto) -> std::optional From 1d40a22b6275f60a73da31637c116cfe7bec1c36 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 10:20:18 +0100 Subject: [PATCH 336/725] moved host RNG functions into standard library --- .../genn/code_generator/standardLibrary.h | 6 ++- .../backends/single_threaded_cpu/backend.cc | 30 +++----------- .../genn/code_generator/standardLibrary.cc | 41 +++++++++++++++++-- 3 files changed, 49 insertions(+), 28 deletions(-) diff --git a/include/genn/genn/code_generator/standardLibrary.h b/include/genn/genn/code_generator/standardLibrary.h index cf51594eaa..f1f6cfd7c9 100644 --- a/include/genn/genn/code_generator/standardLibrary.h +++ b/include/genn/genn/code_generator/standardLibrary.h @@ -8,5 +8,9 @@ //--------------------------------------------------------------------------- namespace GeNN::CodeGenerator::StandardLibrary { -const EnvironmentLibrary::Library &getFunctions(); +//! Get standard maths functions +const EnvironmentLibrary::Library &getMathsFunctions(); + +//! Get std::random based host RNG functions +const EnvironmentLibrary::Library &getHostRNGFunctions(const Type::ResolvedType &precision); } // namespace GeNN::CodeGenerator::StandardLibrary diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 52e9af853f..e3cc5001e2 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -19,24 +19,6 @@ using namespace GeNN::Transpiler; //-------------------------------------------------------------------------- namespace { -const EnvironmentLibrary::Library cpuSinglePrecisionFunctions = { - {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Float, {}), "standardUniformDistribution(hostRNG)"}}, - {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "standardNormalDistribution(hostRNG)"}}, - {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Float, {}), "standardExponentialDistribution(hostRNG)"}}, - {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Float, {Type::Float, Type::Float}), "std::lognormal_distribution($(0), $(1))(hostRNG)"}}, - {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Float, {Type::Float}), "std::gamma_distribution($(0), 1.0f)(hostRNG)"}}, - {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Float}), "std::binomial_distribution($(0), $(1))(hostRNG)"}}, -}; - -const EnvironmentLibrary::Library cpuDoublePrecisionFunctions = { - {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Double, {}), "standardUniformDistribution(hostRNG)"}}, - {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Double, {}), "standardNormalDistribution(hostRNG)"}}, - {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Double, {}), "standardExponentialDistribution(hostRNG)"}}, - {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Double, {Type::Double, Type::Double}), "std::lognormal_distribution($(0), $(1))(hostRNG)"}}, - {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Double, {Type::Double}), "std::gamma_distribution($(0), 1.0)(hostRNG)"}}, - {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Double}), "std::binomial_distribution($(0), $(1))(hostRNG)"}}, -}; - //-------------------------------------------------------------------------- // Timer //-------------------------------------------------------------------------- @@ -127,7 +109,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host CodeStream neuronUpdate(neuronUpdateStream); // Begin environment with standard library - EnvironmentLibrary neuronUpdateEnv(neuronUpdate, StandardLibrary::getFunctions()); + EnvironmentLibrary neuronUpdateEnv(neuronUpdate, StandardLibrary::getMathsFunctions()); neuronUpdateEnv.getStream() << "void updateNeurons(" << modelMerged.getModel().getTimePrecision().getName() << " t"; if(modelMerged.getModel().isRecordingInUse()) { @@ -262,7 +244,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host groupEnv.add(Type::Uint32, "id", "i"); // Add RNG libray - EnvironmentLibrary rngEnv(groupEnv, (modelMerged.getModel().getPrecision() == Type::Float) ? cpuSinglePrecisionFunctions : cpuDoublePrecisionFunctions); + EnvironmentLibrary rngEnv(groupEnv, StandardLibrary::getHostRNGFunctions(modelMerged.getModel().getPrecision())); // Generate neuron update n.generateNeuronUpdate(*this, rngEnv, modelMerged, @@ -313,7 +295,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos CodeStream synapseUpdate(synapseUpdateStream); // Begin environment with standard library - EnvironmentLibrary synapseUpdateEnv(synapseUpdate, StandardLibrary::getFunctions()); + EnvironmentLibrary synapseUpdateEnv(synapseUpdate, StandardLibrary::getMathsFunctions()); synapseUpdateEnv.getStream() << "void updateSynapses(" << modelMerged.getModel().getTimePrecision().getName() << " t)"; { @@ -557,7 +539,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host CodeStream customUpdate(customUpdateStream); // Begin environment with standard library - EnvironmentLibrary customUpdateEnv(customUpdate, StandardLibrary::getFunctions()); + EnvironmentLibrary customUpdateEnv(customUpdate, StandardLibrary::getMathsFunctions()); // Loop through custom update groups for(const auto &g : customUpdateGroups) { @@ -849,8 +831,8 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler CodeStream init(initStream); // Begin environment with RNG library and standard library - EnvironmentLibrary rngEnv(init, (modelMerged.getModel().getPrecision() == Type::Float) ? cpuSinglePrecisionFunctions : cpuDoublePrecisionFunctions); - EnvironmentLibrary initEnv(rngEnv, StandardLibrary::getFunctions()); + EnvironmentLibrary rngEnv(init, StandardLibrary::getHostRNGFunctions(modelMerged.getModel().getPrecision())); + EnvironmentLibrary initEnv(rngEnv, StandardLibrary::getMathsFunctions()); initEnv.getStream() << "void initialize()"; diff --git a/src/genn/genn/code_generator/standardLibrary.cc b/src/genn/genn/code_generator/standardLibrary.cc index aed9c0f758..9a806ed8e5 100644 --- a/src/genn/genn/code_generator/standardLibrary.cc +++ b/src/genn/genn/code_generator/standardLibrary.cc @@ -9,6 +9,8 @@ namespace Type = GeNN::Type; +using namespace GeNN::CodeGenerator; + //--------------------------------------------------------------------------- // Macros //--------------------------------------------------------------------------- @@ -32,7 +34,7 @@ namespace template auto initLibraryTypes(Args&&... args) { - GeNN::CodeGenerator::EnvironmentLibrary::Library map; + EnvironmentLibrary::Library map; (map.emplace(std::forward(args)), ...); return map; } @@ -111,6 +113,23 @@ const auto libraryTypes = initLibraryTypes( std::make_pair("printf", std::make_pair(Type::ResolvedType::createFunction(Type::Int32, {Type::Int8.addQualifier(Type::Qualifier::CONSTANT).createPointer()}, true), "printf($(0), $(@))"))); } +const EnvironmentLibrary::Library floatRandomFunctions = { + {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Float, {}), "standardUniformDistribution(hostRNG)"}}, + {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "standardNormalDistribution(hostRNG)"}}, + {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Float, {}), "standardExponentialDistribution(hostRNG)"}}, + {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Float, {Type::Float, Type::Float}), "std::lognormal_distribution($(0), $(1))(hostRNG)"}}, + {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Float, {Type::Float}), "std::gamma_distribution($(0), 1.0f)(hostRNG)"}}, + {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Float}), "std::binomial_distribution($(0), $(1))(hostRNG)"}}, +}; + +const EnvironmentLibrary::Library doubleRandomFunctions = { + {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Double, {}), "standardUniformDistribution(hostRNG)"}}, + {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Double, {}), "standardNormalDistribution(hostRNG)"}}, + {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Double, {}), "standardExponentialDistribution(hostRNG)"}}, + {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Double, {Type::Double, Type::Double}), "std::lognormal_distribution($(0), $(1))(hostRNG)"}}, + {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Double, {Type::Double}), "std::gamma_distribution($(0), 1.0)(hostRNG)"}}, + {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Double}), "std::binomial_distribution($(0), $(1))(hostRNG)"}}, +}; /*{, {"frexp", "frexpf"}, // pointer arguments @@ -122,8 +141,24 @@ const auto libraryTypes = initLibraryTypes( */ //min, max, printf - -const GeNN::CodeGenerator::EnvironmentLibrary::Library &GeNN::CodeGenerator::StandardLibrary::getFunctions() +//--------------------------------------------------------------------------- +// GeNN::CodeGenerator::StandardLibrary::FunctionTypes +//--------------------------------------------------------------------------- +namespace GeNN::CodeGenerator::StandardLibrary +{ +const EnvironmentLibrary::Library &getMathsFunctions() { return libraryTypes; } + +const EnvironmentLibrary::Library &getHostRNGFunctions(const Type::ResolvedType &precision) +{ + if(precision == Type::Float) { + return floatRandomFunctions; + } + else { + assert(precision == Type::Double); + return doubleRandomFunctions; + } +} +} // namespace GeNN::CodeGenerator::StandardLibrary From 0ac914540dd1bba1bf2f8de6b1962fd92ba0d4d4 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 11:05:56 +0100 Subject: [PATCH 337/725] empty substitutions are ok! --- include/genn/genn/code_generator/environment.h | 1 - src/genn/genn/transpiler/prettyPrinter.cc | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index ef2e9f9976..9d31c0f545 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -194,7 +194,6 @@ class EnvironmentFieldPolicy } // Otherwise, use value directly else { - assert(!str.empty()); return str; } } diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index dc41ec1dea..8d7588db0d 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -262,8 +262,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto &type = m_ResolvedTypes.at(&variable); std::string name = m_Environment.get().getName(variable.getName().lexeme, type); - // If identifier is function i.e. name is a function template - if (type.isFunction()) { + // If identifier is function and name isn't empty i.e. it contains a function template + if (type.isFunction() && !name.empty()) { // Check that there are call arguments on the stack assert(!m_CallArguments.empty()); From 916557bcd6c3ab5d2fc9f9c26266ba9d3737f9dd Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 11:06:09 +0100 Subject: [PATCH 338/725] wrong type associated with uint16_t --- src/genn/genn/transpiler/parser.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index b54bda2fe4..b11ce05971 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -42,7 +42,7 @@ const std::map, Type::ResolvedType> numericTypeSpecif {{"unsigned", "short"}, Type::Uint16}, {{"unsigned", "short", "int"}, Type::Uint16}, - {{"uint16_t"}, Type::Uint8}, + {{"uint16_t"}, Type::Uint16}, {{"int"}, Type::Int32}, {{"signed"}, Type::Int32}, From 42bd07f553cdabc158da4e9c6f01360cfb8618d0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 11:06:45 +0100 Subject: [PATCH 339/725] setup complete environment for SynapseConnectivityHostInitGroupMerged --- include/genn/genn/type.h | 4 +++ .../genn/code_generator/initGroupMerged.cc | 28 +++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index a89f058234..98b11e66f0 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -350,6 +350,10 @@ inline static const ResolvedType Void = ResolvedType(); inline static const ResolvedType AddToPre = ResolvedType::createFunction(Void, {Uint32}); inline static const ResolvedType AddToPost = ResolvedType::createFunction(Void, {Uint32}); inline static const ResolvedType AddToPostDenDelay = ResolvedType::createFunction(Void, {Uint32, Uint32}); +inline static const ResolvedType AllocatePushPullEGP = ResolvedType::createFunction(Void, {Uint32}); + + +inline static const ResolvedType Assert = ResolvedType::createFunction(Void, {Bool}); //! Apply C type promotion rules to numeric type GENN_EXPORT ResolvedType getPromotedType(const ResolvedType &type); diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 7472487c81..856b2b2f55 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -3,6 +3,7 @@ // GeNN code generator includes #include "code_generator/groupMergedTypeEnvironment.h" #include "code_generator/modelSpecMerged.h" +#include "code_generator/standardLibrary.h" // GeNN transpiler includes #include "transpiler/errorHandler.h" @@ -734,18 +735,27 @@ const std::string SynapseConnectivityHostInitGroupMerged::name = "SynapseConnect //------------------------------------------------------------------------- void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) { - - CodeStream::Scope b(env.getStream()); - env.getStream() << "// merged synapse connectivity host init group " << getIndex() << std::endl; - env.getStream() << "for(unsigned int g = 0; g < " << getGroups().size() << "; g++)"; + // Add standard library to environment + EnvironmentLibrary envStdLib(env, StandardLibrary::getMathsFunctions()); + + // Add host RNG functions to environment + EnvironmentLibrary envRandom(envStdLib, StandardLibrary::getHostRNGFunctions(modelMerged.getModel().getPrecision())); + + // Add standard host assert function to environment + EnvironmentExternal envAssert(envRandom); + envAssert.add(Type::Assert, "assert", "assert($(0))"); + + CodeStream::Scope b(envAssert.getStream()); + envAssert.getStream() << "// merged synapse connectivity host init group " << getIndex() << std::endl; + envAssert.getStream() << "for(unsigned int g = 0; g < " << getGroups().size() << "; g++)"; { - CodeStream::Scope b(env.getStream()); + CodeStream::Scope b(envAssert.getStream()); // Get reference to group - env.getStream() << "const auto *group = &mergedSynapseConnectivityHostInitGroup" << getIndex() << "[g]; " << std::endl; + envAssert.getStream() << "const auto *group = &mergedSynapseConnectivityHostInitGroup" << getIndex() << "[g]; " << std::endl; // Create environment for group - EnvironmentGroupMergedField groupEnv(env, *this); + EnvironmentGroupMergedField groupEnv(envAssert, *this); const auto &connectInit = getArchetype().getConnectivityInitialiser(); // If matrix type is procedural then initialized connectivity init snippet will potentially be used with multiple threads per spike. @@ -812,7 +822,7 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac loc, "$(0)", "group->"); // Add substitution - groupEnv.add(Type::ResolvedType::createFunction(Type::Void, {Type::Uint32}), "allocate" + egp.name, allocStream.str()); + groupEnv.add(Type::AllocatePushPullEGP, "allocate" + egp.name, allocStream.str()); // Generate code to push this EGP with count specified by $(0) std::stringstream pushStream; @@ -823,7 +833,7 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac // Add substitution - groupEnv.add(Type::ResolvedType::createFunction(Type::Void, {Type::Uint32}), "push" + egp.name, pushStream.str()); + groupEnv.add(Type::AllocatePushPullEGP, "push" + egp.name, pushStream.str()); } } Transpiler::ErrorHandler errorHandler("Synapse group '" + getArchetype().getName() + "' sparse connectivity host init code"); From 3031d264c0d19e22119ff62bceba00abf1972609 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 11:14:32 +0100 Subject: [PATCH 340/725] upgraded FixedNumberTotalWithReplacement host code --- .../genn/genn/initSparseConnectivitySnippet.h | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/include/genn/genn/initSparseConnectivitySnippet.h b/include/genn/genn/initSparseConnectivitySnippet.h index 0f1ba57153..5ed77b014f 100644 --- a/include/genn/genn/initSparseConnectivitySnippet.h +++ b/include/genn/genn/initSparseConnectivitySnippet.h @@ -294,34 +294,32 @@ class FixedNumberTotalWithReplacement : public Base SET_HOST_INIT_CODE( "// Allocate pre-calculated row length array\n" - "$(allocatepreCalcRowLength, $(num_pre) * $(num_threads));\n" + "allocatepreCalcRowLength(num_pre * num_threads);\n" "// Calculate row lengths\n" - "const size_t numPostPerThread = ($(num_post) + $(num_threads) - 1) / $(num_threads);\n" - "const size_t leftOverNeurons = $(num_post) % numPostPerThread;\n" - "size_t remainingConnections = $(total);\n" - "size_t matrixSize = (size_t)$(num_pre) * (size_t)$(num_post);\n" - "uint16_t *subRowLengths = $(preCalcRowLength);\n" + "const size_t numPostPerThread = (num_post + num_threads - 1) / num_threads;\n" + "const size_t leftOverNeurons = num_post % numPostPerThread;\n" + "size_t remainingConnections = total;\n" + "size_t matrixSize = (size_t)num_pre * (size_t)num_post;\n" + "uint16_t *subRowLengths = preCalcRowLength;\n" "// Loop through rows\n" - "for(size_t i = 0; i < $(num_pre); i++) {\n" - " const bool lastPre = (i == ($(num_pre) - 1));\n" + "for(size_t i = 0; i < num_pre; i++) {\n" + " const bool lastPre = (i == (num_pre - 1));\n" " // Loop through subrows\n" - " for(size_t j = 0; j < $(num_threads); j++) {\n" - " const bool lastSubRow = (j == ($(num_threads) - 1));\n" + " for(size_t j = 0; j < num_threads; j++) {\n" + " const bool lastSubRow = (j == (num_threads - 1));\n" " // If this isn't the last sub-row of the matrix\n" " if(!lastPre || ! lastSubRow) {\n" " // Get length of this subrow\n" " const unsigned int numSubRowNeurons = (leftOverNeurons != 0 && lastSubRow) ? leftOverNeurons : numPostPerThread;\n" " // Calculate probability\n" " const double probability = (double)numSubRowNeurons / (double)matrixSize;\n" - " // Create distribution to sample row length\n" - " std::binomial_distribution rowLengthDist(remainingConnections, probability);\n" " // Sample row length;\n" - " const size_t subRowLength = rowLengthDist($(rng));\n" + " const size_t subRowLength = gennrand_binomial(remainingConnections, probability);\n" " // Update counters\n" " remainingConnections -= subRowLength;\n" " matrixSize -= numSubRowNeurons;\n" " // Add row length to array\n" - " assert(subRowLength < std::numeric_limits::max());\n" + " assert(subRowLength < 0xFFFF);\n" " *subRowLengths++ = (uint16_t)subRowLength;\n" " }\n" " }\n" @@ -329,7 +327,7 @@ class FixedNumberTotalWithReplacement : public Base "// Insert remaining connections into last sub-row\n" "*subRowLengths = (uint16_t)remainingConnections;\n" "// Push populated row length array\n" - "$(pushpreCalcRowLength, $(num_pre) * $(num_threads));\n"); + "pushpreCalcRowLength(num_pre * num_threads);\n"); SET_CALC_MAX_ROW_LENGTH_FUNC( [](unsigned int numPre, unsigned int numPost, const std::unordered_map &pars) From b1c09e4634dde0dfb25e6e506b6e3597b72baaf2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 11:17:22 +0100 Subject: [PATCH 341/725] stick _ infront of register-cached variables and neuron additional input variables --- include/genn/genn/code_generator/environment.h | 8 ++++---- src/genn/genn/code_generator/neuronUpdateGroupMerged.cc | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 9d31c0f545..8a1523de55 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -798,7 +798,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P if(v.access & VarAccessMode::READ_ONLY) { getContextStream() << "const "; } - getContextStream() << resolvedType.getName() << " " << m_LocalPrefix << v.name; + getContextStream() << resolvedType.getName() << " _" << m_LocalPrefix << v.name; // If this isn't a reduction, read value from memory // **NOTE** by not initialising these variables for reductions, @@ -817,7 +817,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // If variables are read-write if(v.access & VarAccessMode::READ_WRITE) { getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << printSubs(getWriteIndex(m_Group.get(), v), *this) << "]"; - getContextStream() << " = " << m_LocalPrefix << v.name << ";" << std::endl; + getContextStream() << " = _" << m_LocalPrefix << v.name << ";" << std::endl; } } } @@ -859,8 +859,8 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // Set flag to indicate that variable has been referenced var->second.first = true; - // Add local prefix to variable name - return m_LocalPrefix + name; + // Add underscore and local prefix to variable name + return "_" + m_LocalPrefix + name; } } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 553a6fc306..b0b00cbec4 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -493,8 +493,8 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // **NOTE** arbitrary code in param value to be deprecated for (const auto &v : nm->getAdditionalInputVars()) { const auto resolvedType = v.type.resolve(getTypeContext()); - neuronEnv.add(resolvedType, v.name, v.name, - {neuronEnv.addInitialiser(resolvedType.getName() + " " + v.name + " = " + v.value + ";")}); + neuronEnv.add(resolvedType, v.name, "_" + v.name, + {neuronEnv.addInitialiser(resolvedType.getName() + " _" + v.name + " = " + v.value + ";")}); } // Substitute parameter and derived parameter names From d548108d669ed0d785b5afbce774fc1ab1e53509 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 11:17:39 +0100 Subject: [PATCH 342/725] no need for [] around index --- src/genn/genn/code_generator/initGroupMerged.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 856b2b2f55..1e5089d1f8 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -791,7 +791,7 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac groupEnv.addField(pointerType, egp.name, pointerToPointerType, egp.name, [egp](const auto &g, size_t) { return "&" + egp.name + g.getName(); }, - "[0]", GroupMergedFieldType::HOST_DYNAMIC); + "0", GroupMergedFieldType::HOST_DYNAMIC); // If backend requires seperate device variables, add additional (private) field) if(!backend.getDeviceVarPrefix().empty()) { From 9fea8215441fa1411d1719a67a825e9597b15b57 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 11:35:04 +0100 Subject: [PATCH 343/725] setup logging BEFORE calling ``modelDefinition`` --- src/genn/generator/generator.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/genn/generator/generator.cc b/src/genn/generator/generator.cc index daaa20794a..d04cd3b4f4 100644 --- a/src/genn/generator/generator.cc +++ b/src/genn/generator/generator.cc @@ -45,16 +45,16 @@ int main(int argc, //!< number of arguments; expected to be 3 const filesystem::path targetPath(argv[2]); const bool forceRebuild = (std::stoi(argv[3]) != 0); - // Create model - // **NOTE** casting to external-facing model to hide model's internals - ModelSpecInternal model; - modelDefinition(static_cast(std::ref(model))); - // Initialise logging, appending all to console plog::ConsoleAppender consoleAppender; Logging::init(GENN_PREFERENCES.logLevel, GENN_PREFERENCES.logLevel, GENN_PREFERENCES.logLevel, &consoleAppender, &consoleAppender, &consoleAppender); + // Create model + // **NOTE** casting to external-facing model to hide model's internals + ModelSpecInternal model; + modelDefinition(static_cast(std::ref(model))); + // Finalize model model.finalise(); From 3ecf4aca3abc266f0ab1c5d5436e55d59b60a2c4 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 11:35:40 +0100 Subject: [PATCH 344/725] check errorState after scanning --- src/genn/genn/gennUtils.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/genn/genn/gennUtils.cc b/src/genn/genn/gennUtils.cc index 61353c405f..41a575423f 100644 --- a/src/genn/genn/gennUtils.cc +++ b/src/genn/genn/gennUtils.cc @@ -70,7 +70,11 @@ std::vector scanCode(const std::string &code, const std::stri // Scan code string and return tokens Transpiler::ErrorHandler errorHandler(errorContext); - return Transpiler::Scanner::scanSource(upgradedCode, errorHandler); + const auto tokens = Transpiler::Scanner::scanSource(upgradedCode, errorHandler); + if(errorHandler.hasError()) { + throw std::runtime_error("Error scanning " + errorContext); + } + return tokens; } //-------------------------------------------------------------------------- bool areTokensEmpty(const std::vector &tokens) From ec97bcbd0e182865036eb8f0d769d88b0a0ff742 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 11:35:53 +0100 Subject: [PATCH 345/725] don't allow identifiers to start with _ --- src/genn/genn/transpiler/scanner.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/transpiler/scanner.cc b/src/genn/genn/transpiler/scanner.cc index 6039600993..ca397aba7a 100644 --- a/src/genn/genn/transpiler/scanner.cc +++ b/src/genn/genn/transpiler/scanner.cc @@ -428,11 +428,11 @@ void scanToken(ScanState &scanState, std::vector &tokens) scanNumber(c, scanState, tokens); } // Otherwise, scan identifier - else if(std::isalpha(c) || c == '_') { + else if(std::isalpha(c)) { scanIdentifier(scanState, tokens); } else { - scanState.error("Unexpected character."); + scanState.error("Unexpected character '" + std::string{c} + "'."); } } } From 5a1a21c52d7a6eb641f1fc88870624aa8547afd3 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 6 Jul 2023 12:50:17 +0100 Subject: [PATCH 346/725] start hacking away at SIMT --- include/genn/backends/cuda/backend.h | 122 +++++---- include/genn/backends/opencl/backend.h | 3 - .../backends/single_threaded_cpu/backend.h | 4 - .../genn/genn/code_generator/backendBase.h | 6 - .../genn/genn/code_generator/backendSIMT.h | 4 +- src/genn/backends/cuda/backend.cc | 253 +++++++++--------- src/genn/backends/opencl/backend.cc | 5 - .../backends/single_threaded_cpu/backend.cc | 8 +- src/genn/genn/code_generator/backendSIMT.cc | 191 +++++++------ .../code_generator/neuronUpdateGroupMerged.cc | 48 ++-- 10 files changed, 326 insertions(+), 318 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index 819b1fe989..b56653e091 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -21,7 +21,6 @@ // GeNN code generator includes #include "code_generator/backendSIMT.h" #include "code_generator/codeStream.h" -#include "code_generator/substitutions.h" // Forward declarations namespace filesystem @@ -133,27 +132,27 @@ class BACKEND_EXPORT Backend : public BackendSIMT // CodeGenerator::BackendSIMT virtuals //-------------------------------------------------------------------------- //! On some older devices, shared memory atomics are actually slower than global memory atomics so should be avoided - virtual bool areSharedMemAtomicsSlow() const override; + virtual bool areSharedMemAtomicsSlow() const final; //! Get the prefix to use for shared memory variables - virtual std::string getSharedPrefix() const override{ return "__shared__ "; } + virtual std::string getSharedPrefix() const final{ return "__shared__ "; } //! Get the ID of the current thread within the threadblock - virtual std::string getThreadID(unsigned int axis = 0) const override; + virtual std::string getThreadID(unsigned int axis = 0) const final; //! Get the ID of the current thread block - virtual std::string getBlockID(unsigned int axis = 0) const override; + virtual std::string getBlockID(unsigned int axis = 0) const final; //! Get the name of the count-leading-zeros function - virtual std::string getCLZ() const override { return "__clz"; } + virtual std::string getCLZ() const final { return "__clz"; } //! Get name of atomic operation - virtual std::string getAtomic(const Type::NumericBase *type, const Type::TypeContext &typeContext, - AtomicOperation op = AtomicOperation::ADD, - AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const override; + virtual std::string getAtomic(const Type::ResolvedType &type, + AtomicOperation op = AtomicOperation::ADD, + AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const final; //! Generate a shared memory barrier - virtual void genSharedMemBarrier(CodeStream &os) const override; + virtual void genSharedMemBarrier(CodeStream &os) const final; //! For SIMT backends which initialize RNGs on device, initialize population RNG with specified seed and sequence virtual void genPopulationRNGInit(CodeStream &os, const std::string &globalRNG, const std::string &seed, const std::string &sequence) const final; @@ -163,78 +162,82 @@ class BACKEND_EXPORT Backend : public BackendSIMT //! If required, generate a postamble for population RNG /*! For example, in OpenCL, this is used to write local RNG state back to global memory*/ - virtual void genPopulationRNGPostamble(CodeStream &os, const std::string &globalRNG) const override; + virtual void genPopulationRNGPostamble(CodeStream &os, const std::string &globalRNG) const final; //! Generate code to skip ahead local copy of global RNG virtual std::string genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const final; + //! Get type of population RNG + virtual Type::ResolvedType getPopulationRNGType() const final; + //-------------------------------------------------------------------------- // CodeGenerator::BackendBase virtuals //-------------------------------------------------------------------------- - virtual void genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; + virtual void genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; - virtual void genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; + virtual void genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; - virtual void genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; + virtual void genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; - virtual void genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const override; + virtual void genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; - virtual void genDefinitionsPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; - virtual void genDefinitionsInternalPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; - virtual void genRunnerPreamble(CodeStream &os, const ModelSpecMerged &modelMerged, const MemAlloc &memAlloc) const override; - virtual void genAllocateMemPreamble(CodeStream &os, const ModelSpecMerged &modelMerged, const MemAlloc &memAlloc) const override; - virtual void genFreeMemPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; - virtual void genStepTimeFinalisePreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const override; + virtual void genDefinitionsPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final; + virtual void genDefinitionsInternalPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final; + virtual void genRunnerPreamble(CodeStream &os, const ModelSpecMerged &modelMerged, const MemAlloc &memAlloc) const final; + virtual void genAllocateMemPreamble(CodeStream &os, const ModelSpecMerged &modelMerged, const MemAlloc &memAlloc) const final; + virtual void genFreeMemPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final; + virtual void genStepTimeFinalisePreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final; //! Generate code to define a variable in the appropriate header file virtual void genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, - const Type::ValueBase *type, const std::string &name, VarLocation loc) const final; + const Type::ResolvedType &type, const std::string &name, VarLocation loc) const final; //! Generate code to instantiate a variable in the provided stream virtual void genVariableInstantiation(CodeStream &os, - const Type::ValueBase *type, const std::string &name, VarLocation loc) const final; + const Type::ResolvedType &type, const std::string &name, VarLocation loc) const final; //! Generate code to allocate variable with a size known at compile-time virtual void genVariableAllocation(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const final; //! Generate code to allocate variable with a size known at runtime virtual void genVariableDynamicAllocation(CodeStream &os, - const Type::Base *type, const std::string &name, VarLocation loc, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + //! Generate code to free a variable virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const final; //! Generate code for pushing a variable with a size known at compile-time to the 'device' virtual void genVariablePush(CodeStream &os, - const Type::ValueBase *type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const final; //! Generate code for pulling a variable with a size known at compile-time from the 'device' virtual void genVariablePull(CodeStream &os, - const Type::ValueBase *type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, size_t count) const final; //! Generate code for pushing a variable's value in the current timestep to the 'device' virtual void genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, unsigned int batchSize) const final; //! Generate code for pulling a variable's value in the current timestep from the 'device' virtual void genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, unsigned int batchSize) const final; //! Generate code for pushing a variable with a size known at tuntime to the 'device' virtual void genVariableDynamicPush(CodeStream &os, - const Type::Base *type, const std::string &name, VarLocation loc, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code for pulling a variable with a size known at runtime from the 'device' virtual void genVariableDynamicPull(CodeStream &os, - const Type::Base *type, const std::string &name, VarLocation loc, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' @@ -243,58 +246,59 @@ class BACKEND_EXPORT Backend : public BackendSIMT const std::string &egpName) const final; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type) const override; + virtual std::string getMergedGroupFieldHostTypeName(const Type::ResolvedType &type) const = 0; + + //! Generate a single RNG instance + /*! On single-threaded platforms this can be a standard RNG like M.T. but, on parallel platforms, it is likely to be a counter-based RNG */ + virtual void genGlobalDeviceRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, + CodeStream &allocations, CodeStream &free, MemAlloc &memAlloc) const = 0; + + //! Generate an RNG with a state per population member + virtual void genPopulationRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, + CodeStream &allocations, CodeStream &free, + const std::string &name, size_t count, MemAlloc &memAlloc) const = 0; - //! When generating merged structures what type to use for simulation RNGs - virtual const Type::ValueBase *getMergedGroupSimRNGType() const override; - - virtual void genGlobalDeviceRNG(CodeStream &definitions, CodeStream &definitionsInternal, - CodeStream &runner, CodeStream &allocations, CodeStream &free, - const Type::TypeContext &typeContext, MemAlloc &memAlloc) const override; - virtual void genPopulationRNG(CodeStream &definitions, CodeStream &definitionsInternal, - CodeStream &runner, CodeStream &allocations, CodeStream &free, - const Type::TypeContext &typeContext, const std::string &name, size_t count, MemAlloc &memAlloc) const override; virtual void genTimer(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, CodeStream &stepTimeFinalise, - const std::string &name, bool updateInStepTime) const override; + const std::string &name, bool updateInStepTime) const final; //! Generate code to return amount of free 'device' memory in bytes - virtual void genReturnFreeDeviceMemoryBytes(CodeStream &os) const override; + virtual void genReturnFreeDeviceMemoryBytes(CodeStream &os) const final; //! On backends which support it, generate a runtime assert - virtual void genAssert(CodeStream &os, const std::string &condition) const override; + virtual void genAssert(CodeStream &os, const std::string &condition) const final; - virtual void genMakefilePreamble(std::ostream &os) const override; - virtual void genMakefileLinkRule(std::ostream &os) const override; - virtual void genMakefileCompileRule(std::ostream &os) const override; + virtual void genMakefilePreamble(std::ostream &os) const final; + virtual void genMakefileLinkRule(std::ostream &os) const final; + virtual void genMakefileCompileRule(std::ostream &os) const final; - virtual void genMSBuildConfigProperties(std::ostream &os) const override; - virtual void genMSBuildImportProps(std::ostream &os) const override; - virtual void genMSBuildItemDefinitions(std::ostream &os) const override; - virtual void genMSBuildCompileModule(const std::string &moduleName, std::ostream &os) const override; - virtual void genMSBuildImportTarget(std::ostream &os) const override; + virtual void genMSBuildConfigProperties(std::ostream &os) const final; + virtual void genMSBuildImportProps(std::ostream &os) const final; + virtual void genMSBuildItemDefinitions(std::ostream &os) const final; + virtual void genMSBuildCompileModule(const std::string &moduleName, std::ostream &os) const final; + virtual void genMSBuildImportTarget(std::ostream &os) const final; //! Get backend-specific allocate memory parameters - virtual std::string getAllocateMemParams(const ModelSpecMerged &) const override; + virtual std::string getAllocateMemParams(const ModelSpecMerged &) const final; //! Different backends seed RNGs in different ways. Does this one initialise population RNGS on device? - virtual bool isPopulationRNGInitialisedOnDevice() const override { return true; } + virtual bool isPopulationRNGInitialisedOnDevice() const final { return true; } //! Backends which support batch-parallelism might require an additional host reduction phase after reduction kernels - virtual bool isHostReductionRequired() const override { return getPreferences().enableNCCLReductions; } + virtual bool isHostReductionRequired() const final { return getPreferences().enableNCCLReductions; } //! How many bytes of memory does 'device' have - virtual size_t getDeviceMemoryBytes() const override{ return m_ChosenDevice.totalGlobalMem; } + virtual size_t getDeviceMemoryBytes() const final{ return m_ChosenDevice.totalGlobalMem; } //! Some backends will have additional small, fast, memory spaces for read-only data which might //! Be well-suited to storing merged group structs. This method returns the prefix required to //! Place arrays in these and their size in preferential order - virtual MemorySpaces getMergedGroupMemorySpaces(const ModelSpecMerged &modelMerged) const override; + virtual MemorySpaces getMergedGroupMemorySpaces(const ModelSpecMerged &modelMerged) const final; - virtual bool supportsNamespace() const override { return true; }; + virtual bool supportsNamespace() const final { return true; }; //! Get hash digest of this backends identification and the preferences it has been configured with - virtual boost::uuids::detail::sha1::digest_type getHashDigest() const override; + virtual boost::uuids::detail::sha1::digest_type getHashDigest() const final; //-------------------------------------------------------------------------- // Public API diff --git a/include/genn/backends/opencl/backend.h b/include/genn/backends/opencl/backend.h index c5d88eab2c..9fe0088354 100644 --- a/include/genn/backends/opencl/backend.h +++ b/include/genn/backends/opencl/backend.h @@ -171,9 +171,6 @@ class BACKEND_EXPORT Backend : public BackendSIMT //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host virtual std::string getMergedGroupFieldHostTypeName(const Type::Base *type, const Type::TypeContext &context) const override; - //! When generating merged structures what type to use for simulation RNGs - virtual const Type::Base *getMergedGroupSimRNGType() const; - virtual void genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const override; virtual void genVariablePull(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, size_t count) const override; diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 6d9298b9dc..5333fba10c 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -118,9 +118,6 @@ class BACKEND_EXPORT Backend : public BackendBase //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host virtual std::string getMergedGroupFieldHostTypeName(const Type::ResolvedType &type) const final; - //! When generating merged structures what type to use for simulation RNGs - virtual std::optional getMergedGroupSimRNGType() const final; - virtual void genPopVariableInit(EnvironmentExternalBase &env, HandlerEnv handler) const final; virtual void genVariableInit(EnvironmentExternalBase &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const final; virtual void genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const final; @@ -159,7 +156,6 @@ class BACKEND_EXPORT Backend : public BackendBase virtual bool isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) const final; virtual bool isGlobalDeviceRNGRequired(const ModelSpecMerged &modelMerged) const final; - virtual bool isPopulationRNGRequired() const final { return false; } //! Different backends seed RNGs in different ways. Does this one initialise population RNGS on device? virtual bool isPopulationRNGInitialisedOnDevice() const final { return false; } diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index d6efc790af..292feee515 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -309,9 +309,6 @@ class GENN_EXPORT BackendBase //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host virtual std::string getMergedGroupFieldHostTypeName(const Type::ResolvedType &type) const = 0; - //! When generating merged structures what type to use for simulation RNGs - virtual std::optional getMergedGroupSimRNGType() const = 0; - virtual void genPopVariableInit(EnvironmentExternalBase &env, HandlerEnv handler) const = 0; virtual void genVariableInit(EnvironmentExternalBase &env, const std::string &count, const std::string &indexVarName, HandlerEnv handler) const = 0; virtual void genSparseSynapseVariableRowInit(EnvironmentExternalBase &env, HandlerEnv handler) const = 0; @@ -389,9 +386,6 @@ class GENN_EXPORT BackendBase //! Different backends use different RNGs for different things. Does this one require a global device RNG for the specified model? virtual bool isGlobalDeviceRNGRequired(const ModelSpecMerged &modelMerged) const = 0; - //! Different backends use different RNGs for different things. Does this one require population RNGs? - virtual bool isPopulationRNGRequired() const = 0; - //! Different backends seed RNGs in different ways. Does this one initialise population RNGS on device? virtual bool isPopulationRNGInitialisedOnDevice() const = 0; diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 5b2971e19d..2120596714 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -111,6 +111,9 @@ class GENN_EXPORT BackendSIMT : public BackendBase //! Generate code to skip ahead local copy of global RNG virtual std::string genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const = 0; + //! Get type of population RNG + virtual Type::ResolvedType getPopulationRNGType() const = 0; + //------------------------------------------------------------------------ // BackendBase virtuals //------------------------------------------------------------------------ @@ -141,7 +144,6 @@ class GENN_EXPORT BackendSIMT : public BackendBase virtual bool isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) const final; virtual bool isGlobalDeviceRNGRequired(const ModelSpecMerged &modelMerged) const final; - virtual bool isPopulationRNGRequired() const final { return true; } virtual bool isPostsynapticRemapRequired() const final { return true; } diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 6f1e67e775..d0714c1d67 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -16,7 +16,7 @@ #include "code_generator/codeStream.h" #include "code_generator/codeGenUtils.h" #include "code_generator/modelSpecMerged.h" -#include "code_generator/substitutions.h" +#include "code_generator/standardLibrary.h" // CUDA backend includes #include "utils.h" @@ -29,29 +29,29 @@ using namespace GeNN::CodeGenerator; //-------------------------------------------------------------------------- namespace { -const std::vector cudaSinglePrecisionFunctions = { - {"gennrand_uniform", 0, "curand_uniform($(rng))"}, - {"gennrand_normal", 0, "curand_normal($(rng))"}, - {"gennrand_exponential", 0, "exponentialDistFloat($(rng))"}, - {"gennrand_log_normal", 2, "curand_log_normal_float($(rng), $(0), $(1))"}, - {"gennrand_gamma", 1, "gammaDistFloat($(rng), $(0))"}, - {"gennrand_binomial", 2, "binomialDistFloat($(rng), $(0), $(1))"} +const EnvironmentLibrary::Library floatRandomFunctions = { + {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Float, {}), "curand_uniform($(rng))"}}, + {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "curand_normal($(rng))"}}, + {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Float, {}), "exponentialDistFloat($(rng))"}}, + {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Float, {Type::Float, Type::Float}), "curand_log_normal_float($(rng), $(0), $(1))"}}, + {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Float, {Type::Float}), "gammaDistFloat($(rng), $(0))"}}, + {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Float}), "binomialDistFloat($(rng), $(0), $(1))"}}, }; -//-------------------------------------------------------------------------- -const std::vector cudaDoublePrecisionFunctions = { - {"gennrand_uniform", 0, "curand_uniform_double($(rng))"}, - {"gennrand_normal", 0, "curand_normal_double($(rng))"}, - {"gennrand_exponential", 0, "exponentialDistDouble($(rng))"}, - {"gennrand_log_normal", 2, "curand_log_normal_double($(rng), $(0), $(1))"}, - {"gennrand_gamma", 1, "gammaDistDouble($(rng), $(0))"}, - {"gennrand_binomial", 2, "binomialDistDouble($(rng), $(0), $(1))"} + +const EnvironmentLibrary::Library doubleRandomFunctions = { + {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Double, {}), "curand_uniform_double($(rng))"}}, + {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Double, {}), "curand_normal_double($(rng))"}}, + {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Double, {}), "exponentialDistDouble($(rng))"}}, + {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Double, {Type::Double, Type::Double}), "curand_log_normal_double($(rng), $(0), $(1))"}}, + {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Double, {Type::Double}), "gammaDistDouble($(rng), $(0))"}}, + {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Double}), "binomialDistDouble($(rng), $(0), $(1))"}}, }; //-------------------------------------------------------------------------- // CUDADeviceType //-------------------------------------------------------------------------- -const Type::ResolvedType CURandState = Type::ResolvedType::createValue(); -const Type::ResolvedType CURandStatePhilox43210 = Type::ResolvedType::createValue(); +const Type::ResolvedType CURandState = Type::ResolvedType::createValue("curandState"); +const Type::ResolvedType CURandStatePhilox43210 = Type::ResolvedType::createValue("curandStatePhilox4_32_10_t"); //-------------------------------------------------------------------------- // Timer @@ -204,9 +204,15 @@ size_t getGroupStartIDSize(const std::vector &mergedGroups) }); } //----------------------------------------------------------------------- -const std::vector &getFunctionTemplates(const Type::NumericBase *precision) +const EnvironmentLibrary::Library &getRNGFunctions(const Type::ResolvedType &precision) { - return (precision->getName() == Type::Double::getInstance()->getName()) ? cudaDoublePrecisionFunctions : cudaSinglePrecisionFunctions; + if(precision == Type::Float) { + return floatRandomFunctions; + } + else { + assert(precision == Type::Double); + return doubleRandomFunctions; + } } //----------------------------------------------------------------------- std::string getNCCLReductionType(VarAccessMode mode) @@ -223,41 +229,34 @@ std::string getNCCLReductionType(VarAccessMode mode) } } //----------------------------------------------------------------------- -std::string getNCCLType(const Type::NumericBase *type, const Type::TypeContext &context) +std::string getNCCLType(const Type::ResolvedType &type) { - // If type is a numeric typedef, resolve it - const auto numericTypedef = dynamic_cast(type); - if (numericTypedef) { - type = numericTypedef->getResolvedType(context); - } + assert(type.isNumeric()); // Convert GeNN types to NCCL types - // **YUCK** Visitor pattern would really help here - if(dynamic_cast(type)) { + if(type == Type::Int8) { return "ncclInt8"; } - else if(dynamic_cast(type)) { + else if(type == Type::Uint8) { return "ncclUint8"; } - else if(dynamic_cast(type)) { + else if(type == Type::Int32) { return "ncclInt32"; } - else if(dynamic_cast(type)){ + else if(type == Type::Uint32){ return "ncclUint32"; } /*else if(type == "half") { return "ncclFloat16"; }*/ - else if(dynamic_cast(type)){ + else if(type == Type::Float){ return "ncclFloat32"; } - else if(dynamic_cast(type)) { + else if(type == Type::Double) { return "ncclFloat64"; } - else if (dynamic_cast(type)) { - } else { - throw std::runtime_error("Data type '" + type->getName() + "' unsupported by NCCL"); + throw std::runtime_error("Data type '" + type.getName() + "' unsupported by NCCL"); } } //----------------------------------------------------------------------- @@ -363,14 +362,12 @@ std::string Backend::getBlockID(unsigned int axis) const } } //-------------------------------------------------------------------------- -std::string Backend::getAtomic(const Type::NumericBase *type, const Type::TypeContext &typeContext, - AtomicOperation op, AtomicMemSpace) const +std::string Backend::getAtomic(const Type::ResolvedType &type, AtomicOperation op, AtomicMemSpace) const { // If operation is an atomic add - const std::string typeName = type->getResolvedName(typeContext); if(op == AtomicOperation::ADD) { - if(((getChosenCUDADevice().major < 2) && (typeName == Type::Float::getInstance()->getName())) - || (((getChosenCUDADevice().major < 6) || (getRuntimeVersion() < 8000)) && (typeName == Type::Double::getInstance()->getName()))) + if(((getChosenCUDADevice().major < 2) && (type == Type::Float)) + || (((getChosenCUDADevice().major < 6) || (getRuntimeVersion() < 8000)) && (type == Type::Double))) { return "atomicAddSW"; } @@ -380,7 +377,7 @@ std::string Backend::getAtomic(const Type::NumericBase *type, const Type::TypeCo // Otherwise, it's an atomic or else { assert(op == AtomicOperation::OR); - assert(typeName == Type::Uint32::getInstance()->getName() || typeName == Type::Int32::getInstance()->getName()); + assert(type == Type::Uint32 || type == Type::Int32); return "atomicOr"; } } @@ -412,124 +409,143 @@ std::string Backend::genGlobalRNGSkipAhead(CodeStream &os, const std::string &se return "&localRNG"; } //-------------------------------------------------------------------------- -void Backend::genNeuronUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +Type::ResolvedType Backend::getPopulationRNGType() const +{ + return CURandState; +} +//-------------------------------------------------------------------------- +void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); - // Generate struct definitions - modelMerged.genMergedNeuronUpdateGroupStructs(os, *this); - modelMerged.genMergedNeuronSpikeQueueUpdateStructs(os, *this); - modelMerged.genMergedNeuronPrevSpikeTimeUpdateStructs(os, *this); - - // Generate arrays of merged structs and functions to push them - genMergedStructArrayPush(os, modelMerged.getMergedNeuronSpikeQueueUpdateGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedNeuronUpdateGroups()); - - // Generate preamble - preambleHandler(os); + // Generate stream with neuron update code + std::ostringstream neuronUpdateStream; + CodeStream neuronUpdate(neuronUpdateStream); - // Generate data structure for accessing merged groups - // **NOTE** constant cache is preferentially given to synapse groups as, typically, more synapse kernels are launched - // so subtract constant memory requirements of synapse group start ids from total constant memory - const size_t synapseGroupStartIDSize = (getGroupStartIDSize(modelMerged.getMergedPresynapticUpdateGroups()) + - getGroupStartIDSize(modelMerged.getMergedPostsynapticUpdateGroups()) + - getGroupStartIDSize(modelMerged.getMergedSynapseDynamicsGroups())); - size_t totalConstMem = (getChosenDeviceSafeConstMemBytes() > synapseGroupStartIDSize) ? (getChosenDeviceSafeConstMemBytes() - synapseGroupStartIDSize) : 0; - genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedNeuronUpdateGroups(), - [this](const NeuronGroupInternal &ng){ return padKernelSize(ng.getNumNeurons(), KernelNeuronUpdate); }); - genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups(), - [this](const NeuronGroupInternal &ng){ return padKernelSize(ng.getNumNeurons(), KernelNeuronPrevSpikeTimeUpdate); }); - os << std::endl; + // Begin environment with standard library + EnvironmentLibrary neuronUpdateEnv(neuronUpdate, StandardLibrary::getMathsFunctions()); // If any neuron groups require their previous spike times updating size_t idNeuronPrevSpikeTimeUpdate = 0; - if(!modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronPrevSpikeTimeUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)"; + //if(!modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups().empty()) { + neuronUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelNeuronPrevSpikeTimeUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(neuronUpdateEnv.getStream()); - Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); - os << "const unsigned int id = " << getKernelBlockSize(KernelNeuronPrevSpikeTimeUpdate) << " * blockIdx.x + threadIdx.x;" << std::endl; + EnvironmentExternal funcEnv(neuronUpdateEnv); + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); + + funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelNeuronPrevSpikeTimeUpdate) << " * blockIdx.x + threadIdx.x;" << std::endl; if(model.getBatchSize() > 1) { - os << "const unsigned int batch = blockIdx.y;" << std::endl; + funcEnv.getStream() << "const unsigned int batch = blockIdx.y;" << std::endl; + funcEnv.add(Type::Uint32.addConst(), "batch", "batch"); + } + else { + funcEnv.add(Type::Uint32.addConst(), "batch", "0"); } - kernelSubs.addVarSubstitution("t", "t"); - genNeuronPrevSpikeTimeUpdateKernel(os, kernelSubs, modelMerged, idNeuronPrevSpikeTimeUpdate); + genNeuronPrevSpikeTimeUpdateKernel(funcEnv, modelMerged, idNeuronPrevSpikeTimeUpdate); } - os << std::endl; - } + neuronUpdateEnv.getStream() << std::endl; + //} // Generate reset kernel to be run before the neuron kernel size_t idNeuronSpikeQueueUpdate = 0; - os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronSpikeQueueUpdate] << "()"; + neuronUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelNeuronSpikeQueueUpdate] << "()"; { - CodeStream::Scope b(os); + CodeStream::Scope b(neuronUpdateEnv.getStream()); - os << "const unsigned int id = " << getKernelBlockSize(KernelNeuronSpikeQueueUpdate) << " * blockIdx.x + threadIdx.x;" << std::endl; + neuronUpdateEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelNeuronSpikeQueueUpdate) << " * blockIdx.x + threadIdx.x;" << std::endl; - genNeuronSpikeQueueUpdateKernel(os, modelMerged, idNeuronSpikeQueueUpdate); + genNeuronSpikeQueueUpdateKernel(neuronUpdateEnv, modelMerged, idNeuronSpikeQueueUpdate); } - os << std::endl; + neuronUpdateEnv.getStream() << std::endl; size_t idStart = 0; - os << "extern \"C\" __global__ void " << KernelNames[KernelNeuronUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t"; + neuronUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelNeuronUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t"; if(model.isRecordingInUse()) { - os << ", unsigned int recordingTimestep"; + neuronUpdateEnv.getStream() << ", unsigned int recordingTimestep"; } - os << ")" << std::endl; + neuronUpdateEnv.getStream() << ")" << std::endl; { - CodeStream::Scope b(os); + CodeStream::Scope b(neuronUpdateEnv.getStream()); - Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); - kernelSubs.addVarSubstitution("t", "t"); + EnvironmentExternal funcEnv(neuronUpdateEnv); + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); - os << "const unsigned int id = " << getKernelBlockSize(KernelNeuronUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; + funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelNeuronUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; if(model.getBatchSize() > 1) { - os << "const unsigned int batch = blockIdx.y;" << std::endl; - kernelSubs.addVarSubstitution("batch", "batch"); + funcEnv.getStream() << "const unsigned int batch = blockIdx.y;" << std::endl; + funcEnv.add(Type::Uint32.addConst(), "batch", "batch"); } else { - kernelSubs.addVarSubstitution("batch", "0"); + funcEnv.add(Type::Uint32.addConst(), "batch", "0"); } - genNeuronUpdateKernel(os, kernelSubs, modelMerged, idStart); + genNeuronUpdateKernel(funcEnv, modelMerged, idStart); } - os << "void updateNeurons(" << modelMerged.getModel().getTimePrecision().getName() << " t"; + neuronUpdateEnv.getStream() << "void updateNeurons(" << modelMerged.getModel().getTimePrecision().getName() << " t"; if(model.isRecordingInUse()) { - os << ", unsigned int recordingTimestep"; + neuronUpdateEnv.getStream() << ", unsigned int recordingTimestep"; } - os << ")"; + neuronUpdateEnv.getStream() << ")"; { - CodeStream::Scope b(os); + CodeStream::Scope b(neuronUpdateEnv.getStream()); if(idNeuronPrevSpikeTimeUpdate > 0) { - CodeStream::Scope b(os); - genKernelDimensions(os, KernelNeuronPrevSpikeTimeUpdate, idNeuronPrevSpikeTimeUpdate, model.getBatchSize()); - os << KernelNames[KernelNeuronPrevSpikeTimeUpdate] << "<<>>(t);" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; + CodeStream::Scope b(neuronUpdateEnv.getStream()); + genKernelDimensions(neuronUpdateEnv.getStream(), KernelNeuronPrevSpikeTimeUpdate, idNeuronPrevSpikeTimeUpdate, model.getBatchSize()); + neuronUpdateEnv.getStream() << KernelNames[KernelNeuronPrevSpikeTimeUpdate] << "<<>>(t);" << std::endl; + neuronUpdateEnv.getStream() << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; } if(idNeuronSpikeQueueUpdate > 0) { - CodeStream::Scope b(os); - genKernelDimensions(os, KernelNeuronSpikeQueueUpdate, idNeuronSpikeQueueUpdate, 1); - os << KernelNames[KernelNeuronSpikeQueueUpdate] << "<<>>();" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; + CodeStream::Scope b(neuronUpdateEnv.getStream()); + genKernelDimensions(neuronUpdateEnv.getStream(), KernelNeuronSpikeQueueUpdate, idNeuronSpikeQueueUpdate, 1); + neuronUpdateEnv.getStream() << KernelNames[KernelNeuronSpikeQueueUpdate] << "<<>>();" << std::endl; + neuronUpdateEnv.getStream() << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; } if(idStart > 0) { - CodeStream::Scope b(os); + CodeStream::Scope b(neuronUpdateEnv.getStream()); - Timer t(os, "neuronUpdate", model.isTimingEnabled()); + Timer t(neuronUpdateEnv.getStream(), "neuronUpdate", model.isTimingEnabled()); - genKernelDimensions(os, KernelNeuronUpdate, idStart, model.getBatchSize()); - os << KernelNames[KernelNeuronUpdate] << "<<>>(t"; + genKernelDimensions(neuronUpdateEnv.getStream(), KernelNeuronUpdate, idStart, model.getBatchSize()); + neuronUpdateEnv.getStream() << KernelNames[KernelNeuronUpdate] << "<<>>(t"; if(model.isRecordingInUse()) { - os << ", recordingTimestep"; + neuronUpdateEnv.getStream() << ", recordingTimestep"; } - os << ");" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; + neuronUpdateEnv.getStream() << ");" << std::endl; + neuronUpdateEnv.getStream() << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; } } + + + // Generate struct definitions + modelMerged.genMergedNeuronUpdateGroupStructs(os, *this); + modelMerged.genMergedNeuronSpikeQueueUpdateStructs(os, *this); + modelMerged.genMergedNeuronPrevSpikeTimeUpdateStructs(os, *this); + + // Generate arrays of merged structs and functions to push them + genMergedStructArrayPush(os, modelMerged.getMergedNeuronSpikeQueueUpdateGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedNeuronUpdateGroups()); + + // Generate preamble + preambleHandler(os); + + // Generate data structure for accessing merged groups + // **NOTE** constant cache is preferentially given to synapse groups as, typically, more synapse kernels are launched + // so subtract constant memory requirements of synapse group start ids from total constant memory + const size_t synapseGroupStartIDSize = (getGroupStartIDSize(modelMerged.getMergedPresynapticUpdateGroups()) + + getGroupStartIDSize(modelMerged.getMergedPostsynapticUpdateGroups()) + + getGroupStartIDSize(modelMerged.getMergedSynapseDynamicsGroups())); + size_t totalConstMem = (getChosenDeviceSafeConstMemBytes() > synapseGroupStartIDSize) ? (getChosenDeviceSafeConstMemBytes() - synapseGroupStartIDSize) : 0; + genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedNeuronUpdateGroups(), + [this](const NeuronGroupInternal &ng){ return padKernelSize(ng.getNumNeurons(), KernelNeuronUpdate); }); + genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups(), + [this](const NeuronGroupInternal &ng){ return padKernelSize(ng.getNumNeurons(), KernelNeuronPrevSpikeTimeUpdate); }); + os << std::endl; + os << neuronUpdateStream.str(); } //-------------------------------------------------------------------------- void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const @@ -1803,13 +1819,8 @@ std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) con return type->getName(); } //-------------------------------------------------------------------------- -const Type::ValueBase *Backend::getMergedGroupSimRNGType() const -{ - return CURandState::getInstance(); -} -//-------------------------------------------------------------------------- void Backend::genGlobalDeviceRNG(CodeStream &, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &, CodeStream &, - const Type::TypeContext &typeContext, MemAlloc &memAlloc) const + MemAlloc &memAlloc) const { // Define global Phillox RNG // **NOTE** this is actually accessed as a global so, unlike other variables, needs device global @@ -1818,14 +1829,14 @@ void Backend::genGlobalDeviceRNG(CodeStream &, CodeStream &definitionsInternal, // Implement global Phillox RNG runner << "__device__ curandStatePhilox4_32_10_t d_rng;" << std::endl; - memAlloc += MemAlloc::device(CURandStatePhilox43210::getInstance()->getSizeBytes(typeContext)); + memAlloc += MemAlloc::device(CURandStatePhilox43210.getSize(getPointerBytes())); } //-------------------------------------------------------------------------- void Backend::genPopulationRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const Type::TypeContext &typeContext, const std::string &name, size_t count, MemAlloc &memAlloc) const + const std::string &name, size_t count, MemAlloc &memAlloc) const { // Create an array or XORWOW RNGs - genArray(definitions, definitionsInternal, runner, allocations, free, typeContext, name, VarLocation::DEVICE, count, memAlloc); + genArray(definitions, definitionsInternal, runner, allocations, free, CURandState, name, VarLocation::DEVICE, count, memAlloc); } //-------------------------------------------------------------------------- void Backend::genTimer(CodeStream &, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index 734ba3b418..a9b872808e 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -2089,11 +2089,6 @@ std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type, con } } //-------------------------------------------------------------------------- -const Type::Base *Backend::getMergedGroupSimRNGType() const -{ - return CLRRNGLFSR113Stream::getInstance(); -} -//-------------------------------------------------------------------------- void Backend::genVariablePush(CodeStream &os, const std::string &type, const std::string &name, VarLocation loc, bool autoInitialized, size_t count) const { if (!(loc & VarLocation::ZERO_COPY)) { diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index e3cc5001e2..7d1c510f25 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -268,7 +268,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host }); } - // Generate struct definitions + // Generate struct definitions modelMerged.genMergedNeuronUpdateGroupStructs(os, *this); modelMerged.genMergedNeuronSpikeQueueUpdateStructs(os, *this); modelMerged.genMergedNeuronPrevSpikeTimeUpdateStructs(os, *this); @@ -1451,12 +1451,6 @@ std::string Backend::getMergedGroupFieldHostTypeName(const Type::ResolvedType &t return type.getName(); } //-------------------------------------------------------------------------- -std::optional Backend::getMergedGroupSimRNGType() const -{ - assert(false); - return std::nullopt; -} -//-------------------------------------------------------------------------- void Backend::genPopVariableInit(EnvironmentExternalBase &env, HandlerEnv handler) const { handler(env); diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 4d441a6b5e..70540524b2 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -322,66 +322,70 @@ void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &en { CodeStream::Scope b(popEnv.getStream()); + // Create matching environment + EnvironmentGroupMergedField neuronEnv(popEnv, ng); + genNeuronIndexCalculation(neuronEnv, batchSize); + // If neuron group requires delays if(ng.getArchetype().isDelayRequired()) { if(batchSize == 1) { - popEnv.printLine("const unsigned int lastTimestepDelaySlot = *$(_spk_que_ptr);"); + neuronEnv.printLine("const unsigned int lastTimestepDelaySlot = *$(_spk_que_ptr);"); } else { - popEnv.printLine("const unsigned int lastTimestepDelaySlot = *$(_spk_que_ptr) + (batch * " + std::to_string(ng.getArchetype().getNumDelaySlots()) + ");"); + neuronEnv.printLine("const unsigned int lastTimestepDelaySlot = *$(_spk_que_ptr) + ($(batch) * " + std::to_string(ng.getArchetype().getNumDelaySlots()) + ");"); } - popEnv.printLine("const unsigned int lastTimestepDelayOffset = lastTimestepDelaySlot * $(num_neurons);"); + neuronEnv.printLine("const unsigned int lastTimestepDelayOffset = lastTimestepDelaySlot * $(num_neurons);"); if(ng.getArchetype().isPrevSpikeTimeRequired()) { // If there is a spike for this thread, set previous spike time to time of last timestep // **NOTE** spkQuePtr is updated below so this already points to last timestep - popEnv.print("if($(id) < $(_spk_cnt)[lastTimestepDelaySlot])"); + neuronEnv.print("if($(id) < $(_spk_cnt)[lastTimestepDelaySlot])"); { - CodeStream::Scope b(popEnv.getStream()); - popEnv.printLine("$(_prev_spk_time)[lastTimestepDelayOffset + $(_spk)[lastTimestepDelayOffset + $(id)]] = $(t) - DT;"); + CodeStream::Scope b(neuronEnv.getStream()); + neuronEnv.printLine("$(_prev_spk_time)[lastTimestepDelayOffset + $(_spk)[lastTimestepDelayOffset + $(id)]] = $(t) - DT;"); } } if(ng.getArchetype().isPrevSpikeEventTimeRequired()) { // If there is a spike-like-event for this thread, set previous spike-like-event time to time of last timestep // **NOTE** spkQuePtr is updated below so this already points to last timestep - popEnv.print("if($(id) < $(_spk_cnt_envt)[lastTimestepDelaySlot])"); + neuronEnv.print("if($(id) < $(_spk_cnt_envt)[lastTimestepDelaySlot])"); { - CodeStream::Scope b(popEnv.getStream()); - popEnv.printLine("$(_prev_spk_evnt_time)[lastTimestepDelayOffset + $(_spk_evnt)[lastTimestepDelayOffset + $(id)]] = $(t) - DT;"); + CodeStream::Scope b(neuronEnv.getStream()); + neuronEnv.printLine("$(_prev_spk_evnt_time)[lastTimestepDelayOffset + $(_spk_evnt)[lastTimestepDelayOffset + $(id)]] = $(t) - DT;"); } } } // Otherwises else { if(batchSize > 1) { - popEnv.printLine("const unsigned int batchOffset = $(num_neurons) * batch;"); + neuronEnv.printLine("const unsigned int batchOffset = $(num_neurons) * $(batch);"); } if(ng.getArchetype().isPrevSpikeTimeRequired()) { // If there is a spike for this thread, set previous spike time to time of last timestep - popEnv.print("if($(id) < $(_spk_cnt)[" + std::string{(batchSize == 1) ? "0" : "batch"} + "])"); + neuronEnv.print("if($(id) < $(_spk_cnt)[$(batch)])"); { - CodeStream::Scope b(popEnv.getStream()); - popEnv.print("$(_prev_spk_time)[$(_spk)["); + CodeStream::Scope b(neuronEnv.getStream()); + neuronEnv.print("$(_prev_spk_time)[$(_spk)["); if(batchSize > 1) { - popEnv.getStream() << "batchOffset + "; + neuronEnv.getStream() << "batchOffset + "; } - popEnv.printLine("$(id)]] = $(t) - DT;"); + neuronEnv.printLine("$(id)]] = $(t) - DT;"); } } if(ng.getArchetype().isPrevSpikeEventTimeRequired()) { // If there is a spike-like-event for this thread, set previous spike-like-event time to time of last timestep - popEnv.print("if($(id) < $(_spk_cnt_evnt)[" + std::string{(batchSize == 1) ? "0" : "batch"} + "])"); + neuronEnv.print("if($(id) < $(_spk_cnt_evnt)[$(batch)])"); { - CodeStream::Scope b(popEnv.getStream()); - popEnv.print("$(_prev_spk_evnt_time)[$(_spk_evnt)["); + CodeStream::Scope b(neuronEnv.getStream()); + neuronEnv.print("$(_prev_spk_evnt_time)[$(_spk_evnt)["); if(batchSize > 1) { - popEnv.getStream() << "batchOffset + "; + neuronEnv.getStream() << "batchOffset + "; } - popEnv.printLine("$(id)]] = $(t) - DT;"); + neuronEnv.printLine("$(id)]] = $(t) - DT;"); } } } - popEnv.getStream() << std::endl; + neuronEnv.getStream() << std::endl; }); } @@ -392,70 +396,82 @@ void BackendSIMT::genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, // Loop through local neuron groups idStart = 0; - for(const auto &n : modelMerged.getMergedNeuronSpikeQueueUpdateGroups()) { - if(idStart == 0) { - env.getStream() << "if(id < " << n.getGroups().size() << ")"; - } - else { - env.getStream() << "if(id >= " << idStart << " && id < " << idStart + n.getGroups().size() << ")"; - } + modelMerged.genMergedNeuronSpikeQueueUpdateGroups( + *this, + [&env, &idStart, batchSize, this](const auto &n) { - CodeStream::Scope b(env.getStream()); + if(idStart == 0) { + env.getStream() << "if(id < " << n.getGroups().size() << ")"; + } + else { + env.getStream() << "if(id >= " << idStart << " && id < " << idStart + n.getGroups().size() << ")"; + } + { + CodeStream::Scope b(env.getStream()); - // Use this to get reference to merged group structure - env.getStream() << getPointerPrefix() << "struct MergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << " *group = &d_mergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << "[id - " << idStart << "]; " << std::endl; + // Use this to get reference to merged group structure + env.getStream() << getPointerPrefix() << "struct MergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << " *group = &d_mergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << "[id - " << idStart << "]; " << std::endl; + + // Create matching environment + EnvironmentGroupMergedField neuronEnv(env, batchSize); + genNeuronIndexCalculation(neuronEnv, batchSize); - if(n.getArchetype().isDelayRequired()) { // with delay - env.getStream() << "*" << env["_spk_que_ptr"] << " = (*" << env["_spk_que_ptr"] << " + 1) % " << n.getArchetype().getNumDelaySlots() << ";" << std::endl; - } + if(n.getArchetype().isDelayRequired()) { // with delay + neuronEnv.printLine("*$(_spk_que_ptr) = (*$(_spk_que_ptr) + 1) % " + std::to_string(n.getArchetype().getNumDelaySlots()) + ";"); + } - if(batchSize > 1) { - env.getStream() << "for(unsigned int batch = 0; batch < " << batchSize << "; batch++)" << CodeStream::OB(1); - } - n.genMergedGroupSpikeCountReset(env, batchSize); - if(batchSize > 1) { - env.getStream() << CodeStream::CB(1); + if(batchSize > 1) { + neuronEnv.getStream() << "for(unsigned int batch = 0; batch < " << batchSize << "; batch++)" << CodeStream::OB(1); + } + n.genMergedGroupSpikeCountReset(neuronEnv, batchSize); + if(batchSize > 1) { + neuronEnv.getStream() << CodeStream::CB(1); + } } - } - idStart += n.getGroups().size(); - } + idStart += n.getGroups().size(); + }); } //-------------------------------------------------------------------------- void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - // If any neuron groups emit spike events - if(std::any_of(modelMerged.getMergedNeuronUpdateGroups().cbegin(), modelMerged.getMergedNeuronUpdateGroups().cend(), - [](const NeuronUpdateGroupMerged &n) { return n.getArchetype().isSpikeEventRequired(); })) + // Generate code to zero shared memory spike event count using thread 1 + std::ostringstream shSpkCountInitStream; + CodeStream shSpkCountInit(shSpkCountInitStream); + shSpkCountInit << getSharedPrefix() << "unsigned int shSpkCount;" << std::endl; + shSpkCountInit << "if (" << getThreadID() << " == 1)"; { - env.getStream() << getSharedPrefix() << "unsigned int shSpkEvnt[" << getKernelBlockSize(KernelNeuronUpdate) << "];" << std::endl; - env.getStream() << getSharedPrefix() << "unsigned int shPosSpkEvnt;" << std::endl; - env.getStream() << getSharedPrefix() << "unsigned int shSpkEvntCount;" << std::endl; - env.getStream() << std::endl; - env.getStream() << "if (" << getThreadID() << " == 1)"; - { - CodeStream::Scope b(env.getStream()); - env.getStream() << "shSpkEvntCount = 0;" << std::endl; - } - env.getStream() << std::endl; + CodeStream::Scope b(shSpkCountInit); + shSpkCountInit << "shSpkCount = 0;" << std::endl; } - // If any neuron groups emit true spikes - if(std::any_of(modelMerged.getMergedNeuronUpdateGroups().cbegin(), modelMerged.getMergedNeuronUpdateGroups().cend(), - [](const NeuronUpdateGroupMerged &n) { return !n.getArchetype().getNeuronModel()->getThresholdConditionCode().empty(); })) + // Generate code to zero shared memory spike event count using thread 1 + std::ostringstream shSpkEvntCountInitStream; + CodeStream shSpkEvntCountInit(shSpkEvntCountInitStream); + shSpkEvntCountInit << getSharedPrefix() << "unsigned int shSpkEvntCount;" << std::endl; + shSpkEvntCountInit << "if (" << getThreadID() << " == 1)"; { - env.getStream() << getSharedPrefix() << "unsigned int shSpk[" << getKernelBlockSize(KernelNeuronUpdate) << "];" << std::endl; - env.getStream() << getSharedPrefix() << "unsigned int shPosSpk;" << std::endl; - env.getStream() << getSharedPrefix() << "unsigned int shSpkCount;" << std::endl; - env.getStream() << "if (" << getThreadID() << " == 0)"; - { - CodeStream::Scope b(env.getStream()); - env.getStream() << "shSpkCount = 0;" << std::endl; - } - env.getStream() << std::endl; + CodeStream::Scope b(shSpkEvntCountInit); + shSpkEvntCountInit << "shSpkEvntCount = 0;" << std::endl; } + // Add shared memory substitutions so they're only instantiated as required + EnvironmentExternal neuronEnv(env); + const std::string blockSizeStr = std::to_string(getKernelBlockSize(KernelNeuronUpdate)); + neuronEnv.add(Type::Void, "_sh_spk", "shSpk", + {neuronEnv.addInitialiser(getSharedPrefix() + "unsigned int shSpk[" + blockSizeStr + "];")}); + neuronEnv.add(Type::Void, "_sh_spk_pos", "shSpkPos", + {neuronEnv.addInitialiser(getSharedPrefix() + "unsigned int shSpkPos;")}); + neuronEnv.add(Type::Void, "_sh_spk_count", "shSpkCount", + {neuronEnv.addInitialiser(shSpkCountInitStream.str()}); + neuronEnv.add(Type::Void, "_sh_spk_evnt", "shSpkEvnt", + {neuronEnv.addInitialiser(getSharedPrefix() + "unsigned int shSpkEvnt[" + blockSizeStr + "];")}); + neuronEnv.add(Type::Void, "_sh_spk_evnt_pos", "shSpkEvntPos", + {neuronEnv.addInitialiser(getSharedPrefix() + "unsigned int shSpkEvntPos;")}); + neuronEnv.add(Type::Void, "_sh_spk_evnt_count", "shSpkEvntCount", + {neuronEnv.addInitialiser(shSpkEvntCountInitStream.str()}); + // If any neuron groups record spikes if(std::any_of(modelMerged.getMergedNeuronUpdateGroups().cbegin(), modelMerged.getMergedNeuronUpdateGroups().cend(), [](const NeuronUpdateGroupMerged &n) { return n.getArchetype().isSpikeRecordingEnabled(); })) @@ -470,31 +486,29 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM genRecordingSharedMemInit(env.getStream(), "Evnt"); } - genSharedMemBarrier(env.getStream()); + genSharedMemBarrier(neuronEnv.getStream()); // Parallelise over neuron groups idStart = 0; genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedNeuronUpdateGroups, + neuronEnv, modelMerged, idStart, &ModelSpecMerged::genMergedNeuronUpdateGroups, [this](const NeuronGroupInternal &ng) { return padKernelSize(ng.getNumNeurons(), KernelNeuronUpdate); }, [batchSize, &modelMerged, this](EnvironmentExternalBase &popEnv, NeuronUpdateGroupMerged &ng) { - EnvironmentGroupMergedField neuronEnv(popEnv, ng); - genNeuronIndexCalculation(neuronEnv, batchSize); - neuronEnv.getStream() << std::endl; - // Call handler to generate generic neuron code - neuronEnv.print("if($(id) < $(num_neurons))"); + popEnv.print("if($(id) < $(num_neurons))"); { - CodeStream::Scope b(neuronEnv.getStream()); + CodeStream::Scope b(popEnv.getStream()); + EnvironmentGroupMergedField groupEnv(popEnv, ng); + genNeuronIndexCalculation(groupEnv, batchSize); - // Copy global RNG stream to local and use pointer to this for rng - const std::string rng = printSubs("$(_rng)[" + ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)") + "]", neuronEnv); - if(ng.getArchetype().isSimRNGRequired()) { - neuronEnv.add(Type::Void, "rng", genPopulationRNGPreamble(neuronEnv.getStream(), rng)); - } + // Add population RNG field + groupEnv.addField(getPopulationRNGType().createPointer(), "_rng", "rng", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "rng" + g.getName(); }, + ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)")); + // **TODO** for OCL do genPopulationRNGPreamble(os, popSubs, "group->rng[" + ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)") + "]") in initialiser - ng.generateNeuronUpdate(*this, neuronEnv, modelMerged, + ng.generateNeuronUpdate(*this, groupEnv, modelMerged, // Emit true spikes [&modelMerged, this](EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng) { @@ -503,13 +517,14 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM // Emit spike-like events [&modelMerged, this](EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng) { - genEmitSpike(env, modelMerged, "Evnt", ng.getArchetype().isSpikeEventRecordingEnabled()); + genEmitSpike(env, modelMerged, "_evnt", ng.getArchetype().isSpikeEventRecordingEnabled()); }); // Copy local stream back to local - if(ng.getArchetype().isSimRNGRequired()) { - genPopulationRNGPostamble(neuronEnv.getStream(), rng); - } + // **TODO** postamble for OCL + //if(ng.getArchetype().isSimRNGRequired()) { + // genPopulationRNGPostamble(neuronEnv.getStream(), rng); + //} } genSharedMemBarrier(neuronEnv.getStream()); @@ -1697,8 +1712,8 @@ size_t BackendSIMT::padKernelSize(size_t size, Kernel kernel) const //-------------------------------------------------------------------------- void BackendSIMT::genEmitSpike(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, const std::string &suffix, bool recordingEnabled) const { - env.getStream() << "const unsigned int spk" << suffix << "Idx = " << getAtomic(Type::Uint32, AtomicOperation::ADD, AtomicMemSpace::SHARED) << "(&shSpk" << suffix << "Count, 1);" << std::endl; - env.getStream() << "shSpk" << suffix << "[spk" << suffix << "Idx] = " << env["id"] << ";" << std::endl; + env.printLine("const unsigned int spk" + suffix + "_idx = " + getAtomic(Type::Uint32, AtomicOperation::ADD, AtomicMemSpace::SHARED) + "(&$(_sh_spk" + suffix + "_count), 1);"); + env.printLine("$(_sh_spk" + suffix + ")[spk" + suffix + "_idx] = $(id);"); // If recording is enabled, set bit in recording word if(recordingEnabled) { diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index b0b00cbec4..b2b7cef8cf 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -43,7 +43,7 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [&modelMerged, &ng](const std::string&, VarAccessDuplication d) { - return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "id"); + return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "$(id)"); }); // Pretty print code back to environment @@ -84,7 +84,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); // Read into local variable - const std::string idx = ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); + const std::string idx = ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "$(id)"); psmEnv.getStream() << "// postsynaptic model " << getIndex() << std::endl; psmEnv.printLine(getScalarType().getName() + " linSyn = $(_out_post)[" + idx + "];"); @@ -122,7 +122,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [&modelMerged, &ng](const std::string&, VarAccessDuplication d) { - return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "id"); + return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "$(id)"); }); // Pretty print code back to environment @@ -133,7 +133,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env prettyPrintStatements(getArchetype().getPSDecayCodeTokens(), getTypeContext(), varEnv, decayErrorHandler); // Write back linSyn - varEnv.printLine("$(_out_post)[" + ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id") + "] = linSyn;"); + varEnv.printLine("$(_out_post)[" + ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "$(id)") + "] = linSyn;"); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &hash) const @@ -167,7 +167,7 @@ void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backe [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); // Add reverse insyn variable to - const std::string idx = ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "id"); + const std::string idx = ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "$(id)"); outSynEnv.printLine(getArchetype().getPreTargetVar() + " += $(_out_pre)[" + idx + "];"); // Zero it again @@ -204,11 +204,11 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { - return ng.getReadVarIndex(delayed, batchSize, d, "id"); + return ng.getReadVarIndex(delayed, batchSize, d, "$(id)"); }, [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { - return ng.getWriteVarIndex(delayed, batchSize, d, "id"); + return ng.getWriteVarIndex(delayed, batchSize, d, "$(id)"); }); /*neuronSubstitutionsInSynapticCode(varEnv, &ng.getArchetype(), "", "_post", "", "", "", dynamicsNotSpike, @@ -239,8 +239,8 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentEx for(const auto &v : getArchetype().getWUModel()->getPostVars()) { if(v.access & VarAccessMode::READ_WRITE) { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - env.print("$(" + v.name + ")[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(v.access), "id") + "] = "); - env.printLine("$(" + v.name + ")[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(v.access), "id") + "];"); + env.print("$(" + v.name + ")[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(v.access), "$(id)") + "] = "); + env.printLine("$(" + v.name + ")[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(v.access), "$(id)") + "];"); } } } @@ -291,11 +291,11 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [batchSize, delayed, &ng](const std::string&, VarAccessDuplication d) { - return ng.getReadVarIndex(delayed, batchSize, d, "id"); + return ng.getReadVarIndex(delayed, batchSize, d, "$(id)"); }, [batchSize, delayed, &ng](const std::string&, VarAccessDuplication d) { - return ng.getWriteVarIndex(delayed, batchSize, d, "id"); + return ng.getWriteVarIndex(delayed, batchSize, d, "$(id)"); }); /*neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_pre", "", "", "", dynamicsNotSpike, @@ -326,8 +326,8 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentEx for(const auto &v : getArchetype().getWUModel()->getPreVars()) { if(v.access & VarAccessMode::READ_WRITE) { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - env.print("$(" + v.name + ")[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(v.access), "id") + "] = "); - env.printLine("$(" + v.name + ")[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(v.access), "id") + "];"); + env.print("$(" + v.name + ")[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(v.access), "$(id)") + "] = "); + env.printLine("$(" + v.name + ")[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(v.access), "$(id)") + "];"); } } } @@ -504,7 +504,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Substitute spike times const std::string timePrecision = modelMerged.getModel().getTimePrecision().getName(); - const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "id"); + const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "$(id)"); neuronEnv.add(getTimeType().addConst(), "sT", "lsT", {neuronEnv.addInitialiser("const " + timePrecision + " lsT = $(_spk_time)[" + spikeTimeReadIndex + "];")}); neuronEnv.add(getTimeType().addConst(), "prev_sT", "lprevST", @@ -521,12 +521,12 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E [batchSize, &neuronEnv, this](const std::string &varName, VarAccessDuplication d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getReadVarIndex(delayed, batchSize, d, "id") ; + return getReadVarIndex(delayed, batchSize, d, "$(id)") ; }, [batchSize, &neuronEnv, this](const std::string &varName, VarAccessDuplication d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); - return getWriteVarIndex(delayed, batchSize, d, "id") ; + return getWriteVarIndex(delayed, batchSize, d, "$(id)") ; }); @@ -720,12 +720,12 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // If spike times are required, copy times from register if(getArchetype().isSpikeTimeRequired()) { - neuronVarEnv.printLine("$(_spk_time)[" + getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, "id") + "] = $(sT);"); + neuronVarEnv.printLine("$(_spk_time)[" + getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, "$(id)") + "] = $(sT);"); } // If previous spike times are required, copy times from register if(getArchetype().isPrevSpikeTimeRequired()) { - neuronVarEnv.printLine("$(_prev_spk_time)[" + getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, "id") + "] = $(prev_sT);"); + neuronVarEnv.printLine("$(_prev_spk_time)[" + getWriteVarIndex(true, batchSize, VarAccessDuplication::DUPLICATE, "$(id)") + "] = $(prev_sT);"); } // Loop through outgoing synapse groups with some sort of presynaptic code @@ -764,10 +764,10 @@ std::string NeuronUpdateGroupMerged::getVarIndex(unsigned int batchSize, VarAcce return (batchSize == 1) ? "0" : "$(batch)"; } else if(varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return "$(" + index + ")"; + return index; } else { - return "$(_batch_offset) + $(" + index + ")"; + return "$(_batch_offset) " + index; } } //-------------------------------------------------------------------------- @@ -778,10 +778,10 @@ std::string NeuronUpdateGroupMerged::getReadVarIndex(bool delay, unsigned int ba return (batchSize == 1) ? "$(_read_delay_slot)" : "$(_read_batch_delay_slot)"; } else if (varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return "$(_read_delay_offset) + $(" + index + ")"; + return "$(_read_delay_offset) + " + index; } else { - return "$(_read_batch_delay_offset) + $(" + index + ")"; + return "$(_read_batch_delay_offset) + " + index; } } else { @@ -796,10 +796,10 @@ std::string NeuronUpdateGroupMerged::getWriteVarIndex(bool delay, unsigned int b return (batchSize == 1) ? "$(_write_delay_slot)" : "$(_write_batch_delay_slot)"; } else if (varDuplication == VarAccessDuplication::SHARED || batchSize == 1) { - return "$(_write_delay_offset) + $(" + index + ")"; + return "$(_write_delay_offset) + " + index; } else { - return "$(_write_batch_delay_offset) + $(" + index + ")"; + return "$(_write_batch_delay_offset) + " + index; } } else { From 489ae36426bf99b5131197d9bb69d984a47e0606 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 7 Jul 2023 16:48:47 +0100 Subject: [PATCH 347/725] corrected order of code generation in CUDA optimiser --- src/genn/backends/cuda/optimiser.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/backends/cuda/optimiser.cc b/src/genn/backends/cuda/optimiser.cc index 9c7a3403ea..9a9ed57c2e 100644 --- a/src/genn/backends/cuda/optimiser.cc +++ b/src/genn/backends/cuda/optimiser.cc @@ -460,11 +460,11 @@ KernelOptimisationOutput optimizeBlockSize(int deviceID, const cudaDeviceProp &d // Generate code with suffix so it doesn't interfere with primary generated code // **NOTE** we don't really need to generate all the code but, on windows, generating code selectively seems to result in werid b const std::string dryRunSuffix = "CUDAOptim"; - generateRunner(outputPath, modelMerged, backend, dryRunSuffix); generateSynapseUpdate(outputPath, modelMerged, backend, dryRunSuffix); generateNeuronUpdate(outputPath, modelMerged, backend, dryRunSuffix); generateCustomUpdate(outputPath, modelMerged, backend, dryRunSuffix); generateInit(outputPath, modelMerged, backend, dryRunSuffix); + generateRunner(outputPath, modelMerged, backend, dryRunSuffix); // Generate support code module if the backend supports namespaces if (backend.supportsNamespace()) { From 895f317e51569ec8ab8bd3d3fdb636598b48f02f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 7 Jul 2023 16:49:03 +0100 Subject: [PATCH 348/725] more backend SIMT hacking --- src/genn/backends/cuda/backend.cc | 135 +++++++++++++++++------------- 1 file changed, 76 insertions(+), 59 deletions(-) diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index d0714c1d67..28795def1f 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -481,7 +481,10 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host else { funcEnv.add(Type::Uint32.addConst(), "batch", "0"); } - genNeuronUpdateKernel(funcEnv, modelMerged, idStart); + + // Add RNG functions to environment and generate kernel + EnvironmentLibrary rngEnv(funcEnv, getRNGFunctions(model.getPrecision())); + genNeuronUpdateKernel(rngEnv, modelMerged, idStart); } neuronUpdateEnv.getStream() << "void updateNeurons(" << modelMerged.getModel().getTimePrecision().getName() << " t"; @@ -548,102 +551,85 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host os << neuronUpdateStream.str(); } //-------------------------------------------------------------------------- -void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { - // Generate struct definitions - modelMerged.genMergedSynapseDendriticDelayUpdateStructs(os, *this); - modelMerged.genMergedPresynapticUpdateGroupStructs(os, *this); - modelMerged.genMergedPostsynapticUpdateGroupStructs(os, *this); - modelMerged.genMergedSynapseDynamicsGroupStructs(os, *this); - - // Generate arrays of merged structs and functions to push them - genMergedStructArrayPush(os, modelMerged.getMergedSynapseDendriticDelayUpdateGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedPresynapticUpdateGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedPostsynapticUpdateGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedSynapseDynamicsGroups()); - - // Generate preamble - preambleHandler(os); - - // Generate data structure for accessing merged groups - size_t totalConstMem = getChosenDeviceSafeConstMemBytes(); - genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedPresynapticUpdateGroups(), - [this](const SynapseGroupInternal &sg) - { - return padKernelSize(getNumPresynapticUpdateThreads(sg, getPreferences()), KernelPresynapticUpdate); - }); - genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedPostsynapticUpdateGroups(), - [this](const SynapseGroupInternal &sg){ return padKernelSize(getNumPostsynapticUpdateThreads(sg), KernelPostsynapticUpdate); }); + // Generate stream with synapse update code + std::ostringstream synapseUpdateStream; + CodeStream synapseUpdate(synapseUpdateStream); - genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedSynapseDynamicsGroups(), - [this](const SynapseGroupInternal &sg){ return padKernelSize(getNumSynapseDynamicsThreads(sg), KernelSynapseDynamicsUpdate); }); + // Begin environment with standard library + EnvironmentLibrary synapseUpdateEnv(synapseUpdate, StandardLibrary::getMathsFunctions()); // If any synapse groups require dendritic delay, a reset kernel is required to be run before the synapse kernel const ModelSpecInternal &model = modelMerged.getModel(); size_t idSynapseDendricDelayUpdate = 0; - if(!modelMerged.getMergedSynapseDendriticDelayUpdateGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDendriticDelayUpdate] << "()"; + //if(!modelMerged.getMergedSynapseDendriticDelayUpdateGroups().empty()) { + synapseUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDendriticDelayUpdate] << "()"; { CodeStream::Scope b(os); - os << "const unsigned int id = " << getKernelBlockSize(KernelSynapseDendriticDelayUpdate) << " * blockIdx.x + threadIdx.x;" << std::endl; - genSynapseDendriticDelayUpdateKernel(os, modelMerged, idSynapseDendricDelayUpdate); + synapseUpdateEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelSynapseDendriticDelayUpdate) << " * blockIdx.x + threadIdx.x;" << std::endl; + genSynapseDendriticDelayUpdateKernel(synapseUpdateEnv, modelMerged, idSynapseDendricDelayUpdate); } - os << std::endl; - } + synapseUpdateEnv.getStream() << std::endl; + //} // If there are any presynaptic update groups size_t idPresynapticStart = 0; - if(!modelMerged.getMergedPresynapticUpdateGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelPresynapticUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; // end of synapse kernel header + //if(!modelMerged.getMergedPresynapticUpdateGroups().empty()) { + synapseUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelPresynapticUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; // end of synapse kernel header { - CodeStream::Scope b(os); + CodeStream::Scope b(synapseUpdateEnv.getStream()); - Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); - kernelSubs.addVarSubstitution("t", "t"); + EnvironmentExternal funcEnv(synapseUpdateEnv); + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); - os << "const unsigned int id = " << getKernelBlockSize(KernelPresynapticUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; + funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelPresynapticUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; if(model.getBatchSize() > 1) { - os << "const unsigned int batch = blockIdx.y;" << std::endl; - kernelSubs.addVarSubstitution("batch", "batch"); + funcEnv.getStream() << "const unsigned int batch = blockIdx.y;" << std::endl; + funcEnv.add(Type::Uint32.addConst(), "batch", "batch"); } else { - kernelSubs.addVarSubstitution("batch", "0"); + funcEnv.add(Type::Uint32.addConst(), "batch", "0"); } - genPresynapticUpdateKernel(os, kernelSubs, modelMerged, idPresynapticStart); + + // Add RNG functions to environment and generate kernel + EnvironmentLibrary rngEnv(funcEnv, getRNGFunctions(model.getPrecision())); + genPresynapticUpdateKernel(rngEnv, modelMerged, idPresynapticStart); } - } + //} // If any synapse groups require postsynaptic learning size_t idPostsynapticStart = 0; if(!modelMerged.getMergedPostsynapticUpdateGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelPostsynapticUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; + synapseUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelPostsynapticUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; { - CodeStream::Scope b(os); + CodeStream::Scope b(synapseUpdateEnv.getStream()); - Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); - kernelSubs.addVarSubstitution("t", "t"); + EnvironmentExternal funcEnv(synapseUpdateEnv); + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); - os << "const unsigned int id = " << getKernelBlockSize(KernelPostsynapticUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; + funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelPostsynapticUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; if(model.getBatchSize() > 1) { - os << "const unsigned int batch = blockIdx.y;" << std::endl; - kernelSubs.addVarSubstitution("batch", "batch"); + funcEnv.getStream() << "const unsigned int batch = blockIdx.y;" << std::endl; + funcEnv.add(Type::Uint32.addConst(), "batch", "batch"); } else { - kernelSubs.addVarSubstitution("batch", "0"); + funcEnv.add(Type::Uint32.addConst(), "batch", "0"); } - genPostsynapticUpdateKernel(os, kernelSubs, modelMerged, idPostsynapticStart); + genPostsynapticUpdateKernel(funcEnv, modelMerged, idPostsynapticStart); } } size_t idSynapseDynamicsStart = 0; if(!modelMerged.getMergedSynapseDynamicsGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDynamicsUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; // end of synapse kernel header + synapseUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDynamicsUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; // end of synapse kernel header { - CodeStream::Scope b(os); + CodeStream::Scope b(synapseUpdateEnv.getStream()); + EnvironmentExternal funcEnv(synapseUpdateEnv); Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); - kernelSubs.addVarSubstitution("t", "t"); + funcEnv.add("t", "t"); os << "const unsigned int id = " << getKernelBlockSize(KernelSynapseDynamicsUpdate) << " * blockIdx.x + threadIdx.x;" << std::endl; if(model.getBatchSize() > 1) { @@ -657,9 +643,9 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge } } - os << "void updateSynapses(" << modelMerged.getModel().getTimePrecision().getName() << " t)"; + synapseUpdateEnv.getStream() << "void updateSynapses(" << modelMerged.getModel().getTimePrecision().getName() << " t)"; { - CodeStream::Scope b(os); + CodeStream::Scope b(synapseUpdateEnv.getStream()); // Launch pre-synapse reset kernel if required if(idSynapseDendricDelayUpdate > 0) { @@ -699,6 +685,37 @@ void Backend::genSynapseUpdate(CodeStream &os, const ModelSpecMerged &modelMerge os << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; } } + + // Generate struct definitions + modelMerged.genMergedSynapseDendriticDelayUpdateStructs(os, *this); + modelMerged.genMergedPresynapticUpdateGroupStructs(os, *this); + modelMerged.genMergedPostsynapticUpdateGroupStructs(os, *this); + modelMerged.genMergedSynapseDynamicsGroupStructs(os, *this); + + // Generate arrays of merged structs and functions to push them + genMergedStructArrayPush(os, modelMerged.getMergedSynapseDendriticDelayUpdateGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedPresynapticUpdateGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedPostsynapticUpdateGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedSynapseDynamicsGroups()); + + // Generate preamble + preambleHandler(os); + + // Generate data structure for accessing merged groups + size_t totalConstMem = getChosenDeviceSafeConstMemBytes(); + genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedPresynapticUpdateGroups(), + [this](const SynapseGroupInternal &sg) + { + return padKernelSize(getNumPresynapticUpdateThreads(sg, getPreferences()), KernelPresynapticUpdate); + }); + genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedPostsynapticUpdateGroups(), + [this](const SynapseGroupInternal &sg){ return padKernelSize(getNumPostsynapticUpdateThreads(sg), KernelPostsynapticUpdate); }); + + genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedSynapseDynamicsGroups(), + [this](const SynapseGroupInternal &sg){ return padKernelSize(getNumSynapseDynamicsThreads(sg), KernelSynapseDynamicsUpdate); }); + + os << synapseUpdateStream.str(); + } //-------------------------------------------------------------------------- void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const From 24faacbe8ea598d74bd419bb27f91c3d84f7f963 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 09:41:31 +0100 Subject: [PATCH 349/725] fixed typo --- include/genn/genn/code_generator/modelSpecMerged.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 2c47375bf6..1991183fa8 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -180,7 +180,7 @@ class GENN_EXPORT ModelSpecMerged GenMergedGroupFn generateGroup); void genMergedCustomConnectivityUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, GenMergedGroupFn generateGroup); - void genMergedCustomConnectivityHostUpdateGroups(BackendBase &backend, const std::string &updateGroupName, + void genMergedCustomConnectivityHostUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, GenMergedGroupFn generateGroup); void genMergedNeuronSpikeQueueUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); void genMergedNeuronPrevSpikeTimeUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); From 20467ad6d29373fb4c18b8df8840957f9a4d67f9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 09:41:44 +0100 Subject: [PATCH 350/725] value types have ``device`` flag --- include/genn/genn/type.h | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 98b11e66f0..91597dc09f 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -91,6 +91,7 @@ struct ResolvedType { std::string name; size_t size; + bool device; std::optional numeric; //------------------------------------------------------------------------ @@ -98,17 +99,17 @@ struct ResolvedType //------------------------------------------------------------------------ bool operator == (const Value &other) const { - return (std::tie(size, numeric) == std::tie(other.size, other.numeric)); + return (std::tie(size, numeric, device) == std::tie(other.size, other.numeric, other.device)); } bool operator != (const Value &other) const { - return (std::tie(size, numeric) != std::tie(other.size, other.numeric)); + return (std::tie(size, numeric, device) != std::tie(other.size, other.numeric, other.device)); } bool operator < (const Value &other) const { - return (std::tie(size, numeric) < std::tie(other.size, other.numeric)); + return (std::tie(size, numeric, device) < std::tie(other.size, other.numeric, other.device)); } }; @@ -260,18 +261,18 @@ struct ResolvedType // Static API //------------------------------------------------------------------------ template - static ResolvedType createNumeric(const std::string &name, int rank, const std::string &literalSuffix = "", Qualifier qualifiers = Qualifier{0}) + static ResolvedType createNumeric(const std::string &name, int rank, const std::string &literalSuffix = "", Qualifier qualifiers = Qualifier{0}, bool device = false) { - return ResolvedType{Value{name, sizeof(T), Numeric{rank, std::numeric_limits::min(), static_cast(std::numeric_limits::max()), - std::numeric_limits::lowest(), std::numeric_limits::max_digits10, - std::is_signed::value, std::is_integral::value, literalSuffix}}, + return ResolvedType{Value{name, sizeof(T), device, Numeric{rank, std::numeric_limits::min(), static_cast(std::numeric_limits::max()), + std::numeric_limits::lowest(), std::numeric_limits::max_digits10, + std::is_signed::value, std::is_integral::value, literalSuffix}}, qualifiers}; } template - static ResolvedType createValue(const std::string &name, Qualifier qualifiers = Qualifier{0}) + static ResolvedType createValue(const std::string &name, Qualifier qualifiers = Qualifier{0}, bool device = false) { - return ResolvedType{Value{name, sizeof(T), std::nullopt}, qualifiers}; + return ResolvedType{Value{name, sizeof(T), device, std::nullopt}, qualifiers}; } static ResolvedType createFunction(const ResolvedType &returnType, const std::vector &argTypes, bool variadic=false) From 05d906a5e86799d5a4e4d42fd1d5c52a975b6822 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 09:42:11 +0100 Subject: [PATCH 351/725] fixed whitespace --- .../genn/code_generator/modelSpecMerged.cc | 196 +++++++++--------- 1 file changed, 98 insertions(+), 98 deletions(-) diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index c7b349d4f6..9aa7b8dcaf 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -100,117 +100,117 @@ ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBa void ModelSpecMerged::genMergedNeuronUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronUpdateGroups, - [](const NeuronGroupInternal &){ return true; }, - &NeuronGroupInternal::getHashDigest, generateGroup); + [](const NeuronGroupInternal &){ return true; }, + &NeuronGroupInternal::getHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedPresynapticUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedPresynapticUpdateGroups, - [](const SynapseGroupInternal &sg) { return (sg.isSpikeEventRequired() || sg.isTrueSpikeRequired()); }, - &SynapseGroupInternal::getWUHashDigest, generateGroup); + [](const SynapseGroupInternal &sg) { return (sg.isSpikeEventRequired() || sg.isTrueSpikeRequired()); }, + &SynapseGroupInternal::getWUHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedPostsynapticUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedPostsynapticUpdateGroups, - [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getLearnPostCode().empty(); }, - &SynapseGroupInternal::getWUHashDigest, generateGroup); + [](const SynapseGroupInternal &sg){ return !Utils::areTokensEmpty(sg.getWUPostLearnCodeTokens()); }, + &SynapseGroupInternal::getWUHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedSynapseDynamicsGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseDynamicsGroups, - [](const SynapseGroupInternal &sg){ return !sg.getWUModel()->getSynapseDynamicsCode().empty(); }, - &SynapseGroupInternal::getWUHashDigest, generateGroup); + [](const SynapseGroupInternal &sg){ return !Utils::areTokensEmpty(sg.getWUSynapseDynamicsCodeTokens()); }, + &SynapseGroupInternal::getWUHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedCustomUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) + GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateGroups, - [&updateGroupName](const CustomUpdateInternal &cg) { return cg.getUpdateGroupName() == updateGroupName; }, - &CustomUpdateInternal::getHashDigest, generateGroup); + [&updateGroupName](const CustomUpdateInternal &cg) { return cg.getUpdateGroupName() == updateGroupName; }, + &CustomUpdateInternal::getHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedCustomUpdateWUGroups(const BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) + GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, - [&updateGroupName](const CustomUpdateWUInternal &cg) - { - return (!cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName); - }, - &CustomUpdateWUInternal::getHashDigest, generateGroup); + [&updateGroupName](const CustomUpdateWUInternal &cg) + { + return (!cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName); + }, + &CustomUpdateWUInternal::getHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedCustomUpdateTransposeWUGroups(const BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) + GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, - [&updateGroupName](const CustomUpdateWUInternal &cg) - { - return (cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName); - }, - &CustomUpdateWUInternal::getHashDigest, generateGroup); + [&updateGroupName](const CustomUpdateWUInternal &cg) + { + return (cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName); + }, + &CustomUpdateWUInternal::getHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedCustomUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) + GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateHostReductionGroups, - [&updateGroupName](const CustomUpdateInternal &cg) - { - return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName); - }, - &CustomUpdateInternal::getHashDigest, generateGroup, true); + [&updateGroupName](const CustomUpdateInternal &cg) + { + return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName); + }, + &CustomUpdateInternal::getHashDigest, generateGroup, true); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedCustomWUUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) + GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateHostReductionGroups, - [&updateGroupName](const CustomUpdateWUInternal &cg) - { - return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName); - }, - &CustomUpdateWUInternal::getHashDigest, generateGroup, true); + [&updateGroupName](const CustomUpdateWUInternal &cg) + { + return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName); + }, + &CustomUpdateWUInternal::getHashDigest, generateGroup, true); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedCustomConnectivityUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) + GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateGroups, - [&updateGroupName](const CustomConnectivityUpdateInternal &cg) - { - return (!cg.getCustomConnectivityUpdateModel()->getRowUpdateCode().empty() && cg.getUpdateGroupName() == updateGroupName); - }, - &CustomConnectivityUpdateInternal::getHashDigest, generateGroup); + [&updateGroupName](const CustomConnectivityUpdateInternal &cg) + { + return (!Utils::areTokensEmpty(cg.getRowUpdateCodeTokens()) && cg.getUpdateGroupName() == updateGroupName); + }, + &CustomConnectivityUpdateInternal::getHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomConnectivityHostUpdateGroups(BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomConnectivityHostUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, + GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityHostUpdateGroups, - [&updateGroupName](const CustomConnectivityUpdateInternal &cg) - { - return (!cg.getCustomConnectivityUpdateModel()->getHostUpdateCode().empty() && cg.getUpdateGroupName() == updateGroupName); - }, - &CustomConnectivityUpdateInternal::getHashDigest, generateGroup, true); + [&updateGroupName](const CustomConnectivityUpdateInternal &cg) + { + return (!Utils::areTokensEmpty(cg.getHostUpdateCodeTokens()) && cg.getUpdateGroupName() == updateGroupName); + }, + &CustomConnectivityUpdateInternal::getHashDigest, generateGroup, true); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedNeuronSpikeQueueUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronSpikeQueueUpdateGroups, - [](const NeuronGroupInternal &){ return true; }, - &NeuronGroupInternal::getSpikeQueueUpdateHashDigest, generateGroup); + [](const NeuronGroupInternal &){ return true; }, + &NeuronGroupInternal::getSpikeQueueUpdateHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedNeuronPrevSpikeTimeUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronPrevSpikeTimeUpdateGroups, - [](const NeuronGroupInternal &ng){ return (ng.isPrevSpikeTimeRequired() || ng.isPrevSpikeEventTimeRequired()); }, - &NeuronGroupInternal::getPrevSpikeTimeUpdateHashDigest, generateGroup); + [](const NeuronGroupInternal &ng){ return (ng.isPrevSpikeTimeRequired() || ng.isPrevSpikeEventTimeRequired()); }, + &NeuronGroupInternal::getPrevSpikeTimeUpdateHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedSynapseDendriticDelayUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) @@ -224,108 +224,108 @@ void ModelSpecMerged::genMergedSynapseDendriticDelayUpdateGroups(const BackendBa } } createMergedGroups(backend, synapseGroupsWithDendriticDelay, m_MergedSynapseDendriticDelayUpdateGroups, - &SynapseGroupInternal::getDendriticDelayUpdateHashDigest, generateGroup); + &SynapseGroupInternal::getDendriticDelayUpdateHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedNeuronInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronInitGroups, - [](const NeuronGroupInternal &){ return true; }, - &NeuronGroupInternal::getInitHashDigest, generateGroup); + [](const NeuronGroupInternal &){ return true; }, + &NeuronGroupInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedCustomUpdateInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateInitGroups, - [](const CustomUpdateInternal &cg) { return cg.isVarInitRequired(); }, - &CustomUpdateInternal::getInitHashDigest, generateGroup); + [](const CustomUpdateInternal &cg) { return cg.isVarInitRequired(); }, + &CustomUpdateInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedCustomWUUpdateInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateInitGroups, - [](const CustomUpdateWUInternal &cg) - { - return (((cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::DENSE) + [](const CustomUpdateWUInternal &cg) + { + return (((cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::DENSE) || (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL)) - && cg.isVarInitRequired()); - }, - &CustomUpdateWUInternal::getInitHashDigest, generateGroup); + && cg.isVarInitRequired()); + }, + &CustomUpdateWUInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedSynapseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseInitGroups, - [](const SynapseGroupInternal &sg) - { - return (((sg.getMatrixType() & SynapseMatrixConnectivity::DENSE) - || (sg.getMatrixType() & SynapseMatrixWeight::KERNEL)) - && sg.isWUVarInitRequired()); - }, - &SynapseGroupInternal::getWUInitHashDigest, generateGroup); + [](const SynapseGroupInternal &sg) + { + return (((sg.getMatrixType() & SynapseMatrixConnectivity::DENSE) + || (sg.getMatrixType() & SynapseMatrixWeight::KERNEL)) + && sg.isWUVarInitRequired()); + }, + &SynapseGroupInternal::getWUInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedSynapseConnectivityInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseConnectivityInitGroups, - [](const SynapseGroupInternal &sg){ return sg.isSparseConnectivityInitRequired(); }, - &SynapseGroupInternal::getConnectivityInitHashDigest, generateGroup); + [](const SynapseGroupInternal &sg){ return sg.isSparseConnectivityInitRequired(); }, + &SynapseGroupInternal::getConnectivityInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedSynapseSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseSparseInitGroups, - [&backend](const SynapseGroupInternal &sg) - { - return ((sg.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && - (sg.isWUVarInitRequired() - || (backend.isPostsynapticRemapRequired() && !sg.getWUModel()->getLearnPostCode().empty()))); - }, - &SynapseGroupInternal::getWUInitHashDigest, generateGroup); + [&backend](const SynapseGroupInternal &sg) + { + return ((sg.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && + (sg.isWUVarInitRequired() + || (backend.isPostsynapticRemapRequired() && !sg.getWUModel()->getLearnPostCode().empty()))); + }, + &SynapseGroupInternal::getWUInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedCustomWUUpdateSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateSparseInitGroups, - [](const CustomUpdateWUInternal &cg) - { - return (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) && cg.isVarInitRequired(); - }, - &CustomUpdateWUInternal::getInitHashDigest, generateGroup); + [](const CustomUpdateWUInternal &cg) + { + return (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) && cg.isVarInitRequired(); + }, + &CustomUpdateWUInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedCustomConnectivityUpdatePreInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePreInitGroups, - [&backend](const CustomConnectivityUpdateInternal &cg) - { - return (cg.isPreVarInitRequired() || (backend.isPopulationRNGInitialisedOnDevice() && Utils::isRNGRequired(cg.getRowUpdateCodeTokens()))); - }, - &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); + [&backend](const CustomConnectivityUpdateInternal &cg) + { + return (cg.isPreVarInitRequired() || (backend.isPopulationRNGInitialisedOnDevice() && Utils::isRNGRequired(cg.getRowUpdateCodeTokens()))); + }, + &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedCustomConnectivityUpdatePostInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePostInitGroups, - [](const CustomConnectivityUpdateInternal &cg) { return cg.isPostVarInitRequired(); }, - &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); + [](const CustomConnectivityUpdateInternal &cg) { return cg.isPostVarInitRequired(); }, + &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedCustomConnectivityUpdateSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateSparseInitGroups, - [](const CustomConnectivityUpdateInternal &cg) { return cg.isVarInitRequired(); }, - &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); + [](const CustomConnectivityUpdateInternal &cg) { return cg.isVarInitRequired(); }, + &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- void ModelSpecMerged::genMergedSynapseConnectivityHostInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) { createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseConnectivityHostInitGroups, - [](const SynapseGroupInternal &sg) - { - return !sg.getConnectivityInitialiser().getSnippet()->getHostInitCode().empty(); - }, - &SynapseGroupInternal::getConnectivityHostInitHashDigest, generateGroup, true); + [](const SynapseGroupInternal &sg) + { + return !sg.getConnectivityInitialiser().getSnippet()->getHostInitCode().empty(); + }, + &SynapseGroupInternal::getConnectivityHostInitHashDigest, generateGroup, true); } //---------------------------------------------------------------------------- boost::uuids::detail::sha1::digest_type ModelSpecMerged::getHashDigest(const BackendBase &backend) const From 6d3216f9c1207b07eee97d6a56b8e91dcbc12b66 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 09:44:49 +0100 Subject: [PATCH 352/725] ``isGlobalHostRNGRequired`` and ``isGlobalDeviceRNGRequired`` can work on plain unmerged ModelSpecInternal --- .../genn/genn/code_generator/backendSIMT.h | 4 +- src/genn/genn/code_generator/backendSIMT.cc | 60 +++++++++---------- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 2120596714..8f55e8462c 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -142,8 +142,8 @@ class GENN_EXPORT BackendSIMT : public BackendBase //! Should 'scalar' variables be implemented on device or can host variables be used directly? virtual bool isDeviceScalarRequired() const final { return true; } - virtual bool isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) const final; - virtual bool isGlobalDeviceRNGRequired(const ModelSpecMerged &modelMerged) const final; + virtual bool isGlobalHostRNGRequired(const ModelSpecInternal &model) const final; + virtual bool isGlobalDeviceRNGRequired(const ModelSpecInternal &model) const final; virtual bool isPostsynapticRemapRequired() const final { return true; } diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 70540524b2..dd0843beaa 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -98,21 +98,19 @@ void BackendSIMT::genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env handler(varEnv); } //-------------------------------------------------------------------------- -bool BackendSIMT::isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) const +bool BackendSIMT::isGlobalHostRNGRequired(const ModelSpecInternal &model) const { // Host RNG is required if any synapse groups or custom connectivity updates require a host RNG - const ModelSpecInternal &model = modelMerged.getModel(); return (std::any_of(model.getSynapseGroups().cbegin(), model.getSynapseGroups().cend(), [](const ModelSpec::SynapseGroupValueType &s){ return s.second.getConnectivityInitialiser().isHostRNGRequired(); }) || std::any_of(model.getCustomConnectivityUpdates().cbegin(), model.getCustomConnectivityUpdates().cend(), [](const ModelSpec::CustomConnectivityUpdateValueType &c){ return Utils::isRNGRequired(c.second.getHostUpdateCodeTokens()); })); } //-------------------------------------------------------------------------- -bool BackendSIMT::isGlobalDeviceRNGRequired(const ModelSpecMerged &modelMerged) const +bool BackendSIMT::isGlobalDeviceRNGRequired(const ModelSpecInternal &model) const { // If any neuron groups require RNG for initialisation, return true // **NOTE** this takes postsynaptic model initialisation into account - const ModelSpecInternal &model = modelMerged.getModel(); if(std::any_of(model.getNeuronGroups().cbegin(), model.getNeuronGroups().cend(), [](const ModelSpec::NeuronGroupValueType &n){ return n.second.isInitRNGRequired(); })) { @@ -464,13 +462,13 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM neuronEnv.add(Type::Void, "_sh_spk_pos", "shSpkPos", {neuronEnv.addInitialiser(getSharedPrefix() + "unsigned int shSpkPos;")}); neuronEnv.add(Type::Void, "_sh_spk_count", "shSpkCount", - {neuronEnv.addInitialiser(shSpkCountInitStream.str()}); + {neuronEnv.addInitialiser(shSpkCountInitStream.str())}); neuronEnv.add(Type::Void, "_sh_spk_evnt", "shSpkEvnt", {neuronEnv.addInitialiser(getSharedPrefix() + "unsigned int shSpkEvnt[" + blockSizeStr + "];")}); neuronEnv.add(Type::Void, "_sh_spk_evnt_pos", "shSpkEvntPos", {neuronEnv.addInitialiser(getSharedPrefix() + "unsigned int shSpkEvntPos;")}); neuronEnv.add(Type::Void, "_sh_spk_evnt_count", "shSpkEvntCount", - {neuronEnv.addInitialiser(shSpkEvntCountInitStream.str()}); + {neuronEnv.addInitialiser(shSpkEvntCountInitStream.str())}); // If any neuron groups record spikes if(std::any_of(modelMerged.getMergedNeuronUpdateGroups().cbegin(), modelMerged.getMergedNeuronUpdateGroups().cend(), @@ -495,13 +493,13 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM [this](const NeuronGroupInternal &ng) { return padKernelSize(ng.getNumNeurons(), KernelNeuronUpdate); }, [batchSize, &modelMerged, this](EnvironmentExternalBase &popEnv, NeuronUpdateGroupMerged &ng) { + CodeStream::Scope b(popEnv.getStream()); + EnvironmentGroupMergedField groupEnv(popEnv, ng); + genNeuronIndexCalculation(groupEnv, batchSize); + // Call handler to generate generic neuron code popEnv.print("if($(id) < $(num_neurons))"); { - CodeStream::Scope b(popEnv.getStream()); - EnvironmentGroupMergedField groupEnv(popEnv, ng); - genNeuronIndexCalculation(groupEnv, batchSize); - // Add population RNG field groupEnv.addField(getPopulationRNGType().createPointer(), "_rng", "rng", [this](const auto &g, size_t) { return getDeviceVarPrefix() + "rng" + g.getName(); }, @@ -527,52 +525,52 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM //} } - genSharedMemBarrier(neuronEnv.getStream()); + genSharedMemBarrier(groupEnv.getStream()); if(ng.getArchetype().isSpikeEventRequired()) { - neuronEnv.getStream() << "if (" << getThreadID() << " == 1)"; + groupEnv.getStream() << "if (" << getThreadID() << " == 1)"; { - CodeStream::Scope b(neuronEnv.getStream()); - neuronEnv.getStream() << "if (shSpkEvntCount > 0)"; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.print("if($(_sh_spk_evnt_count) > 0)"); { - CodeStream::Scope b(neuronEnv.getStream()); - neuronEnv.getStream() << "shPosSpkEvnt = " << getAtomic(Type::Uint32) << "(&group->spkCntEvnt"; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.print("$(_sh_spk_evnt_pos) = " + getAtomic(Type::Uint32) + "(&$(_spk_cnt_evnt)"); if(ng.getArchetype().isDelayRequired()) { - neuronEnv.getStream() << "[*" << neuronEnv["_spk_que_ptr"]; + groupEnv.print("[*$(_spk_que_ptr)"); if(batchSize > 1) { - neuronEnv.getStream() << " + (batch * " << ng.getArchetype().getNumDelaySlots() << ")"; + groupEnv.getStream() << " + (batch * " << ng.getArchetype().getNumDelaySlots() << ")"; } - neuronEnv.getStream() << "], shSpkEvntCount);" << std::endl; + groupEnv.printLine("], $(_sh_spk_evnt_count));"); } else { - neuronEnv.getStream() << "[" << ((batchSize > 1) ? "batch" : "0") << "], shSpkEvntCount);" << std::endl; + groupEnv.printLine("[$(batch)], $(_sh_spk_evnt_count));"); } } } - genSharedMemBarrier(neuronEnv.getStream()); + genSharedMemBarrier(groupEnv.getStream()); } if(!ng.getArchetype().getNeuronModel()->getThresholdConditionCode().empty()) { - neuronEnv.getStream() << "if(" << getThreadID() << " == 0)"; + groupEnv.getStream() << "if(" << getThreadID() << " == 0)"; { - CodeStream::Scope b(neuronEnv.getStream()); - neuronEnv.getStream() << "if (shSpkCount > 0)"; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.print("if ($(_sh_spk_count) > 0)"); { - CodeStream::Scope b(neuronEnv.getStream()); - neuronEnv.getStream() << "shPosSpk = " << getAtomic(Type::Uint32) << "(&group->spkCnt"; + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.getStream() << "shPosSpk = " << getAtomic(Type::Uint32) << "(&group->spkCnt"; if(ng.getArchetype().isDelayRequired() && ng.getArchetype().isTrueSpikeRequired()) { - neuronEnv.getStream() << "[*" << neuronEnv["_spk_que_ptr"]; + groupEnv.getStream() << "[*" << groupEnv["_spk_que_ptr"]; if(batchSize > 1) { - neuronEnv.getStream() << " + (batch * " << ng.getArchetype().getNumDelaySlots() << ")"; + groupEnv.getStream() << " + (batch * " << ng.getArchetype().getNumDelaySlots() << ")"; } - neuronEnv.getStream() << "], shSpkCount);" << std::endl; + groupEnv.getStream() << "], shSpkCount);" << std::endl; } else { - neuronEnv.getStream() << "[" << ((batchSize > 1) ? "batch" : "0") << "], shSpkCount);" << std::endl; + groupEnv.getStream() << "[" << ((batchSize > 1) ? "batch" : "0") << "], shSpkCount);" << std::endl; } } } - genSharedMemBarrier(neuronEnv.getStream()); + genSharedMemBarrier(groupEnv.getStream()); } const std::string queueOffset = ng.getWriteVarIndex(ng.getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, ""); From afab3e3155c4021b4795540adf8254e1d68dce89 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 09:44:59 +0100 Subject: [PATCH 353/725] re-enabled generation of custom connectivity host update code in single-threaded CPU backend --- src/genn/backends/single_threaded_cpu/backend.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 7d1c510f25..29714ef225 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -554,13 +554,13 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host writePreciseLiteral(modelMerged.getModel().getDT(), modelMerged.getModel().getTimePrecision())); // Loop through host update groups and generate code for those in this custom update group - for (const auto &cg : modelMerged.getMergedCustomConnectivityHostUpdateGroups()) { - if (cg.getArchetype().getUpdateGroupName() == g) { - assert(false); - //cg.generateUpdate(*this, os); - } - } - + modelMerged.genMergedCustomConnectivityHostUpdateGroups( + *this, g, + [this, &customUpdateEnv, &modelMerged](auto &c) + { + c.generateUpdate(*this, customUpdateEnv, modelMerged); + }); + { Timer t(funcEnv.getStream(), "customUpdate" + g, model.isTimingEnabled()); modelMerged.genMergedCustomUpdateGroups( From c769983f71cdc65b54d8bf7f003eeeda4dacf1d9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 11:26:18 +0100 Subject: [PATCH 354/725] more work on CUDA backend --- include/genn/backends/cuda/backend.h | 44 +- .../backends/single_threaded_cpu/backend.h | 4 +- .../genn/genn/code_generator/backendBase.h | 5 +- src/genn/backends/cuda/backend.cc | 763 +++++++++--------- .../backends/single_threaded_cpu/backend.cc | 9 +- 5 files changed, 411 insertions(+), 414 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index b56653e091..f25662584a 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -312,6 +312,11 @@ class BACKEND_EXPORT Backend : public BackendSIMT //-------------------------------------------------------------------------- // Private methods //-------------------------------------------------------------------------- + std::string getNCCLReductionType(VarAccessMode mode) const; + std::string getNCCLType(const Type::ResolvedType &type) const; + + void genKernelDimensions(CodeStream &os, Kernel kernel, size_t numThreadsX, size_t batchSize, size_t numBlockThreadsY = 1) const; + template void genMergedStructArrayPush(CodeStream &os, const std::vector &groups) const { @@ -345,6 +350,43 @@ class BACKEND_EXPORT Backend : public BackendSIMT } } + template + void genNCCLReduction(EnvironmentExternal &env, G &cg) const + { + CodeStream::Scope b(env.getStream()); + env.getStream() << "// merged custom update host reduction group " << cg.getIndex() << std::endl; + env.getStream() << "for(unsigned int g = 0; g < " << cg.getGroups().size() << "; g++)"; + { + CodeStream::Scope b(env.getStream()); + + // Get reference to group + env.getStream() << "const auto *group = &merged" << G::name << "Group" << cg.getIndex() << "[g]; " << std::endl; + EnvironmentGroupMergedField groupEnv(env, g); + + // Loop through variables and add pointers if they are reduction targets + const CustomUpdateModels::Base *cm = cg.getArchetype().getCustomUpdateModel(); + for(const auto &v : cm->getVars()) { + if(v.access & VarAccessModeAttribute::REDUCE) { + groupEnv.addField(v.type.resolve(getGroup().getTypeContext()).createPointer(), "_" + v.name, v.name, + [this, v](const auto &g, size_t) + { + return getDeviceVarPrefix() + v.name + g.getName(); + }); + + groupEnv.print("CHECK_NCCL_ERRORS(ncclAllReduce($(_" + v.name + "), $(_" + v.name + "), $(_size)"); + groupEnv.printLine(", " + getNCCLType(v.type, cg.getTypeContext()) + ", " + getNCCLReductionType(getVarAccessMode(v.access)) + ", ncclCommunicator, 0));"); + } + } + + // Loop through variable references and add pointers if they are reduction targets + for(const auto &v : cm->getVarRefs()) { + if(v.access & VarAccessModeAttribute::REDUCE) { + os << "CHECK_NCCL_ERRORS(ncclAllReduce(group->" << v.name << ", group->" << v.name << ", group->size"; + os << ", " << getNCCLType(v.type, cg.getTypeContext()) << ", " << getNCCLReductionType(v.access) << ", ncclCommunicator, 0));" << std::endl; + } + } + } + } //! Get the safe amount of constant cache we can use size_t getChosenDeviceSafeConstMemBytes() const @@ -352,8 +394,6 @@ class BACKEND_EXPORT Backend : public BackendSIMT return m_ChosenDevice.totalConstMem - getPreferences().constantCacheOverhead; } - void genKernelDimensions(CodeStream &os, Kernel kernel, size_t numThreadsX, size_t batchSize, size_t numBlockThreadsY = 1) const; - //-------------------------------------------------------------------------- // Members //-------------------------------------------------------------------------- diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 5333fba10c..2832a088c8 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -154,8 +154,8 @@ class BACKEND_EXPORT Backend : public BackendBase //! Should 'scalar' variables be implemented on device or can host variables be used directly? virtual bool isDeviceScalarRequired() const final { return false; } - virtual bool isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) const final; - virtual bool isGlobalDeviceRNGRequired(const ModelSpecMerged &modelMerged) const final; + virtual bool isGlobalHostRNGRequired(const ModelSpecInternal &model) const final; + virtual bool isGlobalDeviceRNGRequired(const ModelSpecInternal &model) const final; //! Different backends seed RNGs in different ways. Does this one initialise population RNGS on device? virtual bool isPopulationRNGInitialisedOnDevice() const final { return false; } diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 292feee515..7de35ad48a 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -39,6 +39,7 @@ namespace CodeGenerator template class EnvironmentGroupMergedField; class EnvironmentExternalBase; +class ModelSpecInternal; class ModelSpecMerged; template class GroupMerged; @@ -381,10 +382,10 @@ class GENN_EXPORT BackendBase virtual bool isDeviceScalarRequired() const = 0; //! Different backends use different RNGs for different things. Does this one require a global host RNG for the specified model? - virtual bool isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) const = 0; + virtual bool isGlobalHostRNGRequired(const ModelSpecInternal &model) const = 0; //! Different backends use different RNGs for different things. Does this one require a global device RNG for the specified model? - virtual bool isGlobalDeviceRNGRequired(const ModelSpecMerged &modelMerged) const = 0; + virtual bool isGlobalDeviceRNGRequired(const ModelSpecInternal &model) const = 0; //! Different backends seed RNGs in different ways. Does this one initialise population RNGS on device? virtual bool isPopulationRNGInitialisedOnDevice() const = 0; diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 28795def1f..e886889097 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -50,8 +50,8 @@ const EnvironmentLibrary::Library doubleRandomFunctions = { //-------------------------------------------------------------------------- // CUDADeviceType //-------------------------------------------------------------------------- -const Type::ResolvedType CURandState = Type::ResolvedType::createValue("curandState"); -const Type::ResolvedType CURandStatePhilox43210 = Type::ResolvedType::createValue("curandStatePhilox4_32_10_t"); +const Type::ResolvedType CURandState = Type::ResolvedType::createValue("curandState", Type::Qualifier{0}, true); +const Type::ResolvedType CURandStatePhilox43210 = Type::ResolvedType::createValue("curandStatePhilox4_32_10_t", Type::Qualifier{0}, true); //-------------------------------------------------------------------------- // Timer @@ -148,35 +148,6 @@ void genMergedKernelDataStructures(CodeStream &os, size_t &totalConstMem, genGroupStartIDs(os, std::ref(idStart), std::ref(totalConstMem), args...); } //----------------------------------------------------------------------- -void genFilteredGroupStartIDs(CodeStream &, size_t&, size_t&) -{ -} -//----------------------------------------------------------------------- -template -void genFilteredGroupStartIDs(CodeStream &os, size_t &idStart, size_t &totalConstMem, - const std::vector &mergedGroups, G getPaddedNumThreads, F filter, - Args... args) -{ - // Loop through merged groups - for(const auto &m : mergedGroups) { - if(filter(m)) { - genGroupStartID(os, idStart, totalConstMem, m, getPaddedNumThreads); - } - } - - // Generate any remaining groups - genFilteredGroupStartIDs(os, idStart, totalConstMem, args...); -} -//----------------------------------------------------------------------- -template -void genFilteredMergedKernelDataStructures(CodeStream &os, size_t &totalConstMem, - Args... args) -{ - // Generate group start id arrays - size_t idStart = 0; - genFilteredGroupStartIDs(os, std::ref(idStart), std::ref(totalConstMem), args...); -} -//----------------------------------------------------------------------- template size_t getNumMergedGroupThreads(const std::vector &groups, G getNumThreads) { @@ -214,82 +185,6 @@ const EnvironmentLibrary::Library &getRNGFunctions(const Type::ResolvedType &pre return doubleRandomFunctions; } } -//----------------------------------------------------------------------- -std::string getNCCLReductionType(VarAccessMode mode) -{ - // Convert GeNN reduction types to NCCL - if(mode & VarAccessModeAttribute::MAX) { - return "ncclMax"; - } - else if(mode & VarAccessModeAttribute::SUM) { - return "ncclSum"; - } - else { - throw std::runtime_error("Reduction type unsupported by NCCL"); - } -} -//----------------------------------------------------------------------- -std::string getNCCLType(const Type::ResolvedType &type) -{ - assert(type.isNumeric()); - - // Convert GeNN types to NCCL types - if(type == Type::Int8) { - return "ncclInt8"; - } - else if(type == Type::Uint8) { - return "ncclUint8"; - } - else if(type == Type::Int32) { - return "ncclInt32"; - } - else if(type == Type::Uint32){ - return "ncclUint32"; - } - /*else if(type == "half") { - return "ncclFloat16"; - }*/ - else if(type == Type::Float){ - return "ncclFloat32"; - } - else if(type == Type::Double) { - return "ncclFloat64"; - } - else { - throw std::runtime_error("Data type '" + type.getName() + "' unsupported by NCCL"); - } -} -//----------------------------------------------------------------------- -template -void genNCCLReduction(CodeStream &os, const G &cg) -{ - CodeStream::Scope b(os); - os << "// merged custom update host reduction group " << cg.getIndex() << std::endl; - os << "for(unsigned int g = 0; g < " << cg.getGroups().size() << "; g++)"; - { - CodeStream::Scope b(os); - - // Get reference to group - os << "const auto *group = &merged" << G::name << "Group" << cg.getIndex() << "[g]; " << std::endl; - - // Loop through variables and add pointers if they are reduction targets - const CustomUpdateModels::Base *cm = cg.getArchetype().getCustomUpdateModel(); - for(const auto &v : cm->getVars()) { - if(v.access & VarAccessModeAttribute::REDUCE) { - os << "CHECK_NCCL_ERRORS(ncclAllReduce(group->" << v.name << ", group->" << v.name << ", group->size"; - os << ", " << getNCCLType(v.type, cg.getTypeContext()) << ", " << getNCCLReductionType(getVarAccessMode(v.access)) << ", ncclCommunicator, 0)); " << std::endl; - } - } - - // Loop through variable references and add pointers if they are reduction targets - for(const auto &v : cm->getVarRefs()) { - if(v.access & VarAccessModeAttribute::REDUCE) { - os << "CHECK_NCCL_ERRORS(ncclAllReduce(group->" << v.name << ", group->" << v.name << ", group->size"; - os << ", " << getNCCLType(v.type, cg.getTypeContext()) << ", " << getNCCLReductionType(v.access) << ", ncclCommunicator, 0));" << std::endl; - } - } - } -} } // Anonymous namespace //-------------------------------------------------------------------------- @@ -427,7 +322,9 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // If any neuron groups require their previous spike times updating size_t idNeuronPrevSpikeTimeUpdate = 0; - //if(!modelMerged.getMergedNeuronPrevSpikeTimeUpdateGroups().empty()) { + if(std::any_of(model.getNeuronGroups().cbegin(), model.getNeuronGroups().cend(), + [](const auto &ng){ return (ng.isPrevSpikeTimeRequired() || ng.isPrevSpikeEventTimeRequired())})) + { neuronUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelNeuronPrevSpikeTimeUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)"; { CodeStream::Scope b(neuronUpdateEnv.getStream()); @@ -447,7 +344,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host genNeuronPrevSpikeTimeUpdateKernel(funcEnv, modelMerged, idNeuronPrevSpikeTimeUpdate); } neuronUpdateEnv.getStream() << std::endl; - //} + } // Generate reset kernel to be run before the neuron kernel size_t idNeuronSpikeQueueUpdate = 0; @@ -563,6 +460,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // If any synapse groups require dendritic delay, a reset kernel is required to be run before the synapse kernel const ModelSpecInternal &model = modelMerged.getModel(); size_t idSynapseDendricDelayUpdate = 0; + //**TODO** slightly tricky check to do on models //if(!modelMerged.getMergedSynapseDendriticDelayUpdateGroups().empty()) { synapseUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDendriticDelayUpdate] << "()"; { @@ -576,7 +474,9 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // If there are any presynaptic update groups size_t idPresynapticStart = 0; - //if(!modelMerged.getMergedPresynapticUpdateGroups().empty()) { + if(std::any_of(model.getSynapseGroups().cbegin(), model.getSynapseGroups().cend(), + [](const auto &sg){ return (sg.isSpikeEventRequired() || sg.isTrueSpikeRequired()); })) + { synapseUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelPresynapticUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; // end of synapse kernel header { CodeStream::Scope b(synapseUpdateEnv.getStream()); @@ -597,11 +497,13 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos EnvironmentLibrary rngEnv(funcEnv, getRNGFunctions(model.getPrecision())); genPresynapticUpdateKernel(rngEnv, modelMerged, idPresynapticStart); } - //} + } // If any synapse groups require postsynaptic learning size_t idPostsynapticStart = 0; - if(!modelMerged.getMergedPostsynapticUpdateGroups().empty()) { + if(std::any_of(model.getSynapseGroups().cbegin(), model.getSynapseGroups().cend(), + [](const auto &sg){ return !Utils::areTokensEmpty(sg.getWUPostLearnCodeTokens()); })) + { synapseUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelPostsynapticUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; { CodeStream::Scope b(synapseUpdateEnv.getStream()); @@ -620,26 +522,28 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos genPostsynapticUpdateKernel(funcEnv, modelMerged, idPostsynapticStart); } } - + + // If any synapse groups require synapse dynamics size_t idSynapseDynamicsStart = 0; - if(!modelMerged.getMergedSynapseDynamicsGroups().empty()) { + if(std::any_of(model.getSynapseGroups().cbegin(), model.getSynapseGroups().cend(), + [](const auto &sg){ return !Utils::areTokensEmpty(sg.getWUSynapseDynamicsCodeTokens()); })) + { synapseUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDynamicsUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; // end of synapse kernel header { CodeStream::Scope b(synapseUpdateEnv.getStream()); EnvironmentExternal funcEnv(synapseUpdateEnv); - Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); - funcEnv.add("t", "t"); + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); - os << "const unsigned int id = " << getKernelBlockSize(KernelSynapseDynamicsUpdate) << " * blockIdx.x + threadIdx.x;" << std::endl; + funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelSynapseDynamicsUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; if(model.getBatchSize() > 1) { - os << "const unsigned int batch = blockIdx.y;" << std::endl; - kernelSubs.addVarSubstitution("batch", "batch"); + funcEnv.getStream() << "const unsigned int batch = blockIdx.y;" << std::endl; + funcEnv.add(Type::Uint32.addConst(), "batch", "batch"); } else { - kernelSubs.addVarSubstitution("batch", "0"); + funcEnv.add(Type::Uint32.addConst(), "batch", "0"); } - genSynapseDynamicsKernel(os, kernelSubs, modelMerged, idSynapseDynamicsStart); + genSynapseDynamicsKernel(funcEnv, modelMerged, idSynapseDynamicsStart); } } @@ -649,44 +553,44 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Launch pre-synapse reset kernel if required if(idSynapseDendricDelayUpdate > 0) { - CodeStream::Scope b(os); - genKernelDimensions(os, KernelSynapseDendriticDelayUpdate, idSynapseDendricDelayUpdate, 1); - os << KernelNames[KernelSynapseDendriticDelayUpdate] << "<<>>();" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; + CodeStream::Scope b(synapseUpdateEnv.getStream()); + genKernelDimensions(synapseUpdateEnv.getStream(), KernelSynapseDendriticDelayUpdate, idSynapseDendricDelayUpdate, 1); + synapseUpdateEnv.getStream() << KernelNames[KernelSynapseDendriticDelayUpdate] << "<<>>();" << std::endl; + synapseUpdateEnv.getStream() << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; } // Launch synapse dynamics kernel if required if(idSynapseDynamicsStart > 0) { - CodeStream::Scope b(os); - Timer t(os, "synapseDynamics", model.isTimingEnabled()); + CodeStream::Scope b(synapseUpdateEnv.getStream()); + Timer t(synapseUpdateEnv.getStream(), "synapseDynamics", model.isTimingEnabled()); - genKernelDimensions(os, KernelSynapseDynamicsUpdate, idSynapseDynamicsStart, model.getBatchSize()); - os << KernelNames[KernelSynapseDynamicsUpdate] << "<<>>(t);" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; + genKernelDimensions(synapseUpdateEnv.getStream(), KernelSynapseDynamicsUpdate, idSynapseDynamicsStart, model.getBatchSize()); + synapseUpdateEnv.getStream() << KernelNames[KernelSynapseDynamicsUpdate] << "<<>>(t);" << std::endl; + synapseUpdateEnv.getStream() << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; } // Launch presynaptic update kernel if(idPresynapticStart > 0) { CodeStream::Scope b(os); - Timer t(os, "presynapticUpdate", model.isTimingEnabled()); + Timer t(synapseUpdateEnv.getStream(), "presynapticUpdate", model.isTimingEnabled()); - genKernelDimensions(os, KernelPresynapticUpdate, idPresynapticStart, model.getBatchSize()); - os << KernelNames[KernelPresynapticUpdate] << "<<>>(t);" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; + genKernelDimensions(synapseUpdateEnv.getStream(), KernelPresynapticUpdate, idPresynapticStart, model.getBatchSize()); + synapseUpdateEnv.getStream() << KernelNames[KernelPresynapticUpdate] << "<<>>(t);" << std::endl; + synapseUpdateEnv.getStream() << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; } // Launch postsynaptic update kernel if(idPostsynapticStart > 0) { - CodeStream::Scope b(os); - Timer t(os, "postsynapticUpdate", model.isTimingEnabled()); + CodeStream::Scope b(synapseUpdateEnv.getStream()); + Timer t(synapseUpdateEnv.getStream(), "postsynapticUpdate", model.isTimingEnabled()); - genKernelDimensions(os, KernelPostsynapticUpdate, idPostsynapticStart, model.getBatchSize()); - os << KernelNames[KernelPostsynapticUpdate] << "<<>>(t);" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; + genKernelDimensions(synapseUpdateEnv.getStream(), KernelPostsynapticUpdate, idPostsynapticStart, model.getBatchSize()); + synapseUpdateEnv.getStream() << KernelNames[KernelPostsynapticUpdate] << "<<>>(t);" << std::endl; + synapseUpdateEnv.getStream() << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; } } - // Generate struct definitions + // Generate struct definitions modelMerged.genMergedSynapseDendriticDelayUpdateStructs(os, *this); modelMerged.genMergedPresynapticUpdateGroupStructs(os, *this); modelMerged.genMergedPostsynapticUpdateGroupStructs(os, *this); @@ -718,35 +622,18 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos } //-------------------------------------------------------------------------- -void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); - // Generate struct definitions - modelMerged.genMergedCustomUpdateStructs(os, *this); - modelMerged.genMergedCustomUpdateWUStructs(os, *this); - modelMerged.genMergedCustomUpdateTransposeWUStructs(os, *this); - modelMerged.genMergedCustomConnectivityUpdateStructs(os, *this); - - // Generate arrays of merged structs and functions to push them - genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateWUGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateTransposeWUGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdateGroups()); - - // Generate preamble - preambleHandler(os); + // Generate stream with synapse update code + std::ostringstream customUpdateStream; + CodeStream customUpdate(customUpdateStream); - // Generate data structure for accessing merged groups - // **NOTE** constant cache is preferentially given to neuron and synapse groups as, typically, they are launched more often - // than custom update kernels so subtract constant memory requirements of synapse group start ids from total constant memory - const size_t timestepGroupStartIDSize = (getGroupStartIDSize(modelMerged.getMergedPresynapticUpdateGroups()) + - getGroupStartIDSize(modelMerged.getMergedPostsynapticUpdateGroups()) + - getGroupStartIDSize(modelMerged.getMergedSynapseDynamicsGroups()) + - getGroupStartIDSize(modelMerged.getMergedNeuronUpdateGroups())); - size_t totalConstMem = (getChosenDeviceSafeConstMemBytes() > timestepGroupStartIDSize) ? (getChosenDeviceSafeConstMemBytes() - timestepGroupStartIDSize) : 0; + // Begin environment with standard library + EnvironmentLibrary customUpdateEnv(customUpdate, StandardLibrary::getMathsFunctions()); - // Build set containing union of all custom update groupsnames + // Build set containing union of all custom update group names std::set customUpdateGroups; std::transform(model.getCustomUpdates().cbegin(), model.getCustomUpdates().cend(), std::inserter(customUpdateGroups, customUpdateGroups.end()), @@ -762,116 +649,100 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged for(const auto &g : customUpdateGroups) { // Generate kernel size_t idCustomUpdateStart = 0; - if(std::any_of(modelMerged.getMergedCustomUpdateGroups().cbegin(), modelMerged.getMergedCustomUpdateGroups().cend(), - [&g](const CustomUpdateGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g); }) - || std::any_of(modelMerged.getMergedCustomUpdateWUGroups().cbegin(), modelMerged.getMergedCustomUpdateWUGroups().cend(), - [&g](const CustomUpdateWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g); }) - || std::any_of(modelMerged.getMergedCustomConnectivityUpdateGroups().cbegin(), modelMerged.getMergedCustomConnectivityUpdateGroups().cend(), - [&g](const CustomConnectivityUpdateGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g); })) + if(std::any_of(model.getCustomUpdates().cbegin(), model.getCustomUpdates().cend(), + [&g](const auto &cg) { return (cg.getUpdateGroupName() == g); }) + || std::any_of(model.getCustomWUUpdates().cbegin(), model.getCustomWUUpdates().cend(), + [&g](const auto &cg) { return (!cg.isTransposeOperation() && cg.getUpdateGroupName() == g); }) + || std::any_of(model.getCustomConnectivityUpdates().cbegin(), model.getCustomConnectivityUpdates().cend(), + [&g](const auto &cg) { return (!Utils::areTokensEmpty(cg.getRowUpdateCodeTokens()) && cg.getUpdateGroupName() == g); })) { - genFilteredMergedKernelDataStructures(os, totalConstMem, - modelMerged.getMergedCustomUpdateGroups(), - [&model, this](const CustomUpdateInternal &cg){ return getPaddedNumCustomUpdateThreads(cg, model.getBatchSize()); }, - [g](const CustomUpdateGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; }, - - modelMerged.getMergedCustomUpdateWUGroups(), - [&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateWUThreads(cg, model.getBatchSize()); }, - [g](const CustomUpdateWUGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; }, - - modelMerged.getMergedCustomConnectivityUpdateGroups(), - [this](const CustomConnectivityUpdateInternal &cg){ return padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelCustomUpdate); }, - [g](const CustomConnectivityUpdateGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; }); - - os << "extern \"C\" __global__ void " << KernelNames[KernelCustomUpdate] << g << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; + customUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelCustomUpdate] << g << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; { - CodeStream::Scope b(os); + CodeStream::Scope b(customUpdateEnv.getStream()); - Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); - kernelSubs.addVarSubstitution("t", "t"); + EnvironmentExternal funcEnv(customUpdateEnv); + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); - os << "const unsigned int id = " << getKernelBlockSize(KernelCustomUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; + funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelCustomUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom updates" << std::endl; - genCustomUpdateKernel(os, kernelSubs, modelMerged, g, idCustomUpdateStart); + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Custom updates" << std::endl; + genCustomUpdateKernel(funcEnv, modelMerged, g, idCustomUpdateStart); - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom WU updates" << std::endl; - genCustomUpdateWUKernel(os, kernelSubs, modelMerged, g, idCustomUpdateStart); + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Custom WU updates" << std::endl; + genCustomUpdateWUKernel(funcEnv, modelMerged, g, idCustomUpdateStart); - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom connectivity updates" << std::endl; - genCustomConnectivityUpdateKernel(os, kernelSubs, modelMerged, g, idCustomUpdateStart); + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Custom connectivity updates" << std::endl; + genCustomConnectivityUpdateKernel(funcEnv, modelMerged, g, idCustomUpdateStart); } } size_t idCustomTransposeUpdateStart = 0; - if(std::any_of(modelMerged.getMergedCustomUpdateTransposeWUGroups().cbegin(), modelMerged.getMergedCustomUpdateTransposeWUGroups().cend(), - [&g](const CustomUpdateTransposeWUGroupMerged &c){ return (c.getArchetype().getUpdateGroupName() == g); })) + if(std::any_of(model.getCustomWUUpdates().cbegin(), model.getCustomWUUpdates().cend(), + [&g](const auto &cg){ return (cg.isTransposeOperation() && cg.getUpdateGroupName() == g); })) { - genFilteredMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedCustomUpdateTransposeWUGroups(), - [&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); }, - [g](const CustomUpdateTransposeWUGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; }); - - os << "extern \"C\" __global__ void " << KernelNames[KernelCustomTransposeUpdate] << g << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; + customUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelCustomTransposeUpdate] << g << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; { - CodeStream::Scope b(os); + CodeStream::Scope b(customUpdateEnv.getStream()); - Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); - kernelSubs.addVarSubstitution("t", "t"); + EnvironmentExternal funcEnv(customUpdateEnv); + funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); - os << "const unsigned int id = " << getKernelBlockSize(KernelCustomTransposeUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; + funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelCustomTransposeUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; - os << "// ------------------------------------------------------------------------" << std::endl; - os << "// Custom WU transpose updates" << std::endl; - genCustomTransposeUpdateWUKernel(os, kernelSubs, modelMerged, g, idCustomTransposeUpdateStart); + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; + funcEnv.getStream() << "// Custom WU transpose updates" << std::endl; + genCustomTransposeUpdateWUKernel(funcEnv, modelMerged, g, idCustomTransposeUpdateStart); } } - os << "void update" << g << "()"; + customUpdateEnv.getStream() << "void update" << g << "()"; { - CodeStream::Scope b(os); + CodeStream::Scope b(customUpdateEnv.getStream()); // Loop through host update groups and generate code for those in this custom update group - for (const auto &cg : modelMerged.getMergedCustomConnectivityHostUpdateGroups()) { - if (cg.getArchetype().getUpdateGroupName() == g) { - cg.generateUpdate(*this, os); - } - } + modelMerged.genMergedCustomConnectivityHostUpdateGroups( + *this, g, + [this, &customUpdateEnv, &modelMerged](auto &c) + { + c.generateUpdate(*this, customUpdateEnv, modelMerged); + }); // Launch custom update kernel if required if(idCustomUpdateStart > 0) { - CodeStream::Scope b(os); - genKernelDimensions(os, KernelCustomUpdate, idCustomUpdateStart, 1); - Timer t(os, "customUpdate" + g, model.isTimingEnabled()); - os << KernelNames[KernelCustomUpdate] << g << "<<>>(t);" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; + CodeStream::Scope b(customUpdateEnv.getStream()); + genKernelDimensions(customUpdateEnv.getStream(), KernelCustomUpdate, idCustomUpdateStart, 1); + Timer t(customUpdateEnv.getStream(), "customUpdate" + g, model.isTimingEnabled()); + customUpdateEnv.getStream() << KernelNames[KernelCustomUpdate] << g << "<<>>(t);" << std::endl; + customUpdateEnv.getStream() << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; } // Launch custom transpose update kernel if required if(idCustomTransposeUpdateStart > 0) { - CodeStream::Scope b(os); + CodeStream::Scope b(customUpdateEnv.getStream()); // **TODO** make block height parameterizable - genKernelDimensions(os, KernelCustomTransposeUpdate, idCustomTransposeUpdateStart, 1, 8); - Timer t(os, "customUpdate" + g + "Transpose", model.isTimingEnabled()); - os << KernelNames[KernelCustomTransposeUpdate] << g << "<<>>(t);" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; + genKernelDimensions(customUpdateEnv.getStream(), KernelCustomTransposeUpdate, idCustomTransposeUpdateStart, 1, 8); + Timer t(customUpdateEnv.getStream(), "customUpdate" + g + "Transpose", model.isTimingEnabled()); + customUpdateEnv.getStream() << KernelNames[KernelCustomTransposeUpdate] << g << "<<>>(t);" << std::endl; + customUpdateEnv.getStream() << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; } // If NCCL reductions are enabled if(getPreferences().enableNCCLReductions) { // Loop through custom update host reduction groups and // generate reductions for those in this custom update group - for(const auto &cg : modelMerged.getMergedCustomUpdateHostReductionGroups()) { + for(auto &cg : modelMerged.getMergedCustomUpdateHostReductionGroups()) { if(cg.getArchetype().getUpdateGroupName() == g) { - genNCCLReduction(os, cg); + genNCCLReduction(customUpdateEnv, cg); } } // Loop through custom update host reduction groups and // generate reductions for those in this custom update group - for(const auto &cg : modelMerged.getMergedCustomWUUpdateHostReductionGroups()) { + for(auto &cg : modelMerged.getMergedCustomWUUpdateHostReductionGroups()) { if(cg.getArchetype().getUpdateGroupName() == g) { - genNCCLReduction(os, cg); + genNCCLReduction(customUpdateEnv, cg); } } } @@ -879,164 +750,155 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // If timing is enabled if(model.isTimingEnabled()) { // Synchronise last event - os << "CHECK_CUDA_ERRORS(cudaEventSynchronize(customUpdate" << g; + customUpdateEnv.getStream() << "CHECK_CUDA_ERRORS(cudaEventSynchronize(customUpdate" << g; if(idCustomTransposeUpdateStart > 0) { - os << "Transpose"; + customUpdateEnv.getStream() << "Transpose"; } - os << "Stop)); " << std::endl; + customUpdateEnv.getStream() << "Stop)); " << std::endl; if(idCustomUpdateStart > 0) { - CodeGenerator::CodeStream::Scope b(os); - os << "float tmp;" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaEventElapsedTime(&tmp, customUpdate" << g << "Start, customUpdate" << g << "Stop));" << std::endl; - os << "customUpdate" << g << "Time += tmp / 1000.0;" << std::endl; + CodeGenerator::CodeStream::Scope b(customUpdateEnv.getStream()); + customUpdateEnv.getStream() << "float tmp;" << std::endl; + customUpdateEnv.getStream() << "CHECK_CUDA_ERRORS(cudaEventElapsedTime(&tmp, customUpdate" << g << "Start, customUpdate" << g << "Stop));" << std::endl; + customUpdateEnv.getStream() << "customUpdate" << g << "Time += tmp / 1000.0;" << std::endl; } if(idCustomTransposeUpdateStart > 0) { - CodeGenerator::CodeStream::Scope b(os); - os << "float tmp;" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaEventElapsedTime(&tmp, customUpdate" << g << "TransposeStart, customUpdate" << g << "TransposeStop));" << std::endl; - os << "customUpdate" << g << "TransposeTime += tmp / 1000.0;" << std::endl; + CodeGenerator::CodeStream::Scope b(customUpdateEnv.getStream()); + customUpdateEnv.getStream() << "float tmp;" << std::endl; + customUpdateEnv.getStream() << "CHECK_CUDA_ERRORS(cudaEventElapsedTime(&tmp, customUpdate" << g << "TransposeStart, customUpdate" << g << "TransposeStop));" << std::endl; + customUpdateEnv.getStream() << "customUpdate" << g << "TransposeTime += tmp / 1000.0;" << std::endl; } } } } -} -//-------------------------------------------------------------------------- -void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHandler preambleHandler) const -{ - os << "#include " << std::endl; - os << "#include " << std::endl; - os << "#include " << std::endl; - os << std::endl; // Generate struct definitions - modelMerged.genMergedNeuronInitGroupStructs(os, *this); - modelMerged.genMergedSynapseInitGroupStructs(os, *this); - modelMerged.genMergedSynapseConnectivityInitGroupStructs(os, *this); - modelMerged.genMergedSynapseSparseInitGroupStructs(os, *this); - modelMerged.genMergedCustomUpdateInitGroupStructs(os, *this); - modelMerged.genMergedCustomWUUpdateInitGroupStructs(os, *this); - modelMerged.genMergedCustomWUUpdateSparseInitGroupStructs(os, *this); - modelMerged.genMergedCustomConnectivityUpdatePreInitStructs(os, *this); - modelMerged.genMergedCustomConnectivityUpdatePostInitStructs(os, *this); - modelMerged.genMergedCustomConnectivityUpdateSparseInitStructs(os, *this); + modelMerged.genMergedCustomUpdateStructs(os, *this); + modelMerged.genMergedCustomUpdateWUStructs(os, *this); + modelMerged.genMergedCustomUpdateTransposeWUStructs(os, *this); + modelMerged.genMergedCustomConnectivityUpdateStructs(os, *this); // Generate arrays of merged structs and functions to push them - genMergedStructArrayPush(os, modelMerged.getMergedNeuronInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedSynapseInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedSynapseConnectivityInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedSynapseSparseInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomWUUpdateInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomWUUpdateSparseInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdatePreInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdatePostInitGroups()); - genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups()); - + genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateWUGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateTransposeWUGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdateGroups()); + // Generate preamble preambleHandler(os); - // Generate data structure for accessing merged groups from within initialisation kernel - // **NOTE** pass in zero constant cache here as it's precious and would be wasted on init kernels which are only launched once - const ModelSpecInternal &model = modelMerged.getModel(); - size_t totalConstMem = 0; + // Generate data structure for accessing merged groups + // **THINK** I don't think there was any need for these to be filtered + // **NOTE** constant cache is preferentially given to neuron and synapse groups as, typically, they are launched more often + // than custom update kernels so subtract constant memory requirements of synapse group start ids from total constant memory + const size_t timestepGroupStartIDSize = (getGroupStartIDSize(modelMerged.getMergedPresynapticUpdateGroups()) + + getGroupStartIDSize(modelMerged.getMergedPostsynapticUpdateGroups()) + + getGroupStartIDSize(modelMerged.getMergedSynapseDynamicsGroups()) + + getGroupStartIDSize(modelMerged.getMergedNeuronUpdateGroups())); + size_t totalConstMem = (getChosenDeviceSafeConstMemBytes() > timestepGroupStartIDSize) ? (getChosenDeviceSafeConstMemBytes() - timestepGroupStartIDSize) : 0; + const unsigned int batchSize = model.getBatchSize(); genMergedKernelDataStructures( - os, totalConstMem, - modelMerged.getMergedNeuronInitGroups(), [this](const NeuronGroupInternal &ng){ return padKernelSize(ng.getNumNeurons(), KernelInitialize); }, - modelMerged.getMergedSynapseInitGroups(), [this](const SynapseGroupInternal &sg){ return padKernelSize(getNumInitThreads(sg), KernelInitialize); }, - modelMerged.getMergedCustomUpdateInitGroups(), [this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelInitialize); }, - modelMerged.getMergedCustomConnectivityUpdatePreInitGroups(), [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelInitialize); }, - modelMerged.getMergedCustomConnectivityUpdatePostInitGroups(), [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons(), KernelInitialize); }, - modelMerged.getMergedCustomWUUpdateInitGroups(), [this](const CustomUpdateWUInternal &cg){ return padKernelSize(getNumInitThreads(cg), KernelInitialize); }, - modelMerged.getMergedSynapseConnectivityInitGroups(), [this](const SynapseGroupInternal &sg){ return padKernelSize(getNumConnectivityInitThreads(sg), KernelInitialize); }); + os, totalConstMem, + modelMerged.getMergedCustomUpdateGroups(), [batchSize, this](const CustomUpdateInternal &cg){ return getPaddedNumCustomUpdateThreads(cg, batchSize); }, + modelMerged.getMergedCustomUpdateWUGroups(), [batchSize, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateWUThreads(cg, batchSize); }, + modelMerged.getMergedCustomConnectivityUpdateGroups(), [this](const CustomConnectivityUpdateInternal &cg){ return padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelCustomUpdate); }, + modelMerged.getMergedCustomUpdateTransposeWUGroups(), [batchSize, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateTransposeWUThreads(cg, batchSize); }); - // Generate data structure for accessing merged groups from within sparse initialisation kernel - genMergedKernelDataStructures( - os, totalConstMem, - modelMerged.getMergedSynapseSparseInitGroups(), [this](const SynapseGroupInternal &sg){ return padKernelSize(sg.getMaxConnections(), KernelInitializeSparse); }, - modelMerged.getMergedCustomWUUpdateSparseInitGroups(), [this](const CustomUpdateWUInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, - modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups(), [this](const CustomConnectivityUpdateInternal &cg){ return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }); - os << std::endl; + os << customUpdateStream.str(); +} +//-------------------------------------------------------------------------- +void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +{ + const ModelSpecInternal &model = modelMerged.getModel(); + + // Generate stream with synapse update code + std::ostringstream initStream; + CodeStream init(initStream); + + // Begin environment with standard library + EnvironmentLibrary initEnv(init, StandardLibrary::getMathsFunctions()); // If device RNG is required, generate kernel to initialise it - if(isGlobalDeviceRNGRequired(modelMerged)) { - os << "extern \"C\" __global__ void initializeRNGKernel(unsigned long long deviceRNGSeed)"; + if(isGlobalDeviceRNGRequired(model)) { + initEnv.getStream() << "extern \"C\" __global__ void initializeRNGKernel(unsigned long long deviceRNGSeed)"; { - CodeStream::Scope b(os); - os << "if(threadIdx.x == 0)"; + CodeStream::Scope b(initEnv.getStream()); + initEnv.getStream() << "if(threadIdx.x == 0)"; { - CodeStream::Scope b(os); - os << "curand_init(deviceRNGSeed, 0, 0, &d_rng);" << std::endl; + CodeStream::Scope b(initEnv.getStream()); + initEnv.getStream() << "curand_init(deviceRNGSeed, 0, 0, &d_rng);" << std::endl; } } - os << std::endl; + initEnv.getStream() << std::endl; } // init kernel header - os << "extern \"C\" __global__ void " << KernelNames[KernelInitialize] << "(unsigned long long deviceRNGSeed)"; + initEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelInitialize] << "(unsigned long long deviceRNGSeed)"; // initialization kernel code size_t idInitStart = 0; { - Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); - // common variables for all cases - CodeStream::Scope b(os); + CodeStream::Scope b(initEnv.getStream()); - os << "const unsigned int id = " << getKernelBlockSize(KernelInitialize) << " * blockIdx.x + threadIdx.x;" << std::endl; - genInitializeKernel(os, kernelSubs, modelMerged, idInitStart); + initEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelInitialize) << " * blockIdx.x + threadIdx.x;" << std::endl; + genInitializeKernel(initEnv, modelMerged, idInitStart); } const size_t numStaticInitThreads = idInitStart; + /*((sg.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && + (sg.isWUVarInitRequired() + || (backend.isPostsynapticRemapRequired() && !sg.getWUModel()->getLearnPostCode().empty())));*/ + // (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) && cg.isVarInitRequired(); + // return cg.isVarInitRequired(); // Sparse initialization kernel code size_t idSparseInitStart = 0; + //if(std::any_of(model.getSynapseGroups().cbegin(), model.getSynapseGroups().cend(), + // [](const auto &sg){}) if(!modelMerged.getMergedSynapseSparseInitGroups().empty() || !modelMerged.getMergedCustomWUUpdateSparseInitGroups().empty() || !modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups().empty()) { - os << "extern \"C\" __global__ void " << KernelNames[KernelInitializeSparse] << "()"; + initEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelInitializeSparse] << "()"; { - CodeStream::Scope b(os); - - // common variables for all cases - Substitutions kernelSubs(getFunctionTemplates(model.getPrecision())); + CodeStream::Scope b(initEnv.getStream()); - os << "const unsigned int id = " << getKernelBlockSize(KernelInitializeSparse) << " * blockIdx.x + threadIdx.x;" << std::endl; - genInitializeSparseKernel(os, kernelSubs, modelMerged, numStaticInitThreads, idSparseInitStart); + initEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelInitializeSparse) << " * blockIdx.x + threadIdx.x;" << std::endl; + genInitializeSparseKernel(initEnv, modelMerged, numStaticInitThreads, idSparseInitStart); } } - os << "void initialize()"; + initEnv.getStream() << "void initialize()"; { - CodeStream::Scope b(os); + CodeStream::Scope b(initEnv.getStream()); - os << "unsigned long long deviceRNGSeed = 0;" << std::endl; + initEnv.getStream() << "unsigned long long deviceRNGSeed = 0;" << std::endl; // If any sort of on-device global RNG is required const bool simRNGRequired = std::any_of(model.getNeuronGroups().cbegin(), model.getNeuronGroups().cend(), [](const ModelSpec::NeuronGroupValueType &n) { return n.second.isSimRNGRequired(); }); - const bool globalDeviceRNGRequired = isGlobalDeviceRNGRequired(modelMerged); + const bool globalDeviceRNGRequired = isGlobalDeviceRNGRequired(model); if(simRNGRequired || globalDeviceRNGRequired) { // If no seed is specified if (model.getSeed() == 0) { - CodeStream::Scope b(os); + CodeStream::Scope b(initEnv.getStream()); // Use system randomness to generate one unsigned long long worth of seed words - os << "std::random_device seedSource;" << std::endl; - os << "uint32_t *deviceRNGSeedWord = reinterpret_cast(&deviceRNGSeed);" << std::endl; - os << "for(int i = 0; i < " << sizeof(unsigned long long) / sizeof(uint32_t) << "; i++)"; + initEnv.getStream() << "std::random_device seedSource;" << std::endl; + initEnv.getStream() << "uint32_t *deviceRNGSeedWord = reinterpret_cast(&deviceRNGSeed);" << std::endl; + initEnv.getStream() << "for(int i = 0; i < " << sizeof(unsigned long long) / sizeof(uint32_t) << "; i++)"; { - CodeStream::Scope b(os); - os << "deviceRNGSeedWord[i] = seedSource();" << std::endl; + CodeStream::Scope b(initEnv.getStream()); + initEnv.getStream() << "deviceRNGSeedWord[i] = seedSource();" << std::endl; } } // Otherwise, use model seed else { - os << "deviceRNGSeed = " << model.getSeed() << ";" << std::endl; + initEnv.getStream() << "deviceRNGSeed = " << model.getSeed() << ";" << std::endl; } // If global RNG is required, launch kernel to initalize it if (globalDeviceRNGRequired) { - os << "initializeRNGKernel<<<1, 1>>>(deviceRNGSeed);" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; + initEnv.getStream() << "initializeRNGKernel<<<1, 1>>>(deviceRNGSeed);" << std::endl; + initEnv.getStream() << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; } } @@ -1047,56 +909,110 @@ void Backend::genInit(CodeStream &os, const ModelSpecMerged &modelMerged, HostHa // If this synapse population has BITMASK connectivity and is intialised on device, insert a call to cudaMemset to zero the whole bitmask if(s.second.isSparseConnectivityInitRequired() && s.second.getMatrixType() & SynapseMatrixConnectivity::BITMASK) { const size_t gpSize = ceilDivide((size_t)s.second.getSrcNeuronGroup()->getNumNeurons() * getSynapticMatrixRowStride(s.second), 32); - os << "CHECK_CUDA_ERRORS(cudaMemset(d_gp" << s.first << ", 0, " << gpSize << " * sizeof(uint32_t)));" << std::endl; + initEnv.getStream() << "CHECK_CUDA_ERRORS(cudaMemset(d_gp" << s.first << ", 0, " << gpSize << " * sizeof(uint32_t)));" << std::endl; } // If this synapse population has SPARSE connectivity and column-based on device connectivity, insert a call to cudaMemset to zero row lengths // **NOTE** we could also use this code path for row-based connectivity but, leaving this in the kernel is much better as it gets merged if(s.second.getMatrixType() & SynapseMatrixConnectivity::SPARSE && !s.second.getConnectivityInitialiser().getSnippet()->getColBuildCode().empty()) { - os << "CHECK_CUDA_ERRORS(cudaMemset(d_rowLength" << s.first << ", 0, " << s.second.getSrcNeuronGroup()->getNumNeurons() << " * sizeof(unsigned int)));" << std::endl; + initEnv.getStream() << "CHECK_CUDA_ERRORS(cudaMemset(d_rowLength" << s.first << ", 0, " << s.second.getSrcNeuronGroup()->getNumNeurons() << " * sizeof(unsigned int)));" << std::endl; } // If this synapse population has SPARSE connectivity and has postsynaptic learning, insert a call to cudaMemset to zero column lengths if((s.second.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && !s.second.getWUModel()->getLearnPostCode().empty()) { - os << "CHECK_CUDA_ERRORS(cudaMemset(d_colLength" << s.first << ", 0, " << s.second.getTrgNeuronGroup()->getNumNeurons() << " * sizeof(unsigned int)));" << std::endl; + initEnv.getStream() << "CHECK_CUDA_ERRORS(cudaMemset(d_colLength" << s.first << ", 0, " << s.second.getTrgNeuronGroup()->getNumNeurons() << " * sizeof(unsigned int)));" << std::endl; } } // If there are any initialisation threads if(idInitStart > 0) { - CodeStream::Scope b(os); + CodeStream::Scope b(initEnv.getStream()); { - Timer t(os, "init", model.isTimingEnabled(), true); + Timer t(initEnv.getStream(), "init", model.isTimingEnabled(), true); - genKernelDimensions(os, KernelInitialize, idInitStart, 1); - os << KernelNames[KernelInitialize] << "<<>>(deviceRNGSeed);" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; + genKernelDimensions(initEnv.getStream(), KernelInitialize, idInitStart, 1); + initEnv.getStream() << KernelNames[KernelInitialize] << "<<>>(deviceRNGSeed);" << std::endl; + initEnv.getStream() << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; } } } - os << std::endl; - os << "void initializeSparse()"; + initEnv.getStream() << std::endl; + initEnv.getStream() << "void initializeSparse()"; { - CodeStream::Scope b(os); + CodeStream::Scope b(initEnv.getStream()); // Copy all uninitialised state variables to device if(!getPreferences().automaticCopy) { - os << "copyStateToDevice(true);" << std::endl; - os << "copyConnectivityToDevice(true);" << std::endl << std::endl; + initEnv.getStream() << "copyStateToDevice(true);" << std::endl; + initEnv.getStream() << "copyConnectivityToDevice(true);" << std::endl << std::endl; } // If there are any sparse initialisation threads if(idSparseInitStart > 0) { - CodeStream::Scope b(os); + CodeStream::Scope b(initEnv.getStream()); { - Timer t(os, "initSparse", model.isTimingEnabled(), true); + Timer t(initEnv.getStream(), "initSparse", model.isTimingEnabled(), true); - genKernelDimensions(os, KernelInitializeSparse, idSparseInitStart, 1); - os << KernelNames[KernelInitializeSparse] << "<<>>();" << std::endl; - os << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; + genKernelDimensions(initEnv.getStream(), KernelInitializeSparse, idSparseInitStart, 1); + initEnv.getStream() << KernelNames[KernelInitializeSparse] << "<<>>();" << std::endl; + initEnv.getStream() << "CHECK_CUDA_ERRORS(cudaPeekAtLastError());" << std::endl; } } } + + os << "#include " << std::endl; + os << "#include " << std::endl; + os << "#include " << std::endl; + os << std::endl; + + // Generate struct definitions + modelMerged.genMergedNeuronInitGroupStructs(os, *this); + modelMerged.genMergedSynapseInitGroupStructs(os, *this); + modelMerged.genMergedSynapseConnectivityInitGroupStructs(os, *this); + modelMerged.genMergedSynapseSparseInitGroupStructs(os, *this); + modelMerged.genMergedCustomUpdateInitGroupStructs(os, *this); + modelMerged.genMergedCustomWUUpdateInitGroupStructs(os, *this); + modelMerged.genMergedCustomWUUpdateSparseInitGroupStructs(os, *this); + modelMerged.genMergedCustomConnectivityUpdatePreInitStructs(os, *this); + modelMerged.genMergedCustomConnectivityUpdatePostInitStructs(os, *this); + modelMerged.genMergedCustomConnectivityUpdateSparseInitStructs(os, *this); + + // Generate arrays of merged structs and functions to push them + genMergedStructArrayPush(os, modelMerged.getMergedNeuronInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedSynapseInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedSynapseConnectivityInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedSynapseSparseInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomUpdateInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomWUUpdateInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomWUUpdateSparseInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdatePreInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdatePostInitGroups()); + genMergedStructArrayPush(os, modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups()); + + // Generate preamble + preambleHandler(os); + + // Generate data structure for accessing merged groups from within initialisation kernel + // **NOTE** pass in zero constant cache here as it's precious and would be wasted on init kernels which are only launched once + const ModelSpecInternal &model = modelMerged.getModel(); + size_t totalConstMem = 0; + genMergedKernelDataStructures( + os, totalConstMem, + modelMerged.getMergedNeuronInitGroups(), [this](const NeuronGroupInternal &ng){ return padKernelSize(ng.getNumNeurons(), KernelInitialize); }, + modelMerged.getMergedSynapseInitGroups(), [this](const SynapseGroupInternal &sg){ return padKernelSize(getNumInitThreads(sg), KernelInitialize); }, + modelMerged.getMergedCustomUpdateInitGroups(), [this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelInitialize); }, + modelMerged.getMergedCustomConnectivityUpdatePreInitGroups(), [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelInitialize); }, + modelMerged.getMergedCustomConnectivityUpdatePostInitGroups(), [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons(), KernelInitialize); }, + modelMerged.getMergedCustomWUUpdateInitGroups(), [this](const CustomUpdateWUInternal &cg){ return padKernelSize(getNumInitThreads(cg), KernelInitialize); }, + modelMerged.getMergedSynapseConnectivityInitGroups(), [this](const SynapseGroupInternal &sg){ return padKernelSize(getNumConnectivityInitThreads(sg), KernelInitialize); }); + + // Generate data structure for accessing merged groups from within sparse initialisation kernel + genMergedKernelDataStructures( + os, totalConstMem, + modelMerged.getMergedSynapseSparseInitGroups(), [this](const SynapseGroupInternal &sg){ return padKernelSize(sg.getMaxConnections(), KernelInitializeSparse); }, + modelMerged.getMergedCustomWUUpdateSparseInitGroups(), [this](const CustomUpdateWUInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, + modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups(), [this](const CustomConnectivityUpdateInternal &cg){ return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }); + os << std::endl; } //-------------------------------------------------------------------------- void Backend::genDefinitionsPreamble(CodeStream &os, const ModelSpecMerged &) const @@ -1564,60 +1480,59 @@ void Backend::genStepTimeFinalisePreamble(CodeStream &os, const ModelSpecMerged } //-------------------------------------------------------------------------- void Backend::genVariableDefinition(CodeStream &definitions, CodeStream &definitionsInternal, - const Type::ValueBase *type, const std::string &name, VarLocation loc) const + const Type::ResolvedType &type, const std::string &name, VarLocation loc) const { - const bool deviceType = dynamic_cast(type); - CodeStream &d = deviceType ? definitionsInternal : definitions; + CodeStream &d = type.getValue().device ? definitionsInternal : definitions; if(getPreferences().automaticCopy) { // Export pointer, either in definitionsInternal if variable has a device type // or to definitions if it should be accessable on host - d << "EXPORT_VAR " << type->getPointerType()->getName() << " " << name << ";" << std::endl; + d << "EXPORT_VAR " << type.getValue().name << "* " << name << ";" << std::endl; } else { if(loc & VarLocation::HOST) { - if(deviceType) { - throw std::runtime_error("Variable '" + name + "' is of device-only type '" + type->getPointerType()->getName() + "' but is located on the host"); + if(type.getValue().device) { + throw std::runtime_error("Variable '" + name + "' is of device-only type '" + type.getValue().name + "' but is located on the host"); } - definitions << "EXPORT_VAR " << type->getPointerType()->getName() << " " << name << ";" << std::endl; + definitions << "EXPORT_VAR " << type.getValue().name << "* " << name << ";" << std::endl; } if(loc & VarLocation::DEVICE) { // Write host definition to internal definitions stream if type is device only - d << "EXPORT_VAR " << type->getPointerType()->getName() << " d_" << name << ";" << std::endl; + d << "EXPORT_VAR " << type.getValue().name << "* d_" << name << ";" << std::endl; } } } //-------------------------------------------------------------------------- void Backend::genVariableInstantiation(CodeStream &os, - const Type::ValueBase *type, const std::string &name, VarLocation loc) const + const Type::ResolvedType &type, const std::string &name, VarLocation loc) const { if(getPreferences().automaticCopy) { - os << type->getPointerType()->getName() << " " << name << ";" << std::endl; + os << type.getValue().name << "* " << name << ";" << std::endl; } else { if(loc & VarLocation::HOST) { - os << type->getPointerType()->getName() << " " << name << ";" << std::endl; + os << type.getValue().name << "* " << name << ";" << std::endl; } if(loc & VarLocation::DEVICE) { - os << type->getPointerType()->getName() << " d_" << name << ";" << std::endl; + os << type.getValue().name << "* d_" << name << ";" << std::endl; } } } //-------------------------------------------------------------------------- void Backend::genVariableAllocation(CodeStream &os, - const Type::ValueBase *type, const Type::TypeContext &typeContext, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, size_t count, MemAlloc &memAlloc) const { if(getPreferences().automaticCopy) { - os << "CHECK_CUDA_ERRORS(cudaMallocManaged(&" << name << ", " << count << " * sizeof(" << type->getName() << ")));" << std::endl; - memAlloc += MemAlloc::device(count * type->getSizeBytes(typeContext)); + os << "CHECK_CUDA_ERRORS(cudaMallocManaged(&" << name << ", " << count << " * sizeof(" << type.getName() << ")));" << std::endl; + memAlloc += MemAlloc::device(count * type.getSize(getPointerBytes())); } else { if(loc & VarLocation::HOST) { const char *flags = (loc & VarLocation::ZERO_COPY) ? "cudaHostAllocMapped" : "cudaHostAllocPortable"; - os << "CHECK_CUDA_ERRORS(cudaHostAlloc(&" << name << ", " << count << " * sizeof(" << type->getName() << "), " << flags << "));" << std::endl; - memAlloc += MemAlloc::host(count * type->getSizeBytes(typeContext)); + os << "CHECK_CUDA_ERRORS(cudaHostAlloc(&" << name << ", " << count << " * sizeof(" << type.getName() << "), " << flags << "));" << std::endl; + memAlloc += MemAlloc::host(count * type.getSize(getPointerBytes())); } // If variable is present on device at all @@ -1625,32 +1540,31 @@ void Backend::genVariableAllocation(CodeStream &os, // Insert call to correct helper depending on whether variable should be allocated in zero-copy mode or not if(loc & VarLocation::ZERO_COPY) { os << "CHECK_CUDA_ERRORS(cudaHostGetDevicePointer((void **)&d_" << name << ", (void *)" << name << ", 0));" << std::endl; - memAlloc += MemAlloc::zeroCopy(count * type->getSizeBytes(typeContext)); + memAlloc += MemAlloc::zeroCopy(count * type.getSize(getPointerBytes())); } else { - os << "CHECK_CUDA_ERRORS(cudaMalloc(&d_" << name << ", " << count << " * sizeof(" << type->getName() << ")));" << std::endl; - memAlloc += MemAlloc::device(count * type->getSizeBytes(typeContext)); + os << "CHECK_CUDA_ERRORS(cudaMalloc(&d_" << name << ", " << count << " * sizeof(" << type.getName() << ")));" << std::endl; + memAlloc += MemAlloc::device(count * type.getSize(getPointerBytes())); } } } } //-------------------------------------------------------------------------- void Backend::genVariableDynamicAllocation(CodeStream &os, - const Type::Base *type, const std::string &name, VarLocation loc, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName, const std::string &prefix) const { - const auto *pointerType = dynamic_cast(type); - const auto *underlyingType = pointerType ? pointerType->getValueType() : type; - const std::string hostPointer = pointerType ? ("*" + prefix + name) : (prefix + name); - const std::string hostPointerToPointer = pointerType ? (prefix + name) : ("&" + prefix + name); - const std::string devicePointerToPointer = pointerType ? (prefix + "d_" + name) : ("&" + prefix + "d_" + name); + const auto &underlyingType = type.isPointer() ? *type.getPointer().valueType : type; + const std::string hostPointer = type.isPointer() ? ("*" + prefix + name) : (prefix + name); + const std::string hostPointerToPointer = type.isPointer() ? (prefix + name) : ("&" + prefix + name); + const std::string devicePointerToPointer = type.isPointer() ? (prefix + "d_" + name) : ("&" + prefix + "d_" + name); if(getPreferences().automaticCopy) { - os << "CHECK_CUDA_ERRORS(cudaMallocManaged(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType->getName() << ")));" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaMallocManaged(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType.getName() << ")));" << std::endl; } else { if(loc & VarLocation::HOST) { const char *flags = (loc & VarLocation::ZERO_COPY) ? "cudaHostAllocMapped" : "cudaHostAllocPortable"; - os << "CHECK_CUDA_ERRORS(cudaHostAlloc(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType->getName() << "), " << flags << "));" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaHostAlloc(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType.getName() << "), " << flags << "));" << std::endl; } // If variable is present on device at all @@ -1659,7 +1573,7 @@ void Backend::genVariableDynamicAllocation(CodeStream &os, os << "CHECK_CUDA_ERRORS(cudaHostGetDevicePointer((void**)" << devicePointerToPointer << ", (void*)" << hostPointer << ", 0));" << std::endl; } else { - os << "CHECK_CUDA_ERRORS(cudaMalloc(" << devicePointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType->getName() << ")));" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaMalloc(" << devicePointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType.getName() << ")));" << std::endl; } } } @@ -1684,8 +1598,8 @@ void Backend::genVariableFree(CodeStream &os, const std::string &name, VarLocati } //-------------------------------------------------------------------------- void Backend::genVariablePush(CodeStream &os, - const Type::ValueBase *type, const std::string &name, VarLocation loc, - bool autoInitialized, size_t count) const + const Type::ResolvedType &type, const std::string &name, + VarLocation loc, bool autoInitialized, size_t count) const { assert(!getPreferences().automaticCopy); @@ -1697,7 +1611,7 @@ void Backend::genVariablePush(CodeStream &os, os << "CHECK_CUDA_ERRORS(cudaMemcpy(d_" << name; os << ", " << name; - os << ", " << count << " * sizeof(" << type->getName() << "), cudaMemcpyHostToDevice));" << std::endl; + os << ", " << count << " * sizeof(" << type.getName() << "), cudaMemcpyHostToDevice));" << std::endl; if(autoInitialized) { os << CodeStream::CB(1101); @@ -1706,7 +1620,7 @@ void Backend::genVariablePush(CodeStream &os, } //-------------------------------------------------------------------------- void Backend::genVariablePull(CodeStream &os, - const Type::ValueBase *type, const std::string &name, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, size_t count) const { assert(!getPreferences().automaticCopy); @@ -1714,13 +1628,13 @@ void Backend::genVariablePull(CodeStream &os, if(!(loc & VarLocation::ZERO_COPY)) { os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << name; os << ", d_" << name; - os << ", " << count << " * sizeof(" << type->getName() << "), cudaMemcpyDeviceToHost));" << std::endl; + os << ", " << count << " * sizeof(" << type.getName() << "), cudaMemcpyDeviceToHost));" << std::endl; } } //-------------------------------------------------------------------------- void Backend::genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const std::string &name, - VarLocation loc, unsigned int batchSize) const + const Type::ResolvedType &type, const std::string &name, + VarLocation loc, unsigned int batchSize) const { assert(!getPreferences().automaticCopy); @@ -1730,15 +1644,15 @@ void Backend::genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal & if(batchSize == 1) { os << "CHECK_CUDA_ERRORS(cudaMemcpy(d_" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; os << ", " << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() << " * sizeof(" << type->getName() << "), cudaMemcpyHostToDevice));" << std::endl; + os << ", " << ng.getNumNeurons() << " * sizeof(" << type.getName() << "), cudaMemcpyHostToDevice));" << std::endl; } // Otherwise, perform a 2D memcpy to copy current timestep's data from each batch else { os << "CHECK_CUDA_ERRORS(cudaMemcpy2D(d_" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type->getName() << ")"; + os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type.getName() << ")"; os << ", " << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type->getName() << ")"; - os << ", " << ng.getNumNeurons() << " * sizeof(" << type->getName() << ")"; + os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type.getName() << ")"; + os << ", " << ng.getNumNeurons() << " * sizeof(" << type.getName() << ")"; os << ", " << batchSize << ", cudaMemcpyHostToDevice));" << std::endl; } } @@ -1749,8 +1663,8 @@ void Backend::genCurrentVariablePush(CodeStream &os, const NeuronGroupInternal & } //-------------------------------------------------------------------------- void Backend::genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal &ng, - const Type::ValueBase *type, const std::string &name, - VarLocation loc, unsigned int batchSize) const + const Type::ResolvedType &type, const std::string &name, + VarLocation loc, unsigned int batchSize) const { assert(!getPreferences().automaticCopy); @@ -1760,14 +1674,14 @@ void Backend::genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal & if(batchSize == 1) { os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; os << ", d_" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() << " * sizeof(" << type->getName() << "), cudaMemcpyDeviceToHost));" << std::endl; + os << ", " << ng.getNumNeurons() << " * sizeof(" << type.getName() << "), cudaMemcpyDeviceToHost));" << std::endl; } else { os << "CHECK_CUDA_ERRORS(cudaMemcpy2D(" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type->getName() << ")"; + os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type.getName() << ")"; os << ", d_" << name << ng.getName() << " + (spkQuePtr" << ng.getName() << " * " << ng.getNumNeurons() << ")"; - os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type->getName() << ")"; - os << ", " << ng.getNumNeurons() << " * sizeof(" << type->getName() << ")"; + os << ", " << ng.getNumNeurons() * ng.getNumDelaySlots() << " * sizeof(" << type.getName() << ")"; + os << ", " << ng.getNumNeurons() << " * sizeof(" << type.getName() << ")"; os << ", " << batchSize << ", cudaMemcpyDeviceToHost));" << std::endl; } } @@ -1778,52 +1692,50 @@ void Backend::genCurrentVariablePull(CodeStream &os, const NeuronGroupInternal & } //-------------------------------------------------------------------------- void Backend::genVariableDynamicPush(CodeStream &os, - const Type::Base *type, const std::string &name, VarLocation loc, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName, const std::string &prefix) const { assert(!getPreferences().automaticCopy); if(!(loc & VarLocation::ZERO_COPY)) { - const auto *pointerType = dynamic_cast(type); - if (pointerType) { + if (type.isPointer()) { os << "CHECK_CUDA_ERRORS(cudaMemcpy(*" << prefix << "d_" << name; os << ", *" << prefix << name; - os << ", " << countVarName << " * sizeof(" << pointerType->getValueType()->getName() << "), cudaMemcpyHostToDevice));" << std::endl; + os << ", " << countVarName << " * sizeof(" << type.getPointer().valueType->getName() << "), cudaMemcpyHostToDevice));" << std::endl; } else { - os << prefix << name << " = new " << type->getName() << "[" << countVarName << "];" << std::endl; + os << prefix << name << " = new " << type.getName() << "[" << countVarName << "];" << std::endl; os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << prefix << "d_" << name; os << ", " << prefix << name; - os << ", " << countVarName << " * sizeof(" << type->getName() << "), cudaMemcpyHostToDevice));" << std::endl; + os << ", " << countVarName << " * sizeof(" << type.getName() << "), cudaMemcpyHostToDevice));" << std::endl; } } } //-------------------------------------------------------------------------- void Backend::genVariableDynamicPull(CodeStream &os, - const Type::Base *type, const std::string &name, VarLocation loc, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName, const std::string &prefix) const { assert(!getPreferences().automaticCopy); if(!(loc & VarLocation::ZERO_COPY)) { - const auto *pointerType = dynamic_cast(type); - if (pointerType) { + if (type.isPointer()) { os << "CHECK_CUDA_ERRORS(cudaMemcpy(*" << prefix << name; os << ", *" << prefix << "d_" << name; - os << ", " << countVarName << " * sizeof(" << pointerType->getValueType()->getName() << "), cudaMemcpyDeviceToHost));" << std::endl; + os << ", " << countVarName << " * sizeof(" << type.getPointer().valueType->getName() << "), cudaMemcpyDeviceToHost));" << std::endl; } else { os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << prefix << name; os << ", " << prefix << "d_" << name; - os << ", " << countVarName << " * sizeof(" << type->getName() << "), cudaMemcpyDeviceToHost));" << std::endl; + os << ", " << countVarName << " * sizeof(" << type.getName() << "), cudaMemcpyDeviceToHost));" << std::endl; } } } //-------------------------------------------------------------------------- void Backend::genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, - const std::string &groupIdx, const std::string &fieldName, - const std::string &egpName) const + const std::string &groupIdx, const std::string &fieldName, + const std::string &egpName) const { const std::string structName = "Merged" + suffix + "Group" + std::to_string(mergedGroupIdx); os << "CHECK_CUDA_ERRORS(cudaMemcpyToSymbolAsync(d_merged" << suffix << "Group" << mergedGroupIdx; @@ -1831,9 +1743,9 @@ void Backend::genMergedDynamicVariablePush(CodeStream &os, const std::string &su os << ", (sizeof(" << structName << ") * (" << groupIdx << ")) + offsetof(" << structName << ", " << fieldName << ")));" << std::endl; } //-------------------------------------------------------------------------- -std::string Backend::getMergedGroupFieldHostTypeName(const Type::Base *type) const +std::string Backend::getMergedGroupFieldHostTypeName(const Type::ResolvedType &type) const { - return type->getName(); + return type.getName(); } //-------------------------------------------------------------------------- void Backend::genGlobalDeviceRNG(CodeStream &, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &, CodeStream &, @@ -2085,6 +1997,51 @@ std::string Backend::getNVCCFlags() const #endif return nvccFlags; } +//----------------------------------------------------------------------- +std::string Backend::getNCCLReductionType(VarAccessMode mode) const +{ + // Convert GeNN reduction types to NCCL + if(mode & VarAccessModeAttribute::MAX) { + return "ncclMax"; + } + else if(mode & VarAccessModeAttribute::SUM) { + return "ncclSum"; + } + else { + throw std::runtime_error("Reduction type unsupported by NCCL"); + } +} +//----------------------------------------------------------------------- +std::string Backend::getNCCLType(const Type::ResolvedType &type) const +{ + assert(type.isNumeric()); + + // Convert GeNN types to NCCL types + if(type == Type::Int8) { + return "ncclInt8"; + } + else if(type == Type::Uint8) { + return "ncclUint8"; + } + else if(type == Type::Int32) { + return "ncclInt32"; + } + else if(type == Type::Uint32){ + return "ncclUint32"; + } + /*else if(type == "half") { + return "ncclFloat16"; + }*/ + else if(type == Type::Float){ + return "ncclFloat32"; + } + else if(type == Type::Double) { + return "ncclFloat64"; + } + else { + throw std::runtime_error("Data type '" + type.getName() + "' unsupported by NCCL"); + } +} //-------------------------------------------------------------------------- void Backend::genKernelDimensions(CodeStream &os, Kernel kernel, size_t numThreadsX, size_t batchSize, size_t numBlockThreadsY) const { diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 29714ef225..bcf810b2c7 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1294,7 +1294,7 @@ void Backend::genDefinitionsPreamble(CodeStream &os, const ModelSpecMerged &mode os << "#include " << std::endl; // If a global RNG is required, define standard host distributions as recreating them each call is slow - if(isGlobalHostRNGRequired(modelMerged)) { + if(isGlobalHostRNGRequired(model)) { os << "EXPORT_VAR " << "std::uniform_real_distribution<" << model.getPrecision().getName() << "> standardUniformDistribution;" << std::endl; os << "EXPORT_VAR " << "std::normal_distribution<" << model.getPrecision().getName() << "> standardNormalDistribution;" << std::endl; os << "EXPORT_VAR " << "std::exponential_distribution<" << model.getPrecision().getName() << "> standardExponentialDistribution;" << std::endl; @@ -1342,7 +1342,7 @@ void Backend::genRunnerPreamble(CodeStream &os, const ModelSpecMerged &modelMerg const ModelSpecInternal &model = modelMerged.getModel(); // If a global RNG is required, implement standard host distributions as recreating them each call is slow - if(isGlobalHostRNGRequired(modelMerged)) { + if(isGlobalHostRNGRequired(model)) { os << "std::uniform_real_distribution<" << model.getPrecision().getName() << "> standardUniformDistribution(" << writePreciseLiteral(0.0, model.getPrecision()) << ", " << writePreciseLiteral(1.0, model.getPrecision()) << ");" << std::endl; os << "std::normal_distribution<" << model.getPrecision().getName() << "> standardNormalDistribution(" << writePreciseLiteral(0.0, model.getPrecision()) << ", " << writePreciseLiteral(1.0, model.getPrecision()) << ");" << std::endl; os << "std::exponential_distribution<" << model.getPrecision().getName() << "> standardExponentialDistribution(" << writePreciseLiteral(1.0, model.getPrecision()) << ");" << std::endl; @@ -1612,11 +1612,10 @@ void Backend::genMSBuildImportTarget(std::ostream&) const { } //-------------------------------------------------------------------------- -bool Backend::isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) const +bool Backend::isGlobalHostRNGRequired(const ModelSpecInternal &model) const { // If any neuron groups require simulation RNGs or require RNG for initialisation, return true // **NOTE** this takes postsynaptic model initialisation into account - const ModelSpecInternal &model = modelMerged.getModel(); if(std::any_of(model.getNeuronGroups().cbegin(), model.getNeuronGroups().cend(), [](const ModelSpec::NeuronGroupValueType &n) { @@ -1672,7 +1671,7 @@ bool Backend::isGlobalHostRNGRequired(const ModelSpecMerged &modelMerged) const return false; } //-------------------------------------------------------------------------- -bool Backend::isGlobalDeviceRNGRequired(const ModelSpecMerged &) const +bool Backend::isGlobalDeviceRNGRequired(const ModelSpecInternal &) const { return false; } From 265ced4ba4f23e732a8d39a553322b48ced0e096 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 15:05:40 +0100 Subject: [PATCH 355/725] CUDA backend compiling --- include/genn/backends/cuda/backend.h | 39 +++++-- .../genn/genn/code_generator/backendBase.h | 2 +- .../genn/genn/code_generator/backendSIMT.h | 4 +- src/genn/backends/cuda/backend.cc | 45 ++++---- src/genn/genn/code_generator/backendSIMT.cc | 104 +++++++++--------- .../genn/code_generator/generateRunner.cc | 4 +- 6 files changed, 109 insertions(+), 89 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index f25662584a..b3a5ab410d 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -246,17 +246,17 @@ class BACKEND_EXPORT Backend : public BackendSIMT const std::string &egpName) const final; //! When generating function calls to push to merged groups, backend without equivalent of Unified Virtual Addressing e.g. OpenCL 1.2 may use different types on host - virtual std::string getMergedGroupFieldHostTypeName(const Type::ResolvedType &type) const = 0; + virtual std::string getMergedGroupFieldHostTypeName(const Type::ResolvedType &type) const final; //! Generate a single RNG instance /*! On single-threaded platforms this can be a standard RNG like M.T. but, on parallel platforms, it is likely to be a counter-based RNG */ virtual void genGlobalDeviceRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, - CodeStream &allocations, CodeStream &free, MemAlloc &memAlloc) const = 0; + CodeStream &allocations, CodeStream &free, MemAlloc &memAlloc) const final; //! Generate an RNG with a state per population member virtual void genPopulationRNG(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, - const std::string &name, size_t count, MemAlloc &memAlloc) const = 0; + const std::string &name, size_t count, MemAlloc &memAlloc) const final; virtual void genTimer(CodeStream &definitions, CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, CodeStream &free, CodeStream &stepTimeFinalise, @@ -351,7 +351,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT } template - void genNCCLReduction(EnvironmentExternal &env, G &cg) const + void genNCCLReduction(EnvironmentExternalBase &env, G &cg) const { CodeStream::Scope b(env.getStream()); env.getStream() << "// merged custom update host reduction group " << cg.getIndex() << std::endl; @@ -361,28 +361,43 @@ class BACKEND_EXPORT Backend : public BackendSIMT // Get reference to group env.getStream() << "const auto *group = &merged" << G::name << "Group" << cg.getIndex() << "[g]; " << std::endl; - EnvironmentGroupMergedField groupEnv(env, g); + EnvironmentGroupMergedField groupEnv(env, cg); - // Loop through variables and add pointers if they are reduction targets + // Loop through variables const CustomUpdateModels::Base *cm = cg.getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { + // If variable is reduction target if(v.access & VarAccessModeAttribute::REDUCE) { - groupEnv.addField(v.type.resolve(getGroup().getTypeContext()).createPointer(), "_" + v.name, v.name, + // Add pointer field + const auto resolvedType = v.type.resolve(cg.getTypeContext()); + groupEnv.addField(resolvedType.createPointer(), "_" + v.name, v.name, [this, v](const auto &g, size_t) { return getDeviceVarPrefix() + v.name + g.getName(); }); - + + // Add NCCL reduction groupEnv.print("CHECK_NCCL_ERRORS(ncclAllReduce($(_" + v.name + "), $(_" + v.name + "), $(_size)"); - groupEnv.printLine(", " + getNCCLType(v.type, cg.getTypeContext()) + ", " + getNCCLReductionType(getVarAccessMode(v.access)) + ", ncclCommunicator, 0));"); + groupEnv.printLine(", " + getNCCLType(resolvedType) + ", " + getNCCLReductionType(getVarAccessMode(v.access)) + ", ncclCommunicator, 0));"); } } - // Loop through variable references and add pointers if they are reduction targets + // Loop through variable references for(const auto &v : cm->getVarRefs()) { + // If variable reference ios reduction target if(v.access & VarAccessModeAttribute::REDUCE) { - os << "CHECK_NCCL_ERRORS(ncclAllReduce(group->" << v.name << ", group->" << v.name << ", group->size"; - os << ", " << getNCCLType(v.type, cg.getTypeContext()) << ", " << getNCCLReductionType(v.access) << ", ncclCommunicator, 0));" << std::endl; + // Add pointer field + const auto resolvedType = v.type.resolve(cg.getTypeContext()); + groupEnv.addField(resolvedType.createPointer(), "_" + v.name, v.name, + [this, v](const auto &g, size_t) + { + const auto varRef = g.getVarReferences().at(v.name); + return getDeviceVarPrefix() + varRef.getVar().name + varRef.getTargetName(); ; + }); + + // Add NCCL reduction + groupEnv.print("CHECK_NCCL_ERRORS(ncclAllReduce($(_" + v.name + "), $(_" + v.name + "), $(_size)"); + groupEnv.printLine(", " + getNCCLType(v.type.resolve(cg.getTypeContext())) + ", " + getNCCLReductionType(v.access) + ", ncclCommunicator, 0));"); } } } diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 7de35ad48a..e8e8f63b09 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -31,6 +31,7 @@ namespace GeNN { class CustomUpdateInternal; class CustomUpdateWUInternal; +class ModelSpecInternal; class NeuronGroupInternal; class SynapseGroupInternal; @@ -39,7 +40,6 @@ namespace CodeGenerator template class EnvironmentGroupMergedField; class EnvironmentExternalBase; -class ModelSpecInternal; class ModelSpecMerged; template class GroupMerged; diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 8f55e8462c..b0fceafd36 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -136,8 +136,8 @@ class GENN_EXPORT BackendSIMT : public BackendBase genSynapseVariableRowInit(env, handler); } - virtual void genKernelSynapseVariableInit(EnvironmentExternalBase &env, const SynapseInitGroupMerged &sg, HandlerEnv handler) const final; - virtual void genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, const CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const final; + virtual void genKernelSynapseVariableInit(EnvironmentExternalBase &env, SynapseInitGroupMerged &sg, HandlerEnv handler) const final; + virtual void genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const final; //! Should 'scalar' variables be implemented on device or can host variables be used directly? virtual bool isDeviceScalarRequired() const final { return true; } diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index e886889097..f323f6cd32 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -128,7 +128,7 @@ void genGroupStartIDs(CodeStream &, size_t&, size_t&) template void genGroupStartIDs(CodeStream &os, size_t &idStart, size_t &totalConstMem, const std::vector &mergedGroups, G getPaddedNumThreads, - Args... args) + Args&&... args) { // Loop through merged groups for(const auto &m : mergedGroups) { @@ -136,16 +136,16 @@ void genGroupStartIDs(CodeStream &os, size_t &idStart, size_t &totalConstMem, } // Generate any remaining groups - genGroupStartIDs(os, idStart, totalConstMem, args...); + genGroupStartIDs(os, idStart, totalConstMem, std::forward(args)...); } //----------------------------------------------------------------------- template void genMergedKernelDataStructures(CodeStream &os, size_t &totalConstMem, - Args... args) + Args&&... args) { // Generate group start id arrays size_t idStart = 0; - genGroupStartIDs(os, std::ref(idStart), std::ref(totalConstMem), args...); + genGroupStartIDs(os, std::ref(idStart), std::ref(totalConstMem), std::forward(args)...); } //----------------------------------------------------------------------- template @@ -323,7 +323,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // If any neuron groups require their previous spike times updating size_t idNeuronPrevSpikeTimeUpdate = 0; if(std::any_of(model.getNeuronGroups().cbegin(), model.getNeuronGroups().cend(), - [](const auto &ng){ return (ng.isPrevSpikeTimeRequired() || ng.isPrevSpikeEventTimeRequired())})) + [](const auto &ng){ return (ng.second.isPrevSpikeTimeRequired() || ng.second.isPrevSpikeEventTimeRequired()); })) { neuronUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelNeuronPrevSpikeTimeUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)"; { @@ -475,7 +475,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // If there are any presynaptic update groups size_t idPresynapticStart = 0; if(std::any_of(model.getSynapseGroups().cbegin(), model.getSynapseGroups().cend(), - [](const auto &sg){ return (sg.isSpikeEventRequired() || sg.isTrueSpikeRequired()); })) + [](const auto &sg){ return (sg.second.isSpikeEventRequired() || sg.second.isTrueSpikeRequired()); })) { synapseUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelPresynapticUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; // end of synapse kernel header { @@ -502,7 +502,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // If any synapse groups require postsynaptic learning size_t idPostsynapticStart = 0; if(std::any_of(model.getSynapseGroups().cbegin(), model.getSynapseGroups().cend(), - [](const auto &sg){ return !Utils::areTokensEmpty(sg.getWUPostLearnCodeTokens()); })) + [](const auto &sg){ return !Utils::areTokensEmpty(sg.second.getWUPostLearnCodeTokens()); })) { synapseUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelPostsynapticUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; { @@ -526,7 +526,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // If any synapse groups require synapse dynamics size_t idSynapseDynamicsStart = 0; if(std::any_of(model.getSynapseGroups().cbegin(), model.getSynapseGroups().cend(), - [](const auto &sg){ return !Utils::areTokensEmpty(sg.getWUSynapseDynamicsCodeTokens()); })) + [](const auto &sg){ return !Utils::areTokensEmpty(sg.second.getWUSynapseDynamicsCodeTokens()); })) { synapseUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDynamicsUpdate] << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; // end of synapse kernel header { @@ -650,11 +650,11 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Generate kernel size_t idCustomUpdateStart = 0; if(std::any_of(model.getCustomUpdates().cbegin(), model.getCustomUpdates().cend(), - [&g](const auto &cg) { return (cg.getUpdateGroupName() == g); }) + [&g](const auto &cg) { return (cg.second.getUpdateGroupName() == g); }) || std::any_of(model.getCustomWUUpdates().cbegin(), model.getCustomWUUpdates().cend(), - [&g](const auto &cg) { return (!cg.isTransposeOperation() && cg.getUpdateGroupName() == g); }) + [&g](const auto &cg) { return (!cg.second.isTransposeOperation() && cg.second.getUpdateGroupName() == g); }) || std::any_of(model.getCustomConnectivityUpdates().cbegin(), model.getCustomConnectivityUpdates().cend(), - [&g](const auto &cg) { return (!Utils::areTokensEmpty(cg.getRowUpdateCodeTokens()) && cg.getUpdateGroupName() == g); })) + [&g](const auto &cg) { return (!Utils::areTokensEmpty(cg.second.getRowUpdateCodeTokens()) && cg.second.getUpdateGroupName() == g); })) { customUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelCustomUpdate] << g << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; { @@ -681,7 +681,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host size_t idCustomTransposeUpdateStart = 0; if(std::any_of(model.getCustomWUUpdates().cbegin(), model.getCustomWUUpdates().cend(), - [&g](const auto &cg){ return (cg.isTransposeOperation() && cg.getUpdateGroupName() == g); })) + [&g](const auto &cg){ return (cg.second.isTransposeOperation() && cg.second.getUpdateGroupName() == g); })) { customUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelCustomTransposeUpdate] << g << "(" << modelMerged.getModel().getTimePrecision().getName() << " t)" << std::endl; { @@ -732,19 +732,21 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host if(getPreferences().enableNCCLReductions) { // Loop through custom update host reduction groups and // generate reductions for those in this custom update group - for(auto &cg : modelMerged.getMergedCustomUpdateHostReductionGroups()) { - if(cg.getArchetype().getUpdateGroupName() == g) { + modelMerged.genMergedCustomUpdateHostReductionGroups( + *this, g, + [this, &customUpdateEnv, &modelMerged](auto &cg) + { genNCCLReduction(customUpdateEnv, cg); - } - } + }); - // Loop through custom update host reduction groups and + // Loop through custom WU update host reduction groups and // generate reductions for those in this custom update group - for(auto &cg : modelMerged.getMergedCustomWUUpdateHostReductionGroups()) { - if(cg.getArchetype().getUpdateGroupName() == g) { + modelMerged.genMergedCustomWUUpdateHostReductionGroups( + *this, g, + [this, &customUpdateEnv, &modelMerged](auto &cg) + { genNCCLReduction(customUpdateEnv, cg); - } - } + }); } // If timing is enabled @@ -994,7 +996,6 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler // Generate data structure for accessing merged groups from within initialisation kernel // **NOTE** pass in zero constant cache here as it's precious and would be wasted on init kernels which are only launched once - const ModelSpecInternal &model = modelMerged.getModel(); size_t totalConstMem = 0; genMergedKernelDataStructures( os, totalConstMem, diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index dd0843beaa..f4bf853ffb 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -76,7 +76,7 @@ void BackendSIMT::genVariableInit(EnvironmentExternalBase &env, const std::strin handler(env); } //-------------------------------------------------------------------------- -void BackendSIMT::genKernelSynapseVariableInit(EnvironmentExternalBase &env, const SynapseInitGroupMerged &sg, HandlerEnv handler) const +void BackendSIMT::genKernelSynapseVariableInit(EnvironmentExternalBase &env, SynapseInitGroupMerged &sg, HandlerEnv handler) const { // Variable should already be provided via parallelism //assert(kernelSubs.hasVarSubstitution("id")); @@ -87,7 +87,7 @@ void BackendSIMT::genKernelSynapseVariableInit(EnvironmentExternalBase &env, con handler(varEnv); } //-------------------------------------------------------------------------- -void BackendSIMT::genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, const CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const +void BackendSIMT::genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const { // Variable should already be provided via parallelism //assert(kernelSubs.hasVarSubstitution("id")); @@ -396,7 +396,7 @@ void BackendSIMT::genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, idStart = 0; modelMerged.genMergedNeuronSpikeQueueUpdateGroups( *this, - [&env, &idStart, batchSize, this](const auto &n) + [&env, &idStart, batchSize, this](auto &n) { if(idStart == 0) { env.getStream() << "if(id < " << n.getGroups().size() << ")"; @@ -411,7 +411,7 @@ void BackendSIMT::genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, env.getStream() << getPointerPrefix() << "struct MergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << " *group = &d_mergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << "[id - " << idStart << "]; " << std::endl; // Create matching environment - EnvironmentGroupMergedField neuronEnv(env, batchSize); + EnvironmentGroupMergedField neuronEnv(env, n); genNeuronIndexCalculation(neuronEnv, batchSize); if(n.getArchetype().isDelayRequired()) { // with delay @@ -527,96 +527,100 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM genSharedMemBarrier(groupEnv.getStream()); - if(ng.getArchetype().isSpikeEventRequired()) { - groupEnv.getStream() << "if (" << getThreadID() << " == 1)"; + // Use first thread to 'allocate' block of $(_spk) array for this block's spikes + if(!ng.getArchetype().getNeuronModel()->getThresholdConditionCode().empty()) { + groupEnv.getStream() << "if(" << getThreadID() << " == 0)"; { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.print("if($(_sh_spk_evnt_count) > 0)"); + groupEnv.print("if ($(_sh_spk_count) > 0)"); { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.print("$(_sh_spk_evnt_pos) = " + getAtomic(Type::Uint32) + "(&$(_spk_cnt_evnt)"); - if(ng.getArchetype().isDelayRequired()) { + groupEnv.print("$(_sh_spk_pos) = " + getAtomic(Type::Uint32) + "(&$(_spk_cnt)"); + if(ng.getArchetype().isDelayRequired() && ng.getArchetype().isTrueSpikeRequired()) { groupEnv.print("[*$(_spk_que_ptr)"); if(batchSize > 1) { groupEnv.getStream() << " + (batch * " << ng.getArchetype().getNumDelaySlots() << ")"; } - groupEnv.printLine("], $(_sh_spk_evnt_count));"); + groupEnv.printLine("], $(_sh_spk_count));"); } else { - groupEnv.printLine("[$(batch)], $(_sh_spk_evnt_count));"); + groupEnv.printLine("[$(batch)], $(_sh_spk_count));"); } } } genSharedMemBarrier(groupEnv.getStream()); } - if(!ng.getArchetype().getNeuronModel()->getThresholdConditionCode().empty()) { - groupEnv.getStream() << "if(" << getThreadID() << " == 0)"; + // Use second thread to 'allocate' block of $(_spk_evnt) array for this block's spike-like events + if(ng.getArchetype().isSpikeEventRequired()) { + groupEnv.getStream() << "if (" << getThreadID() << " == 1)"; { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.print("if ($(_sh_spk_count) > 0)"); + groupEnv.print("if($(_sh_spk_evnt_count) > 0)"); { CodeStream::Scope b(groupEnv.getStream()); - groupEnv.getStream() << "shPosSpk = " << getAtomic(Type::Uint32) << "(&group->spkCnt"; - if(ng.getArchetype().isDelayRequired() && ng.getArchetype().isTrueSpikeRequired()) { - groupEnv.getStream() << "[*" << groupEnv["_spk_que_ptr"]; + groupEnv.print("$(_sh_spk_evnt_pos) = " + getAtomic(Type::Uint32) + "(&$(_spk_cnt_evnt)"); + if(ng.getArchetype().isDelayRequired()) { + groupEnv.print("[*$(_spk_que_ptr)"); if(batchSize > 1) { groupEnv.getStream() << " + (batch * " << ng.getArchetype().getNumDelaySlots() << ")"; } - groupEnv.getStream() << "], shSpkCount);" << std::endl; + groupEnv.printLine("], $(_sh_spk_evnt_count));"); } else { - groupEnv.getStream() << "[" << ((batchSize > 1) ? "batch" : "0") << "], shSpkCount);" << std::endl; + groupEnv.printLine("[$(batch)], $(_sh_spk_evnt_count));"); } } } genSharedMemBarrier(groupEnv.getStream()); } + // Copy spikes into block of $(_spk) const std::string queueOffset = ng.getWriteVarIndex(ng.getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, ""); - if(ng.getArchetype().isSpikeEventRequired()) { - neuronEnv.getStream() << "if(" << getThreadID() << " < shSpkEvntCount)"; - { - CodeStream::Scope b(neuronEnv.getStream()); - neuronEnv.getStream() << "const unsigned int n = shSpkEvnt[" << getThreadID() << "];" << std::endl; - - neuronEnv.printLine("$(_spk_evnt)[" + queueOffset + "shPosSpkEvnt + " + getThreadID() + "] = n;"); - if(ng.getArchetype().isSpikeEventTimeRequired()) { - neuronEnv.printLine("$(_spk_evnt_time)[" + queueOffset + "n] = t;"); - } - } - } - if(!ng.getArchetype().getNeuronModel()->getThresholdConditionCode().empty()) { const std::string queueOffsetTrueSpk = ng.getWriteVarIndex(ng.getArchetype().isTrueSpikeRequired() && ng.getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, ""); - neuronEnv.getStream() << "if(" << getThreadID() << " < shSpkCount)"; + groupEnv.print("if(" + getThreadID() + " < $(_sh_spk_count))"); { - CodeStream::Scope b(neuronEnv.getStream()); + CodeStream::Scope b(groupEnv.getStream()); - neuronEnv.getStream() << "const unsigned int n = shSpk[" << getThreadID() << "];" << std::endl; + groupEnv.printLine("const unsigned int n = $(_sh_spk)[" + getThreadID() + "];"); // Create new substition stack and explicitly replace id with 'n' and perform WU var update - EnvironmentExternal wuEnv(neuronEnv); + EnvironmentExternal wuEnv(groupEnv); wuEnv.add(Type::Uint32.addConst(), "id", "n"); ng.generateWUVarUpdate(*this, wuEnv, modelMerged); - neuronEnv.printLine("$(_spk)[" + queueOffsetTrueSpk + "shPosSpk + " + getThreadID() + "] = n;"); + groupEnv.printLine("$(_spk)[" + queueOffsetTrueSpk + "$(_sh_spk_pos) + " + getThreadID() + "] = n;"); if(ng.getArchetype().isSpikeTimeRequired()) { - neuronEnv.printLine("$(_spk_time)[" + queueOffset + "n] = t;"); + groupEnv.printLine("$(_spk_time)[" + queueOffset + "n] = $(t);"); + } + } + } + + // Copy spike-like events into block of $(_spk_evnt) + if(ng.getArchetype().isSpikeEventRequired()) { + groupEnv.print("if(" + getThreadID() + " < $(_sh_spk_evnt_count))"); + { + CodeStream::Scope b(groupEnv.getStream()); + groupEnv.printLine("const unsigned int n = $(_sh_spk_evnt)[" + getThreadID() + "];"); + + groupEnv.printLine("$(_spk_evnt)[" + queueOffset + "$(_sh_spk_evnt_pos) + " + getThreadID() + "] = n;"); + if(ng.getArchetype().isSpikeEventTimeRequired()) { + groupEnv.printLine("$(_spk_evnt_time)[" + queueOffset + "n] = $(t);"); } } } // If we're recording spikes or spike-like events, use enough threads to copy this block's recording words if(ng.getArchetype().isSpikeRecordingEnabled() || ng.getArchetype().isSpikeEventRecordingEnabled()) { - neuronEnv.getStream() << "if(" << getThreadID() << " < " << m_KernelBlockSizes[KernelNeuronUpdate] / 32 << ")"; + groupEnv.getStream() << "if(" << getThreadID() << " < " << m_KernelBlockSizes[KernelNeuronUpdate] / 32 << ")"; { - CodeStream::Scope b(neuronEnv.getStream()); + CodeStream::Scope b(groupEnv.getStream()); // Calculate number of words which will be used to record this population's spikes in each batch - neuronEnv.printLine("const unsigned int numRecordingWords = ($(num_neurons) + 31) / 32;"); - neuronEnv.printLine("const unsigned int popWordIdx = ($(id) / 32) + " + getThreadID() + ";"); + groupEnv.printLine("const unsigned int numRecordingWords = ($(num_neurons) + 31) / 32;"); + groupEnv.printLine("const unsigned int popWordIdx = ($(id) / 32) + " + getThreadID() + ";"); // Build global index std::string globalIndex = "(recordingTimestep * numRecordingWords * " + std::to_string(batchSize) + ") + popWordIdx"; @@ -624,25 +628,25 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM globalIndex += " + (batch * numRecordingWords)"; } - neuronEnv.getStream() << "if(popWordIdx < numRecordingWords)"; + groupEnv.getStream() << "if(popWordIdx < numRecordingWords)"; { - CodeStream::Scope c(neuronEnv.getStream()); + CodeStream::Scope c(groupEnv.getStream()); // If we are recording spikes, copy word to correct location in global memory if(ng.getArchetype().isSpikeRecordingEnabled()) { - neuronEnv.getStream() << neuronEnv["_record_spk"] << "[" << globalIndex << "] = shSpkRecord"; + groupEnv.getStream() << groupEnv["_record_spk"] << "[" << globalIndex << "] = shSpkRecord"; if(m_KernelBlockSizes[KernelNeuronUpdate] != 32) { - neuronEnv.getStream() << "[" << getThreadID() << "]"; + groupEnv.getStream() << "[" << getThreadID() << "]"; } - neuronEnv.getStream() << ";" << std::endl; + groupEnv.getStream() << ";" << std::endl; } // If we are recording spike-like events, copy word to correct location in global memory if(ng.getArchetype().isSpikeEventRecordingEnabled()) { - neuronEnv.getStream() << neuronEnv["_record_spk_evnt"] << "[" << globalIndex << "] = shSpkEvntRecord"; + groupEnv.getStream() << groupEnv["_record_spk_evnt"] << "[" << globalIndex << "] = shSpkEvntRecord"; if(m_KernelBlockSizes[KernelNeuronUpdate] != 32) { - neuronEnv.getStream() << "[" << getThreadID() << "]"; + groupEnv.getStream() << "[" << getThreadID() << "]"; } - neuronEnv.getStream() << ";" << std::endl; + groupEnv.getStream() << ";" << std::endl; } } } diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index ac8c5c82d8..530924c214 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -625,12 +625,12 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, runnerVarDecl << "unsigned long long numRecordingTimesteps = 0;" << std::endl; } // If backend requires a global device RNG to simulate (or initialize) this model - if(backend.isGlobalDeviceRNGRequired(modelMerged)) { + if(backend.isGlobalDeviceRNGRequired(model)) { backend.genGlobalDeviceRNG(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, mem); } // If backend required a global host RNG to simulate (or initialize) this model, generate a standard Mersenne Twister - if(backend.isGlobalHostRNGRequired(modelMerged)) { + if(backend.isGlobalHostRNGRequired(model)) { genGlobalHostRNG(definitionsVar, runnerVarDecl, runnerVarAlloc, model.getSeed(), mem); } allVarStreams << std::endl; From ff33af2986bdc44e8ecd8de5ee7de36dc7a52dec Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 16:02:25 +0100 Subject: [PATCH 356/725] re-attached memoryspaces mechanism --- include/genn/backends/cuda/backend.h | 12 +- .../backends/single_threaded_cpu/backend.h | 12 +- .../genn/genn/code_generator/backendBase.h | 12 +- .../genn/genn/code_generator/backendSIMT.h | 52 +++-- .../genn/code_generator/generateModules.h | 16 +- .../genn/genn/code_generator/generateRunner.h | 2 +- .../genn/code_generator/modelSpecMerged.h | 89 +++++--- src/genn/backends/cuda/backend.cc | 44 ++-- src/genn/backends/cuda/optimiser.cc | 14 +- .../backends/single_threaded_cpu/backend.cc | 54 +++-- src/genn/genn/code_generator/backendSIMT.cc | 74 +++--- .../genn/code_generator/generateModules.cc | 38 ++-- .../genn/code_generator/generateRunner.cc | 4 +- .../genn/code_generator/modelSpecMerged.cc | 215 ++++++------------ 14 files changed, 323 insertions(+), 315 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index b3a5ab410d..b883bd32aa 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -173,13 +173,17 @@ class BACKEND_EXPORT Backend : public BackendSIMT //-------------------------------------------------------------------------- // CodeGenerator::BackendBase virtuals //-------------------------------------------------------------------------- - virtual void genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; + virtual void genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const final; - virtual void genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; + virtual void genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const final; - virtual void genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; + virtual void genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const final; - virtual void genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; + virtual void genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const final; virtual void genDefinitionsPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final; virtual void genDefinitionsInternalPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final; diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 2832a088c8..268372d158 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -42,13 +42,17 @@ class BACKEND_EXPORT Backend : public BackendBase //-------------------------------------------------------------------------- // CodeGenerator::BackendBase virtuals //-------------------------------------------------------------------------- - virtual void genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; + virtual void genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const final; - virtual void genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; + virtual void genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const final; - virtual void genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; + virtual void genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const final; - virtual void genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const final; + virtual void genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const final; virtual size_t getSynapticMatrixRowStride(const SynapseGroupInternal &sg) const final; diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index e8e8f63b09..84b489d29c 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -207,25 +207,29 @@ class GENN_EXPORT BackendBase /*! \param os CodeStream to write function to \param modelMerged merged model to generate code for \param preambleHandler callback to write functions for pushing extra-global parameters*/ - virtual void genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; + virtual void genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const = 0; //! Generate platform-specific function to update the state of all synapses /*! \param os CodeStream to write function to \param modelMerged merged model to generate code for \param preambleHandler callback to write functions for pushing extra-global parameters*/ - virtual void genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; + virtual void genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const = 0; //! Generate platform-specific functions to perform custom updates /*! \param os CodeStream to write function to \param modelMerged merged model to generate code for \param preambleHandler callback to write functions for pushing extra-global parameters*/ - virtual void genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; + virtual void genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const = 0; //! Generate platform-specific function to initialise model /*! \param os CodeStream to write function to \param modelMerged merged model to generate code for \param preambleHandler callback to write functions for pushing extra-global parameters*/ - virtual void genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const = 0; + virtual void genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const = 0; //! Gets the stride used to access synaptic matrix rows, taking into account sparse data structure, padding etc virtual size_t getSynapticMatrixRowStride(const SynapseGroupInternal &sg) const = 0; diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index b0fceafd36..12894b1274 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -184,32 +184,39 @@ class GENN_EXPORT BackendSIMT : public BackendBase //------------------------------------------------------------------------ // Protected API //------------------------------------------------------------------------ - void genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; - void genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; - - void genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; - - void genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; - void genPresynapticUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; - void genPostsynapticUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; - void genSynapseDynamicsKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; + void genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const; + void genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const; + void genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const; + + void genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const; + void genPresynapticUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const; + void genPostsynapticUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const; + void genSynapseDynamicsKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const; void genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerged &modelMerged, - const std::string &updateGroup, size_t &idStart) const; + BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroup, size_t &idStart) const; void genCustomUpdateWUKernel(EnvironmentExternal &env, ModelSpecMerged &modelMerged, - const std::string &updateGroup, size_t &idStart) const; + BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroup, size_t &idStart) const; void genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, ModelSpecMerged &modelMerged, - const std::string &updateGroup, size_t &idStart) const; + BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroup, size_t &idStart) const; void genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, - const std::string &updateGroup, size_t &idStart) const; + BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroup, size_t &idStart) const; - void genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const; + void genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const; void genInitializeSparseKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, - size_t numInitializeThreads, size_t &idStart) const; + size_t numInitializeThreads, BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const; //! Helper wrapper around padSize to pad size to a kernel size size_t padKernelSize(size_t size, Kernel kernel) const; @@ -222,10 +229,10 @@ class GENN_EXPORT BackendSIMT : public BackendBase // Type definitions //-------------------------------------------------------------------------- template - using GenMergedGroupsFn = void (ModelSpecMerged::*)(const BackendBase&, std::function); + using GenMergedGroupsFn = void (ModelSpecMerged::*)(const BackendBase&, BackendBase::MemorySpaces&, std::function); template - using GenMergedCustomUpdateGroupsFn = void (ModelSpecMerged::*)(const BackendBase&, const std::string &, std::function); + using GenMergedCustomUpdateGroupsFn = void (ModelSpecMerged::*)(const BackendBase&, BackendBase::MemorySpaces&, const std::string &, std::function); //-------------------------------------------------------------------------- // Private methods @@ -306,10 +313,10 @@ class GENN_EXPORT BackendSIMT : public BackendBase template - void genParallelGroup(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart, + void genParallelGroup(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, size_t &idStart, GenMergedGroupsFn generateGroupFn, S getPaddedSizeFunc, GroupHandlerEnv handler) const { - std::invoke(generateGroupFn, modelMerged, *this, + std::invoke(generateGroupFn, modelMerged, *this, memorySpaces, [this, getPaddedSizeFunc, handler, &env, &idStart](T &g) { genGroup(env, g, idStart, getPaddedSizeFunc, handler); @@ -317,10 +324,11 @@ class GENN_EXPORT BackendSIMT : public BackendBase } template - void genParallelGroup(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, const std::string &updateGroupName, size_t &idStart, - GenMergedCustomUpdateGroupsFn generateGroupFn, S getPaddedSizeFunc, GroupHandlerEnv handler) const + void genParallelGroup(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + const std::string &updateGroupName, size_t &idStart, GenMergedCustomUpdateGroupsFn generateGroupFn, + S getPaddedSizeFunc, GroupHandlerEnv handler) const { - std::invoke(generateGroupFn, modelMerged, *this, updateGroupName, + std::invoke(generateGroupFn, modelMerged, *this, memorySpaces, updateGroupName, [this, getPaddedSizeFunc, handler, &env, &idStart](T &g) { genGroup(env, g, idStart, getPaddedSizeFunc, handler); diff --git a/include/genn/genn/code_generator/generateModules.h b/include/genn/genn/code_generator/generateModules.h index 0498fecab8..8b8fbc0acd 100644 --- a/include/genn/genn/code_generator/generateModules.h +++ b/include/genn/genn/code_generator/generateModules.h @@ -30,15 +30,15 @@ GENN_EXPORT std::pair, MemAlloc> generateAll(const Mode const filesystem::path &sharePath, const filesystem::path &outputPath, bool forceRebuild = false); -GENN_EXPORT void generateNeuronUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix = ""); +GENN_EXPORT void generateNeuronUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, + BackendBase::MemorySpaces &memorySpaces, const std::string &suffix = ""); -GENN_EXPORT void generateCustomUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix = ""); +GENN_EXPORT void generateCustomUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, + BackendBase::MemorySpaces &memorySpaces, const std::string &suffix = ""); -GENN_EXPORT void generateSynapseUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix = ""); +GENN_EXPORT void generateSynapseUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, + BackendBase::MemorySpaces &memorySpaces, const std::string &suffix = ""); -GENN_EXPORT void generateInit(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix = ""); +GENN_EXPORT void generateInit(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, + BackendBase::MemorySpaces &memorySpaces, const std::string &suffix = ""); } diff --git a/include/genn/genn/code_generator/generateRunner.h b/include/genn/genn/code_generator/generateRunner.h index 1650d79197..12450fba5e 100644 --- a/include/genn/genn/code_generator/generateRunner.h +++ b/include/genn/genn/code_generator/generateRunner.h @@ -26,5 +26,5 @@ class path; namespace GeNN::CodeGenerator { GENN_EXPORT MemAlloc generateRunner(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix = ""); + const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, const std::string &suffix = ""); } diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 1991183fa8..3194494151 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -32,7 +32,12 @@ namespace GeNN::CodeGenerator class GENN_EXPORT ModelSpecMerged { public: - ModelSpecMerged(const ModelSpecInternal &model, const BackendBase &backend); + ModelSpecMerged(const ModelSpecInternal &model, const BackendBase &backend) + : m_Model(model), m_NeuronUpdateSupportCode("NeuronUpdateSupportCode"), m_PostsynapticDynamicsSupportCode("PostsynapticDynamicsSupportCode"), + m_PresynapticUpdateSupportCode("PresynapticUpdateSupportCode"), m_PostsynapticUpdateSupportCode("PostsynapticUpdateSupportCode"), + m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode"), m_TypeContext{{"scalar", model.getPrecision()}, {"timepoint", model.getTimePrecision()}} + { + } //-------------------------------------------------------------------------- // CodeGenerator::ModelSpecMerged::EGPField @@ -164,38 +169,56 @@ class GENN_EXPORT ModelSpecMerged //! Get merged custom connectivity update groups where host processing needs to be performed const std::vector &getMergedCustomConnectivityHostUpdateGroups() const { return m_MergedCustomConnectivityHostUpdateGroups; } - void genMergedNeuronUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedPresynapticUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedPostsynapticUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedSynapseDynamicsGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedCustomUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, + void genMergedNeuronUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedPresynapticUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedPostsynapticUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedSynapseDynamicsGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedCustomUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroupName, GenMergedGroupFn generateGroup); - void genMergedCustomUpdateWUGroups(const BackendBase &backend, const std::string &updateGroupName, + void genMergedCustomUpdateWUGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroupName, GenMergedGroupFn generateGroup); - void genMergedCustomUpdateTransposeWUGroups(const BackendBase &backend, const std::string &updateGroupName, + void genMergedCustomUpdateTransposeWUGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroupName, GenMergedGroupFn generateGroup); - void genMergedCustomUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, + void genMergedCustomUpdateHostReductionGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroupName, GenMergedGroupFn generateGroup); - void genMergedCustomWUUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, + void genMergedCustomWUUpdateHostReductionGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroupName, GenMergedGroupFn generateGroup); - void genMergedCustomConnectivityUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, + void genMergedCustomConnectivityUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroupName, GenMergedGroupFn generateGroup); - void genMergedCustomConnectivityHostUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, + void genMergedCustomConnectivityHostUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroupName, GenMergedGroupFn generateGroup); - void genMergedNeuronSpikeQueueUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedNeuronPrevSpikeTimeUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedSynapseDendriticDelayUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedNeuronInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedCustomUpdateInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedCustomWUUpdateInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedSynapseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedSynapseConnectivityInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedSynapseSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedCustomWUUpdateSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedCustomConnectivityUpdatePreInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedCustomConnectivityUpdatePostInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedCustomConnectivityUpdateSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); - void genMergedSynapseConnectivityHostInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup); + void genMergedNeuronSpikeQueueUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedNeuronPrevSpikeTimeUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedSynapseDendriticDelayUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedNeuronInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedCustomUpdateInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedCustomWUUpdateInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedSynapseInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedSynapseConnectivityInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedSynapseSparseInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedCustomWUUpdateSparseInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedCustomConnectivityUpdatePreInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedCustomConnectivityUpdatePostInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedCustomConnectivityUpdateSparseInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); + void genMergedSynapseConnectivityHostInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup); void genMergedNeuronUpdateGroupStructs(CodeStream &os, const BackendBase &backend) const { genMergedStructures(os, backend, m_MergedNeuronUpdateGroups); } @@ -325,7 +348,7 @@ class GENN_EXPORT ModelSpecMerged } template - void createMergedGroups(const BackendBase &backend, + void createMergedGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, const std::vector> &unmergedGroups, std::vector &mergedGroups, D getHashDigest, GenMergedGroupFn generateGroup, bool host = false) { @@ -345,10 +368,15 @@ class GENN_EXPORT ModelSpecMerged // Loop through resultant merged groups size_t i = 0; for(const auto &p : protoMergedGroups) { - // Add group to vector + // Construct new merged group object mergedGroups.emplace_back(i, m_TypeContext, p.second); + + // Call generate function generateGroup(mergedGroups.back()); + // Assign memory spaces + mergedGroups.back().assignMemorySpaces(backend, memorySpaces); + // Loop through fields for(const auto &f : mergedGroups.back().getFields()) { // If field is dynamic, add record to merged EGPS @@ -372,7 +400,7 @@ class GENN_EXPORT ModelSpecMerged } template - void createMergedGroups(const BackendBase &backend, + void createMergedGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, const std::map &groups, std::vector &mergedGroups, F filter, D getHashDigest, G generateGroup, bool host = false) { @@ -385,7 +413,8 @@ class GENN_EXPORT ModelSpecMerged } // Merge filtered vector - createMergedGroups(backend, unmergedGroups, mergedGroups, getHashDigest, generateGroup, host); + createMergedGroups(backend, memorySpaces, unmergedGroups, mergedGroups, + getHashDigest, generateGroup, host); } //-------------------------------------------------------------------------- diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index f323f6cd32..7bd8684e36 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -309,7 +309,8 @@ Type::ResolvedType Backend::getPopulationRNGType() const return CURandState; } //-------------------------------------------------------------------------- -void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); @@ -341,7 +342,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host funcEnv.add(Type::Uint32.addConst(), "batch", "0"); } - genNeuronPrevSpikeTimeUpdateKernel(funcEnv, modelMerged, idNeuronPrevSpikeTimeUpdate); + genNeuronPrevSpikeTimeUpdateKernel(funcEnv, modelMerged, memorySpaces, idNeuronPrevSpikeTimeUpdate); } neuronUpdateEnv.getStream() << std::endl; } @@ -354,7 +355,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host neuronUpdateEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelNeuronSpikeQueueUpdate) << " * blockIdx.x + threadIdx.x;" << std::endl; - genNeuronSpikeQueueUpdateKernel(neuronUpdateEnv, modelMerged, idNeuronSpikeQueueUpdate); + genNeuronSpikeQueueUpdateKernel(neuronUpdateEnv, modelMerged, memorySpaces, idNeuronSpikeQueueUpdate); } neuronUpdateEnv.getStream() << std::endl; @@ -381,7 +382,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Add RNG functions to environment and generate kernel EnvironmentLibrary rngEnv(funcEnv, getRNGFunctions(model.getPrecision())); - genNeuronUpdateKernel(rngEnv, modelMerged, idStart); + genNeuronUpdateKernel(rngEnv, modelMerged, memorySpaces, idStart); } neuronUpdateEnv.getStream() << "void updateNeurons(" << modelMerged.getModel().getTimePrecision().getName() << " t"; @@ -448,7 +449,8 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host os << neuronUpdateStream.str(); } //-------------------------------------------------------------------------- -void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const { // Generate stream with synapse update code std::ostringstream synapseUpdateStream; @@ -467,7 +469,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos CodeStream::Scope b(os); synapseUpdateEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelSynapseDendriticDelayUpdate) << " * blockIdx.x + threadIdx.x;" << std::endl; - genSynapseDendriticDelayUpdateKernel(synapseUpdateEnv, modelMerged, idSynapseDendricDelayUpdate); + genSynapseDendriticDelayUpdateKernel(synapseUpdateEnv, modelMerged, memorySpaces, idSynapseDendricDelayUpdate); } synapseUpdateEnv.getStream() << std::endl; //} @@ -495,7 +497,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos // Add RNG functions to environment and generate kernel EnvironmentLibrary rngEnv(funcEnv, getRNGFunctions(model.getPrecision())); - genPresynapticUpdateKernel(rngEnv, modelMerged, idPresynapticStart); + genPresynapticUpdateKernel(rngEnv, modelMerged, memorySpaces, idPresynapticStart); } } @@ -519,7 +521,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos else { funcEnv.add(Type::Uint32.addConst(), "batch", "0"); } - genPostsynapticUpdateKernel(funcEnv, modelMerged, idPostsynapticStart); + genPostsynapticUpdateKernel(funcEnv, modelMerged, memorySpaces, idPostsynapticStart); } } @@ -543,7 +545,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos else { funcEnv.add(Type::Uint32.addConst(), "batch", "0"); } - genSynapseDynamicsKernel(funcEnv, modelMerged, idSynapseDynamicsStart); + genSynapseDynamicsKernel(funcEnv, modelMerged, memorySpaces, idSynapseDynamicsStart); } } @@ -622,7 +624,8 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos } //-------------------------------------------------------------------------- -void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); @@ -667,15 +670,15 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Custom updates" << std::endl; - genCustomUpdateKernel(funcEnv, modelMerged, g, idCustomUpdateStart); + genCustomUpdateKernel(funcEnv, modelMerged, memorySpaces, g, idCustomUpdateStart); funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Custom WU updates" << std::endl; - genCustomUpdateWUKernel(funcEnv, modelMerged, g, idCustomUpdateStart); + genCustomUpdateWUKernel(funcEnv, modelMerged, memorySpaces, g, idCustomUpdateStart); funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Custom connectivity updates" << std::endl; - genCustomConnectivityUpdateKernel(funcEnv, modelMerged, g, idCustomUpdateStart); + genCustomConnectivityUpdateKernel(funcEnv, modelMerged, memorySpaces, g, idCustomUpdateStart); } } @@ -694,7 +697,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Custom WU transpose updates" << std::endl; - genCustomTransposeUpdateWUKernel(funcEnv, modelMerged, g, idCustomTransposeUpdateStart); + genCustomTransposeUpdateWUKernel(funcEnv, modelMerged, memorySpaces, g, idCustomTransposeUpdateStart); } } customUpdateEnv.getStream() << "void update" << g << "()"; @@ -703,7 +706,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Loop through host update groups and generate code for those in this custom update group modelMerged.genMergedCustomConnectivityHostUpdateGroups( - *this, g, + *this, memorySpaces, g, [this, &customUpdateEnv, &modelMerged](auto &c) { c.generateUpdate(*this, customUpdateEnv, modelMerged); @@ -733,7 +736,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Loop through custom update host reduction groups and // generate reductions for those in this custom update group modelMerged.genMergedCustomUpdateHostReductionGroups( - *this, g, + *this, memorySpaces, g, [this, &customUpdateEnv, &modelMerged](auto &cg) { genNCCLReduction(customUpdateEnv, cg); @@ -742,7 +745,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Loop through custom WU update host reduction groups and // generate reductions for those in this custom update group modelMerged.genMergedCustomWUUpdateHostReductionGroups( - *this, g, + *this, memorySpaces, g, [this, &customUpdateEnv, &modelMerged](auto &cg) { genNCCLReduction(customUpdateEnv, cg); @@ -809,7 +812,8 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host os << customUpdateStream.str(); } //-------------------------------------------------------------------------- -void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); @@ -844,7 +848,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler CodeStream::Scope b(initEnv.getStream()); initEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelInitialize) << " * blockIdx.x + threadIdx.x;" << std::endl; - genInitializeKernel(initEnv, modelMerged, idInitStart); + genInitializeKernel(initEnv, modelMerged, memorySpaces, idInitStart); } const size_t numStaticInitThreads = idInitStart; @@ -864,7 +868,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler CodeStream::Scope b(initEnv.getStream()); initEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelInitializeSparse) << " * blockIdx.x + threadIdx.x;" << std::endl; - genInitializeSparseKernel(initEnv, modelMerged, numStaticInitThreads, idSparseInitStart); + genInitializeSparseKernel(initEnv, modelMerged, numStaticInitThreads, memorySpaces, idSparseInitStart); } } diff --git a/src/genn/backends/cuda/optimiser.cc b/src/genn/backends/cuda/optimiser.cc index 9a9ed57c2e..763ef94105 100644 --- a/src/genn/backends/cuda/optimiser.cc +++ b/src/genn/backends/cuda/optimiser.cc @@ -457,14 +457,18 @@ KernelOptimisationOutput optimizeBlockSize(int deviceID, const cudaDeviceProp &d // Create merged model ModelSpecMerged modelMerged(model, backend); + // Get memory spaces available to this backend + // **NOTE** Memory spaces are given out on a first-come, first-serve basis so subsequent groups are in preferential order + auto memorySpaces = backend.getMergedGroupMemorySpaces(modelMerged); + // Generate code with suffix so it doesn't interfere with primary generated code // **NOTE** we don't really need to generate all the code but, on windows, generating code selectively seems to result in werid b const std::string dryRunSuffix = "CUDAOptim"; - generateSynapseUpdate(outputPath, modelMerged, backend, dryRunSuffix); - generateNeuronUpdate(outputPath, modelMerged, backend, dryRunSuffix); - generateCustomUpdate(outputPath, modelMerged, backend, dryRunSuffix); - generateInit(outputPath, modelMerged, backend, dryRunSuffix); - generateRunner(outputPath, modelMerged, backend, dryRunSuffix); + generateSynapseUpdate(outputPath, modelMerged, backend, memorySpaces, dryRunSuffix); + generateNeuronUpdate(outputPath, modelMerged, backend, memorySpaces, dryRunSuffix); + generateCustomUpdate(outputPath, modelMerged, backend, memorySpaces, dryRunSuffix); + generateInit(outputPath, modelMerged, backend, memorySpaces, dryRunSuffix); + generateRunner(outputPath, modelMerged, backend, memorySpaces, dryRunSuffix); // Generate support code module if the backend supports namespaces if (backend.supportsNamespace()) { diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index bcf810b2c7..28f8d3ee29 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -98,7 +98,8 @@ void genKernelIteration(EnvironmentExternalBase &env, G &g, size_t numKernelDims //-------------------------------------------------------------------------- namespace GeNN::CodeGenerator::SingleThreadedCPU { -void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const { if(modelMerged.getModel().getBatchSize() != 1) { throw std::runtime_error("The single-threaded CPU backend only supports simulations with a batch size of 1"); @@ -127,7 +128,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host Timer t(funcEnv.getStream(), "neuronUpdate", modelMerged.getModel().isTimingEnabled()); modelMerged.genMergedNeuronPrevSpikeTimeUpdateGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &n) { CodeStream::Scope b(funcEnv.getStream()); @@ -184,7 +185,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Loop through merged neuron spike queue update groups modelMerged.genMergedNeuronSpikeQueueUpdateGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &n) { CodeStream::Scope b(funcEnv.getStream()); @@ -205,7 +206,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Loop through merged neuron update groups modelMerged.genMergedNeuronUpdateGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &n) { CodeStream::Scope b(funcEnv.getStream()); @@ -284,7 +285,8 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host os << neuronUpdateStream.str(); } //-------------------------------------------------------------------------- -void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const { if (modelMerged.getModel().getBatchSize() != 1) { throw std::runtime_error("The single-threaded CPU backend only supports simulations with a batch size of 1"); @@ -311,7 +313,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos { Timer t(funcEnv.getStream(), "synapseDynamics", modelMerged.getModel().isTimingEnabled()); modelMerged.genMergedSynapseDynamicsGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &s) { CodeStream::Scope b(funcEnv.getStream()); @@ -387,7 +389,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos { Timer t(funcEnv.getStream(), "presynapticUpdate", modelMerged.getModel().isTimingEnabled()); modelMerged.genMergedPresynapticUpdateGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &s) { CodeStream::Scope b(funcEnv.getStream()); @@ -422,7 +424,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos { Timer t(funcEnv.getStream(), "postsynapticUpdate", modelMerged.getModel().isTimingEnabled()); modelMerged.genMergedPostsynapticUpdateGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &s) { CodeStream::Scope b(funcEnv.getStream()); @@ -518,7 +520,8 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Hos os << synapseUpdateStream.str(); } //-------------------------------------------------------------------------- -void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); @@ -555,7 +558,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Loop through host update groups and generate code for those in this custom update group modelMerged.genMergedCustomConnectivityHostUpdateGroups( - *this, g, + *this, memorySpaces, g, [this, &customUpdateEnv, &modelMerged](auto &c) { c.generateUpdate(*this, customUpdateEnv, modelMerged); @@ -564,7 +567,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host { Timer t(funcEnv.getStream(), "customUpdate" + g, model.isTimingEnabled()); modelMerged.genMergedCustomUpdateGroups( - *this, g, + *this, memorySpaces, g, [this, &funcEnv](auto &c) { CodeStream::Scope b(funcEnv.getStream()); @@ -628,7 +631,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Loop through merged custom WU update groups modelMerged.genMergedCustomUpdateWUGroups( - *this, g, + *this, memorySpaces, g, [this, &funcEnv](auto &c) { CodeStream::Scope b(funcEnv.getStream()); @@ -709,7 +712,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host // Loop through merged custom connectivity update groups modelMerged.genMergedCustomConnectivityUpdateGroups( - *this, g, + *this, memorySpaces, g, [this, &funcEnv](auto &c) { CodeStream::Scope b(funcEnv.getStream()); @@ -746,7 +749,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host Timer t(funcEnv.getStream(), "customUpdate" + g + "Transpose", model.isTimingEnabled()); // Loop through merged custom connectivity update groups modelMerged.genMergedCustomUpdateTransposeWUGroups( - *this, g, + *this, memorySpaces, g, [this, &funcEnv](auto &c) { CodeStream::Scope b(funcEnv.getStream()); @@ -819,7 +822,8 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Host } //-------------------------------------------------------------------------- -void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler preambleHandler) const +void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces, + HostHandler preambleHandler) const { const ModelSpecInternal &model = modelMerged.getModel(); if(model.getBatchSize() != 1) { @@ -847,7 +851,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Neuron groups" << std::endl; modelMerged.genMergedNeuronInitGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &n) { CodeStream::Scope b(funcEnv.getStream()); @@ -865,7 +869,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Synapse groups" << std::endl; modelMerged.genMergedSynapseInitGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &s) { CodeStream::Scope b(funcEnv.getStream()); @@ -883,7 +887,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Custom update groups" << std::endl; modelMerged.genMergedCustomUpdateInitGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &c) { CodeStream::Scope b(funcEnv.getStream()); @@ -901,7 +905,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Custom connectivity presynaptic update groups" << std::endl; modelMerged.genMergedCustomConnectivityUpdatePreInitGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &c) { CodeStream::Scope b(funcEnv.getStream()); @@ -919,7 +923,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Custom connectivity postsynaptic update groups" << std::endl; modelMerged.genMergedCustomConnectivityUpdatePostInitGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &c) { CodeStream::Scope b(funcEnv.getStream()); @@ -937,7 +941,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Custom WU update groups" << std::endl; modelMerged.genMergedCustomWUUpdateInitGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &c) { CodeStream::Scope b(funcEnv.getStream()); @@ -955,7 +959,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Synapse sparse connectivity" << std::endl; modelMerged.genMergedSynapseConnectivityInitGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &s) { CodeStream::Scope b(funcEnv.getStream()); @@ -1112,7 +1116,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Synapse groups with sparse connectivity" << std::endl; modelMerged.genMergedSynapseSparseInitGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &s) { CodeStream::Scope b(funcEnv.getStream()); @@ -1173,7 +1177,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Custom sparse WU update groups" << std::endl; modelMerged.genMergedCustomWUUpdateSparseInitGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &c) { CodeStream::Scope b(funcEnv.getStream()); @@ -1202,7 +1206,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, HostHandler funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Custom connectivity update sparse init groups" << std::endl; modelMerged.genMergedCustomConnectivityUpdateSparseInitGroups( - *this, + *this, memorySpaces, [this, &funcEnv, &modelMerged](auto &c) { CodeStream::Scope b(funcEnv.getStream()); diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index f4bf853ffb..0361826cf2 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -307,14 +307,15 @@ void BackendSIMT::addPresynapticUpdateStrategy(PresynapticUpdateStrategySIMT::Ba s_PresynapticUpdateStrategies.push_back(strategy); } //-------------------------------------------------------------------------- -void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); // Parallelise over neuron groups idStart = 0; genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedNeuronPrevSpikeTimeUpdateGroups, + env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedNeuronPrevSpikeTimeUpdateGroups, [this](const NeuronGroupInternal &ng) { return padKernelSize(ng.getNumNeurons(), KernelNeuronUpdate); }, [batchSize, this](EnvironmentExternalBase &popEnv, NeuronPrevSpikeTimeUpdateGroupMerged &ng) { @@ -388,14 +389,15 @@ void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &en } //-------------------------------------------------------------------------- -void BackendSIMT::genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); // Loop through local neuron groups idStart = 0; modelMerged.genMergedNeuronSpikeQueueUpdateGroups( - *this, + *this, memorySpaces, [&env, &idStart, batchSize, this](auto &n) { if(idStart == 0) { @@ -430,7 +432,8 @@ void BackendSIMT::genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, }); } //-------------------------------------------------------------------------- -void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); @@ -489,7 +492,7 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM // Parallelise over neuron groups idStart = 0; genParallelGroup( - neuronEnv, modelMerged, idStart, &ModelSpecMerged::genMergedNeuronUpdateGroups, + neuronEnv, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedNeuronUpdateGroups, [this](const NeuronGroupInternal &ng) { return padKernelSize(ng.getNumNeurons(), KernelNeuronUpdate); }, [batchSize, &modelMerged, this](EnvironmentExternalBase &popEnv, NeuronUpdateGroupMerged &ng) { @@ -654,7 +657,8 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM }); } //-------------------------------------------------------------------------- -void BackendSIMT::genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const { // Loop through merged synapse groups idStart = 0; @@ -679,7 +683,8 @@ void BackendSIMT::genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase & env.getStream() << std::endl; } //-------------------------------------------------------------------------- -void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const { EnvironmentExternal kernelEnv(env); @@ -709,7 +714,7 @@ void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, Model // Parallelise over synapse groups idStart = 0; genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedPresynapticUpdateGroups, + kernelEnv, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedPresynapticUpdateGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumPresynapticUpdateThreads(sg, getPreferences()), KernelPresynapticUpdate); }, [&modelMerged, this](EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg) { @@ -744,7 +749,8 @@ void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, Model }); } //-------------------------------------------------------------------------- -void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const { EnvironmentExternal kernelEnv(env); @@ -757,7 +763,7 @@ void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, Mode // Parallelise over postsynaptic update groups idStart = 0; genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedPostsynapticUpdateGroups, + kernelEnv, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedPostsynapticUpdateGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumPostsynapticUpdateThreads(sg), KernelPostsynapticUpdate); }, [&modelMerged, this](EnvironmentExternalBase &env, PostsynapticUpdateGroupMerged &sg) { @@ -834,12 +840,13 @@ void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, Mode ); } //-------------------------------------------------------------------------- -void BackendSIMT::genSynapseDynamicsKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genSynapseDynamicsKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const { // Parallelise over synapse groups whose weight update models have code for synapse dynamics idStart = 0; genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedSynapseDynamicsGroups, + env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedSynapseDynamicsGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumSynapseDynamicsThreads(sg), KernelSynapseDynamicsUpdate); }, [&modelMerged, this](EnvironmentExternalBase &env, SynapseDynamicsGroupMerged &sg) { @@ -897,11 +904,11 @@ void BackendSIMT::genSynapseDynamicsKernel(EnvironmentExternalBase &env, ModelSp } //-------------------------------------------------------------------------- void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerged &modelMerged, - const std::string &updateGroup, size_t &idStart) const + BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroup, size_t &idStart) const { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); genParallelGroup( - env, modelMerged, updateGroup, idStart, &ModelSpecMerged::genMergedCustomUpdateGroups, + env, modelMerged, memorySpaces, updateGroup, idStart, &ModelSpecMerged::genMergedCustomUpdateGroups, [batchSize, this](const CustomUpdateInternal &cu) { return getPaddedNumCustomUpdateThreads(cu, batchSize); }, [batchSize, this](EnvironmentExternalBase &env, CustomUpdateGroupMerged &cg) { @@ -1033,11 +1040,11 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge } //-------------------------------------------------------------------------- void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, ModelSpecMerged &modelMerged, - const std::string &updateGroup, size_t &idStart) const + BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroup, size_t &idStart) const { const unsigned int batchSize = modelMerged.getModel().getBatchSize(); genParallelGroup( - env, modelMerged, updateGroup, idStart, &ModelSpecMerged::genMergedCustomUpdateWUGroups, + env, modelMerged, memorySpaces, updateGroup, idStart, &ModelSpecMerged::genMergedCustomUpdateWUGroups, [batchSize, this](const CustomUpdateWUInternal &cu) { return getPaddedNumCustomUpdateWUThreads(cu, batchSize); }, [batchSize, this](EnvironmentExternalBase &env, CustomUpdateWUGroupMerged &cg) { @@ -1156,13 +1163,13 @@ void BackendSIMT::genCustomUpdateWUKernel(EnvironmentExternal &env, ModelSpecMer } //-------------------------------------------------------------------------- void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, ModelSpecMerged &modelMerged, - const std::string &updateGroup, size_t &idStart) const + BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroup, size_t &idStart) const { // Generate 2D array const size_t blockSize = getKernelBlockSize(KernelCustomTransposeUpdate); env.getStream() << getSharedPrefix() << " float shTile[" << blockSize << "][" << (blockSize + 1) << "];" << std::endl; genParallelGroup( - env, modelMerged, updateGroup, idStart, &ModelSpecMerged::genMergedCustomUpdateTransposeWUGroups, + env, modelMerged, memorySpaces, updateGroup, idStart, &ModelSpecMerged::genMergedCustomUpdateTransposeWUGroups, [&modelMerged, this](const CustomUpdateWUInternal &cu) { return getPaddedNumCustomUpdateTransposeWUThreads(cu, modelMerged.getModel().getBatchSize()); }, [blockSize, this](EnvironmentExternalBase &env, CustomUpdateTransposeWUGroupMerged &cg) { @@ -1271,11 +1278,11 @@ void BackendSIMT::genCustomTransposeUpdateWUKernel(EnvironmentExternal &env, Mod } //-------------------------------------------------------------------------- void BackendSIMT::genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, - const std::string &updateGroup, size_t &idStart) const + BackendBase::MemorySpaces &memorySpaces, const std::string &updateGroup, size_t &idStart) const { // Parallelise across presynaptic neurons genParallelGroup( - env, modelMerged, updateGroup, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdateGroups, + env, modelMerged, memorySpaces, updateGroup, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdateGroups, [this](const CustomConnectivityUpdateInternal &cg) { return padSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelCustomUpdate); }, [&modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdateGroupMerged &cg) { @@ -1307,13 +1314,14 @@ void BackendSIMT::genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env }); } //-------------------------------------------------------------------------- -void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, size_t &idStart) const +void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, + BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const { env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Local neuron groups" << std::endl; idStart = 0; genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedNeuronInitGroups, + env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedNeuronInitGroups, [this](const NeuronGroupInternal &ng) { return padKernelSize(ng.getNumNeurons(), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, NeuronInitGroupMerged &ng) { @@ -1357,7 +1365,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Synapse groups" << std::endl; genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedSynapseInitGroups, + env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedSynapseInitGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumInitThreads(sg), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, SynapseInitGroupMerged &sg) { @@ -1370,7 +1378,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Custom update groups" << std::endl; genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedCustomUpdateInitGroups, + env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomUpdateInitGroups, [this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, CustomUpdateInitGroupMerged &cg) { @@ -1395,7 +1403,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Custom WU update groups" << std::endl; genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedCustomWUUpdateInitGroups, + env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomWUUpdateInitGroups, [this](const CustomUpdateWUInternal &cg) { return padKernelSize(getNumInitThreads(cg), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &cg) { @@ -1408,7 +1416,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Custom connectivity presynaptic update groups" << std::endl; genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdatePreInitGroups, + env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdatePreInitGroups, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdatePreInitGroupMerged &cg) { @@ -1440,7 +1448,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Custom connectivity postsynaptic update groups" << std::endl; genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdatePostInitGroups, + env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdatePostInitGroups, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons(), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdatePostInitGroupMerged &cg) { @@ -1472,7 +1480,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Synapse groups with sparse connectivity" << std::endl; genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedSynapseConnectivityInitGroups, + env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedSynapseConnectivityInitGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumConnectivityInitThreads(sg), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, SynapseConnectivityInitGroupMerged &sg) { @@ -1615,7 +1623,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer } //-------------------------------------------------------------------------- void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelSpecMerged &modelMerged, - size_t numInitializeThreads, size_t &idStart) const + size_t numInitializeThreads, BackendBase::MemorySpaces &memorySpaces, size_t &idStart) const { EnvironmentExternal envKernel(env); envKernel.add(Type::Void, "_sh_row_length", "shRowLength", @@ -1623,7 +1631,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // Initialise weight update variables for synapse groups with sparse connectivity genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedSynapseSparseInitGroups, + env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedSynapseSparseInitGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumConnectivityInitThreads(sg), KernelInitializeSparse); }, [&modelMerged, numInitializeThreads, this](EnvironmentExternalBase &env, SynapseSparseInitGroupMerged &sg) { @@ -1664,7 +1672,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // Initialise weight update variables for synapse groups with sparse connectivity genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedCustomWUUpdateSparseInitGroups, + env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomWUUpdateSparseInitGroups, [this](const CustomUpdateWUInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, CustomWUUpdateSparseInitGroupMerged &cg) { @@ -1686,7 +1694,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // Initialise weight update variables for synapse groups with sparse connectivity genParallelGroup( - env, modelMerged, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdateSparseInitGroups, + env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdateSparseInitGroups, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdateSparseInitGroupMerged &cg) { diff --git a/src/genn/genn/code_generator/generateModules.cc b/src/genn/genn/code_generator/generateModules.cc index 155ee53872..da1009c8ce 100644 --- a/src/genn/genn/code_generator/generateModules.cc +++ b/src/genn/genn/code_generator/generateModules.cc @@ -112,13 +112,17 @@ std::pair, MemAlloc> generateAll(const ModelSpecInterna //const auto hashDigest = modelMerged.getHashDigest(backend); MemAlloc mem = MemAlloc::zero(); if(true/*forceRebuild || shouldRebuildModel(outputPath, hashDigest, mem)*/) { + // Get memory spaces available to this backend + // **NOTE** Memory spaces are given out on a first-come, first-serve basis so subsequent groups are in preferential order + auto memorySpaces = backend.getMergedGroupMemorySpaces(modelMerged); + // Generate modules // **NOTE** these are ordered in terms of memory-space priority - generateSynapseUpdate(outputPath, modelMerged, backend); - generateNeuronUpdate(outputPath, modelMerged, backend); - generateCustomUpdate(outputPath, modelMerged, backend); - generateInit(outputPath, modelMerged, backend); - mem = generateRunner(outputPath, modelMerged, backend); + generateSynapseUpdate(outputPath, modelMerged, backend, memorySpaces); + generateNeuronUpdate(outputPath, modelMerged, backend, memorySpaces); + generateCustomUpdate(outputPath, modelMerged, backend, memorySpaces); + generateInit(outputPath, modelMerged, backend, memorySpaces); + mem = generateRunner(outputPath, modelMerged, backend, memorySpaces); // Generate support code module if the backend supports namespaces if(backend.supportsNamespace()) { @@ -188,8 +192,8 @@ std::pair, MemAlloc> generateAll(const ModelSpecInterna return std::make_pair(modules, mem); } //-------------------------------------------------------------------------- -void generateNeuronUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix) +void generateNeuronUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, + BackendBase::MemorySpaces &memorySpaces, const std::string &suffix) { // Create output stream to write to file and wrap in CodeStream std::ofstream neuronUpdateStream((outputPath / ("neuronUpdate" + suffix + ".cc")).str()); @@ -202,7 +206,7 @@ void generateNeuronUpdate(const filesystem::path &outputPath, ModelSpecMerged &m neuronUpdate << std::endl; // Neuron update kernel - backend.genNeuronUpdate(neuronUpdate, modelMerged, + backend.genNeuronUpdate(neuronUpdate, modelMerged, memorySpaces, // Preamble handler [&modelMerged, &backend](CodeStream &os) { @@ -212,8 +216,8 @@ void generateNeuronUpdate(const filesystem::path &outputPath, ModelSpecMerged &m }); } //-------------------------------------------------------------------------- -void generateCustomUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix) +void generateCustomUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, + BackendBase::MemorySpaces &memorySpaces, const std::string &suffix) { // Create output stream to write to file and wrap in CodeStream std::ofstream customUpdateStream((outputPath / ("customUpdate" + suffix + ".cc")).str()); @@ -223,7 +227,7 @@ void generateCustomUpdate(const filesystem::path &outputPath, ModelSpecMerged &m customUpdate << std::endl; // Neuron update kernel - backend.genCustomUpdate(customUpdate, modelMerged, + backend.genCustomUpdate(customUpdate, modelMerged, memorySpaces, // Preamble handler [&modelMerged, &backend](CodeStream &os) { @@ -235,8 +239,8 @@ void generateCustomUpdate(const filesystem::path &outputPath, ModelSpecMerged &m }); } //-------------------------------------------------------------------------- -void generateSynapseUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix) +void generateSynapseUpdate(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, + BackendBase::MemorySpaces &memorySpaces, const std::string &suffix) { // Create output stream to write to file and wrap in CodeStream std::ofstream synapseUpdateStream((outputPath / ("synapseUpdate" + suffix + ".cc")).str()); @@ -249,7 +253,7 @@ void generateSynapseUpdate(const filesystem::path &outputPath, ModelSpecMerged & synapseUpdate << std::endl; // Synaptic update kernels - backend.genSynapseUpdate(synapseUpdate, modelMerged, + backend.genSynapseUpdate(synapseUpdate, modelMerged, memorySpaces, // Preamble handler [&modelMerged, &backend](CodeStream &os) { @@ -260,8 +264,8 @@ void generateSynapseUpdate(const filesystem::path &outputPath, ModelSpecMerged & }); } //-------------------------------------------------------------------------- -void generateInit(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix) +void generateInit(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, const BackendBase &backend, + BackendBase::MemorySpaces &memorySpaces, const std::string &suffix) { // Create output stream to write to file and wrap in CodeStream std::ofstream initStream((outputPath / ("init" + suffix + ".cc")).str()); @@ -269,7 +273,7 @@ void generateInit(const filesystem::path &outputPath, ModelSpecMerged &modelMerg init << "#include \"definitionsInternal" << suffix << ".h\"" << std::endl; - backend.genInit(init, modelMerged, + backend.genInit(init, modelMerged, memorySpaces, // Preamble handler [&modelMerged, &backend](CodeStream &os) { diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 530924c214..3a37578bf0 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -523,7 +523,7 @@ void genCustomUpdate(const ModelSpecMerged &modelMerged, const BackendBase &back // GeNN::CodeGenerator //-------------------------------------------------------------------------- MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, ModelSpecMerged &modelMerged, - const BackendBase &backend, const std::string &suffix) + const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, const std::string &suffix) { // Create output streams to write to file and wrap in CodeStreams std::ofstream definitionsStream((outputPath / ("definitions" + suffix + ".h")).str()); @@ -738,7 +738,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, std::ostringstream synapseConnectivityHostInitStream; CodeStream synapseConnectivityHostInit(synapseConnectivityHostInitStream); modelMerged.genMergedSynapseConnectivityHostInitGroups( - backend, + backend, memorySpaces, [&backend, &modelMerged, &synapseConnectivityHostInit](auto &sg) { EnvironmentExternal env(synapseConnectivityHostInit); diff --git a/src/genn/genn/code_generator/modelSpecMerged.cc b/src/genn/genn/code_generator/modelSpecMerged.cc index 9aa7b8dcaf..fdf66fea47 100644 --- a/src/genn/genn/code_generator/modelSpecMerged.cc +++ b/src/genn/genn/code_generator/modelSpecMerged.cc @@ -11,132 +11,53 @@ using namespace GeNN; using namespace GeNN::CodeGenerator; -//---------------------------------------------------------------------------- -// Anonymous namespace -//---------------------------------------------------------------------------- -namespace -{ -template -void assignGroups(const BackendBase &backend, std::vector &groups, BackendBase::MemorySpaces &memorySpaces) -{ - // Loop through groups and assign groups - for(auto &g : groups) { - g.assignMemorySpaces(backend, memorySpaces); - } -} -} - //---------------------------------------------------------------------------- // GeNN::CodeGenerator::ModelSpecMerged //---------------------------------------------------------------------------- -ModelSpecMerged::ModelSpecMerged(const ModelSpecInternal &model, const BackendBase &backend) -: m_Model(model), m_NeuronUpdateSupportCode("NeuronUpdateSupportCode"), m_PostsynapticDynamicsSupportCode("PostsynapticDynamicsSupportCode"), - m_PresynapticUpdateSupportCode("PresynapticUpdateSupportCode"), m_PostsynapticUpdateSupportCode("PostsynapticUpdateSupportCode"), - m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode"), m_TypeContext{{"scalar", model.getPrecision()}, {"timepoint", model.getTimePrecision()}} -{ - - // Get memory spaces available to this backend - // **NOTE** Memory spaces are given out on a first-come, first-serve basis so subsequent groups are in preferential order - auto memorySpaces = backend.getMergedGroupMemorySpaces(*this); - - // Loop through dendritic delay update groups and assign memory spaces - assignGroups(backend, m_MergedSynapseDendriticDelayUpdateGroups, memorySpaces); - - // Loop through merged presynaptic update groups, assign memory spaces and add support code - for(auto &sg : m_MergedPresynapticUpdateGroups) { - sg.assignMemorySpaces(backend, memorySpaces); - m_PresynapticUpdateSupportCode.addSupportCode(sg.getArchetype().getWUModel()->getSimSupportCode()); - } - - // Loop through merged postsynaptic update groups, assign memory spaces and add support code - for(auto &sg : m_MergedPostsynapticUpdateGroups) { - sg.assignMemorySpaces(backend, memorySpaces); - m_PostsynapticUpdateSupportCode.addSupportCode(sg.getArchetype().getWUModel()->getLearnPostSupportCode()); - } - - // Loop through merged synapse dynamics groups, assign memory spaces and add support code - for(auto &sg : m_MergedSynapseDynamicsGroups) { - sg.assignMemorySpaces(backend, memorySpaces); - m_SynapseDynamicsSupportCode.addSupportCode(sg.getArchetype().getWUModel()->getSynapseDynamicsSuppportCode()); - } - - // Loop through previous spike time and spike queue update groups and assign memory spaces - assignGroups(backend, m_MergedNeuronPrevSpikeTimeUpdateGroups, memorySpaces); - assignGroups(backend, m_MergedNeuronSpikeQueueUpdateGroups, memorySpaces); - - // Loop through merged neuron groups - for(auto &ng : m_MergedNeuronUpdateGroups) { - // Assign memory spaces - ng.assignMemorySpaces(backend, memorySpaces); - - // Add neuron support code - m_NeuronUpdateSupportCode.addSupportCode(ng.getArchetype().getNeuronModel()->getSupportCode()); - - // Loop through merged postsynaptic models and add their support code - for(const auto &sg : ng.getArchetype().getFusedPSMInSyn()) { - m_PostsynapticDynamicsSupportCode.addSupportCode(sg->getPSModel()->getSupportCode()); - } - } - - // Loop through custom update groups and assign memory spaces - assignGroups(backend, m_MergedCustomUpdateGroups, memorySpaces); - assignGroups(backend, m_MergedCustomUpdateWUGroups, memorySpaces); - assignGroups(backend, m_MergedCustomUpdateTransposeWUGroups, memorySpaces); - assignGroups(backend, m_MergedCustomConnectivityUpdateGroups, memorySpaces); - - // Loop through init groups and assign memory spaces - assignGroups(backend, m_MergedNeuronInitGroups, memorySpaces); - assignGroups(backend, m_MergedSynapseInitGroups, memorySpaces); - assignGroups(backend, m_MergedSynapseSparseInitGroups, memorySpaces); - assignGroups(backend, m_MergedSynapseConnectivityInitGroups, memorySpaces); - assignGroups(backend, m_MergedCustomUpdateInitGroups, memorySpaces); - assignGroups(backend, m_MergedCustomWUUpdateInitGroups, memorySpaces); - assignGroups(backend, m_MergedCustomWUUpdateSparseInitGroups, memorySpaces); - assignGroups(backend, m_MergedCustomConnectivityUpdatePreInitGroups, memorySpaces); - assignGroups(backend, m_MergedCustomConnectivityUpdatePostInitGroups, memorySpaces); - assignGroups(backend, m_MergedCustomConnectivityUpdateSparseInitGroups, memorySpaces); -} -//---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedNeuronUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedNeuronUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronUpdateGroups, + createMergedGroups(backend, memorySpaces, getModel().getNeuronGroups(), m_MergedNeuronUpdateGroups, [](const NeuronGroupInternal &){ return true; }, &NeuronGroupInternal::getHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedPresynapticUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedPresynapticUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedPresynapticUpdateGroups, + createMergedGroups(backend, memorySpaces, getModel().getSynapseGroups(), m_MergedPresynapticUpdateGroups, [](const SynapseGroupInternal &sg) { return (sg.isSpikeEventRequired() || sg.isTrueSpikeRequired()); }, &SynapseGroupInternal::getWUHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedPostsynapticUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedPostsynapticUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedPostsynapticUpdateGroups, + createMergedGroups(backend, memorySpaces, getModel().getSynapseGroups(), m_MergedPostsynapticUpdateGroups, [](const SynapseGroupInternal &sg){ return !Utils::areTokensEmpty(sg.getWUPostLearnCodeTokens()); }, &SynapseGroupInternal::getWUHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedSynapseDynamicsGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedSynapseDynamicsGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseDynamicsGroups, + createMergedGroups(backend, memorySpaces, getModel().getSynapseGroups(), m_MergedSynapseDynamicsGroups, [](const SynapseGroupInternal &sg){ return !Utils::areTokensEmpty(sg.getWUSynapseDynamicsCodeTokens()); }, &SynapseGroupInternal::getWUHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + const std::string &updateGroupName, GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateGroups, + createMergedGroups(backend, memorySpaces, getModel().getCustomUpdates(), m_MergedCustomUpdateGroups, [&updateGroupName](const CustomUpdateInternal &cg) { return cg.getUpdateGroupName() == updateGroupName; }, &CustomUpdateInternal::getHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomUpdateWUGroups(const BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomUpdateWUGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + const std::string &updateGroupName, GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, + createMergedGroups(backend, memorySpaces, getModel().getCustomWUUpdates(), m_MergedCustomUpdateWUGroups, [&updateGroupName](const CustomUpdateWUInternal &cg) { return (!cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName); @@ -144,10 +65,10 @@ void ModelSpecMerged::genMergedCustomUpdateWUGroups(const BackendBase &backend, &CustomUpdateWUInternal::getHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomUpdateTransposeWUGroups(const BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomUpdateTransposeWUGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + const std::string &updateGroupName, GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, + createMergedGroups(backend, memorySpaces, getModel().getCustomWUUpdates(), m_MergedCustomUpdateTransposeWUGroups, [&updateGroupName](const CustomUpdateWUInternal &cg) { return (cg.isTransposeOperation() && cg.getUpdateGroupName() == updateGroupName); @@ -155,10 +76,10 @@ void ModelSpecMerged::genMergedCustomUpdateTransposeWUGroups(const BackendBase & &CustomUpdateWUInternal::getHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomUpdateHostReductionGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + const std::string &updateGroupName, GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateHostReductionGroups, + createMergedGroups(backend, memorySpaces, getModel().getCustomUpdates(), m_MergedCustomUpdateHostReductionGroups, [&updateGroupName](const CustomUpdateInternal &cg) { return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName); @@ -166,10 +87,10 @@ void ModelSpecMerged::genMergedCustomUpdateHostReductionGroups(const BackendBase &CustomUpdateInternal::getHashDigest, generateGroup, true); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomWUUpdateHostReductionGroups(const BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomWUUpdateHostReductionGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + const std::string &updateGroupName, GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateHostReductionGroups, + createMergedGroups(backend, memorySpaces, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateHostReductionGroups, [&updateGroupName](const CustomUpdateWUInternal &cg) { return (cg.isBatchReduction() && cg.getUpdateGroupName() == updateGroupName); @@ -177,10 +98,10 @@ void ModelSpecMerged::genMergedCustomWUUpdateHostReductionGroups(const BackendBa &CustomUpdateWUInternal::getHashDigest, generateGroup, true); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomConnectivityUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomConnectivityUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + const std::string &updateGroupName, GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateGroups, + createMergedGroups(backend, memorySpaces, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateGroups, [&updateGroupName](const CustomConnectivityUpdateInternal &cg) { return (!Utils::areTokensEmpty(cg.getRowUpdateCodeTokens()) && cg.getUpdateGroupName() == updateGroupName); @@ -188,10 +109,10 @@ void ModelSpecMerged::genMergedCustomConnectivityUpdateGroups(const BackendBase &CustomConnectivityUpdateInternal::getHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomConnectivityHostUpdateGroups(const BackendBase &backend, const std::string &updateGroupName, - GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomConnectivityHostUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + const std::string &updateGroupName, GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityHostUpdateGroups, + createMergedGroups(backend, memorySpaces, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityHostUpdateGroups, [&updateGroupName](const CustomConnectivityUpdateInternal &cg) { return (!Utils::areTokensEmpty(cg.getHostUpdateCodeTokens()) && cg.getUpdateGroupName() == updateGroupName); @@ -199,21 +120,24 @@ void ModelSpecMerged::genMergedCustomConnectivityHostUpdateGroups(const BackendB &CustomConnectivityUpdateInternal::getHashDigest, generateGroup, true); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedNeuronSpikeQueueUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedNeuronSpikeQueueUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronSpikeQueueUpdateGroups, + createMergedGroups(backend, memorySpaces, getModel().getNeuronGroups(), m_MergedNeuronSpikeQueueUpdateGroups, [](const NeuronGroupInternal &){ return true; }, &NeuronGroupInternal::getSpikeQueueUpdateHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedNeuronPrevSpikeTimeUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedNeuronPrevSpikeTimeUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronPrevSpikeTimeUpdateGroups, + createMergedGroups(backend, memorySpaces, getModel().getNeuronGroups(), m_MergedNeuronPrevSpikeTimeUpdateGroups, [](const NeuronGroupInternal &ng){ return (ng.isPrevSpikeTimeRequired() || ng.isPrevSpikeEventTimeRequired()); }, &NeuronGroupInternal::getPrevSpikeTimeUpdateHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedSynapseDendriticDelayUpdateGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedSynapseDendriticDelayUpdateGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { std::vector> synapseGroupsWithDendriticDelay; for(const auto &n : getModel().getNeuronGroups()) { @@ -223,27 +147,30 @@ void ModelSpecMerged::genMergedSynapseDendriticDelayUpdateGroups(const BackendBa } } } - createMergedGroups(backend, synapseGroupsWithDendriticDelay, m_MergedSynapseDendriticDelayUpdateGroups, + createMergedGroups(backend, memorySpaces, synapseGroupsWithDendriticDelay, m_MergedSynapseDendriticDelayUpdateGroups, &SynapseGroupInternal::getDendriticDelayUpdateHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedNeuronInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedNeuronInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getNeuronGroups(), m_MergedNeuronInitGroups, + createMergedGroups(backend, memorySpaces, getModel().getNeuronGroups(), m_MergedNeuronInitGroups, [](const NeuronGroupInternal &){ return true; }, &NeuronGroupInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomUpdateInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomUpdateInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getCustomUpdates(), m_MergedCustomUpdateInitGroups, + createMergedGroups(backend, memorySpaces, getModel().getCustomUpdates(), m_MergedCustomUpdateInitGroups, [](const CustomUpdateInternal &cg) { return cg.isVarInitRequired(); }, &CustomUpdateInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomWUUpdateInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomWUUpdateInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateInitGroups, + createMergedGroups(backend, memorySpaces, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateInitGroups, [](const CustomUpdateWUInternal &cg) { return (((cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::DENSE) @@ -253,9 +180,10 @@ void ModelSpecMerged::genMergedCustomWUUpdateInitGroups(const BackendBase &backe &CustomUpdateWUInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedSynapseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedSynapseInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseInitGroups, + createMergedGroups(backend, memorySpaces, getModel().getSynapseGroups(), m_MergedSynapseInitGroups, [](const SynapseGroupInternal &sg) { return (((sg.getMatrixType() & SynapseMatrixConnectivity::DENSE) @@ -265,16 +193,18 @@ void ModelSpecMerged::genMergedSynapseInitGroups(const BackendBase &backend, Gen &SynapseGroupInternal::getWUInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedSynapseConnectivityInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedSynapseConnectivityInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseConnectivityInitGroups, + createMergedGroups(backend, memorySpaces, getModel().getSynapseGroups(), m_MergedSynapseConnectivityInitGroups, [](const SynapseGroupInternal &sg){ return sg.isSparseConnectivityInitRequired(); }, &SynapseGroupInternal::getConnectivityInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedSynapseSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedSynapseSparseInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseSparseInitGroups, + createMergedGroups(backend, memorySpaces, getModel().getSynapseGroups(), m_MergedSynapseSparseInitGroups, [&backend](const SynapseGroupInternal &sg) { return ((sg.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && @@ -284,9 +214,10 @@ void ModelSpecMerged::genMergedSynapseSparseInitGroups(const BackendBase &backen &SynapseGroupInternal::getWUInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomWUUpdateSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomWUUpdateSparseInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateSparseInitGroups, + createMergedGroups(backend, memorySpaces, getModel().getCustomWUUpdates(), m_MergedCustomWUUpdateSparseInitGroups, [](const CustomUpdateWUInternal &cg) { return (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) && cg.isVarInitRequired(); @@ -294,9 +225,10 @@ void ModelSpecMerged::genMergedCustomWUUpdateSparseInitGroups(const BackendBase &CustomUpdateWUInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomConnectivityUpdatePreInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomConnectivityUpdatePreInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePreInitGroups, + createMergedGroups(backend, memorySpaces, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePreInitGroups, [&backend](const CustomConnectivityUpdateInternal &cg) { return (cg.isPreVarInitRequired() || (backend.isPopulationRNGInitialisedOnDevice() && Utils::isRNGRequired(cg.getRowUpdateCodeTokens()))); @@ -304,23 +236,26 @@ void ModelSpecMerged::genMergedCustomConnectivityUpdatePreInitGroups(const Backe &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomConnectivityUpdatePostInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomConnectivityUpdatePostInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePostInitGroups, + createMergedGroups(backend, memorySpaces, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdatePostInitGroups, [](const CustomConnectivityUpdateInternal &cg) { return cg.isPostVarInitRequired(); }, &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedCustomConnectivityUpdateSparseInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedCustomConnectivityUpdateSparseInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateSparseInitGroups, + createMergedGroups(backend, memorySpaces, getModel().getCustomConnectivityUpdates(), m_MergedCustomConnectivityUpdateSparseInitGroups, [](const CustomConnectivityUpdateInternal &cg) { return cg.isVarInitRequired(); }, &CustomConnectivityUpdateInternal::getInitHashDigest, generateGroup); } //---------------------------------------------------------------------------- -void ModelSpecMerged::genMergedSynapseConnectivityHostInitGroups(const BackendBase &backend, GenMergedGroupFn generateGroup) +void ModelSpecMerged::genMergedSynapseConnectivityHostInitGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + GenMergedGroupFn generateGroup) { - createMergedGroups(backend, getModel().getSynapseGroups(), m_MergedSynapseConnectivityHostInitGroups, + createMergedGroups(backend, memorySpaces, getModel().getSynapseGroups(), m_MergedSynapseConnectivityHostInitGroups, [](const SynapseGroupInternal &sg) { return !sg.getConnectivityInitialiser().getSnippet()->getHostInitCode().empty(); From d8d2c1d8be7d24ed18132f13fdb65c47059ee5f6 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 16:09:17 +0100 Subject: [PATCH 357/725] fixed typo --- src/genn/genn/code_generator/backendSIMT.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 0361826cf2..a97b926f3c 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -501,7 +501,7 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM genNeuronIndexCalculation(groupEnv, batchSize); // Call handler to generate generic neuron code - popEnv.print("if($(id) < $(num_neurons))"); + groupEnv.print("if($(id) < $(num_neurons))"); { // Add population RNG field groupEnv.addField(getPopulationRNGType().createPointer(), "_rng", "rng", From d7ec40d32979f342c4f0cf22f35a3c29e471afc0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 16:18:00 +0100 Subject: [PATCH 358/725] dt --- src/genn/backends/cuda/backend.cc | 46 ++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 7bd8684e36..096508796c 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -331,7 +331,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back CodeStream::Scope b(neuronUpdateEnv.getStream()); EnvironmentExternal funcEnv(neuronUpdateEnv); - funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); + funcEnv.add(model.getTimePrecision().addConst(), "t", "t"); funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelNeuronPrevSpikeTimeUpdate) << " * blockIdx.x + threadIdx.x;" << std::endl; if(model.getBatchSize() > 1) { @@ -370,7 +370,8 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back EnvironmentExternal funcEnv(neuronUpdateEnv); funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); - + funcEnv.add(model.getTimePrecision().addConst(), "dt", + writePreciseLiteral(model.getDT(), model.getTimePrecision())); funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelNeuronUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; if(model.getBatchSize() > 1) { funcEnv.getStream() << "const unsigned int batch = blockIdx.y;" << std::endl; @@ -484,8 +485,9 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Bac CodeStream::Scope b(synapseUpdateEnv.getStream()); EnvironmentExternal funcEnv(synapseUpdateEnv); - funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); - + funcEnv.add(model.getTimePrecision().addConst(), "t", "t"); + funcEnv.add(model.getTimePrecision().addConst(), "dt", + writePreciseLiteral(model.getDT(), model.getTimePrecision())); funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelPresynapticUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; if(model.getBatchSize() > 1) { funcEnv.getStream() << "const unsigned int batch = blockIdx.y;" << std::endl; @@ -511,8 +513,9 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Bac CodeStream::Scope b(synapseUpdateEnv.getStream()); EnvironmentExternal funcEnv(synapseUpdateEnv); - funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); - + funcEnv.add(model.getTimePrecision().addConst(), "t", "t"); + funcEnv.add(model.getTimePrecision().addConst(), "dt", + writePreciseLiteral(model.getDT(), model.getTimePrecision())); funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelPostsynapticUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; if(model.getBatchSize() > 1) { funcEnv.getStream() << "const unsigned int batch = blockIdx.y;" << std::endl; @@ -535,8 +538,9 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Bac CodeStream::Scope b(synapseUpdateEnv.getStream()); EnvironmentExternal funcEnv(synapseUpdateEnv); - funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); - + funcEnv.add(model.getTimePrecision().addConst(), "t", "t"); + funcEnv.add(model.getTimePrecision().addConst(), "dt", + writePreciseLiteral(model.getDT(), model.getTimePrecision())); funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelSynapseDynamicsUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; if(model.getBatchSize() > 1) { funcEnv.getStream() << "const unsigned int batch = blockIdx.y;" << std::endl; @@ -664,8 +668,9 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back CodeStream::Scope b(customUpdateEnv.getStream()); EnvironmentExternal funcEnv(customUpdateEnv); - funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); - + funcEnv.add(model.getTimePrecision().addConst(), "t", "t"); + funcEnv.add(model.getTimePrecision().addConst(), "dt", + writePreciseLiteral(model.getDT(), model.getTimePrecision())); funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelCustomUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; @@ -691,8 +696,9 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back CodeStream::Scope b(customUpdateEnv.getStream()); EnvironmentExternal funcEnv(customUpdateEnv); - funcEnv.add(modelMerged.getModel().getTimePrecision().addConst(), "t", "t"); - + funcEnv.add(model.getTimePrecision().addConst(), "t", "t"); + funcEnv.add(model.getTimePrecision().addConst(), "dt", + writePreciseLiteral(model.getDT(), model.getTimePrecision())); funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelCustomTransposeUpdate) << " * blockIdx.x + threadIdx.x; " << std::endl; funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; @@ -846,9 +852,13 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: { // common variables for all cases CodeStream::Scope b(initEnv.getStream()); + + EnvironmentExternal funcEnv(initEnv); + funcEnv.add(model.getTimePrecision().addConst(), "dt", + writePreciseLiteral(model.getDT(), model.getTimePrecision())); - initEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelInitialize) << " * blockIdx.x + threadIdx.x;" << std::endl; - genInitializeKernel(initEnv, modelMerged, memorySpaces, idInitStart); + funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelInitialize) << " * blockIdx.x + threadIdx.x;" << std::endl; + genInitializeKernel(funcEnv, modelMerged, memorySpaces, idInitStart); } const size_t numStaticInitThreads = idInitStart; @@ -867,8 +877,12 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: { CodeStream::Scope b(initEnv.getStream()); - initEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelInitializeSparse) << " * blockIdx.x + threadIdx.x;" << std::endl; - genInitializeSparseKernel(initEnv, modelMerged, numStaticInitThreads, memorySpaces, idSparseInitStart); + EnvironmentExternal funcEnv(initEnv); + funcEnv.add(model.getTimePrecision().addConst(), "dt", + writePreciseLiteral(model.getDT(), model.getTimePrecision())); + + funcEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelInitializeSparse) << " * blockIdx.x + threadIdx.x;" << std::endl; + genInitializeSparseKernel(funcEnv, modelMerged, numStaticInitThreads, memorySpaces, idSparseInitStart); } } From 7fe7e7b9c2e11865750ee35d635a1f55f73b6b07 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 16:25:01 +0100 Subject: [PATCH 359/725] in SIMT backend, spike recording fields aren't accessed from within emitSpike so need adding at the top-level scope --- include/genn/genn/code_generator/backendBase.h | 6 ++++++ src/genn/genn/code_generator/backendSIMT.cc | 12 ++++++------ .../code_generator/neuronUpdateGroupMerged.cc | 16 ---------------- 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 84b489d29c..03bd805ff5 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -481,6 +481,12 @@ class GENN_EXPORT BackendBase env.addField(env.getGroup().getTimeType().createPointer(), "_prev_spk_evnt_time", "prevSET", [this](const auto &g, size_t) { return getDeviceVarPrefix() + "prevSET" + g.getName(); }); + env.addField(Type::Uint32.createPointer(), "_record_spk", "recordSpk", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "recordSpk" + g.getName(); }, + "", GroupMergedFieldType::DYNAMIC); + env.addField(Type::Uint32.createPointer(), "_record_spk_event", "recordSpkEvent", + [this](const auto &g, size_t){ return getDeviceVarPrefix() + "recordSpkEvent" + g.getName(); }, + "", GroupMergedFieldType::DYNAMIC); // If batching is enabled, calculate batch offset if(batchSize > 1) { diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index a97b926f3c..d0a3b834f9 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -636,20 +636,20 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM CodeStream::Scope c(groupEnv.getStream()); // If we are recording spikes, copy word to correct location in global memory if(ng.getArchetype().isSpikeRecordingEnabled()) { - groupEnv.getStream() << groupEnv["_record_spk"] << "[" << globalIndex << "] = shSpkRecord"; + groupEnv.print("$(_record_spk)[" + globalIndex + "] = shSpkRecord"); if(m_KernelBlockSizes[KernelNeuronUpdate] != 32) { - groupEnv.getStream() << "[" << getThreadID() << "]"; + groupEnv.print("[" + getThreadID() + "]"); } - groupEnv.getStream() << ";" << std::endl; + groupEnv.printLine(";"); } // If we are recording spike-like events, copy word to correct location in global memory if(ng.getArchetype().isSpikeEventRecordingEnabled()) { - groupEnv.getStream() << groupEnv["_record_spk_evnt"] << "[" << globalIndex << "] = shSpkEvntRecord"; + groupEnv.print("$(_record_spk_evnt)[" + globalIndex + "] = shSpkEvntRecord"); if(m_KernelBlockSizes[KernelNeuronUpdate] != 32) { - groupEnv.getStream() << "[" << getThreadID() << "]"; + groupEnv.print("[" + getThreadID() + "]"); } - groupEnv.getStream() << ";" << std::endl; + groupEnv.printLine(";"); } } } diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index b2b7cef8cf..2612d3ed71 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -470,22 +470,6 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E EnvironmentGroupMergedField neuronEnv(env, *this); - // Add field for spike recording - neuronEnv.addField(Type::Uint32.createPointer(), "_record_spk", "recordSpk", - [&backend](const auto &ng, size_t) - { - return backend.getDeviceVarPrefix() + "recordSpk" + ng.getName(); - }, - "", GroupMergedFieldType::DYNAMIC); - - // Add field for spike event recording - neuronEnv.addField(Type::Uint32.createPointer(), "_record_spk_event", "recordSpkEvent", - [&backend](const auto &ng, size_t) - { - return backend.getDeviceVarPrefix() + "recordSpkEvent" + ng.getName(); - }, - "", GroupMergedFieldType::DYNAMIC); - // Add default input variable neuronEnv.add(modelMerged.getModel().getPrecision(), "Isyn", "Isyn", {neuronEnv.addInitialiser(getScalarType().getName() + " Isyn = 0;")}); From 3d84aad85716b1471d50f8b14a270cddb5d41b25 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 16:43:47 +0100 Subject: [PATCH 360/725] apply neuron index calculation to NeuronInitGroupMerged in backend --- src/genn/backends/single_threaded_cpu/backend.cc | 5 ++++- src/genn/genn/code_generator/backendSIMT.cc | 10 ++++++---- src/genn/genn/code_generator/initGroupMerged.cc | 1 - 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 28f8d3ee29..f21231cb33 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -862,7 +862,10 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedNeuronInitGroup" << n.getIndex() << "[g]; " << std::endl; - n.generateInit(*this, funcEnv, modelMerged); + + EnvironmentGroupMergedField groupEnv(funcEnv, n); + genNeuronIndexCalculation(groupEnv, 1); + n.generateInit(*this, groupEnv, modelMerged); } }); diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index d0a3b834f9..be8760f6ba 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -1325,11 +1325,13 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer [this](const NeuronGroupInternal &ng) { return padKernelSize(ng.getNumNeurons(), KernelInitialize); }, [&modelMerged, this](EnvironmentExternalBase &env, NeuronInitGroupMerged &ng) { - env.getStream() << "// only do this for existing neurons" << std::endl; - env.print("if($(id) < $(num_neurons))"); + EnvironmentGroupMergedField groupEnv(env, ng); + genNeuronIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); + + groupEnv.getStream() << "// only do this for existing neurons" << std::endl; + groupEnv.print("if($(id) < $(num_neurons))"); { - CodeStream::Scope b(env.getStream()); - EnvironmentGroupMergedField groupEnv(env, ng); + CodeStream::Scope b(groupEnv.getStream()); // If population RNGs are initialised on device and this neuron is going to require one, if(isPopulationRNGInitialisedOnDevice() && ng.getArchetype().isSimRNGRequired()) { diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 1e5089d1f8..dce60da533 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -369,7 +369,6 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, Environment // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); - backend.genNeuronIndexCalculation(groupEnv, model.getBatchSize()); // Initialise spike counts genInitSpikeCount(backend, groupEnv, false, model.getBatchSize()); From c0d43886315e14232ffd0a1407ac02b0be61df46 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 16:59:14 +0100 Subject: [PATCH 361/725] invoke! --- src/genn/genn/neuronGroup.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/genn/genn/neuronGroup.cc b/src/genn/genn/neuronGroup.cc index 8a954fdf5b..8cc15934c4 100644 --- a/src/genn/genn/neuronGroup.cc +++ b/src/genn/genn/neuronGroup.cc @@ -60,12 +60,12 @@ void fuseSynapseGroups(const std::vector &unmergedSyn, bo } // If this synapse group can be merged at all - if(!(a->*isSynMergableFunc)()) { + if(!std::invoke(isSynMergableFunc, a)) { continue; } // Get hash digest used for checking compatibility - const auto aHashDigest = (a->*getSynMergeHashFunc)(); + const auto aHashDigest = std::invoke(getSynMergeHashFunc, a); // Create a name for merged groups const std::string mergedTargetName = mergedTargetPrefix + std::to_string(i) + "_" + mergedTargetSuffix; @@ -74,11 +74,11 @@ void fuseSynapseGroups(const std::vector &unmergedSyn, bo bool anyMerged = false; for(auto b = syn.begin(); b != syn.end();) { // If synapse group b can be merged with others and it's compatible with a - if(((*b)->*isSynMergableFunc)() && (aHashDigest == ((*b)->*getSynMergeHashFunc)())) { + if(std::invoke(isSynMergableFunc, *b) && (aHashDigest == std::invoke(getSynMergeHashFunc, *b))) { LOGD_GENN << "Merging " << logDescription << " of '" << (*b)->getName() << "' with '" << a->getName() << "' into '" << mergedTargetName << "'"; // Set b's merge target to our unique name - ((*b)->*setSynMergeTargetFunc)(mergedTargetName); + std::invoke(setSynMergeTargetFunc, *b, mergedTargetName); // Remove from temporary vector b = syn.erase(b); @@ -95,7 +95,7 @@ void fuseSynapseGroups(const std::vector &unmergedSyn, bo // If synapse group A was successfully merged with anything, set it's merge target to the unique name if(anyMerged) { - (a->*setSynMergeTargetFunc)(mergedTargetName); + std::invoke(setSynMergeTargetFunc, a, mergedTargetName); } } } From 0cda97e38d3bd6809c05d7bcd3f7055174647627 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 17:01:01 +0100 Subject: [PATCH 362/725] fixed typos - CUDA VA benchmark runs! --- src/genn/backends/cuda/backend.cc | 7 ++++--- src/genn/genn/code_generator/backendSIMT.cc | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 096508796c..835a884db9 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -467,7 +467,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Bac //if(!modelMerged.getMergedSynapseDendriticDelayUpdateGroups().empty()) { synapseUpdateEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelSynapseDendriticDelayUpdate] << "()"; { - CodeStream::Scope b(os); + CodeStream::Scope b(synapseUpdateEnv.getStream()); synapseUpdateEnv.getStream() << "const unsigned int id = " << getKernelBlockSize(KernelSynapseDendriticDelayUpdate) << " * blockIdx.x + threadIdx.x;" << std::endl; genSynapseDendriticDelayUpdateKernel(synapseUpdateEnv, modelMerged, memorySpaces, idSynapseDendricDelayUpdate); @@ -827,8 +827,9 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: std::ostringstream initStream; CodeStream init(initStream); - // Begin environment with standard library - EnvironmentLibrary initEnv(init, StandardLibrary::getMathsFunctions()); + // Begin environment with RNG library and standard library + EnvironmentLibrary rngEnv(init, getRNGFunctions(model.getPrecision())); + EnvironmentLibrary initEnv(rngEnv, StandardLibrary::getMathsFunctions()); // If device RNG is required, generate kernel to initialise it if(isGlobalDeviceRNGRequired(model)) { diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index be8760f6ba..1798dee5a3 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -1487,6 +1487,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer [&modelMerged, this](EnvironmentExternalBase &env, SynapseConnectivityInitGroupMerged &sg) { EnvironmentGroupMergedField groupEnv(env, sg); + genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); // If there is row-building code in this snippet const auto &connectInit = sg.getArchetype().getConnectivityInitialiser(); From 046a6bc37bf16b291ca20b0d095d0fbaa7ccd65e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 17:47:19 +0100 Subject: [PATCH 363/725] updated bit of syntax --- include/genn/genn/currentSourceModels.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/genn/genn/currentSourceModels.h b/include/genn/genn/currentSourceModels.h index ac7196d225..c44979d7d2 100644 --- a/include/genn/genn/currentSourceModels.h +++ b/include/genn/genn/currentSourceModels.h @@ -98,7 +98,7 @@ class PoissonExp : public Base "{\n" " numSpikes++;\n" " p *= gennrand_uniform();\n" - "} while (p > $(ExpMinusLambda));\n" + "} while (p > ExpMinusLambda);\n" "current += Init * (scalar)(numSpikes - 1);\n" "injectCurrent(current);\n" "current *= ExpDecay;\n"); From 36f31645d5e51fd2c5e315cc3045cdc580179680 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 17:53:37 +0100 Subject: [PATCH 364/725] fixed various RNG stuff * CUDA RNG functions use $(_rng) * ``printSubs`` now operates on ``PrettyPrinter::EnvironmentBase`` * Moved [], print and printLine down to ``PrettyPrinter::EnvironmentBase`` * Pretty printer now prints function string after substituting in parameters --- .../genn/genn/code_generator/codeGenUtils.h | 2 +- .../genn/genn/code_generator/environment.h | 14 ------ include/genn/genn/transpiler/prettyPrinter.h | 14 ++++++ src/genn/backends/cuda/backend.cc | 26 +++++------ src/genn/genn/code_generator/backendSIMT.cc | 43 +++++++++++++------ src/genn/genn/code_generator/codeGenUtils.cc | 5 ++- src/genn/genn/code_generator/environment.cc | 10 ----- src/genn/genn/code_generator/lazyString.cc | 2 +- src/genn/genn/transpiler/prettyPrinter.cc | 22 ++++++++-- 9 files changed, 80 insertions(+), 58 deletions(-) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 406ece7faa..51056cd01f 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -77,7 +77,7 @@ GENN_EXPORT void prettyPrintStatements(const std::vector &tok Transpiler::ErrorHandlerBase &errorHandler, Transpiler::TypeChecker::StatementHandler forEachSynapseTypeCheckHandler = nullptr, Transpiler::PrettyPrinter::StatementHandler forEachSynapsePrettyPrintHandler = nullptr); -GENN_EXPORT std::string printSubs(const std::string &format, EnvironmentExternalBase &env); +GENN_EXPORT std::string printSubs(const std::string &format, Transpiler::PrettyPrinter::EnvironmentBase &env); template diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 8a1523de55..64702d6e18 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -64,20 +64,6 @@ class EnvironmentExternalBase : public Transpiler::PrettyPrinter::EnvironmentBas virtual void define(const Transpiler::Token &name, const GeNN::Type::ResolvedType &type, Transpiler::ErrorHandlerBase &errorHandler) override; - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - void print(const std::string &format); - void printLine(const std::string &format); - - //------------------------------------------------------------------------ - // Operators - //------------------------------------------------------------------------ - std::string operator[] (const std::string &name) - { - return getName(name); - } - protected: //------------------------------------------------------------------------ // Protected API diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h index 14cd30034c..bd3a2a8346 100644 --- a/include/genn/genn/transpiler/prettyPrinter.h +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -36,6 +36,20 @@ class EnvironmentBase //! Get stream to write code within this environment to virtual CodeGenerator::CodeStream &getStream() = 0; + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + void print(const std::string &format); + void printLine(const std::string &format); + + //------------------------------------------------------------------------ + // Operators + //------------------------------------------------------------------------ + std::string operator[] (const std::string &name) + { + return getName(name); + } }; typedef std::function)> StatementHandler; diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 835a884db9..5c0577a0bf 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -30,21 +30,21 @@ using namespace GeNN::CodeGenerator; namespace { const EnvironmentLibrary::Library floatRandomFunctions = { - {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Float, {}), "curand_uniform($(rng))"}}, - {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "curand_normal($(rng))"}}, - {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Float, {}), "exponentialDistFloat($(rng))"}}, - {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Float, {Type::Float, Type::Float}), "curand_log_normal_float($(rng), $(0), $(1))"}}, - {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Float, {Type::Float}), "gammaDistFloat($(rng), $(0))"}}, - {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Float}), "binomialDistFloat($(rng), $(0), $(1))"}}, + {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Float, {}), "curand_uniform($(_rng))"}}, + {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "curand_normal($(_rng))"}}, + {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Float, {}), "exponentialDistFloat($(_rng))"}}, + {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Float, {Type::Float, Type::Float}), "curand_log_normal_float($(_rng), $(0), $(1))"}}, + {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Float, {Type::Float}), "gammaDistFloat($(_rng), $(0))"}}, + {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Float}), "binomialDistFloat($(_rng), $(0), $(1))"}}, }; const EnvironmentLibrary::Library doubleRandomFunctions = { - {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Double, {}), "curand_uniform_double($(rng))"}}, - {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Double, {}), "curand_normal_double($(rng))"}}, - {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Double, {}), "exponentialDistDouble($(rng))"}}, - {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Double, {Type::Double, Type::Double}), "curand_log_normal_double($(rng), $(0), $(1))"}}, - {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Double, {Type::Double}), "gammaDistDouble($(rng), $(0))"}}, - {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Double}), "binomialDistDouble($(rng), $(0), $(1))"}}, + {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Double, {}), "curand_uniform_double($(_rng))"}}, + {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Double, {}), "curand_normal_double($(_rng))"}}, + {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Double, {}), "exponentialDistDouble($(_rng))"}}, + {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Double, {Type::Double, Type::Double}), "curand_log_normal_double($(_rng), $(0), $(1))"}}, + {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Double, {Type::Double}), "gammaDistDouble($(_rng), $(0))"}}, + {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Double}), "binomialDistDouble($(_rng), $(0), $(1))"}}, }; //-------------------------------------------------------------------------- @@ -577,7 +577,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Bac // Launch presynaptic update kernel if(idPresynapticStart > 0) { - CodeStream::Scope b(os); + CodeStream::Scope b(synapseUpdateEnv.getStream()); Timer t(synapseUpdateEnv.getStream(), "presynapticUpdate", model.isTimingEnabled()); genKernelDimensions(synapseUpdateEnv.getStream(), KernelPresynapticUpdate, idPresynapticStart, model.getBatchSize()); diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 1798dee5a3..d0b560a1c2 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -503,6 +503,8 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM // Call handler to generate generic neuron code groupEnv.print("if($(id) < $(num_neurons))"); { + CodeStream::Scope b(groupEnv.getStream()); + // Add population RNG field groupEnv.addField(getPopulationRNGType().createPointer(), "_rng", "rng", [this](const auto &g, size_t) { return getDeviceVarPrefix() + "rng" + g.getName(); }, @@ -1335,17 +1337,22 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // If population RNGs are initialised on device and this neuron is going to require one, if(isPopulationRNGInitialisedOnDevice() && ng.getArchetype().isSimRNGRequired()) { + // Add field for RNG + EnvironmentGroupMergedField rngInitEnv(groupEnv, ng); + rngInitEnv.addField(getPopulationRNGType().createPointer(), "_rng", "rng", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "rng" + g.getName(); }); + // If batch size is 1, initialise single RNG using GLOBAL thread id for sequence if(modelMerged.getModel().getBatchSize() == 1) { - genPopulationRNGInit(groupEnv.getStream(), printSubs("$(_rng)[$(id)]", groupEnv), + genPopulationRNGInit(rngInitEnv.getStream(), printSubs("$(_rng)[$(id)]", rngInitEnv), "deviceRNGSeed", "id"); } // Otherwise, loop through batches and initialise independent RNGs using GLOBAL thread id as basis of sequence else { env.getStream() << "for(unsigned int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; { - CodeStream::Scope b(groupEnv.getStream()); - genPopulationRNGInit(groupEnv.getStream(), printSubs("$(_rng)[(b * $(num_neurons)) + $(id)]", groupEnv), + CodeStream::Scope b(rngInitEnv.getStream()); + genPopulationRNGInit(rngInitEnv.getStream(), printSubs("$(_rng)[(b * $(num_neurons)) + $(id)]", rngInitEnv), "deviceRNGSeed", "(b * " + std::to_string(getNumInitialisationRNGStreams(modelMerged)) + ") + id"); } } @@ -1356,7 +1363,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(ng.getArchetype().isInitRNGRequired()) { - groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); + groupEnv.add(Type::Void, "_rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } ng.generateInit(*this, groupEnv, modelMerged); @@ -1394,7 +1401,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(cg.getArchetype().isInitRNGRequired()) { - groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); + groupEnv.add(Type::Void, "_rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } cg.generateInit(*this, groupEnv, modelMerged); @@ -1431,7 +1438,12 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // If population RNGs are initialised on device and this custom connectivity update // required one, initialise single RNG using GLOBAL thread id for sequence if(isPopulationRNGInitialisedOnDevice() && Utils::isRNGRequired(cg.getArchetype().getRowUpdateCodeTokens())) { - genPopulationRNGInit(groupEnv.getStream(), printSubs("$(_rng)[$(id)]", groupEnv), + // Add field for RNG + EnvironmentGroupMergedField rngInitEnv(groupEnv, cg); + rngInitEnv.addField(getPopulationRNGType().createPointer(), "_rng", "rng", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "rng" + g.getName(); }); + + genPopulationRNGInit(rngInitEnv.getStream(), printSubs("$(_rng)[$(id)]", groupEnv), "deviceRNGSeed", "id"); } @@ -1439,7 +1451,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(Utils::isRNGRequired(cg.getArchetype().getPreVarInitialisers())) { - groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); + groupEnv.add(Type::Void, "_rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } cg.generateInit(*this, groupEnv, modelMerged); @@ -1463,7 +1475,12 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // If population RNGs are initialised on device and this custom connectivity update // required one, initialise single RNG using GLOBAL thread id for sequence if(isPopulationRNGInitialisedOnDevice() && Utils::isRNGRequired(cg.getArchetype().getRowUpdateCodeTokens())) { - genPopulationRNGInit(groupEnv.getStream(), printSubs("$(_rng)[$(id)]", groupEnv), + // Add field for RNG + EnvironmentGroupMergedField rngInitEnv(groupEnv, cg); + rngInitEnv.addField(getPopulationRNGType().createPointer(), "_rng", "rng", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "rng" + g.getName(); }); + + genPopulationRNGInit(rngInitEnv.getStream(), printSubs("$(_rng)[$(id)]", groupEnv), "deviceRNGSeed", "id"); } @@ -1471,7 +1488,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(Utils::isRNGRequired(cg.getArchetype().getPostVarInitialisers())) { - groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); + groupEnv.add(Type::Void, "_rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } cg.generateInit(*this, groupEnv, modelMerged); @@ -1602,7 +1619,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(connectInit.isRNGRequired()) { - groupEnv.add(Type::Void, "rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); + groupEnv.add(Type::Void, "_rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } // If there is row-building code in this snippet @@ -1644,7 +1661,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(sg.getArchetype().isWUInitRNGRequired()) { - groupEnv.add(Type::Void, "rng", + groupEnv.add(Type::Void, "_rng", genGlobalRNGSkipAhead(groupEnv.getStream(), std::to_string(numInitializeThreads) + " + id")); } @@ -1685,7 +1702,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(cg.getArchetype().isInitRNGRequired()) { - groupEnv.add(Type::Void, "rng", + groupEnv.add(Type::Void, "_rng", genGlobalRNGSkipAhead(groupEnv.getStream(), std::to_string(numInitializeThreads) + " + id")); } @@ -1707,7 +1724,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // make copy of global phillox RNG and skip ahead by thread id // **NOTE** not LOCAL id if(Utils::isRNGRequired(cg.getArchetype().getVarInitialisers())) { - groupEnv.add(Type::Void, "rng", + groupEnv.add(Type::Void, "_rng", genGlobalRNGSkipAhead(groupEnv.getStream(), std::to_string(numInitializeThreads) + " + id")); } diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index a5a7d6b87e..3baf29bb32 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -153,10 +153,11 @@ void prettyPrintStatements(const std::vector &tokens, const T PrettyPrinter::print(updateStatements, env, typeContext, resolvedTypes, forEachSynapsePrettyPrintHandler); } //-------------------------------------------------------------------------- -std::string printSubs(const std::string &format, EnvironmentExternalBase &env) +std::string printSubs(const std::string &format, Transpiler::PrettyPrinter::EnvironmentBase &env) { // Create regex iterator to iterate over $(XXX) style varibles in format string - std::regex regex("\\$\\(([\\w]+)\\)"); + // **NOTE** this doesn't match function argument $(0) + std::regex regex("\\$\\(([a-zA-Z_][\\w]+)\\)"); std::sregex_iterator matchesBegin(format.cbegin(), format.cend(), regex); std::sregex_iterator matchesEnd; diff --git a/src/genn/genn/code_generator/environment.cc b/src/genn/genn/code_generator/environment.cc index b9ae2d3828..3b7aa0a92f 100644 --- a/src/genn/genn/code_generator/environment.cc +++ b/src/genn/genn/code_generator/environment.cc @@ -30,16 +30,6 @@ void EnvironmentExternalBase::define(const Token&, const Type::ResolvedType&, Er throw std::runtime_error("Cannot declare variable in external environment"); } //---------------------------------------------------------------------------- -void EnvironmentExternalBase::print(const std::string &format) -{ - getStream() << printSubs(format, *this); -} -//---------------------------------------------------------------------------- -void EnvironmentExternalBase::printLine(const std::string &format) -{ - getStream() << printSubs(format, *this) << std::endl; -} -//---------------------------------------------------------------------------- CodeStream &EnvironmentExternalBase::getContextStream() const { return std::visit( diff --git a/src/genn/genn/code_generator/lazyString.cc b/src/genn/genn/code_generator/lazyString.cc index d40ff44727..e1ec627278 100644 --- a/src/genn/genn/code_generator/lazyString.cc +++ b/src/genn/genn/code_generator/lazyString.cc @@ -18,7 +18,7 @@ using namespace GeNN::CodeGenerator; LazyString::LazyString(const std::string &format, EnvironmentExternalBase &env) { // Create regex iterator to iterate over $(XXX) style varibles in format string - // **NOTE** this doesn't match function argument $(0) + // **NOTE** this doesn't match function argument $(0) std::regex regex("\\$\\(([a-zA-Z_][\\w]+)\\)"); std::sregex_iterator matchesBegin(format.cbegin(), format.cend(), regex); std::sregex_iterator matchesEnd; diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 8d7588db0d..100b42fa50 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -9,6 +9,7 @@ #include // GeNN code generator includes +#include "code_generator/codeGenUtils.h" #include "code_generator/codeStream.h" // Transpiler includes @@ -262,8 +263,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto &type = m_ResolvedTypes.at(&variable); std::string name = m_Environment.get().getName(variable.getName().lexeme, type); - // If identifier is function and name isn't empty i.e. it contains a function template - if (type.isFunction() && !name.empty()) { + // If identifier is function i.e. name is a function template + if (type.isFunction()) { // Check that there are call arguments on the stack assert(!m_CallArguments.empty()); @@ -313,8 +314,8 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } } // Print out name - // **NOTE** in case of function this will be full pretty-printed call - m_Environment.get().getStream() << name; + // **NOTE** this will apply any remaining substitutions + m_Environment.get().print(name); } virtual void visit(const Expression::Unary &unary) final @@ -492,6 +493,19 @@ class Visitor : public Expression::Visitor, public Statement::Visitor }; } // Anonymous namespace +//--------------------------------------------------------------------------- +// GeNN::Transpiler::PrettyPrinter::EnvironmentBase +//--------------------------------------------------------------------------- +void EnvironmentBase::print(const std::string &format) +{ + getStream() << printSubs(format, *this); +} +//---------------------------------------------------------------------------- +void EnvironmentBase::printLine(const std::string &format) +{ + getStream() << printSubs(format, *this) << std::endl; +} + //--------------------------------------------------------------------------- // GeNN::Transpiler::PrettyPrinter //--------------------------------------------------------------------------- From 76fd3036c00f108815b281d815c233ea2e7d47e3 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 18:15:20 +0100 Subject: [PATCH 365/725] call CUDA rngs with address of RNG --- src/genn/backends/cuda/backend.cc | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 5c0577a0bf..7be92e1316 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -30,21 +30,21 @@ using namespace GeNN::CodeGenerator; namespace { const EnvironmentLibrary::Library floatRandomFunctions = { - {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Float, {}), "curand_uniform($(_rng))"}}, - {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "curand_normal($(_rng))"}}, - {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Float, {}), "exponentialDistFloat($(_rng))"}}, - {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Float, {Type::Float, Type::Float}), "curand_log_normal_float($(_rng), $(0), $(1))"}}, - {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Float, {Type::Float}), "gammaDistFloat($(_rng), $(0))"}}, - {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Float}), "binomialDistFloat($(_rng), $(0), $(1))"}}, + {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Float, {}), "curand_uniform(&$(_rng))"}}, + {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Float, {}), "curand_normal(&$(_rng))"}}, + {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Float, {}), "exponentialDistFloat(&$(_rng))"}}, + {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Float, {Type::Float, Type::Float}), "curand_log_normal_float(&$(_rng), $(0), $(1))"}}, + {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Float, {Type::Float}), "gammaDistFloat(&$(_rng), $(0))"}}, + {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Float}), "binomialDistFloat(&$(_rng), $(0), $(1))"}}, }; const EnvironmentLibrary::Library doubleRandomFunctions = { - {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Double, {}), "curand_uniform_double($(_rng))"}}, - {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Double, {}), "curand_normal_double($(_rng))"}}, - {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Double, {}), "exponentialDistDouble($(_rng))"}}, - {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Double, {Type::Double, Type::Double}), "curand_log_normal_double($(_rng), $(0), $(1))"}}, - {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Double, {Type::Double}), "gammaDistDouble($(_rng), $(0))"}}, - {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Double}), "binomialDistDouble($(_rng), $(0), $(1))"}}, + {"gennrand_uniform", {Type::ResolvedType::createFunction(Type::Double, {}), "curand_uniform_double(&$(_rng))"}}, + {"gennrand_normal", {Type::ResolvedType::createFunction(Type::Double, {}), "curand_normal_double(&$(_rng))"}}, + {"gennrand_exponential", {Type::ResolvedType::createFunction(Type::Double, {}), "exponentialDistDouble(&$(_rng))"}}, + {"gennrand_log_normal", {Type::ResolvedType::createFunction(Type::Double, {Type::Double, Type::Double}), "curand_log_normal_double(&$(_rng), $(0), $(1))"}}, + {"gennrand_gamma", {Type::ResolvedType::createFunction(Type::Double, {Type::Double}), "gammaDistDouble(&$(_rng), $(0))"}}, + {"gennrand_binomial", {Type::ResolvedType::createFunction(Type::Uint32, {Type::Uint32, Type::Double}), "binomialDistDouble(&$(_rng), $(0), $(1))"}}, }; //-------------------------------------------------------------------------- From c991bb53f401204dc7a41096e6a97d0fac36ee4a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 10 Jul 2023 18:15:45 +0100 Subject: [PATCH 366/725] fix recording logic and fix up synapse dendritic delay pointer update --- src/genn/genn/code_generator/backendSIMT.cc | 44 +++++++++++---------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index d0b560a1c2..a2d5fcb1c6 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -474,15 +474,15 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM {neuronEnv.addInitialiser(shSpkEvntCountInitStream.str())}); // If any neuron groups record spikes - if(std::any_of(modelMerged.getMergedNeuronUpdateGroups().cbegin(), modelMerged.getMergedNeuronUpdateGroups().cend(), - [](const NeuronUpdateGroupMerged &n) { return n.getArchetype().isSpikeRecordingEnabled(); })) + if(std::any_of(modelMerged.getModel().getNeuronGroups().cbegin(), modelMerged.getModel().getNeuronGroups().cend(), + [](const auto &n) { return n.second.isSpikeRecordingEnabled(); })) { genRecordingSharedMemInit(env.getStream(), ""); } // If any neuron groups record spike-like events - if(std::any_of(modelMerged.getMergedNeuronUpdateGroups().cbegin(), modelMerged.getMergedNeuronUpdateGroups().cend(), - [](const NeuronUpdateGroupMerged &n) { return n.getArchetype().isSpikeEventRecordingEnabled(); })) + if(std::any_of(modelMerged.getModel().getNeuronGroups().cbegin(), modelMerged.getModel().getNeuronGroups().cend(), + [](const auto &n) { return n.second.isSpikeEventRecordingEnabled(); })) { genRecordingSharedMemInit(env.getStream(), "Evnt"); } @@ -664,24 +664,28 @@ void BackendSIMT::genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase & { // Loop through merged synapse groups idStart = 0; - for(const auto &n : modelMerged.getMergedSynapseDendriticDelayUpdateGroups()) { - env.getStream() << "// merged" << n.getIndex() << std::endl; - if(idStart == 0) { - env.getStream() << "if(id < " << n.getGroups().size() << ")"; - } - else { - env.getStream() << "if(id >= " << idStart << " && id < " << idStart + n.getGroups().size() << ")"; - } + modelMerged.genMergedSynapseDendriticDelayUpdateGroups( + *this, memorySpaces, + [&env, &idStart, &modelMerged, this](auto &sg) { - CodeStream::Scope b(env.getStream()); - - // Use this to get reference to merged group structure - env.getStream() << getPointerPrefix() << "struct MergedSynapseDendriticDelayUpdateGroup" << n.getIndex() << " *group = &d_mergedSynapseDendriticDelayUpdateGroup" << n.getIndex() << "[id - " << idStart << "]; " << std::endl; + env.getStream() << "// merged" << sg.getIndex() << std::endl; + if(idStart == 0) { + env.getStream() << "if(id < " << sg.getGroups().size() << ")"; + } + else { + env.getStream() << "if(id >= " << idStart << " && id < " << idStart + sg.getGroups().size() << ")"; + } + { + CodeStream::Scope b(env.getStream()); - env.printLine("*$(_den_delay_ptr) = (*$(_den_delay_ptr) + 1) % " + std::to_string(n.getArchetype().getMaxDendriticDelayTimesteps()) + ";"); - } - idStart += n.getGroups().size(); - } + // Use this to get reference to merged group structure + env.getStream() << getPointerPrefix() << "struct MergedSynapseDendriticDelayUpdateGroup" << sg.getIndex() << " *group = &d_mergedSynapseDendriticDelayUpdateGroup" << sg.getIndex() << "[id - " << idStart << "]; " << std::endl; + EnvironmentGroupMergedField groupEnv(env, sg); + genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); + groupEnv.printLine("*$(_den_delay_ptr) = (*$(_den_delay_ptr) + 1) % " + std::to_string(sg.getArchetype().getMaxDendriticDelayTimesteps()) + ";"); + } + idStart += sg.getGroups().size(); + }); env.getStream() << std::endl; } //-------------------------------------------------------------------------- From aa62e9bbcd8e044feaaaf5d0ffa1a57799d1e0ce Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 10:20:53 +0100 Subject: [PATCH 367/725] fixed small issue with phillox RNG --- src/genn/backends/cuda/backend.cc | 67 ++++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 7be92e1316..d5b60bb1ed 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -301,7 +301,7 @@ std::string Backend::genGlobalRNGSkipAhead(CodeStream &os, const std::string &se // Skipahead RNG os << "curandStatePhilox4_32_10_t localRNG = d_rng;" << std::endl; os << "skipahead_sequence((unsigned long long)" << sequence << ", &localRNG);" << std::endl; - return "&localRNG"; + return "localRNG"; } //-------------------------------------------------------------------------- Type::ResolvedType Backend::getPopulationRNGType() const @@ -1033,6 +1033,8 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: modelMerged.getMergedCustomWUUpdateSparseInitGroups(), [this](const CustomUpdateWUInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups(), [this](const CustomConnectivityUpdateInternal &cg){ return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }); os << std::endl; + + os << initStream.str(); } //-------------------------------------------------------------------------- void Backend::genDefinitionsPreamble(CodeStream &os, const ModelSpecMerged &) const @@ -1599,6 +1601,35 @@ void Backend::genVariableDynamicAllocation(CodeStream &os, } } //-------------------------------------------------------------------------- +void Backend::genLazyVariableDynamicAllocation(CodeStream &os, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, + const std::string &countVarName) const +{ + const auto &underlyingType = type.isPointer() ? *type.getPointer().valueType : type; + const std::string hostPointer = type.isPointer() ? ("*$(_" + name + ")") : ("$(_" + name + ")"); + const std::string hostPointerToPointer = type.isPointer() ? ("$(_" + name + ")") : ("&$(_" + name + ")"); + const std::string devicePointerToPointer = type.isPointer() ? ("$(_d_" + name + ")") : ("&$(_d_" + name + ")"); + if(getPreferences().automaticCopy) { + os << "CHECK_CUDA_ERRORS(cudaMallocManaged(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType.getName() << ")));" << std::endl; + } + else { + if(loc & VarLocation::HOST) { + const char *flags = (loc & VarLocation::ZERO_COPY) ? "cudaHostAllocMapped" : "cudaHostAllocPortable"; + os << "CHECK_CUDA_ERRORS(cudaHostAlloc(" << hostPointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType.getName() << "), " << flags << "));" << std::endl; + } + + // If variable is present on device at all + if(loc & VarLocation::DEVICE) { + if(loc & VarLocation::ZERO_COPY) { + os << "CHECK_CUDA_ERRORS(cudaHostGetDevicePointer((void**)" << devicePointerToPointer << ", (void*)" << hostPointer << ", 0));" << std::endl; + } + else { + os << "CHECK_CUDA_ERRORS(cudaMalloc(" << devicePointerToPointer << ", " << countVarName << " * sizeof(" << underlyingType.getName() << ")));" << std::endl; + } + } + } +} +//-------------------------------------------------------------------------- void Backend::genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const { if(getPreferences().automaticCopy) { @@ -1732,6 +1763,23 @@ void Backend::genVariableDynamicPush(CodeStream &os, } } //-------------------------------------------------------------------------- +void Backend::genLazyVariableDynamicPush(CodeStream &os, + const Type::ResolvedType &type, const std::string &name, + VarLocation loc, const std::string &countVarName) const +{ + if(!(loc & VarLocation::ZERO_COPY)) { + if (type.isPointer()) { + os << "CHECK_CUDA_ERRORS(cudaMemcpy(*$(_d_" << name << "), *$(_" << name << "), "; + os << countVarName << " * sizeof(" << type.getPointer().valueType->getName() << "), cudaMemcpyHostToDevice));" << std::endl; + } + else { + os << "$(d_" << name << ") = new " << type.getName() << "[" << countVarName << "];" << std::endl; + os << "CHECK_CUDA_ERRORS(cudaMemcpy($(_d_" << name << "), $(_" << name << "), "; + os << countVarName << " * sizeof(" << type.getName() << "), cudaMemcpyHostToDevice));" << std::endl; + } + } +} +//-------------------------------------------------------------------------- void Backend::genVariableDynamicPull(CodeStream &os, const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName, const std::string &prefix) const @@ -1753,6 +1801,23 @@ void Backend::genVariableDynamicPull(CodeStream &os, } } //-------------------------------------------------------------------------- +void Backend::genLazyVariableDynamicPull(CodeStream &os, + const Type::ResolvedType &type, const std::string &name, + VarLocation loc, const std::string &countVarName) const +{ + if(!(loc & VarLocation::ZERO_COPY)) { + if (type.isPointer()) { + os << "CHECK_CUDA_ERRORS(cudaMemcpy(*$(_" << name << "), *$(_d_" << name << "), "; + os << countVarName << " * sizeof(" << type.getPointer().valueType->getName() << "), cudaMemcpyDeviceToHost));" << std::endl; + } + else { + os << "CHECK_CUDA_ERRORS(cudaMemcpy($(_" << name << "), $(_d_" << name << "), "; + os << countVarName << " * sizeof(" << type.getName() << "), cudaMemcpyDeviceToHost));" << std::endl; + } + + } +} +//-------------------------------------------------------------------------- void Backend::genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, const std::string &groupIdx, const std::string &fieldName, const std::string &egpName) const From 062a310f21157797b5161cbe1d30253606575467 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 10:21:26 +0100 Subject: [PATCH 368/725] pretty printer should replace **all** placeholders with arguments not just first --- src/genn/genn/transpiler/prettyPrinter.cc | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 100b42fa50..65bd572130 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -273,14 +273,18 @@ class Visitor : public Expression::Visitor, public Statement::Visitor for (i = 0; i < m_CallArguments.top().size(); i++) { // If name contains a $(i) placeholder to replace with this argument, replace with pretty-printed argument const std::string placeholder = "$(" + std::to_string(i) + ")"; - const size_t found = name.find(placeholder); - if (found != std::string::npos) { - name.replace(found, placeholder.length(), m_CallArguments.top().at(i)); - } - // Otherwise, stop searching - else { + + // If placeholder isn't found at all, stop looking for arguments + size_t found = name.find(placeholder); + if(found == std::string::npos) { break; } + + // Keep replacing placeholders + do { + name.replace(found, placeholder.length(), m_CallArguments.top().at(i)); + found = name.find(placeholder, found); + } while(found != std::string::npos); } // If all arguments haven't been substituted From 45bd75bf5ea8fdd1acee7c86169fe69d54f9cf26 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 10:22:14 +0100 Subject: [PATCH 369/725] lazy versions of dynamic push pull backend functions for use in host code (otherwise e.g. device fields don't get created) --- include/genn/backends/cuda/backend.h | 14 +++++++++++++ .../genn/genn/code_generator/backendBase.h | 15 +++++++++++++ .../genn/code_generator/initGroupMerged.cc | 21 ++++++++++--------- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index b883bd32aa..a44c10d99f 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -210,6 +210,10 @@ class BACKEND_EXPORT Backend : public BackendSIMT const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + //! Generate code to allocate variable with a size known at runtime + virtual void genLazyVariableDynamicAllocation(CodeStream &os, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, + const std::string &countVarName) const final; //! Generate code to free a variable virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const final; @@ -239,11 +243,21 @@ class BACKEND_EXPORT Backend : public BackendSIMT const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + //! Generate code for pushing a variable with a size known at tuntime to the 'device' + virtual void genLazyVariableDynamicPush(CodeStream &os, + const Type::ResolvedType &type, const std::string &name, + VarLocation loc, const std::string &countVarName) const final; + //! Generate code for pulling a variable with a size known at runtime from the 'device' virtual void genVariableDynamicPull(CodeStream &os, const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + //! Generate code for pulling a variable with a size known at runtime from the 'device' + virtual void genLazyVariableDynamicPull(CodeStream &os, + const Type::ResolvedType &type, const std::string &name, + VarLocation loc, const std::string &countVarName) const final; + //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' virtual void genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, const std::string &groupIdx, const std::string &fieldName, diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 03bd805ff5..3f83c4d8e3 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -272,6 +272,11 @@ class GENN_EXPORT BackendBase virtual void genVariableDynamicAllocation(CodeStream &os, const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; + + //! Generate code to allocate variable with a size known at runtime + virtual void genLazyVariableDynamicAllocation(CodeStream &os, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, + const std::string &countVarName) const = 0; //! Generate code to free a variable virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const = 0; @@ -301,11 +306,21 @@ class GENN_EXPORT BackendBase const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; + //! Generate code for pushing a variable with a size known at tuntime to the 'device' + virtual void genLazyVariableDynamicPush(CodeStream &os, + const Type::ResolvedType &type, const std::string &name, + VarLocation loc, const std::string &countVarName) const = 0; + //! Generate code for pulling a variable with a size known at runtime from the 'device' virtual void genVariableDynamicPull(CodeStream &os, const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; + //! Generate code for pulling a variable with a size known at runtime from the 'device' + virtual void genLazyVariableDynamicPull(CodeStream &os, + const Type::ResolvedType &type, const std::string &name, + VarLocation loc, const std::string &countVarName) const = 0; + //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' virtual void genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, const std::string &groupIdx, const std::string &fieldName, diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index dce60da533..29f949f93e 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -786,11 +786,12 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac const auto pointerToPointerType = pointerType.createPointer(); // Add field for host pointer - // **NOTE** use [0] to dereference on access to obta - groupEnv.addField(pointerType, egp.name, - pointerToPointerType, egp.name, + groupEnv.addField(pointerToPointerType, "_" + egp.name, egp.name, [egp](const auto &g, size_t) { return "&" + egp.name + g.getName(); }, - "0", GroupMergedFieldType::HOST_DYNAMIC); + "", GroupMergedFieldType::HOST_DYNAMIC); + + // Add substitution for dereferenced access to field + groupEnv.add(pointerType, egp.name, "*$(_" + egp.name + ")"); // If backend requires seperate device variables, add additional (private) field) if(!backend.getDeviceVarPrefix().empty()) { @@ -816,9 +817,9 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac std::stringstream allocStream; const auto &pointerToEGP = resolvedType.createPointer(); CodeGenerator::CodeStream alloc(allocStream); - backend.genVariableDynamicAllocation(alloc, - pointerToEGP, egp.name, - loc, "$(0)", "group->"); + backend.genLazyVariableDynamicAllocation(alloc, + pointerToEGP, egp.name, + loc, "$(0)"); // Add substitution groupEnv.add(Type::AllocatePushPullEGP, "allocate" + egp.name, allocStream.str()); @@ -826,9 +827,9 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac // Generate code to push this EGP with count specified by $(0) std::stringstream pushStream; CodeStream push(pushStream); - backend.genVariableDynamicPush(push, - pointerToEGP, egp.name, - loc, "$(0)", "group->"); + backend.genLazyVariableDynamicPush(push, + pointerToEGP, egp.name, + loc, "$(0)"); // Add substitution From 0714d0b984b9a52c31d229ab62bee6eac77f195f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 10:30:40 +0100 Subject: [PATCH 370/725] fixed typo --- src/genn/genn/code_generator/initGroupMerged.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 29f949f93e..b0dca92e17 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -797,7 +797,7 @@ void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &bac if(!backend.getDeviceVarPrefix().empty()) { groupEnv.addField(pointerToPointerType, "_" + backend.getDeviceVarPrefix() + egp.name, backend.getDeviceVarPrefix() + egp.name, - [egp](const auto &g, size_t) { return "&" + egp.name + g.getName(); }, + [egp, &backend](const auto &g, size_t) { return "&" + backend.getDeviceVarPrefix() + egp.name + g.getName(); }, "", GroupMergedFieldType::DYNAMIC); } From 6429ef13dc3620f0bb2dab9d7fd39a3c8ae248a2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 10:45:46 +0100 Subject: [PATCH 371/725] fixed sparse init check --- src/genn/backends/cuda/backend.cc | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index d5b60bb1ed..39d775a2b2 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -863,17 +863,16 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: } const size_t numStaticInitThreads = idInitStart; - /*((sg.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && - (sg.isWUVarInitRequired() - || (backend.isPostsynapticRemapRequired() && !sg.getWUModel()->getLearnPostCode().empty())));*/ - // (cg.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) && cg.isVarInitRequired(); - // return cg.isVarInitRequired(); // Sparse initialization kernel code size_t idSparseInitStart = 0; - //if(std::any_of(model.getSynapseGroups().cbegin(), model.getSynapseGroups().cend(), - // [](const auto &sg){}) - if(!modelMerged.getMergedSynapseSparseInitGroups().empty() || !modelMerged.getMergedCustomWUUpdateSparseInitGroups().empty() - || !modelMerged.getMergedCustomConnectivityUpdateSparseInitGroups().empty()) { + if(std::any_of(model.getSynapseGroups().cbegin(), model.getSynapseGroups().cend(), + [](const auto &sg){ return ((sg.second.getMatrixType() & SynapseMatrixConnectivity::SPARSE) && + (sg.second.isWUVarInitRequired() || !Utils::areTokensEmpty(sg.second.getWUPostLearnCodeTokens()))); }) + || std::any_of(model.getCustomWUUpdates().cbegin(), model.getCustomWUUpdates().cend(), + [](const auto &cg){ return (cg.second.getSynapseGroup()->getMatrixType() & SynapseMatrixConnectivity::SPARSE) && cg.second.isVarInitRequired(); }) + || std::any_of(model.getCustomConnectivityUpdates().cbegin(), model.getCustomConnectivityUpdates().cend(), + [](const auto &cg){ return cg.second.isVarInitRequired(); })) + { initEnv.getStream() << "extern \"C\" __global__ void " << KernelNames[KernelInitializeSparse] << "()"; { CodeStream::Scope b(initEnv.getStream()); From 1f98d12041f6874d6ce93fa0132e560449e8870f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 11:11:18 +0100 Subject: [PATCH 372/725] fixed some bugs in sparse init --- include/genn/genn/code_generator/backendSIMT.h | 12 ++++++------ src/genn/backends/single_threaded_cpu/backend.cc | 5 ++++- src/genn/genn/code_generator/backendSIMT.cc | 7 ++++--- src/genn/genn/code_generator/initGroupMerged.cc | 1 - 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 12894b1274..26981608cf 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -433,9 +433,9 @@ class GENN_EXPORT BackendSIMT : public BackendBase { // Calculate how many blocks rows need to be processed in (in order to store row lengths in shared memory) const size_t blockSize = getKernelBlockSize(KernelInitializeSparse); - env.getStream() << "const unsigned int numBlocks = (" << env["num_pre"] << " + " << blockSize << " - 1) / " << blockSize << ";" << std::endl; - - env.getStream() << "unsigned int idx = " << env["id"] << ";" << std::endl; + const std::string blockSizeStr = std::to_string(blockSize); + env.printLine("const unsigned int numBlocks = ($(num_pre) + " + blockSizeStr + " - 1) / " + blockSizeStr + ";"); + env.printLine("unsigned int idx = $(id);"); // Loop through blocks env.getStream() << "for(unsigned int r = 0; r < numBlocks; r++)"; @@ -452,7 +452,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase env.getStream() << "if (" << getThreadID() << " < numRowsInBlock)"; { CodeStream::Scope b(env.getStream()); - env.getStream() << "shRowLength[" << getThreadID() << "] = " << env["_row_length"] << "[(r * " << blockSize << ") + " << getThreadID() << "];" << std::endl; + env.printLine("$(_sh_row_length)[" + getThreadID() + "] = $(_row_length)[(r * " + blockSizeStr + ") + " + getThreadID() + "];"); } genSharedMemBarrier(env.getStream()); @@ -462,7 +462,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase CodeStream::Scope b(env.getStream()); // If there is a synapse for this thread to initialise - env.getStream() << "if(" << env["id"] << " < shRowLength[i])"; + env.print("if($(id) < $(_sh_row_length)[i])"); { CodeStream::Scope b(env.getStream()); @@ -479,7 +479,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase } // If matrix is ragged, advance index to next row by adding stride - env.getStream() << "idx += " << env["_row_stride"] << ";" << std::endl; + env.printLine("idx += $(_row_stride);"); } } } diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index f21231cb33..87b06ebc0c 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -883,7 +883,10 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedSynapseInitGroup" << s.getIndex() << "[g]; " << std::endl; - s.generateInit(*this, funcEnv, modelMerged); + EnvironmentGroupMergedField groupEnv(funcEnv, s); + genSynapseIndexCalculation(groupEnv, 1); + + s.generateInit(*this, groupEnv, modelMerged); } }); diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index a2d5fcb1c6..129320335f 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -1655,11 +1655,12 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // Initialise weight update variables for synapse groups with sparse connectivity genParallelGroup( - env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedSynapseSparseInitGroups, + envKernel, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedSynapseSparseInitGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumConnectivityInitThreads(sg), KernelInitializeSparse); }, [&modelMerged, numInitializeThreads, this](EnvironmentExternalBase &env, SynapseSparseInitGroupMerged &sg) { EnvironmentGroupMergedField groupEnv(env, sg); + genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); // If this post synapse requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id @@ -1696,7 +1697,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // Initialise weight update variables for synapse groups with sparse connectivity genParallelGroup( - env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomWUUpdateSparseInitGroups, + envKernel, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomWUUpdateSparseInitGroups, [this](const CustomUpdateWUInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, CustomWUUpdateSparseInitGroupMerged &cg) { @@ -1718,7 +1719,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // Initialise weight update variables for synapse groups with sparse connectivity genParallelGroup( - env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdateSparseInitGroups, + envKernel, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdateSparseInitGroups, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdateSparseInitGroupMerged &cg) { diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index b0dca92e17..058961550b 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -535,7 +535,6 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, Environmen { // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); - backend.genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); // If model is batched and has kernel weights const bool kernel = (getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL); From c6ade21a7574e70c877ca35395f4812e00cdb043 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 12:19:49 +0100 Subject: [PATCH 373/725] added missing lazy dynamic allocate, push and pull methods to single-threaded CPU backend --- .../backends/single_threaded_cpu/backend.h | 15 ++++++++++ .../backends/single_threaded_cpu/backend.cc | 28 ++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index 268372d158..e5f5c4cc86 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -81,6 +81,11 @@ class BACKEND_EXPORT Backend : public BackendBase const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + //! Generate code to allocate variable with a size known at runtime + virtual void genLazyVariableDynamicAllocation(CodeStream &os, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, + const std::string &countVarName) const final; + //! Generate code to free a variable virtual void genVariableFree(CodeStream &os, const std::string &name, VarLocation loc) const final; @@ -109,11 +114,21 @@ class BACKEND_EXPORT Backend : public BackendBase const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + //! Generate code for pushing a variable with a size known at tuntime to the 'device' + virtual void genLazyVariableDynamicPush(CodeStream &os, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, + const std::string &countVarName) const final; + //! Generate code for pulling a variable with a size known at runtime from the 'device' virtual void genVariableDynamicPull(CodeStream &os, const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; + //! Generate code for pulling a variable with a size known at runtime from the 'device' + virtual void genLazyVariableDynamicPull(CodeStream &os, + const Type::ResolvedType &type, const std::string &name, VarLocation loc, + const std::string &countVarName) const final; + //! Generate code for pushing a new pointer to a dynamic variable into the merged group structure on 'device' virtual void genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, const std::string &groupIdx, const std::string &fieldName, diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 87b06ebc0c..8ce9acd98b 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -883,7 +883,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedSynapseInitGroup" << s.getIndex() << "[g]; " << std::endl; - EnvironmentGroupMergedField groupEnv(funcEnv, s); + EnvironmentGroupMergedField groupEnv(funcEnv, s); genSynapseIndexCalculation(groupEnv, 1); s.generateInit(*this, groupEnv, modelMerged); @@ -1405,6 +1405,18 @@ void Backend::genVariableDynamicAllocation(CodeStream &os, } } //-------------------------------------------------------------------------- +void Backend::genLazyVariableDynamicAllocation(CodeStream &os, + const Type::ResolvedType &type, const std::string &name, VarLocation, + const std::string &countVarName) const +{ + if (type.isPointer()) { + os << "*$(_" << name << ") = new " << type.getPointer().valueType->getValue().name << "[" << countVarName << "];" << std::endl; + } + else { + os << "$(_" << name << ") = new " << type.getValue().name << "[" << countVarName << "];" << std::endl; + } +} +//-------------------------------------------------------------------------- void Backend::genVariableFree(CodeStream &os, const std::string &name, VarLocation) const { os << "delete[] " << name << ";" << std::endl; @@ -1441,6 +1453,13 @@ void Backend::genVariableDynamicPush(CodeStream&, assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- +void Backend::genLazyVariableDynamicPush(CodeStream&, + const Type::ResolvedType&, const std::string&, + VarLocation, const std::string&) const +{ + assert(!getPreferences().automaticCopy); +} +//-------------------------------------------------------------------------- void Backend::genVariableDynamicPull(CodeStream&, const Type::ResolvedType&, const std::string&, VarLocation, const std::string&, const std::string&) const @@ -1448,6 +1467,13 @@ void Backend::genVariableDynamicPull(CodeStream&, assert(!getPreferences().automaticCopy); } //-------------------------------------------------------------------------- +void Backend::genLazyVariableDynamicPull(CodeStream&, + const Type::ResolvedType&, const std::string&, + VarLocation, const std::string&) const +{ + assert(!getPreferences().automaticCopy); +} +//-------------------------------------------------------------------------- void Backend::genMergedDynamicVariablePush(CodeStream &os, const std::string &suffix, size_t mergedGroupIdx, const std::string &groupIdx, const std::string &fieldName, const std::string &egpName) const From 0801e27a06049cd9bb9db5bdd7e1f78d4f0b5e35 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 12:20:05 +0100 Subject: [PATCH 374/725] remove codeGenUtils unit tests --- tests/unit/codeGenUtils.cc | 138 ------------------------------------- tests/unit/unit.vcxproj | 1 - 2 files changed, 139 deletions(-) delete mode 100644 tests/unit/codeGenUtils.cc diff --git a/tests/unit/codeGenUtils.cc b/tests/unit/codeGenUtils.cc deleted file mode 100644 index a935829b21..0000000000 --- a/tests/unit/codeGenUtils.cc +++ /dev/null @@ -1,138 +0,0 @@ -// C++ standard includes -#include -#include - -// C standard includes -#include - -// Google test includes -#include "gtest/gtest.h" - -// GeNN code generator includes -#include "code_generator/codeGenUtils.h" -#include "code_generator/substitutions.h" - -using namespace GeNN; -using namespace GeNN::CodeGenerator; - -// Test based on original issue found in https://github.com/brian-team/brian2genn/pull/60 to make sure that ensureFtype doesn't break functions it shouldn't -TEST(CodeGenUtils, ISinF) { - const std::string code = - "const int _infinity_int = 1073741823; // maximum 32bit integer divided by 2\n" - "if (std::isinf(t))\n" - "{\n"; - - std::string substitutedCode = ensureFtype(code, "double"); - ASSERT_EQ(code, substitutedCode); -} - -// Test based on comments by Marcel in https://github.com/brian-team/brian2genn/pull/60 -TEST(CodeGenUtils, foo123) { - const std::string code = "int foo123 = 6;"; - - std::string substitutedCode = code; - regexVarSubstitute(substitutedCode, "foo", "bar"); - ASSERT_EQ(code, substitutedCode); -} - -// Test based on comments by Thomas in https://github.com/brian-team/brian2genn/pull/60 -TEST(CodeGenUtils, not2well) { - const std::string code = "int not2well = 6;"; - - std::string substitutedCode = code; - regexVarSubstitute(substitutedCode, "well", "hell"); - ASSERT_EQ(code, substitutedCode); -} - -// Check that generic maths functions DON'T get messed with -TEST(CodeGenUtils, rint) { - const std::string code = "$(value) = (uint8_t)rint(normal / DT);"; - - const std::string substitutedCode = ensureFtype(code, "float"); - ASSERT_EQ(substitutedCode, "$(value) = (uint8_t)rint(normal / DT);"); -} - -// Check that old-style single-precision maths functions get replaced with generic version -TEST(CodeGenUtils, rintf) { - const std::string code = "$(value) = (uint8_t)rintf(normal / DT);"; - - const std::string substitutedCode = ensureFtype(code, "float"); - ASSERT_EQ(substitutedCode, "$(value) = (uint8_t)rint(normal / DT);"); -} - -// Check that namespace substitution in support code works -TEST(CodeGenUtils, supportCodeFunc) { - const std::string supportCode = "SUPPORT_CODE_FUNC scalar supportCodeFunc(scalar x){ return x; }"; - const std::string code = "supportCodeFunc(x);"; - const std::string substitutedCode = disambiguateNamespaceFunction(supportCode, code, "TestNamespace"); - ASSERT_EQ(substitutedCode, "TestNamespace_supportCodeFunc(x);"); -} - -TEST(CodeGenUtils, FunctionSubstitute) -{ - std::string code = "$(print, $(id_pre), sin($(id_post)));"; - functionSubstitute(code, "print", 2, "printf(\"%d,%d\n\", $(0), $(1))"); - ASSERT_EQ(code, "printf(\"%d,%d\n\", $(id_pre), sin($(id_post)));"); -} -TEST(CodeGenUtils, MissingFunctionClosedBracket) -{ - const std::string code = "$(for_each_synapse, printf(\"%d,\", j)"; - std::string mutableCode = code; - functionSubstitute(mutableCode, "for_each_synapse", 1, "for(int j = 0; j < 10; j++){ $(0) }"); - ASSERT_EQ(code, mutableCode); -} - -//-------------------------------------------------------------------------- -// SingleValueSubstitutionTest -//-------------------------------------------------------------------------- -class SingleValueSubstitutionTest : public ::testing::TestWithParam -{ -protected: - //-------------------------------------------------------------------------- - // Test virtuals - //-------------------------------------------------------------------------- - virtual void SetUp() - { - // Substitute variable for value - m_Code = "$(test)"; - - // Substitute test parameter for value - Substitutions subs; - subs.addParamValueSubstitution({"test"}, {{"test", GetParam()}}); - subs.apply(m_Code); - - // For safety, value_substitutions adds brackets around substituted values - trim these out - m_Code = m_Code.substr(1, m_Code.size() - 2); - } - - //-------------------------------------------------------------------------- - // Protected API - //-------------------------------------------------------------------------- - const std::string &GetCode() const { return m_Code; } - -private: - //-------------------------------------------------------------------------- - // Private API - //-------------------------------------------------------------------------- - std::string m_Code; -}; - -//-------------------------------------------------------------------------- -// Tests -//-------------------------------------------------------------------------- -TEST_P(SingleValueSubstitutionTest, CorrectGeneratedValue) -{ - // Convert results back to double and check they match - double result = std::atof(GetCode().c_str()); - ASSERT_DOUBLE_EQ(result, GetParam()); -} - -//-------------------------------------------------------------------------- -// Instatiations -//-------------------------------------------------------------------------- -INSTANTIATE_TEST_SUITE_P(CodeGenUtils, - SingleValueSubstitutionTest, - ::testing::Values(std::numeric_limits::min(), - std::numeric_limits::max(), - 1.0, - -1.0)); diff --git a/tests/unit/unit.vcxproj b/tests/unit/unit.vcxproj index 87b2db7cbb..4289174ba7 100644 --- a/tests/unit/unit.vcxproj +++ b/tests/unit/unit.vcxproj @@ -15,7 +15,6 @@ - From 4c3f463746e19321856d0d90eda43a88a4fa27ce Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 15:39:02 +0100 Subject: [PATCH 375/725] removed duplicate unused ``upgradeCodeString`` --- .../genn/genn/code_generator/codeGenUtils.h | 6 ---- src/genn/genn/code_generator/codeGenUtils.cc | 30 ------------------- 2 files changed, 36 deletions(-) diff --git a/include/genn/genn/code_generator/codeGenUtils.h b/include/genn/genn/code_generator/codeGenUtils.h index 51056cd01f..236f032be2 100644 --- a/include/genn/genn/code_generator/codeGenUtils.h +++ b/include/genn/genn/code_generator/codeGenUtils.h @@ -57,12 +57,6 @@ GENN_EXPORT void genTypeRange(CodeStream &os, const Type::ResolvedType &type, co //-------------------------------------------------------------------------- GENN_EXPORT std::string disambiguateNamespaceFunction(const std::string supportCode, const std::string code, std::string namespaceName); -//-------------------------------------------------------------------------- -/*! \brief This function automatically replaces old style $(variable) variable references and $(function, arg1, arg2) syntax with new form. - */ - //-------------------------------------------------------------------------- -GENN_EXPORT std::string upgradeCodeString(const std::string &codeString); - //-------------------------------------------------------------------------- /*! \brief This function uses the transpiler to parse, type check and pretty print previously scanned vector of tokens representing an expression */ diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index 3baf29bb32..7d2ebe3d07 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -93,36 +93,6 @@ std::string disambiguateNamespaceFunction(const std::string supportCode, const s return newCode; } //---------------------------------------------------------------------------- -std::string upgradeCodeString(const std::string &codeString) -{ - - // Build vector of regular expressions to replace old style function calls - const std::vector> functionReplacements{ - {std::regex(R"(\$\(gennrand_uniform\))"), "gennrand_uniform()"}, - {std::regex(R"(\$\(gennrand_normal\))"), "gennrand_normal()"}, - {std::regex(R"(\$\(gennrand_exponential\))"), "gennrand_exponential()"}, - {std::regex(R"(\$\(gennrand_log_normal,(.*)\))"), "gennrand_log_normal($1)"}, - {std::regex(R"(\$\(gennrand_gamma,(.*)\))"), "gennrand_gamma($1)"}, - {std::regex(R"(\$\(gennrand_binomial,(.*)\))"), "gennrand_binomial($1)"}, - {std::regex(R"(\$\(addSynapse,(.*)\))"), "addSynapse($1)"}, - {std::regex(R"(\$\(endRow\))"), "endRow()"}, - {std::regex(R"(\$\(endCol\))"), "endCol()"}}; - - // Apply sustitutions to upgraded code string - std::string upgradedCodeString = codeString; - for(const auto &f : functionReplacements) { - upgradedCodeString = std::regex_replace(upgradedCodeString, f.first, f.second); - } - - // **TODO** snake-case -> camel case known built in variables e.g id_pre -> idPre - - // Replace old style $(XX) variables with plain XX - // **NOTE** this is done after functions as single-parameter function calls and variables were indistinguishable with old syntax - const std::regex variable(R"(\$\(([_a-zA-Z][_a-zA-Z0-9]*)\))"); - upgradedCodeString = std::regex_replace(upgradedCodeString, variable, "$1"); - return upgradedCodeString; -} -//---------------------------------------------------------------------------- void prettyPrintExpression(const std::vector &tokens, const Type::TypeContext &typeContext, EnvironmentExternalBase &env, Transpiler::ErrorHandlerBase &errorHandler) { using namespace Transpiler; From 4138618405870d433f4bda2f0db8c367a1765579 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 15:39:29 +0100 Subject: [PATCH 376/725] add some extra functions to ``upgradeCodeString`` --- src/genn/genn/gennUtils.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/genn/genn/gennUtils.cc b/src/genn/genn/gennUtils.cc index 41a575423f..d1f74ca804 100644 --- a/src/genn/genn/gennUtils.cc +++ b/src/genn/genn/gennUtils.cc @@ -38,6 +38,9 @@ std::string upgradeCodeString(const std::string &codeString) {std::regex(R"(\$\(gennrand_log_normal,(.*)\))"), "gennrand_log_normal($1)"}, {std::regex(R"(\$\(gennrand_gamma,(.*)\))"), "gennrand_gamma($1)"}, {std::regex(R"(\$\(gennrand_binomial,(.*)\))"), "gennrand_binomial($1)"}, + {std::regex(R"(\$\(addToPre,(.*)\))"), "addToPre($1)"}, + {std::regex(R"(\$\(addToInSyn,(.*)\))"), "addToPost($1)"}, + {std::regex(R"(\$\(addToInSynDelay,(.*),(.*)\))"), "addToPostDelay($1,$2)"}, {std::regex(R"(\$\(addSynapse,(.*)\))"), "addSynapse($1)"}, {std::regex(R"(\$\(endRow\))"), "endRow()"}, {std::regex(R"(\$\(endCol\))"), "endCol()"}}; From ebbf66497f794a433df0636f07cd3f1ec8959b2c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 15:39:47 +0100 Subject: [PATCH 377/725] made some methods public for unit tests --- include/genn/genn/code_generator/initGroupMerged.h | 12 ++++++------ .../genn/code_generator/neuronUpdateGroupMerged.h | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 0fbf554743..fceec9bdd3 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -301,6 +301,12 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public GroupMerged &getMergedOutSynPreOutputGroups() const { return m_MergedOutSynPreOutputGroups; } const std::vector &getMergedInSynWUMPostCodeGroups() const { return m_MergedInSynWUMPostCodeGroups; } const std::vector &getMergedOutSynWUMPreCodeGroups() const { return m_MergedOutSynWUMPreCodeGroups; } + + //! Should the parameter be implemented heterogeneously? + bool isParamHeterogeneous(const std::string ¶mName) const; + + //! Should the derived parameter be implemented heterogeneously? + bool isDerivedParamHeterogeneous(const std::string ¶mName) const; //---------------------------------------------------------------------------- // Static constants @@ -201,12 +207,6 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase static const std::string name; private: - //! Should the parameter be implemented heterogeneously? - bool isParamHeterogeneous(const std::string ¶mName) const; - - //! Should the derived parameter be implemented heterogeneously? - bool isDerivedParamHeterogeneous(const std::string ¶mName) const; - //------------------------------------------------------------------------ // Members //------------------------------------------------------------------------ From f98f326e5e6f5a0e3f4532ef8cb892bcfecafdfd Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 15:40:04 +0100 Subject: [PATCH 378/725] renamed ``PROCEDURAL_PROCEDURALG`` to ``PROCEDURAL`` --- include/genn/genn/synapseMatrixType.h | 14 +++++++------- src/genn/genn/synapseGroup.cc | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/genn/genn/synapseMatrixType.h b/include/genn/genn/synapseMatrixType.h index c6bd0d5a42..48abb6d53b 100644 --- a/include/genn/genn/synapseMatrixType.h +++ b/include/genn/genn/synapseMatrixType.h @@ -26,13 +26,13 @@ enum class SynapseMatrixWeight : unsigned int //! Supported combinations of SynapticMatrixConnectivity and SynapticMatrixWeight enum class SynapseMatrixType : unsigned int { - DENSE = static_cast(SynapseMatrixConnectivity::DENSE) | static_cast(SynapseMatrixWeight::INDIVIDUAL), - DENSE_PROCEDURALG = static_cast(SynapseMatrixConnectivity::DENSE) | static_cast(SynapseMatrixWeight::PROCEDURAL), - BITMASK = static_cast(SynapseMatrixConnectivity::BITMASK) | static_cast(SynapseMatrixWeight::INDIVIDUAL), - SPARSE = static_cast(SynapseMatrixConnectivity::SPARSE) | static_cast(SynapseMatrixWeight::INDIVIDUAL), - PROCEDURAL_PROCEDURALG = static_cast(SynapseMatrixConnectivity::PROCEDURAL) | static_cast(SynapseMatrixWeight::PROCEDURAL), - PROCEDURAL_KERNELG = static_cast(SynapseMatrixConnectivity::PROCEDURAL) | static_cast(SynapseMatrixWeight::KERNEL), - TOEPLITZ = static_cast(SynapseMatrixConnectivity::TOEPLITZ) | static_cast(SynapseMatrixWeight::KERNEL), + DENSE = static_cast(SynapseMatrixConnectivity::DENSE) | static_cast(SynapseMatrixWeight::INDIVIDUAL), + DENSE_PROCEDURALG = static_cast(SynapseMatrixConnectivity::DENSE) | static_cast(SynapseMatrixWeight::PROCEDURAL), + BITMASK = static_cast(SynapseMatrixConnectivity::BITMASK) | static_cast(SynapseMatrixWeight::INDIVIDUAL), + SPARSE = static_cast(SynapseMatrixConnectivity::SPARSE) | static_cast(SynapseMatrixWeight::INDIVIDUAL), + PROCEDURAL = static_cast(SynapseMatrixConnectivity::PROCEDURAL) | static_cast(SynapseMatrixWeight::PROCEDURAL), + PROCEDURAL_KERNELG = static_cast(SynapseMatrixConnectivity::PROCEDURAL) | static_cast(SynapseMatrixWeight::KERNEL), + TOEPLITZ = static_cast(SynapseMatrixConnectivity::TOEPLITZ) | static_cast(SynapseMatrixWeight::KERNEL), }; //---------------------------------------------------------------------------- diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index f92b5c7957..64a24651e5 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -473,7 +473,7 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType } // If connectivity initialisation snippet defines a kernel and matrix type doesn't support it, give error - if(!m_KernelSize.empty() && (m_MatrixType != SynapseMatrixType::PROCEDURAL_PROCEDURALG) && (m_MatrixType != SynapseMatrixType::TOEPLITZ) + if(!m_KernelSize.empty() && (m_MatrixType != SynapseMatrixType::PROCEDURAL) && (m_MatrixType != SynapseMatrixType::TOEPLITZ) && (m_MatrixType != SynapseMatrixType::SPARSE) && (m_MatrixType != SynapseMatrixType::PROCEDURAL_KERNELG)) { throw std::runtime_error("BITMASK connectivity can only be used with weight update models without variables like StaticPulseConstantWeight."); @@ -489,7 +489,7 @@ SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType // If synapse group uses sparse or procedural connectivity but no kernel size is provided, // check that no variable's initialisation snippets require a kernel - if(((m_MatrixType == SynapseMatrixType::SPARSE) || (m_MatrixType == SynapseMatrixType::PROCEDURAL_PROCEDURALG)) && + if(((m_MatrixType == SynapseMatrixType::SPARSE) || (m_MatrixType == SynapseMatrixType::PROCEDURAL)) && m_KernelSize.empty() && std::any_of(getWUVarInitialisers().cbegin(), getWUVarInitialisers().cend(), [](const auto &v) { return v.second.isKernelRequired(); })) { From 13fe8f81edd6ae828b3925d6c34850c957935a12 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 15:40:13 +0100 Subject: [PATCH 379/725] unit tests now compile --- tests/unit/currentSource.cc | 6 +- tests/unit/currentSourceModels.cc | 2 +- tests/unit/customConnectivityUpdate.cc | 110 +++--- tests/unit/customUpdate.cc | 108 +++--- tests/unit/initSparseConnectivitySnippet.cc | 32 +- tests/unit/initVarSnippet.cc | 4 +- tests/unit/modelSpec.cc | 15 +- tests/unit/modelSpecMerged.cc | 63 ++-- tests/unit/models.cc | 100 ++--- tests/unit/neuronGroup.cc | 384 ++++++++++---------- tests/unit/scanner.cc | 12 +- tests/unit/synapseGroup.cc | 303 +++++++-------- tests/unit/typeChecker.cc | 56 ++- tests/unit/weightUpdateModels.cc | 16 +- 14 files changed, 630 insertions(+), 581 deletions(-) diff --git a/tests/unit/currentSource.cc b/tests/unit/currentSource.cc index 1af2790d43..38792963d7 100644 --- a/tests/unit/currentSource.cc +++ b/tests/unit/currentSource.cc @@ -29,7 +29,7 @@ TEST(CurrentSource, CompareDifferentModel) cs1ParamVals, {}); // Finalize model - model.finalize(); + model.finalise(); CurrentSourceInternal *cs0Internal = static_cast(cs0); CurrentSourceInternal *cs1Internal = static_cast(cs1); @@ -56,7 +56,7 @@ TEST(CurrentSource, CompareDifferentParameters) cs1ParamVals, {}); // Finalize model - model.finalize(); + model.finalise(); CurrentSourceInternal *cs0Internal = static_cast(cs0); CurrentSourceInternal *cs1Internal = static_cast(cs1); @@ -83,7 +83,7 @@ TEST(CurrentSource, CompareSameParameters) cs1ParamVals, {}); // Finalize model - model.finalize(); + model.finalise(); CurrentSourceInternal *cs0Internal = static_cast(cs0); CurrentSourceInternal *cs1Internal = static_cast(cs1); diff --git a/tests/unit/currentSourceModels.cc b/tests/unit/currentSourceModels.cc index 5a46bd9df9..ff4e72a0f0 100644 --- a/tests/unit/currentSourceModels.cc +++ b/tests/unit/currentSourceModels.cc @@ -11,7 +11,7 @@ using namespace GeNN; //-------------------------------------------------------------------------- class GaussianNoiseCopy : public CurrentSourceModels::Base { - SET_INJECTION_CODE("$(injectCurrent, $(mean) + $(gennrand_normal) * $(sd));\n"); + SET_INJECTION_CODE("injectCurrent(mean + (gennrand_normal() * sd));\n"); SET_PARAM_NAMES({"mean", "sd"} ); }; diff --git a/tests/unit/customConnectivityUpdate.cc b/tests/unit/customConnectivityUpdate.cc index 8dc4b64d8b..3df14d5a5a 100644 --- a/tests/unit/customConnectivityUpdate.cc +++ b/tests/unit/customConnectivityUpdate.cc @@ -21,7 +21,7 @@ class StaticPulseDendriticDelayReverse : public WeightUpdateModels::Base SET_VARS({{"d", "uint8_t", VarAccess::READ_ONLY}, {"g", "scalar", VarAccess::READ_ONLY}}); - SET_SIM_CODE("$(addToInSynDelay, $(g), $(d));\n"); + SET_SIM_CODE("addToInSynDelay(g, d);\n"); }; IMPLEMENT_SNIPPET(StaticPulseDendriticDelayReverse); @@ -29,7 +29,7 @@ class Sum : public CustomUpdateModels::Base { DECLARE_SNIPPET(Sum); - SET_UPDATE_CODE("$(sum) += $(a);\n"); + SET_UPDATE_CODE("sum += a;\n"); SET_VARS({{"sum", "scalar"}}); SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_ONLY}}); @@ -43,13 +43,12 @@ class RemoveSynapse : public CustomConnectivityUpdateModels::Base SET_VARS({{"a", "scalar"}}); SET_ROW_UPDATE_CODE( - "$(for_each_synapse,\n" - "{\n" - " if($(id_post) == ($(id_pre) + 1)) {\n" - " $(remove_synapse);\n" + "for_each_synapse {\n" + " if(id_post == (id_pre + 1)) {\n" + " remove_synapse();\n" " break;\n" " }\n" - "});\n"); + "};\n"); }; IMPLEMENT_SNIPPET(RemoveSynapse); @@ -61,13 +60,12 @@ class RemoveSynapseVarRef : public CustomConnectivityUpdateModels::Base SET_VARS({{"a", "scalar"}}); SET_VAR_REFS({{"b", "scalar"}}); SET_ROW_UPDATE_CODE( - "$(for_each_synapse,\n" - "{\n" - " if($(id_post) == ($(id_pre) + 1)) {\n" - " $(remove_synapse);\n" + "for_each_synapse {\n" + " if(id_post == (id_pre + 1)) {\n" + " remove_synapse();\n" " break;\n" " }\n" - "});\n"); + "};\n"); }; IMPLEMENT_SNIPPET(RemoveSynapseVarRef); @@ -79,13 +77,12 @@ class RemoveSynapsePre : public CustomConnectivityUpdateModels::Base SET_VAR_REFS({{"g", "scalar"}}); SET_PRE_VAR_REFS({{"threshLow", "scalar"}, {"threshHigh", "scalar"}}); SET_ROW_UPDATE_CODE( - "$(for_each_synapse,\n" - "{\n" - " if($(g) < $(threshLow) || $(g) > $(threshHigh)) {\n" - " $(remove_synapse);\n" + "for_each_synapse {\n" + " if(g < threshLow || g > threshHigh) {\n" + " remove_synapse();\n" " break;\n" " }\n" - "});\n"); + "};\n"); }; IMPLEMENT_SNIPPET(RemoveSynapsePre); @@ -97,13 +94,12 @@ class RemoveSynapsePost : public CustomConnectivityUpdateModels::Base SET_VAR_REFS({{"g", "scalar"}}); SET_POST_VAR_REFS({{"threshLow", "scalar"}, {"threshHigh", "scalar"}}); SET_ROW_UPDATE_CODE( - "$(for_each_synapse,\n" - "{\n" - " if($(g) < $(threshLow) || $(g) > $(threshHigh)) {\n" - " $(remove_synapse);\n" + "for_each_synapse {\n" + " if(g < threshLow || g > threshHigh) {\n" + " remove_synapse();\n" " break;\n" " }\n" - "});\n"); + "};\n"); }; IMPLEMENT_SNIPPET(RemoveSynapsePost); @@ -115,7 +111,7 @@ class Cont : public WeightUpdateModels::Base SET_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( - "$(addToInSyn, $(g) * $(V_pre));\n"); + "addToInSyn(g * V_pre);\n"); }; IMPLEMENT_SNIPPET(Cont); @@ -127,7 +123,7 @@ class ContPost : public WeightUpdateModels::Base SET_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( - "$(addToInSyn, $(g) * $(V_post));\n"); + "addToInSyn(g * V_post);\n"); }; IMPLEMENT_SNIPPET(ContPost); @@ -155,10 +151,10 @@ TEST(CustomConnectivityUpdate, DependentVariables) model.addNeuronPopulation("Post", 10, paramVals, varVals); // Create synapse group with global weights - model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation( + "Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", - {}, {{"g", 1.0}}, + {{"g", 1.0}}, {}, {}, {}); // Attach custom connectivity update @@ -168,7 +164,7 @@ TEST(CustomConnectivityUpdate, DependentVariables) // Create synapse group with individual weights model.addSynapsePopulation( - "Synapses2", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses2", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); @@ -180,7 +176,7 @@ TEST(CustomConnectivityUpdate, DependentVariables) // Create synapse group with individual weights auto *sg3 = model.addSynapsePopulation( - "Synapses3", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses3", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); @@ -196,7 +192,7 @@ TEST(CustomConnectivityUpdate, DependentVariables) // Create synapse group with individual weights model.addSynapsePopulation( - "Synapses4", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses4", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); @@ -211,7 +207,7 @@ TEST(CustomConnectivityUpdate, DependentVariables) {}, {{"a", 1.0}}, {}, {}, {}, {}, {}); - model.finalize(); + model.finalise(); // Check no dependencies for CCU1 auto ccu1DependentVars = static_cast(ccu1)->getDependentVariables(); @@ -252,7 +248,7 @@ TEST(CustomConnectivityUpdate, DependentVariablesManualReferences) for (int i = 0; i < 3; i++) { // Create synapse group with individual weights synapseGroups[i] = model.addSynapsePopulation( - "Synapses" + std::to_string(i), SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses" + std::to_string(i), SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); @@ -286,7 +282,7 @@ TEST(CustomConnectivityUpdate, DependentVariablesManualReferences) {}, {{"a", 1.0}}, {}, {}, ccu32VarRefs, {}, {}); - model.finalize(); + model.finalise(); // Check synapse group variable has been removed from CCU12 dependent variables as it's manually referenced auto ccu12DependentVars = static_cast(ccu12)->getDependentVariables(); @@ -318,19 +314,19 @@ TEST(CustomConnectivityUpdate, CompareDifferentDependentVars) model.addNeuronPopulation("Post", 10, paramVals, varVals); model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}, {"d", 1.0}}, {}, {}); model.addSynapsePopulation( - "Synapses2", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses2", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}, {"d", 1.0}}, {}, {}); model.addSynapsePopulation( - "Synapses3", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses3", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); @@ -344,7 +340,7 @@ TEST(CustomConnectivityUpdate, CompareDifferentDependentVars) auto *ccu3 = model.addCustomConnectivityUpdate("CustomConnectivityUpdate3", "Test2", "Synapses3", {}, {{"a", 1.0}}, {}, {}, {}, {}, {}); - model.finalize(); + model.finalise(); auto *ccu1Internal = static_cast(ccu1); auto *ccu2Internal = static_cast(ccu2); @@ -381,10 +377,10 @@ TEST(CustomConnectivityUpdate, BitmaskConnectivity) model.addNeuronPopulation("Post", 10, paramVals, varVals); // Create synapse group with global weights - model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::BITMASK_GLOBALG, NO_DELAY, + model.addSynapsePopulation( + "Synapses1", SynapseMatrixType::BITMASK, NO_DELAY, "Pre", "Post", - {}, {{"g", 1.0}}, + {{"g", 1.0}}, {}, {}, {}); try { @@ -411,7 +407,7 @@ TEST(CustomConnectivityUpdate, WrongPrePostSize) // Create synapse group with global weights auto *syn = model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); @@ -481,14 +477,14 @@ TEST(CustomConnectivityUpdate, WrongSG) // Create synapse group with global weights auto *syn1 = model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); // Create synapse group with global weights auto *syn2 = model.addSynapsePopulation( - "Synapses2", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses2", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); @@ -526,7 +522,7 @@ TEST(CustomConnectivityUpdate, DuplicatePrePost) // Create synapse group with global weights auto *syn = model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); @@ -540,7 +536,7 @@ TEST(CustomConnectivityUpdate, DuplicatePrePost) {"threshHigh", createVarRef(pre, "U")}}, {}); try { - model.finalize(); + model.finalise(); FAIL(); } catch(const std::runtime_error &) { @@ -562,14 +558,14 @@ TEST(CustomConnectivityUpdate, MixedPreDelayGroups) // Create synapse group with global weights auto *syn1 = model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::SPARSE_INDIVIDUALG, 5, + "Synapses1", SynapseMatrixType::SPARSE, 5, "Pre1", "Post1", {}, {{"g", 1.0}}, {}, {}); // Create synapse group with global weights model.addSynapsePopulation( - "Synapses2", SynapseMatrixType::SPARSE_INDIVIDUALG, 10, + "Synapses2", SynapseMatrixType::SPARSE, 10, "Pre2", "Post2", {}, {{"g", 1.0}}, {}, {}); @@ -588,7 +584,7 @@ TEST(CustomConnectivityUpdate, MixedPreDelayGroups) VarReferences{{"threshLow", createVarRef(pre1, "V")}, {"threshHigh", createVarRef(pre2, "V")}}, {}); try { - model.finalize(); + model.finalise(); FAIL(); } catch(const std::runtime_error &) { @@ -609,7 +605,7 @@ TEST(CustomConnectivityUpdate, MixedPostDelayGroups) // Create synapse group with global weights auto *syn1 = model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Pre1", "Post1", {}, {{"g", 1.0}}, {}, {}); @@ -617,7 +613,7 @@ TEST(CustomConnectivityUpdate, MixedPostDelayGroups) // Create synapse group with global weights auto *syn2 = model.addSynapsePopulation( - "Synapses2", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses2", SynapseMatrixType::SPARSE, NO_DELAY, "Pre2", "Post2", {}, {{"g", 1.0}}, {}, {}); @@ -637,7 +633,7 @@ TEST(CustomConnectivityUpdate, MixedPostDelayGroups) {}, VarReferences{{"threshLow", createVarRef(post1, "V")}, {"threshHigh", createVarRef(post2, "V")}}); try { - model.finalize(); + model.finalise(); FAIL(); } catch(const std::runtime_error &) { @@ -655,10 +651,10 @@ TEST(CustomConnectivityUpdate, InvalidName) model.addNeuronPopulation("Post", 10, paramVals, varVals); // Create synapse group with global weights - model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation( + "Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", - {}, {{"g", 1.0}}, + {{"g", 1.0}}, {}, {}, {}); try { @@ -682,10 +678,10 @@ TEST(CustomConnectivityUpdate, InvalidUpdateGroupName) model.addNeuronPopulation("Post", 10, paramVals, varVals); // Create synapse group with global weights - model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation( + "Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", - {}, {{"g", 1.0}}, + {{"g", 1.0}}, {}, {}, {}); try { diff --git a/tests/unit/customUpdate.cc b/tests/unit/customUpdate.cc index da32220f7a..fff2fd04af 100644 --- a/tests/unit/customUpdate.cc +++ b/tests/unit/customUpdate.cc @@ -87,7 +87,7 @@ class Cont : public WeightUpdateModels::Base SET_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( - "$(addToInSyn, $(g) * $(V_pre));\n"); + "addToInSyn(g * V_pre);\n"); }; IMPLEMENT_SNIPPET(Cont); @@ -99,7 +99,7 @@ class Cont2 : public WeightUpdateModels::Base SET_VARS({{"g", "scalar"}, {"x", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( - "$(addToInSyn, ($(g) + $(x)) * $(V_pre));\n"); + "addToInSyn((g + x) * V_pre);\n"); }; IMPLEMENT_SNIPPET(Cont2); @@ -107,7 +107,7 @@ class Reduce : public CustomUpdateModels::Base { DECLARE_SNIPPET(Reduce); - SET_UPDATE_CODE("$(reduction) = $(var);\n"); + SET_UPDATE_CODE("reduction = var;\n"); SET_VAR_REFS({{"var", "scalar", VarAccessMode::READ_ONLY}, {"reduction", "scalar", VarAccessMode::REDUCE_SUM}}); @@ -119,8 +119,8 @@ class ReduceDouble : public CustomUpdateModels::Base DECLARE_SNIPPET(ReduceDouble); SET_UPDATE_CODE( - "$(reduction1) = $(var1);\n" - "$(reduction2) = $(var2);\n"); + "reduction1 = var1;\n" + "reduction2 = var2;\n"); SET_VARS({{"reduction1", "scalar", VarAccess::REDUCE_BATCH_SUM}, {"reduction2", "scalar", VarAccess::REDUCE_NEURON_SUM}}); @@ -134,7 +134,7 @@ class ReduceSharedVar : public CustomUpdateModels::Base { DECLARE_SNIPPET(ReduceSharedVar); - SET_UPDATE_CODE("$(reduction) = $(var);\n"); + SET_UPDATE_CODE("reduction = var;\n"); SET_VARS({{"reduction", "scalar", VarAccess::REDUCE_BATCH_SUM}}) SET_VAR_REFS({{"var", "scalar", VarAccessMode::READ_ONLY}}); @@ -146,7 +146,7 @@ class ReduceNeuronSharedVar : public CustomUpdateModels::Base { DECLARE_SNIPPET(ReduceNeuronSharedVar); - SET_UPDATE_CODE("$(reduction) = $(var);\n"); + SET_UPDATE_CODE("reduction = var;\n"); SET_VARS({{"reduction", "scalar", VarAccess::REDUCE_NEURON_SUM}}) SET_VAR_REFS({{"var", "scalar", VarAccessMode::READ_ONLY}}); @@ -171,7 +171,7 @@ TEST(CustomUpdates, ConstantVarSum) CustomUpdate *cu = model.addCustomUpdate("Sum", "CustomUpdate", {}, sumVarValues, sumVarReferences1); - model.finalize(); + model.finalise(); CustomUpdateInternal *cuInternal = static_cast(cu); ASSERT_FALSE(cuInternal->isZeroCopyEnabled()); @@ -180,11 +180,7 @@ TEST(CustomUpdates, ConstantVarSum) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; CodeGenerator::SingleThreadedCPU::Backend backend(preferences); - - // Merge model - CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); - - ASSERT_FALSE(backend.isGlobalHostRNGRequired(modelSpecMerged)); + ASSERT_FALSE(backend.isGlobalHostRNGRequired(model)); } //-------------------------------------------------------------------------- TEST(CustomUpdates, UninitialisedVarSum) @@ -200,7 +196,7 @@ TEST(CustomUpdates, UninitialisedVarSum) CustomUpdate *cu = model.addCustomUpdate("Sum", "CustomUpdate", {}, sumVarValues, sumVarReferences1); - model.finalize(); + model.finalise(); CustomUpdateInternal *cuInternal = static_cast(cu); ASSERT_FALSE(cuInternal->isZeroCopyEnabled()); @@ -209,11 +205,7 @@ TEST(CustomUpdates, UninitialisedVarSum) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; CodeGenerator::SingleThreadedCPU::Backend backend(preferences); - - // Merge model - CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); - - ASSERT_FALSE(backend.isGlobalHostRNGRequired(modelSpecMerged)); + ASSERT_FALSE(backend.isGlobalHostRNGRequired(model)); } //-------------------------------------------------------------------------- TEST(CustomUpdates, RandVarSum) @@ -230,7 +222,7 @@ TEST(CustomUpdates, RandVarSum) CustomUpdate *cu = model.addCustomUpdate("Sum", "CustomUpdate", {}, sumVarValues, sumVarReferences1); - model.finalize(); + model.finalise(); CustomUpdateInternal *cuInternal = static_cast(cu); ASSERT_FALSE(cuInternal->isZeroCopyEnabled()); @@ -239,11 +231,7 @@ TEST(CustomUpdates, RandVarSum) // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; CodeGenerator::SingleThreadedCPU::Backend backend(preferences); - - // Merge model - CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); - - ASSERT_TRUE(backend.isGlobalHostRNGRequired(modelSpecMerged)); + ASSERT_TRUE(backend.isGlobalHostRNGRequired(model)); } //-------------------------------------------------------------------------- TEST(CustomUpdates, VarReferenceTypeChecks) @@ -257,7 +245,7 @@ TEST(CustomUpdates, VarReferenceTypeChecks) model.addNeuronPopulation("Post", 25, paramVals, varVals); auto *sg1 = model.addSynapsePopulation( - "Synapses", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapses", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}, {"d", 4}}, {}, {}); @@ -277,7 +265,7 @@ TEST(CustomUpdates, VarReferenceTypeChecks) catch(const std::runtime_error &) { } - model.finalize(); + model.finalise(); } //-------------------------------------------------------------------------- TEST(CustomUpdates, VarSizeChecks) @@ -309,7 +297,7 @@ TEST(CustomUpdates, VarSizeChecks) catch(const std::runtime_error &) { } - model.finalize(); + model.finalise(); } //-------------------------------------------------------------------------- TEST(CustomUpdates, VarDelayChecks) @@ -323,7 +311,7 @@ TEST(CustomUpdates, VarDelayChecks) auto *post = model.addNeuronPopulation("Post", 10, paramVals, varVals); // Add synapse groups to force pre1's v to be delayed by 10 timesteps and pre2's v to be delayed by 5 timesteps - model.addSynapsePopulation("Syn1", SynapseMatrixType::DENSE_INDIVIDUALG, + model.addSynapsePopulation("Syn1", SynapseMatrixType::DENSE, 10, "Pre1", "Post", {}, {{"g", 0.1}}, {}, {}); @@ -334,7 +322,7 @@ TEST(CustomUpdates, VarDelayChecks) model.addCustomUpdate("Sum1", "CustomUpdate", {}, sumVarValues, sumVarReferences1); - model.finalize(); + model.finalise(); } //-------------------------------------------------------------------------- TEST(CustomUpdates, VarMixedDelayChecks) @@ -349,11 +337,11 @@ TEST(CustomUpdates, VarMixedDelayChecks) model.addNeuronPopulation("Post", 10, paramVals, varVals); // Add synapse groups to force pre1's v to be delayed by 10 timesteps and pre2's v to be delayed by 5 timesteps - model.addSynapsePopulation("Syn1", SynapseMatrixType::DENSE_INDIVIDUALG, + model.addSynapsePopulation("Syn1", SynapseMatrixType::DENSE, 10, "Pre1", "Post", {}, {{"g", 0.1}}, {}, {}); - model.addSynapsePopulation("Syn2", SynapseMatrixType::DENSE_INDIVIDUALG, + model.addSynapsePopulation("Syn2", SynapseMatrixType::DENSE, 5, "Pre2", "Post", {}, {{"g", 0.1}}, {}, {}); @@ -364,7 +352,7 @@ TEST(CustomUpdates, VarMixedDelayChecks) {}, sumVarValues, sumVarReferences2); try { - model.finalize(); + model.finalise(); FAIL(); } catch(const std::runtime_error &) { @@ -382,12 +370,12 @@ TEST(CustomUpdates, WUVarSynapseGroupChecks) model.addNeuronPopulation("Post", 25, paramVals, varVals); auto *sg1 = model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapses1", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}, {"d", 4}}, {}, {}); auto *sg2 = model.addSynapsePopulation( - "Synapses2", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapses2", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}, {"d", 4}}, {}, {}); @@ -407,7 +395,7 @@ TEST(CustomUpdates, WUVarSynapseGroupChecks) catch(const std::runtime_error &) { } - model.finalize(); + model.finalise(); } //-------------------------------------------------------------------------- TEST(CustomUpdates, BatchingVars) @@ -438,7 +426,7 @@ TEST(CustomUpdates, BatchingVars) auto *sum4 = model.addCustomUpdate("Sum4", "CustomUpdate", {}, sumVarValues, sumVarReferences4); - model.finalize(); + model.finalise(); EXPECT_TRUE(static_cast(sum1)->isBatched()); EXPECT_FALSE(static_cast(sum2)->isBatched()); @@ -481,7 +469,7 @@ TEST(CustomUpdates, ReduceDuplicate) model.addCustomUpdate("Sum1", "CustomUpdate", {}, sum2VarValues, sum2VarReferences); try { - model.finalize(); + model.finalise(); FAIL(); } catch(const std::runtime_error &) { @@ -502,7 +490,7 @@ TEST(CustomUpdates, ReductionTypeDuplicateNeuron) VarReferences reduceVarReferences{{"var", createVarRef(pop, "V")}, {"reduction", createVarRef(pop, "a")}}; auto *cu = model.addCustomUpdate("Reduction", "CustomUpdate", {}, {}, reduceVarReferences); - model.finalize(); + model.finalise(); auto *cuInternal = static_cast(cu); ASSERT_TRUE(cuInternal->isBatched()); ASSERT_FALSE(cuInternal->isBatchReduction()); @@ -524,7 +512,7 @@ TEST(CustomUpdates, ReductionTypeDuplicateNeuronInternal) VarValues reduceVars{{"reduction", 0.0}}; auto *cu = model.addCustomUpdate("Reduction", "CustomUpdate", {}, reduceVars, reduceVarReferences); - model.finalize(); + model.finalise(); auto *cuInternal = static_cast(cu); ASSERT_TRUE(cuInternal->isBatched()); ASSERT_FALSE(cuInternal->isBatchReduction()); @@ -546,7 +534,7 @@ TEST(CustomUpdates, ReductionTypeSharedNeuronInternal) VarValues reduceVars{{"reduction", 0.0}}; auto *cu = model.addCustomUpdate("Reduction", "CustomUpdate", {}, reduceVars, reduceVarReferences); - model.finalize(); + model.finalise(); auto *cuInternal = static_cast(cu); ASSERT_FALSE(cuInternal->isBatched()); ASSERT_FALSE(cuInternal->isBatchReduction()); @@ -567,7 +555,7 @@ TEST(CustomUpdates, ReductionTypeDuplicateBatch) VarReferences reduceVarReferences{{"var", createVarRef(pop, "V")}, {"reduction", createVarRef(pop, "a")}}; auto *cu = model.addCustomUpdate("Reduction", "CustomUpdate", {}, {}, reduceVarReferences); - model.finalize(); + model.finalise(); auto *cuInternal = static_cast(cu); ASSERT_TRUE(cuInternal->isBatched()); ASSERT_TRUE(cuInternal->isBatchReduction()); @@ -589,7 +577,7 @@ TEST(CustomUpdates, ReductionTypeDuplicateBatchInternal) VarValues reduceVars{{"reduction", 0.0}}; auto *cu = model.addCustomUpdate("Reduction", "CustomUpdate", {}, reduceVars, reduceVarReferences); - model.finalize(); + model.finalise(); auto *cuInternal = static_cast(cu); ASSERT_TRUE(cuInternal->isBatched()); ASSERT_TRUE(cuInternal->isBatchReduction()); @@ -607,7 +595,7 @@ TEST(CustomUpdates, NeuronSharedCustomUpdateWU) model.addNeuronPopulation("Post", 25, paramVals, varVals); auto *sg1 = model.addSynapsePopulation( - "Synapses", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapses", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); @@ -667,7 +655,7 @@ TEST(CustomUpdates, CompareDifferentModel) {}, {{"mult", 1.0}}, sum2VarReferences); // Finalize model - model.finalize(); + model.finalise(); CustomUpdateInternal *sum0Internal = static_cast(sum0); CustomUpdateInternal *sum1Internal = static_cast(sum1); @@ -707,7 +695,7 @@ TEST(CustomUpdates, CompareDifferentUpdateGroup) auto *sum2 = model.addCustomUpdate("Sum2", "CustomUpdate1", {}, {{"sum", 1.0}}, sumVarReferences); // Finalize model - model.finalize(); + model.finalise(); CustomUpdateInternal *sum0Internal = static_cast(sum0); CustomUpdateInternal *sum1Internal = static_cast(sum1); @@ -743,16 +731,16 @@ TEST(CustomUpdates, CompareDifferentDelay) model.addNeuronPopulation("Post", 10, paramVals, varVals); // Add synapse groups to force pre1's v to be delayed by 0 timesteps, pre2 and pre3's v to be delayed by 5 timesteps and pre4's to be delayed by 10 timesteps - model.addSynapsePopulation("Syn1", SynapseMatrixType::DENSE_INDIVIDUALG, + model.addSynapsePopulation("Syn1", SynapseMatrixType::DENSE, NO_DELAY, "Pre1", "Post", {}, {{"g", 0.1}}, {}, {}); - model.addSynapsePopulation("Syn2", SynapseMatrixType::DENSE_INDIVIDUALG, + model.addSynapsePopulation("Syn2", SynapseMatrixType::DENSE, 5, "Pre2", "Post", {}, {{"g", 0.1}}, {}, {}); - model.addSynapsePopulation("Syn3", SynapseMatrixType::DENSE_INDIVIDUALG, + model.addSynapsePopulation("Syn3", SynapseMatrixType::DENSE, 5, "Pre3", "Post", {}, {{"g", 0.1}}, {}, {}); - model.addSynapsePopulation("Syn4", SynapseMatrixType::DENSE_INDIVIDUALG, + model.addSynapsePopulation("Syn4", SynapseMatrixType::DENSE, 10, "Pre4", "Post", {}, {{"g", 0.1}}, {}, {}); @@ -772,7 +760,7 @@ TEST(CustomUpdates, CompareDifferentDelay) {}, {{"sum", 0.0}}, sumVarReferences4); // Finalize model - model.finalize(); + model.finalise(); // No delay group can't be merged with any others CustomUpdateInternal *sum1Internal = static_cast(sum1); @@ -825,7 +813,7 @@ TEST(CustomUpdates, CompareDifferentBatched) auto *sum3 = model.addCustomUpdate("Sum3", "CustomUpdate", {}, {{"sum", 0.0}}, sumVarReferences3); - model.finalize(); + model.finalise(); // Check that sum1 and sum3 are batched and sum2 is not CustomUpdateInternal *sum1Internal = static_cast(sum1); @@ -866,10 +854,10 @@ TEST(CustomUpdates, CompareDifferentWUTranspose) model.addNeuronPopulation("Post", 10, paramVals, varVals); // Add synapse groups to force pre1's v to be delayed by 0 timesteps, pre2 and pre3's v to be delayed by 5 timesteps and pre4's to be delayed by 10 timesteps - auto *fwdSyn = model.addSynapsePopulation("fwdSyn", SynapseMatrixType::DENSE_INDIVIDUALG, + auto *fwdSyn = model.addSynapsePopulation("fwdSyn", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, {{"g", 0.0}, {"x", 0.0}}, {}, {}); - auto *backSyn = model.addSynapsePopulation("backSyn", SynapseMatrixType::DENSE_INDIVIDUALG, + auto *backSyn = model.addSynapsePopulation("backSyn", SynapseMatrixType::DENSE, NO_DELAY, "Post", "Pre", {}, {{"g", 0.0}}, {}, {}); @@ -882,7 +870,7 @@ TEST(CustomUpdates, CompareDifferentWUTranspose) {}, {{"mult", 0.0}}, sumVarReferences2); // Finalize model - model.finalize(); + model.finalise(); // Updates which transpose different variables can't be merged with any others CustomUpdateWUInternal *sum1Internal = static_cast(sum1); @@ -918,11 +906,11 @@ TEST(CustomUpdates, CompareDifferentWUConnectivity) // Add a sparse and a dense synapse group auto *syn1 = model.addSynapsePopulation( - "Syn1", SynapseMatrixType::DENSE_INDIVIDUALG, + "Syn1", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, {{"g", 0.0}, {"x", 0.0}}, {}, {}); auto *syn2 = model.addSynapsePopulation( - "Syn2", SynapseMatrixType::SPARSE_INDIVIDUALG, + "Syn2", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 0.0}, {"x", 0.0}}, {}, {}, initConnectivity({{"prob", 0.1}})); @@ -936,7 +924,7 @@ TEST(CustomUpdates, CompareDifferentWUConnectivity) {}, {{"sum", 0.0}}, sumVarReferences2); // Finalize model - model.finalize(); + model.finalise(); // Updates and initialisation with different connectivity can't be merged with any others CustomUpdateWUInternal *sum1Internal = static_cast(sum1); @@ -972,7 +960,7 @@ TEST(CustomUpdates, CompareDifferentWUBatched) // Add synapse group VarValues synVarInit{{"gCommon", 1.0}, {"g", 1.0}, {"dCommon",1.0}, {"d", 1.0}}; auto *sg1 = model.addSynapsePopulation( - "Synapses", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapses", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, synVarInit, {}, {}); @@ -987,7 +975,7 @@ TEST(CustomUpdates, CompareDifferentWUBatched) {}, {{"sum", 0.0}}, sumVarReferences2); auto *sum3 = model.addCustomUpdate("Sum3", "CustomUpdate", {}, {{"sum", 0.0}}, sumVarReferences3); - model.finalize(); + model.finalise(); // Check that sum1 and sum3 are batched and sum2 is not CustomUpdateWUInternal *sum1Internal = static_cast(sum1); diff --git a/tests/unit/initSparseConnectivitySnippet.cc b/tests/unit/initSparseConnectivitySnippet.cc index 0d2aa5a23c..f4da306ac2 100644 --- a/tests/unit/initSparseConnectivitySnippet.cc +++ b/tests/unit/initSparseConnectivitySnippet.cc @@ -30,19 +30,15 @@ class FixedNumberTotalWithReplacement : public InitSparseConnectivitySnippet::Ba public: DECLARE_SNIPPET(FixedNumberTotalWithReplacement); - SET_ROW_BUILD_CODE( - "const unsigned int rowLength = $(preCalcRowLength)[($(id_pre) * $(num_threads)) + $(id_thread)];\n" - "if(c >= rowLength) {\n" - " $(endRow);\n" - "}\n" - "const scalar u = $(gennrand_uniform);\n" - "x += (1.0 - x) * (1.0 - pow(u, 1.0 / (scalar)(rowLength - c)));\n" - "unsigned int postIdx = (unsigned int)(x * $(num_post));\n" - "postIdx = (postIdx < $(num_post)) ? postIdx : ($(num_post) - 1);\n" - "$(addSynapse, postIdx + $(id_post_begin));\n" - "c++;\n"); - SET_ROW_BUILD_STATE_VARS({{"x", "scalar", 0.0},{"c", "unsigned int", 0}}); - + SET_ROW_BUILD_CODE( + "scalar x = 0.0;\n" + "for(unsigned int c = 0; c < preCalcRowLength[(id_pre * num_threads) + id_thread]; c++) {\n" + " const scalar u = gennrand_uniform();\n" + " x += (1.0 - x) * (1.0 - pow(u, 1.0 / (scalar)(rowLength - c)));\n" + " unsigned int postIdx = (unsigned int)(x * num_post);\n" + " postIdx = (postIdx < num_post) ? postIdx : (num_post - 1);\n" + " addSynapse(postIdx + id_post_begin);\n" + "}\n"); SET_PARAM_NAMES({"total"}); SET_EXTRA_GLOBAL_PARAMS({{"preCalcRowLength", "unsigned int*"}}) @@ -100,9 +96,9 @@ TEST(InitSparseConnectivitySnippet, CompareVarInitParameters) auto connectivityInit1 = initConnectivity(fixedProbParamsA); auto connectivityInit2 = initConnectivity(fixedProbParamsB); - connectivityInit0.initDerivedParams(0.1); - connectivityInit1.initDerivedParams(0.1); - connectivityInit2.initDerivedParams(0.1); + connectivityInit0.finalise(0.1); + connectivityInit1.finalise(0.1); + connectivityInit2.finalise(0.1); ASSERT_EQ(connectivityInit0.getHashDigest(), connectivityInit1.getHashDigest()); ASSERT_EQ(connectivityInit0.getHashDigest(), connectivityInit2.getHashDigest()); @@ -116,8 +112,8 @@ TEST(InitSparseConnectivitySnippet, CompareUnusedParameters) auto connectivityInit0 = initConnectivity(fixedNumberParamsA); auto connectivityInit1 = initConnectivity(fixedNumberParamsB); - connectivityInit0.initDerivedParams(0.1); - connectivityInit1.initDerivedParams(0.1); + connectivityInit0.finalise(0.1); + connectivityInit1.finalise(0.1); ASSERT_EQ(connectivityInit0.getHashDigest(), connectivityInit1.getHashDigest()); } diff --git a/tests/unit/initVarSnippet.cc b/tests/unit/initVarSnippet.cc index 895bf988a3..fcd5958d87 100644 --- a/tests/unit/initVarSnippet.cc +++ b/tests/unit/initVarSnippet.cc @@ -13,8 +13,8 @@ class UniformCopy : public InitVarSnippet::Base { public: SET_CODE( - "const scalar scale = $(max) - $(min);\n" - "$(value) = $(min) + ($(gennrand_uniform) * scale);"); + "const scalar scale = max - min;\n" + "value = min + (gennrand_uniform() * scale);"); SET_PARAM_NAMES({"min", "max"}); }; diff --git a/tests/unit/modelSpec.cc b/tests/unit/modelSpec.cc index e8e60e8f5d..2118481ece 100644 --- a/tests/unit/modelSpec.cc +++ b/tests/unit/modelSpec.cc @@ -52,13 +52,12 @@ class RemoveSynapse : public CustomConnectivityUpdateModels::Base SET_VARS({{"a", "scalar"}}); SET_ROW_UPDATE_CODE( - "$(for_each_synapse,\n" - "{\n" - " if($(id_post) == ($(id_pre) + 1)) {\n" - " $(remove_synapse);\n" + "for_each_synapse{\n" + " if(id_post == (id_pre + 1)) {\n" + " remove_synapse();\n" " break;\n" " }\n" - "});\n"); + "};\n"); }; IMPLEMENT_SNIPPET(RemoveSynapse); } @@ -104,7 +103,7 @@ TEST(ModelSpec, PSMZeroCopy) model.addNeuronPopulation("Neurons1", 10, paramVals, varVals); SynapseGroup *sg = model.addSynapsePopulation( - "Synapse", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapse", SynapseMatrixType::DENSE, NO_DELAY, "Neurons0", "Neurons1", {}, {{"g", 1.0}}, {{"tau", 5.0}}, {{"x", 0.0}}); @@ -123,7 +122,7 @@ TEST(ModelSpec, WUZeroCopy) model.addNeuronPopulation("Neurons1", 10, paramVals, varVals); SynapseGroup *sg = model.addSynapsePopulation( - "Synapse", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapse", SynapseMatrixType::DENSE, NO_DELAY, "Neurons0", "Neurons1", {}, {{"g", 1.0}}, {}, {}); @@ -157,7 +156,7 @@ TEST(ModelSpec, CustomConnectivityUpdateZeroCopy) model.addNeuronPopulation("Neurons1", 10, paramVals, varVals); model.addSynapsePopulation( - "Synapse", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapse", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", {}, {{"g", 1.0}, {"d", 1}}, {}, {}); diff --git a/tests/unit/modelSpecMerged.cc b/tests/unit/modelSpecMerged.cc index 018a05ea36..7e78558ed6 100644 --- a/tests/unit/modelSpecMerged.cc +++ b/tests/unit/modelSpecMerged.cc @@ -190,7 +190,7 @@ void test(const std::pair (&modelModifiers)[N], M applyModifierFn) applyModifierFn(modelModifiers[i].first, model); // Finalize model - model.finalize(); + model.finalise(); // Create suitable backend to build model CodeGenerator::SingleThreadedCPU::Backend backend(preferences); @@ -262,7 +262,7 @@ void testSynapseVarLocation(S setVarLocationFn) VarValues psmVarValues{{"x", 0.0}}; auto *sg = model.addSynapsePopulation( - "Synapse", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapse", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", params, varValues, preVarValues, postVarValues, psmParams, psmVarValues, @@ -291,10 +291,10 @@ void testCustomConnectivityUpdateVarLocation(S setVarLocationFn) model.addNeuronPopulation("Post", 10, paramVals, varVals); // Create synapse group with global weights - model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation( + "Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", - {}, {{"g", 1.0}}, + {{"g", 1.0}}, {}, {}, {}); auto *ccu = model.addCustomConnectivityUpdate( @@ -609,7 +609,7 @@ TEST(ModelSpecMerged, CompareSynapseNameChanges) neuronParamVals, neuronVarVals); model.addSynapsePopulation( - name, SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + name, SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); @@ -647,7 +647,7 @@ TEST(ModelSpecMerged, ComparePSMParamChanges) neuronParamVals, neuronVarVals); model.addSynapsePopulation( - "Synapse" + std::to_string(p), SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapse" + std::to_string(p), SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post" + std::to_string(p), {}, {{"g", 1.0}}, psmParams[p], {}); @@ -688,7 +688,7 @@ TEST(ModelSpecMerged, ComparePSMVarInitParamChanges) ParamValues params{{"tau", 5.0}}; VarValues varValues{{"x", initVar(psmVarInitParams[p])}}; model.addSynapsePopulation( - "Synapse" + std::to_string(p), SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapse" + std::to_string(p), SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post" + std::to_string(p), {}, {{"g", 1.0}}, params, varValues); @@ -736,7 +736,7 @@ TEST(ModelSpecMerged, CompareWUMParamChanges) VarValues varInit{{"g", 0.0}, {"gRaw", uninitialisedVar()}}; model.addSynapsePopulation( - "Synapse" + std::to_string(p), SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapse" + std::to_string(p), SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post" + std::to_string(p), wumParams[p], varInit, {}, {}); @@ -747,12 +747,12 @@ TEST(ModelSpecMerged, CompareWUMParamChanges) TEST(ModelSpecMerged, CompareWUMGlobalGVarChanges) { // Weight update model variable initialisers - const VarValues varVals1{{"g", 1.0}}; - const VarValues varVals2{{"g", 0.2}}; - const VarValues varVals3{{"g", 0.9}}; + const ParamValues varVals1{{"g", 1.0}}; + const ParamValues varVals2{{"g", 0.2}}; + const ParamValues varVals3{{"g", 0.9}}; // Make array of population parameters to build model with and flags determining whether the hashes should match baseline - const std::pair, bool> modelModifiers[] = { + const std::pair, bool> modelModifiers[] = { {{varVals1, varVals2}, true}, {{varVals1, varVals2}, true}, {{varVals1, varVals1}, false}, @@ -761,7 +761,7 @@ TEST(ModelSpecMerged, CompareWUMGlobalGVarChanges) {{varVals1}, false}}; test(modelModifiers, - [](const std::vector &wumVarVals, ModelSpecInternal &model) + [](const std::vector &wumParamVals, ModelSpecInternal &model) { // Add pre population VarValues neuronVarVals{{"V", 0.0}, {"U", 0.0}}; @@ -770,14 +770,14 @@ TEST(ModelSpecMerged, CompareWUMGlobalGVarChanges) neuronParamVals, neuronVarVals); // Add desired number of post populations - for(size_t p = 0; p < wumVarVals.size(); p++) { + for(size_t p = 0; p < wumParamVals.size(); p++) { model.addNeuronPopulation("Post" + std::to_string(p), 100, neuronParamVals, neuronVarVals); - model.addSynapsePopulation( - "Synapse" + std::to_string(p), SynapseMatrixType::DENSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation( + "Synapse" + std::to_string(p), SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post" + std::to_string(p), - {}, wumVarVals[p], + wumParamVals[p], {}, {}, {}); } }); @@ -815,7 +815,7 @@ TEST(ModelSpecMerged, CompareWUMVarInitParamChanges) VarValues varValues{{"g", initVar(wumVarInitParams[p])}}; model.addSynapsePopulation( - "Synapse" + std::to_string(p), SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapse" + std::to_string(p), SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post" + std::to_string(p), {}, varValues, {}, {}); @@ -864,7 +864,7 @@ TEST(ModelSpecMerged, CompareWUMPreVarInitParamChanges) VarValues postVarValues{{"postTrace", 0.0}}; model.addSynapsePopulation( - "Synapse" + std::to_string(p), SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapse" + std::to_string(p), SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post" + std::to_string(p), params, varValues, preVarValues, postVarValues, {}, {}); @@ -913,7 +913,7 @@ TEST(ModelSpecMerged, CompareWUMPostVarInitParamChanges) VarValues postVarValues{{"postTrace", initVar(wumPostVarInitParams[p])}}; model.addSynapsePopulation( - "Synapse" + std::to_string(p), SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapse" + std::to_string(p), SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post" + std::to_string(p), params, varValues, preVarValues, postVarValues, {}, {}); @@ -956,11 +956,10 @@ TEST(ModelSpecMerged, CompareConnectivityParamChanges) model.addNeuronPopulation("Post" + std::to_string(p), 100, neuronParamVals, neuronVarVals); - VarValues varValues{{"g", 1.0}}; - model.addSynapsePopulation( - "Synapse" + std::to_string(p), SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation( + "Synapse" + std::to_string(p), SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post" + std::to_string(p), - {}, varValues, + {{"g", 1.0}}, {}, {}, {}, initConnectivity(connectivityParams[p])); } @@ -996,10 +995,10 @@ TEST(ModelSpecMerged, CompareConnectivityModelChanges) model.addNeuronPopulation("Post" + std::to_string(p), 100, neuronParamVals, neuronVarVals); - model.addSynapsePopulation( - "Synapse" + std::to_string(p), SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation( + "Synapse" + std::to_string(p), SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post" + std::to_string(p), - {}, {{"g", 1.0}}, + {{"g", 1.0}}, {}, {}, {}, InitSparseConnectivitySnippet::Init(connectivityModels[p], {})); } @@ -1207,7 +1206,7 @@ TEST(ModelSpecMerged, CompareCustomConnectivityUpdateParamChanges) // Create synapse group with global weights auto *syn = model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); @@ -1245,10 +1244,10 @@ TEST(ModelSpecMerged, CompareCustomConnectivityUpdateVarInitParamChanges) model.addNeuronPopulation("Post", 10, paramVals, varVals); // Create synapse group with global weights - model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation( + "Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", - {}, {{"g", 1.0}}, + {{"g", 1.0}}, {}, {}, {}); model.addCustomConnectivityUpdate( diff --git a/tests/unit/models.cc b/tests/unit/models.cc index 5558fda620..b5af0cdb2a 100644 --- a/tests/unit/models.cc +++ b/tests/unit/models.cc @@ -37,7 +37,7 @@ class StaticPulseUInt : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulseUInt); - SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}}); + SET_PARAM_NAMES({"g"}); SET_SIM_CODE("$(addToInSyn, $(g));\n"); }; @@ -76,6 +76,28 @@ class ContPrePost : public WeightUpdateModels::Base "$(addToInSyn, $(g) * $(V_pre));\n"); }; IMPLEMENT_SNIPPET(ContPrePost); + +class ContPrePostConstantWeight : public WeightUpdateModels::Base +{ +public: + DECLARE_SNIPPET(ContPrePostConstantWeight); + + SET_PARAM_NAMES({"g"}); + SET_PRE_VARS({{"preTrace", "scalar"}}); + SET_POST_VARS({{"postTrace", "scalar"}}); + + SET_PRE_SPIKE_CODE( + "scalar dt = $(t) - $(sT_pre);\n" + "$(preTrace) = ($(preTrace) * exp(-dt / $(tauPlus))) + 1.0;\n"); + + SET_POST_SPIKE_CODE( + "scalar dt = $(t) - $(sT_post);\n" + "$(postTrace) = ($(postTrace) * exp(-dt / $(tauMinus))) + 1.0;\n"); + + SET_SYNAPSE_DYNAMICS_CODE( + "$(addToInSyn, $(g) * $(V_pre));\n"); +}; +IMPLEMENT_SNIPPET(ContPrePostConstantWeight); } //-------------------------------------------------------------------------- @@ -110,7 +132,7 @@ TEST(Models, NeuronVarReferenceDelay) VarValues varVals{{"V", 0.0}, {"U", 0.0}}; auto *pre = model.addNeuronPopulation("Neurons0", 10, paramVals, varVals); model.addNeuronPopulation("Neurons1", 10, paramVals, varVals); - model.addSynapsePopulation("Syn", SynapseMatrixType::DENSE_INDIVIDUALG, + model.addSynapsePopulation("Syn", SynapseMatrixType::DENSE, 10, "Neurons0", "Neurons1", {}, {{"g", 0.1}}, {}, {}); @@ -118,7 +140,7 @@ TEST(Models, NeuronVarReferenceDelay) auto neuronU = createVarRef(pre, "U"); // Finalize model - model.finalize(); + model.finalise(); // Check ASSERT_EQ(neuronV.getDelayNeuronGroup(), pre); @@ -161,10 +183,11 @@ TEST(Models, PSMVarReference) model.addNeuronPopulation("Pre", 10, paramVals, varVals); model.addNeuronPopulation("Post", 25, paramVals, varVals); - auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::DENSE_GLOBALG, NO_DELAY, - "Pre", "Post", - {}, {{"g", 1.0}}, - {{"tau", 5.0}}, {{"x", 0.0}}); + auto *sg1 = model.addSynapsePopulation( + "Synapses1", SynapseMatrixType::DENSE, NO_DELAY, + "Pre", "Post", + {{"g", 1.0}}, {}, + {{"tau", 5.0}}, {{"x", 0.0}}); auto psmX = createPSMVarRef(sg1, "x"); ASSERT_EQ(psmX.getSize(), 25); @@ -189,15 +212,15 @@ TEST(Models, WUPreVarReference) model.addNeuronPopulation("Post", 25, paramVals, varVals); auto *sg1 = model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapses1", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {{"preTrace", 0.0}}, {{"postTrace", 0.0}}, {}, {}); - auto *sg2 = model.addSynapsePopulation( - "Synapses2", SynapseMatrixType::DENSE_GLOBALG, 5, + auto *sg2 = model.addSynapsePopulation( + "Synapses2", SynapseMatrixType::DENSE, 5, "Pre", "Post", - {}, {{"g", 1.0}}, {{"preTrace", 0.0}}, {{"postTrace", 0.0}}, + {{"g", 1.0}}, {}, {{"preTrace", 0.0}}, {{"postTrace", 0.0}}, {}, {}); auto wuPre = createWUPreVarRef(sg1, "preTrace"); @@ -212,7 +235,7 @@ TEST(Models, WUPreVarReference) } // Finalize model - model.finalize(); + model.finalise(); ASSERT_EQ(wuPre.getSize(), 10); ASSERT_EQ(wuPre.getDelayNeuronGroup(), nullptr); @@ -230,15 +253,15 @@ TEST(Models, WUPostVarReference) auto *post = model.addNeuronPopulation("Post", 25, paramVals, varVals); auto *sg1 = model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapses1", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {{"preTrace", 0.0}}, {{"postTrace", 0.0}}, {}, {}); - auto *sg2 = model.addSynapsePopulation( - "Synapses2", SynapseMatrixType::DENSE_GLOBALG, NO_DELAY, + auto *sg2 = model.addSynapsePopulation( + "Synapses2", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", - {}, {{"g", 1.0}}, {{"preTrace", 0.0}}, {{"postTrace", 0.0}}, + {{"g", 1.0}}, {}, {{"preTrace", 0.0}}, {{"postTrace", 0.0}}, {}, {}); auto wuPost = createWUPostVarRef(sg1, "postTrace"); @@ -255,7 +278,7 @@ TEST(Models, WUPostVarReference) } // Finalize model - model.finalize(); + model.finalise(); ASSERT_EQ(wuPost.getSize(), 25); ASSERT_EQ(wuPost.getDelayNeuronGroup(), nullptr); @@ -272,15 +295,12 @@ TEST(Models, WUMVarReference) model.addNeuronPopulation("Pre", 10, paramVals, varVals); model.addNeuronPopulation("Post", 25, paramVals, varVals); - auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, - "Pre", "Post", - {}, {{"g", 1.0}}, - {{"tau", 5.0}}, {{"x", 0.0}}); + auto *sg1 = model.addSynapsePopulation( + "Synapses1", SynapseMatrixType::DENSE, NO_DELAY, + "Pre", "Post", + {}, {{"g", 1.0}}, + {{"tau", 5.0}}, {{"x", 0.0}}); - auto *sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::DENSE_GLOBALG, NO_DELAY, - "Pre", "Post", - {}, {{"g", 1.0}}, - {{"tau", 5.0}}, {{"x", 0.0}}); auto wuG1 = createWUVarRef(sg1, "g"); // Test error if variable doesn't exist @@ -290,14 +310,6 @@ TEST(Models, WUMVarReference) } catch(const std::runtime_error &) { } - - // Test error if GLOBALG - try { - auto wuG2 = createWUVarRef(sg2, "x"); - FAIL(); - } - catch(const std::runtime_error &) { - } } //-------------------------------------------------------------------------- TEST(Models, WUMTransposeVarReference) @@ -311,39 +323,39 @@ TEST(Models, WUMTransposeVarReference) model.addNeuronPopulation("Post", 25, paramVals, varVals); auto *sgForward = model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapses1", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); auto *sgBackwardIndividualG = model.addSynapsePopulation( - "Synapses2", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapses2", SynapseMatrixType::DENSE, NO_DELAY, "Post", "Pre", {}, {{"g", 1.0}}, {}, {}); - auto *sgBackwardGlobalG = model.addSynapsePopulation( - "Synapses3", SynapseMatrixType::DENSE_GLOBALG, NO_DELAY, + auto *sgBackwardGlobalG = model.addSynapsePopulation( + "Synapses3", SynapseMatrixType::DENSE, NO_DELAY, "Post", "Pre", - {}, {{"g", 1.0}}, + {{"g", 1.0}}, {}, {}, {}); auto *sgBackwardBadShape = model.addSynapsePopulation( - "Synapses4", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapses4", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Pre", {}, {{"g", 1.0}}, {}, {}); - auto *sgBackwardSparse = model.addSynapsePopulation( - "Synapses5", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + auto *sgBackwardSparse = model.addSynapsePopulation( + "Synapses5", SynapseMatrixType::SPARSE, NO_DELAY, "Post", "Pre", - {}, {{"g", 1.0}}, + {{"g", 1.0}}, {}, {}, {}); auto *sgBackwardBadType = model.addSynapsePopulation( - "Synapses6", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + "Synapses6", SynapseMatrixType::SPARSE, NO_DELAY, "Post", "Pre", - {}, {{"g", 1.0}}, + {{"g", 1.0}}, {}, {}, {}); auto wuG1 = createWUVarRef(sgForward, "g", sgBackwardIndividualG, "g"); diff --git a/tests/unit/neuronGroup.cc b/tests/unit/neuronGroup.cc index 22aece424d..ba834866ea 100644 --- a/tests/unit/neuronGroup.cc +++ b/tests/unit/neuronGroup.cc @@ -27,13 +27,25 @@ class StaticPulseBack : public WeightUpdateModels::Base }; IMPLEMENT_SNIPPET(StaticPulseBack); +class StaticPulseBackConstantWeight : public WeightUpdateModels::Base +{ +public: + DECLARE_SNIPPET(StaticPulseBackConstantWeight); + + SET_PARAM_NAMES({"g"}); + + SET_SIM_CODE( + "$(addToInSyn, $(g));\n" + "$(addToPre, $(g));\n"); +}; +IMPLEMENT_SNIPPET(StaticPulseBackConstantWeight); + class WeightUpdateModelPost : public WeightUpdateModels::Base { public: DECLARE_SNIPPET(WeightUpdateModelPost); - SET_VARS({{"w", "scalar"}}); - SET_PARAM_NAMES({"p"}); + SET_PARAM_NAMES({"w", "p"}); SET_POST_VARS({{"s", "scalar"}}); SET_SIM_CODE("$(w)= $(s);\n"); @@ -46,8 +58,7 @@ class WeightUpdateModelPre : public WeightUpdateModels::Base public: DECLARE_SNIPPET(WeightUpdateModelPre); - SET_VARS({{"w", "scalar"}}); - SET_PARAM_NAMES({"p"}); + SET_PARAM_NAMES({"w", "p"}); SET_PRE_VARS({{"s", "scalar"}}); SET_SIM_CODE("$(w)= $(s);\n"); @@ -219,20 +230,17 @@ TEST(NeuronGroup, ConstantVarIzhikevich) VarValues varVals{{"V", 0.0}, {"U", 0.0}}; NeuronGroup *ng = model.addNeuronPopulation("Neurons0", 10, paramVals, varVals); - model.finalize(); + model.finalise(); - ASSERT_FALSE(ng->isZeroCopyEnabled()); - ASSERT_FALSE(ng->isSimRNGRequired()); - ASSERT_FALSE(ng->isInitRNGRequired()); + auto ngInternal = static_cast(ng); + ASSERT_FALSE(ngInternal->isZeroCopyEnabled()); + ASSERT_FALSE(ngInternal->isSimRNGRequired()); + ASSERT_FALSE(ngInternal->isInitRNGRequired()); // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; CodeGenerator::SingleThreadedCPU::Backend backend(preferences); - - // Merge model - CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); - - ASSERT_FALSE(backend.isGlobalHostRNGRequired(modelSpecMerged)); + ASSERT_FALSE(backend.isGlobalHostRNGRequired(model)); } TEST(NeuronGroup, UninitialisedVarIzhikevich) @@ -243,20 +251,17 @@ TEST(NeuronGroup, UninitialisedVarIzhikevich) VarValues varVals{{"V", uninitialisedVar()}, {"U", uninitialisedVar()}}; NeuronGroup *ng = model.addNeuronPopulation("Neurons0", 10, paramVals, varVals); - model.finalize(); + model.finalise(); - ASSERT_FALSE(ng->isZeroCopyEnabled()); - ASSERT_FALSE(ng->isSimRNGRequired()); - ASSERT_FALSE(ng->isInitRNGRequired()); + auto ngInternal = static_cast(ng); + ASSERT_FALSE(ngInternal->isZeroCopyEnabled()); + ASSERT_FALSE(ngInternal->isSimRNGRequired()); + ASSERT_FALSE(ngInternal->isInitRNGRequired()); // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; CodeGenerator::SingleThreadedCPU::Backend backend(preferences); - - // Merge model - CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); - - ASSERT_FALSE(backend.isGlobalHostRNGRequired(modelSpecMerged)); + ASSERT_FALSE(backend.isGlobalHostRNGRequired(model)); } TEST(NeuronGroup, RandVarIzhikevich) @@ -268,20 +273,17 @@ TEST(NeuronGroup, RandVarIzhikevich) VarValues varVals{{"V", 0.0}, {"U", initVar(dist)}}; NeuronGroup *ng = model.addNeuronPopulation("Neurons0", 10, paramVals, varVals); - model.finalize(); + model.finalise(); - ASSERT_FALSE(ng->isZeroCopyEnabled()); - ASSERT_FALSE(ng->isSimRNGRequired()); - ASSERT_TRUE(ng->isInitRNGRequired()); + auto ngInternal = static_cast(ng); + ASSERT_FALSE(ngInternal->isZeroCopyEnabled()); + ASSERT_FALSE(ngInternal->isSimRNGRequired()); + ASSERT_TRUE(ngInternal->isInitRNGRequired()); // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; CodeGenerator::SingleThreadedCPU::Backend backend(preferences); - - // Merge model - CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); - - ASSERT_TRUE(backend.isGlobalHostRNGRequired(modelSpecMerged)); + ASSERT_TRUE(backend.isGlobalHostRNGRequired(model)); } TEST(NeuronGroup, Poisson) @@ -292,20 +294,17 @@ TEST(NeuronGroup, Poisson) VarValues varVals{{"timeStepToSpike", 0.0}}; NeuronGroup *ng = model.addNeuronPopulation("Neurons0", 10, paramVals, varVals); - model.finalize(); + model.finalise(); - ASSERT_FALSE(ng->isZeroCopyEnabled()); - ASSERT_TRUE(ng->isSimRNGRequired()); - ASSERT_FALSE(ng->isInitRNGRequired()); + auto ngInternal = static_cast(ng); + ASSERT_FALSE(ngInternal->isZeroCopyEnabled()); + ASSERT_TRUE(ngInternal->isSimRNGRequired()); + ASSERT_FALSE(ngInternal->isInitRNGRequired()); // Create a backend CodeGenerator::SingleThreadedCPU::Preferences preferences; CodeGenerator::SingleThreadedCPU::Backend backend(preferences); - - // Merge model - CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); - - ASSERT_TRUE(backend.isGlobalHostRNGRequired(modelSpecMerged)); + ASSERT_TRUE(backend.isGlobalHostRNGRequired(model)); } TEST(NeuronGroup, FuseWUMPrePost) @@ -332,62 +331,62 @@ TEST(NeuronGroup, FuseWUMPrePost) // Create baseline synapse group auto *syn = model.addSynapsePopulation( - "Syn", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Syn", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParams, wumVarVals, wumPreVarVals, wumPostVarVals, {}, {}); // Create synapse group with different value for parameter accessed in presynaptic code auto *synPreParam = model.addSynapsePopulation( - "SynPreParam", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "SynPreParam", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParamsPre, wumVarVals, wumPreVarVals, wumPostVarVals, {}, {}); // Create synapse group with different value for parameter accessed in presynaptic code auto *synPostParam = model.addSynapsePopulation( - "SynPostParam", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "SynPostParam", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParamsPost, wumVarVals, wumPreVarVals, wumPostVarVals, {}, {}); // Create synapse group with different value for parameter only accessed in synapse code auto *synSynParam = model.addSynapsePopulation( - "SynSynParam", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "SynSynParam", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParamsSyn, wumVarVals, wumPreVarVals, wumPostVarVals, {}, {}); // Create synapse group with different presynaptic variable initialiser auto *synPreVar2 = model.addSynapsePopulation( - "SynPreVar2", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "SynPreVar2", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParams, wumVarVals, wumPreVarVals2, wumPostVarVals, {}, {}); // Create synapse group with different postsynaptic variable initialiser auto *synPostVar2 = model.addSynapsePopulation( - "SynPostVar2", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "SynPostVar2", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParams, wumVarVals, wumPreVarVals, wumPostVarVals2, {}, {}); // Create synapse group with axonal delay auto *synAxonalDelay = model.addSynapsePopulation( - "SynAxonalDelay", SynapseMatrixType::DENSE_INDIVIDUALG, 10, + "SynAxonalDelay", SynapseMatrixType::DENSE, 10, "Pre", "Post", wumParams, wumVarVals, wumPreVarVals, wumPostVarVals, {}, {}); // Create synapse group with backprop delay auto *synBackPropDelay = model.addSynapsePopulation( - "SynBackPropDelay", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "SynBackPropDelay", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParams, wumVarVals, wumPreVarVals, wumPostVarVals, {}, {}); synBackPropDelay->setBackPropDelaySteps(10); - model.finalize(); + model.finalise(); // Cast synapse groups to internal types auto synInternal = static_cast(syn); @@ -447,28 +446,28 @@ TEST(NeuronGroup, FusePSM) // Create baseline synapse group auto *syn = model.addSynapsePopulation( - "Syn", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Syn", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, wumVarVals, psmParamVals, {}); // Create second synapse group auto *syn2 = model.addSynapsePopulation( - "Syn2", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Syn2", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, wumVarVals, psmParamVals, {}); // Create synapse group with different value for PSM parameter auto *synParam = model.addSynapsePopulation( - "SynParam", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "SynParam", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, wumVarVals, psmParamVals2, {}); // Create synapse group with different target variable auto *synTarget = model.addSynapsePopulation( - "SynTarget", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "SynTarget", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, wumVarVals, psmParamVals, {}); @@ -476,13 +475,13 @@ TEST(NeuronGroup, FusePSM) // Create synapse group with different max dendritic delay auto *synDelay = model.addSynapsePopulation( - "SynDelay", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "SynDelay", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, wumVarVals, psmParamVals, {}); synDelay->setMaxDendriticDelayTimesteps(20); - model.finalize(); + model.finalise(); // Cast synapse groups to internal types auto synInternal = static_cast(syn); @@ -521,27 +520,27 @@ TEST(NeuronGroup, FusePreOutput) // Create baseline synapse group auto *syn = model.addSynapsePopulation( - "Syn", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Syn", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, wumVarVals, {}, {}); // Create second synapse group auto *syn2 = model.addSynapsePopulation( - "Syn2", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Syn2", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, wumVarVals, {}, {}); // Create synapse group with different target variable auto *synTarget = model.addSynapsePopulation( - "SynTarget", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "SynTarget", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, wumVarVals, {}, {}); synTarget->setPreTargetVar("Isyn2"); - model.finalize(); + model.finalise(); // Cast synapse groups to internal types auto synInternal = static_cast(syn); @@ -569,7 +568,7 @@ TEST(NeuronGroup, CompareNeuronModels) auto *ng1 = model.addNeuronPopulation("Neurons1", 10, paramValsA, varVals2); auto *ng2 = model.addNeuronPopulation("Neurons2", 10, paramValsB, varVals3); - model.finalize(); + model.finalise(); // Check that all groups can be merged NeuronGroupInternal *ng0Internal = static_cast(ng0); @@ -622,7 +621,7 @@ TEST(NeuronGroup, CompareHeterogeneousParamVarState) auto *ng0 = model.addNeuronPopulation("Neurons0", 10, paramValsA, varVals); auto *ng1 = model.addNeuronPopulation("Neurons1", 10, paramValsB, varVals); - model.finalize(); + model.finalise(); // Check that all groups can be merged NeuronGroupInternal *ng0Internal = static_cast(ng0); @@ -662,7 +661,7 @@ TEST(NeuronGroup, CompareSimRNG) auto *ng0 = model.addNeuronPopulation("Neurons0", 10, paramVals, varVals); auto *ng1 = model.addNeuronPopulation("Neurons1", 10, paramVals, varVals); - model.finalize(); + model.finalise(); // Check that groups cannot be merged NeuronGroupInternal *ng0Internal = static_cast(ng0); @@ -670,8 +669,8 @@ TEST(NeuronGroup, CompareSimRNG) ASSERT_NE(ng0Internal->getHashDigest(), ng1Internal->getHashDigest()); ASSERT_NE(ng0Internal->getInitHashDigest(), ng1Internal->getInitHashDigest()); - ASSERT_TRUE(!ng0->isSimRNGRequired()); - ASSERT_TRUE(ng1->isSimRNGRequired()); + ASSERT_TRUE(!ng0Internal->isSimRNGRequired()); + ASSERT_TRUE(ng1Internal->isSimRNGRequired()); } TEST(NeuronGroup, CompareCurrentSources) @@ -713,7 +712,7 @@ TEST(NeuronGroup, CompareCurrentSources) model.addCurrentSource("CS9", "Neurons4", cs2ParamVals, {}); // **TODO** heterogeneous params - model.finalize(); + model.finalise(); NeuronGroupInternal *ng0Internal = static_cast(ng0); NeuronGroupInternal *ng1Internal = static_cast(ng1); @@ -776,62 +775,71 @@ TEST(NeuronGroup, ComparePostsynapticModels) auto *ng4 = model.addNeuronPopulation("Neurons4", 10, paramVals, varVals); // Add incoming synapse groups with Delta and DeltaCurr postsynaptic models to Neurons0 - VarValues staticPulseVarVals{{"g", 0.1}}; + ParamValues staticPulseParamVals{{"g", 0.1}}; ParamValues alphaCurrParamVals{{"tau", 0.5}}; ParamValues alphaCurrParamVals1{{"tau", 0.75}}; VarValues alphaCurrVarVals{{"x", 0.0}}; VarValues alphaCurrVarVals1{{"x", 0.1}}; - model.addSynapsePopulation("SG0", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "SpikeSource", "Neurons0", - {}, staticPulseVarVals, - {}, {}); - model.addSynapsePopulation("SG1", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "SpikeSource", "Neurons0", - {}, staticPulseVarVals, - alphaCurrParamVals, alphaCurrVarVals); + model.addSynapsePopulation( + "SG0", SynapseMatrixType::SPARSE, NO_DELAY, + "SpikeSource", "Neurons0", + staticPulseParamVals, {}, + {}, {}); + model.addSynapsePopulation( + "SG1", SynapseMatrixType::SPARSE, NO_DELAY, + "SpikeSource", "Neurons0", + staticPulseParamVals, {}, + alphaCurrParamVals, alphaCurrVarVals); // Do the same for Neuron1 - model.addSynapsePopulation("SG2", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "SpikeSource", "Neurons1", - {}, staticPulseVarVals, - {}, {}); - model.addSynapsePopulation("SG3", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "SpikeSource", "Neurons1", - {}, staticPulseVarVals, - alphaCurrParamVals, alphaCurrVarVals); + model.addSynapsePopulation( + "SG2", SynapseMatrixType::SPARSE, NO_DELAY, + "SpikeSource", "Neurons1", + staticPulseParamVals, {}, + {}, {}); + model.addSynapsePopulation( + "SG3", SynapseMatrixType::SPARSE, NO_DELAY, + "SpikeSource", "Neurons1", + staticPulseParamVals, {}, + alphaCurrParamVals, alphaCurrVarVals); // Do the same, but with different parameters for Neuron2, - model.addSynapsePopulation("SG4", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "SpikeSource", "Neurons2", - {}, staticPulseVarVals, - {}, {}); - model.addSynapsePopulation("SG5", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "SpikeSource", "Neurons2", - {}, staticPulseVarVals, - alphaCurrParamVals1, alphaCurrVarVals1); + model.addSynapsePopulation( + "SG4", SynapseMatrixType::SPARSE, NO_DELAY, + "SpikeSource", "Neurons2", + staticPulseParamVals, {}, + {}, {}); + model.addSynapsePopulation( + "SG5", SynapseMatrixType::SPARSE, NO_DELAY, + "SpikeSource", "Neurons2", + staticPulseParamVals, {}, + alphaCurrParamVals1, alphaCurrVarVals1); // Do the same, but in the opposite order for Neuron3 - model.addSynapsePopulation("SG6", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "SpikeSource", "Neurons3", - {}, staticPulseVarVals, - alphaCurrParamVals, alphaCurrVarVals); - model.addSynapsePopulation("SG7", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "SpikeSource", "Neurons3", - {}, staticPulseVarVals, - {}, {}); + model.addSynapsePopulation( + "SG6", SynapseMatrixType::SPARSE, NO_DELAY, + "SpikeSource", "Neurons3", + staticPulseParamVals, {}, + alphaCurrParamVals, alphaCurrVarVals); + model.addSynapsePopulation( + "SG7", SynapseMatrixType::SPARSE, NO_DELAY, + "SpikeSource", "Neurons3", + staticPulseParamVals, {}, + {}, {}); // Add two incoming synapse groups with DeltaCurr postsynaptic models sources to Neurons4 - model.addSynapsePopulation("SG8", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "SpikeSource", "Neurons4", - {}, staticPulseVarVals, - {}, {}); - - model.addSynapsePopulation("SG9", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "SpikeSource", "Neurons4", - {}, staticPulseVarVals, - {}, {}); + model.addSynapsePopulation( + "SG8", SynapseMatrixType::SPARSE, NO_DELAY, + "SpikeSource", "Neurons4", + staticPulseParamVals, {}, + {}, {}); + model.addSynapsePopulation( + "SG9", SynapseMatrixType::SPARSE, NO_DELAY, + "SpikeSource", "Neurons4", + staticPulseParamVals, {}, + {}, {}); - model.finalize(); + model.finalise(); NeuronGroupInternal *ng0Internal = static_cast(ng0); NeuronGroupInternal *ng1Internal = static_cast(ng1); @@ -896,47 +904,55 @@ TEST(NeuronGroup, ComparePreOutput) model.addNeuronPopulation("NeuronsPost2", 10, paramVals, varVals); // Add two outgoing synapse groups to NeuronsPre0 - VarValues staticPulseVarVals{{"g", 0.1}}; - model.addSynapsePopulation("SG0", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "NeuronsPre0", "NeuronsPost0", - {}, staticPulseVarVals, - {}, {}); - model.addSynapsePopulation("SG1", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "NeuronsPre0", "NeuronsPost1", - {}, staticPulseVarVals, - {}, {}); + ParamValues staticPulseParamVals{{"g", 0.1}}; + model.addSynapsePopulation( + "SG0", SynapseMatrixType::SPARSE, NO_DELAY, + "NeuronsPre0", "NeuronsPost0", + staticPulseParamVals, {}, + {}, {}); + model.addSynapsePopulation( + "SG1", SynapseMatrixType::SPARSE, NO_DELAY, + "NeuronsPre0", "NeuronsPost1", + staticPulseParamVals, {}, + {}, {}); // Do the same for NeuronsPre1 - model.addSynapsePopulation("SG2", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "NeuronsPre1", "NeuronsPost0", - {}, staticPulseVarVals, - {}, {}); - model.addSynapsePopulation("SG3", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "NeuronsPre1", "NeuronsPost1", - {}, staticPulseVarVals, - {}, {}); + model.addSynapsePopulation( + "SG2", SynapseMatrixType::SPARSE, NO_DELAY, + "NeuronsPre1", "NeuronsPost0", + staticPulseParamVals, {}, + {}, {}); + model.addSynapsePopulation( + "SG3", SynapseMatrixType::SPARSE, NO_DELAY, + "NeuronsPre1", "NeuronsPost1", + staticPulseParamVals, {}, + {}, {}); // Add three outgoing groups to NeuronPre2 - model.addSynapsePopulation("SG4", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "NeuronsPre2", "NeuronsPost0", - {}, staticPulseVarVals, - {}, {}); - model.addSynapsePopulation("SG5", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "NeuronsPre2", "NeuronsPost1", - {}, staticPulseVarVals, - {}, {}); - model.addSynapsePopulation("SG6", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "NeuronsPre2", "NeuronsPost2", - {}, staticPulseVarVals, - {}, {}); + model.addSynapsePopulation( + "SG4", SynapseMatrixType::SPARSE, NO_DELAY, + "NeuronsPre2", "NeuronsPost0", + staticPulseParamVals, {}, + {}, {}); + model.addSynapsePopulation( + "SG5", SynapseMatrixType::SPARSE, NO_DELAY, + "NeuronsPre2", "NeuronsPost1", + staticPulseParamVals, {}, + {}, {}); + model.addSynapsePopulation( + "SG6", SynapseMatrixType::SPARSE, NO_DELAY, + "NeuronsPre2", "NeuronsPost2", + staticPulseParamVals, {}, + {}, {}); // Add one outgoing groups to NeuronPre3 - model.addSynapsePopulation("SG7", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "NeuronsPre3", "NeuronsPost0", - {}, staticPulseVarVals, - {}, {}); + model.addSynapsePopulation( + "SG7", SynapseMatrixType::SPARSE, NO_DELAY, + "NeuronsPre3", "NeuronsPost0", + staticPulseParamVals, {}, + {}, {}); - model.finalize(); + model.finalise(); NeuronGroupInternal *ngPre0Internal = static_cast(ngPre0); NeuronGroupInternal *ngPre1Internal = static_cast(ngPre1); @@ -976,51 +992,51 @@ TEST(NeuronGroup, CompareWUPreUpdate) auto *ng5 = model.addNeuronPopulation("Neurons5", 10, paramVals, varVals); // Add incoming synapse groups with Delta and DeltaCurr postsynaptic models to Neurons0 - VarValues staticPulseVarVals{{"g", 0.1}}; - VarValues testVarVals{{"w", 0.0}}; - ParamValues testParams{{"p", 1.0}}; - ParamValues testParams2{{"p", 2.0}}; + ParamValues staticPulseParamVals{{"g", 0.1}}; + ParamValues testParams{{"w", 0.0}, {"p", 1.0}}; + ParamValues testParams2{{"w", 0.0}, {"p", 2.0}}; VarValues testPreVarVals1{{"s", 0.0}}; VarValues testPreVarVals2{{"s", 2.0}}; // Connect neuron group 1 to neuron group 0 with pre weight update model - model.addSynapsePopulation("SG0", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation("SG0", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons1", "Neurons0", - testParams, testVarVals, testPreVarVals1, {}, + testParams, {}, testPreVarVals1, {}, {}, {}); // Also connect neuron group 2 to neuron group 0 with pre weight update model - model.addSynapsePopulation("SG1", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation("SG1", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons2", "Neurons0", - testParams, testVarVals, testPreVarVals1, {}, + testParams, {}, testPreVarVals1, {}, {}, {}); // Also connect neuron group 3 to neuron group 0 with pre weight update model, but different parameters - model.addSynapsePopulation("SG2", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation("SG2", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons3", "Neurons0", - testParams2, testVarVals, testPreVarVals1, {}, + testParams2, {}, testPreVarVals1, {}, {}, {}); // Connect neuron group 4 to neuron group 0 with 2*pre weight update model - model.addSynapsePopulation("SG3", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation("SG3", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons4", "Neurons0", - testParams, testVarVals, testPreVarVals1, {}, + testParams, {}, testPreVarVals1, {}, {}, {}); - model.addSynapsePopulation("SG4", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation("SG4", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons4", "Neurons0", - testParams, testVarVals, testPreVarVals1, {}, + testParams, {}, testPreVarVals1, {}, {}, {}); // Connect neuron group 5 to neuron group 0 with pre weight update model and static pulse - model.addSynapsePopulation("SG5", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation("SG5", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons5", "Neurons0", - testParams, testVarVals, testPreVarVals2, {}, + testParams, {}, testPreVarVals2, {}, {}, {}); - model.addSynapsePopulation("SG6", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "Neurons5", "Neurons0", - {}, staticPulseVarVals, - {}, {}); - model.finalize(); + model.addSynapsePopulation( + "SG6", SynapseMatrixType::SPARSE, NO_DELAY, + "Neurons5", "Neurons0", + staticPulseParamVals, {}, + {}, {}); + model.finalise(); // Check which groups can be merged together // **NOTE** NG1 and NG5 can be merged because the additional static pulse synapse population doesn't add any presynaptic update @@ -1078,51 +1094,51 @@ TEST(NeuronGroup, CompareWUPostUpdate) auto *ng5 = model.addNeuronPopulation("Neurons5", 10, paramVals, varVals); // Add incoming synapse groups with Delta and DeltaCurr postsynaptic models to Neurons0 - VarValues staticPulseVarVals{{"g", 0.1}}; - VarValues testVarVals{{"w", 0.0}}; - ParamValues testParams{{"p", 1.0}}; - ParamValues testParams2{{"p", 2.0}}; + ParamValues staticPulseParamVals{{"g", 0.1}}; + ParamValues testParams{{"w", 0.0}, {"p", 1.0}}; + ParamValues testParams2{{"w", 0.0}, {"p", 2.0}}; VarValues testPostVarVals1{{"s", 0.0}}; VarValues testPostVarVals2{{"s", 2.0}}; // Connect neuron group 0 to neuron group 1 with post weight update model - model.addSynapsePopulation("SG0", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation("SG0", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", - testParams, testVarVals, {}, testPostVarVals1, + testParams, {}, {}, testPostVarVals1, {}, {}); // Also connect neuron group 0 to neuron group 2 with post weight update model - model.addSynapsePopulation("SG1", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation("SG1", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons2", - testParams, testVarVals, {}, testPostVarVals1, + testParams, {}, {}, testPostVarVals1, {}, {}); // Also connect neuron group 0 to neuron group 3 with post weight update model but different parameters - model.addSynapsePopulation("SG2", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation("SG2", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons3", - testParams2, testVarVals, {}, testPostVarVals1, + testParams2, {}, {}, testPostVarVals1, {}, {}); // Connect neuron group 0 to neuron group 3 with 2*post weight update model - model.addSynapsePopulation("SG3", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation("SG3", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons4", - testParams, testVarVals, {}, testPostVarVals1, + testParams, {}, {}, testPostVarVals1, {}, {}); - model.addSynapsePopulation("SG4", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation("SG4", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons4", - testParams, testVarVals, {}, testPostVarVals1, + testParams, {}, {}, testPostVarVals1, {}, {}); // Connect neuron group 0 to neuron group 4 with post weight update model and static pulse - model.addSynapsePopulation("SG5", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation("SG5", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons5", - testParams, testVarVals, {}, testPostVarVals2, + testParams, {}, {}, testPostVarVals2, {}, {}); - model.addSynapsePopulation("SG6", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "Neurons0", "Neurons5", - {}, staticPulseVarVals, - {}, {}); - model.finalize(); + model.addSynapsePopulation( + "SG6", SynapseMatrixType::SPARSE, NO_DELAY, + "Neurons0", "Neurons5", + staticPulseParamVals, {}, + {}, {}); + model.finalise(); // **NOTE** NG1 and NG5 can be merged because the additional static pulse synapse population doesn't add any presynaptic update NeuronGroupInternal *ng1Internal = static_cast(ng1); diff --git a/tests/unit/scanner.cc b/tests/unit/scanner.cc index 412415259a..2f40c7368c 100644 --- a/tests/unit/scanner.cc +++ b/tests/unit/scanner.cc @@ -56,7 +56,7 @@ class TestErrorHandler : public ErrorHandlerBase TEST(Scanner, DecimalInt) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("1234 4294967295U -2345 -2147483647", {}, errorHandler); + const auto tokens = Scanner::scanSource("1234 4294967295U -2345 -2147483647", errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 7); @@ -77,7 +77,7 @@ TEST(Scanner, DecimalInt) TEST(Scanner, HexInt) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("0x1234 0xFFFFFFFFU -0x1234 -0x7FFFFFFF", {}, errorHandler); + const auto tokens = Scanner::scanSource("0x1234 0xFFFFFFFFU -0x1234 -0x7FFFFFFF", errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 7); @@ -98,12 +98,12 @@ TEST(Scanner, HexInt) TEST(Scanner, DecimalFloat) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("1.0 0.2 100.0f 0.2f -12.0d -0.0004f", {{"scalar", Type::Float}}, errorHandler); + const auto tokens = Scanner::scanSource("1.0 0.2 100.0f 0.2f -12.0d -0.0004f", errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 9); - ASSERT_EQ(tokens[0].type, Token::Type::FLOAT_NUMBER); - ASSERT_EQ(tokens[1].type, Token::Type::FLOAT_NUMBER); + ASSERT_EQ(tokens[0].type, Token::Type::SCALAR_NUMBER); + ASSERT_EQ(tokens[1].type, Token::Type::SCALAR_NUMBER); ASSERT_EQ(tokens[2].type, Token::Type::FLOAT_NUMBER); ASSERT_EQ(tokens[3].type, Token::Type::FLOAT_NUMBER); ASSERT_EQ(tokens[4].type, Token::Type::MINUS); @@ -123,7 +123,7 @@ TEST(Scanner, DecimalFloat) TEST(Scanner, String) { TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource("\"hello world\" \"pre-processor\"", {}, errorHandler); + const auto tokens = Scanner::scanSource("\"hello world\" \"pre-processor\"", errorHandler); ASSERT_FALSE(errorHandler.hasError()); ASSERT_EQ(tokens.size(), 3); diff --git a/tests/unit/synapseGroup.cc b/tests/unit/synapseGroup.cc index ee8b9fc815..0e674eacc6 100644 --- a/tests/unit/synapseGroup.cc +++ b/tests/unit/synapseGroup.cc @@ -31,22 +31,22 @@ class STDPAdditive : public WeightUpdateModels::Base SET_POST_VARS({{"postTrace", "scalar"}}); SET_SIM_CODE( - "$(addToInSyn, $(g));\n" - "const scalar dt = $(t) - $(sT_post); \n" + "addToPost(g);\n" + "const scalar dt = t - sT_post; \n" "if (dt > 0) {\n" - " const scalar newWeight = $(g) - ($(Aminus) * $(postTrace));\n" - " $(g) = fmax($(Wmin), fmin($(Wmax), newWeight));\n" + " const scalar newWeight = g - (Aminus * postTrace);\n" + " g = fmax(Wmin, fmin(Wmax, newWeight));\n" "}\n"); SET_LEARN_POST_CODE( "const scalar dt = $(t) - $(sT_pre);\n" "if (dt > 0) {\n" - " const scalar newWeight = $(g) + ($(Aplus) * $(preTrace));\n" - " $(g) = fmax($(Wmin), fmin($(Wmax), newWeight));\n" + " const scalar newWeight = g + (Aplus * preTrace);\n" + " g = fmax(Wmin, fmin(Wmax, newWeight));\n" "}\n"); - SET_PRE_SPIKE_CODE("$(preTrace) += 1.0;\n"); - SET_POST_SPIKE_CODE("$(postTrace) += 1.0;\n"); - SET_PRE_DYNAMICS_CODE("$(preTrace) *= $(tauPlusDecay);\n"); - SET_POST_DYNAMICS_CODE("$(postTrace) *= $(tauMinusDecay);\n"); + SET_PRE_SPIKE_CODE("preTrace += 1.0;\n"); + SET_POST_SPIKE_CODE("postTrace += 1.0;\n"); + SET_PRE_DYNAMICS_CODE("preTrace *= tauPlusDecay;\n"); + SET_POST_DYNAMICS_CODE("postTrace *= tauMinusDecay;\n"); SET_NEEDS_PRE_SPIKE_TIME(true); SET_NEEDS_POST_SPIKE_TIME(true); @@ -67,22 +67,22 @@ class STDPAdditiveEGPWMinMax : public WeightUpdateModels::Base SET_EXTRA_GLOBAL_PARAMS({{"Wmin", "scalar"}, {"Wmax", "scalar"}}); SET_SIM_CODE( - "$(addToInSyn, $(g));\n" - "const scalar dt = $(t) - $(sT_post); \n" + "addToPost(g);\n" + "const scalar dt = t - sT_post; \n" "if (dt > 0) {\n" - " const scalar newWeight = $(g) - ($(Aminus) * $(postTrace));\n" - " $(g) = fmax($(Wmin), fmin($(Wmax), newWeight));\n" + " const scalar newWeight = g - (Aminus * postTrace);\n" + " g = fmax(Wmin, fmin(Wmax, newWeight));\n" "}\n"); SET_LEARN_POST_CODE( - "const scalar dt = $(t) - $(sT_pre);\n" + "const scalar dt = t - sT_pre;\n" "if (dt > 0) {\n" - " const scalar newWeight = $(g) + ($(Aplus) * $(preTrace));\n" - " $(g) = fmax($(Wmin), fmin($(Wmax), newWeight));\n" + " const scalar newWeight = g + (Aplus * preTrace);\n" + " g = fmax(Wmin, fmin(Wmax, newWeight));\n" "}\n"); - SET_PRE_SPIKE_CODE("$(preTrace) += 1.0;\n"); - SET_POST_SPIKE_CODE("$(postTrace) += 1.0;\n"); - SET_PRE_DYNAMICS_CODE("$(preTrace) *= $(tauPlusDecay);\n"); - SET_POST_DYNAMICS_CODE("$(postTrace) *= $(tauMinusDecay);\n"); + SET_PRE_SPIKE_CODE("preTrace += 1.0;\n"); + SET_POST_SPIKE_CODE("postTrace += 1.0;\n"); + SET_PRE_DYNAMICS_CODE("preTrace *= tauPlusDecay;\n"); + SET_POST_DYNAMICS_CODE("postTrace *= tauMinusDecay;\n"); SET_NEEDS_PRE_SPIKE_TIME(true); SET_NEEDS_POST_SPIKE_TIME(true); @@ -104,7 +104,7 @@ class STDPAdditiveEGPSpike : public WeightUpdateModels::Base SET_EXTRA_GLOBAL_PARAMS({{"S", "scalar"}}); SET_SIM_CODE( - "$(addToInSyn, $(g));\n" + "addToPost(g);\n" "const scalar dt = $(t) - $(sT_post); \n" "if (dt > 0) {\n" " const scalar newWeight = $(g) - ($(Aminus) * $(postTrace));\n" @@ -137,7 +137,7 @@ class STDPAdditiveEGPDynamics : public WeightUpdateModels::Base SET_EXTRA_GLOBAL_PARAMS({{"tauPlusDecay", "scalar"}, {"tauMinusDecay", "scalar"}}); SET_SIM_CODE( - "$(addToInSyn, $(g));\n" + "addToPost(g);\n" "const scalar dt = $(t) - $(sT_post); \n" "if (dt > 0) {\n" " const scalar newWeight = $(g) - ($(Aminus) * $(postTrace));\n" @@ -166,7 +166,7 @@ class Continuous : public WeightUpdateModels::Base SET_VARS({{"g", "scalar"}}); - SET_SYNAPSE_DYNAMICS_CODE("$(addToInSyn, $(g) * $(V_pre));\n"); + SET_SYNAPSE_DYNAMICS_CODE("addToPost(g * V_pre);\n"); }; IMPLEMENT_SNIPPET(Continuous); @@ -175,9 +175,9 @@ class ContinuousDenDelay : public WeightUpdateModels::Base public: DECLARE_SNIPPET(ContinuousDenDelay); - SET_VARS({{"g", "scalar"}}); + SET_PARAM_NAMES({"g"}); - SET_SYNAPSE_DYNAMICS_CODE("$(addToInSynDelay, $(g) * $(V_pre), 1);\n"); + SET_SYNAPSE_DYNAMICS_CODE("addToPostDelay(g * V_pre, 1);\n"); }; IMPLEMENT_SNIPPET(ContinuousDenDelay); @@ -186,12 +186,23 @@ class GradedDenDelay : public WeightUpdateModels::Base public: DECLARE_SNIPPET(GradedDenDelay); - SET_VARS({{"g", "scalar"}}); - SET_EVENT_THRESHOLD_CONDITION_CODE("$(V_pre) >= 0.1"); - SET_EVENT_CODE("$(addToInSynDelay, $(g)*$(V_pre), 1);"); + SET_PARAM_NAMES({"g"}); + SET_EVENT_THRESHOLD_CONDITION_CODE("V_pre >= 0.1"); + SET_EVENT_CODE("addToInSynDelay(g * V_pre, 1);"); }; IMPLEMENT_SNIPPET(GradedDenDelay); +class StaticPulseDendriticDelayConstantWeight : public WeightUpdateModels::Base +{ +public: + DECLARE_SNIPPET(StaticPulseDendriticDelayConstantWeight); + + SET_PARAM_NAMES({"g", "d"}); + + SET_SIM_CODE("addToPostDelay(g, (uint8_t)d);\n"); +}; +IMPLEMENT_SNIPPET(StaticPulseDendriticDelayConstantWeight); + class StaticPulseDynamics : public WeightUpdateModels::Base { public: @@ -199,8 +210,8 @@ class StaticPulseDynamics : public WeightUpdateModels::Base SET_VARS({ {"g", "scalar", VarAccess::READ_ONLY} }); - SET_SIM_CODE("$(addToInSyn, $(g));\n"); - SET_SYNAPSE_DYNAMICS_CODE("$(g) *= 0.99;\n"); + SET_SIM_CODE("addToInSyn(g);\n"); + SET_SYNAPSE_DYNAMICS_CODE("g *= 0.99;\n"); }; IMPLEMENT_SNIPPET(StaticPulseDynamics); @@ -211,8 +222,8 @@ class StaticPulsePostLearn : public WeightUpdateModels::Base SET_VARS({ {"g", "scalar", VarAccess::READ_ONLY} }); - SET_SIM_CODE("$(addToInSyn, $(g));\n"); - SET_LEARN_POST_CODE("$(g) *= 0.99;\n"); + SET_SIM_CODE("addToInSyn(g);\n"); + SET_LEARN_POST_CODE("g *= 0.99;\n"); }; IMPLEMENT_SNIPPET(StaticPulsePostLearn); @@ -221,7 +232,7 @@ class PostRepeatVal : public InitVarSnippet::Base public: DECLARE_SNIPPET(PostRepeatVal); - SET_CODE("$(value) = $(values)[$(id_post) % 10];"); + SET_CODE("value = values[id_post % 10];"); SET_EXTRA_GLOBAL_PARAMS({{"values", "scalar*"}}); }; @@ -232,7 +243,7 @@ class PreRepeatVal : public InitVarSnippet::Base public: DECLARE_SNIPPET(PreRepeatVal); - SET_CODE("$(value) = $(values)[$(id_re) % 10];"); + SET_CODE("value = values[id_pre % 10];"); SET_EXTRA_GLOBAL_PARAMS({{"values", "scalar*"}}); }; @@ -242,7 +253,7 @@ class Sum : public CustomUpdateModels::Base { DECLARE_SNIPPET(Sum); - SET_UPDATE_CODE("$(sum) = $(a) + $(b);\n"); + SET_UPDATE_CODE("sum = a + b;\n"); SET_VARS({{"sum", "scalar"}}); SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_ONLY}, @@ -312,17 +323,17 @@ TEST(SynapseGroup, WUVarReferencedByCustomUpdate) VarValues wumPostVarVals{{"postTrace", 0.0}}; auto *sg1 = model.addSynapsePopulation( - "Synapses1", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapses1", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParams, wumVarVals, wumPreVarVals, wumPostVarVals, {}, {}); auto *sg2 = model.addSynapsePopulation( - "Synapses2", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapses2", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParams, wumVarVals, wumPreVarVals, wumPostVarVals, {}, {}); auto *sg3 = model.addSynapsePopulation( - "Synapses3", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Synapses3", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParams, wumVarVals, wumPreVarVals, wumPostVarVals, {}, {}); @@ -336,7 +347,7 @@ TEST(SynapseGroup, WUVarReferencedByCustomUpdate) {}, sumVarValues, sumVarReferences2); model.addCustomUpdate("SumWeight3", "CustomUpdate", {}, sumVarValues, sumVarReferences3); - model.finalize(); + model.finalise(); ASSERT_TRUE(static_cast(sg1)->getCustomUpdateReferences().empty()); ASSERT_FALSE(static_cast(sg2)->getCustomUpdateReferences().empty()); @@ -355,16 +366,16 @@ TEST(SynapseGroup, CompareWUDifferentModel) VarValues staticPulseVarVals{{"g", 0.1}}; VarValues staticPulseDendriticVarVals{{"g", 0.1}, {"d", 1}}; - auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", {}, staticPulseVarVals, {}, {}); - auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", {}, staticPulseDendriticVarVals, {}, {}); // Finalize model - model.finalize(); + model.finalise(); SynapseGroupInternal *sg0Internal = static_cast(sg0); SynapseGroupInternal *sg1Internal = static_cast(sg1); @@ -398,22 +409,25 @@ TEST(SynapseGroup, CompareWUDifferentGlobalG) model.addNeuronPopulation("Neurons0", 10, paramVals, varVals); model.addNeuronPopulation("Neurons1", 10, paramVals, varVals); - VarValues staticPulseAVarVals{{"g", 0.1}}; - VarValues staticPulseBVarVals{{"g", 0.2}}; - auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "Neurons0", "Neurons1", - {}, staticPulseAVarVals, - {}, {}); - auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "Neurons0", "Neurons1", - {}, staticPulseAVarVals, - {}, {}); - auto *sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, - "Neurons0", "Neurons1", - {}, staticPulseBVarVals, - {}, {}); + ParamValues staticPulseAParamVals{{"g", 0.1}}; + ParamValues staticPulseBParamVals{{"g", 0.2}}; + auto *sg0 = model.addSynapsePopulation( + "Synapses0", SynapseMatrixType::SPARSE, NO_DELAY, + "Neurons0", "Neurons1", + staticPulseAParamVals, {}, + {}, {}); + auto *sg1 = model.addSynapsePopulation( + "Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, + "Neurons0", "Neurons1", + staticPulseAParamVals, {}, + {}, {}); + auto *sg2 = model.addSynapsePopulation( + "Synapses2", SynapseMatrixType::SPARSE, NO_DELAY, + "Neurons0", "Neurons1", + staticPulseBParamVals, {}, + {}, {}); // Finalize model - model.finalize(); + model.finalise(); SynapseGroupInternal *sg0Internal = static_cast(sg0); SynapseGroupInternal *sg1Internal = static_cast(sg1); @@ -436,7 +450,7 @@ TEST(SynapseGroup, CompareWUDifferentGlobalG) ASSERT_TRUE(modelSpecMerged.getMergedSynapseSparseInitGroups().empty()); // Check that global g var is heterogeneous - ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isWUGlobalVarHeterogeneous("g")); + ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isWUParamHeterogeneous("g")); } TEST(SynapseGroup, CompareWUDifferentProceduralConnectivity) @@ -451,24 +465,27 @@ TEST(SynapseGroup, CompareWUDifferentProceduralConnectivity) ParamValues fixedProbParamsA{{"prob", 0.1}}; ParamValues fixedProbParamsB{{"prob", 0.4}}; - VarValues staticPulseVarVals{{"g", 0.1}}; - auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::PROCEDURAL_GLOBALG, NO_DELAY, - "Neurons0", "Neurons1", - {}, staticPulseVarVals, - {}, {}, - initConnectivity(fixedProbParamsA)); - auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::PROCEDURAL_GLOBALG, NO_DELAY, - "Neurons0", "Neurons1", - {}, staticPulseVarVals, - {}, {}, - initConnectivity(fixedProbParamsA)); - auto *sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::PROCEDURAL_GLOBALG, NO_DELAY, - "Neurons0", "Neurons1", - {}, staticPulseVarVals, - {}, {}, - initConnectivity(fixedProbParamsB)); + ParamValues staticPulseParamVals{{"g", 0.1}}; + auto *sg0 = model.addSynapsePopulation( + "Synapses0", SynapseMatrixType::PROCEDURAL, NO_DELAY, + "Neurons0", "Neurons1", + staticPulseParamVals, {}, + {}, {}, + initConnectivity(fixedProbParamsA)); + auto *sg1 = model.addSynapsePopulation( + "Synapses1", SynapseMatrixType::PROCEDURAL, NO_DELAY, + "Neurons0", "Neurons1", + staticPulseParamVals, {}, + {}, {}, + initConnectivity(fixedProbParamsA)); + auto *sg2 = model.addSynapsePopulation( + "Synapses2", SynapseMatrixType::PROCEDURAL, NO_DELAY, + "Neurons0", "Neurons1", + staticPulseParamVals, {}, + {}, {}, + initConnectivity(fixedProbParamsB)); // Finalize model - model.finalize(); + model.finalise(); SynapseGroupInternal *sg0Internal = static_cast(sg0); SynapseGroupInternal *sg1Internal = static_cast(sg1); @@ -518,23 +535,23 @@ TEST(SynapseGroup, CompareWUDifferentToeplitzConnectivity) {"conv_ih", 64}, {"conv_iw", 64}, {"conv_ic", 1}, {"conv_oh", 64}, {"conv_ow", 64}, {"conv_oc", 1}}; VarValues staticPulseVarVals{{"g", 0.1}}; - auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::TOEPLITZ_KERNELG, NO_DELAY, + auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::TOEPLITZ, NO_DELAY, "Pre", "Post1", {}, staticPulseVarVals, {}, {}, initToeplitzConnectivity(convParamsA)); - auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::TOEPLITZ_KERNELG, NO_DELAY, + auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::TOEPLITZ, NO_DELAY, "Pre", "Post1", {}, staticPulseVarVals, {}, {}, initToeplitzConnectivity(convParamsA)); - auto *sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::TOEPLITZ_KERNELG, NO_DELAY, + auto *sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::TOEPLITZ, NO_DELAY, "Pre", "Post2", {}, staticPulseVarVals, {}, {}, initToeplitzConnectivity(convParamsB)); // Finalize model - model.finalize(); + model.finalise(); SynapseGroupInternal *sg0Internal = static_cast(sg0); SynapseGroupInternal *sg1Internal = static_cast(sg1); @@ -584,23 +601,23 @@ TEST(SynapseGroup, CompareWUDifferentProceduralVars) ParamValues uniformParamsB{{"min", 0.25}, {"max", 0.5}}; VarValues staticPulseVarValsA{{"g", initVar(uniformParamsA)}}; VarValues staticPulseVarValsB{{"g", initVar(uniformParamsB)}}; - auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::PROCEDURAL_PROCEDURALG, NO_DELAY, + auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::PROCEDURAL, NO_DELAY, "Neurons0", "Neurons1", {}, staticPulseVarValsA, {}, {}, initConnectivity(fixedProbParams)); - auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::PROCEDURAL_PROCEDURALG, NO_DELAY, + auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::PROCEDURAL, NO_DELAY, "Neurons0", "Neurons1", {}, staticPulseVarValsA, {}, {}, initConnectivity(fixedProbParams)); - auto *sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::PROCEDURAL_PROCEDURALG, NO_DELAY, + auto *sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::PROCEDURAL, NO_DELAY, "Neurons0", "Neurons1", {}, staticPulseVarValsB, {}, {}, initConnectivity(fixedProbParams)); // Finalize model - model.finalize(); + model.finalise(); SynapseGroupInternal *sg0Internal = static_cast(sg0); SynapseGroupInternal *sg1Internal = static_cast(sg1); @@ -625,8 +642,8 @@ TEST(SynapseGroup, CompareWUDifferentProceduralVars) // Check that only synaptic weight initialistion parameters are heterogeneous ASSERT_FALSE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isSparseConnectivityInitParamHeterogeneous("prob")); ASSERT_FALSE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isSparseConnectivityInitDerivedParamHeterogeneous("probLogRecip")); - ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isWUVarInitParamHeterogeneous("g", "min")); - ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isWUVarInitParamHeterogeneous("g", "max")); + ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isVarInitParamHeterogeneous("g", "min")); + ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isVarInitParamHeterogeneous("g", "max")); } TEST(SynapseGroup, CompareWUDifferentProceduralSnippet) @@ -654,7 +671,7 @@ TEST(SynapseGroup, CompareWUDifferentProceduralSnippet) {}, staticPulseVarValsB, {}, {}); // Finalize model - model.finalize(); + model.finalise(); SynapseGroupInternal *sg0Internal = static_cast(sg0); SynapseGroupInternal *sg1Internal = static_cast(sg1); @@ -693,23 +710,23 @@ TEST(SynapseGroup, InitCompareWUDifferentVars) VarValues varValsB{{"g", 1.0}}; VarValues preVarVals{{"preTrace", 0.0}}; VarValues postVarVals{{"postTrace", 0.0}}; - auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", params, varValsA, preVarVals, postVarVals, {}, {}, initConnectivity(fixedProbParams)); - auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", params, varValsA, preVarVals, postVarVals, {}, {}, initConnectivity(fixedProbParams)); - auto *sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto *sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", params, varValsB, preVarVals, postVarVals, {}, {}, initConnectivity(fixedProbParams)); // Finalize model - model.finalize(); + model.finalise(); SynapseGroupInternal *sg0Internal = static_cast(sg0); SynapseGroupInternal *sg1Internal = static_cast(sg1); @@ -738,7 +755,7 @@ TEST(SynapseGroup, InitCompareWUDifferentVars) // Check that only synaptic weight initialistion parameters are heterogeneous ASSERT_FALSE(modelSpecMerged.getMergedSynapseConnectivityInitGroups().at(0).isSparseConnectivityInitParamHeterogeneous("prob")); ASSERT_FALSE(modelSpecMerged.getMergedSynapseConnectivityInitGroups().at(0).isSparseConnectivityInitDerivedParamHeterogeneous("prob")); - ASSERT_TRUE(modelSpecMerged.getMergedSynapseSparseInitGroups().at(0).isWUVarInitParamHeterogeneous("g", "constant")); + ASSERT_TRUE(modelSpecMerged.getMergedSynapseSparseInitGroups().at(0).isVarInitParamHeterogeneous("g", "constant")); } TEST(SynapseGroup, InitCompareWUDifferentPreVars) @@ -757,23 +774,23 @@ TEST(SynapseGroup, InitCompareWUDifferentPreVars) VarValues preVarValsA{{"preTrace", 0.0}}; VarValues preVarValsB{{"preTrace", 1.0}}; VarValues postVarVals{{"postTrace", 0.0}}; - auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", params, synVarVals, preVarValsA, postVarVals, {}, {}, initConnectivity(fixedProbParams)); - auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", params, synVarVals, preVarValsA, postVarVals, {}, {}, initConnectivity(fixedProbParams)); - auto *sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto *sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", params, synVarVals, preVarValsB, postVarVals, {}, {}, initConnectivity(fixedProbParams)); // Finalize model - model.finalize(); + model.finalise(); SynapseGroupInternal *sg0Internal = static_cast(sg0); SynapseGroupInternal *sg1Internal = static_cast(sg1); @@ -802,23 +819,23 @@ TEST(SynapseGroup, InitCompareWUDifferentPostVars) VarValues preVarVals{{"preTrace", 0.0}}; VarValues postVarValsA{{"postTrace", 0.0}}; VarValues postVarValsB{{"postTrace", 0.0}}; - auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", params, synVarVals, preVarVals, postVarValsA, {}, {}, initConnectivity(fixedProbParams)); - auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", params, synVarVals, preVarVals, postVarValsA, {}, {}, initConnectivity(fixedProbParams)); - auto *sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto *sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", params, synVarVals, preVarVals, postVarValsB, {}, {}, initConnectivity(fixedProbParams)); // Finalize model - model.finalize(); + model.finalise(); SynapseGroupInternal *sg0Internal = static_cast(sg0); SynapseGroupInternal *sg1Internal = static_cast(sg1); @@ -844,18 +861,18 @@ TEST(SynapseGroup, InitCompareWUDifferentHeterogeneousParamVarState) ParamValues fixedNumberPostParamsA{{"rowLength", 4}}; ParamValues fixedNumberPostParamsB{{"rowLength", 8}}; VarValues staticPulseVarVals{{"g", 0.1}}; - auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto *sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", {}, staticPulseVarVals, {}, {}, initConnectivity(fixedNumberPostParamsA)); - auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto *sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", {}, staticPulseVarVals, {}, {}, initConnectivity(fixedNumberPostParamsB)); // Finalize model - model.finalize(); + model.finalise(); SynapseGroupInternal *sg0Internal = static_cast(sg0); SynapseGroupInternal *sg1Internal = static_cast(sg1); @@ -894,23 +911,23 @@ TEST(SynapseGroup, InitCompareWUSynapseDynamicsPostLearn) ParamValues fixedNumberPostParams{{"rowLength", 8}}; VarValues staticPulseVarVals{{"g", 0.1}}; - auto* sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto* sg0 = model.addSynapsePopulation("Synapses0", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", {}, staticPulseVarVals, {}, {}, initConnectivity(fixedNumberPostParams)); - auto* sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto* sg1 = model.addSynapsePopulation("Synapses1", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", {}, staticPulseVarVals, {}, {}, initConnectivity(fixedNumberPostParams)); - auto* sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + auto* sg2 = model.addSynapsePopulation("Synapses2", SynapseMatrixType::SPARSE, NO_DELAY, "Neurons0", "Neurons1", {}, staticPulseVarVals, {}, {}, initConnectivity(fixedNumberPostParams)); // Finalize model - model.finalize(); + model.finalise(); SynapseGroupInternal* sg0Internal = static_cast(sg0); SynapseGroupInternal* sg1Internal = static_cast(sg1); @@ -949,10 +966,10 @@ TEST(SynapseGroup, InvalidMatrixTypes) // Check that making a synapse group with procedural connectivity fails if no connectivity initialiser is specified try { - model.addSynapsePopulation( - "NeuronsA_NeuronsB_1", SynapseMatrixType::PROCEDURAL_GLOBALG, NO_DELAY, + model.addSynapsePopulation( + "NeuronsA_NeuronsB_1", SynapseMatrixType::PROCEDURAL, NO_DELAY, "NeuronsA", "NeuronsB", - {}, {{"g", 1.0}}, + {{"g", 1.0}}, {}, {}, {}); FAIL(); } @@ -967,7 +984,7 @@ TEST(SynapseGroup, InvalidMatrixTypes) VarValues preVarVals{{"preTrace", 0.0}}; VarValues postVarVals{{"postTrace", 0.0}}; - model.addSynapsePopulation("NeuronsA_NeuronsB_2", SynapseMatrixType::PROCEDURAL_GLOBALG, NO_DELAY, + model.addSynapsePopulation("NeuronsA_NeuronsB_2", SynapseMatrixType::PROCEDURAL, NO_DELAY, "NeuronsA", "NeuronsB", params, varVals, preVarVals, postVarVals, {}, {}, @@ -980,7 +997,7 @@ TEST(SynapseGroup, InvalidMatrixTypes) // Check that making a synapse group with procedural connectivity and synapse dynamics fails try { ParamValues fixedProbParams{{"prob", 0.1}}; - model.addSynapsePopulation("NeuronsA_NeuronsB_3", SynapseMatrixType::PROCEDURAL_GLOBALG, NO_DELAY, + model.addSynapsePopulation("NeuronsA_NeuronsB_3", SynapseMatrixType::PROCEDURAL, NO_DELAY, "NeuronsA", "NeuronsB", {}, {{"g", 0.0}}, {}, {}, {}, {}, @@ -1012,31 +1029,31 @@ TEST(SynapseGroup, IsDendriticDelayRequired) model.addNeuronPopulation("Pre", 10, paramVals, varVals); model.addNeuronPopulation("Post", 10, paramVals, varVals); - VarValues staticPulseDendriticVarVals{{"g", 0.1}, {"d", 1}}; - VarValues gradedDenDelayVarVars{{"g", 0.1}}; - VarValues contDenDelayVarVars{{"g", 0.1}}; + ParamValues staticPulseDendriticParamVals{{"g", 0.1}, {"d", 1}}; + ParamValues gradedDenDelayParamVars{{"g", 0.1}}; + ParamValues contDenDelayParamVars{{"g", 0.1}}; - auto *syn = model.addSynapsePopulation( - "Syn", SynapseMatrixType::DENSE_GLOBALG, NO_DELAY, + auto *syn = model.addSynapsePopulation( + "Syn", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", - {}, staticPulseDendriticVarVals, + staticPulseDendriticParamVals, {}, {}, {}); auto *synGraded = model.addSynapsePopulation( - "SynGraded", SynapseMatrixType::DENSE_GLOBALG, NO_DELAY, + "SynGraded", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", - {}, gradedDenDelayVarVars, + gradedDenDelayParamVars, {}, {}, {}); auto *synContinuous = model.addSynapsePopulation( - "SynContinuous", SynapseMatrixType::DENSE_GLOBALG, NO_DELAY, + "SynContinuous", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", - {}, contDenDelayVarVars, + contDenDelayParamVars, {}, {}, {}); - ASSERT_TRUE(syn->isDendriticDelayRequired()); - ASSERT_TRUE(synGraded->isDendriticDelayRequired()); - ASSERT_TRUE(synContinuous->isDendriticDelayRequired()); + ASSERT_TRUE(static_cast(syn)->isDendriticDelayRequired()); + ASSERT_TRUE(static_cast(synGraded)->isDendriticDelayRequired()); + ASSERT_TRUE(static_cast(synContinuous)->isDendriticDelayRequired()); } TEST(SynapseGroup, InvalidName) @@ -1048,10 +1065,10 @@ TEST(SynapseGroup, InvalidName) model.addNeuronPopulation("Pre", 10, {}, {}); model.addNeuronPopulation("Post", 10, paramVals, varVals); try { - model.addSynapsePopulation( - "Syn-6", SynapseMatrixType::DENSE_GLOBALG, NO_DELAY, + model.addSynapsePopulation( + "Syn-6", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", - {}, {{"g", 1.0}}, + {{"g", 1.0}}, {}, {}, {}); FAIL(); } @@ -1078,40 +1095,40 @@ TEST(SynapseGroup, CanWUMPreUpdateBeFused) VarValues wumPostVarVals{{"postTrace", 0.0}}; auto *constPre = model.addSynapsePopulation( - "Pre_Post_ConstPre", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Pre_Post_ConstPre", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParams, wumVarVals, wumConstPreVarVals, wumPostVarVals, {}, {}); auto *nonConstPre = model.addSynapsePopulation( - "Pre_Post_NonConstPre", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Pre_Post_NonConstPre", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParams, wumVarVals, wumNonConstPreVarVals, wumPostVarVals, {}, {}); auto *egpWMinMax = model.addSynapsePopulation( - "Pre_Post_EGPWMinMax", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Pre_Post_EGPWMinMax", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumEGPWMinMaxParams, wumVarVals, wumConstPreVarVals, wumPostVarVals, {}, {}); - auto *egpSpike = model.addSynapsePopulation( - "Pre_Post_EGPSpike", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + /*auto *egpSpike = model.addSynapsePopulation( + "Pre_Post_EGPSpike", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParams, wumVarVals, wumConstPreVarVals, wumPostVarVals, {}, {}); auto *egpDynamics = model.addSynapsePopulation( - "Pre_Post_EGPDynamics", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Pre_Post_EGPDynamics", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumEGPDynamicsParams, wumVarVals, wumConstPreVarVals, wumPostVarVals, - {}, {}); + {}, {});*/ ASSERT_TRUE(static_cast(constPre)->canWUMPreUpdateBeFused()); ASSERT_FALSE(static_cast(nonConstPre)->canWUMPreUpdateBeFused()); ASSERT_TRUE(static_cast(egpWMinMax)->canWUMPreUpdateBeFused()); - ASSERT_FALSE(static_cast(egpSpike)->canWUMPreUpdateBeFused()); - ASSERT_FALSE(static_cast(egpDynamics)->canWUMPreUpdateBeFused()); + //ASSERT_FALSE(static_cast(egpSpike)->canWUMPreUpdateBeFused()); + //ASSERT_FALSE(static_cast(egpDynamics)->canWUMPreUpdateBeFused()); } TEST(SynapseGroup, CanWUMPostUpdateBeFused) @@ -1133,31 +1150,31 @@ TEST(SynapseGroup, CanWUMPostUpdateBeFused) VarValues wumNonConstPostVarVals{{"postTrace", initVar({{"min", 0.0}, {"max", 1.0}})}}; auto *constPost = model.addSynapsePopulation( - "Pre_Post_ConstPost", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Pre_Post_ConstPost", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParams, wumVarVals, wumPreVarVals, wumConstPostVarVals, {}, {}); auto *nonConstPost = model.addSynapsePopulation( - "Pre_Post_NonConstPost", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Pre_Post_NonConstPost", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParams, wumVarVals, wumPreVarVals, wumNonConstPostVarVals, {}, {}); auto *egpWMinMax = model.addSynapsePopulation( - "Pre_Post_EGPWMinMax", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Pre_Post_EGPWMinMax", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumEGPWMinMaxParams, wumVarVals, wumPreVarVals, wumConstPostVarVals, {}, {}); auto *egpSpike = model.addSynapsePopulation( - "Pre_Post_EGPSpike", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Pre_Post_EGPSpike", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumParams, wumVarVals, wumPreVarVals, wumConstPostVarVals, {}, {}); auto *egpDynamics = model.addSynapsePopulation( - "Pre_Post_EGPDynamics", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Pre_Post_EGPDynamics", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", wumEGPDynamicsParams, wumVarVals, wumPreVarVals, wumConstPostVarVals, {}, {}); @@ -1178,7 +1195,7 @@ TEST(SynapseGroup, InvalidPSOutputVar) model.addNeuronPopulation("Pre", 10, {}, {}); model.addNeuronPopulation("Post", 10, paramVals, varVals); auto *prePost = model.addSynapsePopulation( - "PrePost", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, + "PrePost", SynapseMatrixType::SPARSE, NO_DELAY, "Pre", "Post", {}, {{"g", 1.0}}, {}, {}); diff --git a/tests/unit/typeChecker.cc b/tests/unit/typeChecker.cc index edd9ecdedf..859e6c5392 100644 --- a/tests/unit/typeChecker.cc +++ b/tests/unit/typeChecker.cc @@ -4,11 +4,13 @@ // GeNN includes #include "type.h" +// GeNN code generator includes +#include "code_generator/standardLibrary.h" + // GeNN transpiler includes #include "transpiler/errorHandler.h" #include "transpiler/parser.h" #include "transpiler/scanner.h" -#include "transpiler/standardLibrary.h" #include "transpiler/typeChecker.h" using namespace GeNN; @@ -93,11 +95,43 @@ class TestEnvironment : public TypeChecker::EnvironmentBase std::unordered_map m_Types; }; +class TestLibraryEnvironment : public TypeChecker::EnvironmentBase +{ +public: + explicit TestLibraryEnvironment(const CodeGenerator::EnvironmentLibrary::Library &library) + : m_Library(library) + {} + //--------------------------------------------------------------------------- + // EnvironmentBase virtuals + //--------------------------------------------------------------------------- + virtual void define(const Token&, const Type::ResolvedType&, ErrorHandlerBase&) final + { + throw TypeChecker::TypeCheckError(); + } + + virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final + { + const auto [typeBegin, typeEnd] = m_Library.get().equal_range(name.lexeme); + if (typeBegin == typeEnd) { + throw TypeChecker::TypeCheckError(); + } + else { + std::vector types; + types.reserve(std::distance(typeBegin, typeEnd)); + std::transform(typeBegin, typeEnd, std::back_inserter(types), + [](const auto &t) { return t.second.first; }); + return types; + } + } +private: + std::reference_wrapper m_Library; +}; + void typeCheckStatements(std::string_view code, TypeChecker::EnvironmentBase &typeEnvironment, const Type::TypeContext &typeContext = {}) { // Scan TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); + const auto tokens = Scanner::scanSource(code, errorHandler); ASSERT_FALSE(errorHandler.hasError()); // Parse @@ -105,7 +139,7 @@ void typeCheckStatements(std::string_view code, TypeChecker::EnvironmentBase &ty ASSERT_FALSE(errorHandler.hasError()); // Typecheck - TypeChecker::typeCheck(statements, typeEnvironment, errorHandler); + TypeChecker::typeCheck(statements, typeEnvironment, typeContext, errorHandler); ASSERT_FALSE(errorHandler.hasError()); } @@ -113,7 +147,7 @@ Type::ResolvedType typeCheckExpression(std::string_view code, TypeChecker::Envir { // Scan TestErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(code, typeContext, errorHandler); + const auto tokens = Scanner::scanSource(code, errorHandler); EXPECT_FALSE(errorHandler.hasError()); // Parse @@ -121,9 +155,9 @@ Type::ResolvedType typeCheckExpression(std::string_view code, TypeChecker::Envir EXPECT_FALSE(errorHandler.hasError()); // Typecheck - const auto type = TypeChecker::typeCheck(expression.get(), typeEnvironment, errorHandler); + const auto resolvedTypes = TypeChecker::typeCheck(expression.get(), typeEnvironment, typeContext, errorHandler); EXPECT_FALSE(errorHandler.hasError()); - return type; + return resolvedTypes.at(expression.get()); } } // Anonymous namespace @@ -250,7 +284,7 @@ TEST(TypeChecker, Binary) TEST(TypeChecker, Call) { // Too few arguments - StandardLibrary::FunctionTypes stdLibraryEnv; + TestLibraryEnvironment stdLibraryEnv(CodeGenerator::StandardLibrary::getMathsFunctions()); EXPECT_THROW({ typeCheckExpression("sin()", stdLibraryEnv);}, TypeChecker::TypeCheckError); @@ -346,14 +380,6 @@ TEST(TypeChecker, Cast) EXPECT_EQ(*type.getPointer().valueType, Type::Int32); } - // Can't remove value const from numeric - // **THINK** why not? it's a copy - EXPECT_THROW({ - TestEnvironment typeEnvironment; - typeEnvironment.define(Type::Int32.addQualifier(Type::Qualifier::CONSTANT), "intVal"); - typeCheckExpression("(int)intVal", typeEnvironment);}, - TypeChecker::TypeCheckError); - // Can't remove value const from pointer EXPECT_THROW({ TestEnvironment typeEnvironment; diff --git a/tests/unit/weightUpdateModels.cc b/tests/unit/weightUpdateModels.cc index 6817197e40..017767972e 100644 --- a/tests/unit/weightUpdateModels.cc +++ b/tests/unit/weightUpdateModels.cc @@ -16,19 +16,19 @@ class PiecewiseSTDPCopy : public WeightUpdateModels::Base { public: SET_PARAM_NAMES({"tLrn", "tChng", "tDecay", "tPunish10", "tPunish01", - "gMax", "gMid", "gSlope", "tauShift", "gSyn0"}); + "gMax", "gMid", "gSlope", "tauShift", "gSyn0"}); SET_VARS({{"g", "scalar"}, {"gRaw", "scalar"}}); SET_SIM_CODE( - "$(addToInSyn, $(g));\n" - "scalar dt = $(sT_post) - $(t) - ($(tauShift)); \n" + "addToPost(g);\n" + "scalar dt = sT_post - t - tauShift; \n" "scalar dg = 0;\n" - "if (dt > $(lim0)) \n" - " dg = -($(off0)) ; \n" + "if (dt > lim0) \n" + " dg = -off0 ; \n" "else if (dt > 0) \n" - " dg = $(slope0) * dt + ($(off1)); \n" - "else if (dt > $(lim1)) \n" - " dg = $(slope1) * dt + ($(off1)); \n" + " dg = slope0 * dt + off1; \n" + "else if (dt > lim1) \n" + " dg = slope1 * dt + ($(off1)); \n" "else dg = - ($(off2)) ; \n" "$(gRaw) += dg; \n" "$(g)=$(gMax)/2 *(tanh($(gSlope)*($(gRaw) - ($(gMid))))+1); \n"); From 538f7b48ae03e8e6be79aa9f3862ef7ecdc00da1 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 17:05:09 +0100 Subject: [PATCH 380/725] fixed bug in ``genKernelIteration`` --- src/genn/backends/single_threaded_cpu/backend.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 8ce9acd98b..b328a09eeb 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -57,9 +57,9 @@ void genKernelIteration(EnvironmentExternalBase &env, G &g, size_t numKernelDims { // Define recursive function to generate nested kernel initialisation loops // **NOTE** this is a std::function as type of auto lambda couldn't be determined inside for recursive call - std::function generateRecursive = - [&handler, &env, &g, &generateRecursive, numKernelDims] - (size_t depth) + std::function generateRecursive = + [&handler, &g, &generateRecursive, numKernelDims] + (EnvironmentExternalBase &env, size_t depth) { // Loop through this kernel dimensions const std::string idxVar = "k" + std::to_string(depth); @@ -83,13 +83,13 @@ void genKernelIteration(EnvironmentExternalBase &env, G &g, size_t numKernelDims } // Otherwise, recurse else { - generateRecursive(depth + 1); + generateRecursive(loopEnv, depth + 1); } } }; // Generate loops through kernel indices recursively - generateRecursive(0); + generateRecursive(env, 0); } } From f9c50e2a05521e1561f4b0e63c871daffeb1a4e2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 17:06:08 +0100 Subject: [PATCH 381/725] generated models where required/possible in unit tests --- .../genn/genn/initSparseConnectivitySnippet.h | 3 +- include/genn/genn/neuronModels.h | 6 +- tests/unit/customConnectivityUpdate.cc | 18 ++- tests/unit/customUpdate.cc | 57 ++++++--- tests/unit/modelSpecMerged.cc | 59 +++++---- tests/unit/models.cc | 30 ++--- tests/unit/neuronGroup.cc | 70 +++++++++-- tests/unit/neuronModels.cc | 2 +- tests/unit/synapseGroup.cc | 112 +++++++++--------- 9 files changed, 221 insertions(+), 136 deletions(-) diff --git a/include/genn/genn/initSparseConnectivitySnippet.h b/include/genn/genn/initSparseConnectivitySnippet.h index 5ed77b014f..fb04af8d0e 100644 --- a/include/genn/genn/initSparseConnectivitySnippet.h +++ b/include/genn/genn/initSparseConnectivitySnippet.h @@ -109,8 +109,7 @@ class OneToOne : public Base DECLARE_SNIPPET(InitSparseConnectivitySnippet::OneToOne); SET_ROW_BUILD_CODE( - "$(addSynapse, $(id_pre));\n" - "$(endRow);\n"); + "addSynapse(id_pre);\n"); SET_MAX_ROW_LENGTH(1); SET_MAX_COL_LENGTH(1); diff --git a/include/genn/genn/neuronModels.h b/include/genn/genn/neuronModels.h index e2d969fef4..7994045b1f 100644 --- a/include/genn/genn/neuronModels.h +++ b/include/genn/genn/neuronModels.h @@ -165,9 +165,9 @@ class Izhikevich : public Base " $(V)=$(c);\n" " $(U)+=$(d);\n" "} \n" - "$(V)+=0.5*(0.04*$(V)*$(V)+5.0*$(V)+140.0-$(U)+$(Isyn))*DT; //at two times for numerical stability\n" - "$(V)+=0.5*(0.04*$(V)*$(V)+5.0*$(V)+140.0-$(U)+$(Isyn))*DT;\n" - "$(U)+=$(a)*($(b)*$(V)-$(U))*DT;\n" + "$(V)+=0.5*(0.04*$(V)*$(V)+5.0*$(V)+140.0-$(U)+$(Isyn))*dt; //at two times for numerical stability\n" + "$(V)+=0.5*(0.04*$(V)*$(V)+5.0*$(V)+140.0-$(U)+$(Isyn))*dt;\n" + "$(U)+=$(a)*($(b)*$(V)-$(U))*dt;\n" "if ($(V) > 30.0){ //keep this to not confuse users with unrealistiv voltage values \n" " $(V)=30.0; \n" "}\n"); diff --git a/tests/unit/customConnectivityUpdate.cc b/tests/unit/customConnectivityUpdate.cc index 3df14d5a5a..1b28ccd552 100644 --- a/tests/unit/customConnectivityUpdate.cc +++ b/tests/unit/customConnectivityUpdate.cc @@ -1,3 +1,7 @@ +// Standard C++ includes +#include +#undef DUPLICATE + // Google test includes #include "gtest/gtest.h" @@ -5,6 +9,7 @@ #include "modelSpecInternal.h" // GeNN code generator includes +#include "code_generator/generateModules.h" #include "code_generator/modelSpecMerged.h" // (Single-threaded CPU) backend includes @@ -21,7 +26,7 @@ class StaticPulseDendriticDelayReverse : public WeightUpdateModels::Base SET_VARS({{"d", "uint8_t", VarAccess::READ_ONLY}, {"g", "scalar", VarAccess::READ_ONLY}}); - SET_SIM_CODE("addToInSynDelay(g, d);\n"); + SET_SIM_CODE("addToPostDelay(g, d);\n"); }; IMPLEMENT_SNIPPET(StaticPulseDendriticDelayReverse); @@ -111,7 +116,7 @@ class Cont : public WeightUpdateModels::Base SET_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( - "addToInSyn(g * V_pre);\n"); + "addToPost(g * V_pre);\n"); }; IMPLEMENT_SNIPPET(Cont); @@ -123,7 +128,7 @@ class ContPost : public WeightUpdateModels::Base SET_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( - "addToInSyn(g * V_post);\n"); + "addToPost(g * V_post);\n"); }; IMPLEMENT_SNIPPET(ContPost); @@ -359,6 +364,13 @@ TEST(CustomConnectivityUpdate, CompareDifferentDependentVars) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateCustomUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + + // Check correct groups are merged ASSERT_EQ(modelSpecMerged.getMergedCustomConnectivityUpdateGroups().size(), 2); ASSERT_EQ(modelSpecMerged.getMergedCustomConnectivityUpdatePreInitGroups().size(), 0); diff --git a/tests/unit/customUpdate.cc b/tests/unit/customUpdate.cc index fff2fd04af..15719b42d2 100644 --- a/tests/unit/customUpdate.cc +++ b/tests/unit/customUpdate.cc @@ -1,3 +1,7 @@ +// Standard C++ includes +#include +#undef DUPLICATE + // Google test includes #include "gtest/gtest.h" @@ -5,6 +9,7 @@ #include "modelSpecInternal.h" // GeNN code generator includes +#include "code_generator/generateModules.h" #include "code_generator/modelSpecMerged.h" // (Single-threaded CPU) backend includes @@ -87,7 +92,7 @@ class Cont : public WeightUpdateModels::Base SET_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( - "addToInSyn(g * V_pre);\n"); + "addToPost(g * V_pre);\n"); }; IMPLEMENT_SNIPPET(Cont); @@ -99,7 +104,7 @@ class Cont2 : public WeightUpdateModels::Base SET_VARS({{"g", "scalar"}, {"x", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( - "addToInSyn((g + x) * V_pre);\n"); + "addToPost((g + x) * V_pre);\n"); }; IMPLEMENT_SNIPPET(Cont2); @@ -672,6 +677,12 @@ TEST(CustomUpdates, CompareDifferentModel) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateCustomUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check correct groups are merged ASSERT_TRUE(modelSpecMerged.getMergedCustomUpdateGroups().size() == 2); ASSERT_TRUE(modelSpecMerged.getMergedCustomUpdateInitGroups().size() == 2); @@ -712,6 +723,12 @@ TEST(CustomUpdates, CompareDifferentUpdateGroup) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateCustomUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check correct groups are merged // **NOTE** update groups don't matter for initialization ASSERT_TRUE(modelSpecMerged.getMergedCustomUpdateGroups().size() == 2); @@ -787,6 +804,12 @@ TEST(CustomUpdates, CompareDifferentDelay) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateCustomUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check correct groups are merged // **NOTE** delay groups don't matter for initialization ASSERT_TRUE(modelSpecMerged.getMergedCustomUpdateGroups().size() == 3); @@ -830,18 +853,6 @@ TEST(CustomUpdates, CompareDifferentBatched) // Check that initialisation of batched and mixed can be merged but not update ASSERT_EQ(sum1Internal->getInitHashDigest(), sum3Internal->getInitHashDigest()); ASSERT_NE(sum1Internal->getHashDigest(), sum3Internal->getHashDigest()); - - // Create a backend - CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(preferences); - - // Merge model - CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); - - // Check correct groups are merged - // **NOTE** delay groups don't matter for initialization - ASSERT_TRUE(modelSpecMerged.getMergedCustomUpdateGroups().size() == 3); - ASSERT_TRUE(modelSpecMerged.getMergedCustomUpdateInitGroups().size() == 2); } //-------------------------------------------------------------------------- TEST(CustomUpdates, CompareDifferentWUTranspose) @@ -887,6 +898,12 @@ TEST(CustomUpdates, CompareDifferentWUTranspose) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateCustomUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check correct groups are merged // **NOTE** transpose variables don't matter for initialization ASSERT_TRUE(modelSpecMerged.getMergedCustomUpdateTransposeWUGroups().size() == 2); @@ -939,6 +956,12 @@ TEST(CustomUpdates, CompareDifferentWUConnectivity) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateCustomUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check correct groups are merged ASSERT_TRUE(modelSpecMerged.getMergedCustomUpdateTransposeWUGroups().empty()); ASSERT_TRUE(modelSpecMerged.getMergedCustomUpdateWUGroups().size() == 2); @@ -1000,6 +1023,12 @@ TEST(CustomUpdates, CompareDifferentWUBatched) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateCustomUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check correct groups are merged // **NOTE** delay groups don't matter for initialization ASSERT_TRUE(modelSpecMerged.getMergedCustomUpdateWUGroups().size() == 3); diff --git a/tests/unit/modelSpecMerged.cc b/tests/unit/modelSpecMerged.cc index 7e78558ed6..411669a53a 100644 --- a/tests/unit/modelSpecMerged.cc +++ b/tests/unit/modelSpecMerged.cc @@ -1,12 +1,16 @@ // Standard C++ includes #include +#include #include #include +#undef DUPLICATE + // Google test includes #include "gtest/gtest.h" // GeNN code generator includes +#include "code_generator/generateModules.h" #include "code_generator/modelSpecMerged.h" // (Single-threaded CPU) backend includes @@ -25,10 +29,10 @@ class AlphaCurr : public PostsynapticModels::Base DECLARE_SNIPPET(AlphaCurr); SET_DECAY_CODE( - "$(x) = (DT * $(expDecay) * $(inSyn) * $(init)) + ($(expDecay) * $(x));\n" - "$(inSyn)*=$(expDecay);\n"); + "x = (dt * expDecay * inSyn * init) + (expDecay * x);\n" + "inSyn *= expDecay;\n"); - SET_CURRENT_CONVERTER_CODE("$(x)"); + SET_CURRENT_CONVERTER_CODE("x"); SET_PARAM_NAMES({"tau"}); @@ -90,7 +94,7 @@ class Sum : public CustomUpdateModels::Base { DECLARE_SNIPPET(Sum); - SET_UPDATE_CODE("$(sum) = $(a) + $(b);\n"); + SET_UPDATE_CODE("sum = a + b;\n"); SET_VARS({{"sum", "scalar"}}); SET_PARAM_NAMES({"b"}); @@ -103,9 +107,7 @@ class OneToOneOff : public InitSparseConnectivitySnippet::Base public: DECLARE_SNIPPET(OneToOneOff); - SET_ROW_BUILD_CODE( - "$(addSynapse, $(id_pre) + 1);\n" - "$(endRow);\n"); + SET_ROW_BUILD_CODE("addSynapse(id_pre + 1);\n"); SET_MAX_ROW_LENGTH(1); SET_MAX_COL_LENGTH(1); @@ -118,13 +120,12 @@ class RemoveSynapse : public CustomConnectivityUpdateModels::Base DECLARE_SNIPPET(RemoveSynapse); SET_ROW_UPDATE_CODE( - "$(for_each_synapse,\n" - "{\n" - " if($(id_post) == ($(id_pre) + 1)) {\n" - " $(remove_synapse);\n" + "for_each_synapse {\n" + " if(id_post == (id_pre + 1)) {\n" + " remove_synapse();\n" " break;\n" " }\n" - "});\n"); + "};\n"); }; IMPLEMENT_SNIPPET(RemoveSynapse); @@ -137,13 +138,12 @@ class RemoveSynapsePrePost : public CustomConnectivityUpdateModels::Base SET_PRE_VARS({{"preThresh", "scalar"}}); SET_POST_VARS({{"postThresh", "scalar"}}); SET_ROW_UPDATE_CODE( - "$(for_each_synapse,\n" - "{\n" - " if($(g) < $(preThresh) || $(g) < $(postThresh)) {\n" - " $(remove_synapse);\n" + "for_each_synapse {\n" + " if(g < preThresh || g < postThresh) {\n" + " remove_synapse();\n" " break;\n" " }\n" - "});\n"); + "};\n"); }; IMPLEMENT_SNIPPET(RemoveSynapsePrePost); @@ -156,13 +156,12 @@ class RemoveSynapseParam : public CustomConnectivityUpdateModels::Base SET_PARAM_NAMES({"thresh"}); SET_ROW_UPDATE_CODE( - "$(for_each_synapse,\n" - "{\n" - " if($(g) < $(thresh)) {\n" - " $(remove_synapse);\n" + "for_each_synapse {\n" + " if(g < thresh) {\n" + " remove_synapse();\n" " break;\n" " }\n" - "});\n"); + "};\n"); }; IMPLEMENT_SNIPPET(RemoveSynapseParam); @@ -195,11 +194,20 @@ void test(const std::pair (&modelModifiers)[N], M applyModifierFn) // Create suitable backend to build model CodeGenerator::SingleThreadedCPU::Backend backend(preferences); - // Merge model - CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Created merged model + CodeGenerator::ModelSpecMerged modelMerged(model, backend); + + // Generate modules + // **NOTE** these are ordered in terms of memory-space priority + auto memorySpaces = backend.getMergedGroupMemorySpaces(modelMerged); + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateSynapseUpdate(outputPath, modelMerged, backend, memorySpaces); + generateNeuronUpdate(outputPath, modelMerged, backend, memorySpaces); + generateCustomUpdate(outputPath, modelMerged, backend, memorySpaces); + generateInit(outputPath, modelMerged, backend, memorySpaces); // Write hash digests of model to array - moduleHash[i] = modelSpecMerged.getHashDigest(backend); + moduleHash[i] = modelMerged.getHashDigest(backend); } // Loop through modified models @@ -320,7 +328,6 @@ TEST(ModelSpecMerged, CompareModelChanges) {[](ModelSpecInternal &model) { model.setTiming(true); }, false}, {[](ModelSpecInternal &model) { model.setPrecision(Type::Double); }, false}, {[](ModelSpecInternal &model) { model.setTimePrecision(Type::Double); }, false}, - {[](ModelSpecInternal &model) { model.setBatchSize(10); }, false}, {[](ModelSpecInternal &model) { model.setSeed(1234); }, false}}; test(modelModifiers, diff --git a/tests/unit/models.cc b/tests/unit/models.cc index b5af0cdb2a..18335c61a2 100644 --- a/tests/unit/models.cc +++ b/tests/unit/models.cc @@ -17,10 +17,10 @@ class AlphaCurr : public PostsynapticModels::Base DECLARE_SNIPPET(AlphaCurr); SET_DECAY_CODE( - "$(x) = (DT * $(expDecay) * $(inSyn) * $(init)) + ($(expDecay) * $(x));\n" - "$(inSyn)*=$(expDecay);\n"); + "x = (dt * expDecay * inSyn * init) + (expDecay * x);\n" + "inSyn *= expDecay;\n"); - SET_CURRENT_CONVERTER_CODE("$(x)"); + SET_CURRENT_CONVERTER_CODE("x"); SET_PARAM_NAMES({"tau"}); @@ -39,7 +39,7 @@ class StaticPulseUInt : public WeightUpdateModels::Base SET_PARAM_NAMES({"g"}); - SET_SIM_CODE("$(addToInSyn, $(g));\n"); + SET_SIM_CODE("addToPost(g);\n"); }; IMPLEMENT_SNIPPET(StaticPulseUInt); @@ -51,7 +51,7 @@ class Cont : public WeightUpdateModels::Base SET_VARS({{"g", "scalar"}}); SET_SYNAPSE_DYNAMICS_CODE( - "$(addToInSyn, $(g) * $(V_pre));\n"); + "addToPost(g * V_pre);\n"); }; IMPLEMENT_SNIPPET(Cont); @@ -65,15 +65,15 @@ class ContPrePost : public WeightUpdateModels::Base SET_POST_VARS({{"postTrace", "scalar"}}); SET_PRE_SPIKE_CODE( - "scalar dt = $(t) - $(sT_pre);\n" - "$(preTrace) = ($(preTrace) * exp(-dt / $(tauPlus))) + 1.0;\n"); + "scalar dt = t - sT_pre;\n" + "preTrace = (preTrace * exp(-dt / tauPlus)) + 1.0;\n"); SET_POST_SPIKE_CODE( - "scalar dt = $(t) - $(sT_post);\n" - "$(postTrace) = ($(postTrace) * exp(-dt / $(tauMinus))) + 1.0;\n"); + "scalar dt = t - sT_post;\n" + "postTrace = (postTrace * exp(-dt / tauMinus)) + 1.0;\n"); SET_SYNAPSE_DYNAMICS_CODE( - "$(addToInSyn, $(g) * $(V_pre));\n"); + "addToPost(g * V_pre);\n"); }; IMPLEMENT_SNIPPET(ContPrePost); @@ -87,15 +87,15 @@ class ContPrePostConstantWeight : public WeightUpdateModels::Base SET_POST_VARS({{"postTrace", "scalar"}}); SET_PRE_SPIKE_CODE( - "scalar dt = $(t) - $(sT_pre);\n" - "$(preTrace) = ($(preTrace) * exp(-dt / $(tauPlus))) + 1.0;\n"); + "scalar dt = t - sT_pre;\n" + "preTrace = (preTrace * exp(-dt / tauPlus)) + 1.0;\n"); SET_POST_SPIKE_CODE( - "scalar dt = $(t) - $(sT_post);\n" - "$(postTrace) = ($(postTrace) * exp(-dt / $(tauMinus))) + 1.0;\n"); + "scalar dt = t - sT_post;\n" + "postTrace = (postTrace * exp(-dt / tauMinus)) + 1.0;\n"); SET_SYNAPSE_DYNAMICS_CODE( - "$(addToInSyn, $(g) * $(V_pre));\n"); + "addToPost(g * V_pre);\n"); }; IMPLEMENT_SNIPPET(ContPrePostConstantWeight); } diff --git a/tests/unit/neuronGroup.cc b/tests/unit/neuronGroup.cc index ba834866ea..c1356edd3b 100644 --- a/tests/unit/neuronGroup.cc +++ b/tests/unit/neuronGroup.cc @@ -1,3 +1,7 @@ +// Standard C++ includes +#include +#undef DUPLICATE + // Google test includes #include "gtest/gtest.h" @@ -5,6 +9,7 @@ #include "modelSpecInternal.h" // GeNN code generator includes +#include "code_generator/generateModules.h" #include "code_generator/modelSpecMerged.h" // (Single-threaded CPU) backend includes @@ -72,7 +77,7 @@ class AlphaCurr : public PostsynapticModels::Base DECLARE_SNIPPET(AlphaCurr); SET_DECAY_CODE( - "$(x) = (DT * $(expDecay) * $(inSyn) * $(init)) + ($(expDecay) * $(x));\n" + "$(x) = (dt * $(expDecay) * $(inSyn) * $(init)) + ($(expDecay) * $(x));\n" "$(inSyn)*=$(expDecay);\n"); SET_CURRENT_CONVERTER_CODE("$(x)"); @@ -99,7 +104,7 @@ class LIFAdditional : public NeuronModels::Base " $(V) = alpha - ($(ExpTC) * (alpha - $(V)));\n" "}\n" "else {\n" - " $(RefracTime) -= DT;\n" + " $(RefracTime) -= dt;\n" "}\n" ); @@ -586,28 +591,34 @@ TEST(NeuronGroup, CompareNeuronModels) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateNeuronUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check all groups are merged ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 1); ASSERT_TRUE(modelSpecMerged.getMergedNeuronInitGroups().size() == 2); - // Find which merged neuron init group is the one with the single population i.e. the one with constant initialisers - const size_t constantInitIndex = (modelSpecMerged.getMergedNeuronInitGroups().at(0).getGroups().size() == 1) ? 0 : 1; - const auto &constantInitMergedGroup = modelSpecMerged.getMergedNeuronInitGroups().at(constantInitIndex); - const auto &uniformInitMergedGroup = modelSpecMerged.getMergedNeuronInitGroups().at(1 - constantInitIndex); - // Check that only 'd' parameter is heterogeneous in neuron update group ASSERT_FALSE(modelSpecMerged.getMergedNeuronUpdateGroups().at(0).isParamHeterogeneous("a")); ASSERT_FALSE(modelSpecMerged.getMergedNeuronUpdateGroups().at(0).isParamHeterogeneous("b")); ASSERT_FALSE(modelSpecMerged.getMergedNeuronUpdateGroups().at(0).isParamHeterogeneous("c")); ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().at(0).isParamHeterogeneous("d")); + // Find which merged neuron init group is the one with the single population i.e. the one with constant initialisers + const size_t constantInitIndex = (modelSpecMerged.getMergedNeuronInitGroups().at(0).getGroups().size() == 1) ? 0 : 1; + const auto &constantInitMergedGroup = modelSpecMerged.getMergedNeuronInitGroups().at(constantInitIndex); + const auto &uniformInitMergedGroup = modelSpecMerged.getMergedNeuronInitGroups().at(1 - constantInitIndex); + // Check that only 'V' init 'min' parameter is heterogeneous ASSERT_FALSE(constantInitMergedGroup.isVarInitParamHeterogeneous("V", "constant")); ASSERT_FALSE(constantInitMergedGroup.isVarInitParamHeterogeneous("U", "constant")); ASSERT_TRUE(uniformInitMergedGroup.isVarInitParamHeterogeneous("V", "min")); ASSERT_FALSE(uniformInitMergedGroup.isVarInitParamHeterogeneous("V", "max")); - ASSERT_FALSE(uniformInitMergedGroup.isVarInitParamHeterogeneous("U", "min")); - ASSERT_FALSE(uniformInitMergedGroup.isVarInitParamHeterogeneous("U", "max")); + ASSERT_FALSE(uniformInitMergedGroup.isVarInitParamHeterogeneous("U", "constant")); + ASSERT_FALSE(uniformInitMergedGroup.isVarInitParamHeterogeneous("U", "constant")); } TEST(NeuronGroup, CompareHeterogeneousParamVarState) @@ -636,6 +647,12 @@ TEST(NeuronGroup, CompareHeterogeneousParamVarState) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateNeuronUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check all groups are merged ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 1); ASSERT_TRUE(modelSpecMerged.getMergedNeuronInitGroups().size() == 1); @@ -735,6 +752,11 @@ TEST(NeuronGroup, CompareCurrentSources) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateNeuronUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check neurons are merged into two groups ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 2); @@ -748,11 +770,10 @@ TEST(NeuronGroup, CompareCurrentSources) const size_t poissonIndex = (dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(0).getArchetype().getCurrentSourceModel() == CurrentSourceModels::PoissonExp::getInstance()) ? 0 : 1; // Check that only the ExpDecay and Init derived parameters of the poisson exp current sources are heterogeneous - // **NOTE** tauSyn is not heterogeneous because it's not referenced directly ASSERT_FALSE(dcDCMergedGroup.getMergedCurrentSourceGroups().at(0).isParamHeterogeneous("amp")); ASSERT_FALSE(dcDCMergedGroup.getMergedCurrentSourceGroups().at(1).isParamHeterogeneous("amp")); ASSERT_FALSE(dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(poissonIndex).isParamHeterogeneous("weight")); - ASSERT_FALSE(dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(poissonIndex).isParamHeterogeneous("tauSyn")); + ASSERT_TRUE(dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(poissonIndex).isParamHeterogeneous("tauSyn")); ASSERT_FALSE(dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(poissonIndex).isParamHeterogeneous("rate")); ASSERT_FALSE(dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(1 - poissonIndex).isParamHeterogeneous("amp")); ASSERT_TRUE(dcPoissonMergedGroup.getMergedCurrentSourceGroups().at(poissonIndex).isDerivedParamHeterogeneous("ExpDecay")); @@ -862,6 +883,12 @@ TEST(NeuronGroup, ComparePostsynapticModels) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateNeuronUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check neurons are merged into three groups ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 3); ASSERT_TRUE(modelSpecMerged.getMergedNeuronInitGroups().size() == 3); @@ -879,8 +906,7 @@ TEST(NeuronGroup, ComparePostsynapticModels) const size_t alphaInitIndex = (deltaAlphaMergedInitGroup->getMergedInSynPSMGroups().at(0).getArchetype().getPSModel() == AlphaCurr::getInstance()) ? 0 : 1; // Check that parameter and both derived parameters are heterogeneous - // **NOTE** tau is NOT heterogeneous because it's unused - ASSERT_FALSE(deltaAlphaMergedUpdateGroup->getMergedInSynPSMGroups().at(alphaUpdateIndex).isParamHeterogeneous("tau")); + ASSERT_TRUE(deltaAlphaMergedUpdateGroup->getMergedInSynPSMGroups().at(alphaUpdateIndex).isParamHeterogeneous("tau")); ASSERT_TRUE(deltaAlphaMergedUpdateGroup->getMergedInSynPSMGroups().at(alphaUpdateIndex).isDerivedParamHeterogeneous("expDecay")); ASSERT_TRUE(deltaAlphaMergedUpdateGroup->getMergedInSynPSMGroups().at(alphaUpdateIndex).isDerivedParamHeterogeneous("init")); ASSERT_TRUE(deltaAlphaMergedInitGroup->getMergedInSynPSMGroups().at(alphaInitIndex).isVarInitParamHeterogeneous("x", "constant")); @@ -972,6 +998,12 @@ TEST(NeuronGroup, ComparePreOutput) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateNeuronUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check neurons are merged into six groups (one for each output group and one for each number of incoming synapses) ASSERT_EQ(modelSpecMerged.getMergedNeuronUpdateGroups().size(), 6); ASSERT_EQ(modelSpecMerged.getMergedNeuronInitGroups().size(), 6); @@ -1061,6 +1093,12 @@ TEST(NeuronGroup, CompareWUPreUpdate) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateNeuronUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check neuron init and update is merged into three groups (NG0 with no outsyns, NG1, NG2, NG3 and NG5 with 1 outsyn and NG4with 2 outsyns) ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 3); ASSERT_TRUE(modelSpecMerged.getMergedNeuronInitGroups().size() == 3); @@ -1162,6 +1200,12 @@ TEST(NeuronGroup, CompareWUPostUpdate) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateNeuronUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check neurons are merged into three groups ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 3); ASSERT_TRUE(modelSpecMerged.getMergedNeuronInitGroups().size() == 3); diff --git a/tests/unit/neuronModels.cc b/tests/unit/neuronModels.cc index af6912cda5..722752882f 100644 --- a/tests/unit/neuronModels.cc +++ b/tests/unit/neuronModels.cc @@ -19,7 +19,7 @@ class LIFCopy : public NeuronModels::Base " $(V) = alpha - ($(ExpTC) * (alpha - $(V)));\n" "}\n" "else {\n" - " $(RefracTime) -= DT;\n" + " $(RefracTime) -= dt;\n" "}\n" ); diff --git a/tests/unit/synapseGroup.cc b/tests/unit/synapseGroup.cc index 0e674eacc6..b139200e40 100644 --- a/tests/unit/synapseGroup.cc +++ b/tests/unit/synapseGroup.cc @@ -1,3 +1,7 @@ +// Standard C++ includes +#include +#undef DUPLICATE + // Google test includes #include "gtest/gtest.h" @@ -5,6 +9,7 @@ #include "modelSpecInternal.h" // GeNN code generator includes +#include "code_generator/generateModules.h" #include "code_generator/modelSpecMerged.h" // (Single-threaded CPU) backend includes @@ -188,7 +193,7 @@ class GradedDenDelay : public WeightUpdateModels::Base SET_PARAM_NAMES({"g"}); SET_EVENT_THRESHOLD_CONDITION_CODE("V_pre >= 0.1"); - SET_EVENT_CODE("addToInSynDelay(g * V_pre, 1);"); + SET_EVENT_CODE("addToPostDelay(g * V_pre, 1);"); }; IMPLEMENT_SNIPPET(GradedDenDelay); @@ -208,9 +213,9 @@ class StaticPulseDynamics : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulseDynamics); - SET_VARS({ {"g", "scalar", VarAccess::READ_ONLY} }); + SET_VARS({ {"g", "scalar"} }); - SET_SIM_CODE("addToInSyn(g);\n"); + SET_SIM_CODE("addToPost(g);\n"); SET_SYNAPSE_DYNAMICS_CODE("g *= 0.99;\n"); }; IMPLEMENT_SNIPPET(StaticPulseDynamics); @@ -220,9 +225,9 @@ class StaticPulsePostLearn : public WeightUpdateModels::Base public: DECLARE_SNIPPET(StaticPulsePostLearn); - SET_VARS({ {"g", "scalar", VarAccess::READ_ONLY} }); + SET_VARS({ {"g", "scalar"} }); - SET_SIM_CODE("addToInSyn(g);\n"); + SET_SIM_CODE("addToPost(g);\n"); SET_LEARN_POST_CODE("g *= 0.99;\n"); }; IMPLEMENT_SNIPPET(StaticPulsePostLearn); @@ -389,6 +394,13 @@ TEST(SynapseGroup, CompareWUDifferentModel) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateNeuronUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateSynapseUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check all groups are merged ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 2); ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().size() == 2); @@ -399,7 +411,7 @@ TEST(SynapseGroup, CompareWUDifferentModel) } -TEST(SynapseGroup, CompareWUDifferentGlobalG) +TEST(SynapseGroup, CompareWUDifferentParams) { ModelSpecInternal model; @@ -442,6 +454,13 @@ TEST(SynapseGroup, CompareWUDifferentGlobalG) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateNeuronUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateSynapseUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check all groups are merged ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 2); ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().size() == 1); @@ -492,25 +511,6 @@ TEST(SynapseGroup, CompareWUDifferentProceduralConnectivity) SynapseGroupInternal *sg2Internal = static_cast(sg2); ASSERT_EQ(sg0Internal->getWUHashDigest(), sg1Internal->getWUHashDigest()); ASSERT_EQ(sg0Internal->getWUHashDigest(), sg2Internal->getWUHashDigest()); - - // Create a backend - CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(preferences); - - // Merge model - CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); - - // Check all groups are merged - ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 2); - ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().size() == 1); - ASSERT_TRUE(modelSpecMerged.getMergedSynapseConnectivityInitGroups().empty()); - ASSERT_TRUE(modelSpecMerged.getMergedSynapseInitGroups().empty()); - ASSERT_TRUE(modelSpecMerged.getMergedSynapseSparseInitGroups().empty()); - - // Check that connectivity parameter is heterogeneous - // **NOTE** raw parameter is NOT as only derived parameter is used in code - ASSERT_FALSE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isSparseConnectivityInitParamHeterogeneous("prob")); - ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isSparseConnectivityInitDerivedParamHeterogeneous("probLogRecip")); } @@ -566,6 +566,13 @@ TEST(SynapseGroup, CompareWUDifferentToeplitzConnectivity) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateNeuronUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateSynapseUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check all groups are merged ASSERT_EQ(modelSpecMerged.getMergedNeuronUpdateGroups().size(), 3); ASSERT_EQ(modelSpecMerged.getMergedPresynapticUpdateGroups().size(), 1); @@ -624,26 +631,6 @@ TEST(SynapseGroup, CompareWUDifferentProceduralVars) SynapseGroupInternal *sg2Internal = static_cast(sg2); ASSERT_EQ(sg0Internal->getWUHashDigest(), sg1Internal->getWUHashDigest()); ASSERT_EQ(sg0Internal->getWUHashDigest(), sg2Internal->getWUHashDigest()); - - // Create a backend - CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(preferences); - - // Merge model - CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); - - // Check all groups are merged - ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 2); - ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().size() == 1); - ASSERT_TRUE(modelSpecMerged.getMergedSynapseConnectivityInitGroups().empty()); - ASSERT_TRUE(modelSpecMerged.getMergedSynapseInitGroups().empty()); - ASSERT_TRUE(modelSpecMerged.getMergedSynapseSparseInitGroups().empty()); - - // Check that only synaptic weight initialistion parameters are heterogeneous - ASSERT_FALSE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isSparseConnectivityInitParamHeterogeneous("prob")); - ASSERT_FALSE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isSparseConnectivityInitDerivedParamHeterogeneous("probLogRecip")); - ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isVarInitParamHeterogeneous("g", "min")); - ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().at(0).isVarInitParamHeterogeneous("g", "max")); } TEST(SynapseGroup, CompareWUDifferentProceduralSnippet) @@ -678,20 +665,6 @@ TEST(SynapseGroup, CompareWUDifferentProceduralSnippet) SynapseGroupInternal *sg2Internal = static_cast(sg2); ASSERT_EQ(sg0Internal->getWUHashDigest(), sg1Internal->getWUHashDigest()); ASSERT_NE(sg0Internal->getWUHashDigest(), sg2Internal->getWUHashDigest()); - - // Create a backend - CodeGenerator::SingleThreadedCPU::Preferences preferences; - CodeGenerator::SingleThreadedCPU::Backend backend(preferences); - - // Merge model - CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); - - // Check all groups are merged - ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 2); - ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().size() == 2); - ASSERT_TRUE(modelSpecMerged.getMergedSynapseConnectivityInitGroups().empty()); - ASSERT_TRUE(modelSpecMerged.getMergedSynapseInitGroups().empty()); - ASSERT_TRUE(modelSpecMerged.getMergedSynapseSparseInitGroups().empty()); } TEST(SynapseGroup, InitCompareWUDifferentVars) @@ -745,6 +718,13 @@ TEST(SynapseGroup, InitCompareWUDifferentVars) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateNeuronUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateSynapseUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check all groups are merged ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 2); ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().size() == 1); @@ -886,6 +866,13 @@ TEST(SynapseGroup, InitCompareWUDifferentHeterogeneousParamVarState) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateNeuronUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateSynapseUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check all groups are merged ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 2); ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().size() == 1); @@ -946,6 +933,13 @@ TEST(SynapseGroup, InitCompareWUSynapseDynamicsPostLearn) // Merge model CodeGenerator::ModelSpecMerged modelSpecMerged(model, backend); + // Generate required modules + // **NOTE** these are ordered in terms of memory-space priority + const filesystem::path outputPath = std::filesystem::temp_directory_path(); + generateNeuronUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateSynapseUpdate(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + generateInit(outputPath, modelSpecMerged, backend, CodeGenerator::BackendBase::MemorySpaces{}); + // Check all groups are merged ASSERT_TRUE(modelSpecMerged.getMergedNeuronUpdateGroups().size() == 2); ASSERT_TRUE(modelSpecMerged.getMergedPresynapticUpdateGroups().size() == 3); From 9283d919e3dba7107d4b25d94f6a9227336ec7ca Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 20 Feb 2023 17:57:37 +0000 Subject: [PATCH 382/725] mirror logic for fusing wum pre and postsynaptic updates for PSM --- src/genn/genn/synapseGroup.cc | 42 ++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 64a24651e5..4c1ed361ec 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -567,11 +567,34 @@ void SynapseGroup::finalise(double dt) //---------------------------------------------------------------------------- bool SynapseGroup::canPSBeFused() const { - // Return true if there are no variables or extra global parameters - // **NOTE** many models with variables would work fine, but - // nothing stops initialisers being used to configure PS models - // to behave totally different, similarly with EGPs - return (getPSVarInitialisers().empty() && getPSModel()->getExtraGlobalParams().empty()); + // If any postsynaptic model variables aren't initialised to constant values, this synapse group's postsynaptic model can't be merged + // **NOTE** hash check will compare these constant values + if(std::any_of(getPSVarInitialisers().cbegin(), getPSVarInitialisers().cend(), + [](const Models::VarInit &v){ return (dynamic_cast(v.getSnippet()) == nullptr); })) + { + return false; + } + + // Loop through EGPs + // **NOTE** this is kind of silly as, if it's not referenced in either of + // these code strings, there wouldn't be a lot of point in a PSM EGP existing! + const auto psmEGPs = getPSModel()->getExtraGlobalParams(); + const std::string decayCode = getPSModel()->getDecayCode(); + const std::string applyInputCode = getPSModel()->getApplyInputCode(); + for(const auto &egp : psmEGPs) { + // If this EGP is referenced in decay code, return false + const std::string egpName = "$(" + egp.name + ")"; + if(decayCode.find(egpName) != std::string::npos) { + return false; + } + + // If this EGP is referenced in apply input code, return false + if(applyInputCode.find(egpName) != std::string::npos) { + return false; + } + } + + return true; } //---------------------------------------------------------------------------- bool SynapseGroup::canWUMPreUpdateBeFused() const @@ -821,6 +844,15 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getPSFuseHashDigest() cons Utils::updateHash(getPSTargetVar(), hash); Utils::updateHash(getPSParams(), hash); Utils::updateHash(getPSDerivedParams(), hash); + + // Loop through PSM variable initialisers and hash first parameter. + // Due to SynapseGroup::canPSBeFused, all initialiser snippets + // will be constant and have a single parameter containing the value + for(const auto &w : getPSVarInitialisers()) { + assert(w.getParams().size() == 1); + Utils::updateHash(w.getParams().at(0), hash); + } + return hash.get_digest(); } //---------------------------------------------------------------------------- From d2086747a9ab9a51f9b54e00ef546905ec2a3560 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 20 Feb 2023 18:05:46 +0000 Subject: [PATCH 383/725] unit test for fusing PSMs with variables --- tests/unit/neuronGroup.cc | 78 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tests/unit/neuronGroup.cc b/tests/unit/neuronGroup.cc index c1356edd3b..6a50eecbc9 100644 --- a/tests/unit/neuronGroup.cc +++ b/tests/unit/neuronGroup.cc @@ -495,6 +495,13 @@ TEST(NeuronGroup, FusePSM) auto synTargetInternal = static_cast(synTarget); auto synDelayInternal = static_cast(synDelay); + // Check all groups can be fused + ASSERT_TRUE(synInternal->canPSBeFused()); + ASSERT_TRUE(syn2Internal->canPSBeFused()); + ASSERT_TRUE(synParamInternal->canPSBeFused()); + ASSERT_TRUE(synTargetInternal->canPSBeFused()); + ASSERT_TRUE(synDelayInternal->canPSBeFused()); + // Check that identically configured PSMs can be merged ASSERT_EQ(synInternal->getFusedPSVarSuffix(), syn2Internal->getFusedPSVarSuffix()); @@ -508,6 +515,77 @@ TEST(NeuronGroup, FusePSM) ASSERT_NE(synInternal->getFusedPSVarSuffix(), synDelayInternal->getFusedPSVarSuffix()); } +TEST(NeuronGroup, FuseVarPSM) +{ + ModelSpecInternal model; + model.setMergePostsynapticModels(true); + + LIFAdditional::ParamValues paramVals(0.25, 10.0, 0.0, 0.0, 20.0, 0.0, 5.0); + LIFAdditional::VarValues varVals(0.0, 0.0); + AlphaCurr::ParamValues psmParamVals(5.0); + AlphaCurr::VarValues psmVarValsConst1(0.0); + AlphaCurr::VarValues psmVarValsConst2(1.0); + AlphaCurr::VarValues psmVarValsRand(initVar({0.0, 1.0})); + WeightUpdateModels::StaticPulseDendriticDelay::VarValues wumVarVals(0.1, 10); + + // Add two neuron groups to model + auto *pre = model.addNeuronPopulation("Pre", 10, paramVals, varVals); + auto *post = model.addNeuronPopulation("Post", 10, paramVals, varVals); + + // Create baseline synapse group + auto *syn1 = model.addSynapsePopulation( + "Syn1", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Pre", "Post", + {}, wumVarVals, + psmParamVals, psmVarValsConst1); + + // Create second synapse group with same model and constant initialisers + auto *syn2 = model.addSynapsePopulation( + "Syn2", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Pre", "Post", + {}, wumVarVals, + psmParamVals, psmVarValsConst1); + + // Create third synapse group with same model and different constant initialisers + auto *syn3 = model.addSynapsePopulation( + "Syn3", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Pre", "Post", + {}, wumVarVals, + psmParamVals, psmVarValsConst2); + + // Create fourth synapse group with same model and random variable initialisers + auto *syn4 = model.addSynapsePopulation( + "Syn4", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Pre", "Post", + {}, wumVarVals, + psmParamVals, psmVarValsRand); + + + // **TODO** third safe group with different variable initialisers + model.finalize(); + + // Cast neuron groups to internal types + auto preInternal = static_cast(pre); + auto postInternal = static_cast(post); + + // Cast synapse groups to internal types + auto syn1Internal = static_cast(syn1); + auto syn2Internal = static_cast(syn2); + auto syn3Internal = static_cast(syn3); + auto syn4Internal = static_cast(syn4); + + // Check only groups with 'safe' model can be fused + ASSERT_TRUE(syn1Internal->canPSBeFused()); + ASSERT_TRUE(syn2Internal->canPSBeFused()); + ASSERT_TRUE(syn3Internal->canPSBeFused()); + ASSERT_FALSE(syn4Internal->canPSBeFused()); + + // Check that identically configured PSMs can be merged + ASSERT_EQ(syn1Internal->getFusedPSVarSuffix(), syn2Internal->getFusedPSVarSuffix()); + + ASSERT_TRUE(preInternal->getFusedPSMInSyn().empty()); + ASSERT_EQ(postInternal->getFusedPSMInSyn().size(), 3); +} TEST(NeuronGroup, FusePreOutput) { ModelSpecInternal model; From 962e3801000ead9b074bee1e1447effeb8fa3f63 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 17:34:00 +0100 Subject: [PATCH 384/725] need to turn on WIN32_LEAN_AND_MEAN to get rid of stupid DUPLICATE macro clash --- tests/unit/unit.vcxproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/unit.vcxproj b/tests/unit/unit.vcxproj index 4289174ba7..cf2c7c772e 100644 --- a/tests/unit/unit.vcxproj +++ b/tests/unit/unit.vcxproj @@ -68,7 +68,7 @@ MaxSpeed true ..\..\include\genn\genn;..\..\include\genn\third_party;..\..\include\genn\backends\single_threaded_cpu;$(GTEST_DIR);$(GTEST_DIR)/include;%(AdditionalIncludeDirectories) - NOMINMAX;_SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) + NOMINMAX;WIN32_LEAN_AND_MEAN;_SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) stdcpp17 true From 11d31bd6c26ffe29772014a3591a033dc23941fd Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 17:40:26 +0100 Subject: [PATCH 385/725] fixed up fusing logic --- src/genn/genn/synapseGroup.cc | 36 +++++++-------------- tests/unit/customUpdate.cc | 1 - tests/unit/initSparseConnectivitySnippet.cc | 3 +- tests/unit/modelSpecMerged.cc | 2 -- tests/unit/neuronGroup.cc | 27 ++++++++-------- tests/unit/synapseGroup.cc | 1 - 6 files changed, 26 insertions(+), 44 deletions(-) diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 4c1ed361ec..91db21c27c 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -570,7 +570,7 @@ bool SynapseGroup::canPSBeFused() const // If any postsynaptic model variables aren't initialised to constant values, this synapse group's postsynaptic model can't be merged // **NOTE** hash check will compare these constant values if(std::any_of(getPSVarInitialisers().cbegin(), getPSVarInitialisers().cend(), - [](const Models::VarInit &v){ return (dynamic_cast(v.getSnippet()) == nullptr); })) + [](const auto &v){ return (dynamic_cast(v.second.getSnippet()) == nullptr); })) { return false; } @@ -578,18 +578,14 @@ bool SynapseGroup::canPSBeFused() const // Loop through EGPs // **NOTE** this is kind of silly as, if it's not referenced in either of // these code strings, there wouldn't be a lot of point in a PSM EGP existing! - const auto psmEGPs = getPSModel()->getExtraGlobalParams(); - const std::string decayCode = getPSModel()->getDecayCode(); - const std::string applyInputCode = getPSModel()->getApplyInputCode(); - for(const auto &egp : psmEGPs) { + for(const auto &egp : getPSModel()->getExtraGlobalParams()) { // If this EGP is referenced in decay code, return false - const std::string egpName = "$(" + egp.name + ")"; - if(decayCode.find(egpName) != std::string::npos) { + if(Utils::isIdentifierReferenced(egp.name, getPSDecayCodeTokens())) { return false; } // If this EGP is referenced in apply input code, return false - if(applyInputCode.find(egpName) != std::string::npos) { + if(Utils::isIdentifierReferenced(egp.name, getPSApplyInputCodeTokens())) { return false; } } @@ -608,18 +604,14 @@ bool SynapseGroup::canWUMPreUpdateBeFused() const } // Loop through EGPs - const auto wumEGPs = getWUModel()->getExtraGlobalParams(); - const std::string preSpikeCode = getWUModel()->getPreSpikeCode(); - const std::string preDynamicsCode = getWUModel()->getPreDynamicsCode(); - for(const auto &egp : wumEGPs) { + for(const auto &egp : getWUModel()->getExtraGlobalParams()) { // If this EGP is referenced in presynaptic spike code, return false - const std::string egpName = "$(" + egp.name + ")"; - if(preSpikeCode.find(egpName) != std::string::npos) { + if(Utils::isIdentifierReferenced(egp.name, getWUPreSpikeCodeTokens())) { return false; } // If this EGP is referenced in presynaptic dynamics code, return false - if(preDynamicsCode.find(egpName) != std::string::npos) { + if(Utils::isIdentifierReferenced(egp.name, getWUPreDynamicsCodeTokens())) { return false; } } @@ -637,18 +629,14 @@ bool SynapseGroup::canWUMPostUpdateBeFused() const } // Loop through EGPs - const auto wumEGPs = getWUModel()->getExtraGlobalParams(); - const std::string postSpikeCode = getWUModel()->getPostSpikeCode(); - const std::string postDynamicsCode = getWUModel()->getPostDynamicsCode(); - for(const auto &egp : wumEGPs) { + for(const auto &egp : getWUModel()->getExtraGlobalParams()) { // If this EGP is referenced in postsynaptic spike code, return false - const std::string egpName = "$(" + egp.name + ")"; - if(postSpikeCode.find(egpName) != std::string::npos) { + if(Utils::isIdentifierReferenced(egp.name, getWUPostSpikeCodeTokens())) { return false; } // If this EGP is referenced in postsynaptic dynamics code, return false - if(postDynamicsCode.find(egpName) != std::string::npos) { + if(Utils::isIdentifierReferenced(egp.name, getWUPostDynamicsCodeTokens())) { return false; } } @@ -849,8 +837,8 @@ boost::uuids::detail::sha1::digest_type SynapseGroup::getPSFuseHashDigest() cons // Due to SynapseGroup::canPSBeFused, all initialiser snippets // will be constant and have a single parameter containing the value for(const auto &w : getPSVarInitialisers()) { - assert(w.getParams().size() == 1); - Utils::updateHash(w.getParams().at(0), hash); + assert(w.second.getParams().size() == 1); + Utils::updateHash(w.second.getParams().at("constant"), hash); } return hash.get_digest(); diff --git a/tests/unit/customUpdate.cc b/tests/unit/customUpdate.cc index 15719b42d2..a7e1ac20fc 100644 --- a/tests/unit/customUpdate.cc +++ b/tests/unit/customUpdate.cc @@ -1,6 +1,5 @@ // Standard C++ includes #include -#undef DUPLICATE // Google test includes #include "gtest/gtest.h" diff --git a/tests/unit/initSparseConnectivitySnippet.cc b/tests/unit/initSparseConnectivitySnippet.cc index f4da306ac2..b93dbb6a28 100644 --- a/tests/unit/initSparseConnectivitySnippet.cc +++ b/tests/unit/initSparseConnectivitySnippet.cc @@ -13,8 +13,7 @@ class OneToOneCopy : public InitSparseConnectivitySnippet::Base { public: SET_ROW_BUILD_CODE( - "$(addSynapse, $(id_pre));\n" - "$(endRow);\n"); + "addSynapse(id_pre);\n"); SET_MAX_ROW_LENGTH(1); SET_MAX_COL_LENGTH(1); diff --git a/tests/unit/modelSpecMerged.cc b/tests/unit/modelSpecMerged.cc index 411669a53a..0cf2892fa9 100644 --- a/tests/unit/modelSpecMerged.cc +++ b/tests/unit/modelSpecMerged.cc @@ -4,8 +4,6 @@ #include #include -#undef DUPLICATE - // Google test includes #include "gtest/gtest.h" diff --git a/tests/unit/neuronGroup.cc b/tests/unit/neuronGroup.cc index 6a50eecbc9..8c7c7a8136 100644 --- a/tests/unit/neuronGroup.cc +++ b/tests/unit/neuronGroup.cc @@ -1,6 +1,5 @@ // Standard C++ includes #include -#undef DUPLICATE // Google test includes #include "gtest/gtest.h" @@ -520,49 +519,49 @@ TEST(NeuronGroup, FuseVarPSM) ModelSpecInternal model; model.setMergePostsynapticModels(true); - LIFAdditional::ParamValues paramVals(0.25, 10.0, 0.0, 0.0, 20.0, 0.0, 5.0); - LIFAdditional::VarValues varVals(0.0, 0.0); - AlphaCurr::ParamValues psmParamVals(5.0); - AlphaCurr::VarValues psmVarValsConst1(0.0); - AlphaCurr::VarValues psmVarValsConst2(1.0); - AlphaCurr::VarValues psmVarValsRand(initVar({0.0, 1.0})); - WeightUpdateModels::StaticPulseDendriticDelay::VarValues wumVarVals(0.1, 10); - + ParamValues paramVals{{"C", 0.25}, {"TauM", 10.0}, {"Vrest", 0.0}, {"Vreset", 0.0}, {"Vthresh", 20.0}, {"Ioffset", 0.0}, {"TauRefrac", 5.0}}; + VarValues varVals{{"V", 0.0}, {"RefracTime", 0.0}}; + ParamValues psmParamVals{{"tau", 5.0}}; + VarValues psmVarValsConst1{{"x", 0.0}}; + VarValues psmVarValsConst2{{"x", 1.0}}; + VarValues psmVarValsRand{{"x", initVar({{"min", 0.0}, {"max", 1.0}})}}; + VarValues wumVarVals{{"g", 0.1}, {"d", 10}}; + // Add two neuron groups to model auto *pre = model.addNeuronPopulation("Pre", 10, paramVals, varVals); auto *post = model.addNeuronPopulation("Post", 10, paramVals, varVals); // Create baseline synapse group auto *syn1 = model.addSynapsePopulation( - "Syn1", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Syn1", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, wumVarVals, psmParamVals, psmVarValsConst1); // Create second synapse group with same model and constant initialisers auto *syn2 = model.addSynapsePopulation( - "Syn2", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Syn2", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, wumVarVals, psmParamVals, psmVarValsConst1); // Create third synapse group with same model and different constant initialisers auto *syn3 = model.addSynapsePopulation( - "Syn3", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Syn3", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, wumVarVals, psmParamVals, psmVarValsConst2); // Create fourth synapse group with same model and random variable initialisers auto *syn4 = model.addSynapsePopulation( - "Syn4", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, + "Syn4", SynapseMatrixType::DENSE, NO_DELAY, "Pre", "Post", {}, wumVarVals, psmParamVals, psmVarValsRand); // **TODO** third safe group with different variable initialisers - model.finalize(); + model.finalise(); // Cast neuron groups to internal types auto preInternal = static_cast(pre); diff --git a/tests/unit/synapseGroup.cc b/tests/unit/synapseGroup.cc index b139200e40..e8d8aecbe1 100644 --- a/tests/unit/synapseGroup.cc +++ b/tests/unit/synapseGroup.cc @@ -1,6 +1,5 @@ // Standard C++ includes #include -#undef DUPLICATE // Google test includes #include "gtest/gtest.h" From 9b766fc9c828a9ce15be2c7eafd695e1ea8ce67e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 17:47:24 +0100 Subject: [PATCH 386/725] manually cherry picked batched prev spike logic --- src/genn/genn/code_generator/backendSIMT.cc | 24 +++++++++++++-------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 129320335f..987efa971f 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -354,7 +354,7 @@ void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &en } } } - // Otherwises + // Otherwise else { if(batchSize > 1) { neuronEnv.printLine("const unsigned int batchOffset = $(num_neurons) * $(batch);"); @@ -364,11 +364,14 @@ void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &en neuronEnv.print("if($(id) < $(_spk_cnt)[$(batch)])"); { CodeStream::Scope b(neuronEnv.getStream()); - neuronEnv.print("$(_prev_spk_time)[$(_spk)["); - if(batchSize > 1) { - neuronEnv.getStream() << "batchOffset + "; + neuronEnv.print("$(_prev_spk_time)["); + if (batchSize == 1) { + neuronEnv.print("$(_spk)[$(id)]"); + } + else { + neuronEnv.print("batchOffset + $(_spk)[batchOffset + $(id)]"); } - neuronEnv.printLine("$(id)]] = $(t) - DT;"); + neuronEnv.printLine("] = $(t) - $(dt);"); } } if(ng.getArchetype().isPrevSpikeEventTimeRequired()) { @@ -376,11 +379,14 @@ void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &en neuronEnv.print("if($(id) < $(_spk_cnt_evnt)[$(batch)])"); { CodeStream::Scope b(neuronEnv.getStream()); - neuronEnv.print("$(_prev_spk_evnt_time)[$(_spk_evnt)["); - if(batchSize > 1) { - neuronEnv.getStream() << "batchOffset + "; + neuronEnv.print("$(_prev_spk_evnt_time)["); + if (batchSize == 1) { + neuronEnv.print("$(_spk_evnt)[$(id)]"); + } + else { + neuronEnv.print("batchOffset + $(_spk_evnt)[batchOffset + $(id)]"); } - neuronEnv.printLine("$(id)]] = $(t) - DT;"); + neuronEnv.printLine("] = $(t) - $(dt);"); } } } From 6fa67b91bd6cf689d3242123f60578c8d4be9a2e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 6 Feb 2023 15:43:39 +0000 Subject: [PATCH 387/725] updated batch_prev_pre_spike_time_in_sim test to cover non-delayed --- .../batch_prev_pre_spike_time_in_sim/model.cc | 11 +++++++++-- .../features/batch_prev_pre_spike_time_in_sim/test.cc | 11 +++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/features/batch_prev_pre_spike_time_in_sim/model.cc b/tests/features/batch_prev_pre_spike_time_in_sim/model.cc index 464d0eeb4d..2e4aaffd83 100644 --- a/tests/features/batch_prev_pre_spike_time_in_sim/model.cc +++ b/tests/features/batch_prev_pre_spike_time_in_sim/model.cc @@ -66,13 +66,20 @@ void modelDefinition(ModelSpec &model) model.setBatchSize(2); model.addNeuronPopulation("pre", 10, {}, {}); + model.addNeuronPopulation("preDelay", 10, {}, {}); model.addNeuronPopulation("post", 10, {}, {}); + model.addNeuronPopulation("postDelay", 10, {}, {}); model.addSynapsePopulation( - "syn", SynapseMatrixType::SPARSE_INDIVIDUALG, 20, "pre", "post", + "syn", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, "pre", "post", + {}, WeightUpdateModel::VarValues(-std::numeric_limits::max()), + {}, {}, + initConnectivity({})); + + model.addSynapsePopulation( + "synDelay", SynapseMatrixType::SPARSE_INDIVIDUALG, 20, "preDelay", "postDelay", {}, WeightUpdateModel::VarValues(-std::numeric_limits::max()), {}, {}, initConnectivity({})); - model.setPrecision(GENN_FLOAT); } diff --git a/tests/features/batch_prev_pre_spike_time_in_sim/test.cc b/tests/features/batch_prev_pre_spike_time_in_sim/test.cc index bd8429bc97..e99db835bd 100644 --- a/tests/features/batch_prev_pre_spike_time_in_sim/test.cc +++ b/tests/features/batch_prev_pre_spike_time_in_sim/test.cc @@ -37,9 +37,16 @@ class SimTest : public SimulationTest // 2) PREVIOUS spike occurred (-)20 timesteps before // 3) t is incremented one timestep at the end of StepGeNN const float delayedLastSpikeTime = (scalar)i + 1.0f + (20.0f * std::floor((t - 22.0f - (scalar)i) / 20.0f)); - - // If, theoretically, spike would have arrived before delay it's impossible so time should be a very large negative number + + // Check wsyn read from delayed population if(delayedLastSpikeTime < 21.0f) { + ASSERT_LT(wsynDelay[i], -1.0E6); + } + else { + ASSERT_FLOAT_EQ(wsynDelay[i], delayedLastSpikeTime); + } + // Check wsyn read from non-delayed population + if(delayedLastSpikeTime < 1.0f) { ASSERT_LT(wsyn[i], -1.0E6); } else { From c3f3e43353e032429712161dd079fcfe76aa661e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 16 Feb 2023 14:11:37 +0000 Subject: [PATCH 388/725] increase constant cache estimate --- include/genn/backends/cuda/backend.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index a44c10d99f..fadf3d0096 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -94,10 +94,10 @@ struct Preferences : public PreferencesBase KernelBlockSize manualBlockSizes; //! How much constant cache is already used and therefore can't be used by GeNN? - /*! Each of the four modules which includes CUDA headers(neuronUpdate, synapseUpdate, init and runner) + /*! Each of the four modules which includes CUDA headers(neuronUpdate, synapseUpdate, custom update, init and runner) Takes 72 bytes of constant memory for a lookup table used by cuRAND. If your application requires additional constant cache, increase this */ - size_t constantCacheOverhead = 72 * 4; + size_t constantCacheOverhead = 72 * 5; //! NVCC compiler options for all GPU code std::string userNvccFlags = ""; From 3229cf139ddb2affc2f2848fe848dfc3e2c1446e Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 17:52:32 +0100 Subject: [PATCH 389/725] manually cherry pick correct handling of variables with VarAccessDuplication_SHARED_NEURON in PyGeNN --- pygenn/genn_groups.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index 53a48f5e5e..e16391ff92 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -216,9 +216,13 @@ def _load_vars(self, vars, size=None, var_dict=None, get_location_fn=None): # Determine how many copies of this variable are present num_copies = (1 if v.access & VarAccessDuplication.SHARED else self._model.batch_size) - + + # Determine size of this variable + var_size = (1 if v.access & VarAccessDuplication.SHARED_NEURON + else size) + # Get view - var_data.view = self._assign_ext_ptr_array(v.name, size * num_copies, + var_data.view = self._assign_ext_ptr_array(v.name, var_size * num_copies, var_data.type) # If there is more than one copy, reshape view to 2D @@ -816,8 +820,8 @@ def load(self): # If variable is located on host var_loc = self.get_var_location(v.name) if var_loc & VarLocation.HOST: - # **TODO** WHAT IS HAPPENING HERE? # Determine how many copies of this variable are present + # **YUCK** this isn't quite right - really should look at is_batched() #num_copies = (1 if (v.access & VarAccessDuplication_SHARED) != 0 # else self._model.batch_size) num_copies = 1 From a8c117034760c1254306e1de7564bc08a637b813 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 17:56:06 +0100 Subject: [PATCH 390/725] add failing unit test for custom updates attached to SHARED_NEURON variables --- tests/unit/customUpdate.cc | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/unit/customUpdate.cc b/tests/unit/customUpdate.cc index a7e1ac20fc..4d726c5b5f 100644 --- a/tests/unit/customUpdate.cc +++ b/tests/unit/customUpdate.cc @@ -499,6 +499,7 @@ TEST(CustomUpdates, ReductionTypeDuplicateNeuron) ASSERT_TRUE(cuInternal->isBatched()); ASSERT_FALSE(cuInternal->isBatchReduction()); ASSERT_TRUE(cuInternal->isNeuronReduction()); + ASSERT_EQ(cuInternal->getSize(), 10); } //-------------------------------------------------------------------------- TEST(CustomUpdates, ReductionTypeDuplicateNeuronInternal) @@ -639,6 +640,31 @@ TEST(CustomUpdates, NeuronBatchReduction) } } //-------------------------------------------------------------------------- +TEST(CustomUpdates, SharedNeuronVariable) +{ + ModelSpecInternal model; + model.setBatchSize(5); + + // Add neuron (copy of izhikevich model where a, b, c and d are shared_neuron) to model + VarValues izkVarVals{{"V", 0.0}, {"U", 0.0}, + {"a", 0.02}, {"b", 0.2}, {"c", -65.0}, {"d", 8.0}}; + auto *pop = model.addNeuronPopulation("Pop", 10, {}, izkVarVals); + + // Add custom update to sum A and B + VarReferences sumVarReferences{{"a", createVarRef(pop, "a")}, {"b", createVarRef(pop, "b")}}; + auto *cu = model.addCustomUpdate("Sum", "CustomUpdate", + {}, {{"sum", 0.0}}, sumVarReferences); + + // Finalize model + model.finalise(); + + auto *cuInternal = static_cast(cu); + ASSERT_TRUE(cuInternal->isBatched()); + ASSERT_FALSE(cuInternal->isBatchReduction()); + ASSERT_FALSE(cuInternal->isNeuronReduction()); + ASSERT_EQ(cuInternal->getSize(), 1); +} +//-------------------------------------------------------------------------- TEST(CustomUpdates, CompareDifferentModel) { ModelSpecInternal model; From 19b573a3ff14c80aa7acf1ccdb6e8d43c908b8a7 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 18:10:25 +0100 Subject: [PATCH 391/725] Fixed test for duplicatedness in VarReferenceBase::isDuplicated --- include/genn/genn/models.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 4e1a35fab8..6002d61c93 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -161,7 +161,7 @@ class GENN_EXPORT VarReferenceBase bool isDuplicated() const { - return m_IsBatched() && (m_Var.access & VarAccessDuplication::DUPLICATE); + return m_IsBatched() && !(m_Var.access & VarAccessDuplication::SHARED); } bool operator < (const VarReferenceBase &other) const From 13258b3c4e255ef8ba54bae1e63cd4ecfda7a347 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 18:23:40 +0100 Subject: [PATCH 392/725] Proper handling of custom updates to SHARED_NEURON variables * ``CustomUpdate`` detects whether it is per-neuron or not based on whether any variables or variable reference targets point to non SHARED_NEURON variables. If this is the case, SHARED_NEURON variables can only be read * Backend code for generating SHARED_NEURON updates # Conflicts: # src/genn/backends/single_threaded_cpu/backend.cc # src/genn/genn/customUpdate.cc --- include/genn/genn/customUpdate.h | 4 +++ include/genn/genn/customUpdateInternal.h | 1 + .../backends/single_threaded_cpu/backend.cc | 16 ++++++---- src/genn/genn/code_generator/backendSIMT.cc | 24 +++++++++++++- src/genn/genn/customUpdate.cc | 32 +++++++++++++++++-- 5 files changed, 68 insertions(+), 9 deletions(-) diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index f170355b7e..bf50bcf64b 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -243,6 +243,7 @@ class GENN_EXPORT CustomUpdate : public CustomUpdateBase //------------------------------------------------------------------------ bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDuplication::SHARED); } bool isNeuronReduction() const { return isReduction(getVarReferences(), VarAccessDuplication::SHARED_NEURON); } + bool isPerNeuron() const{ return m_PerNeuron; } //! Updates hash with custom update /*! NOTE: this can only be called after model is finalized */ @@ -261,6 +262,9 @@ class GENN_EXPORT CustomUpdate : public CustomUpdateBase const std::unordered_map m_VarReferences; const unsigned int m_Size; const NeuronGroup *m_DelayNeuronGroup; + + //! Is this custom update per-neuron i.e. run in parallel across all neurons + bool m_PerNeuron; }; //------------------------------------------------------------------------ diff --git a/include/genn/genn/customUpdateInternal.h b/include/genn/genn/customUpdateInternal.h index c4c78b5ca5..9407b71066 100644 --- a/include/genn/genn/customUpdateInternal.h +++ b/include/genn/genn/customUpdateInternal.h @@ -25,6 +25,7 @@ class CustomUpdateInternal : public CustomUpdate using CustomUpdateBase::isInitRNGRequired; using CustomUpdateBase::isZeroCopyEnabled; using CustomUpdateBase::isBatched; + using CustomUpdate::isPerNeuron; using CustomUpdateBase::getVarLocationHashDigest; using CustomUpdateBase::getUpdateCodeTokens; diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index b328a09eeb..63a30e814c 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -590,13 +590,17 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back const auto reductionTargets = genInitReductionTargets(groupEnv.getStream(), c); // Loop through group members - groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["size"] << "; i++)"; - { - CodeStream::Scope b(groupEnv.getStream()); - - // Generate custom update - EnvironmentGroupMergedField memberEnv(groupEnv, c); + groupEnv.print("for(unsigned int i = 0; i < $(size); i++)"); + EnvironmentGroupMergedField memberEnv(groupEnv, c); + if (c.getArchetype().isPerNeuron()) { + memberEnv.print("for(unsigned int i = 0; i < $(size); i++)"); memberEnv.add(Type::Uint32.addConst(), "id", "i"); + } + else { + memberEnv.add(Type::Uint32.addConst(), "id", "0"); + } + { + CodeStream::Scope b(memberEnv.getStream()); c.generateCustomUpdate(*this, memberEnv); // Loop through reduction targets and generate reduction diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 987efa971f..8556c2bd33 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -211,7 +211,8 @@ size_t BackendSIMT::getPaddedNumCustomUpdateThreads(const CustomUpdateInternal & return padKernelSize(32 * numCopies, KernelCustomUpdate); } else { - return numCopies * padKernelSize(cg.getSize(), KernelCustomUpdate); + const size_t numElements = cg.isPerNeuron() ? cg.getSize() : 1; + return numCopies * padKernelSize(numElements, KernelCustomUpdate); } } //-------------------------------------------------------------------------- @@ -1018,6 +1019,27 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge } } } + // Otherwise, if this update isn't per-neuron + else if (!cg.getArchetype().isPerNeuron()) { + EnvironmentGroupMergedField groupEnv(env, cg); + if(cg.getArchetype().isBatched()) { + groupEnv.add(Type::Uint32.addConst(), "batch", "$(id)"); + groupEnv.add(Type::Uint32.addConst(), "id", "0"); + } + // Otherwise, just substitute "batch" for 0 + else { + groupEnv.add(Type::Uint32.addConst(), "batch", "0"); + } + + groupEnv.getStream() << "// only do this for existing neurons" << std::endl; + groupEnv.getStream() << "if(" << groupEnv["batch"] << " < " << (cg.getArchetype().isBatched() ? batchSize : 1) << ")"; + { + CodeStream::Scope b(groupEnv.getStream()); + + genCustomUpdateIndexCalculation(groupEnv); + cg.generateCustomUpdate(*this, groupEnv); + } + } // Otherwise else { EnvironmentGroupMergedField groupEnv(env, cg); diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index da8c594247..7d856529fa 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -112,7 +112,7 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : CustomUpdateBase(name, updateGroupName, customUpdateModel, params, varInitialisers, defaultVarLocation, defaultExtraGlobalParamLocation), - m_VarReferences(varReferences), m_Size(varReferences.empty() ? 0 : varReferences.begin()->second.getSize()), m_DelayNeuronGroup(nullptr) + m_VarReferences(varReferences), m_Size(varReferences.empty() ? 0 : varReferences.begin()->second.getSize()), m_DelayNeuronGroup(nullptr), m_PerNeuron(false) { // Validate parameters, variables and variable references getCustomUpdateModel()->validate(getParams(), getVarInitialisers(), getVarReferences(), "Custom update " + getName()); @@ -124,6 +124,32 @@ CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGro // Check variable reference types Models::checkVarReferences(m_VarReferences, getCustomUpdateModel()->getVarRefs()); + // Update is per-neuron if any variables or variable reference targets AREN'T SHARED_NEURON + const auto modelVars = getCustomUpdateModel()->getVars(); + m_PerNeuron = std::any_of(m_VarReferences.cbegin(), m_VarReferences.cend(), + [](const auto& v) + { + return !(v.second.getVar().access & VarAccessDuplication::SHARED_NEURON); + }); + m_PerNeuron |= std::any_of(modelVars.cbegin(), modelVars.cend(), + [](const Models::Base::Var& v) + { + return !(v.access & VarAccessDuplication::SHARED_NEURON); + }); + + // Loop through all variable references + for(const auto &modelVarRef : getCustomUpdateModel()->getVarRefs()) { + const auto &varRef = m_VarReferences.at(modelVarRef.name); + + // If custom update is per-neuron, check that any variable references to SHARED_NEURON variables are read-only + // **NOTE** if custom update isn't per-neuron, it's totally fine to write to SHARED_NEURON variables + if(m_PerNeuron && (varRef.getVar().access & VarAccessDuplication::SHARED_NEURON) + && (modelVarRef.access == VarAccessMode::READ_WRITE)) + { + throw std::runtime_error("Variable references to SHARED_NEURON variables in per-neuron custom updates cannot be read-write."); + } + } + // Check only one type of reduction is specified if (isBatchReduction() && isNeuronReduction()) { throw std::runtime_error("Custom updates cannot perform batch and neuron reductions simultaneously."); @@ -167,7 +193,8 @@ boost::uuids::detail::sha1::digest_type CustomUpdate::getHashDigest() const boost::uuids::detail::sha1 hash; CustomUpdateBase::updateHash(hash); - // Update hash with whether delay is required + // Update hash with whether custom update is per-neuron and if delay is required + Utils::updateHash(isPerNeuron(), hash); const bool delayed = (getDelayNeuronGroup() != nullptr); Utils::updateHash(delayed, hash); @@ -192,6 +219,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdate::getInitHashDigest() const // Superclass boost::uuids::detail::sha1 hash; CustomUpdateBase::updateInitHash(hash); + Utils::updateHash(isPerNeuron(), hash); return hash.get_digest(); } From ca546aaee11c99d22ae1ce69f7bc65c640f6396a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 16 Mar 2023 09:25:40 +0000 Subject: [PATCH 393/725] unit tests for various SHARED_NEURON custom update behaviour # Conflicts: # tests/unit/customUpdate.cc --- tests/unit/customUpdate.cc | 89 ++++++++++++++++++++++++++------------ 1 file changed, 61 insertions(+), 28 deletions(-) diff --git a/tests/unit/customUpdate.cc b/tests/unit/customUpdate.cc index 4d726c5b5f..78a32f33ef 100644 --- a/tests/unit/customUpdate.cc +++ b/tests/unit/customUpdate.cc @@ -83,6 +83,17 @@ class Sum3 : public CustomUpdateModels::Base }; IMPLEMENT_SNIPPET(Sum3); +class Copy : public CustomUpdateModels::Base +{ + DECLARE_SNIPPET(Copy); + + SET_UPDATE_CODE("a = b;\n"); + + SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_ONLY}, + {"b", "scalar", VarAccessMode::READ_ONLY}}); +}; +IMPLEMENT_SNIPPET(Copy); + class Cont : public WeightUpdateModels::Base { public: @@ -410,7 +421,6 @@ TEST(CustomUpdates, BatchingVars) // Add neuron and spike source (arbitrary choice of model with read_only variables) to model VarValues izkVarVals{{"V", 0.0}, {"U", 0.0}, {"a", 0.02}, {"b", 0.2}, {"c", -65.0}, {"d", 8.0}}; auto *pop = model.addNeuronPopulation("Pop", 10, {}, izkVarVals); - // Create updates where variable is shared and references vary VarValues sumVarValues{{"sum", 1.0}}; @@ -433,9 +443,32 @@ TEST(CustomUpdates, BatchingVars) model.finalise(); EXPECT_TRUE(static_cast(sum1)->isBatched()); + EXPECT_TRUE(static_cast(sum1)->isPerNeuron()); EXPECT_FALSE(static_cast(sum2)->isBatched()); + EXPECT_TRUE(static_cast(sum2)->isPerNeuron()); EXPECT_TRUE(static_cast(sum3)->isBatched()); + EXPECT_TRUE(static_cast(sum3)->isPerNeuron()); EXPECT_FALSE(static_cast(sum4)->isBatched()); + EXPECT_TRUE(static_cast(sum4)->isPerNeuron()); +} +//-------------------------------------------------------------------------- +TEST(CustomUpdates, NeuronSharedVars) +{ + ModelSpecInternal model; + model.setBatchSize(5); + + // Add neuron and spike source (arbitrary choice of model with read-only neuron shared variables) to model + VarValues izkVarVals{{"V", 0.0}, {"U", 0.0}, {"a", 0.02}, {"b", 0.2}, {"c", -65.0}, {"d", 8.0}}; + auto *pop = model.addNeuronPopulation("Pop", 10, {}, izkVarVals); + + VarReferences copyVarReferences1{{"a", createVarRef(pop, "a")}, {"b", createVarRef(pop, "b")}}; + CustomUpdate *cu = model.addCustomUpdate("Copy", "CustomUpdate", + {}, {}, copyVarReferences1); + model.finalise(); + + auto *cuInternal = static_cast(cu); + EXPECT_TRUE(cuInternal->isBatched()); + EXPECT_FALSE(cuInternal->isPerNeuron()); } //-------------------------------------------------------------------------- TEST(CustomUpdates, BatchingWriteShared) @@ -458,7 +491,28 @@ TEST(CustomUpdates, BatchingWriteShared) } } //-------------------------------------------------------------------------- -TEST(CustomUpdates, ReduceDuplicate) +TEST(CustomUpdates, WriteNeuronShared) +{ + ModelSpecInternal model; + model.setBatchSize(5); + + // Add neuron and spike source (arbitrary choice of model with read-only neuron shared variables) to model + VarValues izkVarVals{{"V", 0.0}, {"U", 0.0}, {"a", 0.02}, {"b", 0.2}, {"c", -65.0}, {"d", 8.0}}; + auto *pop = model.addNeuronPopulation("Pop", 10, {}, izkVarVals); + + // Create custom update which tries to create a read-write reference to a (which isn't per-neuron) + VarValues sum2VarValues{{"mult", 1.0}}; + VarReferences sum2VarReferences{{"a", createVarRef(pop, "a")}, {"b", createVarRef(pop, "V")}}; + try { + model.addCustomUpdate("Sum1", "CustomUpdate", + {}, sum2VarValues, sum2VarReferences); + FAIL(); + } + catch(const std::runtime_error &) { + } +} +//-------------------------------------------------------------------------- +TEST(CustomUpdates, WriteBatchShared) { ModelSpecInternal model; model.setBatchSize(5); @@ -499,7 +553,7 @@ TEST(CustomUpdates, ReductionTypeDuplicateNeuron) ASSERT_TRUE(cuInternal->isBatched()); ASSERT_FALSE(cuInternal->isBatchReduction()); ASSERT_TRUE(cuInternal->isNeuronReduction()); - ASSERT_EQ(cuInternal->getSize(), 10); + ASSERT_TRUE(cuInternal->isPerNeuron()); } //-------------------------------------------------------------------------- TEST(CustomUpdates, ReductionTypeDuplicateNeuronInternal) @@ -522,6 +576,7 @@ TEST(CustomUpdates, ReductionTypeDuplicateNeuronInternal) ASSERT_TRUE(cuInternal->isBatched()); ASSERT_FALSE(cuInternal->isBatchReduction()); ASSERT_TRUE(cuInternal->isNeuronReduction()); + ASSERT_TRUE(cuInternal->isPerNeuron()); } //-------------------------------------------------------------------------- TEST(CustomUpdates, ReductionTypeSharedNeuronInternal) @@ -544,6 +599,7 @@ TEST(CustomUpdates, ReductionTypeSharedNeuronInternal) ASSERT_FALSE(cuInternal->isBatched()); ASSERT_FALSE(cuInternal->isBatchReduction()); ASSERT_TRUE(cuInternal->isNeuronReduction()); + ASSERT_TRUE(cuInternal->isPerNeuron()); } //-------------------------------------------------------------------------- TEST(CustomUpdates, ReductionTypeDuplicateBatch) @@ -565,6 +621,7 @@ TEST(CustomUpdates, ReductionTypeDuplicateBatch) ASSERT_TRUE(cuInternal->isBatched()); ASSERT_TRUE(cuInternal->isBatchReduction()); ASSERT_FALSE(cuInternal->isNeuronReduction()); + ASSERT_TRUE(cuInternal->isPerNeuron()); } //-------------------------------------------------------------------------- TEST(CustomUpdates, ReductionTypeDuplicateBatchInternal) @@ -587,6 +644,7 @@ TEST(CustomUpdates, ReductionTypeDuplicateBatchInternal) ASSERT_TRUE(cuInternal->isBatched()); ASSERT_TRUE(cuInternal->isBatchReduction()); ASSERT_FALSE(cuInternal->isNeuronReduction()); + ASSERT_TRUE(cuInternal->isPerNeuron()); } //-------------------------------------------------------------------------- TEST(CustomUpdates, NeuronSharedCustomUpdateWU) @@ -640,31 +698,6 @@ TEST(CustomUpdates, NeuronBatchReduction) } } //-------------------------------------------------------------------------- -TEST(CustomUpdates, SharedNeuronVariable) -{ - ModelSpecInternal model; - model.setBatchSize(5); - - // Add neuron (copy of izhikevich model where a, b, c and d are shared_neuron) to model - VarValues izkVarVals{{"V", 0.0}, {"U", 0.0}, - {"a", 0.02}, {"b", 0.2}, {"c", -65.0}, {"d", 8.0}}; - auto *pop = model.addNeuronPopulation("Pop", 10, {}, izkVarVals); - - // Add custom update to sum A and B - VarReferences sumVarReferences{{"a", createVarRef(pop, "a")}, {"b", createVarRef(pop, "b")}}; - auto *cu = model.addCustomUpdate("Sum", "CustomUpdate", - {}, {{"sum", 0.0}}, sumVarReferences); - - // Finalize model - model.finalise(); - - auto *cuInternal = static_cast(cu); - ASSERT_TRUE(cuInternal->isBatched()); - ASSERT_FALSE(cuInternal->isBatchReduction()); - ASSERT_FALSE(cuInternal->isNeuronReduction()); - ASSERT_EQ(cuInternal->getSize(), 1); -} -//-------------------------------------------------------------------------- TEST(CustomUpdates, CompareDifferentModel) { ModelSpecInternal model; From e1c2961d260c491cec1f888f46bb67766adbdb61 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 16 Mar 2023 09:31:57 +0000 Subject: [PATCH 394/725] extended custom update feature test to cover shared neuron variables --- tests/features/custom_update/model.cc | 69 +++++++++++++++++++-------- tests/features/custom_update/test.cc | 20 +++++--- 2 files changed, 62 insertions(+), 27 deletions(-) diff --git a/tests/features/custom_update/model.cc b/tests/features/custom_update/model.cc index 8d4f12d68d..b6174bb2f4 100644 --- a/tests/features/custom_update/model.cc +++ b/tests/features/custom_update/model.cc @@ -12,38 +12,38 @@ suite of minimal models with known analytic outcomes that are used for continuou class TestNeuron : public NeuronModels::Base { public: - DECLARE_MODEL(TestNeuron, 0, 1); + DECLARE_MODEL(TestNeuron, 0, 2); - SET_VARS({{"V","scalar"}}); + SET_VARS({{"V","scalar"}, {"VShared", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}}); }; IMPLEMENT_MODEL(TestNeuron); class TestCurrentSource : public CurrentSourceModels::Base { - DECLARE_MODEL(TestCurrentSource, 0, 1); + DECLARE_MODEL(TestCurrentSource, 0, 2); - SET_VARS({{"C", "scalar"}}); + SET_VARS({{"C", "scalar"}, {"CShared", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}}); }; IMPLEMENT_MODEL(TestCurrentSource); class TestPSM : public PostsynapticModels::Base { public: - DECLARE_MODEL(TestPSM, 0, 1); + DECLARE_MODEL(TestPSM, 0, 2); SET_CURRENT_CONVERTER_CODE("$(inSyn); $(inSyn) = 0"); - SET_VARS({{"P", "scalar"}}); + SET_VARS({{"P", "scalar"}, {"PShared", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}}); }; IMPLEMENT_MODEL(TestPSM); class TestWUM : public WeightUpdateModels::Base { public: - DECLARE_WEIGHT_UPDATE_MODEL(TestWUM, 0, 1, 1, 1); + DECLARE_WEIGHT_UPDATE_MODEL(TestWUM, 0, 1, 2, 2); SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}}); - SET_PRE_VARS({{"Pre", "scalar"}}); - SET_POST_VARS({{"Post", "scalar"}}); + SET_PRE_VARS({{"Pre", "scalar"}, {"PreShared", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}}); + SET_POST_VARS({{"Post", "scalar"}, {"PostShared", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}}); SET_SIM_CODE("$(addToInSyn, $(g));\n"); }; IMPLEMENT_MODEL(TestWUM); @@ -72,6 +72,18 @@ class SetTime : public CustomUpdateModels::Base }; IMPLEMENT_MODEL(SetTime); +class SetTimeShared : public CustomUpdateModels::Base +{ +public: + DECLARE_CUSTOM_UPDATE_MODEL(SetTimeShared, 0, 0, 1); + + SET_UPDATE_CODE( + "$(R) = $(t);\n"); + + SET_VAR_REFS({{"R", "scalar", VarAccessMode::READ_WRITE}}) +}; +IMPLEMENT_MODEL(SetTimeShared); + void modelDefinition(ModelSpec &model) { #ifdef CL_HPP_TARGET_OPENCL_VERSION @@ -92,24 +104,24 @@ void modelDefinition(ModelSpec &model) 10, 10, 1); // conv_oh, conv_ow, conv_oc model.addNeuronPopulation("SpikeSource", 100, {}, {}); - auto *ng = model.addNeuronPopulation("Neuron", 100, {}, {0.0}); - auto *cs = model.addCurrentSource("CurrentSource", "Neuron", {}, {0.0}); + auto *ng = model.addNeuronPopulation("Neuron", 100, {}, {0.0, 0.0}); + auto *cs = model.addCurrentSource("CurrentSource", "Neuron", {}, {0.0, 0.0}); auto *denseSG = model.addSynapsePopulation( "Dense", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, "SpikeSource", "Neuron", - {}, {0.0}, {0.0}, {0.0}, - {}, {0.0}); + {}, {0.0}, {0.0, 0.0}, {0.0, 0.0}, + {}, {0.0, 0.0}); auto *sparseSG = model.addSynapsePopulation( "Sparse", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, "SpikeSource", "Neuron", - {}, {0.0}, {0.0}, {0.0}, - {}, {0.0}, + {}, {0.0}, {0.0, 0.0}, {0.0, 0.0}, + {}, {0.0, 0.0}, initConnectivity({0.1})); auto *kernelSG = model.addSynapsePopulation( "Kernel", SynapseMatrixType::TOEPLITZ_KERNELG, NO_DELAY, "SpikeSource", "Neuron", - {}, {0.0}, {0.0}, {0.0}, - {}, {0.0}, + {}, {0.0}, {0.0, 0.0}, {0.0, 0.0}, + {}, {0.0, 0.0}, initToeplitzConnectivity(convParams)); TestCU::VarReferences cuTestVarReferences(createVarRef(ng, "V")); @@ -124,10 +136,16 @@ void modelDefinition(ModelSpec &model) SetTime::VarReferences neuronVarReferences(createVarRef(ng, "V")); // R model.addCustomUpdate("NeuronSetTime", "Test", {}, {0.0}, neuronVarReferences); - + SetTimeShared::VarReferences neuronSharedVarReferences(createVarRef(ng, "VShared")); // R + model.addCustomUpdate("NeuronSharedSetTime", "Test", + {}, {}, neuronSharedVarReferences); + SetTime::VarReferences csVarReferences(createVarRef(cs, "C")); // R model.addCustomUpdate("CurrentSourceSetTime", "Test", {}, {0.0}, csVarReferences); + SetTimeShared::VarReferences csSharedVarReferences(createVarRef(cs, "CShared")); // R + model.addCustomUpdate("CurrentSourceSharedSetTime", "Test", + {}, {}, csSharedVarReferences); SetTime::VarReferences cuVarReferences(createVarRef(cu, "C")); // R model.addCustomUpdate("CustomUpdateSetTime", "Test", @@ -136,15 +154,24 @@ void modelDefinition(ModelSpec &model) SetTime::VarReferences psmVarReferences(createPSMVarRef(denseSG, "P")); // R model.addCustomUpdate("PSMSetTime", "Test", {}, {0.0}, psmVarReferences); - + SetTimeShared::VarReferences psmSharedVarReferences(createPSMVarRef(denseSG, "PShared")); // R + model.addCustomUpdate("PSMSharedSetTime", "Test", + {}, {}, psmSharedVarReferences); + SetTime::VarReferences wuPreVarReferences(createWUPreVarRef(denseSG, "Pre")); // R model.addCustomUpdate("WUPreSetTime", "Test", {}, {0.0}, wuPreVarReferences); - + SetTimeShared::VarReferences wuPreSharedVarReferences(createWUPreVarRef(denseSG, "PreShared")); // R + model.addCustomUpdate("WUPreSharedSetTime", "Test", + {}, {}, wuPreSharedVarReferences); + SetTime::VarReferences wuPostVarReferences(createWUPostVarRef(sparseSG, "Post")); // R model.addCustomUpdate("WUPostSetTime", "Test", {}, {0.0}, wuPostVarReferences); - + SetTimeShared::VarReferences wuPostSharedVarReferences(createWUPostVarRef(sparseSG, "PostShared")); // R + model.addCustomUpdate("WUPostSharedSetTime", "Test", + {}, {}, wuPostSharedVarReferences); + SetTime::WUVarReferences wuDenseVarReferences(createWUVarRef(denseSG, "g")); // R model.addCustomUpdate("WUDenseSetTime", "Test", {}, {0.0}, wuDenseVarReferences); diff --git a/tests/features/custom_update/test.cc b/tests/features/custom_update/test.cc index 741b96f0b2..49e2ebfb65 100644 --- a/tests/features/custom_update/test.cc +++ b/tests/features/custom_update/test.cc @@ -39,16 +39,24 @@ TEST_F(SimTest, CustomUpdate) // Pull variables pullVNeuronSetTimeFromDevice(); pullVNeuronFromDevice(); + pullVSharedNeuronFromDevice(); + pullVCurrentSourceSetTimeFromDevice(); pullCCurrentSourceFromDevice(); + pullCSharedCurrentSourceFromDevice(); + pullVCustomUpdateSetTimeFromDevice(); pullCCustomUpdateFromDevice(); + pullVPSMSetTimeFromDevice(); pullPDenseFromDevice(); + pullPSharedDenseFromDevice(); pullVWUPreSetTimeFromDevice(); pullPreDenseFromDevice(); + pullPreSharedDenseFromDevice(); pullVWUPostSetTimeFromDevice(); pullPostSparseFromDevice(); + pullPostSharedSparseFromDevice(); pullVWUDenseSetTimeFromDevice(); pullgDenseFromDevice(); pullVWUSparseSetTimeFromDevice(); @@ -64,12 +72,14 @@ TEST_F(SimTest, CustomUpdate) [](scalar v) { return v == t; })); EXPECT_TRUE(std::all_of(&VNeuronSetTime[0], &VNeuronSetTime[100], [](scalar v) { return v == t; })); + EXPECT_EQ(VSharedNeuron[0], t); EXPECT_TRUE(std::all_of(&VCurrentSourceSetTime[0], &VCurrentSourceSetTime[100], [](scalar v) { return v == t; })); EXPECT_TRUE(std::all_of(&CCurrentSource[0], &CCurrentSource[100], [](scalar v) { return v == t; })); - + EXPECT_EQ(CSharedCurrentSource[0], t); + EXPECT_TRUE(std::all_of(&VCustomUpdateSetTime[0], &VCustomUpdateSetTime[100], [](scalar v) { return v == t; })); EXPECT_TRUE(std::all_of(&CCustomUpdate[0], &CCustomUpdate[100], @@ -79,21 +89,19 @@ TEST_F(SimTest, CustomUpdate) [](scalar v) { return v == t; })); EXPECT_TRUE(std::all_of(&PDense[0], &PDense[100], [](scalar v) { return v == t; })); + EXPECT_EQ(PSharedDense[0], t); EXPECT_TRUE(std::all_of(&VWUPreSetTime[0], &VWUPreSetTime[100], [](scalar v) { return v == t; })); EXPECT_TRUE(std::all_of(&PreDense[0], &PreDense[100], [](scalar v) { return v == t; })); + EXPECT_EQ(PreSharedDense[0], t); EXPECT_TRUE(std::all_of(&VWUPostSetTime[0], &VWUPostSetTime[100], [](scalar v) { return v == t; })); EXPECT_TRUE(std::all_of(&PostSparse[0], &PostSparse[100], [](scalar v) { return v == t; })); - - EXPECT_TRUE(std::all_of(&VPSMSetTime[0], &VPSMSetTime[100], - [](scalar v) { return v == t; })); - EXPECT_TRUE(std::all_of(&PDense[0], &PDense[100], - [](scalar v) { return v == t; })); + EXPECT_EQ(PostSharedSparse[0], t); EXPECT_TRUE(std::all_of(&VWUDenseSetTime[0], &VWUDenseSetTime[100 * 100], [](scalar v) { return v == t; })); From 3ee60548aaf83f5c560633d46413973779b4badc Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 16 Mar 2023 09:53:12 +0000 Subject: [PATCH 395/725] extended batch custom update test to cover shared neuron variables --- tests/features/custom_update_batch/model.cc | 25 +++++++++++++++++---- tests/features/custom_update_batch/test.cc | 2 ++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/tests/features/custom_update_batch/model.cc b/tests/features/custom_update_batch/model.cc index 911513c1bc..d093f49d12 100644 --- a/tests/features/custom_update_batch/model.cc +++ b/tests/features/custom_update_batch/model.cc @@ -12,9 +12,10 @@ suite of minimal models with known analytic outcomes that are used for continuou class TestNeuron : public NeuronModels::Base { public: - DECLARE_MODEL(TestNeuron, 0, 2); + DECLARE_MODEL(TestNeuron, 0, 3); - SET_VARS({{"V","scalar"}, {"U", "scalar", VarAccess::READ_ONLY}}); + SET_VARS({{"V","scalar"}, {"U", "scalar", VarAccess::READ_ONLY}, + {"S", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}}); }; IMPLEMENT_MODEL(TestNeuron); @@ -41,6 +42,18 @@ class SetTimeBatch : public CustomUpdateModels::Base }; IMPLEMENT_MODEL(SetTimeBatch); +class SetTimeShared : public CustomUpdateModels::Base +{ +public: + DECLARE_CUSTOM_UPDATE_MODEL(SetTimeShared, 0, 0, 1); + + SET_UPDATE_CODE( + "$(R) = ($(batch) * 1000.0) + $(t);\n"); + + SET_VAR_REFS({{"R", "scalar", VarAccessMode::READ_WRITE}}) +}; +IMPLEMENT_MODEL(SetTimeShared); + class SetTime : public CustomUpdateModels::Base { public: @@ -99,7 +112,7 @@ void modelDefinition(ModelSpec &model) model.setBatchSize(5); model.addNeuronPopulation("SpikeSource", 50, {}, {}); - auto *ng = model.addNeuronPopulation("Neuron", 50, {}, {0.0, 0.0}); + auto *ng = model.addNeuronPopulation("Neuron", 50, {}, {0.0, 0.0, 0.0}); auto *denseSG = model.addSynapsePopulation( "Dense", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, "SpikeSource", "Neuron", @@ -124,7 +137,11 @@ void modelDefinition(ModelSpec &model) SetTimeBatch::VarReferences neuronDuplicateVarReferences(createVarRef(ng, "V")); // R model.addCustomUpdate("NeuronDuplicateSetTime", "Test", {}, {0.0}, neuronDuplicateVarReferences); - + + SetTimeShared::VarReferences neuronSharedNeuronVarReferences(createVarRef(ng, "S")); // R + model.addCustomUpdate("NeuronSharedNeuronSetTime", "Test", + {}, {}, neuronSharedNeuronVarReferences); + SetTime::VarReferences neuronSharedVarReferences(createVarRef(ng, "U")); // R model.addCustomUpdate("NeuronSharedSetTime", "Test", {}, {0.0}, neuronSharedVarReferences); diff --git a/tests/features/custom_update_batch/test.cc b/tests/features/custom_update_batch/test.cc index 2286c7acd7..6948e12344 100644 --- a/tests/features/custom_update_batch/test.cc +++ b/tests/features/custom_update_batch/test.cc @@ -55,6 +55,7 @@ TEST_F(SimTest, CustomUpdateBatch) pullVWUMKernelDuplicateSetTimeFromDevice(); pullVNeuronFromDevice(); pullUNeuronFromDevice(); + pullSNeuronFromDevice(); pullVDenseFromDevice(); pullUDenseFromDevice(); pullVSparseFromDevice(); @@ -82,6 +83,7 @@ TEST_F(SimTest, CustomUpdateBatch) const unsigned int startSparseSynIdx = b * (50 * maxRowLengthSparse); const float batchOffset = b * 1000.0f; + ASSERT_EQ(SNeuron[b], batchOffset + t); // Check batched variables match expectations ASSERT_TRUE(std::all_of(&VNeuronDuplicateSetTime[startNeuronIdx], &VNeuronDuplicateSetTime[endNeuronIdx], [batchOffset](scalar v) { return v == (batchOffset + t); })); From 3b691ac02122b8199854e39be808fd8b4070b688 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 18:33:12 +0100 Subject: [PATCH 396/725] **CHERRYPICK** fix overallocation of threads to shared-neuron updates --- src/genn/genn/code_generator/backendSIMT.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 8556c2bd33..e9a9bee8f1 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -210,9 +210,11 @@ size_t BackendSIMT::getPaddedNumCustomUpdateThreads(const CustomUpdateInternal & if (cg.isNeuronReduction()) { return padKernelSize(32 * numCopies, KernelCustomUpdate); } + else if (cg.isPerNeuron()) { + return numCopies * padKernelSize(cg.getSize(), KernelCustomUpdate); + } else { - const size_t numElements = cg.isPerNeuron() ? cg.getSize() : 1; - return numCopies * padKernelSize(numElements, KernelCustomUpdate); + return padKernelSize(numCopies, KernelCustomUpdate); } } //-------------------------------------------------------------------------- From 86fd82a9af73502ed15bf30f26248bde92e49706 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 18:34:38 +0100 Subject: [PATCH 397/725] manually cherry pick allow kernels to be retrieved via PyGeNN ``SynapseGroup.get_var_values`` --- pygenn/genn_groups.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index e16391ff92..b5ded4dfb2 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -464,6 +464,8 @@ def get_var_values(self, var_name): if self.matrix_type & SynapseMatrixConnectivity.DENSE: return np.copy(var_view) + elif self.matrix_type & SynapseMatrixConnectivity.KERNEL: + return np.copy(var_view) elif self.matrix_type & SynapseMatrixConnectivity.SPARSE: max_rl = self.max_row_length row_ls = self._row_lengths if self._connectivity_initialiser_provided else self.row_lengths From c19c940de15de6ae1de17c0078565f11389c454c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 11 Jul 2023 18:36:41 +0100 Subject: [PATCH 398/725] On Windows, call ``os.add_dll_directory`` before trying to load CUDA backend # Conflicts: # pygenn/genn_model.py --- pygenn/genn_model.py | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index e3798e03a7..ab81fb589b 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -48,6 +48,7 @@ from psutil import cpu_count from setuptools import msvc from subprocess import check_call # to call make +import sys from textwrap import dedent from warnings import warn from weakref import proxy @@ -78,23 +79,6 @@ init_toeplitz_connectivity_snippets, init_var_snippets, neuron_models, postsynaptic_models, weight_update_models) -# Loop through backends in preferential order -backend_modules = OrderedDict() -for b in ["cuda", "single_threaded_cpu", "opencl"]: - # Try and import - try: - m = import_module("." + b + "_backend", "pygenn") - # Ignore failed imports - likely due to non-supported backends - except ImportError as ex: - pass - # Raise any other errors - except: - raise - # Otherwise add to (ordered) dictionary - else: - backend_modules[b] = m - - # Dynamically add Python mixin to wrapped class CurrentSource.__bases__ += (CurrentSourceMixin,) CustomUpdate.__bases__ += (CustomUpdateMixin,) @@ -120,6 +104,30 @@ # **NOTE** shutil.which would be nicer, but isn't in Python < 3.3 _msbuild = find_executable("msbuild", _msvc_env["PATH"]) + # If Python version is newer than 3.8 and CUDA path is in environment + if sys.version_info >= (3, 8) and "CUDA_PATH" in environ: + # Add CUDA bin directory to DLL search directories + from os import add_dll_directory + add_dll_directory(path.join(environ["CUDA_PATH"], "bin")) + + +# Loop through backends in preferential order +backend_modules = OrderedDict() +for b in ["cuda", "single_threaded_cpu", "opencl"]: + # Try and import + try: + m = import_module("." + b + "_backend", "pygenn") + # Ignore failed imports - likely due to non-supported backends + except ImportError as ex: + pass + # Raise any other errors + except: + raise + # Otherwise add to (ordered) dictionary + else: + backend_modules[b] = m + + GeNNType = namedtuple("GeNNType", ["np_dtype", "assign_ext_ptr_array", "assign_ext_ptr_single"]) class GeNNModel(ModelSpecInternal): From 347e725bdcc25f3949f58f99050d1a248cd530f0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 17 Jul 2023 17:56:27 +0100 Subject: [PATCH 399/725] removed some crud from groupMerged.h --- include/genn/genn/code_generator/groupMerged.h | 8 -------- 1 file changed, 8 deletions(-) diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 3a9a61c3e4..77cbea89bd 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -19,20 +19,12 @@ #include "code_generator/backendBase.h" #include "code_generator/codeGenUtils.h" -// GeNN transpiler includes -#include "transpiler/statement.h" -#include "transpiler/typeChecker.h" - // Forward declarations namespace GeNN::CodeGenerator { class CodeStream; } -namespace GeNN::Transpiler::TypeChecker -{ -class EnvironmentBase; -} //------------------------------------------------------------------------ // GeNN::CodeGenerator::GroupMergedFieldType From 14c662e8290c402ab7408dcd0c498ceed7dab461 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 17 Jul 2023 17:56:50 +0100 Subject: [PATCH 400/725] fixed GCC warnings --- include/genn/genn/code_generator/customUpdateGroupMerged.h | 6 +++--- include/genn/genn/code_generator/environment.h | 6 +++--- include/genn/genn/code_generator/initGroupMerged.h | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/include/genn/genn/code_generator/customUpdateGroupMerged.h b/include/genn/genn/code_generator/customUpdateGroupMerged.h index 1c5c111079..37afe86d7a 100644 --- a/include/genn/genn/code_generator/customUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customUpdateGroupMerged.h @@ -137,7 +137,7 @@ template class CustomUpdateHostReductionGroupMergedBase : public GroupMerged { protected: - using GroupMerged::GroupMerged; + using GroupMerged::GroupMerged; template void generateCustomUpdateBase(const BackendBase &backend, EnvironmentGroupMergedField &env) @@ -146,7 +146,7 @@ class CustomUpdateHostReductionGroupMergedBase : public GroupMerged const auto *cm = this->getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { if(v.access & VarAccessModeAttribute::REDUCE) { - const auto fieldType = v.type.resolve(getTypeContext()).createPointer(); + const auto fieldType = v.type.resolve(this->getTypeContext()).createPointer(); env.addField(fieldType, v.name, v.name, [&backend, v](const auto &g, size_t) { @@ -158,7 +158,7 @@ class CustomUpdateHostReductionGroupMergedBase : public GroupMerged // Loop through variable references and add pointers if they are reduction targets for(const auto &v : cm->getVarRefs()) { if(v.access & VarAccessModeAttribute::REDUCE) { - const auto fieldType = v.type.resolve(getTypeContext()).createPointer(); + const auto fieldType = v.type.resolve(this->getTypeContext()).createPointer(); env.addField(fieldType, v.name, v.name, [&backend, v](const auto &g, size_t) { diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 64702d6e18..115ffa96eb 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -288,9 +288,9 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P } // Perform any type-specific logic to mark this identifier as required - setRequired(std::get<2>(env->second)); + this->setRequired(std::get<2>(env->second)); - return getNameInternal(std::get<2>(env->second)); + return this->getNameInternal(std::get<2>(env->second)); } } @@ -314,7 +314,7 @@ class EnvironmentExternalDynamicBase : public EnvironmentExternalBase, public P } // Perform any type-specific logic to mark this identifier as required - setRequired(std::get<2>(env->second)); + this->setRequired(std::get<2>(env->second)); // Return type of variables return {std::get<0>(env->second)}; diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index fceec9bdd3..948aa6fe40 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -461,7 +461,7 @@ class GENN_EXPORT CustomConnectivityUpdatePreInitGroupMerged : public InitGroupM CustomConnectivityUpdatePreVarAdapter> { public: - InitGroupMergedBase::InitGroupMergedBase; + using InitGroupMergedBase::InitGroupMergedBase; //---------------------------------------------------------------------------- // Public API From 7af6e43dce13e075e07484dba73968124c60e7ae Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 09:29:29 +0100 Subject: [PATCH 401/725] removed all trace of groupMergedTypeEnvironment --- .../groupMergedTypeEnvironment.h | 272 ------------------ .../code_generator/customUpdateGroupMerged.cc | 1 - .../genn/code_generator/initGroupMerged.cc | 1 - .../code_generator/neuronUpdateGroupMerged.cc | 1 - src/genn/genn/genn.vcxproj | 1 - 5 files changed, 276 deletions(-) delete mode 100644 include/genn/genn/code_generator/groupMergedTypeEnvironment.h diff --git a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h b/include/genn/genn/code_generator/groupMergedTypeEnvironment.h deleted file mode 100644 index 15b779d150..0000000000 --- a/include/genn/genn/code_generator/groupMergedTypeEnvironment.h +++ /dev/null @@ -1,272 +0,0 @@ -#pragma once - -// Standard C++ includes -#include - -// GeNN code generator includes -#include "code_generator/groupMerged.h" - -// GeNN transpiler includes -#include "transpiler/errorHandler.h" -#include "transpiler/typeChecker.h" - -//---------------------------------------------------------------------------- -// GeNN::CodeGenerator::GroupMergedTypeEnvironment -//---------------------------------------------------------------------------- -namespace GeNN::CodeGenerator -{ -template -class GroupMergedTypeEnvironment : public Transpiler::TypeChecker::EnvironmentBase -{ - using Token = Transpiler::Token; - using ErrorHandlerBase = Transpiler::ErrorHandlerBase; - using EnvironmentBase = Transpiler::TypeChecker::EnvironmentBase; - using TypeCheckError = Transpiler::TypeChecker::TypeCheckError; - - using IsHeterogeneousFn = bool (G::*)(const std::string&) const; - using IsVarInitHeterogeneousFn = bool (G::*)(const std::string&, const std::string&) const; - - using GroupInternal = typename G::GroupInternal; - using GetVarSuffixFn = const std::string &(GroupInternal::*)(void) const; - using GetParamValuesFn = const std::unordered_map &(GroupInternal::*)(void) const; - - template - using GetVarReferencesFn = const std::unordered_map &(GroupInternal::*)(void) const; - -public: - GroupMergedTypeEnvironment(G &groupMerged, EnvironmentBase *enclosing = nullptr) - : m_GroupMerged(groupMerged), m_Enclosing(enclosing) - { - } - - //--------------------------------------------------------------------------- - // EnvironmentBase virtuals - //--------------------------------------------------------------------------- - virtual void define(const Transpiler::Token &name, const Type::ResolvedType&, ErrorHandlerBase &errorHandler) final - { - errorHandler.error(name, "Cannot declare variable in external environment"); - throw TypeCheckError(); - } - - virtual std::vector getTypes(const Token &name, ErrorHandlerBase &errorHandler) final - { - auto type = m_Types.find(name.lexeme); - if(type == m_Types.end()) { - if(m_Enclosing) { - return m_Enclosing->getTypes(name, errorHandler); - } - else { - errorHandler.error(name, "Undefined identifier"); - throw TypeCheckError(); - } - } - else { - // Add field to merged group if required - addField(type->second); - - return {type->second.first}; - } - } - - //--------------------------------------------------------------------------- - // Public API - //--------------------------------------------------------------------------- - void defineField(const Type::ResolvedType &type, const std::string &name) - { - if(!m_Types.try_emplace(name, type, std::nullopt).second) { - throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); - } - } - - - void defineField(const Type::ResolvedType &type, const std::string &name, - const Type::ResolvedType &fieldType, std::string_view fieldName, typename G::GetFieldValueFunc getFieldValue, - GroupMergedFieldType mergedFieldType = GroupMergedFieldType::STANDARD) - { - if(!m_Types.try_emplace(name, std::piecewise_construct, - std::forward_as_tuple(type), - std::forward_as_tuple(std::in_place, fieldType, fieldName, getFieldValue, mergedFieldType)).second) - { - throw std::runtime_error("Redeclaration of '" + std::string{name} + "'"); - } - } - - void definePointerField(const Type::ResolvedType &type, const std::string &name, const std::string &prefix, VarAccessMode access, - const std::string &fieldSuffix = "", GetVarSuffixFn getVarSuffixFn = &GroupInternal::getName) - { - const auto qualifiedType = (access & VarAccessModeAttribute::READ_ONLY) ? type.addQualifier(Type::Qualifier::CONSTANT) : type; - defineField(qualifiedType, name, - type.createPointer(), name + fieldSuffix, - [name, prefix, getVarSuffixFn](const auto &g, size_t) { return prefix + name + std::invoke(getVarSuffixFn, g); }); - } - - void definePointerField(const Type::UnresolvedType &type, const std::string &name, const std::string &prefix, VarAccessMode access, - const std::string &fieldSuffix = "", GetVarSuffixFn getVarSuffixFn = &GroupInternal::getName) - { - definePointerField(type.resolve(m_GroupMerged.getTypeContext()), name, prefix, access, fieldSuffix, getVarSuffixFn); - } - - void defineScalarField(const std::string &name, const std::string &fieldSuffix, typename G::GetFieldDoubleValueFunc getFieldValue) - { - defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), name, - m_GroupMerged.getScalarType(), name + fieldSuffix, - [getFieldValue, this](const auto &g, size_t i) - { - return (Utils::writePreciseString(getFieldValue(g, i), m_GroupMerged.getScalarType().getNumeric().maxDigits10) - + m_GroupMerged.getScalarType().getNumeric().literalSuffix); - }); - } - - void defineHeterogeneousParams(const Snippet::Base::StringVec ¶mNames, const std::string &fieldSuffix, - GetParamValuesFn getParamValues, IsHeterogeneousFn isHeterogeneous) - { - // Loop through params - for(const auto &p : paramNames) { - if (std::invoke(isHeterogeneous, m_GroupMerged, p)) { - defineScalarField(p, fieldSuffix, - [p, getParamValues](const auto &g, size_t) - { - return std::invoke(getParamValues, g).at(p); - }); - } - // Otherwise, just add a const-qualified scalar to the type environment - else { - defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), p); - } - } - } - - void defineHeterogeneousDerivedParams(const Snippet::Base::DerivedParamVec &derivedParams, const std::string &fieldSuffix, - GetParamValuesFn getDerivedParamValues, IsHeterogeneousFn isHeterogeneous) - { - // Loop through derived params - for(const auto &d : derivedParams) { - if (std::invoke(isHeterogeneous, m_GroupMerged, d.name)) { - defineScalarField(d.name, fieldSuffix, - [d, getDerivedParamValues](const auto &g, size_t) - { - return std::invoke(getDerivedParamValues, g).at(d.name); - }); - } - else { - defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), d.name); - } - } - } - - template - void defineHeterogeneousVarInitParams(IsVarInitHeterogeneousFn isHeterogeneous, const std::string &fieldSuffix = "") - { - // Loop through weight update model variables - const A archetypeAdaptor(m_GroupMerged.getArchetype()); - for(const auto &v : archetypeAdaptor.getDefs()) { - // Loop through parameters - for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getParams()) { - if(std::invoke(isHeterogeneous, m_GroupMerged, v.name, p.first)) { - defineScalarField(p.first, v.name + fieldSuffix, - [p, v](const auto &g, size_t) - { - return A(g).getInitialisers().at(v.name).getParams().at(p.first); - }); - } - else { - defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), p.first + v.name); - } - } - } - } - - template - void defineHeterogeneousVarInitDerivedParams(IsVarInitHeterogeneousFn isHeterogeneous, const std::string &fieldSuffix = "") - { - // Loop through weight update model variables - const A archetypeAdaptor(m_GroupMerged.getArchetype()); - for(const auto &v : archetypeAdaptor.getDefs()) { - // Loop through parameters - for(const auto &p : archetypeAdaptor.getInitialisers().at(v.name).getDerivedParams()) { - if(std::invoke(isHeterogeneous, m_GroupMerged, v.name, p.first)) { - defineScalarField(p.first, v.name + fieldSuffix, - [p, v](const auto &g, size_t) - { - return A(g).getInitialisers().at(v.name).getDerivedParams().at(p.first); - }); - } - else { - defineField(m_GroupMerged.getScalarType().addQualifier(Type::Qualifier::CONSTANT), p.first + v.name); - } - } - } - } - - void defineVars(const Models::Base::VarVec &vars, const std::string &arrayPrefix, - const std::string &fieldSuffix = "", GetVarSuffixFn getVarSuffixFn = &GroupInternal::getName) - { - // Loop through variables - for(const auto &v : vars) { - definePointerField(v.type, v.name, arrayPrefix, getVarAccessMode(v.access), - fieldSuffix, getVarSuffixFn); - } - } - - template - void defineVarReferences(const Models::Base::VarRefVec &varReferences, const std::string &arrayPrefix, - const std::string &fieldSuffix = "", GetVarReferencesFn getVarRefFn = &GroupInternal::getVarReferences) - { - // Loop through variables - for(const auto &v : varReferences) { - // If variable access is read-only, qualify type with const - const auto resolvedType = v.type.resolve(m_GroupMerged.getTypeContext()); - const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addQualifier(Type::Qualifier::CONSTANT) : resolvedType; - defineField(qualifiedType, v.name, - resolvedType.createPointer(), v.name + fieldSuffix, - [arrayPrefix, getVarRefFn, v](const auto &g, size_t) - { - const auto varRef = std::invoke(getVarRefFn, g).at(v.name); - return arrayPrefix + varRef.getVar().name + varRef.getTargetName(); - }); - } - } - - void defineEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "", - const std::string &fieldSuffix = "", GetVarSuffixFn getVarSuffixFn = &GroupInternal::getName) - { - for(const auto &e : egps) { - const auto pointerType = e.type.resolve(m_GroupMerged.getTypeContext()).createPointer(); - defineField(pointerType, e.name, - pointerType, e.name + varName + fieldSuffix, - [arrayPrefix, e, varName, getVarSuffixFn](const auto &g, size_t) - { - return arrayPrefix + e.name + varName + std::invoke(getVarSuffixFn, g); - }, - GroupMergedFieldType::DYNAMIC); - } - } - -private: - //--------------------------------------------------------------------------- - // Private methods - //--------------------------------------------------------------------------- - void addField(std::pair> &type) - { - // If this type has an associated field - if (type.second) { - // Call function to add field to underlying merge group - // **THINK** std::apply should work here but doesn't seem to - /*std::apply(&G::addField, std::tuple_cat(std::make_tuple(m_GroupMerged), - *type.second));*/ - m_GroupMerged.addField(std::get<0>(*type.second), std::get<1>(*type.second), - std::get<2>(*type.second), std::get<3>(*type.second)); - - // Reset optional field so it doesn't get added again - type.second.reset(); - } - } - //--------------------------------------------------------------------------- - // Members - //--------------------------------------------------------------------------- - G &m_GroupMerged; - EnvironmentBase *m_Enclosing; - - std::unordered_map>> m_Types; -}; -} // namespace GeNN::CodeGenerator diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 1daa1c0cbd..3b33c4cb0e 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -5,7 +5,6 @@ // GeNN code generator includes #include "code_generator/environment.h" -#include "code_generator/groupMergedTypeEnvironment.h" #include "code_generator/modelSpecMerged.h" // GeNN transpiler includes diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 058961550b..bfeefdad27 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -1,7 +1,6 @@ #include "code_generator/initGroupMerged.h" // GeNN code generator includes -#include "code_generator/groupMergedTypeEnvironment.h" #include "code_generator/modelSpecMerged.h" #include "code_generator/standardLibrary.h" diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 2612d3ed71..d07c366ad8 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -1,7 +1,6 @@ #include "code_generator/neuronUpdateGroupMerged.h" // GeNN code generator includes -#include "code_generator/groupMergedTypeEnvironment.h" #include "code_generator/modelSpecMerged.h" // GeNN transpiler includes diff --git a/src/genn/genn/genn.vcxproj b/src/genn/genn/genn.vcxproj index 66b9642e74..70cb4851e3 100644 --- a/src/genn/genn/genn.vcxproj +++ b/src/genn/genn/genn.vcxproj @@ -66,7 +66,6 @@ - From 46647b48dd7d4c47406887fe4c53467a465f5014 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 10:26:09 +0100 Subject: [PATCH 402/725] fixed circular dependency issue caused by templated ``genXXXIndexCalculation`` methods in ``BackendBase``. Moved templated code into anonymous namespace of backendBase.cc and exposed for sensible types only requiring forward declarations --- .../genn/genn/code_generator/backendBase.h | 275 +-------------- .../backends/single_threaded_cpu/backend.cc | 28 +- src/genn/genn/code_generator/backendBase.cc | 329 +++++++++++++++++- src/genn/genn/code_generator/backendSIMT.cc | 30 +- 4 files changed, 371 insertions(+), 291 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 3f83c4d8e3..ed578e78be 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -44,9 +44,12 @@ class ModelSpecMerged; template class GroupMerged; class NeuronUpdateGroupMerged; +class NeuronPrevSpikeTimeUpdateGroupMerged; +class NeuronSpikeQueueUpdateGroupMerged; class PresynapticUpdateGroupMerged; class PostsynapticUpdateGroupMerged; class SynapseDynamicsGroupMerged; +class SynapseDendriticDelayUpdateGroupMerged; class CustomConnectivityUpdateGroupMerged; class CustomUpdateGroupMerged; class CustomUpdateWUGroupMerged; @@ -470,265 +473,21 @@ class GENN_EXPORT BackendBase bool areSixtyFourBitSynapseIndicesRequired(const GroupMerged &sg) const; - template - void genNeuronIndexCalculation(EnvironmentGroupMergedField &env, unsigned int batchSize) const - { - env.addField(Type::Uint32.addConst(), "num_neurons", - Type::Uint32, "numNeurons", - [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); - env.addField(Type::Uint32.createPointer(), "_spk_cnt", "spkCnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getName(); }); - env.addField(Type::Uint32.createPointer(), "_spk", "spk", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getName(); }); - env.addField(Type::Uint32.createPointer(), "_spk_cnt_evnt", "spkCntEvnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getName(); }); - env.addField(Type::Uint32.createPointer(), "_spk_evnt", "spkEvnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkEvnt" + g.getName(); }); - env.addField(Type::Uint32.createPointer(), "_spk_que_ptr", "spkQuePtr", - [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getName(); }); - - env.addField(env.getGroup().getTimeType().createPointer(), "_spk_time", "sT", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "sT" + g.getName(); }); - env.addField(env.getGroup().getTimeType().createPointer(), "_spk_evnt_time", "seT", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "seT" + g.getName(); }); - env.addField(env.getGroup().getTimeType().createPointer(), "_prev_spk_time", "prevST", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "prevST" + g.getName(); }); - env.addField(env.getGroup().getTimeType().createPointer(), "_prev_spk_evnt_time", "prevSET", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "prevSET" + g.getName(); }); - - env.addField(Type::Uint32.createPointer(), "_record_spk", "recordSpk", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "recordSpk" + g.getName(); }, - "", GroupMergedFieldType::DYNAMIC); - env.addField(Type::Uint32.createPointer(), "_record_spk_event", "recordSpkEvent", - [this](const auto &g, size_t){ return getDeviceVarPrefix() + "recordSpkEvent" + g.getName(); }, - "", GroupMergedFieldType::DYNAMIC); - - // If batching is enabled, calculate batch offset - if(batchSize > 1) { - env.add(Type::Uint32.addConst(), "_batchOffset", "batchOffset", - {env.addInitialiser("const unsigned int batchOffset = $(num_neurons) * $(batch);")}); - } - - // If axonal delays are required - if(env.getGroup().getArchetype().isDelayRequired()) { - // We should READ from delay slot before spkQuePtr - const unsigned int numDelaySlots = env.getGroup().getArchetype().getNumDelaySlots(); - const std::string numDelaySlotsStr = std::to_string(numDelaySlots); - env.add(Type::Uint32.addConst(), "_read_delay_slot", "readDelaySlot", - {env.addInitialiser("const unsigned int readDelaySlot = (*$(_spk_que_ptr) + " + std::to_string(numDelaySlots - 1) + ") % " + numDelaySlotsStr+ ";")}); - env.add(Type::Uint32.addConst(), "_read_delay_offset", "readDelayOffset", - {env.addInitialiser("const unsigned int readDelayOffset = $(_read_delay_slot) * $(num_neurons);")}); - - // And we should WRITE to delay slot pointed to be spkQuePtr - env.add(Type::Uint32.addConst(), "_write_delay_slot", "writeDelaySlot", - {env.addInitialiser("const unsigned int writeDelaySlot = * $(_spk_que_ptr);")}); - env.add(Type::Uint32.addConst(), "_write_delay_offset", "writeDelayOffset", - {env.addInitialiser("const unsigned int writeDelayOffset = $(_write_delay_slot) * $(num_neurons);")}); - - // If batching is also enabled - if(batchSize > 1) { - // Calculate batched delay slots - env.add(Type::Uint32.addConst(), "_read_batch_delay_slot", "readBatchDelaySlot", - {env.addInitialiser("const unsigned int readBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_read_delay_slot);")}); - env.add(Type::Uint32.addConst(), "_write_batch_delay_slot", "writeBatchDelaySlot", - {env.addInitialiser("const unsigned int writeBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_write_delay_slot);")}); - - // Calculate current batch offset - env.add(Type::Uint32.addConst(), "_batch_delay_offset", "batchDelayOffset", - {env.addInitialiser("const unsigned int batchDelayOffset = $(_batch_offset) * " + numDelaySlotsStr + ";")}); - - // Calculate further offsets to include delay and batch - env.add(Type::Uint32.addConst(), "_read_batch_delay_offset", "readBatchDelayOffset", - {env.addInitialiser("const unsigned int readBatchDelayOffset = $(_read_delay_offset) + $(_batch_delay_offset);")}); - env.add(Type::Uint32.addConst(), "_write_batch_delay_offset", "writeBatchDelayOffset", - {env.addInitialiser("const unsigned int writeBatchDelayOffset = $(_write_delay_offset)+ $(_batch_delay_offset);")}); - } - } - } - - template - void genSynapseIndexCalculation(EnvironmentGroupMergedField &env, unsigned int batchSize) const - { - // Synapse group fields - env.addField(Type::Uint32.addConst(), "num_pre", - Type::Uint32, "numSrcNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); - env.addField(Type::Uint32.addConst(), "num_post", - Type::Uint32, "numTrgNeurons", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); - env.addField(Type::Uint32, "_row_stride", "rowStride", - [this](const SynapseGroupInternal &sg, size_t) { return std::to_string(getSynapticMatrixRowStride(sg)); }); - env.addField(Type::Uint32, "_col_stride", "colStride", - [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getMaxSourceConnections()); }); - - // Postsynaptic model fields - env.addField(env.getGroup().getScalarType().createPointer(), "_out_post", "outPost", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); - env.addField(env.getGroup().getScalarType().createPointer(), "_den_delay", "denDelay", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); - env.addField(Type::Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr", - [this](const auto &g, size_t) { return getScalarAddressPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); - - // Presynaptic output fields - env.addField(env.getGroup().getScalarType().createPointer(), "_out_pre", "outPre", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); - - // Source neuron fields - env.addField(Type::Uint32.createPointer(), "_src_spk_que_ptr", "srcSpkQuePtr", - [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getSrcNeuronGroup()->getName(); }); - env.addField(Type::Uint32.createPointer(), "_src_spk_cnt", "srcSpkCnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getSrcNeuronGroup()->getName(); }); - env.addField(Type::Uint32.createPointer(), "_src_spk", "srcSpk", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getSrcNeuronGroup()->getName(); }); - env.addField(Type::Uint32.createPointer(), "_src_spk_evnt_cnt", "srcSpkCntEvnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getSrcNeuronGroup()->getName(); }); - env.addField(Type::Uint32.createPointer(), "_src_spk_evnt", "srcSpkEvnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkEvnt" + g.getSrcNeuronGroup()->getName(); }); - - // Target neuron fields - env.addField(Type::Uint32.createPointer(), "_trg_spk_que_ptr", "trgSpkQuePtr", - [this](const auto &g, size_t) { return getScalarAddressPrefix() + "spkQuePtr" + g.getTrgNeuronGroup()->getName(); }); - env.addField(Type::Uint32.createPointer(), "_trg_spk_cnt", "trgSpkCnt", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpkCnt" + g.getTrgNeuronGroup()->getName(); }); - env.addField(Type::Uint32.createPointer(), "_trg_spk", "trgSpk", - [this](const auto &g, size_t) { return getDeviceVarPrefix() + "glbSpk" + g.getTrgNeuronGroup()->getName(); }); - - // Connectivity fields - if(env.getGroup().getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { - env.addField(Type::Uint32.createPointer(), "_gp", "gp", - [this](const auto &sg, size_t) { return getDeviceVarPrefix() + "gp" + sg.getName(); }); - } - else if(env.getGroup().getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - env.addField(Type::Uint32.createPointer(), "_row_length", "rowLength", - [this](const auto &sg, size_t) { return getDeviceVarPrefix() + "rowLength" + sg.getName(); }); - env.addField(env.getGroup().getArchetype().getSparseIndType().createPointer(), "_ind", "ind", - [this](const auto &sg, size_t) { return getDeviceVarPrefix() + "ind" + sg.getName(); }); - env.addField(Type::Uint32.createPointer(), "_col_length", "colLength", - [this](const auto &sg, size_t) { return getDeviceVarPrefix() + "colLength" + sg.getName(); }); - env.addField(Type::Uint32.createPointer(), "_remap", "remap", - [this](const auto &sg, size_t) { return getDeviceVarPrefix() + "remap" + sg.getName(); }); - } - - // If batching is enabled - if(batchSize > 1) { - // Calculate batch offsets into pre and postsynaptic populations - env.add(Type::Uint32.addConst(), "_pre_batch_offset", "preBatchOffset", - {env.addInitialiser("const unsigned int preBatchOffset = $(num_pre) * $(batch);")}); - env.add(Type::Uint32.addConst(), "_post_batch_offset", "postBatchOffset", - {env.addInitialiser("const unsigned int preBatchOffset = $(num_post) * $(batch);")}); - - // Calculate batch offsets into synapse arrays, using 64-bit arithmetic if necessary - if(areSixtyFourBitSynapseIndicesRequired(env.getGroup())) { - assert(false); - //os << "const uint64_t synBatchOffset = (uint64_t)preBatchOffset * (uint64_t)group->rowStride;" << std::endl; - } - else { - env.add(Type::Uint32.addConst(), "_syn_batch_offset", "synBatchOffset", - {env.addInitialiser("const unsigned int synBatchOffset = $(_pre_batch_offset) * $(_row_stride);")}); - } - - // If synapse group has kernel - const auto &kernelSize = env.getGroup().getArchetype().getKernelSize(); - if(!kernelSize.empty()) { - // Loop through kernel dimensions and multiply together - // **TODO** extract list of kernel size variables referenced - std::ostringstream kernBatchOffsetInit; - kernBatchOffsetInit << "const unsigned int kernBatchOffset = "; - for(size_t i = 0; i < kernelSize.size(); i++) { - kernBatchOffsetInit << getKernelSize(env.getGroup(), i) << " * "; - } - - // And finally by batch - kernBatchOffsetInit << "$(batch);" << std::endl; - - env.add(Type::Uint32.addConst(), "_kern_batch_offset", "kernBatchOffset", - {env.addInitialiser(kernBatchOffsetInit.str())}); - } - } - - // If presynaptic neuron group has variable queues, calculate offset to read from its variables with axonal delay - if(env.getGroup().getArchetype().getSrcNeuronGroup()->isDelayRequired()) { - const unsigned int numDelaySteps = env.getGroup().getArchetype().getDelaySteps(); - const unsigned int numSrcDelaySlots = env.getGroup().getArchetype().getSrcNeuronGroup()->getNumDelaySlots(); - - std::ostringstream preDelaySlotInit; - preDelaySlotInit << "const unsigned int preDelaySlot = "; - if(numDelaySteps == 0) { - preDelaySlotInit << "*$(_src_spk_que_ptr);" << std::endl; - } - else { - preDelaySlotInit << "(*$(_src_spk_que_ptr) + " << (numSrcDelaySlots - numDelaySteps) << ") % " << numSrcDelaySlots << ";" << std::endl; - } - env.add(Type::Uint32, "_pre_delay_slot", "preDelaySlot", - {env.addInitialiser(preDelaySlotInit.str())}); - - env.add(Type::Uint32, "_pre_delay_offset", "preDelayOffset", - {env.addInitialiser("const unsigned int preDelayOffset = $(_pre_delay_slot) * $(num_pre);")}); + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; + + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; - if(batchSize > 1) { - env.add(Type::Uint32, "_pre_batch_delay_slot", "preBatchDelaySlot", - {env.addInitialiser("const unsigned int preBatchDelaySlot = $(_pre_delay_slot) + ($(batch) * " + std::to_string(numSrcDelaySlots) + ");")}); - env.add(Type::Uint32, "_pre_batch_delay_offset", "preBatchDelayOffset", - {env.addInitialiser("const unsigned int preBatchDelayOffset = $(_pre_delay_offset) + ($(_pre_batch_offset) * " + std::to_string(numSrcDelaySlots) + ");")}); - } - - if(env.getGroup().getArchetype().getWUModel()->isPrevPreSpikeTimeRequired() - || env.getGroup().getArchetype().getWUModel()->isPrevPreSpikeEventTimeRequired()) - { - env.add(Type::Uint32, "_pre_prev_spike_time_delay_offset", "prePrevSpikeTimeDelayOffset", - {env.addInitialiser("const unsigned int prePrevSpikeTimeDelayOffset = ((*$(_src_spk_que_ptr) + " - + std::to_string(numSrcDelaySlots - numDelaySteps - 1) + ") % " + std::to_string(numSrcDelaySlots) + ") * $(num_pre);")}); - - if(batchSize > 1) { - env.add(Type::Uint32, "_pre_prev_spike_time_batch_delay_offset", "prePrevSpikeTimeBatchDelayOffset", - {env.addInitialiser("const unsigned int prePrevSpikeTimeBatchDelayOffset = $(_pre_prev_spike_time_delay_offset) + ($(_pre_batch_offset) * " + std::to_string(numSrcDelaySlots) + ");")}); - } - } - } - - // If postsynaptic neuron group has variable queues, calculate offset to read from its variables at current time - if(env.getGroup().getArchetype().getTrgNeuronGroup()->isDelayRequired()) { - const unsigned int numBackPropDelaySteps = env.getGroup().getArchetype().getBackPropDelaySteps(); - const unsigned int numTrgDelaySlots = env.getGroup().getArchetype().getTrgNeuronGroup()->getNumDelaySlots(); - - std::ostringstream postDelaySlotInit; - postDelaySlotInit << "const unsigned int postDelaySlot = "; - if(numBackPropDelaySteps == 0) { - postDelaySlotInit << "*$(_trg_spk_que_ptr);" << std::endl; - } - else { - postDelaySlotInit << "(*$(_trg_spk_que_ptr) + " << (numTrgDelaySlots - numBackPropDelaySteps) << ") % " << numTrgDelaySlots << ";" << std::endl; - } - env.add(Type::Uint32, "_post_delay_slot", "postDelaySlot", - {env.addInitialiser(postDelaySlotInit.str())}); - - env.add(Type::Uint32, "_post_delay_offset", "postDelayOffset", - {env.addInitialiser("const unsigned int postDelayOffset = $(_post_delay_slot) * $(num_post);")}); - - if(batchSize > 1) { - env.add(Type::Uint32, "_post_batch_delay_slot", "postBatchDelaySlot", - {env.addInitialiser("const unsigned int postBatchDelaySlot =$(_post_delay_slot) + (batch * " + std::to_string(numTrgDelaySlots) + ");")}); - env.add(Type::Uint32, "_post_batch_delay_offset", "postBatchDelayOffset", - {env.addInitialiser("const unsigned int postBatchDelayOffset = $(_post_delay_offset) + ($(_post_batch_offset) * " + std::to_string(numTrgDelaySlots) + ");")}); - } - - if(env.getGroup().getArchetype().getWUModel()->isPrevPostSpikeTimeRequired()) { - env.add(Type::Uint32, "_post_prev_spike_time_delay_offset", "postPrevSpikeTimeDelayOffset", - {env.addInitialiser("const unsigned int postPrevSpikeTimeDelayOffset = ((*$(_trg_spk_que_ptr) + " - + std::to_string(numTrgDelaySlots - numBackPropDelaySteps - 1) + ") % " + std::to_string(numTrgDelaySlots) + ") * $(num_post);")}); - - if(batchSize > 1) { - env.add(Type::Uint32, "_post_prev_spike_time_batch_delay_offset", "postPrevSpikeTimeBatchDelayOffset", - {env.addInitialiser("const unsigned int postPrevSpikeTimeBatchDelayOffset = $(_post_prev_spike_time_delay_offset) + ($(_post_batch_offset) * " + std::to_string(numTrgDelaySlots) + ");")}); - } - - } - } - } - void genCustomUpdateIndexCalculation(EnvironmentGroupMergedField &env) const; - - void genCustomConnectivityUpdateIndexCalculation(EnvironmentGroupMergedField &env) const; - //! Get backend-specific pointer size in bytes size_t getPointerBytes() const{ return m_PointerBytes; } diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 63a30e814c..eac43679b3 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -142,7 +142,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, n); - genNeuronIndexCalculation(groupEnv, 1); + buildStandardEnvironment(groupEnv, 1); if(n.getArchetype().isDelayRequired()) { if(n.getArchetype().isPrevSpikeTimeRequired()) { @@ -197,7 +197,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Get reference to group funcEnv.getStream() << "const auto *group = &mergedNeuronSpikeQueueUpdateGroup" << n.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, n); - genNeuronIndexCalculation(groupEnv, 1); + buildStandardEnvironment(groupEnv, 1); // Generate spike count reset n.genMergedGroupSpikeCountReset(groupEnv, 1); @@ -218,7 +218,7 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Get reference to group funcEnv.getStream() << "const auto *group = &mergedNeuronUpdateGroup" << n.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, n); - genNeuronIndexCalculation(groupEnv, 1); + buildStandardEnvironment(groupEnv, 1); // If spike or spike-like event recording is in use if(n.getArchetype().isSpikeRecordingEnabled() || n.getArchetype().isSpikeEventRecordingEnabled()) { @@ -327,9 +327,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Bac // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, s); - - // **TODO** rename as it does more! - genSynapseIndexCalculation(groupEnv, 1); + buildStandardEnvironment(groupEnv, 1); // Loop through presynaptic neurons groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)"; @@ -403,8 +401,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Bac // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, s); - - genSynapseIndexCalculation(groupEnv, 1); + buildStandardEnvironment(groupEnv, 1); // generate the code for processing spike-like events if (s.getArchetype().isSpikeEventRequired()) { @@ -438,8 +435,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Bac // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, s); - - genSynapseIndexCalculation(groupEnv, 1); + buildStandardEnvironment(groupEnv, 1); // Get number of postsynaptic spikes if (s.getArchetype().getTrgNeuronGroup()->isDelayRequired() && s.getArchetype().getTrgNeuronGroup()->isTrueSpikeRequired()) { @@ -582,7 +578,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, c); - genCustomUpdateIndexCalculation(groupEnv); + buildStandardEnvironment(groupEnv); if (c.getArchetype().isNeuronReduction()) { // Initialise reduction targets @@ -731,7 +727,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, c); - genCustomConnectivityUpdateIndexCalculation(groupEnv); + buildStandardEnvironment(groupEnv); // Loop through presynaptic neurons funcEnv.getStream() << "for(unsigned int i = 0; i < " << funcEnv["num_pre"] << "; i++)"; @@ -868,7 +864,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: funcEnv.getStream() << "const auto *group = &mergedNeuronInitGroup" << n.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, n); - genNeuronIndexCalculation(groupEnv, 1); + buildStandardEnvironment(groupEnv, 1); n.generateInit(*this, groupEnv, modelMerged); } }); @@ -888,7 +884,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedSynapseInitGroup" << s.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, s); - genSynapseIndexCalculation(groupEnv, 1); + buildStandardEnvironment(groupEnv, 1); s.generateInit(*this, groupEnv, modelMerged); } @@ -981,7 +977,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedSynapseConnectivityInitGroup" << s.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, s); - genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); + buildStandardEnvironment(groupEnv, modelMerged.getModel().getBatchSize()); // If matrix connectivity is ragged if(s.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { @@ -1138,7 +1134,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedSynapseSparseInitGroup" << s.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, s); - genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); + buildStandardEnvironment(groupEnv, modelMerged.getModel().getBatchSize()); // If postsynaptic learning is required, initially zero column lengths if (!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 92edc89b7b..8950739d6a 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -8,8 +8,278 @@ #include "code_generator/groupMerged.h" #include "code_generator/customConnectivityUpdateGroupMerged.h" #include "code_generator/customUpdateGroupMerged.h" +#include "code_generator/initGroupMerged.h" #include "code_generator/neuronUpdateGroupMerged.h" +#include "code_generator/synapseUpdateGroupMerged.h" +using namespace GeNN; +using namespace GeNN::CodeGenerator; + +//-------------------------------------------------------------------------- +// Anonymous namespace +//-------------------------------------------------------------------------- +namespace +{ +template +void buildStandardNeuronEnvironment(const BackendBase &backend, EnvironmentGroupMergedField &env, unsigned int batchSize) +{ + using namespace Type; + + env.addField(Uint32.addConst(), "num_neurons", + Uint32, "numNeurons", + [](const auto &ng, size_t) { return std::to_string(ng.getNumNeurons()); }); + env.addField(Uint32.createPointer(), "_spk_cnt", "spkCnt", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkCnt" + g.getName(); }); + env.addField(Uint32.createPointer(), "_spk", "spk", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpk" + g.getName(); }); + env.addField(Uint32.createPointer(), "_spk_cnt_evnt", "spkCntEvnt", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getName(); }); + env.addField(Uint32.createPointer(), "_spk_evnt", "spkEvnt", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkEvnt" + g.getName(); }); + env.addField(Uint32.createPointer(), "_spk_que_ptr", "spkQuePtr", + [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + g.getName(); }); + + env.addField(env.getGroup().getTimeType().createPointer(), "_spk_time", "sT", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "sT" + g.getName(); }); + env.addField(env.getGroup().getTimeType().createPointer(), "_spk_evnt_time", "seT", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "seT" + g.getName(); }); + env.addField(env.getGroup().getTimeType().createPointer(), "_prev_spk_time", "prevST", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "prevST" + g.getName(); }); + env.addField(env.getGroup().getTimeType().createPointer(), "_prev_spk_evnt_time", "prevSET", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "prevSET" + g.getName(); }); + + env.addField(Uint32.createPointer(), "_record_spk", "recordSpk", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "recordSpk" + g.getName(); }, + "", GroupMergedFieldType::DYNAMIC); + env.addField(Uint32.createPointer(), "_record_spk_event", "recordSpkEvent", + [&backend](const auto &g, size_t){ return backend.getDeviceVarPrefix() + "recordSpkEvent" + g.getName(); }, + "", GroupMergedFieldType::DYNAMIC); + + // If batching is enabled, calculate batch offset + if(batchSize > 1) { + env.add(Uint32.addConst(), "_batchOffset", "batchOffset", + {env.addInitialiser("const unsigned int batchOffset = $(num_neurons) * $(batch);")}); + } + + // If axonal delays are required + if(env.getGroup().getArchetype().isDelayRequired()) { + // We should READ from delay slot before spkQuePtr + const unsigned int numDelaySlots = env.getGroup().getArchetype().getNumDelaySlots(); + const std::string numDelaySlotsStr = std::to_string(numDelaySlots); + env.add(Uint32.addConst(), "_read_delay_slot", "readDelaySlot", + {env.addInitialiser("const unsigned int readDelaySlot = (*$(_spk_que_ptr) + " + std::to_string(numDelaySlots - 1) + ") % " + numDelaySlotsStr+ ";")}); + env.add(Uint32.addConst(), "_read_delay_offset", "readDelayOffset", + {env.addInitialiser("const unsigned int readDelayOffset = $(_read_delay_slot) * $(num_neurons);")}); + + // And we should WRITE to delay slot pointed to be spkQuePtr + env.add(Uint32.addConst(), "_write_delay_slot", "writeDelaySlot", + {env.addInitialiser("const unsigned int writeDelaySlot = * $(_spk_que_ptr);")}); + env.add(Uint32.addConst(), "_write_delay_offset", "writeDelayOffset", + {env.addInitialiser("const unsigned int writeDelayOffset = $(_write_delay_slot) * $(num_neurons);")}); + + // If batching is also enabled + if(batchSize > 1) { + // Calculate batched delay slots + env.add(Uint32.addConst(), "_read_batch_delay_slot", "readBatchDelaySlot", + {env.addInitialiser("const unsigned int readBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_read_delay_slot);")}); + env.add(Uint32.addConst(), "_write_batch_delay_slot", "writeBatchDelaySlot", + {env.addInitialiser("const unsigned int writeBatchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_write_delay_slot);")}); + + // Calculate current batch offset + env.add(Uint32.addConst(), "_batch_delay_offset", "batchDelayOffset", + {env.addInitialiser("const unsigned int batchDelayOffset = $(_batch_offset) * " + numDelaySlotsStr + ";")}); + + // Calculate further offsets to include delay and batch + env.add(Uint32.addConst(), "_read_batch_delay_offset", "readBatchDelayOffset", + {env.addInitialiser("const unsigned int readBatchDelayOffset = $(_read_delay_offset) + $(_batch_delay_offset);")}); + env.add(Uint32.addConst(), "_write_batch_delay_offset", "writeBatchDelayOffset", + {env.addInitialiser("const unsigned int writeBatchDelayOffset = $(_write_delay_offset)+ $(_batch_delay_offset);")}); + } + } +} +//-------------------------------------------------------------------------- +template +void buildStandardSynapseEnvironment(const BackendBase &backend, EnvironmentGroupMergedField &env, unsigned int batchSize) +{ + using namespace Type; + + // Synapse group fields + env.addField(Uint32.addConst(), "num_pre", + Uint32, "numSrcNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getSrcNeuronGroup()->getNumNeurons()); }); + env.addField(Uint32.addConst(), "num_post", + Uint32, "numTrgNeurons", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getTrgNeuronGroup()->getNumNeurons()); }); + env.addField(Uint32, "_row_stride", "rowStride", + [&backend](const SynapseGroupInternal &sg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(sg)); }); + env.addField(Uint32, "_col_stride", "colStride", + [](const SynapseGroupInternal &sg, size_t) { return std::to_string(sg.getMaxSourceConnections()); }); + + // Postsynaptic model fields + env.addField(env.getGroup().getScalarType().createPointer(), "_out_post", "outPost", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); + env.addField(env.getGroup().getScalarType().createPointer(), "_den_delay", "denDelay", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); + env.addField(Uint32.createPointer(), "_den_delay_ptr", "denDelayPtr", + [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "denDelayPtr" + g.getFusedPSVarSuffix(); }); + + // Presynaptic output fields + env.addField(env.getGroup().getScalarType().createPointer(), "_out_pre", "outPre", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); + + // Source neuron fields + env.addField(Uint32.createPointer(), "_src_spk_que_ptr", "srcSpkQuePtr", + [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + g.getSrcNeuronGroup()->getName(); }); + env.addField(Uint32.createPointer(), "_src_spk_cnt", "srcSpkCnt", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkCnt" + g.getSrcNeuronGroup()->getName(); }); + env.addField(Uint32.createPointer(), "_src_spk", "srcSpk", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpk" + g.getSrcNeuronGroup()->getName(); }); + env.addField(Uint32.createPointer(), "_src_spk_evnt_cnt", "srcSpkCntEvnt", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkCntEvnt" + g.getSrcNeuronGroup()->getName(); }); + env.addField(Uint32.createPointer(), "_src_spk_evnt", "srcSpkEvnt", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkEvnt" + g.getSrcNeuronGroup()->getName(); }); + + // Target neuron fields + env.addField(Uint32.createPointer(), "_trg_spk_que_ptr", "trgSpkQuePtr", + [&backend](const auto &g, size_t) { return backend.getScalarAddressPrefix() + "spkQuePtr" + g.getTrgNeuronGroup()->getName(); }); + env.addField(Uint32.createPointer(), "_trg_spk_cnt", "trgSpkCnt", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpkCnt" + g.getTrgNeuronGroup()->getName(); }); + env.addField(Uint32.createPointer(), "_trg_spk", "trgSpk", + [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "glbSpk" + g.getTrgNeuronGroup()->getName(); }); + + // Connectivity fields + if(env.getGroup().getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { + env.addField(Uint32.createPointer(), "_gp", "gp", + [&backend](const auto &sg, size_t) { return backend.getDeviceVarPrefix() + "gp" + sg.getName(); }); + } + else if(env.getGroup().getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { + env.addField(Uint32.createPointer(), "_row_length", "rowLength", + [&backend](const auto &sg, size_t) { return backend.getDeviceVarPrefix() + "rowLength" + sg.getName(); }); + env.addField(env.getGroup().getArchetype().getSparseIndType().createPointer(), "_ind", "ind", + [&backend](const auto &sg, size_t) { return backend.getDeviceVarPrefix() + "ind" + sg.getName(); }); + env.addField(Uint32.createPointer(), "_col_length", "colLength", + [&backend](const auto &sg, size_t) { return backend.getDeviceVarPrefix() + "colLength" + sg.getName(); }); + env.addField(Uint32.createPointer(), "_remap", "remap", + [&backend](const auto &sg, size_t) { return backend.getDeviceVarPrefix() + "remap" + sg.getName(); }); + } + + // If batching is enabled + if(batchSize > 1) { + // Calculate batch offsets into pre and postsynaptic populations + env.add(Uint32.addConst(), "_pre_batch_offset", "preBatchOffset", + {env.addInitialiser("const unsigned int preBatchOffset = $(num_pre) * $(batch);")}); + env.add(Uint32.addConst(), "_post_batch_offset", "postBatchOffset", + {env.addInitialiser("const unsigned int preBatchOffset = $(num_post) * $(batch);")}); + + // Calculate batch offsets into synapse arrays, using 64-bit arithmetic if necessary + if(backend.areSixtyFourBitSynapseIndicesRequired(env.getGroup())) { + assert(false); + //os << "const uint64_t synBatchOffset = (uint64_t)preBatchOffset * (uint64_t)group->rowStride;" << std::endl; + } + else { + env.add(Uint32.addConst(), "_syn_batch_offset", "synBatchOffset", + {env.addInitialiser("const unsigned int synBatchOffset = $(_pre_batch_offset) * $(_row_stride);")}); + } + + // If synapse group has kernel + const auto &kernelSize = env.getGroup().getArchetype().getKernelSize(); + if(!kernelSize.empty()) { + // Loop through kernel dimensions and multiply together + // **TODO** extract list of kernel size variables referenced + std::ostringstream kernBatchOffsetInit; + kernBatchOffsetInit << "const unsigned int kernBatchOffset = "; + for(size_t i = 0; i < kernelSize.size(); i++) { + kernBatchOffsetInit << getKernelSize(env.getGroup(), i) << " * "; + } + + // And finally by batch + kernBatchOffsetInit << "$(batch);" << std::endl; + + env.add(Uint32.addConst(), "_kern_batch_offset", "kernBatchOffset", + {env.addInitialiser(kernBatchOffsetInit.str())}); + } + } + + // If presynaptic neuron group has variable queues, calculate offset to read from its variables with axonal delay + if(env.getGroup().getArchetype().getSrcNeuronGroup()->isDelayRequired()) { + const unsigned int numDelaySteps = env.getGroup().getArchetype().getDelaySteps(); + const unsigned int numSrcDelaySlots = env.getGroup().getArchetype().getSrcNeuronGroup()->getNumDelaySlots(); + + std::ostringstream preDelaySlotInit; + preDelaySlotInit << "const unsigned int preDelaySlot = "; + if(numDelaySteps == 0) { + preDelaySlotInit << "*$(_src_spk_que_ptr);" << std::endl; + } + else { + preDelaySlotInit << "(*$(_src_spk_que_ptr) + " << (numSrcDelaySlots - numDelaySteps) << ") % " << numSrcDelaySlots << ";" << std::endl; + } + env.add(Uint32, "_pre_delay_slot", "preDelaySlot", + {env.addInitialiser(preDelaySlotInit.str())}); + + env.add(Uint32, "_pre_delay_offset", "preDelayOffset", + {env.addInitialiser("const unsigned int preDelayOffset = $(_pre_delay_slot) * $(num_pre);")}); + + if(batchSize > 1) { + env.add(Uint32, "_pre_batch_delay_slot", "preBatchDelaySlot", + {env.addInitialiser("const unsigned int preBatchDelaySlot = $(_pre_delay_slot) + ($(batch) * " + std::to_string(numSrcDelaySlots) + ");")}); + env.add(Uint32, "_pre_batch_delay_offset", "preBatchDelayOffset", + {env.addInitialiser("const unsigned int preBatchDelayOffset = $(_pre_delay_offset) + ($(_pre_batch_offset) * " + std::to_string(numSrcDelaySlots) + ");")}); + } + + if(env.getGroup().getArchetype().getWUModel()->isPrevPreSpikeTimeRequired() + || env.getGroup().getArchetype().getWUModel()->isPrevPreSpikeEventTimeRequired()) + { + env.add(Uint32, "_pre_prev_spike_time_delay_offset", "prePrevSpikeTimeDelayOffset", + {env.addInitialiser("const unsigned int prePrevSpikeTimeDelayOffset = ((*$(_src_spk_que_ptr) + " + + std::to_string(numSrcDelaySlots - numDelaySteps - 1) + ") % " + std::to_string(numSrcDelaySlots) + ") * $(num_pre);")}); + + if(batchSize > 1) { + env.add(Uint32, "_pre_prev_spike_time_batch_delay_offset", "prePrevSpikeTimeBatchDelayOffset", + {env.addInitialiser("const unsigned int prePrevSpikeTimeBatchDelayOffset = $(_pre_prev_spike_time_delay_offset) + ($(_pre_batch_offset) * " + std::to_string(numSrcDelaySlots) + ");")}); + } + } + } + + // If postsynaptic neuron group has variable queues, calculate offset to read from its variables at current time + if(env.getGroup().getArchetype().getTrgNeuronGroup()->isDelayRequired()) { + const unsigned int numBackPropDelaySteps = env.getGroup().getArchetype().getBackPropDelaySteps(); + const unsigned int numTrgDelaySlots = env.getGroup().getArchetype().getTrgNeuronGroup()->getNumDelaySlots(); + + std::ostringstream postDelaySlotInit; + postDelaySlotInit << "const unsigned int postDelaySlot = "; + if(numBackPropDelaySteps == 0) { + postDelaySlotInit << "*$(_trg_spk_que_ptr);" << std::endl; + } + else { + postDelaySlotInit << "(*$(_trg_spk_que_ptr) + " << (numTrgDelaySlots - numBackPropDelaySteps) << ") % " << numTrgDelaySlots << ";" << std::endl; + } + env.add(Uint32, "_post_delay_slot", "postDelaySlot", + {env.addInitialiser(postDelaySlotInit.str())}); + + env.add(Uint32, "_post_delay_offset", "postDelayOffset", + {env.addInitialiser("const unsigned int postDelayOffset = $(_post_delay_slot) * $(num_post);")}); + + if(batchSize > 1) { + env.add(Uint32, "_post_batch_delay_slot", "postBatchDelaySlot", + {env.addInitialiser("const unsigned int postBatchDelaySlot =$(_post_delay_slot) + (batch * " + std::to_string(numTrgDelaySlots) + ");")}); + env.add(Uint32, "_post_batch_delay_offset", "postBatchDelayOffset", + {env.addInitialiser("const unsigned int postBatchDelayOffset = $(_post_delay_offset) + ($(_post_batch_offset) * " + std::to_string(numTrgDelaySlots) + ");")}); + } + + if(env.getGroup().getArchetype().getWUModel()->isPrevPostSpikeTimeRequired()) { + env.add(Uint32, "_post_prev_spike_time_delay_offset", "postPrevSpikeTimeDelayOffset", + {env.addInitialiser("const unsigned int postPrevSpikeTimeDelayOffset = ((*$(_trg_spk_que_ptr) + " + + std::to_string(numTrgDelaySlots - numBackPropDelaySteps - 1) + ") % " + std::to_string(numTrgDelaySlots) + ") * $(num_post);")}); + + if(batchSize > 1) { + env.add(Uint32, "_post_prev_spike_time_batch_delay_offset", "postPrevSpikeTimeBatchDelayOffset", + {env.addInitialiser("const unsigned int postPrevSpikeTimeBatchDelayOffset = $(_post_prev_spike_time_delay_offset) + ($(_post_batch_offset) * " + std::to_string(numTrgDelaySlots) + ");")}); + } + + } + } +} +} //-------------------------------------------------------------------------- // GeNN::CodeGenerator::BackendBase //-------------------------------------------------------------------------- @@ -33,7 +303,42 @@ bool BackendBase::areSixtyFourBitSynapseIndicesRequired(const GroupMerged &env) const +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const +{ + buildStandardNeuronEnvironment(*this, env, batchSize); +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const +{ + buildStandardNeuronEnvironment(*this, env, batchSize); +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const +{ + buildStandardNeuronEnvironment(*this, env, batchSize); +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const +{ + buildStandardSynapseEnvironment(*this, env, batchSize); +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const +{ + buildStandardSynapseEnvironment(*this, env, batchSize); +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const +{ + buildStandardSynapseEnvironment(*this, env, batchSize); +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const +{ + buildStandardSynapseEnvironment(*this, env, batchSize); +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const { // Add size field env.addField(Type::Uint32, "size", "size", @@ -73,7 +378,7 @@ void BackendBase::genCustomUpdateIndexCalculation(EnvironmentGroupMergedField &env) const +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const { // If there are delays on presynaptic variable references if(env.getGroup().getArchetype().getPreDelayNeuronGroup() != nullptr) { @@ -87,6 +392,26 @@ void BackendBase::genCustomConnectivityUpdateIndexCalculation(EnvironmentGroupMe {env.addInitialiser("const unsigned int postDelayOffset = (*$(_post_spk_que_ptr) * $(num_post));")}); } } +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const +{ + buildStandardNeuronEnvironment(*this, env, batchSize); +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const +{ + buildStandardSynapseEnvironment(*this, env, batchSize); +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const +{ + buildStandardSynapseEnvironment(*this, env, batchSize); +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const +{ + buildStandardSynapseEnvironment(*this, env, batchSize); +} //---------------------------------------------------------------------------- std::string BackendBase::getReductionInitialValue(VarAccessMode access, const Type::ResolvedType &type) const { diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index e9a9bee8f1..93e5de1cbd 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -326,7 +326,7 @@ void BackendSIMT::genNeuronPrevSpikeTimeUpdateKernel(EnvironmentExternalBase &en // Create matching environment EnvironmentGroupMergedField neuronEnv(popEnv, ng); - genNeuronIndexCalculation(neuronEnv, batchSize); + buildStandardEnvironment(neuronEnv, batchSize); // If neuron group requires delays if(ng.getArchetype().isDelayRequired()) { @@ -423,7 +423,7 @@ void BackendSIMT::genNeuronSpikeQueueUpdateKernel(EnvironmentExternalBase &env, // Create matching environment EnvironmentGroupMergedField neuronEnv(env, n); - genNeuronIndexCalculation(neuronEnv, batchSize); + buildStandardEnvironment(neuronEnv, batchSize); if(n.getArchetype().isDelayRequired()) { // with delay neuronEnv.printLine("*$(_spk_que_ptr) = (*$(_spk_que_ptr) + 1) % " + std::to_string(n.getArchetype().getNumDelaySlots()) + ";"); @@ -507,7 +507,7 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM { CodeStream::Scope b(popEnv.getStream()); EnvironmentGroupMergedField groupEnv(popEnv, ng); - genNeuronIndexCalculation(groupEnv, batchSize); + buildStandardEnvironment(groupEnv, batchSize); // Call handler to generate generic neuron code groupEnv.print("if($(id) < $(num_neurons))"); @@ -690,7 +690,7 @@ void BackendSIMT::genSynapseDendriticDelayUpdateKernel(EnvironmentExternalBase & // Use this to get reference to merged group structure env.getStream() << getPointerPrefix() << "struct MergedSynapseDendriticDelayUpdateGroup" << sg.getIndex() << " *group = &d_mergedSynapseDendriticDelayUpdateGroup" << sg.getIndex() << "[id - " << idStart << "]; " << std::endl; EnvironmentGroupMergedField groupEnv(env, sg); - genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); + buildStandardEnvironment(groupEnv, modelMerged.getModel().getBatchSize()); groupEnv.printLine("*$(_den_delay_ptr) = (*$(_den_delay_ptr) + 1) % " + std::to_string(sg.getArchetype().getMaxDendriticDelayTimesteps()) + ";"); } idStart += sg.getGroups().size(); @@ -740,7 +740,7 @@ void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, Model LOGD_BACKEND << "Using '" << typeid(*presynapticUpdateStrategy).name() << "' presynaptic update strategy for merged synapse group '" << sg.getIndex() << "'"; // Generate index calculation code - genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); + buildStandardEnvironment(groupEnv, modelMerged.getModel().getBatchSize()); // Generate preamble presynapticUpdateStrategy->genPreamble(groupEnv, modelMerged, sg, *this); @@ -786,7 +786,7 @@ void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, Mode // Generate index calculation code const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - genSynapseIndexCalculation(groupEnv, batchSize); + buildStandardEnvironment(groupEnv, batchSize); groupEnv.printLine("const unsigned int numSpikes = $(_trg_spk_cnt)[" + sg.getPostSlot(batchSize) + "];"); @@ -869,7 +869,7 @@ void BackendSIMT::genSynapseDynamicsKernel(EnvironmentExternalBase &env, ModelSp // Generate index calculation code const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - genSynapseIndexCalculation(groupEnv, batchSize); + buildStandardEnvironment(groupEnv, batchSize); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { groupEnv.print("if ($(id) < ($(num_pre) * $(_row_stride)))"); @@ -948,7 +948,7 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge CodeStream::Scope b(groupEnv.getStream()); groupEnv.add(Type::Uint32.addConst(), "batch", "batch"); - genCustomUpdateIndexCalculation(groupEnv); + buildStandardEnvironment(groupEnv); // **THINK** it would be great to 'lift' reads of SHARED variables out of this loop cg.generateCustomUpdate(*this, groupEnv); @@ -978,7 +978,7 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge groupEnv.getStream() << "const unsigned int batch = " << env["id"] << " / 32;" << std::endl; groupEnv.add(Type::Uint32.addConst(), "batch", "batch"); - genCustomUpdateIndexCalculation(groupEnv); + buildStandardEnvironment(groupEnv); // Initialise reduction targets const auto reductionTargets = genInitReductionTargets(groupEnv.getStream(), cg); @@ -1038,7 +1038,7 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge { CodeStream::Scope b(groupEnv.getStream()); - genCustomUpdateIndexCalculation(groupEnv); + buildStandardEnvironment(groupEnv); cg.generateCustomUpdate(*this, groupEnv); } } @@ -1068,7 +1068,7 @@ void BackendSIMT::genCustomUpdateKernel(EnvironmentExternal &env, ModelSpecMerge { CodeStream::Scope b(groupEnv.getStream()); - genCustomUpdateIndexCalculation(groupEnv); + buildStandardEnvironment(groupEnv); cg.generateCustomUpdate(*this, groupEnv); } } @@ -1324,7 +1324,7 @@ void BackendSIMT::genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env { EnvironmentGroupMergedField groupEnv(env, cg); - genCustomConnectivityUpdateIndexCalculation(groupEnv); + buildStandardEnvironment(groupEnv); groupEnv.getStream() << "// only do this for existing presynaptic neurons" << std::endl; groupEnv.print("if($(id) < $(num_pre))"); @@ -1362,7 +1362,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer [&modelMerged, this](EnvironmentExternalBase &env, NeuronInitGroupMerged &ng) { EnvironmentGroupMergedField groupEnv(env, ng); - genNeuronIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); + buildStandardEnvironment(groupEnv, modelMerged.getModel().getBatchSize()); groupEnv.getStream() << "// only do this for existing neurons" << std::endl; groupEnv.print("if($(id) < $(num_neurons))"); @@ -1538,7 +1538,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer [&modelMerged, this](EnvironmentExternalBase &env, SynapseConnectivityInitGroupMerged &sg) { EnvironmentGroupMergedField groupEnv(env, sg); - genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); + buildStandardEnvironment(groupEnv, modelMerged.getModel().getBatchSize()); // If there is row-building code in this snippet const auto &connectInit = sg.getArchetype().getConnectivityInitialiser(); @@ -1690,7 +1690,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS [&modelMerged, numInitializeThreads, this](EnvironmentExternalBase &env, SynapseSparseInitGroupMerged &sg) { EnvironmentGroupMergedField groupEnv(env, sg); - genSynapseIndexCalculation(groupEnv, modelMerged.getModel().getBatchSize()); + buildStandardEnvironment(groupEnv, modelMerged.getModel().getBatchSize()); // If this post synapse requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id From 1187d3cd32a3d10099c07265ba2da81ebb2c57d8 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 13:01:48 +0100 Subject: [PATCH 403/725] unpicked another nightmare of a circular dependency caused by std::unordered_map NOT allowing forward-declared ValueType on GCC. --- include/genn/genn/currentSource.h | 6 ++-- include/genn/genn/currentSourceInternal.h | 4 +-- include/genn/genn/currentSourceModels.h | 2 +- include/genn/genn/customConnectivityUpdate.h | 16 ++++----- .../genn/customConnectivityUpdateInternal.h | 10 +++--- .../genn/customConnectivityUpdateModels.h | 6 ++-- include/genn/genn/customUpdate.h | 12 +++---- include/genn/genn/customUpdateInternal.h | 4 +-- include/genn/genn/customUpdateModels.h | 2 +- include/genn/genn/gennUtils.h | 9 ++--- .../genn/genn/initSparseConnectivitySnippet.h | 3 ++ .../genn/initToeplitzConnectivitySnippet.h | 3 ++ include/genn/genn/initVarSnippet.h | 34 ++++++++++++++++++- include/genn/genn/modelSpec.h | 18 +++++----- include/genn/genn/models.h | 32 +---------------- include/genn/genn/neuronGroup.h | 6 ++-- include/genn/genn/neuronGroupInternal.h | 4 +-- include/genn/genn/neuronModels.h | 2 +- include/genn/genn/postsynapticModels.h | 2 +- include/genn/genn/snippet.h | 14 +++----- include/genn/genn/synapseGroup.h | 20 +++++------ include/genn/genn/synapseGroupInternal.h | 12 +++---- include/genn/genn/type.h | 5 ++- include/genn/genn/weightUpdateModels.h | 6 ++-- src/genn/genn/currentSource.cc | 2 +- src/genn/genn/currentSourceModels.cc | 5 ++- src/genn/genn/customConnectivityUpdate.cc | 4 +-- .../genn/customConnectivityUpdateModels.cc | 9 +++-- src/genn/genn/customUpdate.cc | 6 ++-- src/genn/genn/customUpdateModels.cc | 3 ++ src/genn/genn/gennUtils.cc | 5 +-- .../genn/initSparseConnectivitySnippet.cc | 3 ++ .../genn/initToeplitzConnectivitySnippet.cc | 3 ++ src/genn/genn/initVarSnippet.cc | 31 +++++++++++++++++ src/genn/genn/models.cc | 29 +--------------- src/genn/genn/neuronGroup.cc | 2 +- src/genn/genn/neuronModels.cc | 5 ++- src/genn/genn/postsynapticModels.cc | 5 ++- src/genn/genn/snippet.cc | 22 +++++++++++- src/genn/genn/synapseGroup.cc | 6 ++-- src/genn/genn/transpiler/parser.cc | 1 + src/genn/genn/weightUpdateModels.cc | 9 +++-- 42 files changed, 214 insertions(+), 168 deletions(-) diff --git a/include/genn/genn/currentSource.h b/include/genn/genn/currentSource.h index f06e607028..5fee6d0cec 100644 --- a/include/genn/genn/currentSource.h +++ b/include/genn/genn/currentSource.h @@ -49,7 +49,7 @@ class GENN_EXPORT CurrentSource const CurrentSourceModels::Base *getCurrentSourceModel() const{ return m_CurrentSourceModel; } const std::unordered_map &getParams() const{ return m_Params; } - const std::unordered_map &getVarInitialisers() const{ return m_VarInitialisers; } + const std::unordered_map &getVarInitialisers() const{ return m_VarInitialisers; } //! Get variable location for current source model state variable VarLocation getVarLocation(const std::string &varName) const; @@ -67,7 +67,7 @@ class GENN_EXPORT CurrentSource protected: CurrentSource(const std::string &name, const CurrentSourceModels::Base *currentSourceModel, - const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, + const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, const NeuronGroupInternal *trgNeuronGroup, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation); @@ -106,7 +106,7 @@ class GENN_EXPORT CurrentSource const CurrentSourceModels::Base *m_CurrentSourceModel; std::unordered_map m_Params; std::unordered_map m_DerivedParams; - std::unordered_map m_VarInitialisers; + std::unordered_map m_VarInitialisers; const NeuronGroupInternal *m_TrgNeuronGroup; diff --git a/include/genn/genn/currentSourceInternal.h b/include/genn/genn/currentSourceInternal.h index a1a0697657..ba4edff676 100644 --- a/include/genn/genn/currentSourceInternal.h +++ b/include/genn/genn/currentSourceInternal.h @@ -12,7 +12,7 @@ class CurrentSourceInternal : public CurrentSource { public: CurrentSourceInternal(const std::string &name, const CurrentSourceModels::Base *currentSourceModel, - const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, + const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, const NeuronGroupInternal *targetNeuronGroup, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : CurrentSource(name, currentSourceModel, params, varInitialisers, targetNeuronGroup, @@ -46,7 +46,7 @@ class CurrentSourceVarAdapter Models::Base::VarVec getDefs() const{ return m_CS.getCurrentSourceModel()->getVars(); } - const std::unordered_map &getInitialisers() const{ return m_CS.getVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_CS.getVarInitialisers(); } bool isVarDelayed(const std::string&) const{ return false; } diff --git a/include/genn/genn/currentSourceModels.h b/include/genn/genn/currentSourceModels.h index c44979d7d2..1a043b0d6d 100644 --- a/include/genn/genn/currentSourceModels.h +++ b/include/genn/genn/currentSourceModels.h @@ -40,7 +40,7 @@ class GENN_EXPORT Base : public Models::Base //! Validate names of parameters etc void validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, + const std::unordered_map &varValues, const std::string &description) const; }; diff --git a/include/genn/genn/customConnectivityUpdate.h b/include/genn/genn/customConnectivityUpdate.h index b19ca387cd..7639f5cc0c 100644 --- a/include/genn/genn/customConnectivityUpdate.h +++ b/include/genn/genn/customConnectivityUpdate.h @@ -46,9 +46,9 @@ class GENN_EXPORT CustomConnectivityUpdate const CustomConnectivityUpdateModels::Base *getCustomConnectivityUpdateModel() const { return m_CustomConnectivityUpdateModel; } const std::unordered_map &getParams() const { return m_Params; } - const std::unordered_map &getVarInitialisers() const { return m_VarInitialisers; } - const std::unordered_map &getPreVarInitialisers() const { return m_PreVarInitialisers; } - const std::unordered_map &getPostVarInitialisers() const { return m_PostVarInitialisers; } + const std::unordered_map &getVarInitialisers() const { return m_VarInitialisers; } + const std::unordered_map &getPreVarInitialisers() const { return m_PreVarInitialisers; } + const std::unordered_map &getPostVarInitialisers() const { return m_PostVarInitialisers; } const std::unordered_map &getVarReferences() const{ return m_VarReferences; } const std::unordered_map &getPreVarReferences() const{ return m_PreVarReferences; } @@ -75,8 +75,8 @@ class GENN_EXPORT CustomConnectivityUpdate protected: CustomConnectivityUpdate(const std::string &name, const std::string &updateGroupName, SynapseGroupInternal *synapseGroup, const CustomConnectivityUpdateModels::Base *customConnectivityUpdateModel, - const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, - const std::unordered_map &preVarInitialisers, const std::unordered_map &postVarInitialisers, + const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, + const std::unordered_map &preVarInitialisers, const std::unordered_map &postVarInitialisers, const std::unordered_map &varReferences, const std::unordered_map &preVarReferences, const std::unordered_map &postVarReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation); @@ -135,9 +135,9 @@ class GENN_EXPORT CustomConnectivityUpdate const CustomConnectivityUpdateModels::Base *m_CustomConnectivityUpdateModel; const std::unordered_map m_Params; std::unordered_map m_DerivedParams; - std::unordered_map m_VarInitialisers; - std::unordered_map m_PreVarInitialisers; - std::unordered_map m_PostVarInitialisers; + std::unordered_map m_VarInitialisers; + std::unordered_map m_PreVarInitialisers; + std::unordered_map m_PostVarInitialisers; //! Location of individual state variables std::vector m_VarLocation; diff --git a/include/genn/genn/customConnectivityUpdateInternal.h b/include/genn/genn/customConnectivityUpdateInternal.h index 7ef824bb8e..ca03cbe35a 100644 --- a/include/genn/genn/customConnectivityUpdateInternal.h +++ b/include/genn/genn/customConnectivityUpdateInternal.h @@ -14,8 +14,8 @@ class CustomConnectivityUpdateInternal : public CustomConnectivityUpdate public: CustomConnectivityUpdateInternal(const std::string &name, const std::string &updateGroupName, SynapseGroupInternal *synapseGroup, const CustomConnectivityUpdateModels::Base *customConnectivityUpdateModel, - const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, - const std::unordered_map &preVarInitialisers, const std::unordered_map &postVarInitialisers, + const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, + const std::unordered_map &preVarInitialisers, const std::unordered_map &postVarInitialisers, const std::unordered_map &varReferences, const std::unordered_map &preVarReferences, const std::unordered_map &postVarReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) @@ -55,7 +55,7 @@ class CustomConnectivityUpdateVarAdapter Models::Base::VarVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getVars(); } - const std::unordered_map &getInitialisers() const{ return m_CU.getVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_CU.getVarInitialisers(); } const std::string &getNameSuffix() const{ return m_CU.getName(); } @@ -82,7 +82,7 @@ class CustomConnectivityUpdatePreVarAdapter Models::Base::VarVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPreVars(); } - const std::unordered_map &getInitialisers() const{ return m_CU.getPreVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_CU.getPreVarInitialisers(); } bool isVarDelayed(const std::string &) const { return false; } @@ -111,7 +111,7 @@ class CustomConnectivityUpdatePostVarAdapter Models::Base::VarVec getDefs() const{ return m_CU.getCustomConnectivityUpdateModel()->getPostVars(); } - const std::unordered_map &getInitialisers() const{ return m_CU.getPostVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_CU.getPostVarInitialisers(); } bool isVarDelayed(const std::string &) const { return false; } diff --git a/include/genn/genn/customConnectivityUpdateModels.h b/include/genn/genn/customConnectivityUpdateModels.h index 7c9c82a872..9d7ddeb15a 100644 --- a/include/genn/genn/customConnectivityUpdateModels.h +++ b/include/genn/genn/customConnectivityUpdateModels.h @@ -72,9 +72,9 @@ class GENN_EXPORT Base : public Models::Base //! Validate names of parameters etc void validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, - const std::unordered_map &preVarValues, - const std::unordered_map &postVarValues, + const std::unordered_map &varValues, + const std::unordered_map &preVarValues, + const std::unordered_map &postVarValues, const std::unordered_map &varRefTargets, const std::unordered_map &preVarRefTargets, const std::unordered_map &postVarRefTargets, diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index bf50bcf64b..81c49d4524 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -39,7 +39,7 @@ class GENN_EXPORT CustomUpdateBase const CustomUpdateModels::Base *getCustomUpdateModel() const{ return m_CustomUpdateModel; } const std::unordered_map &getParams() const{ return m_Params; } - const std::unordered_map &getVarInitialisers() const{ return m_VarInitialisers; } + const std::unordered_map &getVarInitialisers() const{ return m_VarInitialisers; } //! Get variable location for custom update model state variable VarLocation getVarLocation(const std::string &varName) const; @@ -52,7 +52,7 @@ class GENN_EXPORT CustomUpdateBase protected: CustomUpdateBase(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, - const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, + const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation); //------------------------------------------------------------------------ @@ -148,7 +148,7 @@ class GENN_EXPORT CustomUpdateBase const CustomUpdateModels::Base *m_CustomUpdateModel; const std::unordered_map m_Params; std::unordered_map m_DerivedParams; - std::unordered_map m_VarInitialisers; + std::unordered_map m_VarInitialisers; //! Location of individual state variables std::vector m_VarLocation; @@ -179,7 +179,7 @@ class CustomUpdateVarAdapter Models::Base::VarVec getDefs() const{ return m_CU.getCustomUpdateModel()->getVars(); } - const std::unordered_map &getInitialisers() const{ return m_CU.getVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_CU.getVarInitialisers(); } bool isVarDelayed(const std::string &) const { return false; } @@ -230,7 +230,7 @@ class GENN_EXPORT CustomUpdate : public CustomUpdateBase protected: CustomUpdate(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, - const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, + const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation); //------------------------------------------------------------------------ @@ -281,7 +281,7 @@ class GENN_EXPORT CustomUpdateWU : public CustomUpdateBase protected: CustomUpdateWU(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, - const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, + const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation); //------------------------------------------------------------------------ diff --git a/include/genn/genn/customUpdateInternal.h b/include/genn/genn/customUpdateInternal.h index 9407b71066..b6a08727f8 100644 --- a/include/genn/genn/customUpdateInternal.h +++ b/include/genn/genn/customUpdateInternal.h @@ -14,7 +14,7 @@ class CustomUpdateInternal : public CustomUpdate public: CustomUpdateInternal(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, - const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, + const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : CustomUpdate(name, updateGroupName, customUpdateModel, params, varInitialisers, varReferences, defaultVarLocation, defaultExtraGlobalParamLocation) @@ -70,7 +70,7 @@ class CustomUpdateWUInternal : public CustomUpdateWU public: CustomUpdateWUInternal(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, - const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, + const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : CustomUpdateWU(name, updateGroupName, customUpdateModel, params, varInitialisers, varReferences, defaultVarLocation, defaultExtraGlobalParamLocation) diff --git a/include/genn/genn/customUpdateModels.h b/include/genn/genn/customUpdateModels.h index bfc4fb6d1b..0e71afb341 100644 --- a/include/genn/genn/customUpdateModels.h +++ b/include/genn/genn/customUpdateModels.h @@ -38,7 +38,7 @@ class GENN_EXPORT Base : public Models::Base //! Validate names of parameters etc template void validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, + const std::unordered_map &varValues, const std::unordered_map &varRefTargets, const std::string &description) const { diff --git a/include/genn/genn/gennUtils.h b/include/genn/genn/gennUtils.h index ac674ede7a..22ba4bbc98 100644 --- a/include/genn/genn/gennUtils.h +++ b/include/genn/genn/gennUtils.h @@ -20,17 +20,12 @@ // GeNN includes #include "gennExport.h" +#include "initVarSnippet.h" #include "type.h" // GeNN code generator includes #include "transpiler/token.h" -// Forward declarations -namespace GeNN::Models -{ -class VarInit; -} - //-------------------------------------------------------------------------- // GeNN::Utils //-------------------------------------------------------------------------- @@ -53,7 +48,7 @@ GENN_EXPORT bool isRNGRequired(const std::vector &tokens); //-------------------------------------------------------------------------- //! \brief Does the model with the vectors of variable initialisers and modes require an RNG for the specified init location i.e. host or device //-------------------------------------------------------------------------- -GENN_EXPORT bool isRNGRequired(const std::unordered_map &varInitialisers); +GENN_EXPORT bool isRNGRequired(const std::unordered_map &varInitialisers); //-------------------------------------------------------------------------- //! \brief Is the variable name valid? GeNN variable names must obey C variable naming rules diff --git a/include/genn/genn/initSparseConnectivitySnippet.h b/include/genn/genn/initSparseConnectivitySnippet.h index fb04af8d0e..96c890189f 100644 --- a/include/genn/genn/initSparseConnectivitySnippet.h +++ b/include/genn/genn/initSparseConnectivitySnippet.h @@ -12,6 +12,9 @@ #include "binomial.h" #include "snippet.h" +// GeNN transpiler includes +#include "transpiler/token.h" + //---------------------------------------------------------------------------- // Macros //---------------------------------------------------------------------------- diff --git a/include/genn/genn/initToeplitzConnectivitySnippet.h b/include/genn/genn/initToeplitzConnectivitySnippet.h index a41b3d63c9..f23516732f 100644 --- a/include/genn/genn/initToeplitzConnectivitySnippet.h +++ b/include/genn/genn/initToeplitzConnectivitySnippet.h @@ -13,6 +13,9 @@ #include "binomial.h" #include "snippet.h" +// GeNN transpiler includes +#include "transpiler/token.h" + //---------------------------------------------------------------------------- // Macros //---------------------------------------------------------------------------- diff --git a/include/genn/genn/initVarSnippet.h b/include/genn/genn/initVarSnippet.h index 015976dd98..bb411aa21d 100644 --- a/include/genn/genn/initVarSnippet.h +++ b/include/genn/genn/initVarSnippet.h @@ -3,6 +3,9 @@ // GeNN includes #include "snippet.h" +// GeNN transpiler includes +#include "transpiler/token.h" + //---------------------------------------------------------------------------- // Macros //---------------------------------------------------------------------------- @@ -32,6 +35,35 @@ class GENN_EXPORT Base : public Snippet::Base void validate(const std::unordered_map ¶mValues) const; }; + +//---------------------------------------------------------------------------- +// GeNN::InitVarSnippet::Init +//---------------------------------------------------------------------------- +//! Class used to bind together everything required to initialise a variable: +//! 1. A pointer to a variable initialisation snippet +//! 2. The parameters required to control the variable initialisation snippet +class Init : public Snippet::Init +{ +public: + Init(const Base *snippet, const std::unordered_map ¶ms); + Init(double constant); + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + bool isRNGRequired() const; + + bool isKernelRequired() const; + + const std::vector &getCodeTokens() const{ return m_CodeTokens; } + +private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + std::vector m_CodeTokens; +}; + //---------------------------------------------------------------------------- // GeNN::InitVarSnippet::Uninitialised //---------------------------------------------------------------------------- @@ -50,7 +82,7 @@ class Uninitialised : public Base * - \c value - The value to intialise the variable to - \note This snippet type is seldom used directly - Models::VarInit + \note This snippet type is seldom used directly - InitVarSnippet::Init has an implicit constructor that, internally, creates one of these snippets*/ class Constant : public Base { diff --git a/include/genn/genn/modelSpec.h b/include/genn/genn/modelSpec.h index 1b8608ba72..6b7ea6a78d 100644 --- a/include/genn/genn/modelSpec.h +++ b/include/genn/genn/modelSpec.h @@ -44,7 +44,7 @@ Part of the code generation and generated code sections. namespace GeNN { using ParamValues = std::unordered_map; -using VarValues = std::unordered_map; +using VarValues = std::unordered_map; using VarReferences = std::unordered_map; using WUVarReferences = std::unordered_map; @@ -67,28 +67,28 @@ enum class TimePrecision //! Initialise a variable using an initialisation snippet /*! \tparam S type of variable initialisation snippet (derived from InitVarSnippet::Base). \param params parameters for snippet wrapped in ParamValues object. - \return Models::VarInit object for use within model's VarValues*/ + \return InitVarSnippet::Init object for use within model's VarValues*/ template -inline Models::VarInit initVar(const ParamValues ¶ms) +inline InitVarSnippet::Init initVar(const ParamValues ¶ms) { - return Models::VarInit(S::getInstance(), params); + return InitVarSnippet::Init(S::getInstance(), params); } //! Initialise a variable using an initialisation snippet /*! \tparam S type of variable initialisation snippet (derived from InitVarSnippet::Base). - \return Models::VarInit object for use within model's VarValues*/ + \return InitVarSnippet::Init object for use within model's VarValues*/ template -inline Models::VarInit initVar() +inline InitVarSnippet::Init initVar() { - return Models::VarInit(S::getInstance(), {}); + return InitVarSnippet::Init(S::getInstance(), {}); } //! Mark a variable as uninitialised /*! This means that the backend will not generate any automatic initialization code, but will instead copy the variable from host to device during ``initializeSparse`` function */ -inline Models::VarInit uninitialisedVar() +inline InitVarSnippet::Init uninitialisedVar() { - return Models::VarInit(InitVarSnippet::Uninitialised::getInstance(), {}); + return InitVarSnippet::Init(InitVarSnippet::Uninitialised::getInstance(), {}); } //! Initialise connectivity using a sparse connectivity snippet diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 6002d61c93..e9bfef8420 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -9,7 +9,6 @@ // GeNN includes #include "initVarSnippet.h" -#include "snippet.h" #include "type.h" #include "varAccess.h" @@ -112,40 +111,11 @@ class GENN_EXPORT Base : public Snippet::Base //! Validate names of parameters etc void validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, + const std::unordered_map &varValues, const std::string &description) const; }; - -//---------------------------------------------------------------------------- -// GeNN::Models::VarInit -//---------------------------------------------------------------------------- -//! Class used to bind together everything required to initialise a variable: -//! 1. A pointer to a variable initialisation snippet -//! 2. The parameters required to control the variable initialisation snippet -class VarInit : public Snippet::Init -{ -public: - VarInit(const InitVarSnippet::Base *snippet, const std::unordered_map ¶ms); - VarInit(double constant); - - //------------------------------------------------------------------------ - // Public API - //------------------------------------------------------------------------ - bool isRNGRequired() const; - - bool isKernelRequired() const; - - const std::vector &getCodeTokens() const{ return m_CodeTokens; } - -private: - //------------------------------------------------------------------------ - // Members - //------------------------------------------------------------------------ - std::vector m_CodeTokens; -}; - //---------------------------------------------------------------------------- // GeNN::Models::VarReferenceBase //---------------------------------------------------------------------------- diff --git a/include/genn/genn/neuronGroup.h b/include/genn/genn/neuronGroup.h index 8d01bdb9bf..547310df93 100644 --- a/include/genn/genn/neuronGroup.h +++ b/include/genn/genn/neuronGroup.h @@ -125,7 +125,7 @@ class GENN_EXPORT NeuronGroup const NeuronModels::Base *getNeuronModel() const{ return m_NeuronModel; } const std::unordered_map &getParams() const{ return m_Params; } - const std::unordered_map &getVarInitialisers() const{ return m_VarInitialisers; } + const std::unordered_map &getVarInitialisers() const{ return m_VarInitialisers; } bool isSpikeTimeRequired() const; bool isPrevSpikeTimeRequired() const; @@ -178,7 +178,7 @@ class GENN_EXPORT NeuronGroup protected: NeuronGroup(const std::string &name, int numNeurons, const NeuronModels::Base *neuronModel, - const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, + const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation); //------------------------------------------------------------------------ @@ -290,7 +290,7 @@ class GENN_EXPORT NeuronGroup const NeuronModels::Base *m_NeuronModel; const std::unordered_map m_Params; std::unordered_map m_DerivedParams; - std::unordered_map m_VarInitialisers; + std::unordered_map m_VarInitialisers; std::vector m_InSyn; std::vector m_OutSyn; std::vector m_FusedPSMInSyn; diff --git a/include/genn/genn/neuronGroupInternal.h b/include/genn/genn/neuronGroupInternal.h index 02625354c7..9e8da8077b 100644 --- a/include/genn/genn/neuronGroupInternal.h +++ b/include/genn/genn/neuronGroupInternal.h @@ -12,7 +12,7 @@ class NeuronGroupInternal : public NeuronGroup { public: NeuronGroupInternal(const std::string &name, int numNeurons, const NeuronModels::Base *neuronModel, - const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, + const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : NeuronGroup(name, numNeurons, neuronModel, params, varInitialisers, defaultVarLocation, defaultExtraGlobalParamLocation) @@ -70,7 +70,7 @@ class NeuronVarAdapter Models::Base::VarVec getDefs() const{ return m_NG.getNeuronModel()->getVars(); } - const std::unordered_map &getInitialisers() const{ return m_NG.getVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_NG.getVarInitialisers(); } bool isVarDelayed(const std::string &varName) const{ return m_NG.isVarQueueRequired(varName); } diff --git a/include/genn/genn/neuronModels.h b/include/genn/genn/neuronModels.h index 7994045b1f..9572250d44 100644 --- a/include/genn/genn/neuronModels.h +++ b/include/genn/genn/neuronModels.h @@ -68,7 +68,7 @@ class GENN_EXPORT Base : public Models::Base //! Validate names of parameters etc void validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, + const std::unordered_map &varValues, const std::string &description) const; }; diff --git a/include/genn/genn/postsynapticModels.h b/include/genn/genn/postsynapticModels.h index 715cf5ec04..4980c514dc 100644 --- a/include/genn/genn/postsynapticModels.h +++ b/include/genn/genn/postsynapticModels.h @@ -38,7 +38,7 @@ class GENN_EXPORT Base : public Models::Base //! Validate names of parameters etc void validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, + const std::unordered_map &varValues, const std::string &description) const; }; diff --git a/include/genn/genn/snippet.h b/include/genn/genn/snippet.h index 9eff48738b..36bf8a051b 100644 --- a/include/genn/genn/snippet.h +++ b/include/genn/genn/snippet.h @@ -12,7 +12,6 @@ // GeNN includes #include "gennExport.h" -#include "gennUtils.h" #include "type.h" //---------------------------------------------------------------------------- @@ -74,14 +73,10 @@ class GENN_EXPORT Base //! Additional input variables, row state variables and other things have a name, a type and an initial value struct ParamVal { - ParamVal(const std::string &n, const Type::ResolvedType &t, const std::string &v) : name(n), type(t), value(v) - {} - ParamVal(const std::string &n, const Type::ResolvedType &t, double v) : ParamVal(n, t, Utils::writePreciseString(v)) - {} - ParamVal(const std::string &n, const std::string &t, const std::string &v) : name(n), type(t), value(v) - {} - ParamVal(const std::string &n, const std::string &t, double v) : ParamVal(n, t, Utils::writePreciseString(v)) - {} + ParamVal(const std::string &n, const Type::ResolvedType &t, const std::string &v); + ParamVal(const std::string &n, const Type::ResolvedType &t, double v); + ParamVal(const std::string &n, const std::string &t, const std::string &v); + ParamVal(const std::string &n, const std::string &t, double v); bool operator == (const ParamVal &other) const { @@ -106,7 +101,6 @@ class GENN_EXPORT Base std::function&, double)> func; }; - //---------------------------------------------------------------------------- // Typedefines //---------------------------------------------------------------------------- diff --git a/include/genn/genn/synapseGroup.h b/include/genn/genn/synapseGroup.h index c54565b63f..443e7e6257 100644 --- a/include/genn/genn/synapseGroup.h +++ b/include/genn/genn/synapseGroup.h @@ -155,15 +155,15 @@ class GENN_EXPORT SynapseGroup const WeightUpdateModels::Base *getWUModel() const{ return m_WUModel; } const std::unordered_map &getWUParams() const{ return m_WUParams; } - const std::unordered_map &getWUVarInitialisers() const{ return m_WUVarInitialisers; } - const std::unordered_map &getWUPreVarInitialisers() const{ return m_WUPreVarInitialisers; } - const std::unordered_map &getWUPostVarInitialisers() const{ return m_WUPostVarInitialisers; } + const std::unordered_map &getWUVarInitialisers() const{ return m_WUVarInitialisers; } + const std::unordered_map &getWUPreVarInitialisers() const{ return m_WUPreVarInitialisers; } + const std::unordered_map &getWUPostVarInitialisers() const{ return m_WUPostVarInitialisers; } const std::unordered_map getWUConstInitVals() const; const PostsynapticModels::Base *getPSModel() const{ return m_PSModel; } const std::unordered_map &getPSParams() const{ return m_PSParams; } - const std::unordered_map &getPSVarInitialisers() const{ return m_PSVarInitialisers; } + const std::unordered_map &getPSVarInitialisers() const{ return m_PSVarInitialisers; } const InitSparseConnectivitySnippet::Init &getConnectivityInitialiser() const{ return m_SparseConnectivityInitialiser; } const InitToeplitzConnectivitySnippet::Init &getToeplitzConnectivityInitialiser() const { return m_ToeplitzConnectivityInitialiser; } @@ -203,8 +203,8 @@ class GENN_EXPORT SynapseGroup protected: SynapseGroup(const std::string &name, SynapseMatrixType matrixType, unsigned int delaySteps, - const WeightUpdateModels::Base *wu, const std::unordered_map &wuParams, const std::unordered_map &wuVarInitialisers, const std::unordered_map &wuPreVarInitialisers, const std::unordered_map &wuPostVarInitialisers, - const PostsynapticModels::Base *ps, const std::unordered_map &psParams, const std::unordered_map &psVarInitialisers, + const WeightUpdateModels::Base *wu, const std::unordered_map &wuParams, const std::unordered_map &wuVarInitialisers, const std::unordered_map &wuPreVarInitialisers, const std::unordered_map &wuPostVarInitialisers, + const PostsynapticModels::Base *ps, const std::unordered_map &psParams, const std::unordered_map &psVarInitialisers, NeuronGroupInternal *srcNeuronGroup, NeuronGroupInternal *trgNeuronGroup, const InitSparseConnectivitySnippet::Init &connectivityInitialiser, const InitToeplitzConnectivitySnippet::Init &toeplitzInitialiser, @@ -443,13 +443,13 @@ class GENN_EXPORT SynapseGroup std::unordered_map m_WUDerivedParams; //! Initialisers for weight update model per-synapse variables - std::unordered_map m_WUVarInitialisers; + std::unordered_map m_WUVarInitialisers; //! Initialisers for weight update model per-presynaptic neuron variables - std::unordered_map m_WUPreVarInitialisers; + std::unordered_map m_WUPreVarInitialisers; //! Initialisers for weight update model post-presynaptic neuron variables - std::unordered_map m_WUPostVarInitialisers; + std::unordered_map m_WUPostVarInitialisers; //! Post synapse update model type const PostsynapticModels::Base *m_PSModel; @@ -461,7 +461,7 @@ class GENN_EXPORT SynapseGroup std::unordered_map m_PSDerivedParams; //! Initialisers for post synapse model variables - std::unordered_map m_PSVarInitialisers; + std::unordered_map m_PSVarInitialisers; //! Location of individual per-synapse state variables std::vector m_WUVarLocation; diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h index 3c72c2a81f..556600a79c 100644 --- a/include/genn/genn/synapseGroupInternal.h +++ b/include/genn/genn/synapseGroupInternal.h @@ -13,8 +13,8 @@ class SynapseGroupInternal : public SynapseGroup { public: SynapseGroupInternal(const std::string &name, SynapseMatrixType matrixType, unsigned int delaySteps, - const WeightUpdateModels::Base *wu, const std::unordered_map &wuParams, const std::unordered_map &wuVarInitialisers, const std::unordered_map &wuPreVarInitialisers, const std::unordered_map &wuPostVarInitialisers, - const PostsynapticModels::Base *ps, const std::unordered_map &psParams, const std::unordered_map &psVarInitialisers, + const WeightUpdateModels::Base *wu, const std::unordered_map &wuParams, const std::unordered_map &wuVarInitialisers, const std::unordered_map &wuPreVarInitialisers, const std::unordered_map &wuPostVarInitialisers, + const PostsynapticModels::Base *ps, const std::unordered_map &psParams, const std::unordered_map &psVarInitialisers, NeuronGroupInternal *srcNeuronGroup, NeuronGroupInternal *trgNeuronGroup, const InitSparseConnectivitySnippet::Init &connectivityInitialiser, const InitToeplitzConnectivitySnippet::Init &toeplitzConnectivityInitialiser, @@ -109,7 +109,7 @@ class SynapsePSMVarAdapter Models::Base::VarVec getDefs() const{ return m_SG.getPSModel()->getVars(); } - const std::unordered_map &getInitialisers() const{ return m_SG.getPSVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_SG.getPSVarInitialisers(); } const std::string &getNameSuffix() const{ return m_SG.getFusedPSVarSuffix(); } @@ -161,7 +161,7 @@ class SynapseWUVarAdapter Models::Base::VarVec getDefs() const{ return m_SG.getWUModel()->getVars(); } - const std::unordered_map &getInitialisers() const{ return m_SG.getWUVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_SG.getWUVarInitialisers(); } const std::string &getNameSuffix() const{ return m_SG.getName(); } private: @@ -187,7 +187,7 @@ class SynapseWUPreVarAdapter Models::Base::VarVec getDefs() const{ return m_SG.getWUModel()->getPreVars(); } - const std::unordered_map &getInitialisers() const{ return m_SG.getWUPreVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_SG.getWUPreVarInitialisers(); } const std::string &getNameSuffix() const{ return m_SG.getFusedWUPreVarSuffix(); } @@ -216,7 +216,7 @@ class SynapseWUPostVarAdapter Models::Base::VarVec getDefs() const{ return m_SG.getWUModel()->getPostVars(); } - const std::unordered_map &getInitialisers() const{ return m_SG.getWUPostVarInitialisers(); } + const std::unordered_map &getInitialisers() const{ return m_SG.getWUPostVarInitialisers(); } const std::string &getNameSuffix() const{ return m_SG.getFusedWUPostVarSuffix(); } diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 91597dc09f..134ca0f033 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -6,15 +6,18 @@ // Standard C++ includes #include +#include #include #include #include #include #include +// Boost includes +#include + // GeNN includes #include "gennExport.h" -#include "gennUtils.h" //---------------------------------------------------------------------------- // Macros diff --git a/include/genn/genn/weightUpdateModels.h b/include/genn/genn/weightUpdateModels.h index b394472608..5ec3dedee3 100644 --- a/include/genn/genn/weightUpdateModels.h +++ b/include/genn/genn/weightUpdateModels.h @@ -148,9 +148,9 @@ class GENN_EXPORT Base : public Models::Base //! Validate names of parameters etc void validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, - const std::unordered_map &preVarValues, - const std::unordered_map &postVarValues, + const std::unordered_map &varValues, + const std::unordered_map &preVarValues, + const std::unordered_map &postVarValues, const std::string &description) const; }; diff --git a/src/genn/genn/currentSource.cc b/src/genn/genn/currentSource.cc index e6578bbf82..f7f1ffa40f 100644 --- a/src/genn/genn/currentSource.cc +++ b/src/genn/genn/currentSource.cc @@ -33,7 +33,7 @@ VarLocation CurrentSource::getExtraGlobalParamLocation(const std::string &varNam } //---------------------------------------------------------------------------- CurrentSource::CurrentSource(const std::string &name, const CurrentSourceModels::Base *currentSourceModel, - const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, + const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, const NeuronGroupInternal *trgNeuronGroup, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : m_Name(name), m_CurrentSourceModel(currentSourceModel), m_Params(params), m_VarInitialisers(varInitialisers), diff --git a/src/genn/genn/currentSourceModels.cc b/src/genn/genn/currentSourceModels.cc index a277079083..d45b772c57 100644 --- a/src/genn/genn/currentSourceModels.cc +++ b/src/genn/genn/currentSourceModels.cc @@ -1,5 +1,8 @@ #include "currentSourceModels.h" +// GeNN includes +#include "gennUtils.h" + using namespace GeNN; namespace GeNN::CurrentSourceModels @@ -23,7 +26,7 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const } //---------------------------------------------------------------------------- void Base::validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, + const std::unordered_map &varValues, const std::string &description) const { // Superclass diff --git a/src/genn/genn/customConnectivityUpdate.cc b/src/genn/genn/customConnectivityUpdate.cc index c2258ce045..378cd21d33 100644 --- a/src/genn/genn/customConnectivityUpdate.cc +++ b/src/genn/genn/customConnectivityUpdate.cc @@ -92,8 +92,8 @@ bool CustomConnectivityUpdate::isPostVarInitRequired() const //------------------------------------------------------------------------ CustomConnectivityUpdate::CustomConnectivityUpdate(const std::string &name, const std::string &updateGroupName, SynapseGroupInternal *synapseGroup, const CustomConnectivityUpdateModels::Base *customConnectivityUpdateModel, - const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, - const std::unordered_map &preVarInitialisers, const std::unordered_map &postVarInitialisers, + const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, + const std::unordered_map &preVarInitialisers, const std::unordered_map &postVarInitialisers, const std::unordered_map &varReferences, const std::unordered_map &preVarReferences, const std::unordered_map &postVarReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) diff --git a/src/genn/genn/customConnectivityUpdateModels.cc b/src/genn/genn/customConnectivityUpdateModels.cc index 72e7d09fc1..51292a4584 100644 --- a/src/genn/genn/customConnectivityUpdateModels.cc +++ b/src/genn/genn/customConnectivityUpdateModels.cc @@ -1,5 +1,8 @@ #include "customConnectivityUpdateModels.h" +// GeNN includes +#include "gennUtils.h" + //---------------------------------------------------------------------------- // GeNN::CustomConnectivityUpdateModels::Base //---------------------------------------------------------------------------- @@ -25,9 +28,9 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const } //---------------------------------------------------------------------------- void Base::validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, - const std::unordered_map &preVarValues, - const std::unordered_map &postVarValues, + const std::unordered_map &varValues, + const std::unordered_map &preVarValues, + const std::unordered_map &postVarValues, const std::unordered_map &varRefTargets, const std::unordered_map &preVarRefTargets, const std::unordered_map &postVarRefTargets, diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index 7d856529fa..cd79e3c72b 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -33,7 +33,7 @@ bool CustomUpdateBase::isVarInitRequired() const } //---------------------------------------------------------------------------- CustomUpdateBase::CustomUpdateBase(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, - const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, + const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : m_Name(name), m_UpdateGroupName(updateGroupName), m_CustomUpdateModel(customUpdateModel), m_Params(params), m_VarInitialisers(varInitialisers), m_VarLocation(varInitialisers.size(), defaultVarLocation), @@ -109,7 +109,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateBase::getVarLocationHashDige //---------------------------------------------------------------------------- CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, - const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, + const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : CustomUpdateBase(name, updateGroupName, customUpdateModel, params, varInitialisers, defaultVarLocation, defaultExtraGlobalParamLocation), m_VarReferences(varReferences), m_Size(varReferences.empty() ? 0 : varReferences.begin()->second.getSize()), m_DelayNeuronGroup(nullptr), m_PerNeuron(false) @@ -228,7 +228,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdate::getInitHashDigest() const //---------------------------------------------------------------------------- CustomUpdateWU::CustomUpdateWU(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, - const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, + const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : CustomUpdateBase(name, updateGroupName, customUpdateModel, params, varInitialisers, defaultVarLocation, defaultExtraGlobalParamLocation), m_VarReferences(varReferences), m_SynapseGroup(m_VarReferences.empty() ? nullptr : static_cast(m_VarReferences.begin()->second.getSynapseGroup())) diff --git a/src/genn/genn/customUpdateModels.cc b/src/genn/genn/customUpdateModels.cc index d6915807b7..6416beaca4 100644 --- a/src/genn/genn/customUpdateModels.cc +++ b/src/genn/genn/customUpdateModels.cc @@ -1,5 +1,8 @@ #include "customUpdateModels.h" +// GeNN includes +#include "gennUtils.h" + using namespace GeNN; namespace GeNN::CustomUpdateModels diff --git a/src/genn/genn/gennUtils.cc b/src/genn/genn/gennUtils.cc index d1f74ca804..b75ea836e1 100644 --- a/src/genn/genn/gennUtils.cc +++ b/src/genn/genn/gennUtils.cc @@ -7,9 +7,6 @@ // Standard C includes #include -// GeNN includes -#include "models.h" - // GeNN transpiler includes #include "transpiler/errorHandler.h" #include "transpiler/scanner.h" @@ -122,7 +119,7 @@ bool isRNGRequired(const std::vector &tokens) } //-------------------------------------------------------------------------- -bool isRNGRequired(const std::unordered_map &varInitialisers) +bool isRNGRequired(const std::unordered_map &varInitialisers) { // Return true if any of these variable initialisers require an RNG return std::any_of(varInitialisers.cbegin(), varInitialisers.cend(), diff --git a/src/genn/genn/initSparseConnectivitySnippet.cc b/src/genn/genn/initSparseConnectivitySnippet.cc index da50be7d2a..469d433e78 100644 --- a/src/genn/genn/initSparseConnectivitySnippet.cc +++ b/src/genn/genn/initSparseConnectivitySnippet.cc @@ -1,5 +1,8 @@ #include "initSparseConnectivitySnippet.h" +// GeNN includes +#include "gennUtils.h" + using namespace GeNN; namespace GeNN::InitSparseConnectivitySnippet diff --git a/src/genn/genn/initToeplitzConnectivitySnippet.cc b/src/genn/genn/initToeplitzConnectivitySnippet.cc index bfc6a3bcd9..47e8bcdabe 100644 --- a/src/genn/genn/initToeplitzConnectivitySnippet.cc +++ b/src/genn/genn/initToeplitzConnectivitySnippet.cc @@ -1,5 +1,8 @@ #include "initToeplitzConnectivitySnippet.h" +// GeNN includes +#include "gennUtils.h" + using namespace GeNN; namespace GeNN::InitToeplitzConnectivitySnippet diff --git a/src/genn/genn/initVarSnippet.cc b/src/genn/genn/initVarSnippet.cc index 2992cd4eeb..efcbdfa5c5 100644 --- a/src/genn/genn/initVarSnippet.cc +++ b/src/genn/genn/initVarSnippet.cc @@ -1,5 +1,8 @@ #include "initVarSnippet.h" +// GeNN includes +#include "gennUtils.h" + using namespace GeNN; namespace GeNN::InitVarSnippet @@ -34,4 +37,32 @@ void Base::validate(const std::unordered_map ¶mValues) // Superclass Snippet::Base::validate(paramValues, "Variable initialiser "); } + + +//---------------------------------------------------------------------------- +// Init +//---------------------------------------------------------------------------- +Init::Init(const Base *snippet, const std::unordered_map ¶ms) +: Snippet::Init(snippet, params) +{ + // Scan code tokens + m_CodeTokens = Utils::scanCode(getSnippet()->getCode(), "Variable initialisation code"); +} +//---------------------------------------------------------------------------- +Init::Init(double constant) +: Snippet::Init(Constant::getInstance(), {{"constant", constant}}) +{ + // Scan code tokens + m_CodeTokens = Utils::scanCode(getSnippet()->getCode(), "Variable initialisation code"); +} +//---------------------------------------------------------------------------- +bool Init::isRNGRequired() const +{ + return Utils::isRNGRequired(m_CodeTokens); +} +//---------------------------------------------------------------------------- +bool Init::isKernelRequired() const +{ + return Utils::isIdentifierReferenced("id_kernel", m_CodeTokens); +} } // namespace GeNN::InitVarSnippet \ No newline at end of file diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index ba4d018cd4..9f7ced3a0f 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -21,7 +21,7 @@ void Base::updateHash(boost::uuids::detail::sha1 &hash) const } //---------------------------------------------------------------------------- void Base::validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, + const std::unordered_map &varValues, const std::string &description) const { // Superclass @@ -34,33 +34,6 @@ void Base::validate(const std::unordered_map ¶mValues, Utils::validateInitialisers(vars, varValues, "variable", description); } -//---------------------------------------------------------------------------- -// VarInit -//---------------------------------------------------------------------------- -VarInit::VarInit(const InitVarSnippet::Base *snippet, const std::unordered_map ¶ms) -: Snippet::Init(snippet, params) -{ - // Scan code tokens - m_CodeTokens = Utils::scanCode(getSnippet()->getCode(), "Variable initialisation code"); -} -//---------------------------------------------------------------------------- -VarInit::VarInit(double constant) -: Snippet::Init(InitVarSnippet::Constant::getInstance(), {{"constant", constant}}) -{ - // Scan code tokens - m_CodeTokens = Utils::scanCode(getSnippet()->getCode(), "Variable initialisation code"); -} -//---------------------------------------------------------------------------- -bool VarInit::isRNGRequired() const -{ - return Utils::isRNGRequired(m_CodeTokens); -} -//---------------------------------------------------------------------------- -bool VarInit::isKernelRequired() const -{ - return Utils::isIdentifierReferenced("id_kernel", m_CodeTokens); -} - //---------------------------------------------------------------------------- // VarReference //---------------------------------------------------------------------------- diff --git a/src/genn/genn/neuronGroup.cc b/src/genn/genn/neuronGroup.cc index 8cc15934c4..374aa9d83b 100644 --- a/src/genn/genn/neuronGroup.cc +++ b/src/genn/genn/neuronGroup.cc @@ -305,7 +305,7 @@ bool NeuronGroup::isInitRNGRequired() const } //---------------------------------------------------------------------------- NeuronGroup::NeuronGroup(const std::string &name, int numNeurons, const NeuronModels::Base *neuronModel, - const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, + const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : m_Name(name), m_NumNeurons(numNeurons), m_NeuronModel(neuronModel), m_Params(params), m_VarInitialisers(varInitialisers), m_NumDelaySlots(1), m_VarQueueRequired(varInitialisers.size(), false), m_SpikeLocation(defaultVarLocation), m_SpikeEventLocation(defaultVarLocation), diff --git a/src/genn/genn/neuronModels.cc b/src/genn/genn/neuronModels.cc index 2a187e8395..134ac7b3a5 100644 --- a/src/genn/genn/neuronModels.cc +++ b/src/genn/genn/neuronModels.cc @@ -1,5 +1,8 @@ #include "neuronModels.h" +// GeNN includes +#include "gennUtils.h" + using namespace GeNN; namespace GeNN::NeuronModels @@ -37,7 +40,7 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const } //---------------------------------------------------------------------------- void Base::validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, + const std::unordered_map &varValues, const std::string &description) const { // Superclass diff --git a/src/genn/genn/postsynapticModels.cc b/src/genn/genn/postsynapticModels.cc index 7f54f1a280..79e990e0d3 100644 --- a/src/genn/genn/postsynapticModels.cc +++ b/src/genn/genn/postsynapticModels.cc @@ -1,5 +1,8 @@ #include "postsynapticModels.h" +// GeNN includes +#include "gennUtils.h" + using namespace GeNN; namespace GeNN::PostsynapticModels @@ -25,7 +28,7 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const } //---------------------------------------------------------------------------- void Base::validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, + const std::unordered_map &varValues, const std::string &description) const { // Superclass diff --git a/src/genn/genn/snippet.cc b/src/genn/genn/snippet.cc index 0e511318f2..e688c28539 100644 --- a/src/genn/genn/snippet.cc +++ b/src/genn/genn/snippet.cc @@ -1,13 +1,33 @@ #include "snippet.h" // GeNN includes +#include "gennUtils.h" #include "logging.h" //---------------------------------------------------------------------------- -// GeNN::Snippet::Base +// GeNN::Snippet::Base::ParamVal //---------------------------------------------------------------------------- namespace GeNN::Snippet { + Base::ParamVal::ParamVal(const std::string &n, const Type::ResolvedType &t, const std::string &v) +: name(n), type(t), value(v) +{} + //---------------------------------------------------------------------------- + Base::ParamVal::ParamVal(const std::string &n, const Type::ResolvedType &t, double v) +: ParamVal(n, t, Utils::writePreciseString(v)) +{} + //---------------------------------------------------------------------------- + Base::ParamVal::ParamVal(const std::string &n, const std::string &t, const std::string &v) +: name(n), type(t), value(v) +{} + //---------------------------------------------------------------------------- + Base::ParamVal::ParamVal(const std::string &n, const std::string &t, double v) +: ParamVal(n, t, Utils::writePreciseString(v)) +{} + + //---------------------------------------------------------------------------- +// GeNN::Snippet::Base +//---------------------------------------------------------------------------- void Base::updateHash(boost::uuids::detail::sha1 &hash) const { Utils::updateHash(getParamNames(), hash); diff --git a/src/genn/genn/synapseGroup.cc b/src/genn/genn/synapseGroup.cc index 91db21c27c..023d0e602a 100644 --- a/src/genn/genn/synapseGroup.cc +++ b/src/genn/genn/synapseGroup.cc @@ -17,7 +17,7 @@ //---------------------------------------------------------------------------- namespace { -std::unordered_map getConstInitVals(const std::unordered_map &varInitialisers) +std::unordered_map getConstInitVals(const std::unordered_map &varInitialisers) { // Reserve initial values to match initialisers std::unordered_map initVals; @@ -306,8 +306,8 @@ VarLocation SynapseGroup::getSparseConnectivityExtraGlobalParamLocation(const st } //---------------------------------------------------------------------------- SynapseGroup::SynapseGroup(const std::string &name, SynapseMatrixType matrixType, unsigned int delaySteps, - const WeightUpdateModels::Base *wu, const std::unordered_map &wuParams, const std::unordered_map &wuVarInitialisers, const std::unordered_map &wuPreVarInitialisers, const std::unordered_map &wuPostVarInitialisers, - const PostsynapticModels::Base *ps, const std::unordered_map &psParams, const std::unordered_map &psVarInitialisers, + const WeightUpdateModels::Base *wu, const std::unordered_map &wuParams, const std::unordered_map &wuVarInitialisers, const std::unordered_map &wuPreVarInitialisers, const std::unordered_map &wuPostVarInitialisers, + const PostsynapticModels::Base *ps, const std::unordered_map &psParams, const std::unordered_map &psVarInitialisers, NeuronGroupInternal *srcNeuronGroup, NeuronGroupInternal *trgNeuronGroup, const InitSparseConnectivitySnippet::Init &connectivityInitialiser, const InitToeplitzConnectivitySnippet::Init &toeplitzInitialiser, diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index b11ce05971..01ae62ac59 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include #include diff --git a/src/genn/genn/weightUpdateModels.cc b/src/genn/genn/weightUpdateModels.cc index 7280e232bf..938d2887d3 100644 --- a/src/genn/genn/weightUpdateModels.cc +++ b/src/genn/genn/weightUpdateModels.cc @@ -1,5 +1,8 @@ #include "weightUpdateModels.h" +// GeNN includes +#include "gennUtils.h" + using namespace GeNN; namespace GeNN::WeightUpdateModels @@ -43,9 +46,9 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const } //---------------------------------------------------------------------------- void Base::validate(const std::unordered_map ¶mValues, - const std::unordered_map &varValues, - const std::unordered_map &preVarValues, - const std::unordered_map &postVarValues, + const std::unordered_map &varValues, + const std::unordered_map &preVarValues, + const std::unordered_map &postVarValues, const std::string &description) const { // Superclass From 5fc075d2886937c23ba635e03d952ead79950bc2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 13:05:10 +0100 Subject: [PATCH 404/725] de-inlined another bit of code that relies on GeNNUtils --- include/genn/genn/customUpdateModels.h | 18 ++++++---------- src/genn/genn/customUpdateModels.cc | 30 ++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/include/genn/genn/customUpdateModels.h b/include/genn/genn/customUpdateModels.h index 0e71afb341..dbfa6d51fe 100644 --- a/include/genn/genn/customUpdateModels.h +++ b/include/genn/genn/customUpdateModels.h @@ -36,21 +36,15 @@ class GENN_EXPORT Base : public Models::Base boost::uuids::detail::sha1::digest_type getHashDigest() const; //! Validate names of parameters etc - template void validate(const std::unordered_map ¶mValues, const std::unordered_map &varValues, - const std::unordered_map &varRefTargets, - const std::string &description) const - { - // Superclass - Models::Base::validate(paramValues, varValues, description); + const std::unordered_map &varRefTargets, + const std::string &description) const; - const auto varRefs = getVarRefs(); - Utils::validateVecNames(getVarRefs(), "Variable reference"); - - // Validate variable reference initialisers - Utils::validateInitialisers(varRefs, varRefTargets, "Variable reference", description); - } + void validate(const std::unordered_map ¶mValues, + const std::unordered_map &varValues, + const std::unordered_map &varRefTargets, + const std::string &description) const; }; //---------------------------------------------------------------------------- diff --git a/src/genn/genn/customUpdateModels.cc b/src/genn/genn/customUpdateModels.cc index 6416beaca4..f450a7e898 100644 --- a/src/genn/genn/customUpdateModels.cc +++ b/src/genn/genn/customUpdateModels.cc @@ -23,4 +23,34 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const Utils::updateHash(getVarRefs(), hash); return hash.get_digest(); } +//---------------------------------------------------------------------------- +void Base::validate(const std::unordered_map ¶mValues, + const std::unordered_map &varValues, + const std::unordered_map &varRefTargets, + const std::string &description) const +{ + // Superclass + Models::Base::validate(paramValues, varValues, description); + + const auto varRefs = getVarRefs(); + Utils::validateVecNames(getVarRefs(), "Variable reference"); + + // Validate variable reference initialisers + Utils::validateInitialisers(varRefs, varRefTargets, "Variable reference", description); +} +//---------------------------------------------------------------------------- +void Base::validate(const std::unordered_map ¶mValues, + const std::unordered_map &varValues, + const std::unordered_map &varRefTargets, + const std::string &description) const +{ + // Superclass + Models::Base::validate(paramValues, varValues, description); + + const auto varRefs = getVarRefs(); + Utils::validateVecNames(getVarRefs(), "Variable reference"); + + // Validate variable reference initialisers + Utils::validateInitialisers(varRefs, varRefTargets, "Variable reference", description); +} } // namespace GeNN::CustomUpdateModels \ No newline at end of file From 269466410d48d0d111d5b70920acefe290dea0ff Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 13:08:12 +0100 Subject: [PATCH 405/725] included environment.h now it's safe to do so! --- .../genn/code_generator/customConnectivityUpdateGroupMerged.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index 035c794e4d..c8b7b43ed6 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -5,6 +5,7 @@ // GeNN code generator includes #include "code_generator/codeGenUtils.h" +#include "code_generator/environment.h" #include "code_generator/groupMerged.h" //---------------------------------------------------------------------------- From f6c3876c97f9f3bbc71ca0d32503f1d817d191b8 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 13:30:16 +0100 Subject: [PATCH 406/725] removed default argument from EnvironmentGroupMergedField declaration --- include/genn/genn/code_generator/environment.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 115ffa96eb..e3a304052e 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -376,7 +376,7 @@ class EnvironmentExternal : public EnvironmentExternalDynamicBase +template class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase> { using GroupInternal = typename G::GroupInternal; From 96b7094bd734d9d3d340369e550c57c7be93f8df Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 13:53:56 +0100 Subject: [PATCH 407/725] fixed GCC compilation errors --- .../genn/genn/code_generator/environment.h | 125 ++++++++++-------- .../genn/code_generator/initGroupMerged.cc | 8 +- .../synapseUpdateGroupMerged.cc | 44 +++--- src/genn/genn/transpiler/parser.cc | 1 + 4 files changed, 101 insertions(+), 77 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index e3a304052e..80edc90352 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -207,8 +207,11 @@ class EnvironmentFieldPolicy } private: - std::reference_wrapper m_FieldGroup; + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ std::reference_wrapper m_Group; + std::reference_wrapper m_FieldGroup; }; //---------------------------------------------------------------------------- @@ -395,7 +398,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &(GroupInternal::*)(void) const; public: - using EnvironmentExternalDynamicBase::EnvironmentExternalDynamicBase; + using EnvironmentExternalDynamicBase>::EnvironmentExternalDynamicBase; //------------------------------------------------------------------------ // Public API @@ -404,7 +407,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &initialisers = {}) { - addInternal(type, name, std::make_tuple(false, LazyString{value, *this}, std::nullopt), initialisers); + this->addInternal(type, name, std::make_tuple(false, LazyString{value, *this}, std::nullopt), initialisers); } //! Map a type (for type-checking) and a group merged field to back it to an identifier @@ -413,9 +416,9 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase &initialisers = {}) { - addInternal(type, name, std::make_tuple(false, LazyString{indexSuffix, *this}, - std::make_optional(std::make_tuple(fieldType, fieldName, getFieldValue, mergedFieldType))), - initialisers); + this->addInternal(type, name, std::make_tuple(false, LazyString{indexSuffix, *this}, + std::make_optional(std::make_tuple(fieldType, fieldName, getFieldValue, mergedFieldType))), + initialisers); } //! Map a type (for type-checking) and a group merged field to back it to an identifier @@ -429,7 +432,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getScalarType(); addField(scalarType.addConst(), name, scalarType, name + fieldSuffix, [getFieldValue, scalarType](const auto &g, size_t i) @@ -444,7 +447,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup(), p)) { addScalar(p, fieldSuffix, [p, getParamValues](const auto &g, size_t) { @@ -453,9 +456,9 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getScalarType().addConst(), p, + writePreciseLiteral(std::invoke(getParamValues, this->getGroup().getArchetype()).at(p), + this->getGroup().getScalarType())); } } } @@ -466,7 +469,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup(), d.name)) { addScalar(d.name, fieldSuffix, [d, getDerivedParamValues](const auto &g, size_t) { @@ -475,9 +478,9 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getScalarType().addConst(), d.name, + writePreciseLiteral(std::invoke(getDerivedParamValues, this->getGroup().getArchetype()).at(d.name), + this->getGroup().getScalarType())); } } } @@ -487,7 +490,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getTypeContext()); assert(!resolvedType.isPointer()); const auto pointerType = resolvedType.createPointer(); addField(pointerType, e.name, @@ -505,11 +508,11 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getArchetype()); const auto *snippet = connectInit.getSnippet(); for(const auto &p : snippet->getParamNames()) { // If parameter is heterogeneous, add scalar field - if (std::invoke(isHeterogeneous, getGroup(), p)) { + if (std::invoke(isHeterogeneous, this->getGroup(), p)) { addScalar(p, fieldSuffix, [p, getConnectivity](const auto &g, size_t) { @@ -518,8 +521,8 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getScalarType().addConst(), p, + writePreciseLiteral(connectInit.getParams().at(p), this->getGroup().getScalarType())); } } } @@ -529,11 +532,11 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getArchetype()); const auto *snippet = connectInit.getSnippet(); for(const auto &d : snippet->getDerivedParams()) { // If parameter is heterogeneous, add scalar field - if (std::invoke(isHeterogeneous, getGroup(), d.name)) { + if (std::invoke(isHeterogeneous, this->getGroup(), d.name)) { addScalar(d.name, fieldSuffix, [d, getConnectivity](const auto &g, size_t) { @@ -542,8 +545,8 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getScalarType().addConst(), d.name, + writePreciseLiteral(connectInit.getDerivedParams().at(d.name), this->getGroup().getScalarType())); } } } @@ -553,9 +556,9 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getArchetype()).getInitialisers().at(varName).getParams()) { // If parameter is heterogeneous, add scalar field - if(std::invoke(isHeterogeneous, getGroup(), varName, p.first)) { + if(std::invoke(isHeterogeneous, this->getGroup(), varName, p.first)) { addScalar(p.first, varName + fieldSuffix, [p, varName](const auto &g, size_t) { @@ -564,8 +567,8 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getScalarType().addConst(), p.first, + writePreciseLiteral(p.second, this->getGroup().getScalarType())); } } } @@ -575,9 +578,9 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getArchetype()).getInitialisers().at(varName).getDerivedParams()) { // If derived parameter is heterogeneous, add scalar field - if(std::invoke(isHeterogeneous, getGroup(), varName, p.first)) { + if(std::invoke(isHeterogeneous, this->getGroup(), varName, p.first)) { addScalar(p.first, varName + fieldSuffix, [p, varName](const auto &g, size_t) { @@ -586,8 +589,8 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getScalarType().addConst(), p.first, + writePreciseLiteral(p.second, this->getGroup().getScalarType())); } } } @@ -596,9 +599,9 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getArchetype()); for(const auto &v : archetypeAdaptor.getDefs()) { - const auto resolvedType = v.type.resolve(getGroup().getTypeContext()); + const auto resolvedType = v.type.resolve(this->getGroup().getTypeContext()); const auto qualifiedType = (getVarAccessMode(v.access) & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; addField(qualifiedType, v.name, resolvedType.createPointer(), v.name + fieldSuffix, @@ -621,10 +624,10 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase getIndexFn, const std::string &fieldSuffix = "") { // Loop through variable references - const A archetypeAdaptor(getGroup().getArchetype()); + const A archetypeAdaptor(this->getGroup().getArchetype()); for(const auto &v : archetypeAdaptor.getDefs()) { // If variable access is read-only, qualify type with const - const auto resolvedType = v.type.resolve(getGroup().getTypeContext()); + const auto resolvedType = v.type.resolve(this->getGroup().getTypeContext()); const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; addField(qualifiedType, v.name, resolvedType.createPointer(), v.name + fieldSuffix, @@ -652,6 +655,10 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase>> m_Environment; }; +//------------------------------------------------------------------------ +// GeNN::CodeGenerator::VarCachePolicy +//------------------------------------------------------------------------ +//! Policy for use with EnvironmentLocalCacheBase for caching state variables template class VarCachePolicy { @@ -667,33 +674,45 @@ class VarCachePolicy : m_GetReadIndex(getIndex), m_GetWriteIndex(getIndex) {} - std::string getReadIndex(G &g, const Models::Base::Var &var) + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + std::string getReadIndex(G&, const Models::Base::Var &var) { return m_GetReadIndex(var.name, getVarAccessDuplication(var.access)); } - std::string getWriteIndex(G &g, const Models::Base::Var &var) + std::string getWriteIndex(G&, const Models::Base::Var &var) { return m_GetWriteIndex(var.name, getVarAccessDuplication(var.access)); } - static std::string getVarSuffix(const GroupInternal &g, const Models::Base::Var &var) + //------------------------------------------------------------------------ + // Static API + //------------------------------------------------------------------------ + static std::string getVarSuffix(const GroupInternal &g, const Models::Base::Var&) { return A(g).getNameSuffix(); } private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ GetIndexFn m_GetReadIndex; GetIndexFn m_GetWriteIndex; }; +//------------------------------------------------------------------------ +// GeNN::CodeGenerator::VarRefCachePolicy +//------------------------------------------------------------------------ +//! Policy for use with EnvironmentLocalCacheBase for caching variable references template class VarRefCachePolicy { protected: using GroupInternal = typename G::GroupInternal; using GetIndexFn = std::function; - VarRefCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex) : m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex) @@ -702,7 +721,10 @@ class VarRefCachePolicy VarRefCachePolicy(GetIndexFn getIndex) : m_GetReadIndex(getIndex), m_GetWriteIndex(getIndex) {} - + + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ std::string getReadIndex(G &g, const Models::Base::VarRef &var) { return m_GetReadIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); @@ -713,13 +735,18 @@ class VarRefCachePolicy return m_GetWriteIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); } + //------------------------------------------------------------------------ + // Static API + //------------------------------------------------------------------------ static std::string getVarSuffix(const GroupInternal &g, const Models::Base::VarRef &var) { return A(g).getInitialisers().at(var.name).getTargetName(); } - private: + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ GetIndexFn m_GetReadIndex; GetIndexFn m_GetWriteIndex; }; @@ -748,14 +775,6 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P [](const auto &v){ return std::make_pair(v.name, std::make_pair(false, v)); }); } - /*template - EnvironmentLocalCacheBase(G &group, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, - const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, - PolicyArgs&&... policyArgs) - : EnvironmentLocalVarCache(group, group, context, enclosing, arrayPrefix, fieldSuffix, localPrefix, std::forward(policyArgs)...) - {}*/ - - EnvironmentLocalCacheBase(const EnvironmentLocalCacheBase&) = delete; ~EnvironmentLocalCacheBase() @@ -778,7 +797,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P m_FieldGroup.get().addField(resolvedType.createPointer(), v.name + m_FieldSuffix, [arrayPrefix, v, &group](const typename F::GroupInternal &, size_t i) { - return arrayPrefix + v.name + getVarSuffix(group.getGroups().at(i), v); + return arrayPrefix + v.name + P::getVarSuffix(group.getGroups().at(i), v); }); if(v.access & VarAccessMode::READ_ONLY) { @@ -790,7 +809,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // **NOTE** by not initialising these variables for reductions, // compilers SHOULD emit a warning if user code doesn't set it to something if(!(v.access & VarAccessModeAttribute::REDUCE)) { - getContextStream() << " = group->" << v.name << m_FieldSuffix << "[" << printSubs(getReadIndex(m_Group.get(), v), *this) << "]"; + getContextStream() << " = group->" << v.name << m_FieldSuffix << "[" << printSubs(this->getReadIndex(m_Group.get(), v), *this) << "]"; } getContextStream() << ";" << std::endl; } @@ -802,7 +821,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P for(const auto &v : referencedDefs) { // If variables are read-write if(v.access & VarAccessMode::READ_WRITE) { - getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << printSubs(getWriteIndex(m_Group.get(), v), *this) << "]"; + getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << printSubs(this->getWriteIndex(m_Group.get(), v), *this) << "]"; getContextStream() << " = _" << m_LocalPrefix << v.name << ";" << std::endl; } } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index bfeefdad27..1f58962141 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -75,8 +75,8 @@ void genInitNeuronVarCode(const BackendBase &backend, EnvironmentExternalBase &e // Substitute in parameters and derived parameters for initialising variables EnvironmentGroupMergedField varEnv(env, group, fieldGroup); - varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous, var.name, fieldSuffix); - varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, var.name, fieldSuffix); + varEnv.template addVarInitParams(&G::isVarInitParamHeterogeneous, var.name, fieldSuffix); + varEnv.template addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, var.name, fieldSuffix); varEnv.addExtraGlobalParams(varInit.getSnippet()->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name, fieldSuffix); // Add field for variable itself @@ -157,8 +157,8 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, // Substitute in parameters and derived parameters for initialising variables EnvironmentGroupMergedField varEnv(env, group); - varEnv.addVarInitParams(&G::isVarInitParamHeterogeneous, var.name); - varEnv.addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, var.name); + varEnv.template addVarInitParams(&G::isVarInitParamHeterogeneous, var.name); + varEnv.template addVarInitDerivedParams(&G::isVarInitDerivedParamHeterogeneous, var.name); varEnv.addExtraGlobalParams(varInit.getSnippet()->getExtraGlobalParams(), backend.getDeviceVarPrefix(), var.name); // Add field for variable itself diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index 69d7e8d690..dcc160831a 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -30,16 +30,18 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa synEnv.addExtraGlobalParams(wu->getExtraGlobalParams(), backend.getDeviceVarPrefix()); // Substitute names of pre and postsynaptic weight update variable - synEnv.addVars(backend.getDeviceVarPrefix(), - [&sg, batchSize](VarAccess a, const std::string&) - { - return sg.getPreWUVarIndex(batchSize, getVarAccessDuplication(a), "$(id_pre)"); - }); - synEnv.addVars(backend.getDeviceVarPrefix(), - [&sg, batchSize](VarAccess a, const std::string&) - { - return sg.getPostWUVarIndex(batchSize, getVarAccessDuplication(a), "$(id_post)"); - }); + synEnv.template addVars( + backend.getDeviceVarPrefix(), + [&sg, batchSize](VarAccess a, const std::string&) + { + return sg.getPreWUVarIndex(batchSize, getVarAccessDuplication(a), "$(id_pre)"); + }); + synEnv.template addVars( + backend.getDeviceVarPrefix(), + [&sg, batchSize](VarAccess a, const std::string&) + { + return sg.getPostWUVarIndex(batchSize, getVarAccessDuplication(a), "$(id_post)"); + }); // If this synapse group has a kernel @@ -51,11 +53,12 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa // If weights are individual, substitute variables for values stored in global memory if (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::INDIVIDUAL) { - synEnv.addVars(backend.getDeviceVarPrefix(), - [&sg, batchSize](VarAccess a, const std::string&) - { - return sg.getSynVarIndex(batchSize, getVarAccessDuplication(a), "$(id_syn)"); - }); + synEnv.template addVars( + backend.getDeviceVarPrefix(), + [&sg, batchSize](VarAccess a, const std::string&) + { + return sg.getSynVarIndex(batchSize, getVarAccessDuplication(a), "$(id_syn)"); + }); } // Otherwise, if weights are procedual else if (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::PROCEDURAL) { @@ -101,11 +104,12 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa else if(sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL) { assert(!sg.getArchetype().getKernelSize().empty()); - synEnv.addVars(backend.getDeviceVarPrefix(), - [&sg, batchSize](VarAccess a, const std::string&) - { - return sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), "$(id_kernel)"); - }); + synEnv.template addVars( + backend.getDeviceVarPrefix(), + [&sg, batchSize](VarAccess a, const std::string&) + { + return sg.getKernelVarIndex(batchSize, getVarAccessDuplication(a), "$(id_kernel)"); + }); } // Otherwise, substitute variables for constant values else { diff --git a/src/genn/genn/transpiler/parser.cc b/src/genn/genn/transpiler/parser.cc index 01ae62ac59..6d51aebec9 100644 --- a/src/genn/genn/transpiler/parser.cc +++ b/src/genn/genn/transpiler/parser.cc @@ -2,6 +2,7 @@ // Standard C++ includes #include +#include #include #include #include From 935ca60690bc2f3e8e38a11a2c3224590f03b598 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 14:00:07 +0100 Subject: [PATCH 408/725] removed Backend parameters to ModelSpecMerged constructor --- include/genn/genn/code_generator/modelSpecMerged.h | 2 +- src/genn/backends/cuda/optimiser.cc | 4 ++-- src/genn/genn/code_generator/generateModules.cc | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 3194494151..de391d4b76 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -32,7 +32,7 @@ namespace GeNN::CodeGenerator class GENN_EXPORT ModelSpecMerged { public: - ModelSpecMerged(const ModelSpecInternal &model, const BackendBase &backend) + ModelSpecMerged(const ModelSpecInternal &model) : m_Model(model), m_NeuronUpdateSupportCode("NeuronUpdateSupportCode"), m_PostsynapticDynamicsSupportCode("PostsynapticDynamicsSupportCode"), m_PresynapticUpdateSupportCode("PresynapticUpdateSupportCode"), m_PostsynapticUpdateSupportCode("PostsynapticUpdateSupportCode"), m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode"), m_TypeContext{{"scalar", model.getPrecision()}, {"timepoint", model.getTimePrecision()}} diff --git a/src/genn/backends/cuda/optimiser.cc b/src/genn/backends/cuda/optimiser.cc index 763ef94105..49be881e37 100644 --- a/src/genn/backends/cuda/optimiser.cc +++ b/src/genn/backends/cuda/optimiser.cc @@ -455,7 +455,7 @@ KernelOptimisationOutput optimizeBlockSize(int deviceID, const cudaDeviceProp &d Backend backend(blockSize, preferences, deviceID); // Create merged model - ModelSpecMerged modelMerged(model, backend); + ModelSpecMerged modelMerged(model); // Get memory spaces available to this backend // **NOTE** Memory spaces are given out on a first-come, first-serve basis so subsequent groups are in preferential order @@ -481,7 +481,7 @@ KernelOptimisationOutput optimizeBlockSize(int deviceID, const cudaDeviceProp &d // Calculate module's hash digest // **NOTE** this COULD be done in thread functions but, because when using GeNN from Python, // this will call into Python code it would require whole Python interface to be made thread-safe - const auto hashDigest = (modelMerged.*m.getArchetypeHashDigest)(); + const auto hashDigest = std::invoke(m.getArchetypeHashDigest, modelMerged); // Launch thread to analyse kernels in this module (if required) threads.emplace_back(analyseModule, std::cref(m), r, cuContext, hashDigest, std::cref(outputPath), std::cref(nvccPath), diff --git a/src/genn/genn/code_generator/generateModules.cc b/src/genn/genn/code_generator/generateModules.cc index da1009c8ce..dfc668025e 100644 --- a/src/genn/genn/code_generator/generateModules.cc +++ b/src/genn/genn/code_generator/generateModules.cc @@ -101,7 +101,7 @@ std::pair, MemAlloc> generateAll(const ModelSpecInterna filesystem::create_directory(outputPath); // Create merged model - ModelSpecMerged modelMerged(model, backend); + ModelSpecMerged modelMerged(model); // **TODO** because merged group fields are populated in the same pass // as code is generated, we will need to ALWAYS generate code but only From 45c19554dd079e9e791b809de8d951832272753c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 14:24:41 +0100 Subject: [PATCH 409/725] fixed more GCC warnings --- .../genn/genn/code_generator/backendSIMT.h | 2 +- .../genn/genn/code_generator/environment.h | 2 +- .../presynapticUpdateStrategySIMT.h | 24 ++++++------- .../genn/code_generator/supportCodeMerged.h | 2 +- src/genn/genn/code_generator/backendSIMT.cc | 18 +++++----- src/genn/genn/code_generator/codeGenUtils.cc | 23 ------------- .../customConnectivityUpdateGroupMerged.cc | 4 +-- .../genn/code_generator/generateRunner.cc | 23 ++++++------- .../genn/code_generator/initGroupMerged.cc | 2 +- .../presynapticUpdateStrategySIMT.cc | 34 ++++++++----------- 10 files changed, 53 insertions(+), 81 deletions(-) diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 26981608cf..b28e5b6432 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -484,7 +484,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase } } - void genEmitSpike(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, const std::string &suffix, bool recordingEnabled) const; + void genEmitSpike(EnvironmentExternalBase &env, const std::string &suffix, bool recordingEnabled) const; void genRecordingSharedMemInit(CodeStream &os, const std::string &suffix) const; diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 80edc90352..55a9b86ae9 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -616,7 +616,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase void addVars(const std::string &arrayPrefix, const std::string &indexSuffix, const std::string &fieldSuffix = "") { - addVars(arrayPrefix, [&indexSuffix](VarAccess a, const std::string &) { return indexSuffix; }, + addVars(arrayPrefix, [&indexSuffix](VarAccess, const std::string &) { return indexSuffix; }, fieldSuffix); } diff --git a/include/genn/genn/code_generator/presynapticUpdateStrategySIMT.h b/include/genn/genn/code_generator/presynapticUpdateStrategySIMT.h index ccfa226be4..3e589cbc65 100644 --- a/include/genn/genn/code_generator/presynapticUpdateStrategySIMT.h +++ b/include/genn/genn/code_generator/presynapticUpdateStrategySIMT.h @@ -38,8 +38,8 @@ class Base //! How many neurons does each thread accumulate the outputs of into shared memory virtual size_t getSharedMemoryPerThread(const PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const = 0; - virtual void genPreamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const = 0; + virtual void genPreamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend) const = 0; //! Generate presynaptic update code virtual void genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, @@ -71,8 +71,8 @@ class PreSpan : public Base //! How many neurons does each thread accumulate the outputs of into shared memory virtual size_t getSharedMemoryPerThread(const PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; - virtual void genPreamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; + virtual void genPreamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend) const final; //! Generate presynaptic update code virtual void genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, @@ -104,8 +104,8 @@ class PostSpan : public Base //! How many neurons does each thread accumulate the outputs of into shared memory virtual size_t getSharedMemoryPerThread(const PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; - virtual void genPreamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; + virtual void genPreamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend) const final; //! Generate presynaptic update code virtual void genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, @@ -145,8 +145,8 @@ class PostSpanBitmask : public Base //! How many neurons does each thread accumulate the outputs of into shared memory virtual size_t getSharedMemoryPerThread(const PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; - virtual void genPreamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; + virtual void genPreamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend) const final; //! Generate presynaptic update code virtual void genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, @@ -178,8 +178,8 @@ class PreSpanProcedural : public Base //! How many neurons does each thread accumulate the outputs of into shared memory virtual size_t getSharedMemoryPerThread(const PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; - virtual void genPreamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; + virtual void genPreamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend) const final; //! Generate presynaptic update code virtual void genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, @@ -211,8 +211,8 @@ class PostSpanToeplitz : public Base //! How many neurons does each thread accumulate the outputs of into shared memory virtual size_t getSharedMemoryPerThread(const PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; - virtual void genPreamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; + virtual void genPreamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend) const final; //! Generate presynaptic update code virtual void genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, diff --git a/include/genn/genn/code_generator/supportCodeMerged.h b/include/genn/genn/code_generator/supportCodeMerged.h index f3dd100e2c..fd5b91d2a8 100644 --- a/include/genn/genn/code_generator/supportCodeMerged.h +++ b/include/genn/genn/code_generator/supportCodeMerged.h @@ -44,7 +44,7 @@ class SupportCodeMerged } //! Generate support code - void gen(CodeStream &os, const Type::ResolvedType &scalarType, const bool supportsNamespace = true) const + void gen(CodeStream &os, const Type::ResolvedType &, const bool supportsNamespace = true) const { // Loop through support code for(const auto &s : m_SupportCode) { diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 93e5de1cbd..a127312fc6 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -76,7 +76,7 @@ void BackendSIMT::genVariableInit(EnvironmentExternalBase &env, const std::strin handler(env); } //-------------------------------------------------------------------------- -void BackendSIMT::genKernelSynapseVariableInit(EnvironmentExternalBase &env, SynapseInitGroupMerged &sg, HandlerEnv handler) const +void BackendSIMT::genKernelSynapseVariableInit(EnvironmentExternalBase &env, SynapseInitGroupMerged&, HandlerEnv handler) const { // Variable should already be provided via parallelism //assert(kernelSubs.hasVarSubstitution("id")); @@ -87,7 +87,7 @@ void BackendSIMT::genKernelSynapseVariableInit(EnvironmentExternalBase &env, Syn handler(varEnv); } //-------------------------------------------------------------------------- -void BackendSIMT::genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &cu, HandlerEnv handler) const +void BackendSIMT::genKernelCustomUpdateVariableInit(EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &, HandlerEnv handler) const { // Variable should already be provided via parallelism //assert(kernelSubs.hasVarSubstitution("id")); @@ -522,14 +522,14 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM ng.generateNeuronUpdate(*this, groupEnv, modelMerged, // Emit true spikes - [&modelMerged, this](EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng) + [this](EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng) { - genEmitSpike(env, modelMerged, "", ng.getArchetype().isSpikeRecordingEnabled()); + genEmitSpike(env, "", ng.getArchetype().isSpikeRecordingEnabled()); }, // Emit spike-like events - [&modelMerged, this](EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng) + [this](EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng) { - genEmitSpike(env, modelMerged, "_evnt", ng.getArchetype().isSpikeEventRecordingEnabled()); + genEmitSpike(env, "_evnt", ng.getArchetype().isSpikeEventRecordingEnabled()); }); // Copy local stream back to local @@ -743,7 +743,7 @@ void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, Model buildStandardEnvironment(groupEnv, modelMerged.getModel().getBatchSize()); // Generate preamble - presynapticUpdateStrategy->genPreamble(groupEnv, modelMerged, sg, *this); + presynapticUpdateStrategy->genPreamble(groupEnv, sg, *this); // If spike events should be processed if(sg.getArchetype().isSpikeEventRequired()) { @@ -1744,7 +1744,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // Generate sparse synapse variable initialisation code genSparseSynapseVarInit( groupEnv, modelMerged, cg, true, - [](EnvironmentExternalBase &env, CustomWUUpdateSparseInitGroupMerged&){}); + [](EnvironmentExternalBase&, CustomWUUpdateSparseInitGroupMerged&){}); }); // Initialise weight update variables for synapse groups with sparse connectivity @@ -1775,7 +1775,7 @@ size_t BackendSIMT::padKernelSize(size_t size, Kernel kernel) const return padSize(size, getKernelBlockSize(kernel)); } //-------------------------------------------------------------------------- -void BackendSIMT::genEmitSpike(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, const std::string &suffix, bool recordingEnabled) const +void BackendSIMT::genEmitSpike(EnvironmentExternalBase &env, const std::string &suffix, bool recordingEnabled) const { env.printLine("const unsigned int spk" + suffix + "_idx = " + getAtomic(Type::Uint32, AtomicOperation::ADD, AtomicMemSpace::SHARED) + "(&$(_sh_spk" + suffix + "_count), 1);"); env.printLine("$(_sh_spk" + suffix + ")[spk" + suffix + "_idx] = $(id);"); diff --git a/src/genn/genn/code_generator/codeGenUtils.cc b/src/genn/genn/code_generator/codeGenUtils.cc index 7d2ebe3d07..0925cad075 100644 --- a/src/genn/genn/code_generator/codeGenUtils.cc +++ b/src/genn/genn/code_generator/codeGenUtils.cc @@ -32,34 +32,11 @@ #include "transpiler/parser.h" #include "transpiler/prettyPrinter.h" -//-------------------------------------------------------------------------- -// Anonymous namespace -//-------------------------------------------------------------------------- -namespace -{ -std::string trimWhitespace(const std::string& str) -{ - const std::string whitespace = " \t\r\n"; - - // If string is all whitespace, return empty - const auto strBegin = str.find_first_not_of(whitespace); - if (strBegin == std::string::npos) { - return ""; - } - - const auto strEnd = str.find_last_not_of(whitespace); - const auto strRange = strEnd - strBegin + 1; - - return str.substr(strBegin, strRange); -} -} // Anonymous namespace - //---------------------------------------------------------------------------- // GeNN::CodeGenerator //---------------------------------------------------------------------------- namespace GeNN::CodeGenerator { -//---------------------------------------------------------------------------- void genTypeRange(CodeStream &os, const Type::ResolvedType &type, const std::string &prefix) { const auto &numeric = type.getNumeric(); diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 438b364db5..aeba04e607 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -402,7 +402,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back addPrivateVarRefAccess(bodyEnv, modelMerged.getModel().getBatchSize(), "$(id_syn)"); addPrivateVarRefAccess( bodyEnv, modelMerged.getModel().getBatchSize(), - [](VarAccessMode a, const Models::VarReference &varRef) + [](VarAccessMode, const Models::VarReference &varRef) { if(varRef.getDelayNeuronGroup() != nullptr) { return "$(_post_delay_offset) + $(id_post)"; @@ -436,7 +436,7 @@ bool CustomConnectivityUpdateGroupMerged::isDerivedParamHeterogeneous(const std: //---------------------------------------------------------------------------- const std::string CustomConnectivityHostUpdateGroupMerged::name = "CustomConnectivityHostUpdate"; //---------------------------------------------------------------------------- -void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged&) { CodeStream::Scope b(env.getStream()); diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 3a37578bf0..7d4ff13ecb 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -93,7 +93,7 @@ void genHostScalar(CodeStream &definitionsVar, CodeStream &runnerVarDecl, runnerVarDecl << type.getValue().name << " " << name << " = " << value << ";" << std::endl; } //-------------------------------------------------------------------------- -void genHostDeviceScalar(const ModelSpecMerged &modelMerged, const BackendBase &backend, CodeStream &definitionsVar, +void genHostDeviceScalar(const BackendBase &backend, CodeStream &definitionsVar, CodeStream &definitionsInternalVar, CodeStream &runnerVarDecl, CodeStream &runnerVarAlloc, CodeStream &runnerVarFree, const Type::ResolvedType &type, const std::string &name, const std::string &hostValue, MemAlloc &mem) { @@ -250,11 +250,10 @@ void genStatePushPull(CodeStream &definitionsFunc, CodeStream &runnerPushFunc, C } } //------------------------------------------------------------------------- -void genVariable(const ModelSpecMerged &modelMerged, const BackendBase &backend, - CodeStream &definitionsVar, CodeStream &definitionsFunc, CodeStream &definitionsInternal, - CodeStream &runner, CodeStream &allocations, CodeStream &free, CodeStream &push, CodeStream &pull, - const Type::ResolvedType &type, const std::string &name, - VarLocation loc, bool autoInitialized, size_t count, MemAlloc &mem, +void genVariable(const BackendBase &backend, CodeStream &definitionsVar, CodeStream &definitionsFunc, + CodeStream &definitionsInternal, CodeStream &runner, CodeStream &allocations, + CodeStream &free, CodeStream &push, CodeStream &pull, const Type::ResolvedType &type, + const std::string &name, VarLocation loc, bool autoInitialized, size_t count, MemAlloc &mem, std::vector &statePushPullFunction) { // Generate push and pull functions @@ -416,7 +415,7 @@ void genRunnerVars(const ModelSpecMerged &modelMerged, const BackendBase &backen const auto *varInitSnippet = varAdaptor.getInitialisers().at(var.name).getSnippet(); const bool autoInitialized = !varInitSnippet->getCode().empty(); const auto resolvedType = var.type.resolve(modelMerged.getTypeContext()); - genVariable(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, resolvedType, var.name + group.getName(), varAdaptor.getLoc(var.name), autoInitialized, getSizeFn(group, var), mem, statePushPullFunctions); @@ -995,7 +994,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, // If neuron group has axonal delays if (n.second.isDelayRequired()) { - genHostDeviceScalar(modelMerged, backend, definitionsVar, definitionsInternalVar, + genHostDeviceScalar(backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, Type::Uint32, "spkQuePtr" + n.first, "0", mem); } @@ -1084,7 +1083,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const size_t count = n.second.isVarQueueRequired(var.name) ? numCopies * numElements * n.second.getNumDelaySlots() : numCopies * n.second.getNumNeurons(); const bool autoInitialized = !varInitSnippet->getCode().empty(); const auto resolvedType = var.type.resolve(modelMerged.getTypeContext()); - genVariable(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, + genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, resolvedType, var.name + n.first, n.second.getVarLocation(var.name), autoInitialized, count, mem, neuronStatePushPullFunctions); @@ -1260,7 +1259,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, model.getPrecision(), "denDelay" + sg->getFusedPSVarSuffix(), sg->getDendriticDelayLocation(), (size_t)sg->getMaxDendriticDelayTimesteps() * (size_t)sg->getTrgNeuronGroup()->getNumNeurons() * batchSize, mem); - genHostDeviceScalar(modelMerged, backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + genHostDeviceScalar(backend, definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, Type::Uint32, "denDelayPtr" + sg->getFusedPSVarSuffix(), "0", mem); } @@ -1399,7 +1398,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const auto resolvedType = wuVar.type.resolve(modelMerged.getTypeContext()); if(individualWeights) { const size_t size = (size_t)s.second.getSrcNeuronGroup()->getNumNeurons() * (size_t)backend.getSynapticMatrixRowStride(s.second); - genVariable(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, resolvedType, wuVar.name + s.second.getName(), s.second.getWUVarLocation(wuVar.name), autoInitialized, size * getNumVarCopies(wuVar.access, batchSize), mem, synapseGroupStatePushPullFunctions); } @@ -1408,7 +1407,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const size_t size = s.second.getKernelSizeFlattened() * getNumVarCopies(wuVar.access, batchSize); // Generate variable - genVariable(modelMerged, backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, + genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, resolvedType, wuVar.name + s.second.getName(), s.second.getWUVarLocation(wuVar.name), autoInitialized, size, mem, synapseGroupStatePushPullFunctions); } diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 1f58962141..aad8351b46 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -1168,7 +1168,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateSparseInitGroupM return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomConnectivityUpdateSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void CustomConnectivityUpdateSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged&) { // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 4e2c615498..0e65109fc2 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -74,8 +74,8 @@ size_t PreSpan::getSharedMemoryPerThread(const PresynapticUpdateGroupMerged&, co return 0; } //---------------------------------------------------------------------------- -void PreSpan::genPreamble(EnvironmentExternalBase&, const ModelSpecMerged&, - PresynapticUpdateGroupMerged&, const BackendSIMT&) const +void PreSpan::genPreamble(EnvironmentExternalBase&, PresynapticUpdateGroupMerged&, + const BackendSIMT&) const { } //---------------------------------------------------------------------------- @@ -86,7 +86,6 @@ void PreSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &mod const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); const std::string eventSuffix = trueSpike ? "" : "_evnt"; - const auto *wu = sg.getArchetype().getWUModel(); const size_t numThreadsPerSpike = sg.getArchetype().getNumThreadsPerSpike(); if(numThreadsPerSpike > 1) { @@ -180,7 +179,7 @@ void PreSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &mod } } //---------------------------------------------------------------------------- -void PreSpan::genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged&, +void PreSpan::genPostamble(EnvironmentExternalBase &, const ModelSpecMerged&, PresynapticUpdateGroupMerged&, const BackendSIMT&) const { } @@ -217,8 +216,8 @@ bool PostSpan::isCompatible(const SynapseGroupInternal &sg, const PreferencesBas && !(sg.getMatrixType() & SynapseMatrixConnectivity::TOEPLITZ)); } //---------------------------------------------------------------------------- -void PostSpan::genPreamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const +void PostSpan::genPreamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend) const { // If data structure is dense, we can accumulate output directly into register if(shouldAccumulateInRegister(sg)) { @@ -251,7 +250,6 @@ void PostSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &mo env.printLine("const unsigned int numSpikes = $(_src_spk_cnt" + eventSuffix + ")[" + sg.getPreSlot(batchSize) + "];"); env.getStream() << "const unsigned int numSpikeBlocks = (numSpikes + " << backend.getKernelBlockSize(KernelPresynapticUpdate) << " - 1) / " << backend.getKernelBlockSize(KernelPresynapticUpdate) << ";" << std::endl; - const auto *wu = sg.getArchetype().getWUModel(); env.getStream() << "for (unsigned int r = 0; r < numSpikeBlocks; r++)"; { CodeStream::Scope b(env.getStream()); @@ -453,19 +451,18 @@ size_t PreSpanProcedural::getSharedMemoryPerThread(const PresynapticUpdateGroupM return 0; } //---------------------------------------------------------------------------- -void PreSpanProcedural::genPreamble(EnvironmentExternalBase&, const ModelSpecMerged&, - PresynapticUpdateGroupMerged&, const BackendSIMT&) const +void PreSpanProcedural::genPreamble(EnvironmentExternalBase&, PresynapticUpdateGroupMerged&, + const BackendSIMT&) const { } //---------------------------------------------------------------------------- void PreSpanProcedural::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const + PresynapticUpdateGroupMerged &sg, const BackendSIMT&, bool trueSpike) const { // Get suffix based on type of events const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); const std::string eventSuffix = trueSpike ? "" : "_evnt"; - const auto *wu = sg.getArchetype().getWUModel(); const size_t numThreadsPerSpike = sg.getArchetype().getNumThreadsPerSpike(); if(numThreadsPerSpike > 1) { @@ -641,8 +638,8 @@ bool PostSpanBitmask::isCompatible(const SynapseGroupInternal &sg, const Prefere && !sg.isDendriticDelayRequired()); } //---------------------------------------------------------------------------- -void PostSpanBitmask::genPreamble(EnvironmentExternalBase &env, const ModelSpecMerged &, - PresynapticUpdateGroupMerged &, const BackendSIMT &backend) const +void PostSpanBitmask::genPreamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &, + const BackendSIMT &backend) const { // Loop through bits written by this thread for(size_t i = 0; i < 32; i++) { @@ -674,7 +671,6 @@ void PostSpanBitmask::genUpdate(EnvironmentExternalBase &env, const ModelSpecMer env.getStream() << "const unsigned int numSpikeBlocks = (numSpikes + " << blockSize << " - 1) / " << blockSize << ";" << std::endl; - const auto *wu = sg.getArchetype().getWUModel(); env.printLine("const unsigned int rowWords = ($(num_post) + 32 - 1) / 32;"); env.getStream() << "for (unsigned int r = 0; r < numSpikeBlocks; r++)"; { @@ -813,8 +809,8 @@ bool PostSpanToeplitz::isCompatible(const SynapseGroupInternal &sg, const Prefer return (sg.getMatrixType() & SynapseMatrixConnectivity::TOEPLITZ); } //---------------------------------------------------------------------------- -void PostSpanToeplitz::genPreamble(EnvironmentExternalBase &env, const ModelSpecMerged &, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const +void PostSpanToeplitz::genPreamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend) const { if(isSmallSharedMemoryPop(sg, backend)) { env.print("if(" + backend.getThreadID() + " < $(num_post))"); @@ -835,14 +831,14 @@ size_t PostSpanToeplitz::getSharedMemoryPerThread(const PresynapticUpdateGroupMe void PostSpanToeplitz::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const { - const auto &connectInit = sg.getArchetype().getToeplitzConnectivityInitialiser(); + assert(false); + /*const auto &connectInit = sg.getArchetype().getToeplitzConnectivityInitialiser(); // Get suffix based on type of events const ModelSpecInternal &model = modelMerged.getModel(); const unsigned int batchSize = model.getBatchSize(); const std::string eventSuffix = trueSpike ? "" : "_evnt"; - assert(false); - /* + // Create substitution stack for generating Toeplitz connectivity code Substitutions connSubs(&popSubs); connSubs.addVarSubstitution("id_diag", connSubs["id"]); From bd69484914ff6ebb8e7c0af92a84639858ae9994 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 15:20:11 +0100 Subject: [PATCH 410/725] bit of include tidying --- .../code_generator/customConnectivityUpdateGroupMerged.h | 3 +++ include/genn/genn/code_generator/environment.h | 4 +++- include/genn/genn/code_generator/groupMerged.h | 9 ++------- include/genn/genn/code_generator/initGroupMerged.h | 6 ++++++ .../genn/genn/code_generator/synapseUpdateGroupMerged.h | 3 +++ 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index c8b7b43ed6..90f1c967a0 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -3,6 +3,9 @@ // Standard C++ includes #include +// GeNN includes +#include "customConnectivityUpdateInternal.h" + // GeNN code generator includes #include "code_generator/codeGenUtils.h" #include "code_generator/environment.h" diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 55a9b86ae9..e101bd5d21 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -383,6 +383,8 @@ template class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase> { using GroupInternal = typename G::GroupInternal; + + using GetFieldDoubleValueFunc = std::function; using IsHeterogeneousFn = bool (G::*)(const std::string&) const; using IsVarInitHeterogeneousFn = bool (G::*)(const std::string&, const std::string&) const; using GetParamValuesFn = const std::unordered_map &(GroupInternal::*)(void) const; @@ -429,7 +431,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getScalarType(); diff --git a/include/genn/genn/code_generator/groupMerged.h b/include/genn/genn/code_generator/groupMerged.h index 77cbea89bd..36367ed776 100644 --- a/include/genn/genn/code_generator/groupMerged.h +++ b/include/genn/genn/code_generator/groupMerged.h @@ -8,11 +8,7 @@ // GeNN includes #include "gennExport.h" -#include "currentSourceInternal.h" -#include "customConnectivityUpdateInternal.h" -#include "customUpdateInternal.h" #include "neuronGroupInternal.h" -#include "synapseGroupInternal.h" #include "type.h" // GeNN code generator includes @@ -63,7 +59,6 @@ class ChildGroupMerged //------------------------------------------------------------------------ typedef G GroupInternal; typedef std::function GetFieldValueFunc; - typedef std::function GetFieldDoubleValueFunc; typedef std::tuple Field; ChildGroupMerged(size_t index, const Type::TypeContext &typeContext, const std::vector> groups) @@ -424,7 +419,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged void orderNeuronGroupChildren(std::vector &childGroups, const Type::TypeContext &typeContext, G getVectorFunc, H getHashDigestFunc) const { - const std::vector &archetypeChildren = std::invoke(getVectorFunc, getArchetype()); + const auto &archetypeChildren = std::invoke(getVectorFunc, getArchetype()); // Resize vector of vectors to hold children for all neuron groups, sorted in a consistent manner std::vector>> sortedGroupChildren; @@ -437,7 +432,7 @@ class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged &groupChildren = (g.get().*getVectorFunc)(); + const auto &groupChildren = std::invoke(getVectorFunc, g.get()); assert(groupChildren.size() == archetypeChildren.size()); // Loop through children and add them and their digests to vector diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index 948aa6fe40..0eb1eabf69 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -1,5 +1,11 @@ #pragma once +// GeNN includes +#include "customConnectivityUpdateInternal.h" +#include "currentSourceInternal.h" +#include "customUpdateInternal.h" +#include "synapseGroupInternal.h" + // GeNN code generator includes #include "code_generator/groupMerged.h" diff --git a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h index 13830b3bca..74a1658337 100644 --- a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h @@ -1,5 +1,8 @@ #pragma once +// GeNN includes +#include "synapseGroupInternal.h" + // GeNN code generator includes #include "code_generator/groupMerged.h" From bcefef06eaeb04630c3aa67bdb6eaf36e0dd8c8c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 15:26:38 +0100 Subject: [PATCH 411/725] attempt at fixing GroupInternal/External mess --- include/genn/genn/code_generator/environment.h | 6 ++---- include/genn/genn/currentSourceInternal.h | 2 ++ include/genn/genn/customConnectivityUpdateInternal.h | 2 ++ include/genn/genn/customUpdateInternal.h | 4 ++++ include/genn/genn/neuronGroupInternal.h | 2 ++ include/genn/genn/synapseGroupInternal.h | 2 ++ 6 files changed, 14 insertions(+), 4 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index e101bd5d21..da75a7fbee 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -383,6 +383,7 @@ template class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase> { using GroupInternal = typename G::GroupInternal; + using GroupExternal = typename GroupInternal::GroupExternal; using GetFieldDoubleValueFunc = std::function; using IsHeterogeneousFn = bool (G::*)(const std::string&) const; @@ -394,10 +395,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase; template - using GetConnectivityFn = const I &(GroupInternal::*)(void) const; - - template - using GetVarReferencesFn = const std::unordered_map &(GroupInternal::*)(void) const; + using GetConnectivityFn = const I &(GroupExternal::*)(void) const; public: using EnvironmentExternalDynamicBase>::EnvironmentExternalDynamicBase; diff --git a/include/genn/genn/currentSourceInternal.h b/include/genn/genn/currentSourceInternal.h index ba4edff676..05a3d50457 100644 --- a/include/genn/genn/currentSourceInternal.h +++ b/include/genn/genn/currentSourceInternal.h @@ -11,6 +11,8 @@ namespace GeNN class CurrentSourceInternal : public CurrentSource { public: + using GroupExternal = CurrentSource; + CurrentSourceInternal(const std::string &name, const CurrentSourceModels::Base *currentSourceModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, const NeuronGroupInternal *targetNeuronGroup, VarLocation defaultVarLocation, diff --git a/include/genn/genn/customConnectivityUpdateInternal.h b/include/genn/genn/customConnectivityUpdateInternal.h index ca03cbe35a..e603501cf6 100644 --- a/include/genn/genn/customConnectivityUpdateInternal.h +++ b/include/genn/genn/customConnectivityUpdateInternal.h @@ -12,6 +12,8 @@ namespace GeNN class CustomConnectivityUpdateInternal : public CustomConnectivityUpdate { public: + using GroupExternal = CustomConnectivityUpdate; + CustomConnectivityUpdateInternal(const std::string &name, const std::string &updateGroupName, SynapseGroupInternal *synapseGroup, const CustomConnectivityUpdateModels::Base *customConnectivityUpdateModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, diff --git a/include/genn/genn/customUpdateInternal.h b/include/genn/genn/customUpdateInternal.h index b6a08727f8..fbb08a0d43 100644 --- a/include/genn/genn/customUpdateInternal.h +++ b/include/genn/genn/customUpdateInternal.h @@ -12,6 +12,8 @@ namespace GeNN class CustomUpdateInternal : public CustomUpdate { public: + using GroupExternal = CustomUpdate; + CustomUpdateInternal(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, @@ -68,6 +70,8 @@ class CustomUpdateVarRefAdapter class CustomUpdateWUInternal : public CustomUpdateWU { public: + using GroupExternal = CustomUpdateWU; + CustomUpdateWUInternal(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, diff --git a/include/genn/genn/neuronGroupInternal.h b/include/genn/genn/neuronGroupInternal.h index 9e8da8077b..5905735aad 100644 --- a/include/genn/genn/neuronGroupInternal.h +++ b/include/genn/genn/neuronGroupInternal.h @@ -11,6 +11,8 @@ namespace GeNN class NeuronGroupInternal : public NeuronGroup { public: + using GroupExternal = NeuronGroup; + NeuronGroupInternal(const std::string &name, int numNeurons, const NeuronModels::Base *neuronModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h index 556600a79c..a9c56be5f5 100644 --- a/include/genn/genn/synapseGroupInternal.h +++ b/include/genn/genn/synapseGroupInternal.h @@ -12,6 +12,8 @@ namespace GeNN class SynapseGroupInternal : public SynapseGroup { public: + using GroupExternal = SynapseGroup; + SynapseGroupInternal(const std::string &name, SynapseMatrixType matrixType, unsigned int delaySteps, const WeightUpdateModels::Base *wu, const std::unordered_map &wuParams, const std::unordered_map &wuVarInitialisers, const std::unordered_map &wuPreVarInitialisers, const std::unordered_map &wuPostVarInitialisers, const PostsynapticModels::Base *ps, const std::unordered_map &psParams, const std::unordered_map &psVarInitialisers, From 9d70d7db6c2bbb15526983b551043c439875eeeb Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 15:30:11 +0100 Subject: [PATCH 412/725] auto --- include/genn/backends/cuda/backend.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index fadf3d0096..9b9d7ca9a9 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -382,7 +382,7 @@ class BACKEND_EXPORT Backend : public BackendSIMT EnvironmentGroupMergedField groupEnv(env, cg); // Loop through variables - const CustomUpdateModels::Base *cm = cg.getArchetype().getCustomUpdateModel(); + const auto *cm = cg.getArchetype().getCustomUpdateModel(); for(const auto &v : cm->getVars()) { // If variable is reduction target if(v.access & VarAccessModeAttribute::REDUCE) { From 4e6eb77b292844990ba715c90d5f9f77b1af3d96 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 15:36:00 +0100 Subject: [PATCH 413/725] fixed small SIMT bug --- .../genn/code_generator/presynapticUpdateStrategySIMT.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 0e65109fc2..51bb7f71ee 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -317,10 +317,11 @@ void PostSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &mo {synEnv.addInitialiser( "const unsigned int synAddress = ($(_sh_spk" + eventSuffix + ")[j] * $(_row_stride)) + $(id);")}); if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - synEnv.getStream() << "const unsigned int npost = " << synEnv["_sh_row_length"] << "[j];" << std::endl; + synEnv.printLine("const unsigned int npost = $(_sh_row_length)[j];"); - synEnv.getStream() << "if (" << synEnv["id"] << " < npost)" << CodeStream::OB(140); - synEnv.getStream() << "const unsigned int ipost = " << synEnv["_ind"] << "[synAddress];" << std::endl; + synEnv.print("if ($(id) < npost)"); + synEnv.getStream() << CodeStream::OB(140); + synEnv.printLine("const unsigned int ipost = $(_ind)[$(id_syn)];"); synEnv.add(Type::Uint32.addConst(), "id_post", "ipost"); } From 758a42e4a0cc2bfcacc90b1d0087d57651c4abc2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 15:40:48 +0100 Subject: [PATCH 414/725] tidying of Environment --- .../genn/genn/code_generator/environment.h | 22 +++++++------------ .../presynapticUpdateStrategySIMT.cc | 10 +-------- 2 files changed, 9 insertions(+), 23 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index da75a7fbee..3042da6e8c 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -677,20 +677,17 @@ class VarCachePolicy //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - std::string getReadIndex(G&, const Models::Base::Var &var) + std::string getReadIndex(G&, const Models::Base::Var &var) const { return m_GetReadIndex(var.name, getVarAccessDuplication(var.access)); } - std::string getWriteIndex(G&, const Models::Base::Var &var) + std::string getWriteIndex(G&, const Models::Base::Var &var) const { return m_GetWriteIndex(var.name, getVarAccessDuplication(var.access)); } - //------------------------------------------------------------------------ - // Static API - //------------------------------------------------------------------------ - static std::string getVarSuffix(const GroupInternal &g, const Models::Base::Var&) + std::string getVarSuffix(const GroupInternal &g, const Models::Base::Var&) const { return A(g).getNameSuffix(); } @@ -725,20 +722,17 @@ class VarRefCachePolicy //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ - std::string getReadIndex(G &g, const Models::Base::VarRef &var) + std::string getReadIndex(G &g, const Models::Base::VarRef &var) const { return m_GetReadIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); } - std::string getWriteIndex(G &g, const Models::Base::VarRef &var) + std::string getWriteIndex(G &g, const Models::Base::VarRef &var) const { return m_GetWriteIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); } - //------------------------------------------------------------------------ - // Static API - //------------------------------------------------------------------------ - static std::string getVarSuffix(const GroupInternal &g, const Models::Base::VarRef &var) + std::string getVarSuffix(const GroupInternal &g, const Models::Base::VarRef &var) const { return A(g).getInitialisers().at(var.name).getTargetName(); } @@ -795,9 +789,9 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P const auto &group = m_Group.get(); const auto &arrayPrefix = m_ArrayPrefix; m_FieldGroup.get().addField(resolvedType.createPointer(), v.name + m_FieldSuffix, - [arrayPrefix, v, &group](const typename F::GroupInternal &, size_t i) + [arrayPrefix, v, &group, this](const typename F::GroupInternal &, size_t i) { - return arrayPrefix + v.name + P::getVarSuffix(group.getGroups().at(i), v); + return arrayPrefix + v.name + getVarSuffix(group.getGroups().at(i), v); }); if(v.access & VarAccessMode::READ_ONLY) { diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 51bb7f71ee..8c9e5d685d 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -328,15 +328,7 @@ void PostSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &mo else { // DENSE synEnv.add(Type::Uint32.addConst(), "id_post", "$(id)"); } - /*synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", - backend.getAtomic(model.getPrecision()) + "(&$(_den_delay)[" + sg.getPostDenDelayIndex(batchSize, "ipost", "$(1)") + "], $(0))"); - - synEnv.add(Type::AddToPost, "addToPost", - backend.getAtomic(model.getPrecision()) + "(&$(_out_post)[" + sg.getPostISynIndex(batchSize, "ipost") + "], $(0))"); - - - synEnv.add(Type::AddToPre, "addToPre", "lrevInSyn += $(0)"); - */ + // If dendritic delay is required, always use atomic operation to update dendritic delay buffer synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", backend.getAtomic(model.getPrecision()) + "(&$(_den_delay)[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); From 9ff9afac18b95476c8801a0a983705400910328f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 15:42:48 +0100 Subject: [PATCH 415/725] GCC fix --- include/genn/genn/code_generator/environment.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 3042da6e8c..28cf578b71 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -791,7 +791,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P m_FieldGroup.get().addField(resolvedType.createPointer(), v.name + m_FieldSuffix, [arrayPrefix, v, &group, this](const typename F::GroupInternal &, size_t i) { - return arrayPrefix + v.name + getVarSuffix(group.getGroups().at(i), v); + return arrayPrefix + v.name + this->getVarSuffix(group.getGroups().at(i), v); }); if(v.access & VarAccessMode::READ_ONLY) { From 85646bff220883f1a2a7907611a70fa1cdf69a67 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 16:45:12 +0100 Subject: [PATCH 416/725] stop passing ModelSpecMerged around everwhere when all that's needed is batch size! --- .../genn/genn/code_generator/backendSIMT.h | 19 +--- .../customConnectivityUpdateGroupMerged.h | 4 +- .../genn/code_generator/initGroupMerged.h | 32 +++--- .../code_generator/neuronUpdateGroupMerged.h | 18 +-- .../presynapticUpdateStrategySIMT.h | 49 ++++---- .../code_generator/synapseUpdateGroupMerged.h | 10 +- src/genn/backends/cuda/backend.cc | 4 +- .../backends/single_threaded_cpu/backend.cc | 70 ++++++------ src/genn/genn/code_generator/backendSIMT.cc | 76 +++++++------ .../customConnectivityUpdateGroupMerged.cc | 29 +++-- .../genn/code_generator/generateRunner.cc | 2 +- .../genn/code_generator/initGroupMerged.cc | 107 +++++++++--------- .../code_generator/neuronUpdateGroupMerged.cc | 71 ++++++------ .../presynapticUpdateStrategySIMT.cc | 89 +++++++-------- .../synapseUpdateGroupMerged.cc | 22 ++-- 15 files changed, 284 insertions(+), 318 deletions(-) diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index b28e5b6432..5a766133fe 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -334,21 +334,10 @@ class GENN_EXPORT BackendSIMT : public BackendBase genGroup(env, g, idStart, getPaddedSizeFunc, handler); }); } - - - - - /*template - void genParallelGroup(EnvironmentExternalBase &env, std::vector &groups, size_t &idStart, - S getPaddedSizeFunc, GroupHandlerEnv handler) const - { - genParallelGroup(env, groups, idStart, getPaddedSizeFunc, - [](const T &) { return true; }, handler); - }*/ // Helper function to generate kernel code to initialise variables associated with synapse group or custom WU update with dense/kernel connectivity template - void genSynapseVarInit(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, G &g, + void genSynapseVarInit(EnvironmentExternalBase &env, unsigned int batchSize, G &g, bool initRNGRequired, bool kernel, size_t kernelDimensions) const { env.getStream() << "if(" << env["id"] << " < "; @@ -422,13 +411,13 @@ class GENN_EXPORT BackendSIMT : public BackendBase } // Generate init code - g.generateInit(*this, initEnv, modelMerged); + g.generateInit(*this, initEnv, batchSize); } } // Helper function to generate kernel code to initialise variables associated with synapse group or custom WU update with sparse connectivity template - void genSparseSynapseVarInit(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, G &g, + void genSparseSynapseVarInit(EnvironmentExternalBase &env, unsigned int batchSize, G &g, bool varInitRequired, GroupHandlerEnv handler) const { // Calculate how many blocks rows need to be processed in (in order to store row lengths in shared memory) @@ -471,7 +460,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase EnvironmentExternal initEnv(env); initEnv.add(Type::Uint32.addConst(), "id_pre", "((r * " + std::to_string(blockSize) + ") + i)"); initEnv.add(Type::Uint32.addConst(), "id_post", "$(_ind)[idx]"); - g.generateInit(*this, initEnv, modelMerged); + g.generateInit(*this, initEnv, batchSize); } // Call handler diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index 90f1c967a0..64bc3451f7 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -36,7 +36,7 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public GroupMerged &getMergedCurrentSourceGroups() const { return m_MergedCurrentSourceGroups; } const std::vector &getMergedInSynPSMGroups() const { return m_MergedInSynPSMGroups; } @@ -241,7 +241,7 @@ class GENN_EXPORT SynapseInitGroupMerged : public InitGroupMergedBase genEmitTrueSpike, BackendBase::GroupHandlerEnv genEmitSpikeLikeEvent); - void generateWUVarUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); + void generateWUVarUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize); std::string getVarIndex(unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; std::string getReadVarIndex(bool delay, unsigned int batchSize, VarAccessDuplication varDuplication, const std::string &index) const; diff --git a/include/genn/genn/code_generator/presynapticUpdateStrategySIMT.h b/include/genn/genn/code_generator/presynapticUpdateStrategySIMT.h index 3e589cbc65..e6ffdbea75 100644 --- a/include/genn/genn/code_generator/presynapticUpdateStrategySIMT.h +++ b/include/genn/genn/code_generator/presynapticUpdateStrategySIMT.h @@ -11,7 +11,6 @@ class SynapseGroupInternal; namespace CodeGenerator { class BackendSIMT; -class ModelSpecMerged; } } @@ -42,11 +41,11 @@ class Base const BackendSIMT &backend) const = 0; //! Generate presynaptic update code - virtual void genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const = 0; + virtual void genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize, bool trueSpike) const = 0; - virtual void genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const = 0; + virtual void genPostamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize) const = 0; }; //-------------------------------------------------------------------------- @@ -75,11 +74,11 @@ class PreSpan : public Base const BackendSIMT &backend) const final; //! Generate presynaptic update code - virtual void genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const final; + virtual void genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize, bool trueSpike) const final; - virtual void genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; + virtual void genPostamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize) const final; }; //-------------------------------------------------------------------------- @@ -108,11 +107,11 @@ class PostSpan : public Base const BackendSIMT &backend) const final; //! Generate presynaptic update code - virtual void genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const final; + virtual void genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize, bool trueSpike) const final; - virtual void genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; + virtual void genPostamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize) const final; private: //-------------------------------------------------------------------------- @@ -149,11 +148,11 @@ class PostSpanBitmask : public Base const BackendSIMT &backend) const final; //! Generate presynaptic update code - virtual void genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const final; + virtual void genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize, bool trueSpike) const final; - virtual void genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; + virtual void genPostamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize) const final; }; //-------------------------------------------------------------------------- @@ -182,11 +181,11 @@ class PreSpanProcedural : public Base const BackendSIMT &backend) const final; //! Generate presynaptic update code - virtual void genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const final; + virtual void genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize, bool trueSpike) const final; - virtual void genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; + virtual void genPostamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize) const final; }; //-------------------------------------------------------------------------- @@ -215,10 +214,10 @@ class PostSpanToeplitz : public Base const BackendSIMT &backend) const final; //! Generate presynaptic update code - virtual void genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const final; + virtual void genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize, bool trueSpike) const final; - virtual void genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const final; + virtual void genPostamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize) const final; }; } // namespace GeNN::CodeGenerator::PresynapticUpdateStrategySIMT diff --git a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h index 74a1658337..5a5b6681e1 100644 --- a/include/genn/genn/code_generator/synapseUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/synapseUpdateGroupMerged.h @@ -116,9 +116,9 @@ class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase runnerVarDecl, runnerMergedStructAlloc, name); } - void generateSpikeEventThreshold(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); - void generateSpikeEventUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); - void generateSpikeUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); + void generateSpikeEventThreshold(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize); + void generateSpikeEventUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize); + void generateSpikeUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize); void generateProceduralConnectivity(const BackendBase &backend, EnvironmentExternalBase &env); void generateToeplitzConnectivity(const BackendBase &backend, EnvironmentExternalBase &env); @@ -145,7 +145,7 @@ class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase runnerVarDecl, runnerMergedStructAlloc, name); } - void generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); + void generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize); //---------------------------------------------------------------------------- // Static constants @@ -170,7 +170,7 @@ class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase runnerVarDecl, runnerMergedStructAlloc, name); } - void generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged); + void generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize); //---------------------------------------------------------------------------- // Static constants diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 39d775a2b2..21eeda887d 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -713,9 +713,9 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Loop through host update groups and generate code for those in this custom update group modelMerged.genMergedCustomConnectivityHostUpdateGroups( *this, memorySpaces, g, - [this, &customUpdateEnv, &modelMerged](auto &c) + [this, &customUpdateEnv](auto &c) { - c.generateUpdate(*this, customUpdateEnv, modelMerged); + c.generateUpdate(*this, customUpdateEnv); }); // Launch custom update kernel if required diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index eac43679b3..99e8a9ad16 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -248,12 +248,12 @@ void Backend::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back EnvironmentLibrary rngEnv(groupEnv, StandardLibrary::getHostRNGFunctions(modelMerged.getModel().getPrecision())); // Generate neuron update - n.generateNeuronUpdate(*this, rngEnv, modelMerged, + n.generateNeuronUpdate(*this, rngEnv, 1, // Emit true spikes - [&modelMerged, this](EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng) + [this](EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng) { // Insert code to update WU vars - ng.generateWUVarUpdate(*this, env, modelMerged); + ng.generateWUVarUpdate(*this, env, 1); // Insert code to emit true spikes genEmitSpike(env, ng, true, ng.getArchetype().isSpikeRecordingEnabled()); @@ -376,7 +376,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Bac synEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + s.getPreISynIndex(1, "$(id_pre)") + "] += $(0)"); // Call synapse dynamics handler - s.generateSynapseUpdate(*this, synEnv, modelMerged); + s.generateSynapseUpdate(*this, synEnv, 1); } } } @@ -487,7 +487,7 @@ void Backend::genSynapseUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Bac synEnv.add(Type::Uint32.addConst(), "id_post", "spike"); synEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + s.getPreISynIndex(1, "$(id_pre)") + "] += $(0)"); - s.generateSynapseUpdate(*this, synEnv, modelMerged); + s.generateSynapseUpdate(*this, synEnv, 1); } } groupEnv.getStream() << std::endl; @@ -557,7 +557,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back *this, memorySpaces, g, [this, &customUpdateEnv, &modelMerged](auto &c) { - c.generateUpdate(*this, customUpdateEnv, modelMerged); + c.generateUpdate(*this, customUpdateEnv); }); { @@ -865,7 +865,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: EnvironmentGroupMergedField groupEnv(funcEnv, n); buildStandardEnvironment(groupEnv, 1); - n.generateInit(*this, groupEnv, modelMerged); + n.generateInit(*this, groupEnv, 1); } }); @@ -886,7 +886,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: EnvironmentGroupMergedField groupEnv(funcEnv, s); buildStandardEnvironment(groupEnv, 1); - s.generateInit(*this, groupEnv, modelMerged); + s.generateInit(*this, groupEnv, 1); } }); @@ -904,7 +904,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomUpdateInitGroup" << c.getIndex() << "[g]; " << std::endl; - c.generateInit(*this, funcEnv, modelMerged); + c.generateInit(*this, funcEnv, 1); } }); @@ -922,7 +922,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomConnectivityUpdatePreInitGroup" << c.getIndex() << "[g]; " << std::endl; - c.generateInit(*this, funcEnv, modelMerged); + c.generateInit(*this, funcEnv, 1); } }); @@ -940,7 +940,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomConnectivityUpdatePostInitGroup" << c.getIndex() << "[g]; " << std::endl; - c.generateInit(*this, funcEnv, modelMerged); + c.generateInit(*this, funcEnv, 1); } }); @@ -958,7 +958,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomWUUpdateInitGroup" << c.getIndex() << "[g]; " << std::endl; - c.generateInit(*this, funcEnv, modelMerged); + c.generateInit(*this, funcEnv, 1); } }); @@ -1067,7 +1067,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: } // Call handler to initialize variables - s.generateKernelInit(*this, kernelInitEnv, modelMerged); + s.generateKernelInit(*this, kernelInitEnv, 1); } // If there is row-building code in this snippet @@ -1151,28 +1151,28 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); if(s.getArchetype().isWUVarInitRequired()) { groupEnv.add(Type::Uint32.addConst(), "row_len", "$(_row_length)[i]"); - s.generateInit(*this, groupEnv, modelMerged); + s.generateInit(*this, groupEnv, 1); } // If postsynaptic learning is required if(!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { - groupEnv.getStream() << "// Loop through synapses in corresponding matrix row" << std::endl; - groupEnv.getStream() << "for(unsigned int j = 0; j < " << groupEnv["_row_length"] << "[i]; j++)" << std::endl; + groupEnv.printLine("// Loop through synapses in corresponding matrix row"); + groupEnv.print("for(unsigned int j = 0; j < $(_row_length)[i]; j++)"); { CodeStream::Scope b(groupEnv.getStream()); // If postsynaptic learning is required, calculate column length and remapping if(!s.getArchetype().getWUModel()->getLearnPostCode().empty()) { - groupEnv.getStream() << "// Calculate index of this synapse in the row-major matrix" << std::endl; - groupEnv.getStream() << "const unsigned int rowMajorIndex = (i * " << groupEnv["_row_stride"] << ") + j;" << std::endl; - groupEnv.getStream() << "// Using this, lookup postsynaptic target" << std::endl; - groupEnv.getStream() << "const unsigned int postIndex = " << groupEnv["_ind"] << "[rowMajorIndex];" << std::endl; - groupEnv.getStream() << "// From this calculate index of this synapse in the column-major matrix" << std::endl; - groupEnv.getStream() << "const unsigned int colMajorIndex = (postIndex * " << groupEnv["_col_stride"] << ") + " << groupEnv["_col_length"] << "[postIndex];" << std::endl; - groupEnv.getStream() << "// Increment column length corresponding to this postsynaptic neuron" << std::endl; - groupEnv.getStream() << groupEnv["_col_length"] << "[postIndex]++;" << std::endl; - groupEnv.getStream() << "// Add remapping entry" << std::endl; - groupEnv.getStream() << groupEnv["_remap"] << "p[colMajorIndex] = rowMajorIndex;" << std::endl; + groupEnv.printLine("// Calculate index of this synapse in the row-major matrix"); + groupEnv.printLine("const unsigned int rowMajorIndex = (i * $(_row_stride)) + j;"); + groupEnv.printLine("// Using this, lookup postsynaptic target"); + groupEnv.printLine("const unsigned int postIndex = $(_ind)[rowMajorIndex];"); + groupEnv.printLine("// From this calculate index of this synapse in the column-major matrix)"); + groupEnv.printLine("const unsigned int colMajorIndex = (postIndex * $(_col_stride)) + $(_col_length)[postIndex];"); + groupEnv.printLine("// Increment column length corresponding to this postsynaptic neuron"); + groupEnv.printLine("$(_col_length)[postIndex]++;"); + groupEnv.printLine("// Add remapping entry"); + groupEnv.printLine("$(_remap)[colMajorIndex] = rowMajorIndex;"); } } } @@ -1204,7 +1204,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Generate initialisation code groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); groupEnv.add(Type::Uint32.addConst(), "row_len", "$(_row_length)[i]"); - c.generateInit(*this, groupEnv, modelMerged); + c.generateInit(*this, groupEnv, 1); } } }); @@ -1233,7 +1233,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Generate initialisation code groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); groupEnv.add(Type::Uint32.addConst(), "row_len", "$(_row_length)[i]"); - c.generateInit(*this, groupEnv, modelMerged); + c.generateInit(*this, groupEnv, 1); } } }); @@ -1871,7 +1871,7 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda groupEnv.getStream() << "if("; // Generate weight update threshold condition - sg.generateSpikeEventThreshold(*this, groupEnv, modelMerged); + sg.generateSpikeEventThreshold(*this, groupEnv, 1); groupEnv.getStream() << ")"; groupEnv.getStream() << CodeStream::OB(10); @@ -1897,10 +1897,10 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda synEnv.add(Type::AddToPre, "addToPre", "$(_out_pre)[" + sg.getPreISynIndex(1, "$(id_pre)") + "] += $(0)"); if(trueSpike) { - sg.generateSpikeUpdate(*this, synEnv, modelMerged); + sg.generateSpikeUpdate(*this, synEnv, 1); } else { - sg.generateSpikeEventUpdate(*this, synEnv, modelMerged); + sg.generateSpikeEventUpdate(*this, synEnv, 1); } } } @@ -1947,10 +1947,10 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda { CodeStream::Scope b(env.getStream()); if(trueSpike) { - sg.generateSpikeUpdate(*this, groupEnv, modelMerged); + sg.generateSpikeUpdate(*this, groupEnv, 1); } else { - sg.generateSpikeEventUpdate(*this, groupEnv, modelMerged); + sg.generateSpikeEventUpdate(*this, groupEnv, 1); } } @@ -1985,10 +1985,10 @@ void Backend::genPresynapticUpdate(EnvironmentExternalBase &env, PresynapticUpda if(trueSpike) { - sg.generateSpikeUpdate(*this, synEnv, modelMerged); + sg.generateSpikeUpdate(*this, synEnv, 1); } else { - sg.generateSpikeEventUpdate(*this, synEnv, modelMerged); + sg.generateSpikeEventUpdate(*this, synEnv, 1); } if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::BITMASK) { diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index a127312fc6..7a95ab230f 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -18,13 +18,13 @@ size_t getNumMergedGroupThreads(const std::vector &groups, G getNumThreads) return std::accumulate( groups.cbegin(), groups.cend(), size_t{0}, [getNumThreads](size_t acc, const T &n) - { - return std::accumulate(n.getGroups().cbegin(), n.getGroups().cend(), acc, - [getNumThreads](size_t acc, std::reference_wrapper g) { - return acc + getNumThreads(g.get()); + return std::accumulate(n.getGroups().cbegin(), n.getGroups().cend(), acc, + [getNumThreads](size_t acc, std::reference_wrapper g) + { + return acc + getNumThreads(g.get()); + }); }); - }); } } // Anonymous namespace @@ -520,7 +520,7 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)")); // **TODO** for OCL do genPopulationRNGPreamble(os, popSubs, "group->rng[" + ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)") + "]") in initialiser - ng.generateNeuronUpdate(*this, groupEnv, modelMerged, + ng.generateNeuronUpdate(*this, groupEnv, batchSize, // Emit true spikes [this](EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng) { @@ -603,7 +603,7 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM // Create new substition stack and explicitly replace id with 'n' and perform WU var update EnvironmentExternal wuEnv(groupEnv); wuEnv.add(Type::Uint32.addConst(), "id", "n"); - ng.generateWUVarUpdate(*this, wuEnv, modelMerged); + ng.generateWUVarUpdate(*this, wuEnv, batchSize); groupEnv.printLine("$(_spk)[" + queueOffsetTrueSpk + "$(_sh_spk_pos) + " + getThreadID() + "] = n;"); if(ng.getArchetype().isSpikeTimeRequired()) { @@ -740,7 +740,8 @@ void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, Model LOGD_BACKEND << "Using '" << typeid(*presynapticUpdateStrategy).name() << "' presynaptic update strategy for merged synapse group '" << sg.getIndex() << "'"; // Generate index calculation code - buildStandardEnvironment(groupEnv, modelMerged.getModel().getBatchSize()); + const unsigned int batchSize = modelMerged.getModel().getBatchSize(); + buildStandardEnvironment(groupEnv, batchSize); // Generate preamble presynapticUpdateStrategy->genPreamble(groupEnv, sg, *this); @@ -748,19 +749,19 @@ void BackendSIMT::genPresynapticUpdateKernel(EnvironmentExternalBase &env, Model // If spike events should be processed if(sg.getArchetype().isSpikeEventRequired()) { CodeStream::Scope b(groupEnv.getStream()); - presynapticUpdateStrategy->genUpdate(groupEnv, modelMerged, sg, *this, false); + presynapticUpdateStrategy->genUpdate(groupEnv, sg, *this, batchSize, false); } // If true spikes should be processed if(sg.getArchetype().isTrueSpikeRequired()) { CodeStream::Scope b(groupEnv.getStream()); - presynapticUpdateStrategy->genUpdate(groupEnv, modelMerged, sg, *this, true); + presynapticUpdateStrategy->genUpdate(groupEnv, sg, *this, batchSize, true); } groupEnv.getStream() << std::endl; // Generate pre-amble - presynapticUpdateStrategy->genPostamble(groupEnv, modelMerged, sg, *this); + presynapticUpdateStrategy->genPostamble(groupEnv, sg, *this, batchSize); }); } //-------------------------------------------------------------------------- @@ -843,7 +844,7 @@ void BackendSIMT::genPostsynapticUpdateKernel(EnvironmentExternalBase &env, Mode synEnv.add(Type::AddToPre, "addToPre", getAtomic(modelMerged.getModel().getPrecision()) + "(&$(_out_pre)[" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], $(0))"); - sg.generateSynapseUpdate(*this, synEnv, modelMerged); + sg.generateSynapseUpdate(*this, synEnv, batchSize); if (sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { synEnv.getStream() << CodeStream::CB(1540); @@ -909,7 +910,7 @@ void BackendSIMT::genSynapseDynamicsKernel(EnvironmentExternalBase &env, ModelSp synEnv.add(Type::AddToPre, "addToPre", getAtomic(modelMerged.getModel().getPrecision()) + "(&$(_out_pre)[" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], $(0))"); - sg.generateSynapseUpdate(*this, synEnv, modelMerged); + sg.generateSynapseUpdate(*this, synEnv, batchSize); if (sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { synEnv.getStream() << CodeStream::CB(1); @@ -1340,7 +1341,7 @@ void BackendSIMT::genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env groupEnv.add(Type::Void, "rng", genPopulationRNGPreamble(groupEnv.getStream(), rng)); } - cg.generateUpdate(*this, groupEnv, modelMerged); + cg.generateUpdate(*this, groupEnv, modelMerged.getModel().getBatchSize()); // Copy local stream back to local if(Utils::isRNGRequired(cg.getArchetype().getRowUpdateCodeTokens())) { @@ -1356,13 +1357,14 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer env.getStream() << "// ------------------------------------------------------------------------" << std::endl; env.getStream() << "// Local neuron groups" << std::endl; idStart = 0; + const unsigned int batchSize = modelMerged.getModel().getBatchSize(); genParallelGroup( env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedNeuronInitGroups, [this](const NeuronGroupInternal &ng) { return padKernelSize(ng.getNumNeurons(), KernelInitialize); }, - [&modelMerged, this](EnvironmentExternalBase &env, NeuronInitGroupMerged &ng) + [&modelMerged, batchSize, this](EnvironmentExternalBase &env, NeuronInitGroupMerged &ng) { EnvironmentGroupMergedField groupEnv(env, ng); - buildStandardEnvironment(groupEnv, modelMerged.getModel().getBatchSize()); + buildStandardEnvironment(groupEnv, batchSize); groupEnv.getStream() << "// only do this for existing neurons" << std::endl; groupEnv.print("if($(id) < $(num_neurons))"); @@ -1370,6 +1372,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer CodeStream::Scope b(groupEnv.getStream()); // If population RNGs are initialised on device and this neuron is going to require one, + if(isPopulationRNGInitialisedOnDevice() && ng.getArchetype().isSimRNGRequired()) { // Add field for RNG EnvironmentGroupMergedField rngInitEnv(groupEnv, ng); @@ -1377,13 +1380,13 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer [this](const auto &g, size_t) { return getDeviceVarPrefix() + "rng" + g.getName(); }); // If batch size is 1, initialise single RNG using GLOBAL thread id for sequence - if(modelMerged.getModel().getBatchSize() == 1) { + if(batchSize == 1) { genPopulationRNGInit(rngInitEnv.getStream(), printSubs("$(_rng)[$(id)]", rngInitEnv), "deviceRNGSeed", "id"); } // Otherwise, loop through batches and initialise independent RNGs using GLOBAL thread id as basis of sequence else { - env.getStream() << "for(unsigned int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; + env.getStream() << "for(unsigned int b = 0; b < " << batchSize << "; b++)"; { CodeStream::Scope b(rngInitEnv.getStream()); genPopulationRNGInit(rngInitEnv.getStream(), printSubs("$(_rng)[(b * $(num_neurons)) + $(id)]", rngInitEnv), @@ -1400,7 +1403,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer groupEnv.add(Type::Void, "_rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } - ng.generateInit(*this, groupEnv, modelMerged); + ng.generateInit(*this, groupEnv, batchSize); } }); env.getStream() << std::endl; @@ -1410,9 +1413,9 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer genParallelGroup( env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedSynapseInitGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumInitThreads(sg), KernelInitialize); }, - [&modelMerged, this](EnvironmentExternalBase &env, SynapseInitGroupMerged &sg) + [batchSize, this](EnvironmentExternalBase &env, SynapseInitGroupMerged &sg) { - genSynapseVarInit(env, modelMerged, sg, sg.getArchetype().isWUInitRNGRequired(), + genSynapseVarInit(env, batchSize, sg, sg.getArchetype().isWUInitRNGRequired(), (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL), sg.getArchetype().getKernelSize().size()); }); @@ -1423,7 +1426,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer genParallelGroup( env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomUpdateInitGroups, [this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelInitialize); }, - [&modelMerged, this](EnvironmentExternalBase &env, CustomUpdateInitGroupMerged &cg) + [batchSize, this](EnvironmentExternalBase &env, CustomUpdateInitGroupMerged &cg) { env.getStream() << "// only do this for existing variables" << std::endl; env.print("if($(id) < $(size))"); @@ -1438,7 +1441,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer groupEnv.add(Type::Void, "_rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } - cg.generateInit(*this, groupEnv, modelMerged); + cg.generateInit(*this, groupEnv, batchSize); } }); env.getStream() << std::endl; @@ -1448,10 +1451,10 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer genParallelGroup( env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomWUUpdateInitGroups, [this](const CustomUpdateWUInternal &cg) { return padKernelSize(getNumInitThreads(cg), KernelInitialize); }, - [&modelMerged, this](EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &cg) + [batchSize, this](EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &cg) { const SynapseGroup *sg = cg.getArchetype().getSynapseGroup(); - genSynapseVarInit(env, modelMerged, cg, cg.getArchetype().isInitRNGRequired(), + genSynapseVarInit(env, batchSize, cg, cg.getArchetype().isInitRNGRequired(), (sg->getMatrixType() & SynapseMatrixWeight::KERNEL), sg->getKernelSize().size()); }); env.getStream() << std::endl; @@ -1461,7 +1464,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer genParallelGroup( env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdatePreInitGroups, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelInitialize); }, - [&modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdatePreInitGroupMerged &cg) + [batchSize, this](EnvironmentExternalBase &env, CustomConnectivityUpdatePreInitGroupMerged &cg) { env.getStream() << "// only do this for existing variables" << std::endl; env.print("if($(id) < $(size))"); @@ -1488,7 +1491,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer groupEnv.add(Type::Void, "_rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } - cg.generateInit(*this, groupEnv, modelMerged); + cg.generateInit(*this, groupEnv, batchSize); } }); env.getStream() << std::endl; @@ -1498,7 +1501,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer genParallelGroup( env, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdatePostInitGroups, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons(), KernelInitialize); }, - [&modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdatePostInitGroupMerged &cg) + [batchSize, this](EnvironmentExternalBase &env, CustomConnectivityUpdatePostInitGroupMerged &cg) { env.getStream() << "// only do this for existing variables" << std::endl; env.print("if($(id) < $(size))"); @@ -1525,7 +1528,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer groupEnv.add(Type::Void, "_rng", genGlobalRNGSkipAhead(groupEnv.getStream(), "id")); } - cg.generateInit(*this, groupEnv, modelMerged); + cg.generateInit(*this, groupEnv, batchSize); } }); env.getStream() << std::endl; @@ -1684,13 +1687,14 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS {envKernel.addInitialiser(getSharedPrefix() + "unsigned int shRowLength[" + std::to_string(getKernelBlockSize(KernelInitializeSparse)) + "];")}); // Initialise weight update variables for synapse groups with sparse connectivity + const unsigned int batchSize = modelMerged.getModel().getBatchSize(); genParallelGroup( envKernel, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedSynapseSparseInitGroups, [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumConnectivityInitThreads(sg), KernelInitializeSparse); }, - [&modelMerged, numInitializeThreads, this](EnvironmentExternalBase &env, SynapseSparseInitGroupMerged &sg) + [batchSize, numInitializeThreads, this](EnvironmentExternalBase &env, SynapseSparseInitGroupMerged &sg) { EnvironmentGroupMergedField groupEnv(env, sg); - buildStandardEnvironment(groupEnv, modelMerged.getModel().getBatchSize()); + buildStandardEnvironment(groupEnv, batchSize); // If this post synapse requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id @@ -1702,7 +1706,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // Generate sparse synapse variable initialisation code genSparseSynapseVarInit( - groupEnv, modelMerged, sg, sg.getArchetype().isWUVarInitRequired(), + groupEnv, batchSize, sg, sg.getArchetype().isWUVarInitRequired(), [this](EnvironmentExternalBase &env, SynapseSparseInitGroupMerged &sg) { // If postsynaptic learning is required @@ -1729,7 +1733,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS genParallelGroup( envKernel, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomWUUpdateSparseInitGroups, [this](const CustomUpdateWUInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, - [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, CustomWUUpdateSparseInitGroupMerged &cg) + [batchSize, numInitializeThreads, this](EnvironmentExternalBase &env, CustomWUUpdateSparseInitGroupMerged &cg) { EnvironmentGroupMergedField groupEnv(env, cg); @@ -1743,7 +1747,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // Generate sparse synapse variable initialisation code genSparseSynapseVarInit( - groupEnv, modelMerged, cg, true, + groupEnv, batchSize, cg, true, [](EnvironmentExternalBase&, CustomWUUpdateSparseInitGroupMerged&){}); }); @@ -1751,7 +1755,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS genParallelGroup( envKernel, modelMerged, memorySpaces, idStart, &ModelSpecMerged::genMergedCustomConnectivityUpdateSparseInitGroups, [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getMaxConnections(), KernelInitializeSparse); }, - [numInitializeThreads, &modelMerged, this](EnvironmentExternalBase &env, CustomConnectivityUpdateSparseInitGroupMerged &cg) + [batchSize, numInitializeThreads, this](EnvironmentExternalBase &env, CustomConnectivityUpdateSparseInitGroupMerged &cg) { EnvironmentGroupMergedField groupEnv(env, cg); @@ -1765,7 +1769,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS // Generate sparse synapse variable initialisation code genSparseSynapseVarInit( - groupEnv, modelMerged, cg, true, + groupEnv, batchSize, cg, true, [](EnvironmentExternalBase&, CustomConnectivityUpdateSparseInitGroupMerged&){}); }); } diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index aeba04e607..7e2cadb186 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -140,7 +140,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateGroupMerged::get return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { // Create new environment to add current source fields to neuron update group EnvironmentGroupMergedField updateEnv(env, *this); @@ -183,7 +183,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back for(const auto &v : getArchetype().getCustomConnectivityUpdateModel()->getPreVarRefs()) { // If model isn't batched or variable isn't duplicated const auto &varRef = getArchetype().getPreVarReferences().at(v.name); - if(modelMerged.getModel().getBatchSize() == 1 || !varRef.isDuplicated()) { + if(batchSize == 1 || !varRef.isDuplicated()) { // Determine index const std::string index = (varRef.getDelayNeuronGroup() != nullptr) ? "$(_pre_delay_offset) + $(id_pre)" : "$(id_pre)"; @@ -253,11 +253,11 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Use subsequent parameters to initialise new synapse's variables referenced via the custom connectivity update for (size_t i = 0; i < ccuVarRefs.size(); i++) { // If model is batched and this variable is duplicated - if ((modelMerged.getModel().getBatchSize() > 1) && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) + if (batchSize > 1 && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) { // Copy parameter into a register (just incase it's e.g. a RNG call) and copy into all batches addSynapse << "const " << ccuVarRefs[i].type.resolve(getTypeContext()).getName() << " _" << ccuVarRefs[i].name << "Val = $(" << (1 + ccuVars.size() + i) << ");" << std::endl; - addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; + addSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { CodeStream::Scope b(addSynapse); addSynapse << "$(_" << ccuVarRefs[i].name << ")[(b * $(_syn_stride)) + newIdx] = _" << ccuVarRefs[i].name << "Val;" << std::endl; @@ -274,10 +274,10 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Loop through any other dependent variables for (size_t i = 0; i < dependentVars.size(); i++) { // If model is batched and this dependent variable is duplicated - if ((modelMerged.getModel().getBatchSize() > 1) && dependentVars.at(i).isDuplicated()) + if (batchSize > 1 && dependentVars.at(i).isDuplicated()) { // Loop through all batches and zero - addSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; + addSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { CodeStream::Scope b(addSynapse); addSynapse << "$(_dependent_var_" << i << ")[(b * $(_syn_stride)) + newIdx] = 0;" << std::endl; @@ -319,11 +319,10 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Loop through variable references for (size_t i = 0; i < ccuVarRefs.size(); i++) { // If model is batched and this variable is duplicated - if ((modelMerged.getModel().getBatchSize() > 1) - && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) + if (batchSize > 1 && getArchetype().getVarReferences().at(ccuVarRefs[i].name).isDuplicated()) { // Loop through all batches and copy custom connectivity update variable references from end of row over synapse to be deleted - removeSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; + removeSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { CodeStream::Scope b(addSynapse); removeSynapse << "$(_" << ccuVarRefs[i].name << ")[(b * $(_syn_stride)) + $(id_syn)] = "; @@ -339,9 +338,9 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back // Loop through any other dependent variables for (size_t i = 0; i < dependentVars.size(); i++) { // If model is batched and this dependent variable is duplicated - if ((modelMerged.getModel().getBatchSize() > 1) && dependentVars.at(i).isDuplicated()) { + if (batchSize > 1 && dependentVars.at(i).isDuplicated()) { // Loop through all batches and copy dependent variable from end of row over synapse to be deleted - removeSynapse << "for(int b = 0; b < " << modelMerged.getModel().getBatchSize() << "; b++)"; + removeSynapse << "for(int b = 0; b < " << batchSize << "; b++)"; { CodeStream::Scope b(removeSynapse); removeSynapse << "$(_dependent_var_" << i << ")[(b * $(_syn_stride)) + $(id_syn)] = "; @@ -382,7 +381,7 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getVarRefs(), errorHandler); addTypes(env, getArchetype().getCustomConnectivityUpdateModel()->getPostVarRefs(), errorHandler); }, - [&backend, &modelMerged, &removeSynapseStream, this](auto &env, auto generateBody) + [batchSize, &backend, &removeSynapseStream, this](auto &env, auto generateBody) { EnvironmentGroupMergedField bodyEnv(env, *this); bodyEnv.print("for(int j = 0; j < $(_row_length)[$(id_pre)]; j++)"); @@ -399,9 +398,9 @@ void CustomConnectivityUpdateGroupMerged::generateUpdate(const BackendBase &back bodyEnv.addVars(backend.getDeviceVarPrefix(), "$(id_post)"); // Add postsynaptic and synapse variable references, only exposing those that aren't batched - addPrivateVarRefAccess(bodyEnv, modelMerged.getModel().getBatchSize(), "$(id_syn)"); + addPrivateVarRefAccess(bodyEnv, batchSize, "$(id_syn)"); addPrivateVarRefAccess( - bodyEnv, modelMerged.getModel().getBatchSize(), + bodyEnv, batchSize, [](VarAccessMode, const Models::VarReference &varRef) { if(varRef.getDelayNeuronGroup() != nullptr) { @@ -436,7 +435,7 @@ bool CustomConnectivityUpdateGroupMerged::isDerivedParamHeterogeneous(const std: //---------------------------------------------------------------------------- const std::string CustomConnectivityHostUpdateGroupMerged::name = "CustomConnectivityHostUpdate"; //---------------------------------------------------------------------------- -void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged&) +void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env) { CodeStream::Scope b(env.getStream()); diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 7d4ff13ecb..b2a405395c 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -741,7 +741,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, [&backend, &modelMerged, &synapseConnectivityHostInit](auto &sg) { EnvironmentExternal env(synapseConnectivityHostInit); - sg.generateInit(backend, env, modelMerged); + sg.generateInit(backend, env); }); // Loop through merged synapse connectivity host initialisation groups diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index aad8351b46..5f3cca3600 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -195,11 +195,11 @@ void genInitWUVarCode(const BackendBase &backend, EnvironmentExternalBase &env, // GeNN::CodeGenerator::NeuronInitGroupMerged::CurrentSource //---------------------------------------------------------------------------- void NeuronInitGroupMerged::CurrentSource::generate(const BackendBase &backend, EnvironmentExternalBase &env, - NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) + NeuronInitGroupMerged &ng, unsigned int batchSize) { genInitNeuronVarCode( backend, env, *this, ng, "CS" + std::to_string(getIndex()), - "num_neurons", 0, modelMerged.getModel().getBatchSize()); + "num_neurons", 0, batchSize); } @@ -207,7 +207,7 @@ void NeuronInitGroupMerged::CurrentSource::generate(const BackendBase &backend, // GeNN::CodeGenerator::NeuronInitGroupMerged::InSynPSM //---------------------------------------------------------------------------- void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, EnvironmentExternalBase &env, - NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) + NeuronInitGroupMerged &ng, unsigned int batchSize) { const std::string fieldSuffix = "InSyn" + std::to_string(getIndex()); @@ -218,11 +218,10 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir groupEnv.addField(getScalarType().createPointer(), "_out_post", "outPost" + fieldSuffix, [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); backend.genVariableInit(groupEnv, "num_neurons", "id", - [&modelMerged, this] (EnvironmentExternalBase &varEnv) + [batchSize, this] (EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_out_post", writePreciseLiteral(0.0, getScalarType()), - "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, - modelMerged.getModel().getBatchSize()); + "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, batchSize); }); @@ -232,12 +231,11 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir groupEnv.addField(getScalarType().createPointer(), "_den_delay", "denDelay" + fieldSuffix, [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "denDelay" + g.getFusedPSVarSuffix(); }); backend.genVariableInit(groupEnv, "num_neurons", "id", - [&modelMerged, this](EnvironmentExternalBase &varEnv) + [batchSize, this](EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_den_delay", writePreciseLiteral(0.0, getScalarType()), "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, - modelMerged.getModel().getBatchSize(), - true, getArchetype().getMaxDendriticDelayTimesteps()); + batchSize, true, getArchetype().getMaxDendriticDelayTimesteps()); }); // Add field for dendritic delay pointer and zero @@ -251,14 +249,14 @@ void NeuronInitGroupMerged::InSynPSM::generate(const BackendBase &backend, Envir } genInitNeuronVarCode( - backend, groupEnv, *this, ng, fieldSuffix, "num_neurons", 0, modelMerged.getModel().getBatchSize()); + backend, groupEnv, *this, ng, fieldSuffix, "num_neurons", 0, batchSize); } //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynPreOutput //---------------------------------------------------------------------------- void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, EnvironmentExternalBase &env, - NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) + NeuronInitGroupMerged &ng, unsigned int batchSize) { // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this, ng); @@ -267,11 +265,10 @@ void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend groupEnv.addField(getScalarType().createPointer(), "_out_pre", "outPreOutSyn" + std::to_string(getIndex()), [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); backend.genVariableInit(env, "num_neurons", "id", - [&modelMerged, this] (EnvironmentExternalBase &varEnv) + [batchSize, this] (EnvironmentExternalBase &varEnv) { genVariableFill(varEnv, "_out_pre", writePreciseLiteral(0.0, getScalarType()), - "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, - modelMerged.getModel().getBatchSize()); + "id", "$(num_neurons)", VarAccessDuplication::DUPLICATE, batchSize); }); } @@ -279,20 +276,20 @@ void NeuronInitGroupMerged::OutSynPreOutput::generate(const BackendBase &backend // GeNN::CodeGenerator::NeuronInitGroupMerged::InSynWUMPostVars //---------------------------------------------------------------------------- void NeuronInitGroupMerged::InSynWUMPostVars::generate(const BackendBase &backend, EnvironmentExternalBase &env, - NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) + NeuronInitGroupMerged &ng, unsigned int batchSize) { genInitNeuronVarCode( - backend, env, *this, ng, "InSynWUMPost" + std::to_string(getIndex()), "num_neurons", 0, modelMerged.getModel().getBatchSize()); + backend, env, *this, ng, "InSynWUMPost" + std::to_string(getIndex()), "num_neurons", 0, batchSize); } //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronInitGroupMerged::OutSynWUMPreVars //---------------------------------------------------------------------------- void NeuronInitGroupMerged::OutSynWUMPreVars::generate(const BackendBase &backend, EnvironmentExternalBase &env, - NeuronInitGroupMerged &ng, const ModelSpecMerged &modelMerged) + NeuronInitGroupMerged &ng, unsigned int batchSize) { genInitNeuronVarCode( - backend, env, *this, ng, "OutSynWUMPre" + std::to_string(getIndex()), "num_neurons", 0, modelMerged.getModel().getBatchSize()); + backend, env, *this, ng, "OutSynWUMPre" + std::to_string(getIndex()), "num_neurons", 0, batchSize); } //---------------------------------------------------------------------------- @@ -362,39 +359,37 @@ boost::uuids::detail::sha1::digest_type NeuronInitGroupMerged::getHashDigest() c return hash.get_digest(); } //---------------------------------------------------------------------------- -void NeuronInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void NeuronInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { - const auto &model = modelMerged.getModel(); - // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); // Initialise spike counts - genInitSpikeCount(backend, groupEnv, false, model.getBatchSize()); - genInitSpikeCount(backend, groupEnv, true, model.getBatchSize()); + genInitSpikeCount(backend, groupEnv, false, batchSize); + genInitSpikeCount(backend, groupEnv, true, batchSize); // Initialise spikes - genInitSpikes(backend, groupEnv, false, model.getBatchSize()); - genInitSpikes(backend, groupEnv, true, model.getBatchSize()); + genInitSpikes(backend, groupEnv, false, batchSize); + genInitSpikes(backend, groupEnv, true, batchSize); // Initialize spike times if(getArchetype().isSpikeTimeRequired()) { - genInitSpikeTime(backend, groupEnv, "sT", model.getBatchSize()); + genInitSpikeTime(backend, groupEnv, "sT", batchSize); } // Initialize previous spike times if(getArchetype().isPrevSpikeTimeRequired()) { - genInitSpikeTime( backend, groupEnv, "prevST", model.getBatchSize()); + genInitSpikeTime( backend, groupEnv, "prevST", batchSize); } // Initialize spike-like-event times if(getArchetype().isSpikeEventTimeRequired()) { - genInitSpikeTime(backend, groupEnv, "seT", model.getBatchSize()); + genInitSpikeTime(backend, groupEnv, "seT", batchSize); } // Initialize previous spike-like-event times if(getArchetype().isPrevSpikeEventTimeRequired()) { - genInitSpikeTime(backend, groupEnv, "prevSET", model.getBatchSize()); + genInitSpikeTime(backend, groupEnv, "prevSET", batchSize); } // If neuron group requires delays @@ -410,23 +405,23 @@ void NeuronInitGroupMerged::generateInit(const BackendBase &backend, Environment } // Initialise neuron variables - genInitNeuronVarCode(backend, groupEnv, *this, "", "num_neurons", 0, modelMerged.getModel().getBatchSize()); + genInitNeuronVarCode(backend, groupEnv, *this, "", "num_neurons", 0, batchSize); // Generate initialisation code for child groups for (auto &cs : m_MergedCurrentSourceGroups) { - cs.generate(backend, groupEnv, *this, modelMerged); + cs.generate(backend, groupEnv, *this, batchSize); } for(auto &sg : m_MergedInSynPSMGroups) { - sg.generate(backend, groupEnv, *this, modelMerged); + sg.generate(backend, groupEnv, *this, batchSize); } for (auto &sg : m_MergedOutSynPreOutputGroups) { - sg.generate(backend, groupEnv, *this, modelMerged); + sg.generate(backend, groupEnv, *this, batchSize); } for (auto &sg : m_MergedOutSynWUMPreVarGroups) { - sg.generate(backend, groupEnv, *this, modelMerged); + sg.generate(backend, groupEnv, *this, batchSize); } for (auto &sg : m_MergedInSynWUMPostVarGroups) { - sg.generate(backend, groupEnv, *this, modelMerged); + sg.generate(backend, groupEnv, *this, batchSize); } } //-------------------------------------------------------------------------- @@ -530,14 +525,14 @@ boost::uuids::detail::sha1::digest_type SynapseInitGroupMerged::getHashDigest() return hash.get_digest(); } //---------------------------------------------------------------------------- -void SynapseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void SynapseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); // If model is batched and has kernel weights const bool kernel = (getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL); - if (kernel && modelMerged.getModel().getBatchSize() > 1) { + if (kernel && batchSize > 1) { // Loop through kernel dimensions and multiply together to calculate batch stride std::ostringstream batchStrideInit; batchStrideInit << "const unsigned int batchStride = "; @@ -564,7 +559,7 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, Environmen // Generate initialisation code const std::string stride = kernel ? "$(_batch_stride)" : "$(num_pre) * $(_row_stride)"; - genInitWUVarCode(backend, groupEnv, *this, stride, modelMerged.getModel().getBatchSize(), + genInitWUVarCode(backend, groupEnv, *this, stride, batchSize, [&backend, kernel, this](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { if (kernel) { @@ -605,10 +600,10 @@ boost::uuids::detail::sha1::digest_type SynapseSparseInitGroupMerged::getHashDig return hash.get_digest(); } //---------------------------------------------------------------------------- -void SynapseSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void SynapseSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { // Create environment for group - genInitWUVarCode(backend, env, *this, "$(num_pre) * $(_row_stride)", modelMerged.getModel().getBatchSize(), + genInitWUVarCode(backend, env, *this, "$(num_pre) * $(_row_stride)", batchSize, [&backend](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { backend.genSparseSynapseVariableRowInit(varInitEnv, handler); @@ -657,7 +652,7 @@ void SynapseConnectivityInitGroupMerged::generateSparseColumnInit(const BackendB genInitConnectivity(backend, env, false); } //---------------------------------------------------------------------------- -void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); @@ -668,7 +663,7 @@ void SynapseConnectivityInitGroupMerged::generateKernelInit(const BackendBase &b {groupEnv.addInitialiser("const unsigned int kernelInd = " + getKernelIndex(*this) + ";")}); // Initialise single (hence empty lambda function) synapse variable - genInitWUVarCode(backend, groupEnv, *this, "$(num_pre) * $(_row_stride)", modelMerged.getModel().getBatchSize(), + genInitWUVarCode(backend, groupEnv, *this, "$(num_pre) * $(_row_stride)", batchSize, [](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { handler(varInitEnv); @@ -730,13 +725,13 @@ void SynapseConnectivityInitGroupMerged::genInitConnectivity(const BackendBase & //---------------------------------------------------------------------------- const std::string SynapseConnectivityHostInitGroupMerged::name = "SynapseConnectivityHostInit"; //------------------------------------------------------------------------- -void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void SynapseConnectivityHostInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env) { // Add standard library to environment EnvironmentLibrary envStdLib(env, StandardLibrary::getMathsFunctions()); // Add host RNG functions to environment - EnvironmentLibrary envRandom(envStdLib, StandardLibrary::getHostRNGFunctions(modelMerged.getModel().getPrecision())); + EnvironmentLibrary envRandom(envStdLib, StandardLibrary::getHostRNGFunctions(getScalarType())); // Add standard host assert function to environment EnvironmentExternal envAssert(envRandom); @@ -870,11 +865,11 @@ boost::uuids::detail::sha1::digest_type CustomUpdateInitGroupMerged::getHashDige return hash.get_digest(); } // ---------------------------------------------------------------------------- -void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void CustomUpdateInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { // Initialise custom update variables genInitNeuronVarCode(backend, env, *this, "", "size", 1, - getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1); + getArchetype().isBatched() ? batchSize : 1); } // ---------------------------------------------------------------------------- @@ -918,7 +913,7 @@ boost::uuids::detail::sha1::digest_type CustomWUUpdateInitGroupMerged::getHashDi return hash.get_digest(); } // ---------------------------------------------------------------------------- -void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { EnvironmentGroupMergedField groupEnv(env, *this); @@ -933,7 +928,7 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Env } } - if(modelMerged.getModel().getBatchSize() > 1) { + if(batchSize > 1) { // Loop through kernel dimensions and multiply together to calculate batch stride std::ostringstream batchStrideInit; batchStrideInit << "const unsigned int batchStride = "; @@ -969,7 +964,7 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Env // Loop through rows const std::string stride = kernel ? "$(_batch_stride)" : "$(num_pre) * $(_row_stride)"; genInitWUVarCode( - backend, groupEnv, *this, stride, getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, + backend, groupEnv, *this, stride, getArchetype().isBatched() ? batchSize : 1, [&backend, kernel, this](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { if (kernel) { @@ -1020,7 +1015,7 @@ boost::uuids::detail::sha1::digest_type CustomWUUpdateSparseInitGroupMerged::get return hash.get_digest(); } // ---------------------------------------------------------------------------- -void CustomWUUpdateSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void CustomWUUpdateSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); @@ -1047,7 +1042,7 @@ void CustomWUUpdateSparseInitGroupMerged::generateInit(const BackendBase &backen });*/ genInitWUVarCode(backend, groupEnv, *this, "$(num_pre) * $(_row_stride)", - getArchetype().isBatched() ? modelMerged.getModel().getBatchSize() : 1, + getArchetype().isBatched() ? batchSize : 1, [&backend](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { return backend.genSparseSynapseVariableRowInit(varInitEnv, handler); @@ -1078,7 +1073,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePreInitGroupMerg return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomConnectivityUpdatePreInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void CustomConnectivityUpdatePreInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); @@ -1091,7 +1086,7 @@ void CustomConnectivityUpdatePreInitGroupMerged::generateInit(const BackendBase }); // Initialise presynaptic custom connectivity update variables - genInitNeuronVarCode(backend, groupEnv, *this, "", "size", 0, modelMerged.getModel().getBatchSize()); + genInitNeuronVarCode(backend, groupEnv, *this, "", "size", 0, batchSize); } // ---------------------------------------------------------------------------- @@ -1118,7 +1113,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePostInitGroupMer return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); @@ -1131,7 +1126,7 @@ void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase }); // Initialise presynaptic custom connectivity update variables - genInitNeuronVarCode(backend, groupEnv, *this, "", "size", 0, modelMerged.getModel().getBatchSize()); + genInitNeuronVarCode(backend, groupEnv, *this, "", "size", 0, batchSize); } // ---------------------------------------------------------------------------- @@ -1168,7 +1163,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdateSparseInitGroupM return hash.get_digest(); } //---------------------------------------------------------------------------- -void CustomConnectivityUpdateSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged&) +void CustomConnectivityUpdateSparseInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int) { // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index d07c366ad8..e2a8b5a0a5 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -18,7 +18,7 @@ using namespace GeNN::Transpiler; // GeNN::CodeGenerator::NeuronUpdateGroupMerged::CurrentSource //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend, EnvironmentExternalBase &env, - NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) + NeuronUpdateGroupMerged &ng, unsigned int batchSize) { const std::string fieldSuffix = "CS" + std::to_string(getIndex()); const auto *cm = getArchetype().getCurrentSourceModel(); @@ -34,15 +34,15 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend csEnv.addExtraGlobalParams(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", fieldSuffix); // Define inject current function - csEnv.add(Type::ResolvedType::createFunction(Type::Void, {modelMerged.getModel().getPrecision()}), + csEnv.add(Type::ResolvedType::createFunction(Type::Void, {getScalarType()}), "injectCurrent", "$(Isyn) += $(0)"); // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [&modelMerged, &ng](const std::string&, VarAccessDuplication d) + [batchSize, &ng](const std::string&, VarAccessDuplication d) { - return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "$(id)"); + return ng.getVarIndex(batchSize, d, "$(id)"); }); // Pretty print code back to environment @@ -70,7 +70,7 @@ bool NeuronUpdateGroupMerged::CurrentSource::isDerivedParamHeterogeneous( const // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynPSM //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, EnvironmentExternalBase &env, - NeuronUpdateGroupMerged &ng, const ModelSpecMerged &modelMerged) + NeuronUpdateGroupMerged &ng, unsigned int batchSize) { const std::string fieldSuffix = "InSyn" + std::to_string(getIndex()); const auto *psm = getArchetype().getPSModel(); @@ -83,7 +83,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPost" + g.getFusedPSVarSuffix(); }); // Read into local variable - const std::string idx = ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "$(id)"); + const std::string idx = ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)"); psmEnv.getStream() << "// postsynaptic model " << getIndex() << std::endl; psmEnv.printLine(getScalarType().getName() + " linSyn = $(_out_post)[" + idx + "];"); @@ -111,17 +111,17 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env psmEnv.addExtraGlobalParams(psm->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", fieldSuffix); // **TODO** naming convention - psmEnv.add(modelMerged.getModel().getPrecision(), "inSyn", "linSyn"); + psmEnv.add(getScalarType(), "inSyn", "linSyn"); // Allow synapse group's PS output var to override what Isyn points to - psmEnv.add(modelMerged.getModel().getPrecision(), "Isyn", getArchetype().getPSTargetVar()); + psmEnv.add(getScalarType(), "Isyn", getArchetype().getPSTargetVar()); // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", - [&modelMerged, &ng](const std::string&, VarAccessDuplication d) + [batchSize, &ng](const std::string&, VarAccessDuplication d) { - return ng.getVarIndex(modelMerged.getModel().getBatchSize(), d, "$(id)"); + return ng.getVarIndex(batchSize, d, "$(id)"); }); // Pretty print code back to environment @@ -132,7 +132,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env prettyPrintStatements(getArchetype().getPSDecayCodeTokens(), getTypeContext(), varEnv, decayErrorHandler); // Write back linSyn - varEnv.printLine("$(_out_post)[" + ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "$(id)") + "] = linSyn;"); + varEnv.printLine("$(_out_post)[" + ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)") + "] = linSyn;"); } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynPSM::updateHash(boost::uuids::detail::sha1 &hash) const @@ -154,19 +154,19 @@ bool NeuronUpdateGroupMerged::InSynPSM::isDerivedParamHeterogeneous( const std:: //---------------------------------------------------------------------------- // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynPreOutput //---------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged) +void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backend, EnvironmentExternalBase &env, + NeuronUpdateGroupMerged &ng, unsigned int batchSize) { const std::string fieldSuffix = "OutSyn" + std::to_string(getIndex()); // Create new environment to add out syn fields to neuron update group EnvironmentGroupMergedField outSynEnv(env, *this, ng); - outSynEnv.addField(modelMerged.getModel().getPrecision().createPointer(), "_out_pre", "outPre" + fieldSuffix, + outSynEnv.addField(getScalarType().createPointer(), "_out_pre", "outPre" + fieldSuffix, [&backend](const auto &g, size_t) { return backend.getDeviceVarPrefix() + "outPre" + g.getFusedPreOutputSuffix(); }); // Add reverse insyn variable to - const std::string idx = ng.getVarIndex(modelMerged.getModel().getBatchSize(), VarAccessDuplication::DUPLICATE, "$(id)"); + const std::string idx = ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)"); outSynEnv.printLine(getArchetype().getPreTargetVar() + " += $(_out_pre)[" + idx + "];"); // Zero it again @@ -177,13 +177,11 @@ void NeuronUpdateGroupMerged::OutSynPreOutput::generate(const BackendBase &backe // GeNN::CodeGenerator::NeuronUpdateGroupMerged::InSynWUMPostCode //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &backend, EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) + unsigned int batchSize, bool dynamicsNotSpike) { const std::string fieldSuffix = "InSynWUMPost" + std::to_string(getIndex()); const auto *wum = getArchetype().getWUModel(); - const unsigned int batchSize = modelMerged.getModel().getBatchSize(); - // If there are any statements to execute here const auto &tokens = dynamicsNotSpike ? getArchetype().getWUPostDynamicsCodeTokens() : getArchetype().getWUPostSpikeCodeTokens(); if(!Utils::areTokensEmpty(tokens)) { @@ -229,7 +227,7 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged) + unsigned int batchSize) { // If this group has a delay and no postsynaptic dynamics (which will already perform this copying) const std::string suffix = "InSynWUMPost" + std::to_string(getIndex()); @@ -237,7 +235,6 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots for(const auto &v : getArchetype().getWUModel()->getPostVars()) { if(v.access & VarAccessMode::READ_WRITE) { - const unsigned int batchSize = modelMerged.getModel().getBatchSize(); env.print("$(" + v.name + ")[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(v.access), "$(id)") + "] = "); env.printLine("$(" + v.name + ")[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(v.access), "$(id)") + "];"); } @@ -265,11 +262,10 @@ bool NeuronUpdateGroupMerged::InSynWUMPostCode::isDerivedParamHeterogeneous( con // GeNN::CodeGenerator::NeuronUpdateGroupMerged::OutSynWUMPreCode //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &backend, EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged, bool dynamicsNotSpike) + unsigned int batchSize, bool dynamicsNotSpike) { const std::string fieldSuffix = "OutSynWUMPre" + std::to_string(getIndex()); const auto *wum = getArchetype().getWUModel(); - const unsigned int batchSize = modelMerged.getModel().getBatchSize(); // If there are any statements to execute here const auto &tokens = dynamicsNotSpike ? getArchetype().getWUPreDynamicsCodeTokens() : getArchetype().getWUPreSpikeCodeTokens(); @@ -316,7 +312,7 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back } //---------------------------------------------------------------------------- void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentExternalBase &env, const NeuronUpdateGroupMerged &ng, - const ModelSpecMerged &modelMerged) + unsigned int batchSize) { // If this group has a delay and no presynaptic dynamics (which will already perform this copying) const std::string suffix = "OutSynWUMPre" + std::to_string(getIndex()); @@ -324,7 +320,6 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::genCopyDelayedVars(EnvironmentEx // Loop through variables and copy between read and write delay slots for(const auto &v : getArchetype().getWUModel()->getPreVars()) { if(v.access & VarAccessMode::READ_WRITE) { - const unsigned int batchSize = modelMerged.getModel().getBatchSize(); env.print("$(" + v.name + ")[" + ng.getWriteVarIndex(true, batchSize, getVarAccessDuplication(v.access), "$(id)") + "] = "); env.printLine("$(" + v.name + ")[" + ng.getReadVarIndex(true, batchSize, getVarAccessDuplication(v.access), "$(id)") + "];"); } @@ -459,18 +454,16 @@ boost::uuids::detail::sha1::digest_type NeuronUpdateGroupMerged::getHashDigest() return hash.get_digest(); } //-------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, +void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize, BackendBase::GroupHandlerEnv genEmitTrueSpike, BackendBase::GroupHandlerEnv genEmitSpikeLikeEvent) { - const ModelSpecInternal &model = modelMerged.getModel(); - const unsigned int batchSize = model.getBatchSize(); const NeuronModels::Base *nm = getArchetype().getNeuronModel(); EnvironmentGroupMergedField neuronEnv(env, *this); // Add default input variable - neuronEnv.add(modelMerged.getModel().getPrecision(), "Isyn", "Isyn", + neuronEnv.add(getScalarType(), "Isyn", "Isyn", {neuronEnv.addInitialiser(getScalarType().getName() + " Isyn = 0;")}); // **NOTE** arbitrary code in param value to be deprecated @@ -486,7 +479,7 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E neuronEnv.addExtraGlobalParams(nm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); // Substitute spike times - const std::string timePrecision = modelMerged.getModel().getTimePrecision().getName(); + const std::string timePrecision = getTimeType().getName(); const std::string spikeTimeReadIndex = getReadVarIndex(getArchetype().isDelayRequired(), batchSize, VarAccessDuplication::DUPLICATE, "$(id)"); neuronEnv.add(getTimeType().addConst(), "sT", "lsT", {neuronEnv.addInitialiser("const " + timePrecision + " lsT = $(_spk_time)[" + spikeTimeReadIndex + "];")}); @@ -516,19 +509,19 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Loop through incoming synapse groups for(auto &sg : m_MergedInSynPSMGroups) { CodeStream::Scope b(neuronVarEnv.getStream()); - sg.generate(backend, neuronVarEnv, *this, modelMerged); + sg.generate(backend, neuronVarEnv, *this, batchSize); } // Loop through outgoing synapse groups with presynaptic output for (auto &sg : m_MergedOutSynPreOutputGroups) { CodeStream::Scope b(neuronVarEnv.getStream()); - sg.generate(backend, neuronVarEnv, *this, modelMerged); + sg.generate(backend, neuronVarEnv, *this, batchSize); } // Loop through all of neuron group's current sources for (auto &cs : m_MergedCurrentSourceGroups) { CodeStream::Scope b(neuronVarEnv.getStream()); - cs.generate(backend, neuronVarEnv, *this, modelMerged); + cs.generate(backend, neuronVarEnv, *this, batchSize); } @@ -564,13 +557,13 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Generate var update for outgoing synaptic populations with presynaptic update code for (auto &sg : m_MergedOutSynWUMPreCodeGroups) { CodeStream::Scope b(neuronVarEnv.getStream()); - sg.generate(backend, neuronVarEnv, *this, modelMerged, true); + sg.generate(backend, neuronVarEnv, *this, batchSize, true); } // Generate var update for incoming synaptic populations with postsynaptic code for (auto &sg : m_MergedInSynWUMPostCodeGroups) { CodeStream::Scope b(neuronVarEnv.getStream()); - sg.generate(backend, neuronVarEnv, *this, modelMerged, true); + sg.generate(backend, neuronVarEnv, *this, batchSize, true); } // look for spike type events first. @@ -713,30 +706,30 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Loop through outgoing synapse groups with some sort of presynaptic code for (auto &sg : m_MergedOutSynWUMPreCodeGroups) { - sg.genCopyDelayedVars(neuronVarEnv, *this, modelMerged); + sg.genCopyDelayedVars(neuronVarEnv, *this, batchSize); } // Loop through incoming synapse groups with some sort of presynaptic code for (auto &sg : m_MergedInSynWUMPostCodeGroups) { - sg.genCopyDelayedVars(neuronVarEnv, *this, modelMerged); + sg.genCopyDelayedVars(neuronVarEnv, *this, batchSize); } } } } } //-------------------------------------------------------------------------- -void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void NeuronUpdateGroupMerged::generateWUVarUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { // Generate var update for outgoing synaptic populations with presynaptic update code for (auto &sg : m_MergedOutSynWUMPreCodeGroups) { CodeStream::Scope b(env.getStream()); - sg.generate(backend, env, *this, modelMerged, false); + sg.generate(backend, env, *this, batchSize, false); } // Generate var update for incoming synaptic populations with postsynaptic code for (auto &sg : m_MergedInSynWUMPostCodeGroups) { CodeStream::Scope b(env.getStream()); - sg.generate(backend, env, *this, modelMerged, false); + sg.generate(backend, env, *this, batchSize, false); } } //-------------------------------------------------------------------------- diff --git a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc index 8c9e5d685d..100646f81b 100644 --- a/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc +++ b/src/genn/genn/code_generator/presynapticUpdateStrategySIMT.cc @@ -10,8 +10,7 @@ #include "code_generator/backendSIMT.h" #include "code_generator/codeGenUtils.h" #include "code_generator/codeStream.h" -#include "code_generator/groupMerged.h" -#include "code_generator/modelSpecMerged.h" +#include "code_generator/synapseUpdateGroupMerged.h" //---------------------------------------------------------------------------- // Anonymous namespace @@ -79,12 +78,10 @@ void PreSpan::genPreamble(EnvironmentExternalBase&, PresynapticUpdateGroupMerged { } //---------------------------------------------------------------------------- -void PreSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const +void PreSpan::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize, bool trueSpike) const { // Get suffix based on type of events - const ModelSpecInternal &model = modelMerged.getModel(); - const unsigned int batchSize = model.getBatchSize(); const std::string eventSuffix = trueSpike ? "" : "_evnt"; const size_t numThreadsPerSpike = sg.getArchetype().getNumThreadsPerSpike(); @@ -152,16 +149,16 @@ void PreSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &mod synEnv.add(Type::Uint32.addConst(), "id_syn", "synAddress"); synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", - backend.getAtomic(model.getPrecision()) + "(&$(_den_delay)[" + sg.getPostDenDelayIndex(batchSize, "ipost", "$(1)") + "], $(0))"); + backend.getAtomic(sg.getScalarType()) + "(&$(_den_delay)[" + sg.getPostDenDelayIndex(batchSize, "ipost", "$(1)") + "], $(0))"); synEnv.add(Type::AddToPost, "addToPost", - backend.getAtomic(model.getPrecision()) + "(&$(_out_post)[" + sg.getPostISynIndex(batchSize, "ipost") + "], $(0))"); + backend.getAtomic(sg.getScalarType()) + "(&$(_out_post)[" + sg.getPostISynIndex(batchSize, "ipost") + "], $(0))"); synEnv.add(Type::AddToPre, "addToPre", "lrevInSyn += $(0)"); if(trueSpike) { - sg.generateSpikeUpdate(backend, synEnv, modelMerged); + sg.generateSpikeUpdate(backend, synEnv, batchSize); } else { - sg.generateSpikeEventUpdate(backend, synEnv, modelMerged); + sg.generateSpikeEventUpdate(backend, synEnv, batchSize); } } @@ -173,14 +170,13 @@ void PreSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &mod // Should this be in the Postamble? if(sg.getArchetype().isPresynapticOutputRequired()) { // write lrevInSyn to global memory if not 0 - env.getStream() << "if(lrevInSyn != 0.0) " << backend.getAtomic(model.getPrecision()) + "(&" + env["_out_pre"] + "[" + sg.getPreISynIndex(batchSize, "preInd") + "], lrevInSyn);" << std::endl; + env.getStream() << "if(lrevInSyn != 0.0) " << backend.getAtomic(sg.getScalarType()) + "(&" + env["_out_pre"] + "[" + sg.getPreISynIndex(batchSize, "preInd") + "], lrevInSyn);" << std::endl; } } } //---------------------------------------------------------------------------- -void PreSpan::genPostamble(EnvironmentExternalBase &, const ModelSpecMerged&, - PresynapticUpdateGroupMerged&, const BackendSIMT&) const +void PreSpan::genPostamble(EnvironmentExternalBase &, PresynapticUpdateGroupMerged&, const BackendSIMT&, unsigned int) const { } @@ -239,12 +235,10 @@ size_t PostSpan::getSharedMemoryPerThread(const PresynapticUpdateGroupMerged &sg return isSmallSharedMemoryPop(sg, backend) ? 1 : 0; } //---------------------------------------------------------------------------- -void PostSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const +void PostSpan::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize, bool trueSpike) const { // Get suffix based on type of events - const ModelSpecInternal &model = modelMerged.getModel(); - const unsigned int batchSize = model.getBatchSize(); const std::string eventSuffix = trueSpike ? "" : "_evnt"; env.printLine("const unsigned int numSpikes = $(_src_spk_cnt" + eventSuffix + ")[" + sg.getPreSlot(batchSize) + "];"); @@ -331,7 +325,7 @@ void PostSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &mo // If dendritic delay is required, always use atomic operation to update dendritic delay buffer synEnv.add(Type::AddToPostDenDelay, "addToPostDelay", - backend.getAtomic(model.getPrecision()) + "(&$(_den_delay)[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); + backend.getAtomic(sg.getScalarType()) + "(&$(_den_delay)[" + sg.getPostDenDelayIndex(batchSize, "$(id_post)", "$(1)") + "], $(0))"); // If we should accumulate in register, add parameter to register if(shouldAccumulateInRegister(sg)) { @@ -345,19 +339,19 @@ void PostSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &mo // Otherwise, use global memory atomic else { synEnv.add(Type::AddToPost, "addToPost", - backend.getAtomic(model.getPrecision()) + "(&$(_out_post)[" + sg.getPostISynIndex(batchSize, "$(id_post)") + "], $(0))"); + backend.getAtomic(sg.getScalarType()) + "(&$(_out_post)[" + sg.getPostISynIndex(batchSize, "$(id_post)") + "], $(0))"); } if(sg.getArchetype().isPresynapticOutputRequired()) { synEnv.add(Type::AddToPre, "addToPre", - backend.getAtomic(model.getPrecision()) + "(&$(_out_pre)([" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], $(0))"); + backend.getAtomic(sg.getScalarType()) + "(&$(_out_pre)([" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], $(0))"); } if(trueSpike) { - sg.generateSpikeUpdate(backend, synEnv, modelMerged); + sg.generateSpikeUpdate(backend, synEnv, batchSize); } else { - sg.generateSpikeEventUpdate(backend, synEnv, modelMerged); + sg.generateSpikeEventUpdate(backend, synEnv, batchSize); } if(sg.getArchetype().getMatrixType() & SynapseMatrixConnectivity::SPARSE) { @@ -375,12 +369,10 @@ void PostSpan::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &mo } } //---------------------------------------------------------------------------- -void PostSpan::genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const +void PostSpan::genPostamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize) const { // If we should accumulate output directly into register - const ModelSpecInternal &model = modelMerged.getModel(); - const unsigned int batchSize = model.getBatchSize(); if(shouldAccumulateInRegister(sg)) { env.getStream() << "// only do this for existing neurons" << std::endl; env.print("if ($(id) < $(num_post))"); @@ -388,7 +380,7 @@ void PostSpan::genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged CodeStream::Scope b(env.getStream()); const std::string inSyn = printSubs("$(_out_post)[" + sg.getPostISynIndex(batchSize, "$(id)") + "]", env); if(sg.getArchetype().isPSModelFused()) { - env.getStream() << backend.getAtomic(model.getPrecision()) << "(&" << inSyn << ", linSyn);" << std::endl; + env.getStream() << backend.getAtomic(sg.getScalarType()) << "(&" << inSyn << ", linSyn);" << std::endl; } else { env.getStream() << inSyn << " += linSyn;" << std::endl; @@ -402,7 +394,7 @@ void PostSpan::genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged { CodeGenerator::CodeStream::Scope b(env.getStream()); const std::string inSyn = printSubs("$(_out_post)[" + sg.getPostISynIndex(batchSize, backend.getThreadID()) + "]", env); - env.getStream() << backend.getAtomic(model.getPrecision()) << "(&" << inSyn << "], shLg[" << backend.getThreadID() << "]); " << std::endl; + env.getStream() << backend.getAtomic(sg.getScalarType()) << "(&" << inSyn << "], shLg[" << backend.getThreadID() << "]); " << std::endl; } } } @@ -449,12 +441,10 @@ void PreSpanProcedural::genPreamble(EnvironmentExternalBase&, PresynapticUpdateG { } //---------------------------------------------------------------------------- -void PreSpanProcedural::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT&, bool trueSpike) const +void PreSpanProcedural::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT&, unsigned int batchSize, bool trueSpike) const { // Get suffix based on type of events - const ModelSpecInternal &model = modelMerged.getModel(); - const unsigned int batchSize = model.getBatchSize(); const std::string eventSuffix = trueSpike ? "" : "_evnt"; const size_t numThreadsPerSpike = sg.getArchetype().getNumThreadsPerSpike(); @@ -602,8 +592,8 @@ void PreSpanProcedural::genUpdate(EnvironmentExternalBase &env, const ModelSpecM } } //---------------------------------------------------------------------------- -void PreSpanProcedural::genPostamble(EnvironmentExternalBase&, const ModelSpecMerged&, - PresynapticUpdateGroupMerged&, const BackendSIMT&) const +void PreSpanProcedural::genPostamble(EnvironmentExternalBase&, PresynapticUpdateGroupMerged&, + const BackendSIMT&, unsigned int) const { } @@ -650,11 +640,10 @@ size_t PostSpanBitmask::getSharedMemoryPerThread(const PresynapticUpdateGroupMer return 32; } //---------------------------------------------------------------------------- -void PostSpanBitmask::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const +void PostSpanBitmask::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize, bool trueSpike) const { // Get suffix based on type of events - const unsigned int batchSize = modelMerged.getModel().getBatchSize(); const std::string eventSuffix = trueSpike ? "" : "_evnt"; // Get blocksize @@ -736,13 +725,13 @@ void PostSpanBitmask::genUpdate(EnvironmentExternalBase &env, const ModelSpecMer synEnv.add(Type::AddToPost, "addToPost", "shLg[(ibit * " + std::to_string(blockSize) + ") + " + backend.getThreadID() + "] += $(0)"); synEnv.add(Type::AddToPre, "addToPre", - backend.getAtomic(modelMerged.getModel().getPrecision()) + "(&$(_out_pre)[" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], $(0))"); + backend.getAtomic(sg.getScalarType()) + "(&$(_out_pre)[" + sg.getPreISynIndex(batchSize, "$(id_pre)") + "], $(0))"); if(trueSpike) { - sg.generateSpikeUpdate(backend, synEnv, modelMerged); + sg.generateSpikeUpdate(backend, synEnv, batchSize); } else { - sg.generateSpikeEventUpdate(backend, synEnv, modelMerged); + sg.generateSpikeEventUpdate(backend, synEnv, batchSize); } synEnv.getStream() << "ibit++;" << std::endl; @@ -757,8 +746,8 @@ void PostSpanBitmask::genUpdate(EnvironmentExternalBase &env, const ModelSpecMer } } //---------------------------------------------------------------------------- -void PostSpanBitmask::genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const +void PostSpanBitmask::genPostamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize) const { backend.genSharedMemBarrier(env.getStream()); const size_t blockSize = backend.getKernelBlockSize(KernelPresynapticUpdate); @@ -773,9 +762,9 @@ void PostSpanBitmask::genPostamble(EnvironmentExternalBase &env, const ModelSpec env.print("for(;shIdx < endShIdx && glbIdx < $(num_post); shIdx++, glbIdx += 32)"); { CodeStream::Scope b(env.getStream()); - const std::string inSyn = "$(_out_post)[" + sg.getPostISynIndex(modelMerged.getModel().getBatchSize(), "glbIdx") +"]"; + const std::string inSyn = "$(_out_post)[" + sg.getPostISynIndex(batchSize, "glbIdx") +"]"; if(sg.getArchetype().isPSModelFused()) { - env.printLine(backend.getAtomic(modelMerged.getModel().getPrecision()) + "(&" + inSyn + ", shLg[shIdx]);"); + env.printLine(backend.getAtomic(sg.getScalarType()) + "(&" + inSyn + ", shLg[shIdx]);"); } else { env.printLine(inSyn + " += shLg[shIdx];"); @@ -821,8 +810,8 @@ size_t PostSpanToeplitz::getSharedMemoryPerThread(const PresynapticUpdateGroupMe return isSmallSharedMemoryPop(sg, backend) ? 1 : 0; } //---------------------------------------------------------------------------- -void PostSpanToeplitz::genUpdate(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend, bool trueSpike) const +void PostSpanToeplitz::genUpdate(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize, bool trueSpike) const { assert(false); /*const auto &connectInit = sg.getArchetype().getToeplitzConnectivityInitialiser(); @@ -963,8 +952,8 @@ void PostSpanToeplitz::genUpdate(EnvironmentExternalBase &env, const ModelSpecMe }*/ } //---------------------------------------------------------------------------- -void PostSpanToeplitz::genPostamble(EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged, - PresynapticUpdateGroupMerged &sg, const BackendSIMT &backend) const +void PostSpanToeplitz::genPostamble(EnvironmentExternalBase &env, PresynapticUpdateGroupMerged &sg, + const BackendSIMT &backend, unsigned int batchSize) const { // If we should accumulate into shared memory if(isSmallSharedMemoryPop(sg, backend)) { @@ -972,8 +961,8 @@ void PostSpanToeplitz::genPostamble(EnvironmentExternalBase &env, const ModelSpe env.print("if(" + backend.getThreadID() + " < $(num_post))"); { CodeGenerator::CodeStream::Scope b(env.getStream()); - const std::string idx = sg.getPostISynIndex(modelMerged.getModel().getBatchSize(), backend.getThreadID()); - env.printLine(backend.getAtomic(modelMerged.getModel().getPrecision()) + "(&$(_out_post)[" + idx + "], shLg[" + backend.getThreadID() + "]);"); + const std::string idx = sg.getPostISynIndex(batchSize, backend.getThreadID()); + env.printLine(backend.getAtomic(sg.getScalarType()) + "(&$(_out_post)[" + idx + "], shLg[" + backend.getThreadID() + "]);"); } } } diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index dcc160831a..f550e4eec6 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -16,10 +16,8 @@ namespace { template void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBase &env, const std::vector &tokens, const std::string &errorContext, - G &sg, const ModelSpecMerged &modelMerged, bool backendSupportsNamespace) + G &sg, unsigned int batchSize, bool backendSupportsNamespace) { - const ModelSpecInternal &model = modelMerged.getModel(); - const unsigned int batchSize = model.getBatchSize(); const auto *wu = sg.getArchetype().getWUModel(); EnvironmentGroupMergedField synEnv(env, sg); @@ -365,7 +363,7 @@ boost::uuids::detail::sha1::digest_type SynapseGroupMergedBase::getHashDigest() //---------------------------------------------------------------------------- const std::string PresynapticUpdateGroupMerged::name = "PresynapticUpdate"; //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateSpikeEventThreshold(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void PresynapticUpdateGroupMerged::generateSpikeEventThreshold(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { EnvironmentGroupMergedField synEnv(env, *this); @@ -394,16 +392,16 @@ void PresynapticUpdateGroupMerged::generateSpikeEventThreshold(const BackendBase prettyPrintStatements(getArchetype().getWUEventThresholdCodeTokens(), getTypeContext(), synEnv, errorHandler); } //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateSpikeEventUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void PresynapticUpdateGroupMerged::generateSpikeEventUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { applySynapseSubstitutions(backend, env, getArchetype().getWUEventCodeTokens(), "event code", - *this, modelMerged, backend.supportsNamespace()); + *this, batchSize, backend.supportsNamespace()); } //---------------------------------------------------------------------------- -void PresynapticUpdateGroupMerged::generateSpikeUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void PresynapticUpdateGroupMerged::generateSpikeUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { applySynapseSubstitutions(backend, env, getArchetype().getWUSimCodeTokens(), "sim code", - *this, modelMerged, backend.supportsNamespace()); + *this, batchSize, backend.supportsNamespace()); } //---------------------------------------------------------------------------- void PresynapticUpdateGroupMerged::generateProceduralConnectivity(const BackendBase&, EnvironmentExternalBase &env) @@ -449,14 +447,14 @@ void PresynapticUpdateGroupMerged::generateToeplitzConnectivity(const BackendBas //---------------------------------------------------------------------------- const std::string PostsynapticUpdateGroupMerged::name = "PostsynapticUpdate"; //---------------------------------------------------------------------------- -void PostsynapticUpdateGroupMerged::generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void PostsynapticUpdateGroupMerged::generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { /*if (!wum->getLearnPostSupportCode().empty() && backend.supportsNamespace()) { os << "using namespace " << modelMerged.getPostsynapticUpdateSupportCodeNamespace(wum->getLearnPostSupportCode()) << ";" << std::endl; }*/ applySynapseSubstitutions(backend, env, getArchetype().getWUPostLearnCodeTokens(), "learn post code", - *this, modelMerged, backend.supportsNamespace()); + *this, batchSize, backend.supportsNamespace()); } //---------------------------------------------------------------------------- @@ -464,14 +462,14 @@ void PostsynapticUpdateGroupMerged::generateSynapseUpdate(const BackendBase &bac //---------------------------------------------------------------------------- const std::string SynapseDynamicsGroupMerged::name = "SynapseDynamics"; //---------------------------------------------------------------------------- -void SynapseDynamicsGroupMerged::generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, const ModelSpecMerged &modelMerged) +void SynapseDynamicsGroupMerged::generateSynapseUpdate(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { /*if (!wum->getSynapseDynamicsSuppportCode().empty() && backend.supportsNamespace()) { os << "using namespace " << modelMerged.getSynapseDynamicsSupportCodeNamespace(wum->getSynapseDynamicsSuppportCode()) << ";" << std::endl; }*/ applySynapseSubstitutions(backend, env, getArchetype().getWUSynapseDynamicsCodeTokens(), "synapse dynamics", - *this, modelMerged, backend.supportsNamespace()); + *this, batchSize, backend.supportsNamespace()); } From 76a1bcb0005deec442ba68809c392dc3f29fc653 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 17:17:40 +0100 Subject: [PATCH 417/725] ``GeNNModel.unload`` should clear out all simulation state from ``GeNNModel`` and unload shared library model # Conflicts: # pygenn/genn_groups.py # pygenn/genn_model.py --- pygenn/genn_groups.py | 56 ++++++++++++++++++++++-- pygenn/genn_model.py | 26 ++++++++++- userproject/include/sharedLibraryModel.h | 45 +++++++++++++++---- 3 files changed, 114 insertions(+), 13 deletions(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index b5ded4dfb2..a75226c38d 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -275,6 +275,25 @@ def _load_var_init_egps(self, var_dict=None): for var_name, var_data in iteritems(var_dict): self._load_egp(var_data.extra_global_params, var_name) + def _unload_vars(self, var_dict=None): + # If no variable dictionary is specified, use standard one + if var_dict is None: + var_dict = self.vars + + # Loop through variables and clear views + for v in itervalues(var_dict): + v.view = None + for e in itervalues(v.extra_global_params): + e.view = None + + def _unload_egps(self, egp_dict=None): + # If no EGP dictionary is specified, use standard one + if egp_dict is None: + egp_dict = self.extra_global_params + + # Loop through extra global params and clear views + for e in itervalues(egp_dict): + e.view = None class NeuronGroupMixin(GroupMixin): @@ -287,8 +306,10 @@ def _init_group(self, model, var_space): model -- pygenn.genn_model.GeNNModel this neuron group is part of """ super(NeuronGroupMixin, self)._init_group(model) - self.spike_que_ptr = [0] - + self.spike_que_ptr = None + self._spike_recording_data = None + self._spike_event_recording_data = None + self.vars, self.extra_global_params = prepare_model( self.neuron_model, self, var_space) @@ -339,6 +360,14 @@ def load(self, num_recording_timesteps): # Load neuron extra global params self._load_egp() + def unload(self): + self.spike_que_ptr = None + self._spike_recording_data = None + self._spike_event_recording_data = None + + self._unload_vars() + self._unload_egps() + def load_init_egps(self): # Load any egps used for variable initialisation self._load_var_init_egps() @@ -730,7 +759,20 @@ def _connectivity_initialiser_provided(self): snippet = self.sparse_connectivity_initialiser.snippet return (len(snippet.get_row_build_code()) > 0 or len(snippet.get_col_build_code()) > 0) - + + def unload(self): + self._ind = None + self._row_lengths = None + self.in_syn = None + + self._unload_vars() + self._unload_vars(self.pre_vars) + self._unload_vars(self.post_vars) + self._unload_vars(self.psm_vars) + self._unload_egps() + self._unload_egps(self.psm_extra_global_params) + self._unload_egps(self.connectivity_extra_global_params) + def _init_wum_var(self, var_data, num_copies): # If initialisation is required if var_data.init_required: @@ -794,6 +836,10 @@ def load_init_egps(self): # Load any egps used for variable initialisation self._load_var_init_egps() + def unload(self): + self._unload_vars() + self._unload_egps() + class CustomUpdateMixin(GroupMixin): """Class representing a custom update""" def _init_group(self, model, var_space): @@ -855,6 +901,10 @@ def load_init_egps(self): # Load any egps used for variable initialisation self._load_var_init_egps() + def unload(self): + self._unload_vars() + self._unload_egps() + @property def _custom_wu_update(self): return isinstance(self, CustomUpdateWU) diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index ab81fb589b..5a76b8fa09 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -576,7 +576,7 @@ def load(self, path_to_model="./", num_recording_timesteps=None): # Allocate recording buffers self._slm.allocate_recording_buffers(num_recording_timesteps) - # Loop through synapse populations and load any + # Loop through neuron populations and load any # extra global parameters required for initialization for pop_data in itervalues(self.neuron_populations): pop_data.load_init_egps() @@ -620,6 +620,30 @@ def load(self, path_to_model="./", num_recording_timesteps=None): self._loaded = True self._built = True + + def unload(self): + # Loop through custom updates and unload + for cu_data in itervalues(self.custom_updates): + cu_data.unload() + + # Loop through current sources and unload + for src_data in itervalues(self.current_sources): + src_data.unload() + + # Loop through synapse populations and unload + for pop_data in itervalues(self.synapse_populations): + pop_data.unload() + + # Loop through neuron populations and unload + for pop_data in itervalues(self.neuron_populations): + pop_data.unload() + + # Close shared library model + self._slm.close() + + # Clear loaded flag + self._loaded = False + def step_time(self): """Make one simulation step""" if not self._loaded: diff --git a/userproject/include/sharedLibraryModel.h b/userproject/include/sharedLibraryModel.h index 6e7a9fe256..d4b9a992b3 100644 --- a/userproject/include/sharedLibraryModel.h +++ b/userproject/include/sharedLibraryModel.h @@ -49,15 +49,8 @@ class SharedLibraryModel virtual ~SharedLibraryModel() { - // Close model library if loaded successfully - if(m_Library) { - freeMem(); -#ifdef _WIN32 - FreeLibrary(m_Library); -#else - dlclose(m_Library); -#endif - } + // Close model library + close(); } //---------------------------------------------------------------------------- @@ -110,6 +103,40 @@ class SharedLibraryModel } } + void close() + { + if(m_Library) { + freeMem(); +#ifdef _WIN32 + FreeLibrary(m_Library); +#else + dlclose(m_Library); +#endif + m_Library = nullptr; + } + + // Null all pointers + m_AllocateMem = nullptr; + m_AllocateRecordingBuffers = nullptr; + m_FreeMem = nullptr; + m_GetFreeDeviceMemBytes = nullptr; + m_Initialize = nullptr; + m_InitializeSparse = nullptr; + m_StepTime = nullptr; + m_PullRecordingBuffersFromDevice = nullptr; + m_NCCLGenerateUniqueID = nullptr; + m_NCCLGetUniqueID = nullptr; + m_NCCLInitCommunicator = nullptr; + m_NCCLUniqueIDBytes = nullptr; + m_T = nullptr; + m_Timestep = nullptr; + + // Empty all dictionaries + m_PopulationVars.clear(); + m_PopulationEPGs.clear(); + m_CustomUpdates.clear(); + } + void allocateExtraGlobalParam(const std::string &popName, const std::string &egpName, unsigned int count) { // Get EGP functions and check allocate exists From 6031c99ef42db6aabdd59bd9d4effab3d28d354c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Sun, 19 Mar 2023 14:26:22 +0000 Subject: [PATCH 418/725] basic boilerplate of ``Models::Base::EGPRef``, ``Models::EGPReference``, ``Models::EGPReferenceContainerBase`` and additions to ``CustomUpdateModel::Base`` # Conflicts: # include/genn/genn/customUpdateModels.h # include/genn/genn/models.h # src/genn/genn/customUpdateModels.cc # src/genn/genn/models.cc --- include/genn/genn/customUpdateModels.h | 3 ++ include/genn/genn/models.h | 54 ++++++++++++++++++++++++++ src/genn/genn/customUpdateModels.cc | 3 ++ src/genn/genn/models.cc | 45 +++++++++++++++++++++ 4 files changed, 105 insertions(+) diff --git a/include/genn/genn/customUpdateModels.h b/include/genn/genn/customUpdateModels.h index dbfa6d51fe..e13671aa09 100644 --- a/include/genn/genn/customUpdateModels.h +++ b/include/genn/genn/customUpdateModels.h @@ -26,6 +26,9 @@ class GENN_EXPORT Base : public Models::Base //! Gets names and types (as strings) of model variable references virtual VarRefVec getVarRefs() const{ return {}; } + //! Gets names and types (as strings) of model extra global parameter references + virtual EGPRefVec getEGPRefs() const { return {}; } + //! Gets the code that performs the custom update virtual std::string getUpdateCode() const{ return ""; } diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index e9bfef8420..8b95aae311 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -82,11 +82,28 @@ class GENN_EXPORT Base : public Snippet::Base VarAccessMode access; }; + struct EGPRef + { + EGPRef(const std::string &n, const std::string &t) : name(n), type(t) + {} + EGPRef() : EGPRef("", "") + {} + + bool operator == (const EGPRef &other) const + { + return ((name == other.name) && (type == other.type)); + } + + std::string name; + std::string type; + }; + //---------------------------------------------------------------------------- // Typedefines //---------------------------------------------------------------------------- typedef std::vector VarVec; typedef std::vector VarRefVec; + typedef std::vector EGPRefVec; //---------------------------------------------------------------------------- // Declared virtuals @@ -270,11 +287,48 @@ class GENN_EXPORT WUVarReference : public VarReferenceBase const GetTargetNameFn m_GetTransposeTargetName; }; +//---------------------------------------------------------------------------- +// Models::EGPReference +//---------------------------------------------------------------------------- +class GENN_EXPORT EGPReference +{ +public: + //------------------------------------------------------------------------ + // Public API + //------------------------------------------------------------------------ + const Models::Base::EGP &getEGP() const { return m_EGP; } + size_t getEGPIndex() const { return m_EGPIndex; } + std::string getTargetName() const { return m_TargetName; } + + //------------------------------------------------------------------------ + // Static API + //------------------------------------------------------------------------ + static EGPReference createEGPRef(const NeuronGroup *ng, const std::string &egpName); + static EGPReference createEGPRef(const CurrentSource *cs, const std::string &egpName); + static EGPReference createEGPRef(const CustomUpdate *cu, const std::string &egpName); + static EGPReference createEGPRef(const CustomUpdateWU *cu, const std::string &egpName); + static EGPReference createPSMEGPRef(const SynapseGroup *sg, const std::string &egpName); + static EGPReference createWUEGPRef(const SynapseGroup *sg, const std::string &egpName); + +private: + EGPReference(size_t egpIndex, const Models::Base::EGPVec &egpVec, + const std::string &targetName) + : m_EGPIndex(egpIndex), m_EGP(egpVec.at(egpIndex)), m_TargetName(targetName) + {} + //------------------------------------------------------------------------ + // Members + //------------------------------------------------------------------------ + size_t m_EGPIndex; + Models::Base::EGP m_EGP; + std::string m_TargetName; +}; + //---------------------------------------------------------------------------- // updateHash overrides //---------------------------------------------------------------------------- GENN_EXPORT void updateHash(const Base::Var &v, boost::uuids::detail::sha1 &hash); GENN_EXPORT void updateHash(const Base::VarRef &v, boost::uuids::detail::sha1 &hash); +GENN_EXPORT void updateHash(const Base::EGPRef &e, boost::uuids::detail::sha1 &hash); GENN_EXPORT void updateHash(const VarReference &v, boost::uuids::detail::sha1 &hash); GENN_EXPORT void updateHash(const WUVarReference &v, boost::uuids::detail::sha1 &hash); diff --git a/src/genn/genn/customUpdateModels.cc b/src/genn/genn/customUpdateModels.cc index f450a7e898..edce380a6f 100644 --- a/src/genn/genn/customUpdateModels.cc +++ b/src/genn/genn/customUpdateModels.cc @@ -21,6 +21,7 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const Utils::updateHash(getUpdateCode(), hash); Utils::updateHash(getVarRefs(), hash); + Utils::updateHash(getEGPRefs(), hash); return hash.get_digest(); } //---------------------------------------------------------------------------- @@ -37,6 +38,7 @@ void Base::validate(const std::unordered_map ¶mValues, // Validate variable reference initialisers Utils::validateInitialisers(varRefs, varRefTargets, "Variable reference", description); + Utils::validateVecNames(getEGPRefs(), "Extra global parameter reference"); } //---------------------------------------------------------------------------- void Base::validate(const std::unordered_map ¶mValues, @@ -52,5 +54,6 @@ void Base::validate(const std::unordered_map ¶mValues, // Validate variable reference initialisers Utils::validateInitialisers(varRefs, varRefTargets, "Variable reference", description); + Utils::validateVecNames(getEGPRefs(), "Extra global parameter reference"); } } // namespace GeNN::CustomUpdateModels \ No newline at end of file diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 9f7ced3a0f..00ba94aa8c 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -205,6 +205,45 @@ SynapseGroup *WUVarReference::getTransposeSynapseGroup() const return m_TransposeSG; } +//---------------------------------------------------------------------------- +// EGPReference +//---------------------------------------------------------------------------- +EGPReference EGPReference::createEGPRef(const NeuronGroup *ng, const std::string &egpName) +{ + const auto *nm = ng->getNeuronModel(); + return EGPReference(nm->getExtraGlobalParamIndex(egpName), nm->getExtraGlobalParams(), ng->getName()); +} +//---------------------------------------------------------------------------- +EGPReference EGPReference::createEGPRef(const CurrentSource *cs, const std::string &egpName) +{ + const auto *cm = cs->getCurrentSourceModel(); + return EGPReference(cm->getExtraGlobalParamIndex(egpName), cm->getExtraGlobalParams(), cs->getName()); +} +//---------------------------------------------------------------------------- +EGPReference EGPReference::createEGPRef(const CustomUpdate *cu, const std::string &egpName) +{ + const auto *cm = cu->getCustomUpdateModel(); + return EGPReference(cm->getExtraGlobalParamIndex(egpName), cm->getExtraGlobalParams(), cu->getName()); +} +//---------------------------------------------------------------------------- +EGPReference EGPReference::createEGPRef(const CustomUpdateWU *cu, const std::string &egpName) +{ + const auto *cm = cu->getCustomUpdateModel(); + return EGPReference(cm->getExtraGlobalParamIndex(egpName), cm->getExtraGlobalParams(), cu->getName()); +} +//---------------------------------------------------------------------------- +EGPReference EGPReference::createPSMEGPRef(const SynapseGroup *sg, const std::string &egpName) +{ + const auto *psm = sg->getPSModel(); + return EGPReference(psm->getExtraGlobalParamIndex(egpName), psm->getExtraGlobalParams(), sg->getName()); +} +//---------------------------------------------------------------------------- +EGPReference EGPReference::createWUEGPRef(const SynapseGroup *sg, const std::string &egpName) +{ + const auto *wum = sg->getWUModel(); + return EGPReference(wum->getExtraGlobalParamIndex(egpName), wum->getExtraGlobalParams(), sg->getName()); +} + //---------------------------------------------------------------------------- // Free functions //---------------------------------------------------------------------------- @@ -222,6 +261,12 @@ void updateHash(const Base::VarRef &v, boost::uuids::detail::sha1 &hash) Utils::updateHash(v.access, hash); } //---------------------------------------------------------------------------- +void updateHash(const Base::EGPRef &e, boost::uuids::detail::sha1 &hash) +{ + Utils::updateHash(e.name, hash); + Type::updateHash(e.type, hash); +} +//---------------------------------------------------------------------------- void updateHash(const VarReference &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.getTargetName(), hash); From ce0fcc6ae71bfb91859442ce39fb9abf8a104506 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 20 Mar 2023 09:47:32 +0000 Subject: [PATCH 419/725] remaining functionality for EGP references # Conflicts: # include/genn/genn/code_generator/groupMerged.h # include/genn/genn/code_generator/modelSpecMerged.h # include/genn/genn/customUpdate.h # include/genn/genn/customUpdateInternal.h # include/genn/genn/modelSpec.h # include/genn/genn/models.h # src/genn/genn/code_generator/customUpdateGroupMerged.cc # src/genn/genn/customUpdate.cc # src/genn/genn/customUpdateModels.cc # src/genn/genn/models.cc --- .../genn/genn/code_generator/environment.h | 19 ++++++++++++++++ .../genn/code_generator/modelSpecMerged.h | 2 ++ include/genn/genn/customUpdate.h | 10 ++++++--- include/genn/genn/customUpdateInternal.h | 8 +++---- include/genn/genn/customUpdateModels.h | 6 ++--- include/genn/genn/modelSpec.h | 13 ++++++----- include/genn/genn/models.h | 9 ++++---- .../code_generator/customUpdateGroupMerged.cc | 4 ++++ src/genn/genn/customUpdate.cc | 22 ++++++++++++++----- src/genn/genn/customUpdateModels.cc | 8 +++---- src/genn/genn/modelSpec.cc | 8 +++---- src/genn/genn/models.cc | 6 +++++ 12 files changed, 81 insertions(+), 34 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 28cf578b71..9bc5cfaaf6 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -503,6 +503,25 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBasegetGroup().getTypeContext()); + assert(!resolvedType.isPointer()); + const auto pointerType = resolvedType.createPointer(); + addField(pointerType, e.name, + pointerType, e.name + fieldSuffix, + [arrayPrefix, e](const auto &g, size_t) + { + const auto egpRef = g.getEGPReferences().at(e.name); + return arrayPrefix + egpRef.getEGP().name + egpRef.getTargetName(); + }, + "", GroupMergedFieldType::DYNAMIC); + } + } + template void addConnectInitParams(const std::string &fieldSuffix, GetConnectivityFn getConnectivity, IsHeterogeneousFn isHeterogeneous) diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index de391d4b76..0be08cc808 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -386,6 +386,8 @@ class GENN_EXPORT ModelSpecMerged const auto &g = mergedGroups.back().getGroups()[groupIndex]; // Add reference to this group's variable to data structure + // **NOTE** this works fine with EGP references because the function to + // get their value will just return the name of the referenced EGP assert(std::get<0>(f).isPointer()); m_MergedEGPs[std::get<2>(f)(g, groupIndex)].emplace( std::piecewise_construct, diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 81c49d4524..18a61b548d 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -41,6 +41,8 @@ class GENN_EXPORT CustomUpdateBase const std::unordered_map &getParams() const{ return m_Params; } const std::unordered_map &getVarInitialisers() const{ return m_VarInitialisers; } + const std::unordered_map &getEGPReferences() const{ return m_EGPReferences; } + //! Get variable location for custom update model state variable VarLocation getVarLocation(const std::string &varName) const; @@ -53,7 +55,7 @@ class GENN_EXPORT CustomUpdateBase protected: CustomUpdateBase(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, - VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation); + const std::unordered_map &egpReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation); //------------------------------------------------------------------------ // Protected methods @@ -150,6 +152,8 @@ class GENN_EXPORT CustomUpdateBase std::unordered_map m_DerivedParams; std::unordered_map m_VarInitialisers; + std::unordered_map m_EGPReferences; + //! Location of individual state variables std::vector m_VarLocation; @@ -231,7 +235,7 @@ class GENN_EXPORT CustomUpdate : public CustomUpdateBase CustomUpdate(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, - VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation); + const std::unordered_map &egpReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation); //------------------------------------------------------------------------ // Protected methods @@ -282,7 +286,7 @@ class GENN_EXPORT CustomUpdateWU : public CustomUpdateBase CustomUpdateWU(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, - VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation); + const std::unordered_map &egpReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation); //------------------------------------------------------------------------ // Protected methods diff --git a/include/genn/genn/customUpdateInternal.h b/include/genn/genn/customUpdateInternal.h index fbb08a0d43..e3ff62b263 100644 --- a/include/genn/genn/customUpdateInternal.h +++ b/include/genn/genn/customUpdateInternal.h @@ -17,9 +17,9 @@ class CustomUpdateInternal : public CustomUpdate CustomUpdateInternal(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, - VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) + const std::unordered_map &egpReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : CustomUpdate(name, updateGroupName, customUpdateModel, params, varInitialisers, varReferences, - defaultVarLocation, defaultExtraGlobalParamLocation) + egpReferences, defaultVarLocation, defaultExtraGlobalParamLocation) { } @@ -75,9 +75,9 @@ class CustomUpdateWUInternal : public CustomUpdateWU CustomUpdateWUInternal(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, - VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) + const std::unordered_map &egpReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : CustomUpdateWU(name, updateGroupName, customUpdateModel, params, varInitialisers, varReferences, - defaultVarLocation, defaultExtraGlobalParamLocation) + egpReferences, defaultVarLocation, defaultExtraGlobalParamLocation) { getSynapseGroup()->addCustomUpdateReference(this); } diff --git a/include/genn/genn/customUpdateModels.h b/include/genn/genn/customUpdateModels.h index e13671aa09..27fb31f229 100644 --- a/include/genn/genn/customUpdateModels.h +++ b/include/genn/genn/customUpdateModels.h @@ -23,11 +23,11 @@ class GENN_EXPORT Base : public Models::Base //---------------------------------------------------------------------------- // Declared virtuals //---------------------------------------------------------------------------- - //! Gets names and types (as strings) of model variable references + //! Gets names and typesn of model variable references virtual VarRefVec getVarRefs() const{ return {}; } - //! Gets names and types (as strings) of model extra global parameter references - virtual EGPRefVec getEGPRefs() const { return {}; } + //! Gets names and types of model extra global parameter references + virtual EGPRefVec getExtraGlobalParamRefs() const { return {}; } //! Gets the code that performs the custom update virtual std::string getUpdateCode() const{ return ""; } diff --git a/include/genn/genn/modelSpec.h b/include/genn/genn/modelSpec.h index 6b7ea6a78d..910e0283c5 100644 --- a/include/genn/genn/modelSpec.h +++ b/include/genn/genn/modelSpec.h @@ -47,6 +47,7 @@ using ParamValues = std::unordered_map; using VarValues = std::unordered_map; using VarReferences = std::unordered_map; using WUVarReferences = std::unordered_map; +using EGPReferences = std::unordered_map; //! Floating point precision to use for "scalar" type variables models enum class ScalarPrecision @@ -552,7 +553,7 @@ class GENN_EXPORT ModelSpec \return pointer to newly created CustomUpdateBase */ CustomUpdate *addCustomUpdate(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *model, const ParamValues ¶mValues, const VarValues &varInitialisers, - const VarReferences &varReferences); + const VarReferences &varReferences, const EGPReferences &egpReferences = {}); //! Adds a new custom update with references to weight update model variable to the //! model using a custom update model managed by the user @@ -565,7 +566,7 @@ class GENN_EXPORT ModelSpec \return pointer to newly created CustomUpdateBase */ CustomUpdateWU *addCustomUpdate(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *model, const ParamValues ¶mValues, const VarValues &varInitialisers, - const WUVarReferences &varReferences); + const WUVarReferences &varReferences, const EGPReferences &egpReferences = {}); //! Adds a new custom update to the model using a singleton custom update model //! created using standard DECLARE_CUSTOM_UPDATE_MODEL and IMPLEMENT_MODEL macros @@ -579,10 +580,10 @@ class GENN_EXPORT ModelSpec template CustomUpdate *addCustomUpdate(const std::string &name, const std::string &updateGroupName, const ParamValues ¶mValues, const VarValues &varInitialisers, - const VarReferences &varReferences) + const VarReferences &varReferences, const EGPReferences &egpReferences = {}) { return addCustomUpdate(name, updateGroupName, CustomUpdateModel::getInstance(), - paramValues, varInitialisers, varReferences); + paramValues, varInitialisers, varReferences, egpReferences); } @@ -598,10 +599,10 @@ class GENN_EXPORT ModelSpec template CustomUpdateWU *addCustomUpdate(const std::string &name, const std::string &updateGroupName, const ParamValues ¶mValues, const VarValues &varInitialisers, - const WUVarReferences &varReferences) + const WUVarReferences &varReferences, const EGPReferences &egpReferences = {}) { return addCustomUpdate(name, updateGroupName, CustomUpdateModel::getInstance(), - paramValues, varInitialisers, varReferences); + paramValues, varInitialisers, varReferences, egpReferences); } //! Adds a new custom connectivity update attached to synapse group and potentially with synaptic, presynaptic and diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 8b95aae311..98ae7d4a54 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -84,18 +84,18 @@ class GENN_EXPORT Base : public Snippet::Base struct EGPRef { - EGPRef(const std::string &n, const std::string &t) : name(n), type(t) + EGPRef(const std::string &n, const Type::ResolvedType &t) : name(n), type(t) {} - EGPRef() : EGPRef("", "") + EGPRef(const std::string &n, const std::string &t) : name(n), type(t) {} bool operator == (const EGPRef &other) const { - return ((name == other.name) && (type == other.type)); + return (std::tie(name, type) == std::tie(other.name, other.type)); } std::string name; - std::string type; + Type::UnresolvedType type; }; //---------------------------------------------------------------------------- @@ -331,6 +331,7 @@ GENN_EXPORT void updateHash(const Base::VarRef &v, boost::uuids::detail::sha1 &h GENN_EXPORT void updateHash(const Base::EGPRef &e, boost::uuids::detail::sha1 &hash); GENN_EXPORT void updateHash(const VarReference &v, boost::uuids::detail::sha1 &hash); GENN_EXPORT void updateHash(const WUVarReference &v, boost::uuids::detail::sha1 &hash); +GENN_EXPORT void updateHash(const EGPReference &v, boost::uuids::detail::sha1 &hash); //! Helper function to check if variable reference types match those specified in model template diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index 3b33c4cb0e..008ddeceb8 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -39,6 +39,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateGroupMerged::getHashDigest() updateHash([](const auto &cg) { return cg.getParams(); }, hash); updateHash([](const auto &cg) { return cg.getDerivedParams(); }, hash); updateHash([](const auto &cg) { return cg.getVarReferences(); }, hash); + updateHash([](const auto &cg) { return cg.getEGPReferences(); }, hash); return hash.get_digest(); } @@ -63,6 +64,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E cuEnv.addParams(cm->getParamNames(), "", &CustomUpdateInternal::getParams, &CustomUpdateGroupMerged::isParamHeterogeneous); cuEnv.addDerivedParams(cm->getDerivedParams(), "", &CustomUpdateInternal::getDerivedParams, &CustomUpdateGroupMerged::isDerivedParamHeterogeneous); cuEnv.addExtraGlobalParams(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); + cuEnv.addExtraGlobalParamRefs(cm->getExtraGlobalParamRefs(), backend.getDeviceVarPrefix()); // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( @@ -168,6 +170,7 @@ boost::uuids::detail::sha1::digest_type CustomUpdateWUGroupMergedBase::getHashDi updateHash([](const auto &cg) { return cg.getParams(); }, hash); updateHash([](const auto &cg) { return cg.getDerivedParams(); }, hash); updateHash([](const auto &cg) { return cg.getVarReferences(); }, hash); + updateHash([](const auto &cg) { return cg.getEGPReferences(); }, hash); return hash.get_digest(); } @@ -247,6 +250,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdateBase(const BackendBase & cuEnv.addParams(cm->getParamNames(), "", &CustomUpdateInternal::getParams, &CustomUpdateWUGroupMergedBase::isParamHeterogeneous); cuEnv.addDerivedParams(cm->getDerivedParams(), "", &CustomUpdateInternal::getDerivedParams, &CustomUpdateWUGroupMergedBase::isDerivedParamHeterogeneous); cuEnv.addExtraGlobalParams(cm->getExtraGlobalParams(), backend.getDeviceVarPrefix()); + cuEnv.addExtraGlobalParamRefs(cm->getExtraGlobalParamRefs(), backend.getDeviceVarPrefix()); // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( diff --git a/src/genn/genn/customUpdate.cc b/src/genn/genn/customUpdate.cc index cd79e3c72b..6ce87cd735 100644 --- a/src/genn/genn/customUpdate.cc +++ b/src/genn/genn/customUpdate.cc @@ -34,9 +34,9 @@ bool CustomUpdateBase::isVarInitRequired() const //---------------------------------------------------------------------------- CustomUpdateBase::CustomUpdateBase(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, - VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) + const std::unordered_map &egpReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) : m_Name(name), m_UpdateGroupName(updateGroupName), m_CustomUpdateModel(customUpdateModel), m_Params(params), - m_VarInitialisers(varInitialisers), m_VarLocation(varInitialisers.size(), defaultVarLocation), + m_VarInitialisers(varInitialisers), m_EGPReferences(egpReferences), m_VarLocation(varInitialisers.size(), defaultVarLocation), m_ExtraGlobalParamLocation(customUpdateModel->getExtraGlobalParams().size(), defaultExtraGlobalParamLocation), m_Batched(false) { @@ -44,6 +44,16 @@ CustomUpdateBase::CustomUpdateBase(const std::string &name, const std::string &u Utils::validatePopName(name, "Custom update"); Utils::validatePopName(updateGroupName, "Custom update group name"); + // Loop through all extra global parameter references + for (const auto &modelEGPRef : getCustomUpdateModel()->getExtraGlobalParamRefs()) { + const auto egpRef = egpReferences.at(modelEGPRef.name); + + // Check types of extra global parameter references against those specified in model + // **THINK** due to GeNN's current string-based type system this is rather conservative + if (egpRef.getEGP().type != modelEGPRef.type) { + throw std::runtime_error("Incompatible type for extra global parameter reference '" + modelEGPRef.name + "'"); + } + } // Scan custom update model code string m_UpdateCodeTokens = Utils::scanCode(getCustomUpdateModel()->getUpdateCode(), "Custom update '" + getName() + "' update code"); @@ -110,8 +120,8 @@ boost::uuids::detail::sha1::digest_type CustomUpdateBase::getVarLocationHashDige CustomUpdate::CustomUpdate(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, - VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) - : CustomUpdateBase(name, updateGroupName, customUpdateModel, params, varInitialisers, defaultVarLocation, defaultExtraGlobalParamLocation), + const std::unordered_map &egpReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) + : CustomUpdateBase(name, updateGroupName, customUpdateModel, params, varInitialisers, egpReferences, defaultVarLocation, defaultExtraGlobalParamLocation), m_VarReferences(varReferences), m_Size(varReferences.empty() ? 0 : varReferences.begin()->second.getSize()), m_DelayNeuronGroup(nullptr), m_PerNeuron(false) { // Validate parameters, variables and variable references @@ -229,8 +239,8 @@ boost::uuids::detail::sha1::digest_type CustomUpdate::getInitHashDigest() const CustomUpdateWU::CustomUpdateWU(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *customUpdateModel, const std::unordered_map ¶ms, const std::unordered_map &varInitialisers, const std::unordered_map &varReferences, - VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) -: CustomUpdateBase(name, updateGroupName, customUpdateModel, params, varInitialisers, defaultVarLocation, defaultExtraGlobalParamLocation), + const std::unordered_map &egpReferences, VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation) +: CustomUpdateBase(name, updateGroupName, customUpdateModel, params, varInitialisers, egpReferences, defaultVarLocation, defaultExtraGlobalParamLocation), m_VarReferences(varReferences), m_SynapseGroup(m_VarReferences.empty() ? nullptr : static_cast(m_VarReferences.begin()->second.getSynapseGroup())) { // Validate parameters, variables and variable references diff --git a/src/genn/genn/customUpdateModels.cc b/src/genn/genn/customUpdateModels.cc index edce380a6f..e2c05ea599 100644 --- a/src/genn/genn/customUpdateModels.cc +++ b/src/genn/genn/customUpdateModels.cc @@ -21,7 +21,7 @@ boost::uuids::detail::sha1::digest_type Base::getHashDigest() const Utils::updateHash(getUpdateCode(), hash); Utils::updateHash(getVarRefs(), hash); - Utils::updateHash(getEGPRefs(), hash); + Utils::updateHash(getExtraGlobalParamRefs(), hash); return hash.get_digest(); } //---------------------------------------------------------------------------- @@ -34,11 +34,11 @@ void Base::validate(const std::unordered_map ¶mValues, Models::Base::validate(paramValues, varValues, description); const auto varRefs = getVarRefs(); - Utils::validateVecNames(getVarRefs(), "Variable reference"); + Utils::validateVecNames(varRefs, "Variable reference"); // Validate variable reference initialisers Utils::validateInitialisers(varRefs, varRefTargets, "Variable reference", description); - Utils::validateVecNames(getEGPRefs(), "Extra global parameter reference"); + Utils::validateVecNames(getExtraGlobalParamRefs(), "Extra global parameter reference"); } //---------------------------------------------------------------------------- void Base::validate(const std::unordered_map ¶mValues, @@ -54,6 +54,6 @@ void Base::validate(const std::unordered_map ¶mValues, // Validate variable reference initialisers Utils::validateInitialisers(varRefs, varRefTargets, "Variable reference", description); - Utils::validateVecNames(getEGPRefs(), "Extra global parameter reference"); + Utils::validateVecNames(getExtraGlobalParamRefs(), "Extra global parameter reference"); } } // namespace GeNN::CustomUpdateModels \ No newline at end of file diff --git a/src/genn/genn/modelSpec.cc b/src/genn/genn/modelSpec.cc index 3f304ce6b9..5392beb4c8 100644 --- a/src/genn/genn/modelSpec.cc +++ b/src/genn/genn/modelSpec.cc @@ -149,13 +149,13 @@ CurrentSource *ModelSpec::addCurrentSource(const std::string ¤tSourceName, // --------------------------------------------------------------------------- CustomUpdate *ModelSpec::addCustomUpdate(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *model, const ParamValues ¶mValues, const VarValues &varInitialisers, - const VarReferences &varReferences) + const VarReferences &varReferences, const EGPReferences &egpReferences) { // Add neuron group to map auto result = m_CustomUpdates.emplace(std::piecewise_construct, std::forward_as_tuple(name), std::forward_as_tuple(name, updateGroupName, model, - paramValues, varInitialisers, varReferences, + paramValues, varInitialisers, varReferences, egpReferences, m_DefaultVarLocation, m_DefaultExtraGlobalParamLocation)); if(!result.second) { @@ -194,13 +194,13 @@ CustomConnectivityUpdate *ModelSpec::addCustomConnectivityUpdate(const std::stri // --------------------------------------------------------------------------- CustomUpdateWU *ModelSpec::addCustomUpdate(const std::string &name, const std::string &updateGroupName, const CustomUpdateModels::Base *model, const ParamValues ¶mValues, const VarValues &varInitialisers, - const WUVarReferences &varReferences) + const WUVarReferences &varReferences, const EGPReferences &egpReferences) { // Add neuron group to map auto result = m_CustomWUUpdates.emplace(std::piecewise_construct, std::forward_as_tuple(name), std::forward_as_tuple(name, updateGroupName, model, - paramValues, varInitialisers, varReferences, + paramValues, varInitialisers, varReferences, egpReferences, m_DefaultVarLocation, m_DefaultExtraGlobalParamLocation)); if(!result.second) { diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index 00ba94aa8c..fea7767f21 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -283,4 +283,10 @@ void updateHash(const WUVarReference &v, boost::uuids::detail::sha1 &hash) Utils::updateHash(v.getTransposeVarIndex(), hash); } } +//---------------------------------------------------------------------------- +void Models::updateHash(const EGPReference &v, boost::uuids::detail::sha1 &hash) +{ + Utils::updateHash(v.getTargetName(), hash); + Utils::updateHash(v.getEGPIndex(), hash); +} } // namespace GeNN::Models From 725b0d71553b4cb5da22d537240d8c6b0ce88ed4 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 20 Mar 2023 11:20:19 +0000 Subject: [PATCH 420/725] free functions for creating EGP references --- include/genn/genn/modelSpec.h | 36 +++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/include/genn/genn/modelSpec.h b/include/genn/genn/modelSpec.h index 910e0283c5..be80b84ced 100644 --- a/include/genn/genn/modelSpec.h +++ b/include/genn/genn/modelSpec.h @@ -206,6 +206,42 @@ inline Models::WUVarReference createWUVarRef(CustomConnectivityUpdate *cu, const return Models::WUVarReference(cu, varName); } +//! Creates a reference to a neuron group extra global parameter +inline Models::EGPReference createEGPRef(const NeuronGroup *ng, const std::string &egpName) +{ + return Models::EGPReference::createEGPRef(ng, egpName); +} + +//! Creates a reference to a current source extra global parameter +inline Models::EGPReference createEGPRef(const CurrentSource *cs, const std::string &egpName) +{ + return Models::EGPReference::createEGPRef(cs, egpName); +} + +//! Creates a reference to a custom update extra global parameter +inline Models::EGPReference createEGPRef(const CustomUpdate *cu, const std::string &egpName) +{ + return Models::EGPReference::createEGPRef(cu, egpName); +} + +//! Creates a reference to a custom weight update extra global parameter +inline Models::EGPReference createEGPRef(const CustomUpdateWU *cu, const std::string &egpName) +{ + return Models::EGPReference::createEGPRef(cu, egpName); +} + +//! Creates a reference to a postsynaptic model extra global parameter +inline Models::EGPReference createPSMEGPRef(const SynapseGroup *sg, const std::string &egpName) +{ + return Models::EGPReference::createPSMEGPRef(sg, egpName); +} + +//! Creates a reference to a weight update model extra global parameter +inline Models::EGPReference createWUEGPRef(const SynapseGroup *sg, const std::string &egpName) +{ + return Models::EGPReference::createWUEGPRef(sg, egpName); +} + //---------------------------------------------------------------------------- // GeNN::ModelSpec //---------------------------------------------------------------------------- From d695e9835c4f27dad04f641ea90e63d166469bc9 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 30 May 2023 17:00:57 +0100 Subject: [PATCH 421/725] unit test for EGP references --- tests/unit/customUpdate.cc | 74 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/tests/unit/customUpdate.cc b/tests/unit/customUpdate.cc index 78a32f33ef..a3b11193eb 100644 --- a/tests/unit/customUpdate.cc +++ b/tests/unit/customUpdate.cc @@ -78,7 +78,7 @@ class Sum3 : public CustomUpdateModels::Base SET_UPDATE_CODE("$(sum) = $(scale) * ($(a) + $(b));\n"); SET_VARS({{"sum", "scalar"}, {"scale", "scalar", VarAccess::READ_ONLY_SHARED_NEURON}}); - SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_ONLY}, + SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_WRITE}, {"b", "scalar", VarAccessMode::READ_ONLY}}); }; IMPLEMENT_SNIPPET(Sum3); @@ -89,11 +89,47 @@ class Copy : public CustomUpdateModels::Base SET_UPDATE_CODE("a = b;\n"); - SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_ONLY}, + SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_WRITE}, {"b", "scalar", VarAccessMode::READ_ONLY}}); }; IMPLEMENT_SNIPPET(Copy); +class EGPScale : public CustomUpdateModels::Base +{ + DECLARE_CUSTOM_UPDATE_MODEL(EGPScale, 0, 0, 2); + + SET_UPDATE_CODE("$(a) = $(b) * $(c)[$(id)];\n"); + + SET_EXTRA_GLOBAL_PARAMS({{"c", "scalar*"}}); + SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_WRITE}, + {"b", "scalar", VarAccessMode::READ_ONLY}}); +}; +IMPLEMENT_MODEL(EGPScale); + +class EGPRefScale : public CustomUpdateModels::Base +{ + DECLARE_CUSTOM_UPDATE_MODEL_EGP_REF(EGPRefScale, 0, 0, 2, 1); + + SET_UPDATE_CODE("$(a) = $(b) * $(c)[$(id)];\n"); + + SET_EXTRA_GLOBAL_PARAM_REFS({{"c", "scalar*"}}); + SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_WRITE}, + {"b", "scalar", VarAccessMode::READ_ONLY}}); +}; +IMPLEMENT_MODEL(EGPRefScale); + +class EGPRefScaleInt : public CustomUpdateModels::Base +{ + DECLARE_CUSTOM_UPDATE_MODEL_EGP_REF(EGPRefScaleInt, 0, 0, 2, 1); + + SET_UPDATE_CODE("$(a) = $(b) * $(c)[$(id)];\n"); + + SET_EXTRA_GLOBAL_PARAM_REFS({{"c", "int*"}}); + SET_VAR_REFS({{"a", "scalar", VarAccessMode::READ_WRITE}, + {"b", "scalar", VarAccessMode::READ_ONLY}}); +}; +IMPLEMENT_MODEL(EGPRefScaleInt); + class Cont : public WeightUpdateModels::Base { public: @@ -283,6 +319,40 @@ TEST(CustomUpdates, VarReferenceTypeChecks) model.finalise(); } //-------------------------------------------------------------------------- +TEST(CustomUpdates, EGPReferenceTypeChecks) +{ + ModelSpecInternal model; + + // Add three neuron group to model + NeuronModels::Izhikevich::ParamValues paramVals(0.02, 0.2, -65.0, 8.0); + NeuronModels::Izhikevich::VarValues varVals(0.0, 0.0); + auto *pop1 = model.addNeuronPopulation("Pop1", 10, paramVals, varVals); + auto *pop2 = model.addNeuronPopulation("Pop2", 10, paramVals, varVals); + auto *pop3 = model.addNeuronPopulation("Pop3", 10, paramVals, varVals); + + // Add scaling custom update with EGP + EGPScale::VarReferences scaleVarReferences1(createVarRef(pop1, "V"), createVarRef(pop1, "U")); + auto *egpScale = model.addCustomUpdate("Scale1", "CustomUpdate", + {}, {}, scaleVarReferences1); + + // Add scaling custom update with EGP ref sharing "c" EGP + EGPRefScale::VarReferences scaleVarReferences2(createVarRef(pop2, "V"), createVarRef(pop2, "U")); + EGPRefScale::EGPReferences scaleEGPReferences2(createEGPRef(egpScale, "c")); + model.addCustomUpdate("Scale2", "CustomUpdate", + EGPRefScale::ParamValues{}, EGPRefScale::VarValues{}, scaleVarReferences2, scaleEGPReferences2); + try { + // Add scaling custom update with EGP ref sharing "c" EGP + EGPRefScaleInt::VarReferences scaleVarReferences3(createVarRef(pop3, "V"), createVarRef(pop3, "U")); + EGPRefScaleInt::EGPReferences scaleEGPReferences3(createEGPRef(egpScale, "c")); + model.addCustomUpdate("Scale3", "CustomUpdate", + EGPRefScaleInt::ParamValues{}, EGPRefScaleInt::VarValues{}, scaleVarReferences3, scaleEGPReferences3); + FAIL(); + } + catch(const std::runtime_error &) { + } + + model.finalize(); +} TEST(CustomUpdates, VarSizeChecks) { ModelSpecInternal model; From 1bfb414fa95052af73adea878fe8f1a0b12d181b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 30 May 2023 17:04:10 +0100 Subject: [PATCH 422/725] fixed error calculation in a couple of feature tests (before blindly copy-pasting) --- tests/features/extra_global_cs_param/test.cc | 4 ++-- tests/features/extra_global_psm_param/test.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/features/extra_global_cs_param/test.cc b/tests/features/extra_global_cs_param/test.cc index 60b41cb18d..05fddb3bb1 100644 --- a/tests/features/extra_global_cs_param/test.cc +++ b/tests/features/extra_global_cs_param/test.cc @@ -29,10 +29,10 @@ TEST_F(SimulationTest, ExtraGlobalCSParams) for(int j = 0; j < 10; j++) { if(j == i) { - error = fabs(xpop[j] - 1.0f); + error += fabs(xpop[j] - 1.0f); } else { - error = fabs(xpop[j]); + error += fabs(xpop[j]); } } } diff --git a/tests/features/extra_global_psm_param/test.cc b/tests/features/extra_global_psm_param/test.cc index f9bd8da1b7..008b369a09 100644 --- a/tests/features/extra_global_psm_param/test.cc +++ b/tests/features/extra_global_psm_param/test.cc @@ -29,10 +29,10 @@ TEST_F(SimulationTest, ExtraGlobalPSMParams) for(int j = 0; j < 10; j++) { if(i > 1 && j == (i - 1)) { - error = fabs(xpost[j] - ((i - 1) * DT)); + error += fabs(xpost[j] - ((i - 1) * DT)); } else { - error = fabs(xpost[j]); + error += fabs(xpost[j]); } } } From cc40d9af1b0c826b6088a3cb0f525c3650109e8f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 30 May 2023 17:04:26 +0100 Subject: [PATCH 423/725] feature test for extra global param refs --- .../extra_global_param_ref.sln | 30 ++++++++ .../extra_global_param_ref.vcxproj | 63 +++++++++++++++++ .../features/extra_global_param_ref/model.cc | 68 +++++++++++++++++++ .../extra_global_param_ref/runner_guid.txt | 1 + tests/features/extra_global_param_ref/test.cc | 43 ++++++++++++ 5 files changed, 205 insertions(+) create mode 100644 tests/features/extra_global_param_ref/extra_global_param_ref.sln create mode 100644 tests/features/extra_global_param_ref/extra_global_param_ref.vcxproj create mode 100644 tests/features/extra_global_param_ref/model.cc create mode 100644 tests/features/extra_global_param_ref/runner_guid.txt create mode 100644 tests/features/extra_global_param_ref/test.cc diff --git a/tests/features/extra_global_param_ref/extra_global_param_ref.sln b/tests/features/extra_global_param_ref/extra_global_param_ref.sln new file mode 100644 index 0000000000..84e14479be --- /dev/null +++ b/tests/features/extra_global_param_ref/extra_global_param_ref.sln @@ -0,0 +1,30 @@ +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio 2013 +VisualStudioVersion = 12.0.30501.0 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "extra_global_param_ref", "extra_global_param_ref.vcxproj", "{26223262-6B56-40FA-8DAC-8BCAB3AF7F3A}" + ProjectSection(ProjectDependencies) = postProject + {9062DB83-2CB8-4D2E-BF12-0FEF606D9571} = {9062DB83-2CB8-4D2E-BF12-0FEF606D9571} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "extra_global_param_ref_CODE\runner.vcxproj", "{9062DB83-2CB8-4D2E-BF12-0FEF606D9571}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|x64 = Debug|x64 + Release|x64 = Release|x64 + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {26223262-6B56-40FA-8DAC-8BCAB3AF7F3A}.Debug|x64.ActiveCfg = Debug|x64 + {26223262-6B56-40FA-8DAC-8BCAB3AF7F3A}.Debug|x64.Build.0 = Debug|x64 + {26223262-6B56-40FA-8DAC-8BCAB3AF7F3A}.Release|x64.ActiveCfg = Release|x64 + {26223262-6B56-40FA-8DAC-8BCAB3AF7F3A}.Release|x64.Build.0 = Release|x64 + {9062DB83-2CB8-4D2E-BF12-0FEF606D9571}.Debug|x64.ActiveCfg = Debug|x64 + {9062DB83-2CB8-4D2E-BF12-0FEF606D9571}.Debug|x64.Build.0 = Debug|x64 + {9062DB83-2CB8-4D2E-BF12-0FEF606D9571}.Release|x64.ActiveCfg = Release|x64 + {9062DB83-2CB8-4D2E-BF12-0FEF606D9571}.Release|x64.Build.0 = Release|x64 + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection +EndGlobal diff --git a/tests/features/extra_global_param_ref/extra_global_param_ref.vcxproj b/tests/features/extra_global_param_ref/extra_global_param_ref.vcxproj new file mode 100644 index 0000000000..91c54be056 --- /dev/null +++ b/tests/features/extra_global_param_ref/extra_global_param_ref.vcxproj @@ -0,0 +1,63 @@ + + + + + Debug + x64 + + + Release + x64 + + + + {26223262-6B56-40FA-8DAC-8BCAB3AF7F3A} + + + + + + + + + Application + true + $(DefaultPlatformToolset) + true + MultiByte + + + + + + + + + + ./ + $(Platform)\$(Configuration)\ + test + + + + Level3 + MaxSpeed + Disabled + true + true + true + extra_global_param_ref_CODE;$(GTEST_DIR);$(GTEST_DIR)/include + _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) + + + true + true + true + runner_Release.lib;%(AdditionalDependencies) + runner_Debug.lib;%(AdditionalDependencies) + + + + + + diff --git a/tests/features/extra_global_param_ref/model.cc b/tests/features/extra_global_param_ref/model.cc new file mode 100644 index 0000000000..a697d5a05f --- /dev/null +++ b/tests/features/extra_global_param_ref/model.cc @@ -0,0 +1,68 @@ +//-------------------------------------------------------------------------- +/*! \file extra_global_param_ref/model.cc + +\brief model definition file that is part of the feature testing +suite of minimal models with known analytic outcomes that are used for continuous integration testing. +*/ +//-------------------------------------------------------------------------- + + +#include "modelSpec.h" + +//---------------------------------------------------------------------------- +// Neuron +//---------------------------------------------------------------------------- +class Neuron : public NeuronModels::Base +{ +public: + DECLARE_MODEL(Neuron, 0, 1); + + SET_SIM_CODE("$(x) = $(e)[$(id)];\n"); + + SET_VARS({{"x", "scalar"}}); + SET_EXTRA_GLOBAL_PARAMS({{"e", "scalar*"}}); +}; +IMPLEMENT_MODEL(Neuron); + +//---------------------------------------------------------------------------- +// CU +//---------------------------------------------------------------------------- +class CU : public CustomUpdateModels::Base +{ +public: + DECLARE_CUSTOM_UPDATE_MODEL_EGP_REF(CU, 0, 0, 1, 1); + SET_UPDATE_CODE("if($(id) == (int)round(fmod($(t), 10.0))) {\n" + " $(e)[$(id)] = 1.0;\n" + "}\n" + "else {\n" + " $(e)[$(id)] = 0.0;\n" + "}"); + SET_VAR_REFS({{"v", "scalar"}}) + SET_EXTRA_GLOBAL_PARAM_REFS({{"e", "scalar*"}}); +}; +IMPLEMENT_MODEL(CU); + + +void modelDefinition(ModelSpec &model) +{ +#ifdef CL_HPP_TARGET_OPENCL_VERSION + if(std::getenv("OPENCL_DEVICE") != nullptr) { + GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; + GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); + } + if(std::getenv("OPENCL_PLATFORM") != nullptr) { + GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); + } +#endif + model.setDT(0.1); + model.setName("extra_global_param_ref"); + + auto *pop = model.addNeuronPopulation("pop", 10, {}, Neuron::VarValues(0.0)); + + CU::VarReferences cuVarRefs(createVarRef(pop, "x")); + CU::EGPReferences cuEGPRefs(createEGPRef(pop, "e")); + model.addCustomUpdate("CU", "CustomUpdate", + CU::ParamValues{}, CU::VarValues{}, cuVarRefs, cuEGPRefs); + + model.setPrecision(GENN_FLOAT); +} diff --git a/tests/features/extra_global_param_ref/runner_guid.txt b/tests/features/extra_global_param_ref/runner_guid.txt new file mode 100644 index 0000000000..4b3c78edc8 --- /dev/null +++ b/tests/features/extra_global_param_ref/runner_guid.txt @@ -0,0 +1 @@ +9062DB83-2CB8-4D2E-BF12-0FEF606D9571 diff --git a/tests/features/extra_global_param_ref/test.cc b/tests/features/extra_global_param_ref/test.cc new file mode 100644 index 0000000000..bf4ce6cdc9 --- /dev/null +++ b/tests/features/extra_global_param_ref/test.cc @@ -0,0 +1,43 @@ +//-------------------------------------------------------------------------- +/*! \file extra_global_param_ref/test.cc + +\brief Main test code that is part of the feature testing +suite of minimal models with known analytic outcomes that are used for continuous integration testing. +*/ +//-------------------------------------------------------------------------- +// Standard C include +#include + +// Google test includes +#include "gtest/gtest.h" + +// Auto-generated simulation code includess +#include "extra_global_param_ref_CODE/definitions.h" + +// **NOTE** base-class for simulation tests must be +// included after auto-generated globals are includes +#include "../../utils/simulation_test.h" + + + +TEST_F(SimulationTest, ExtraGlobalParamRef) +{ + allocateepop(10); + scalar error = 0.0; + for(int i = 0; i < 100; i++) { + updateCustomUpdate(); + StepGeNN(); + + for(int j = 0; j < 10; j++) { + if(j == (int)round(i * DT)) { + error += fabs(xpop[j] - 1.0f); + } + else { + error += fabs(xpop[j]); + } + } + } + + // Check total error is less than some tolerance + EXPECT_LT(error, 1e-6); +} From 9e987a22e707a05e03924215faeeab2a33c54856 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 30 May 2023 17:07:29 +0100 Subject: [PATCH 424/725] Linux makefile --- tests/features/extra_global_param_ref/Makefile | 1 + 1 file changed, 1 insertion(+) create mode 120000 tests/features/extra_global_param_ref/Makefile diff --git a/tests/features/extra_global_param_ref/Makefile b/tests/features/extra_global_param_ref/Makefile new file mode 120000 index 0000000000..1302b13ca5 --- /dev/null +++ b/tests/features/extra_global_param_ref/Makefile @@ -0,0 +1 @@ +../../utils/Makefile \ No newline at end of file From c677ef57ba368163829138b7be7175edb83db1b8 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 18:07:09 +0100 Subject: [PATCH 425/725] PyGeNN integration # Conflicts: # pygenn/genn_groups.py # pygenn/genn_model.py # pygenn/model_preprocessor.py --- pygenn/genn_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index 5a76b8fa09..7764cc4d13 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -469,7 +469,7 @@ class derived from return c_source def add_custom_update(self, cu_name, group_name, custom_update_model, - param_space, var_space, var_ref_space): + param_space, var_space, var_ref_space, egp_ref_space={}): """Add a current source to the GeNN model Args: @@ -487,6 +487,8 @@ class derived from CustomUpdateModel class var_ref_space -- dict with variable references for the CustomUpdateModel class + egp_ref_space -- dict with extra global parameter references + for the CustomUpdateModel class """ if self._built: raise Exception("GeNN model already built") @@ -501,7 +503,7 @@ class derived from # Use superclass to add population c_update = super(GeNNModel, self).add_custom_update( cu_name, group_name, custom_update_model, - param_space, var_init, var_ref_space) + param_space, var_init, var_ref_space, egp_ref_space) # Setup back-reference, store group in dictionary and return c_update._init_group(self, var_space) From d750ca299444e8fc3f5038f1737900ebaf16f5b0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 18:09:13 +0100 Subject: [PATCH 426/725] fixed pygenn typos --- pygenn/genn_model.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index 7764cc4d13..e9949dca86 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -61,7 +61,7 @@ from .genn import (generate_code, init_logging, CurrentSource, CurrentSourceModelBase, CustomUpdate, CustomUpdateModelBase, CustomUpdateWU, DerivedParam, - EGP, InitSparseConnectivitySnippetBase, + EGP, EGPRef, InitSparseConnectivitySnippetBase, InitToeplitzConnectivitySnippetBase, InitVarSnippetBase, ModelSpecInternal, NeuronGroup, NeuronModelBase, ParamVal, PlogSeverity, PostsynapticModelBase, @@ -1107,7 +1107,8 @@ def create_custom_custom_update_class(class_name, param_names=None, derived_params=None, var_refs=None, update_code=None, - extra_global_params=None): + extra_global_params=None, + extra_global_param_refs=None,): """This helper function creates a custom CustomUpdate class. See also: create_custom_neuron_class @@ -1117,7 +1118,7 @@ def create_custom_custom_update_class(class_name, param_names=None, create_custom_sparse_connect_init_snippet_class Args: - class_name -- name of the new class + class_name -- name of the new class Keyword args: param_names -- list of strings with param names of the model @@ -1131,6 +1132,9 @@ def create_custom_custom_update_class(class_name, param_names=None, update_code -- string with the current injection code extra_global_params -- list of pairs of strings with names and types of additional parameters + extra_global_param_refs -- list of pairs of strings with names and types of + extra global parameter references + """ body = {} @@ -1140,6 +1144,10 @@ def create_custom_custom_update_class(class_name, param_names=None, if var_refs is not None: body["get_var_refs"] = lambda self: [VarRef(*v) for v in var_refs] + if extra_global_param_refs is not None: + body["get_extra_global_param_refs"] =\ + lambda self: [EGPRef(*e) for e in extra_global_param_refs] + return create_custom_model_class( class_name, CustomUpdateModelBase, param_names, var_name_types, derived_params, extra_global_params, body) From 92e195e828737863d9ad5163b208989145e0b07c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 18 Jul 2023 18:11:23 +0100 Subject: [PATCH 427/725] GCC fix --- src/genn/genn/models.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/models.cc b/src/genn/genn/models.cc index fea7767f21..812b02a2b6 100644 --- a/src/genn/genn/models.cc +++ b/src/genn/genn/models.cc @@ -284,7 +284,7 @@ void updateHash(const WUVarReference &v, boost::uuids::detail::sha1 &hash) } } //---------------------------------------------------------------------------- -void Models::updateHash(const EGPReference &v, boost::uuids::detail::sha1 &hash) +void updateHash(const EGPReference &v, boost::uuids::detail::sha1 &hash) { Utils::updateHash(v.getTargetName(), hash); Utils::updateHash(v.getEGPIndex(), hash); From 817d745e92ca142e44c2405b1fdc06468e4d4499 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 19 Jul 2023 13:48:05 +0100 Subject: [PATCH 428/725] fixed a few more warnings --- src/genn/genn/transpiler/prettyPrinter.cc | 4 ++-- src/genn/genn/transpiler/typeChecker.cc | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 65bd572130..1ece7b70f1 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -85,7 +85,7 @@ class EnvironmentCallArgument : public EnvironmentBase //--------------------------------------------------------------------------- // EnvironmentBase virtuals //--------------------------------------------------------------------------- - virtual std::string define(const std::string &name) final + virtual std::string define(const std::string&) final { throw std::runtime_error("Cannot declare variable in call environment"); } @@ -526,4 +526,4 @@ void GeNN::Transpiler::PrettyPrinter::print(const Expression::ExpressionPtr &exp { EnvironmentInternal internalEnvironment(environment); Visitor visitor(expression, internalEnvironment, context, resolvedTypes); -} \ No newline at end of file +} diff --git a/src/genn/genn/transpiler/typeChecker.cc b/src/genn/genn/transpiler/typeChecker.cc index d84eb24ee4..3f3b99cca4 100644 --- a/src/genn/genn/transpiler/typeChecker.cc +++ b/src/genn/genn/transpiler/typeChecker.cc @@ -55,7 +55,7 @@ bool checkForConstRemoval(const Type::ResolvedType &rightType, const Type::Resol return std::visit( Utils::Overload{ // If both are value types - [](const Type::ResolvedType::Value &rightValue, const Type::ResolvedType::Value &leftValue) + [](const Type::ResolvedType::Value&, const Type::ResolvedType::Value&) { return true; }, @@ -113,7 +113,7 @@ bool checkImplicitConversion(const Type::ResolvedType &rightType, const Type::Re } }, // Otherwise, if left is pointer and right is numeric, - [op](const Type::ResolvedType::Value &rightValue, const Type::ResolvedType::Pointer &leftPointer) + [op](const Type::ResolvedType::Value &rightValue, const Type::ResolvedType::Pointer&) { assert(rightValue.numeric); if (op == Token::Type::PLUS_EQUAL || op == Token::Type::MINUS_EQUAL) { @@ -293,8 +293,9 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } }, // Otherwise, if both operands are pointers + // **TODO** don't pointer types need to be the same? [&binary, &leftType, &rightType, opType, this] - (const Type::ResolvedType::Pointer &rightPointer, const Type::ResolvedType::Pointer &leftPointer) -> std::optional + (const Type::ResolvedType::Pointer&, const Type::ResolvedType::Pointer&) -> std::optional { // If operator is minus and pointer types match if (opType == Token::Type::MINUS && leftType == rightType) { @@ -307,7 +308,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor }, // Otherwise, if right is numeric and left is pointer [&binary, &leftType, &rightType, opType, this] - (const Type::ResolvedType::Value &rightValue, const Type::ResolvedType::Pointer &leftPointer) -> std::optional + (const Type::ResolvedType::Value &rightValue, const Type::ResolvedType::Pointer&) -> std::optional { // If operator is valid and numeric type is integer // P + n or P - n @@ -320,7 +321,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor }, // Otherwise, if right is pointer and left is numeric [&binary, &rightType, opType, this] - (const Type::ResolvedType::Pointer &rightPointer, const Type::ResolvedType::Value &leftValue) -> std::optional + (const Type::ResolvedType::Pointer&, const Type::ResolvedType::Value &leftValue) -> std::optional { // n + P if (opType == Token::Type::PLUS && leftValue.numeric->isIntegral) { @@ -543,7 +544,7 @@ class Visitor : public Expression::Visitor, public Statement::Visitor const auto argConversionRank = std::visit( Utils::Overload{ // If types are numeric, any cast goes - [c, a](const Type::ResolvedType::Value &cValue, const Type::ResolvedType::Value &aValue) -> std::optional + [c, a](const Type::ResolvedType::Value &cValue, const Type::ResolvedType::Value&) -> std::optional { // If names are identical, match is exact // **TODO** we don't care about qualifiers From 5bf2aae8539d94193a97d900433896c082c01b85 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 19 Jul 2023 16:30:24 +0100 Subject: [PATCH 429/725] start fixing up PyGeNN * Updated enums and function signatures * Included support for EGP * Started trying to figure out a way of exposing types (opaquely) --- include/genn/genn/modelSpec.h | 16 ------- pygenn/__init__.py | 7 +-- pygenn/genn_model.py | 10 ++-- pygenn/src/genn.cc | 87 +++++++++++++++++------------------ pygenn/src/type.cc | 31 +++++++++++++ setup.py | 3 ++ 6 files changed, 85 insertions(+), 69 deletions(-) create mode 100644 pygenn/src/type.cc diff --git a/include/genn/genn/modelSpec.h b/include/genn/genn/modelSpec.h index be80b84ced..f3db7d372c 100644 --- a/include/genn/genn/modelSpec.h +++ b/include/genn/genn/modelSpec.h @@ -49,22 +49,6 @@ using VarReferences = std::unordered_map; using WUVarReferences = std::unordered_map; using EGPReferences = std::unordered_map; -//! Floating point precision to use for "scalar" type variables models -enum class ScalarPrecision -{ - FLOAT, - DOUBLE, - LONG_DOUBLE, -}; - -//! Precision to use for variables which store time -enum class TimePrecision -{ - DEFAULT, //!< Time uses default model precision - FLOAT, //!< Time uses single precision - not suitable for long simulations - DOUBLE, //!< Time uses double precision - may reduce performance -}; - //! Initialise a variable using an initialisation snippet /*! \tparam S type of variable initialisation snippet (derived from InitVarSnippet::Base). \param params parameters for snippet wrapped in ParamValues object. diff --git a/pygenn/__init__.py b/pygenn/__init__.py index 058023d38c..0b19b543ff 100644 --- a/pygenn/__init__.py +++ b/pygenn/__init__.py @@ -3,9 +3,10 @@ # pygenn interface from .genn import (create_var_ref, create_psm_var_ref, create_wu_pre_var_ref, - create_wu_post_var_ref, create_wu_var_ref, PlogSeverity, - ScalarPrecision, SpanType, SynapseMatrixType, TimePrecision, - VarAccess, VarAccessMode, VarLocation) + create_wu_post_var_ref, create_wu_var_ref, create_egp_ref, + create_psm_egp_ref, create_wu_egp_ref, PlogSeverity, + SpanType, SynapseMatrixType, VarAccess, + VarAccessMode, VarLocation) from .genn_model import (GeNNModel, init_sparse_connectivity, init_toeplitz_connectivity, init_var) diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index e9949dca86..2d11bfa752 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -65,9 +65,9 @@ InitToeplitzConnectivitySnippetBase, InitVarSnippetBase, ModelSpecInternal, NeuronGroup, NeuronModelBase, ParamVal, PlogSeverity, PostsynapticModelBase, - ScalarPrecision, SparseConnectivityInit, SynapseGroup, - SynapseMatrixType, TimePrecision, ToeplitzConnectivityInit, - Var, VarInit, VarLocation, VarRef, WeightUpdateModelBase) + SparseConnectivityInit, SynapseGroup, SynapseMatrixType, + ToeplitzConnectivityInit, Var, VarInit, VarLocation, + VarRef, WeightUpdateModelBase) from .shared_library_model import (SharedLibraryModelDouble, SharedLibraryModelFloat) @@ -128,8 +128,6 @@ backend_modules[b] = m -GeNNType = namedtuple("GeNNType", ["np_dtype", "assign_ext_ptr_array", "assign_ext_ptr_single"]) - class GeNNModel(ModelSpecInternal): """GeNNModel class This class helps to define, build and run a GeNN model from python @@ -147,7 +145,7 @@ def __init__(self, precision="float", model_name="GeNNModel", or "long double"). defaults to float. model_name -- string name of the model. Defaults to "GeNNModel". backend -- string specifying name of backend module to use - Defaults to None to pick 'best' backend for your system + Defaults to one to pick 'best' backend for your system time_precision -- string time precision as string ("float", "double" or "long double"). defaults to float. genn_log_level -- Log level for GeNN diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index 7615690dce..3f4daa791d 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -69,9 +69,7 @@ class PyInitSparseConnectivitySnippetBase : public PySnippet; - Logging::init(gennLevel, codeGeneratorLevel, consoleAppender, consoleAppender); + Logging::init(gennLevel, codeGeneratorLevel, transpilerLevel, + consoleAppender, consoleAppender, consoleAppender); } } @@ -242,22 +241,18 @@ PYBIND11_MODULE(genn, m) .value("TOEPLITZ", SynapseMatrixConnectivity::TOEPLITZ); pybind11::enum_(m, "SynapseMatrixWeight") - .value("GLOBAL", SynapseMatrixWeight::GLOBAL) .value("INDIVIDUAL", SynapseMatrixWeight::INDIVIDUAL) .value("PROCEDURAL", SynapseMatrixWeight::PROCEDURAL) .value("KERNEL", SynapseMatrixWeight::KERNEL); pybind11::enum_(m, "SynapseMatrixType") - .value("DENSE_GLOBALG", SynapseMatrixType::DENSE_GLOBALG) - .value("DENSE_INDIVIDUALG", SynapseMatrixType::DENSE_INDIVIDUALG) + .value("DENSE", SynapseMatrixType::DENSE) .value("DENSE_PROCEDURALG", SynapseMatrixType::DENSE_PROCEDURALG) - .value("BITMASK_GLOBALG", SynapseMatrixType::BITMASK_GLOBALG) - .value("SPARSE_GLOBALG", SynapseMatrixType::SPARSE_GLOBALG) - .value("SPARSE_INDIVIDUALG", SynapseMatrixType::SPARSE_INDIVIDUALG) - .value("PROCEDURAL_GLOBALG", SynapseMatrixType::PROCEDURAL_GLOBALG) - .value("PROCEDURAL_PROCEDURALG", SynapseMatrixType::PROCEDURAL_PROCEDURALG) + .value("BITMASK", SynapseMatrixType::BITMASK) + .value("SPARSE", SynapseMatrixType::SPARSE) + .value("PROCEDURAL", SynapseMatrixType::PROCEDURAL) .value("PROCEDURAL_KERNELG", SynapseMatrixType::PROCEDURAL_KERNELG) - .value("TOEPLITZ_KERNELG", SynapseMatrixType::TOEPLITZ_KERNELG) + .value("TOEPLITZ", SynapseMatrixType::TOEPLITZ) .def("__and__", [](SynapseMatrixType a, SynapseMatrixConnectivity b){ return a & b; }, pybind11::is_operator()) @@ -284,15 +279,19 @@ PYBIND11_MODULE(genn, m) //! Flags defining how variables should be duplicated across multiple batches pybind11::enum_(m, "VarAccessDuplication") .value("DUPLICATE", VarAccessDuplication::DUPLICATE) - .value("SHARED", VarAccessDuplication::SHARED); + .value("SHARED", VarAccessDuplication::SHARED) + .value("SHARED_NEURON", VarAccessDuplication::SHARED_NEURON); //! Supported combinations of VarAccessMode and VarAccessDuplication pybind11::enum_(m, "VarAccess") .value("READ_WRITE", VarAccess::READ_WRITE) .value("READ_ONLY", VarAccess::READ_ONLY) + .value("READ_ONLY_SHARED_NEURON", VarAccess::READ_ONLY_SHARED_NEURON) .value("READ_ONLY_DUPLICATE", VarAccess::READ_ONLY_DUPLICATE) .value("REDUCE_BATCH_SUM", VarAccess::REDUCE_BATCH_SUM) .value("REDUCE_BATCH_MAX", VarAccess::REDUCE_BATCH_MAX) + .value("REDUCE_NEURON_SUM", VarAccess::REDUCE_NEURON_SUM) + .value("REDUCE_NEURON_MAX", VarAccess::REDUCE_NEURON_MAX) .def("__and__", [](VarAccess a, VarAccessModeAttribute b){ return a & b; }, pybind11::is_operator()) @@ -316,33 +315,27 @@ PYBIND11_MODULE(genn, m) pybind11::enum_(m, "SpanType") .value("POSTSYNAPTIC", SynapseGroup::SpanType::POSTSYNAPTIC) .value("PRESYNAPTIC", SynapseGroup::SpanType::PRESYNAPTIC); - - //! Precision to use for scalar type variables - pybind11::enum_(m, "ScalarPrecision") - .value("FLOAT", ScalarPrecision::FLOAT) - .value("DOUBLE", ScalarPrecision::DOUBLE) - .value("LONG_DOUBLE", ScalarPrecision::LONG_DOUBLE); - - //! Precision to use for variables which store time - pybind11::enum_(m, "TimePrecision") - .value("DEFAULT", TimePrecision::DEFAULT) - .value("FLOAT", TimePrecision::FLOAT) - .value("DOUBLE", TimePrecision::DOUBLE); - + //------------------------------------------------------------------------ // Free functions //------------------------------------------------------------------------ m.def("generate_code", &generateCode, pybind11::return_value_policy::move); m.def("init_logging", &initLogging); - m.def("create_var_ref", pybind11::overload_cast(&createVarRef), pybind11::return_value_policy::move); - m.def("create_var_ref", pybind11::overload_cast(&createVarRef), pybind11::return_value_policy::move); - m.def("create_var_ref", pybind11::overload_cast(&createVarRef), pybind11::return_value_policy::move); + m.def("create_var_ref", pybind11::overload_cast(&createVarRef), pybind11::return_value_policy::move); + m.def("create_var_ref", pybind11::overload_cast(&createVarRef), pybind11::return_value_policy::move); + m.def("create_var_ref", pybind11::overload_cast(&createVarRef), pybind11::return_value_policy::move); m.def("create_psm_var_ref", &createPSMVarRef, pybind11::return_value_policy::move); m.def("create_wu_pre_var_ref", &createWUPreVarRef, pybind11::return_value_policy::move); m.def("create_wu_post_var_ref", &createWUPostVarRef, pybind11::return_value_policy::move); - m.def("create_wu_var_ref", pybind11::overload_cast(&createWUVarRef), + m.def("create_wu_var_ref", pybind11::overload_cast(&createWUVarRef), "sg"_a, "var_name"_a, "transpose_sg"_a = nullptr, "transpose_var_name"_a = "", pybind11::return_value_policy::move); - m.def("create_wu_var_ref", pybind11::overload_cast(&createWUVarRef), pybind11::return_value_policy::move); + m.def("create_wu_var_ref", pybind11::overload_cast(&createWUVarRef), pybind11::return_value_policy::move); + m.def("create_egp_ref", pybind11::overload_cast(&createEGPRef), pybind11::return_value_policy::move); + m.def("create_egp_ref", pybind11::overload_cast(&createEGPRef), pybind11::return_value_policy::move); + m.def("create_egp_ref", pybind11::overload_cast(&createEGPRef), pybind11::return_value_policy::move); + m.def("create_egp_ref", pybind11::overload_cast(&createEGPRef), pybind11::return_value_policy::move); + m.def("create_psm_egp_ref", pybind11::overload_cast(&createPSMEGPRef), pybind11::return_value_policy::move); + m.def("create_wu_egp_ref", pybind11::overload_cast(&createWUEGPRef), pybind11::return_value_policy::move); //------------------------------------------------------------------------ // genn.ModelSpec @@ -380,12 +373,12 @@ PYBIND11_MODULE(genn, m) .def("add_custom_update", static_cast(&ModelSpecInternal::addCustomUpdate), + const ParamValues&, const VarValues&, const VarReferences&, const EGPReferences&)>(&ModelSpecInternal::addCustomUpdate), pybind11::return_value_policy::reference) .def("add_custom_update", static_cast(&ModelSpecInternal::addCustomUpdate), + const ParamValues&, const VarValues&, const WUVarReferences&, const EGPReferences&)>(&ModelSpecInternal::addCustomUpdate), pybind11::return_value_policy::reference) .def("add_neuron_population", static_cast(&ModelSpecInternal::addSynapsePopulation), pybind11::return_value_policy::reference) - .def("finalize", &ModelSpecInternal::finalize); + .def("finalise", &ModelSpecInternal::finalise); //------------------------------------------------------------------------ // genn.CurrentSource @@ -573,14 +566,12 @@ PYBIND11_MODULE(genn, m) pybind11::class_(m, "InitSparseConnectivitySnippetBase") .def(pybind11::init<>()) .def("get_row_build_code", &InitSparseConnectivitySnippet::Base::getRowBuildCode) - .def("get_row_build_state_vars", &InitSparseConnectivitySnippet::Base::getRowBuildStateVars) .def("get_col_build_code", &InitSparseConnectivitySnippet::Base::getColBuildCode) - .def("get_col_build_state_vars", &InitSparseConnectivitySnippet::Base::getColBuildStateVars) .def("get_host_init_code", &InitSparseConnectivitySnippet::Base::getHostInitCode) .def("get_calc_max_row_length_func", &InitSparseConnectivitySnippet::Base::getCalcMaxRowLengthFunc) .def("get_calc_max_col_length_func", &InitSparseConnectivitySnippet::Base::getCalcMaxColLengthFunc) .def("get_calc_kernel_size_func", &InitSparseConnectivitySnippet::Base::getCalcKernelSizeFunc); - + //------------------------------------------------------------------------ // genn.InitToeplitzConnectivitySnippetBase //------------------------------------------------------------------------ @@ -590,7 +581,7 @@ PYBIND11_MODULE(genn, m) .def("get_diagonal_build_state_vars", &InitToeplitzConnectivitySnippet::Base::getDiagonalBuildStateVars) .def("get_calc_max_row_length_func", &InitToeplitzConnectivitySnippet::Base::getCalcMaxRowLengthFunc) .def("get_calc_kernel_size_func", &InitToeplitzConnectivitySnippet::Base::getCalcKernelSizeFunc); - + //------------------------------------------------------------------------ // genn.InitVarSnippetBaseBase //------------------------------------------------------------------------ @@ -598,7 +589,7 @@ PYBIND11_MODULE(genn, m) .def(pybind11::init<>()) .def("get_code", &InitVarSnippet::Base::getCode); - + //------------------------------------------------------------------------ // genn.Var //------------------------------------------------------------------------ @@ -608,7 +599,7 @@ PYBIND11_MODULE(genn, m) .def_readonly("name", &Models::Base::Var::name) .def_readonly("type", &Models::Base::Var::type) .def_readonly("access", &Models::Base::Var::access); - + //------------------------------------------------------------------------ // genn.VarRef //------------------------------------------------------------------------ @@ -618,7 +609,15 @@ PYBIND11_MODULE(genn, m) .def_readonly("name", &Models::Base::VarRef::name) .def_readonly("type", &Models::Base::VarRef::type) .def_readonly("access", &Models::Base::VarRef::access); - + + //------------------------------------------------------------------------ + // genn.EGPRef + //------------------------------------------------------------------------ + pybind11::class_(m, "EGPRef") + .def(pybind11::init()) + .def_readonly("name", &Models::Base::EGPRef::name) + .def_readonly("type", &Models::Base::EGPRef::type); + //------------------------------------------------------------------------ // genn.ModelBase //------------------------------------------------------------------------ @@ -709,10 +708,10 @@ PYBIND11_MODULE(genn, m) //------------------------------------------------------------------------ // genn.VarInit //------------------------------------------------------------------------ - pybind11::class_(m, "VarInit") + pybind11::class_(m, "VarInit") .def(pybind11::init&>()) .def(pybind11::init()) - .def_property_readonly("snippet", &Models::VarInit::getSnippet, pybind11::return_value_policy::reference); + .def_property_readonly("snippet", &InitVarSnippet::Init::getSnippet, pybind11::return_value_policy::reference); //------------------------------------------------------------------------ // genn.WUVarReference diff --git a/pygenn/src/type.cc b/pygenn/src/type.cc new file mode 100644 index 0000000000..19e627e1d0 --- /dev/null +++ b/pygenn/src/type.cc @@ -0,0 +1,31 @@ +// PyBind11 includes +#include + +// GeNN includes +#include "type.h" + +using namespace GeNN::Type; + +//---------------------------------------------------------------------------- +// type +//---------------------------------------------------------------------------- +PYBIND11_MODULE(type, m) +{ + //------------------------------------------------------------------------ + // Attributes + //------------------------------------------------------------------------ + m.attr("Bool") = pybind11::cast(Bool); +} +/*Bool = CREATE_NUMERIC(bool, 0, ""); +inline static const ResolvedType Int8 = CREATE_NUMERIC(int8_t, 10, ""); +inline static const ResolvedType Int16 = CREATE_NUMERIC(int16_t, 20, ""); +inline static const ResolvedType Int32 = CREATE_NUMERIC(int32_t, 30, ""); +inline static const ResolvedType Int64 = CREATE_NUMERIC(int64_t, 40, ""); + +inline static const ResolvedType Uint8 = CREATE_NUMERIC(uint8_t, 10, "u"); +inline static const ResolvedType Uint16 = CREATE_NUMERIC(uint16_t, 20, "u"); +inline static const ResolvedType Uint32 = CREATE_NUMERIC(uint32_t, 30, "u"); +inline static const ResolvedType Uint64 = CREATE_NUMERIC(uint64_t, 40, "u"); + +inline static const ResolvedType Float = CREATE_NUMERIC(float, 50, "f"); +inline static const ResolvedType Double*/ diff --git a/setup.py b/setup.py index 30a49d7ab2..4f31a875a0 100644 --- a/setup.py +++ b/setup.py @@ -128,6 +128,9 @@ Pybind11Extension("genn", [os.path.join(pygenn_src, "genn.cc")], **genn_extension_kwargs), + Pybind11Extension("type", + [os.path.join(pygenn_src, "type.cc")], + **genn_extension_kwargs), Pybind11Extension("init_sparse_connectivity_snippets", [os.path.join(pygenn_src, "initSparseConnectivitySnippets.cc")], **genn_extension_kwargs), From e351dc45d369ef758d221349bb8e003412a97011 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 20 Jul 2023 09:28:08 +0100 Subject: [PATCH 430/725] fixed Windows DLL linker errors --- include/genn/genn/code_generator/environment.h | 5 +++-- include/genn/genn/code_generator/lazyString.h | 5 ++++- include/genn/genn/code_generator/modelSpecMerged.h | 2 ++ include/genn/genn/code_generator/standardLibrary.h | 7 +++++-- include/genn/genn/initSparseConnectivitySnippet.h | 2 +- include/genn/genn/transpiler/prettyPrinter.h | 3 ++- include/genn/genn/transpiler/token.h | 6 +++--- include/genn/genn/transpiler/typeChecker.h | 3 ++- include/genn/genn/type.h | 4 ++-- 9 files changed, 24 insertions(+), 13 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 9bc5cfaaf6..bd312d3662 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -6,6 +6,7 @@ #include // GeNN includes +#include "gennExport.h" #include "gennUtils.h" #include "varAccess.h" #include "type.h" @@ -33,7 +34,7 @@ struct Token; //! Base class for external environments i.e. those defines OUTSIDE of transpiled code by code generator namespace GeNN::CodeGenerator { -class EnvironmentExternalBase : public Transpiler::PrettyPrinter::EnvironmentBase, public Transpiler::TypeChecker::EnvironmentBase +class GENN_EXPORT EnvironmentExternalBase : public Transpiler::PrettyPrinter::EnvironmentBase, public Transpiler::TypeChecker::EnvironmentBase { public: explicit EnvironmentExternalBase(EnvironmentExternalBase &enclosing) @@ -88,7 +89,7 @@ class EnvironmentExternalBase : public Transpiler::PrettyPrinter::EnvironmentBas //---------------------------------------------------------------------------- // GeNN::CodeGenerator::EnvironmentLibrary //---------------------------------------------------------------------------- -class EnvironmentLibrary : public EnvironmentExternalBase +class GENN_EXPORT EnvironmentLibrary : public EnvironmentExternalBase { public: using Library = std::unordered_multimap>; diff --git a/include/genn/genn/code_generator/lazyString.h b/include/genn/genn/code_generator/lazyString.h index f826491f23..653ebc820b 100644 --- a/include/genn/genn/code_generator/lazyString.h +++ b/include/genn/genn/code_generator/lazyString.h @@ -5,6 +5,9 @@ #include #include +// GeNN includes +#include "gennExport.h" + // Forward declarations namespace GeNN::CodeGenerator { @@ -17,7 +20,7 @@ class EnvironmentExternalBase; //! Lazily-evaluated string class - constructed from a format string containing $(XX) references to variables in environment namespace GeNN::CodeGenerator { -class LazyString +class GENN_EXPORT LazyString { public: LazyString(const std::string &format, EnvironmentExternalBase &env); diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 0be08cc808..757cbdf20d 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -38,6 +38,8 @@ class GENN_EXPORT ModelSpecMerged m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode"), m_TypeContext{{"scalar", model.getPrecision()}, {"timepoint", model.getTimePrecision()}} { } + ModelSpecMerged(const ModelSpecMerged&) = delete; + ModelSpecMerged &operator=(const ModelSpecMerged &) = delete; //-------------------------------------------------------------------------- // CodeGenerator::ModelSpecMerged::EGPField diff --git a/include/genn/genn/code_generator/standardLibrary.h b/include/genn/genn/code_generator/standardLibrary.h index f1f6cfd7c9..4f94be1dc2 100644 --- a/include/genn/genn/code_generator/standardLibrary.h +++ b/include/genn/genn/code_generator/standardLibrary.h @@ -1,5 +1,8 @@ #pragma once +// GeNN includes +#include "gennExport.h" + // Code generator includes #include "code_generator/environment.h" @@ -9,8 +12,8 @@ namespace GeNN::CodeGenerator::StandardLibrary { //! Get standard maths functions -const EnvironmentLibrary::Library &getMathsFunctions(); +GENN_EXPORT const EnvironmentLibrary::Library &getMathsFunctions(); //! Get std::random based host RNG functions -const EnvironmentLibrary::Library &getHostRNGFunctions(const Type::ResolvedType &precision); +GENN_EXPORT const EnvironmentLibrary::Library &getHostRNGFunctions(const Type::ResolvedType &precision); } // namespace GeNN::CodeGenerator::StandardLibrary diff --git a/include/genn/genn/initSparseConnectivitySnippet.h b/include/genn/genn/initSparseConnectivitySnippet.h index 96c890189f..591a61e657 100644 --- a/include/genn/genn/initSparseConnectivitySnippet.h +++ b/include/genn/genn/initSparseConnectivitySnippet.h @@ -68,7 +68,7 @@ class GENN_EXPORT Base : public Snippet::Base //---------------------------------------------------------------------------- // Init //---------------------------------------------------------------------------- -class Init : public Snippet::Init +class GENN_EXPORT Init : public Snippet::Init { public: Init(const Base *snippet, const std::unordered_map ¶ms); diff --git a/include/genn/genn/transpiler/prettyPrinter.h b/include/genn/genn/transpiler/prettyPrinter.h index bd3a2a8346..d9c470d11f 100644 --- a/include/genn/genn/transpiler/prettyPrinter.h +++ b/include/genn/genn/transpiler/prettyPrinter.h @@ -5,6 +5,7 @@ #include // GeNN includes +#include "gennExport.h" #include "type.h" // Transpiler includes @@ -22,7 +23,7 @@ class CodeStream; //--------------------------------------------------------------------------- namespace GeNN::Transpiler::PrettyPrinter { -class EnvironmentBase +class GENN_EXPORT EnvironmentBase { public: //------------------------------------------------------------------------ diff --git a/include/genn/genn/transpiler/token.h b/include/genn/genn/transpiler/token.h index 9d9e6c5c6d..923e29eaa0 100644 --- a/include/genn/genn/transpiler/token.h +++ b/include/genn/genn/transpiler/token.h @@ -54,8 +54,8 @@ struct Token { } - const Type type; - const std::string lexeme; - const size_t line; + Type type; + std::string lexeme; + size_t line; }; } // namespace GeNN::Transpiler diff --git a/include/genn/genn/transpiler/typeChecker.h b/include/genn/genn/transpiler/typeChecker.h index 80d00c1b78..748ad77188 100644 --- a/include/genn/genn/transpiler/typeChecker.h +++ b/include/genn/genn/transpiler/typeChecker.h @@ -8,6 +8,7 @@ #include // GeNN includes +#include "gennExport.h" #include "type.h" // Transpiler includes @@ -36,7 +37,7 @@ class TypeCheckError : public std::runtime_error //--------------------------------------------------------------------------- // GeNN::Transpiler::TypeChecker::EnvironmentBase //--------------------------------------------------------------------------- -class EnvironmentBase +class GENN_EXPORT EnvironmentBase { public: //------------------------------------------------------------------------ diff --git a/include/genn/genn/type.h b/include/genn/genn/type.h index 134ca0f033..b0a30d0ce5 100644 --- a/include/genn/genn/type.h +++ b/include/genn/genn/type.h @@ -47,7 +47,7 @@ inline Qualifier operator | (Qualifier a, Qualifier b) //---------------------------------------------------------------------------- // GeNN::Type::ResolvedType //---------------------------------------------------------------------------- -struct ResolvedType +struct GENN_EXPORT ResolvedType { //------------------------------------------------------------------------ // Numeric @@ -289,7 +289,7 @@ typedef std::unordered_map TypeContext; //---------------------------------------------------------------------------- // UnresolvedType //---------------------------------------------------------------------------- -struct UnresolvedType +struct GENN_EXPORT UnresolvedType { UnresolvedType(const ResolvedType &type) : detail(type) From 9c1e956bf08051692f822847e12227dab50861af Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 20 Jul 2023 09:40:42 +0100 Subject: [PATCH 431/725] fix a few more Windows linker errors --- include/genn/genn/code_generator/initGroupMerged.h | 2 +- include/genn/genn/initToeplitzConnectivitySnippet.h | 2 +- include/genn/genn/initVarSnippet.h | 2 +- include/genn/genn/models.h | 1 + include/genn/genn/snippet.h | 2 +- 5 files changed, 5 insertions(+), 4 deletions(-) diff --git a/include/genn/genn/code_generator/initGroupMerged.h b/include/genn/genn/code_generator/initGroupMerged.h index d26a74fea7..7c638e178d 100644 --- a/include/genn/genn/code_generator/initGroupMerged.h +++ b/include/genn/genn/code_generator/initGroupMerged.h @@ -15,7 +15,7 @@ namespace GeNN::CodeGenerator { template -class GENN_EXPORT InitGroupMergedBase : public B +class InitGroupMergedBase : public B { public: using B::B; diff --git a/include/genn/genn/initToeplitzConnectivitySnippet.h b/include/genn/genn/initToeplitzConnectivitySnippet.h index f23516732f..d50225d4fe 100644 --- a/include/genn/genn/initToeplitzConnectivitySnippet.h +++ b/include/genn/genn/initToeplitzConnectivitySnippet.h @@ -61,7 +61,7 @@ class GENN_EXPORT Base : public Snippet::Base //---------------------------------------------------------------------------- // Init //---------------------------------------------------------------------------- -class Init : public Snippet::Init +class GENN_EXPORT Init : public Snippet::Init { public: Init(const Base *snippet, const std::unordered_map ¶ms); diff --git a/include/genn/genn/initVarSnippet.h b/include/genn/genn/initVarSnippet.h index bb411aa21d..d2d76e2ae5 100644 --- a/include/genn/genn/initVarSnippet.h +++ b/include/genn/genn/initVarSnippet.h @@ -42,7 +42,7 @@ class GENN_EXPORT Base : public Snippet::Base //! Class used to bind together everything required to initialise a variable: //! 1. A pointer to a variable initialisation snippet //! 2. The parameters required to control the variable initialisation snippet -class Init : public Snippet::Init +class GENN_EXPORT Init : public Snippet::Init { public: Init(const Base *snippet, const std::unordered_map ¶ms); diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h index 98ae7d4a54..ef6c9dff0c 100644 --- a/include/genn/genn/models.h +++ b/include/genn/genn/models.h @@ -248,6 +248,7 @@ class GENN_EXPORT WUVarReference : public VarReferenceBase bool operator < (const WUVarReference &other) const { + //**TODO** could be expressed in terms of tuple < const bool hasTranspose = (getTransposeSynapseGroup() != nullptr); const bool otherHasTranspose = (other.getTransposeSynapseGroup() != nullptr); if (hasTranspose && otherHasTranspose) { diff --git a/include/genn/genn/snippet.h b/include/genn/genn/snippet.h index 36bf8a051b..520fbe6205 100644 --- a/include/genn/genn/snippet.h +++ b/include/genn/genn/snippet.h @@ -71,7 +71,7 @@ class GENN_EXPORT Base }; //! Additional input variables, row state variables and other things have a name, a type and an initial value - struct ParamVal + struct GENN_EXPORT ParamVal { ParamVal(const std::string &n, const Type::ResolvedType &t, const std::string &v); ParamVal(const std::string &n, const Type::ResolvedType &t, double v); From e1c18e50e28c99017602e84a1af931b6e96655bb Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 20 Jul 2023 09:59:07 +0100 Subject: [PATCH 432/725] ``ModelSpec::setTimePrecision`` and ``ModelSpec::setPrecision`` variants taking string --- include/genn/genn/gennUtils.h | 47 +++++++++++++---------------------- include/genn/genn/modelSpec.h | 10 ++++++-- src/genn/genn/gennUtils.cc | 22 ++++++++++++++++ src/genn/genn/modelSpec.cc | 17 +++++++++++++ src/genn/genn/type.cc | 16 +----------- 5 files changed, 65 insertions(+), 47 deletions(-) diff --git a/include/genn/genn/gennUtils.h b/include/genn/genn/gennUtils.h index 22ba4bbc98..b2f6b9d03a 100644 --- a/include/genn/genn/gennUtils.h +++ b/include/genn/genn/gennUtils.h @@ -31,43 +31,36 @@ //-------------------------------------------------------------------------- namespace GeNN::Utils { +//! Helper to scan a multi-line code string, giving meaningful errors with the specified context string GENN_EXPORT std::vector scanCode(const std::string &code, const std::string &errorContext); +//! Helper to scan a type specifier string e.g "unsigned int" and parse it into a resolved type +GENN_EXPORT Type::ResolvedType parseNumericType(const std::string &type, const Type::TypeContext &typeContext); + +//! Is this sequence of tokens empty? +/*! For ease of parsing and as an extra check that we have scanned SOMETHING, + empty token sequences should have a single EOF token */ GENN_EXPORT bool areTokensEmpty(const std::vector &tokens); -//-------------------------------------------------------------------------- -//! \brief Does the code string contain any functions requiring random number generator -//-------------------------------------------------------------------------- +//! Checks whether the sequence of token references a given identifier GENN_EXPORT bool isIdentifierReferenced(const std::string &identifierName, const std::vector &tokens); -//-------------------------------------------------------------------------- -//! \brief Does the code string contain any functions requiring random number generator -//-------------------------------------------------------------------------- +//! Checks whether the sequence of token includes an RNG function identifier GENN_EXPORT bool isRNGRequired(const std::vector &tokens); -//-------------------------------------------------------------------------- -//! \brief Does the model with the vectors of variable initialisers and modes require an RNG for the specified init location i.e. host or device -//-------------------------------------------------------------------------- +//! Checks whether any of the variable initialisers in the vector require an RNG for initialisation GENN_EXPORT bool isRNGRequired(const std::unordered_map &varInitialisers); -//-------------------------------------------------------------------------- -//! \brief Is the variable name valid? GeNN variable names must obey C variable naming rules -//-------------------------------------------------------------------------- +//! Checks variable name is valid? GeNN variable names must obey C variable naming rules GENN_EXPORT void validateVarName(const std::string &name, const std::string &description); -//-------------------------------------------------------------------------- -//! \brief Is the population name valid? GeNN population names obey C variable naming rules but can start with a number -//-------------------------------------------------------------------------- +//! Checks whether population name is valid? GeNN population names obey C variable naming rules but can start with a number GENN_EXPORT void validatePopName(const std::string &name, const std::string &description); -//-------------------------------------------------------------------------- -//! \brief Are all the parameter names in vector valid? GeNN variables and population names must obey C variable naming rules -//-------------------------------------------------------------------------- +//! Checks that all the parameter names in vector valid? GeNN variables and population names must obey C variable naming rules GENN_EXPORT void validateParamNames(const std::vector ¶mNames); -//-------------------------------------------------------------------------- -//! \brief Are initialisers provided for all of the the item names in the vector? -//-------------------------------------------------------------------------- +//! Checks that initialisers provided for all of the the item names in the vector? template void validateInitialisers(const std::vector &vec, const std::unordered_map &values, const std::string &type, const std::string description) @@ -86,9 +79,7 @@ void validateInitialisers(const std::vector &vec, const std::unordered_map void validateVecNames(const std::vector &vec, const std::string &description) { @@ -97,9 +88,7 @@ void validateVecNames(const std::vector &vec, const std::string &description) } } -//-------------------------------------------------------------------------- -//! \brief This function writes a floating point value to a stream -setting the precision so no digits are lost -//-------------------------------------------------------------------------- +//! Write a floating point value to a stream - setting the precision so no digits are lost template::value>::type * = nullptr> void writePreciseString(std::ostream &os, T value, int maxDigits10 = std::numeric_limits::max_digits10) { @@ -124,9 +113,7 @@ void writePreciseString(std::ostream &os, T value, int maxDigits10 = std::numeri os << std::setprecision(previousPrecision); } -//-------------------------------------------------------------------------- -//! \brief This function writes a floating point value to a string - setting the precision so no digits are lost -//-------------------------------------------------------------------------- +//! Write a floating point value to a string - setting the precision so no digits are lost template::value>::type * = nullptr> inline std::string writePreciseString(T value, int maxDigits10 = std::numeric_limits::max_digits10) { diff --git a/include/genn/genn/modelSpec.h b/include/genn/genn/modelSpec.h index f3db7d372c..4e495793b3 100644 --- a/include/genn/genn/modelSpec.h +++ b/include/genn/genn/modelSpec.h @@ -252,12 +252,18 @@ class GENN_EXPORT ModelSpec //! Method to set the neuronal network model name void setName(const std::string &name){ m_Name = name; } - //! Set numerical precision for floating point + //! Set numerical precision for scalar type void setPrecision(const Type::ResolvedType &precision); - //! Set numerical precision for time + //! Set numerical precision for floating point + void setPrecision(const std::string &precision); + + //! Set numerical precision for time type void setTimePrecision(const Type::ResolvedType &timePrecision); + //! Set numerical precision for time type + void setTimePrecision(const std::string &timePrecision); + //! Set the integration step size of the model void setDT(double dt){ m_DT = dt; } diff --git a/src/genn/genn/gennUtils.cc b/src/genn/genn/gennUtils.cc index b75ea836e1..2cc15bc8d4 100644 --- a/src/genn/genn/gennUtils.cc +++ b/src/genn/genn/gennUtils.cc @@ -9,6 +9,7 @@ // GeNN transpiler includes #include "transpiler/errorHandler.h" +#include "transpiler/parser.h" #include "transpiler/scanner.h" //-------------------------------------------------------------------------- @@ -65,6 +66,8 @@ namespace GeNN::Utils { std::vector scanCode(const std::string &code, const std::string &errorContext) { + using namespace Transpiler; + // Upgrade code string const std::string upgradedCode = upgradeCodeString(code); @@ -77,6 +80,25 @@ std::vector scanCode(const std::string &code, const std::stri return tokens; } //-------------------------------------------------------------------------- +Type::ResolvedType parseNumericType(const std::string &type, const Type::TypeContext &typeContext) +{ + using namespace Transpiler; + + // Scan type + SingleLineErrorHandler errorHandler; + const auto tokens = Scanner::scanSource(type, errorHandler); + + // Parse type numeric type + const auto resolvedType = Parser::parseNumericType(tokens, typeContext, errorHandler); + + // If an error was encountered while scanning or parsing, throw exception + if (errorHandler.hasError()) { + throw std::runtime_error("Error parsing type '" + std::string{type} + "'"); + } + + return resolvedType; +} +//-------------------------------------------------------------------------- bool areTokensEmpty(const std::vector &tokens) { // For easy parsing, there should always be at least one token diff --git a/src/genn/genn/modelSpec.cc b/src/genn/genn/modelSpec.cc index 5392beb4c8..8adf81ca74 100644 --- a/src/genn/genn/modelSpec.cc +++ b/src/genn/genn/modelSpec.cc @@ -29,6 +29,9 @@ // GeNN code generator includes #include "code_generator/codeGenUtils.h" +// GeNN transpiler includes +#include "transpiler/parser.h" + // --------------------------------------------------------------------------- // GeNN::ModelSpec // --------------------------------------------------------------------------- @@ -59,6 +62,13 @@ void ModelSpec::setPrecision(const Type::ResolvedType &precision) } } // --------------------------------------------------------------------------- +void ModelSpec::setPrecision(const std::string &precision) +{ + // Parse type string and set precision + // **NOTE** no type context as that would be circular! + setPrecision(Utils::parseNumericType(precision, {})); +} +// --------------------------------------------------------------------------- void ModelSpec::setTimePrecision(const Type::ResolvedType &timePrecision) { if (!timePrecision.isNumeric()) { @@ -72,6 +82,13 @@ void ModelSpec::setTimePrecision(const Type::ResolvedType &timePrecision) } } // --------------------------------------------------------------------------- +void ModelSpec::setTimePrecision(const std::string &timePrecision) +{ + // Parse type string and set time precision + // **NOTE** no type context as that would be circular! + setTimePrecision(Utils::parseNumericType(timePrecision, {})); +} +// --------------------------------------------------------------------------- unsigned int ModelSpec::getNumNeurons() const { // Return sum of local neuron group sizes diff --git a/src/genn/genn/type.cc b/src/genn/genn/type.cc index 507e259c30..5a6e5efb7e 100644 --- a/src/genn/genn/type.cc +++ b/src/genn/genn/type.cc @@ -94,21 +94,7 @@ ResolvedType UnresolvedType::resolve(const TypeContext &typeContext) const }, [&typeContext](const std::string &name) { - using namespace Transpiler; - - // Scan type - SingleLineErrorHandler errorHandler; - const auto tokens = Scanner::scanSource(name, errorHandler); - - // Parse type numeric type - const auto type = Parser::parseNumericType(tokens, typeContext, errorHandler); - - // If an error was encountered while scanning or parsing, throw exception - if (errorHandler.hasError()) { - throw std::runtime_error("Error parsing type '" + std::string{name} + "'"); - } - - return type; + return Utils::parseNumericType(name, typeContext); }}, detail); } From 734f844021e6130376a19e530ca7cf27c9db3653 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 24 Jul 2023 10:56:26 +0100 Subject: [PATCH 433/725] work on integrating type into Python --- include/genn/genn/modelSpec.h | 10 ++----- pygenn/genn_model.py | 54 ++++++++++++++++------------------- pygenn/src/genn.cc | 20 +++++++++++-- pygenn/src/type.cc | 31 -------------------- pygenn/src/types.cc | 33 +++++++++++++++++++++ setup.py | 4 +-- src/genn/genn/modelSpec.cc | 36 +++++++++-------------- 7 files changed, 93 insertions(+), 95 deletions(-) delete mode 100644 pygenn/src/type.cc create mode 100644 pygenn/src/types.cc diff --git a/include/genn/genn/modelSpec.h b/include/genn/genn/modelSpec.h index 4e495793b3..7c89dc7c2a 100644 --- a/include/genn/genn/modelSpec.h +++ b/include/genn/genn/modelSpec.h @@ -253,16 +253,10 @@ class GENN_EXPORT ModelSpec void setName(const std::string &name){ m_Name = name; } //! Set numerical precision for scalar type - void setPrecision(const Type::ResolvedType &precision); - - //! Set numerical precision for floating point - void setPrecision(const std::string &precision); + void setPrecision(const Type::UnresolvedType &precision); //! Set numerical precision for time type - void setTimePrecision(const Type::ResolvedType &timePrecision); - - //! Set numerical precision for time type - void setTimePrecision(const std::string &timePrecision); + void setTimePrecision(const Type::UnresolvedType &timePrecision); //! Set the integration step size of the model void setDT(double dt){ m_DT = dt; } diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index 2d11bfa752..6a76efa5da 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -77,7 +77,7 @@ from . import (current_source_models, custom_update_models, init_sparse_connectivity_snippets, init_toeplitz_connectivity_snippets, init_var_snippets, - neuron_models, postsynaptic_models, weight_update_models) + neuron_models, postsynaptic_models, types, weight_update_models) # Dynamically add Python mixin to wrapped class CurrentSource.__bases__ += (CurrentSourceMixin,) @@ -137,52 +137,45 @@ def __init__(self, precision="float", model_name="GeNNModel", backend=None, time_precision=None, genn_log_level=PlogSeverity.WARNING, code_gen_log_level=PlogSeverity.WARNING, + transpiler_log_level=PlogSeverity.WARNING, backend_log_level=PlogSeverity.WARNING, **preference_kwargs): """Init GeNNModel Keyword args: - precision -- string precision as string ("float", "double" - or "long double"). defaults to float. - model_name -- string name of the model. Defaults to "GeNNModel". - backend -- string specifying name of backend module to use - Defaults to one to pick 'best' backend for your system - time_precision -- string time precision as string ("float", "double" - or "long double"). defaults to float. - genn_log_level -- Log level for GeNN - code_gen_log_level -- Log level for GeNN code-generator - backend_log_level -- Log level for backend - preference_kwargs -- Additional keyword arguments to set in backend preferences structure + precision -- string precision as string ("float" or "double"). + model_name -- string name of the model. Defaults to "GeNNModel". + backend -- string specifying name of backend module to use + Defaults to one to pick 'best' backend for your system + time_precision -- string time precision as string ("float" or "double") + genn_log_level -- Log level for GeNN + code_gen_log_level -- Log level for GeNN code-generator + transpiler_log_level -- Log level for GeNN transpiler + backend_log_level -- Log level for backend + preference_kwargs -- Additional keyword arguments to set in backend preferences structure """ # Superclass super(GeNNModel, self).__init__() + + # Set precision + self.precision = precision # Based on time precision, create correct type # of SLM class and determine GeNN time type # **NOTE** all SLM uses its template parameter for is time variable - time_precision = precision if time_precision is None else time_precision - if time_precision == "float": + self.time_precision = (precision if time_precision is None + else time_precision) + print(self.time_precision, types.Float) + if self.time_precision == types.Float: self._slm = SharedLibraryModelFloat() - self.time_precision = TimePrecision.FLOAT - elif time_precision == "double": + elif self.time_precision == types.Double: self._slm = SharedLibraryModelDouble() - self.time_precision = TimePrecision.DOUBLE else: raise ValueError( "Supported time precisions are float and double, " "but '{1}' was given".format(self._time_precision)) - # Set scalar type from precision - if precision == "float": - self.precision = ScalarPrecision.FLOAT - elif precision == "double": - self.precision = ScalarPrecision.DOUBLE - else: - raise ValueError( - "Supported precisions are float and double, " - "but '{1}' was given".format(precision)) - # Initialise GeNN logging - init_logging(genn_log_level, code_gen_log_level) + init_logging(genn_log_level, code_gen_log_level, transpiler_log_level) self._built = False self._loaded = False @@ -199,7 +192,8 @@ def __init__(self, precision="float", model_name="GeNNModel", self.current_sources = {} self.custom_updates = {} - # Build dictionary containing conversions between GeNN C++ types and numpy types + # Build dictionary containing conversions + # between GeNN C++ types and numpy types self.genn_types = { "float": np.float32, "double": np.float64, @@ -220,7 +214,7 @@ def __init__(self, precision="float", model_name="GeNNModel", "bool": np.bool8} # Add "scalar" type to genn_types - pointing at float or double as appropriate - if precision == "float": + if self.precision == types.Float: self.genn_types["scalar"] = self.genn_types["float"] else: self.genn_types["scalar"] = self.genn_types["double"] diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index 3f4daa791d..7019d62afc 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -347,7 +347,7 @@ PYBIND11_MODULE(genn, m) //-------------------------------------------------------------------- .def_property("name", &ModelSpecInternal::getName, &ModelSpecInternal::setName) .def_property("precision", &ModelSpecInternal::getPrecision, &ModelSpecInternal::setPrecision) - .def_property("time_precision", &ModelSpecInternal::getTimePrecision, &ModelSpecInternal::setTimePrecision) + .def_property("time_precision", &ModelSpecInternal::getTimePrecision, &ModelSpecInternal::setTimePrecision)) .def_property("dt", &ModelSpecInternal::getDT, &ModelSpecInternal::setDT) .def_property("timing_enabled", &ModelSpecInternal::isTimingEnabled, &ModelSpecInternal::setTiming) .def_property("batch_size", &ModelSpecInternal::getBatchSize, &ModelSpecInternal::setBatchSize) @@ -526,6 +526,22 @@ PYBIND11_MODULE(genn, m) .def("get_ps_var_location", &SynapseGroup::getPSVarLocation) .def("set_ps_var_location", &SynapseGroup::setPSVarLocation); + //------------------------------------------------------------------------ + // genn.ResolvedType + //------------------------------------------------------------------------ + pybind11::class_(m, "ResolvedType") + .def("__eq__", + [](const Type::ResolvedType &a, Type::ResolvedType b) { return a == b; }); + + //------------------------------------------------------------------------ + // genn.UnresolvedType + //------------------------------------------------------------------------ + pybind11::class_(m, "UnresolvedType") + .def(pybind11::init()) + .def(pybind11::init()) + .def("__eq__", + [](const Type::UnresolvedType &a, Type::UnresolvedType b) { return a == b; }); + //------------------------------------------------------------------------ // genn.DerivedParam //------------------------------------------------------------------------ @@ -583,7 +599,7 @@ PYBIND11_MODULE(genn, m) .def("get_calc_kernel_size_func", &InitToeplitzConnectivitySnippet::Base::getCalcKernelSizeFunc); //------------------------------------------------------------------------ - // genn.InitVarSnippetBaseBase + // genn.InitVarSnippetBase //------------------------------------------------------------------------ pybind11::class_(m, "InitVarSnippetBase") .def(pybind11::init<>()) diff --git a/pygenn/src/type.cc b/pygenn/src/type.cc deleted file mode 100644 index 19e627e1d0..0000000000 --- a/pygenn/src/type.cc +++ /dev/null @@ -1,31 +0,0 @@ -// PyBind11 includes -#include - -// GeNN includes -#include "type.h" - -using namespace GeNN::Type; - -//---------------------------------------------------------------------------- -// type -//---------------------------------------------------------------------------- -PYBIND11_MODULE(type, m) -{ - //------------------------------------------------------------------------ - // Attributes - //------------------------------------------------------------------------ - m.attr("Bool") = pybind11::cast(Bool); -} -/*Bool = CREATE_NUMERIC(bool, 0, ""); -inline static const ResolvedType Int8 = CREATE_NUMERIC(int8_t, 10, ""); -inline static const ResolvedType Int16 = CREATE_NUMERIC(int16_t, 20, ""); -inline static const ResolvedType Int32 = CREATE_NUMERIC(int32_t, 30, ""); -inline static const ResolvedType Int64 = CREATE_NUMERIC(int64_t, 40, ""); - -inline static const ResolvedType Uint8 = CREATE_NUMERIC(uint8_t, 10, "u"); -inline static const ResolvedType Uint16 = CREATE_NUMERIC(uint16_t, 20, "u"); -inline static const ResolvedType Uint32 = CREATE_NUMERIC(uint32_t, 30, "u"); -inline static const ResolvedType Uint64 = CREATE_NUMERIC(uint64_t, 40, "u"); - -inline static const ResolvedType Float = CREATE_NUMERIC(float, 50, "f"); -inline static const ResolvedType Double*/ diff --git a/pygenn/src/types.cc b/pygenn/src/types.cc new file mode 100644 index 0000000000..df63ae22b1 --- /dev/null +++ b/pygenn/src/types.cc @@ -0,0 +1,33 @@ +// PyBind11 includes +#include + +// GeNN includes +#include "type.h" + +using namespace GeNN::Type; + +//---------------------------------------------------------------------------- +// types +//---------------------------------------------------------------------------- +PYBIND11_MODULE(types, m) +{ + pybind11::module_::import("pygenn.genn"); + + //------------------------------------------------------------------------ + // Attributes + //------------------------------------------------------------------------ + m.attr("Bool") = pybind11::cast(Bool); + + m.attr("Int8") = pybind11::cast(Int8); + m.attr("Int16") = pybind11::cast(Int16); + m.attr("Int32") = pybind11::cast(Int32); + m.attr("Int64") = pybind11::cast(Int64); + + m.attr("Uint8") = pybind11::cast(Uint8); + m.attr("Uit16") = pybind11::cast(Uint16); + m.attr("Uint32") = pybind11::cast(Uint32); + m.attr("Uint64") = pybind11::cast(Uint64); + + m.attr("Float") = pybind11::cast(Float); + m.attr("Double") = pybind11::cast(Double); +} \ No newline at end of file diff --git a/setup.py b/setup.py index 4f31a875a0..73e257a39d 100644 --- a/setup.py +++ b/setup.py @@ -128,8 +128,8 @@ Pybind11Extension("genn", [os.path.join(pygenn_src, "genn.cc")], **genn_extension_kwargs), - Pybind11Extension("type", - [os.path.join(pygenn_src, "type.cc")], + Pybind11Extension("types", + [os.path.join(pygenn_src, "types.cc")], **genn_extension_kwargs), Pybind11Extension("init_sparse_connectivity_snippets", [os.path.join(pygenn_src, "initSparseConnectivitySnippets.cc")], diff --git a/src/genn/genn/modelSpec.cc b/src/genn/genn/modelSpec.cc index 8adf81ca74..a867634d21 100644 --- a/src/genn/genn/modelSpec.cc +++ b/src/genn/genn/modelSpec.cc @@ -49,46 +49,38 @@ ModelSpec::~ModelSpec() { } // --------------------------------------------------------------------------- -void ModelSpec::setPrecision(const Type::ResolvedType &precision) +void ModelSpec::setPrecision(const Type::UnresolvedType &precision) { - if (!precision.isNumeric()) { + // Resolve type + // **NOTE** no type context as that would be circular! + const auto resolved = precision.resolve({}); + if (!resolved.isNumeric()) { throw std::runtime_error("Only numeric types can be used for precision"); } else { - if (precision.getNumeric().isIntegral) { + if (resolved.getNumeric().isIntegral) { throw std::runtime_error("Only floating point types can be used for precision"); } - m_Precision = precision; + m_Precision = resolved; } } // --------------------------------------------------------------------------- -void ModelSpec::setPrecision(const std::string &precision) -{ - // Parse type string and set precision - // **NOTE** no type context as that would be circular! - setPrecision(Utils::parseNumericType(precision, {})); -} -// --------------------------------------------------------------------------- -void ModelSpec::setTimePrecision(const Type::ResolvedType &timePrecision) +void ModelSpec::setTimePrecision(const Type::UnresolvedType &timePrecision) { - if (!timePrecision.isNumeric()) { + // Resolve type + // **NOTE** no type context as that would be circular! + const auto resolved = timePrecision.resolve({}); + if (!resolved.isNumeric()) { throw std::runtime_error("Only numeric types can be used for timeprecision"); } else { - if (timePrecision.getNumeric().isIntegral) { + if (resolved.getNumeric().isIntegral) { throw std::runtime_error("Only floating point types can be used for time precision"); } - m_TimePrecision = timePrecision; + m_TimePrecision = resolved; } } // --------------------------------------------------------------------------- -void ModelSpec::setTimePrecision(const std::string &timePrecision) -{ - // Parse type string and set time precision - // **NOTE** no type context as that would be circular! - setTimePrecision(Utils::parseNumericType(timePrecision, {})); -} -// --------------------------------------------------------------------------- unsigned int ModelSpec::getNumNeurons() const { // Return sum of local neuron group sizes From e26197d5428e7cb953a7c92d9aff928d3cdaf958 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 24 Jul 2023 12:30:31 +0100 Subject: [PATCH 434/725] fixed weird flag format bug --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 73e257a39d..18bcf8f1b2 100644 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ # If this is Windows, turn off warnings about dll-interface being required # for stuff to be used by clients and prevent windows.h exporting TOO many awful macros if WIN: - extension_kwargs["extra_compile_args"].extend(["/wd\"4251\"", "-DWIN32_LEAN_AND_MEAN", "-DNOMINMAX"]) + extension_kwargs["extra_compile_args"].extend(["/wd4251", "-DWIN32_LEAN_AND_MEAN", "-DNOMINMAX"]) # Extend these kwargs for extensions which link against GeNN genn_extension_kwargs = deepcopy(extension_kwargs) From d56eaa5133a82a906e8a31d27bcd63c3cf3a79e0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 24 Jul 2023 12:39:25 +0100 Subject: [PATCH 435/725] fixed stray "inSyn" in runner generation --- src/genn/genn/code_generator/generateRunner.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index b2a405395c..6649617c66 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -1427,11 +1427,11 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, if(!s.second.isPSModelFused()) { // Add code to push and pull inSyn genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, s.second.getInSynLocation(), - backend.getPreferences().automaticCopy, "inSyn" + s.second.getName(), synapseGroupStatePushPullFunctions, + backend.getPreferences().automaticCopy, "outPost" + s.second.getName(), synapseGroupStatePushPullFunctions, [&]() { backend.genVariablePushPull(runnerPushFunc, runnerPullFunc, - model.getPrecision(), "inSyn" + s.second.getName(), + model.getPrecision(), "outPost" + s.second.getName(), s.second.getInSynLocation(), true, s.second.getTrgNeuronGroup()->getNumNeurons() * batchSize); }); From d943368d496fdac76ca97975caf5c880e2d5d9b5 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 24 Jul 2023 12:39:43 +0100 Subject: [PATCH 436/725] added new staticPulseConstantWeight model to PyGeNN --- pygenn/src/weightUpdateModels.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/pygenn/src/weightUpdateModels.cc b/pygenn/src/weightUpdateModels.cc index e823e879d8..093140c7bc 100644 --- a/pygenn/src/weightUpdateModels.cc +++ b/pygenn/src/weightUpdateModels.cc @@ -29,6 +29,7 @@ PYBIND11_MODULE(weight_update_models, m) // **THINK** with some cunning, standard macros could maybe populate // an array with instance pointers that we could loop over m.def("StaticPulse", &getBaseInstance, pybind11::return_value_policy::reference); + m.def("StaticPulseConstantWeight", &getBaseInstance, pybind11::return_value_policy::reference); m.def("StaticPulseDendriticDelay", &getBaseInstance, pybind11::return_value_policy::reference); m.def("StaticGraded", &getBaseInstance, pybind11::return_value_policy::reference); m.def("PiecewiseSTDP", &getBaseInstance, pybind11::return_value_policy::reference); From 9a3f45ab32f5680cbfd74a4fe3aa74472cf424bb Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 24 Jul 2023 12:39:58 +0100 Subject: [PATCH 437/725] fixed typo --- pygenn/src/genn.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index 7019d62afc..f185e34d9d 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -347,7 +347,7 @@ PYBIND11_MODULE(genn, m) //-------------------------------------------------------------------- .def_property("name", &ModelSpecInternal::getName, &ModelSpecInternal::setName) .def_property("precision", &ModelSpecInternal::getPrecision, &ModelSpecInternal::setPrecision) - .def_property("time_precision", &ModelSpecInternal::getTimePrecision, &ModelSpecInternal::setTimePrecision)) + .def_property("time_precision", &ModelSpecInternal::getTimePrecision, &ModelSpecInternal::setTimePrecision) .def_property("dt", &ModelSpecInternal::getDT, &ModelSpecInternal::setDT) .def_property("timing_enabled", &ModelSpecInternal::isTimingEnabled, &ModelSpecInternal::setTiming) .def_property("batch_size", &ModelSpecInternal::getBatchSize, &ModelSpecInternal::setBatchSize) From e3705fa5c58936684350e51322d8cb82ca4b4b58 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 09:23:40 +0100 Subject: [PATCH 438/725] turn precisions into UnresolvedType at correct point --- pygenn/genn_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index 6a76efa5da..89beebf6cb 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -66,8 +66,8 @@ ModelSpecInternal, NeuronGroup, NeuronModelBase, ParamVal, PlogSeverity, PostsynapticModelBase, SparseConnectivityInit, SynapseGroup, SynapseMatrixType, - ToeplitzConnectivityInit, Var, VarInit, VarLocation, - VarRef, WeightUpdateModelBase) + ToeplitzConnectivityInit, UnresolvedType, Var, VarInit, + VarLocation, VarRef, WeightUpdateModelBase) from .shared_library_model import (SharedLibraryModelDouble, SharedLibraryModelFloat) @@ -157,14 +157,14 @@ def __init__(self, precision="float", model_name="GeNNModel", super(GeNNModel, self).__init__() # Set precision - self.precision = precision + self.precision = UnresolvedType(precision) # Based on time precision, create correct type # of SLM class and determine GeNN time type # **NOTE** all SLM uses its template parameter for is time variable - self.time_precision = (precision if time_precision is None - else time_precision) - print(self.time_precision, types.Float) + self.time_precision = UnresolvedType(self.precision + if time_precision is None + else time_precision) if self.time_precision == types.Float: self._slm = SharedLibraryModelFloat() elif self.time_precision == types.Double: @@ -521,7 +521,7 @@ def build(self, path_to_model="./", force_rebuild=False): share_path = path.join(path.split(__file__)[0], "share") # Finalize model - self.finalize() + self.finalise() # Create suitable preferences object for backend preferences = self._backend_module.Preferences() From 880d095e603fef0f4d9dd8ee181851a7bf21f3b0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 10:16:47 +0100 Subject: [PATCH 439/725] type context is provided by ModelSpec rather than ModelSpecMerged --- .../genn/genn/code_generator/modelSpecMerged.h | 18 ++++++------------ include/genn/genn/modelSpec.h | 4 +++- include/genn/genn/modelSpecInternal.h | 1 + src/genn/genn/code_generator/generateRunner.cc | 12 ++++++------ src/genn/genn/modelSpec.cc | 9 +++------ 5 files changed, 19 insertions(+), 25 deletions(-) diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h index 757cbdf20d..d1921f2357 100644 --- a/include/genn/genn/code_generator/modelSpecMerged.h +++ b/include/genn/genn/code_generator/modelSpecMerged.h @@ -35,7 +35,7 @@ class GENN_EXPORT ModelSpecMerged ModelSpecMerged(const ModelSpecInternal &model) : m_Model(model), m_NeuronUpdateSupportCode("NeuronUpdateSupportCode"), m_PostsynapticDynamicsSupportCode("PostsynapticDynamicsSupportCode"), m_PresynapticUpdateSupportCode("PresynapticUpdateSupportCode"), m_PostsynapticUpdateSupportCode("PostsynapticUpdateSupportCode"), - m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode"), m_TypeContext{{"scalar", model.getPrecision()}, {"timepoint", model.getTimePrecision()}} + m_SynapseDynamicsSupportCode("SynapseDynamicsSupportCode") { } ModelSpecMerged(const ModelSpecMerged&) = delete; @@ -92,9 +92,6 @@ class GENN_EXPORT ModelSpecMerged //-------------------------------------------------------------------------- //! Get underlying, unmerged model const ModelSpecInternal &getModel() const{ return m_Model; } - - //! Get type context used to resolve all types used in model - const Type::TypeContext &getTypeContext() const{ return m_TypeContext; } //! Get merged neuron groups which require updating const std::vector &getMergedNeuronUpdateGroups() const{ return m_MergedNeuronUpdateGroups; } @@ -350,7 +347,7 @@ class GENN_EXPORT ModelSpecMerged } template - void createMergedGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + void createMergedGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, const std::vector> &unmergedGroups, std::vector &mergedGroups, D getHashDigest, GenMergedGroupFn generateGroup, bool host = false) { @@ -371,7 +368,7 @@ class GENN_EXPORT ModelSpecMerged size_t i = 0; for(const auto &p : protoMergedGroups) { // Construct new merged group object - mergedGroups.emplace_back(i, m_TypeContext, p.second); + mergedGroups.emplace_back(i, m_Model.getTypeContext(), p.second); // Call generate function generateGroup(mergedGroups.back()); @@ -404,7 +401,7 @@ class GENN_EXPORT ModelSpecMerged } template - void createMergedGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, + void createMergedGroups(const BackendBase &backend, BackendBase::MemorySpaces &memorySpaces, const std::map &groups, std::vector &mergedGroups, F filter, D getHashDigest, G generateGroup, bool host = false) { @@ -417,8 +414,8 @@ class GENN_EXPORT ModelSpecMerged } // Merge filtered vector - createMergedGroups(backend, memorySpaces, unmergedGroups, mergedGroups, - getHashDigest, generateGroup, host); + createMergedGroups(backend, memorySpaces, unmergedGroups, + mergedGroups, getHashDigest, generateGroup, host); } //-------------------------------------------------------------------------- @@ -519,8 +516,5 @@ class GENN_EXPORT ModelSpecMerged //! Map containing mapping of original extra global param names to their locations within merged groups MergedEGPMap m_MergedEGPs; - - //! Type context used to resolve all types used in model - Type::TypeContext m_TypeContext; }; } // namespace GeNN::CodeGenerator diff --git a/include/genn/genn/modelSpec.h b/include/genn/genn/modelSpec.h index 7c89dc7c2a..0362ed1ce8 100644 --- a/include/genn/genn/modelSpec.h +++ b/include/genn/genn/modelSpec.h @@ -696,7 +696,7 @@ class GENN_EXPORT ModelSpec //! Get hash digest used for detecting changes boost::uuids::detail::sha1::digest_type getHashDigest() const; - Type::TypeContext getTypeContext() const; + const Type::TypeContext &getTypeContext() const{ return m_TypeContext; } //! Get std::map containing local named NeuronGroup objects in model const std::map &getNeuronGroups() const{ return m_LocalNeuronGroups; } @@ -754,6 +754,8 @@ class GENN_EXPORT ModelSpec //! Type of floating point variables (float, double, ...; default: float) Type::ResolvedType m_Precision; + Type::TypeContext m_TypeContext; + //! Type of floating point variables used to store time std::optional m_TimePrecision; diff --git a/include/genn/genn/modelSpecInternal.h b/include/genn/genn/modelSpecInternal.h index 3ae2aa8832..abeecfcc3a 100644 --- a/include/genn/genn/modelSpecInternal.h +++ b/include/genn/genn/modelSpecInternal.h @@ -27,5 +27,6 @@ class ModelSpecInternal : public ModelSpec using ModelSpec::zeroCopyInUse; using ModelSpec::isRecordingInUse; using ModelSpec::getHashDigest; + using ModelSpec::getTypeContext; }; } // namespace GeNN diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 6649617c66..63678609fd 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -273,7 +273,7 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & CodeStream &extraGlobalParam, const Type::UnresolvedType &type, const std::string &name, bool apiRequired, VarLocation loc) { // Resolved type - const auto resolvedType = type.resolve(modelMerged.getTypeContext()); + const auto resolvedType = type.resolve(modelMerged.getModel().getTypeContext()); // Generate variables backend.genVariableDefinition(definitionsVar, definitionsInternalVar, resolvedType, name, loc); @@ -414,7 +414,7 @@ void genRunnerVars(const ModelSpecMerged &modelMerged, const BackendBase &backen for(const auto &var : varAdaptor.getDefs()) { const auto *varInitSnippet = varAdaptor.getInitialisers().at(var.name).getSnippet(); const bool autoInitialized = !varInitSnippet->getCode().empty(); - const auto resolvedType = var.type.resolve(modelMerged.getTypeContext()); + const auto resolvedType = var.type.resolve(modelMerged.getModel().getTypeContext()); genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, resolvedType, var.name + group.getName(), varAdaptor.getLoc(var.name), autoInitialized, getSizeFn(group, var), mem, statePushPullFunctions); @@ -438,7 +438,7 @@ void genRunnerFusedVars(const ModelSpecMerged &modelMerged, const BackendBase &b // Loop through variables const V varAdaptor(group); for(const auto &var : varAdaptor.getDefs()) { - const auto resolvedType = var.type.resolve(modelMerged.getTypeContext()); + const auto resolvedType = var.type.resolve(modelMerged.getModel().getTypeContext()); backend.genArray(definitionsVar, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, resolvedType, var.name + varAdaptor.getNameSuffix(), varAdaptor.getLoc(var.name), getSizeFn(group, var), mem); @@ -462,7 +462,7 @@ void genRunnerFusedVarPushPull(const ModelSpecMerged &modelMerged, const Backend const V varAdaptor(group); for(const auto &var : varAdaptor.getDefs()) { const bool autoInitialized = !varAdaptor.getInitialisers().at(var.name).getSnippet()->getCode().empty(); - const auto resolvedType = var.type.resolve(modelMerged.getTypeContext()); + const auto resolvedType = var.type.resolve(modelMerged.getModel().getTypeContext()); genVarPushPullScope(definitionsFunc, runnerPushFunc, runnerPullFunc, varAdaptor.getLoc(var.name), backend.getPreferences().automaticCopy, var.name + group.getName(), groupStatePushPullFunctions, [&]() @@ -1082,7 +1082,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, const unsigned int numElements = getNumVarElements(var.access, n.second.getNumNeurons()); const size_t count = n.second.isVarQueueRequired(var.name) ? numCopies * numElements * n.second.getNumDelaySlots() : numCopies * n.second.getNumNeurons(); const bool autoInitialized = !varInitSnippet->getCode().empty(); - const auto resolvedType = var.type.resolve(modelMerged.getTypeContext()); + const auto resolvedType = var.type.resolve(modelMerged.getModel().getTypeContext()); genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, runnerPushFunc, runnerPullFunc, resolvedType, var.name + n.first, n.second.getVarLocation(var.name), autoInitialized, count, mem, neuronStatePushPullFunctions); @@ -1395,7 +1395,7 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, for(const auto &wuVar : wu->getVars()) { const auto *varInitSnippet = s.second.getWUVarInitialisers().at(wuVar.name).getSnippet(); const bool autoInitialized = !varInitSnippet->getCode().empty(); - const auto resolvedType = wuVar.type.resolve(modelMerged.getTypeContext()); + const auto resolvedType = wuVar.type.resolve(modelMerged.getModel().getTypeContext()); if(individualWeights) { const size_t size = (size_t)s.second.getSrcNeuronGroup()->getNumNeurons() * (size_t)backend.getSynapticMatrixRowStride(s.second); genVariable(backend, definitionsVar, definitionsFunc, definitionsInternalVar, runnerVarDecl, runnerVarAlloc, runnerVarFree, diff --git a/src/genn/genn/modelSpec.cc b/src/genn/genn/modelSpec.cc index a867634d21..f23d3878f5 100644 --- a/src/genn/genn/modelSpec.cc +++ b/src/genn/genn/modelSpec.cc @@ -222,8 +222,10 @@ CustomUpdateWU *ModelSpec::addCustomUpdate(const std::string &name, const std::s // --------------------------------------------------------------------------- void ModelSpec::finalise() { + // Build type context + m_TypeContext = {{"scalar", getPrecision()}, {"timepoint", getTimePrecision()}}; + // Finalise neuron groups - const auto typeContext = getTypeContext(); for(auto &n : m_LocalNeuronGroups) { n.second.finalise(m_DT); } @@ -363,11 +365,6 @@ boost::uuids::detail::sha1::digest_type ModelSpec::getHashDigest() const return hash.get_digest(); } // --------------------------------------------------------------------------- -Type::TypeContext ModelSpec::getTypeContext() const -{ - return Type::TypeContext{{"scalar", getPrecision()}, {"timepoint", getTimePrecision()}}; -} -// --------------------------------------------------------------------------- NeuronGroupInternal *ModelSpec::findNeuronGroupInternal(const std::string &name) { // If a matching local neuron group is found, return it From 831702f9b5e09933c631fa6e505ab4de15d49a9c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 10:16:57 +0100 Subject: [PATCH 440/725] fixed typo in err types --- pygenn/src/types.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygenn/src/types.cc b/pygenn/src/types.cc index df63ae22b1..b323982406 100644 --- a/pygenn/src/types.cc +++ b/pygenn/src/types.cc @@ -24,7 +24,7 @@ PYBIND11_MODULE(types, m) m.attr("Int64") = pybind11::cast(Int64); m.attr("Uint8") = pybind11::cast(Uint8); - m.attr("Uit16") = pybind11::cast(Uint16); + m.attr("Uint16") = pybind11::cast(Uint16); m.attr("Uint32") = pybind11::cast(Uint32); m.attr("Uint64") = pybind11::cast(Uint64); From 96991cfc8d5d2656fc4695130ea6aa32245929f2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 10:17:39 +0100 Subject: [PATCH 441/725] Exposed some more type stuff * UnresolvedType.resolve * Type context * __hash__ for ResolvedType --- pygenn/src/genn.cc | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index f185e34d9d..3b8e444946 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -361,6 +361,7 @@ PYBIND11_MODULE(genn, m) .def_property_readonly("num_neurons", &ModelSpecInternal::getNumNeurons) .def_property_readonly("recording_in_use", &ModelSpecInternal::isRecordingInUse) + .def_property_readonly("type_context", &ModelSpecInternal::getTypeContext) //-------------------------------------------------------------------- // Methods @@ -530,6 +531,19 @@ PYBIND11_MODULE(genn, m) // genn.ResolvedType //------------------------------------------------------------------------ pybind11::class_(m, "ResolvedType") + .def("__hash__", + [](const Type::ResolvedType &a) + { + // Calculate hash digest + boost::uuids::detail::sha1 shaHash; + Type::updateHash(a, shaHash); + const auto shaDigest = shaHash.get_digest(); + + // Return size-t worth of hash + size_t hash; + memcpy(&hash, &shaDigest[0], sizeof(size_t)); + return hash; + }) .def("__eq__", [](const Type::ResolvedType &a, Type::ResolvedType b) { return a == b; }); @@ -539,6 +553,7 @@ PYBIND11_MODULE(genn, m) pybind11::class_(m, "UnresolvedType") .def(pybind11::init()) .def(pybind11::init()) + .def("resolve", &Type::UnresolvedType::resolve) .def("__eq__", [](const Type::UnresolvedType &a, Type::UnresolvedType b) { return a == b; }); From 15728cb650f908680690059c40e67f7ab40bd6ab Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 10:18:00 +0100 Subject: [PATCH 442/725] PyGeNN works! --- pygenn/genn_groups.py | 74 ++++++++++++------------------------ pygenn/genn_model.py | 34 ++++++----------- pygenn/model_preprocessor.py | 14 +------ 3 files changed, 36 insertions(+), 86 deletions(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index a75226c38d..c73363d260 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -13,7 +13,7 @@ from weakref import proxy import numpy as np -from . import neuron_models +from . import neuron_models, types from .genn import (CustomUpdateWU, SynapseMatrixConnectivity, SynapseMatrixWeight, VarAccessDuplication, VarLocation) from .model_preprocessor import prepare_model, ExtraGlobalParameter, Variable @@ -90,11 +90,7 @@ def _assign_ext_ptr_array(self, var_name, var_size, var_type): Args: var_name -- string a fully qualified name of the variable to assign var_size -- int the size of the variable - var_type -- string type of the variable. The supported types are - char, unsigned char, short, unsigned short, int, - unsigned int, long, unsigned long, long long, - unsigned long long, float, double, long double - and scalar. + var_type -- ResolvedType object Returns numpy array of type var_type @@ -103,10 +99,7 @@ def _assign_ext_ptr_array(self, var_name, var_size, var_type): internal_var_name = var_name + self.name - if var_type == "scalar": - var_type = self._model.precision - - # Get numpy data type corresponding to type string + # Get numpy data type corresponding to type dtype = self._model.genn_types[var_type] # Calculate bytes @@ -123,11 +116,7 @@ def _assign_ext_ptr_single(self, var_name, var_type): Args: var_name -- string a fully qualified name of the variable to assign - var_type -- string type of the variable. The supported types are - char, unsigned char, short, unsigned short, int, - unsigned int, long, unsigned long, long long, - unsigned long long, float, double, long double - and scalar. + var_type -- ResolvedType object Returns numpy array of type var_type @@ -136,10 +125,7 @@ def _assign_ext_ptr_single(self, var_name, var_type): internal_var_name = var_name + self.name - if var_type == "scalar": - var_type = self._model.precision - - # Get numpy data type corresponding to type string + # Get numpy data type corresponding to type dtype = self._model.genn_types[var_type] # Get dtype view of array memoryview @@ -162,11 +148,6 @@ def _push_extra_global_param_to_device(self, egp_name, egp_dict=None): # Retrieve EGP from dictionary egp = egp_dict[egp_name] - # If EGP is scalar, give error - if egp.is_scalar: - raise Exception("Only pointer-type extra global parameters " - "need to be pushed") - self._model._slm.push_extra_global_param_to_device(self.name, egp_name, len(egp.values)) @@ -184,11 +165,6 @@ def _pull_extra_global_param_from_device(self, egp_name, egp_dict=None): # Retrieve EGP from dictionary egp = egp_dict[egp_name] - # If EGP is scalar, give error - if egp.is_scalar: - raise Exception("Only pointer-type extra global parameters " - "need to be pulled") - self._model._slm.pull_extra_global_param_from_device(self.name, egp_name, len(egp.values)) @@ -222,8 +198,9 @@ def _load_vars(self, vars, size=None, var_dict=None, get_location_fn=None): else size) # Get view + resolved_type = var_data.type.resolve(self._model.type_context) var_data.view = self._assign_ext_ptr_array(v.name, var_size * num_copies, - var_data.type) + resolved_type) # If there is more than one copy, reshape view to 2D if num_copies > 1: @@ -243,13 +220,8 @@ def _load_egp(self, egp_dict=None, egp_suffix=""): # Loop through extra global params for egp_name, egp_data in iteritems(egp_dict): - if egp_data.is_scalar: - # Assign view - egp_data.view = self._assign_ext_ptr_single(egp_name + egp_suffix, - egp_data.type) - # Copy values - egp_data.view[:] = egp_data.values - elif egp_data.values is not None: + resolved_type = egp_data.type.resolve(self._model.type_context) + if egp_data.values is not None: # Allocate memory self._model._slm.allocate_extra_global_param( self.name, egp_name + egp_suffix, len(egp_data.values)) @@ -257,7 +229,7 @@ def _load_egp(self, egp_dict=None, egp_suffix=""): # Assign view egp_data.view = self._assign_ext_ptr_array(egp_name + egp_suffix, len(egp_data.values), - egp_data.type) + resolved_type) # Copy values egp_data.view[:] = egp_data.values @@ -338,7 +310,7 @@ def load(self, num_recording_timesteps): # Assign pointer to recording data self._spike_recording_data = self._assign_ext_ptr_array( - "recordSpk", recording_words, "uint32_t") + "recordSpk", recording_words, types.Uint32) # If spike-event recording is enabled if self.spike_event_recording_enabled: @@ -348,7 +320,7 @@ def load(self, num_recording_timesteps): # Assign pointer to recording data self._spike_event_recording_data = self._assign_ext_ptr_array( - "recordSpkEvent", recording_words, "uint32_t") + "recordSpkEvent", recording_words, types.Uint32) if self.num_delay_slots > 1: self.spike_que_ptr = self._model._slm.assign_external_pointer_single_ui( @@ -602,11 +574,11 @@ def push_connectivity_to_device(self): def pull_in_syn_from_device(self): """Pull synaptic input current from device""" - self.pull_var_from_device("inSyn") + self.pull_var_from_device("outPost") def push_in_syn_to_device(self): """Push synaptic input current to device""" - self.push_var_to_device("inSyn") + self.push_var_to_device("outPost") def pull_psm_extra_global_param_from_device(self, egp_name): """Wrapper around GeNNModel.pull_extra_global_param_from_device @@ -640,7 +612,7 @@ def load(self): self._sparse_ind_type) row_length = self._assign_ext_ptr_array("rowLength", self.src.size, - "unsigned int") + types.Uint32) # add pointers to the object self._ind = ind self._row_lengths = row_length @@ -687,9 +659,10 @@ def load(self): num_copies = (1 if (v.access & VarAccessDuplication.SHARED) != 0 else self._model.batch_size) # Get view + resolved_type = var_data.type.resolve(self._model.type_context) var_data.view = self._assign_ext_ptr_array( v.name, self.weight_update_var_size * num_copies, - var_data.type) + resolved_type) # If there is more than one copy, reshape view to 2D if num_copies > 1: @@ -726,13 +699,13 @@ def load(self): # If it's inSyn is accessible on the host if self.in_syn_location & VarLocation.HOST: # Get view - self.in_syn = self._assign_ext_ptr_array( - "inSyn", self.trg.size * self._model.batch_size, - "scalar") + self.out_post = self._assign_ext_ptr_array( + "outPost", self.trg.size * self._model.batch_size, + self._model.precision) # Reshape to expose batches - self.in_syn = np.reshape(self.in_syn, (self._model.batch_size, - self.trg.size)) + self.out_post = np.reshape(self.out_post, (self._model.batch_size, + self.trg.size)) # Load extra global parameters self._load_egp() @@ -876,8 +849,9 @@ def load(self): # Get view size = self._synapse_group.weight_update_var_size * num_copies + resolved_type = var_data.type.resolve(self._model.type_context) var_data.view = self._assign_ext_ptr_array( - v.name, size, var_data.type) + v.name, size, resolved_type) # If there is more than one copy, reshape view to 2D if num_copies > 1: diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index 89beebf6cb..3af651d37f 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -195,29 +195,17 @@ def __init__(self, precision="float", model_name="GeNNModel", # Build dictionary containing conversions # between GeNN C++ types and numpy types self.genn_types = { - "float": np.float32, - "double": np.float64, - "int": np.int32, - "unsigned int": np.uint32, - "short": np.int16, - "unsigned short": np.uint16, - "char": np.int8, - "unsigned char": np.uint8, - "uint64_t": np.uint64, - "int64_t": np.int64, - "uint32_t": np.uint32, - "int32_t": np.int32, - "uint16_t": np.uint16, - "int16_t": np.int16, - "uint8_t": np.uint8, - "int8_t": np.int8, - "bool": np.bool8} - - # Add "scalar" type to genn_types - pointing at float or double as appropriate - if self.precision == types.Float: - self.genn_types["scalar"] = self.genn_types["float"] - else: - self.genn_types["scalar"] = self.genn_types["double"] + types.Float: np.float32, + types.Double: np.float64, + types.Uint32: np.int64, + types.Int32: np.uint64, + types.Int32: np.int32, + types.Uint32: np.uint32, + types.Int16: np.int16, + types.Uint16: np.uint16, + types.Int8: np.int8, + types.Uint8: np.uint8, + types.Bool: np.bool8} @property def backend_name(self): diff --git a/pygenn/model_preprocessor.py b/pygenn/model_preprocessor.py index 0b81921c2f..6ee1fa0627 100644 --- a/pygenn/model_preprocessor.py +++ b/pygenn/model_preprocessor.py @@ -172,13 +172,7 @@ def __init__(self, variable_name, variable_type, group, values=None): Keyword args: values -- iterable """ - if variable_type[-1] == "*": - self.is_scalar = False - self.type = variable_type[:-1] - else: - self.is_scalar = True - self.type = variable_type - + self.type = variable_type self.group = group if type(group) in ProxyTypes else proxy(group) self.name = variable_name self.view = None @@ -192,12 +186,6 @@ def set_values(self, values): """ if values is None: self.values = None - elif self.is_scalar: - if isinstance(values, Number): - self.values = values - else: - raise ValueError("scalar extra global variables can only be " - "initialised with a number") else: # Try and iterate values try: From d03eabed4054b10caa119c1140c2b4b7c1af186d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 12:43:35 +0100 Subject: [PATCH 443/725] removed some overloaded getVarLocation methods --- include/genn/genn/customUpdate.h | 3 --- include/genn/genn/neuronGroup.h | 3 --- 2 files changed, 6 deletions(-) diff --git a/include/genn/genn/customUpdate.h b/include/genn/genn/customUpdate.h index 18a61b548d..e7d6e42fdd 100644 --- a/include/genn/genn/customUpdate.h +++ b/include/genn/genn/customUpdate.h @@ -46,9 +46,6 @@ class GENN_EXPORT CustomUpdateBase //! Get variable location for custom update model state variable VarLocation getVarLocation(const std::string &varName) const; - //! Get variable location for custom update model state variable - VarLocation getVarLocation(size_t index) const{ return m_VarLocation.at(index); } - //! Is var init code required for any variables in this custom update group's custom update model? bool isVarInitRequired() const; diff --git a/include/genn/genn/neuronGroup.h b/include/genn/genn/neuronGroup.h index 547310df93..9e7306254b 100644 --- a/include/genn/genn/neuronGroup.h +++ b/include/genn/genn/neuronGroup.h @@ -159,9 +159,6 @@ class GENN_EXPORT NeuronGroup //! Get location of neuron model state variable by name VarLocation getVarLocation(const std::string &varName) const; - //! Get location of neuron model state variable by index - VarLocation getVarLocation(size_t index) const{ return m_VarLocation.at(index); } - //! Get location of neuron model extra global parameter by name /*! This is only used by extra global parameters which are pointers*/ VarLocation getExtraGlobalParamLocation(const std::string ¶mName) const; From 83ca88e4bf9b6d00eec680ebe8831d548dc762f1 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 12:45:51 +0100 Subject: [PATCH 444/725] rough go at hooking up custom connectivity updates to PyGeNN --- pygenn/__init__.py | 10 +- pygenn/genn_groups.py | 65 ++- pygenn/genn_model.py | 424 +++++++++++++------ pygenn/src/customConnectivityUpdateModels.cc | 31 ++ pygenn/src/genn.cc | 83 +++- setup.py | 3 + 6 files changed, 478 insertions(+), 138 deletions(-) create mode 100644 pygenn/src/customConnectivityUpdateModels.cc diff --git a/pygenn/__init__.py b/pygenn/__init__.py index 0b19b543ff..092997cdae 100644 --- a/pygenn/__init__.py +++ b/pygenn/__init__.py @@ -7,7 +7,15 @@ create_psm_egp_ref, create_wu_egp_ref, PlogSeverity, SpanType, SynapseMatrixType, VarAccess, VarAccessMode, VarLocation) -from .genn_model import (GeNNModel, init_sparse_connectivity, +from .genn_model import (GeNNModel, create_neuron_model, + create_postsynaptic_model, + create_weight_update_model, + create_current_source_model, + create_custom_update_model, + create_custom_connectivity_update_model, + create_init_var_snippet, + create_sparse_connect_init_snippet, + init_sparse_connectivity, init_toeplitz_connectivity, init_var) if sys.version_info >= (3, 8): diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index c73363d260..00d49b3ef2 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -816,7 +816,7 @@ def unload(self): class CustomUpdateMixin(GroupMixin): """Class representing a custom update""" def _init_group(self, model, var_space): - """Init NeuronGroup + """Init CustomUpdate Args: name -- string name of the group @@ -891,3 +891,66 @@ def _synapse_group(self): # Return Python synapse group reference from # first (arbitrarily) variable reference return next(itervalues(self.var_references)).synapse_group + +class CustomConnectivityUpdateMixin(GroupMixin): + """Class representing a custom connectivity update""" + def _init_group(self, model, var_space, pre_var_space, post_var_space): + """Init CustomConnectivityUpdateGroup + + Args: + name -- string name of the group + model -- pygenn.genn_model.GeNNModel this neuron group is part of + """ + super(CustomConnectivityUpdateMixin, self)._init_group(model) + self.vars, self.extra_global_params = prepare_model( + self.model, self, var_space) + self.pre_vars = {vnt.name: Variable(vnt.name, vnt.type, + pre_var_space[vnt.name], self) + for vnt in self.model.get_pre_vars()} + self.post_vars = {vnt.name: Variable(vnt.name, vnt.type, + post_var_space[vnt.name], self) + for vnt in self.model.get_post_vars()} + + def load(self): + # Loop through state variables + for v in self.model.get_vars(): + # Get corresponding data from dictionary + var_data = self.vars[v.name] + + # If variable is located on host + var_loc = self.get_var_location(v.name) + if var_loc & VarLocation.HOST: + # Get view + size = self._synapse_group.weight_update_var_size + resolved_type = var_data.type.resolve(self._model.type_context) + var_data.view = self._assign_ext_ptr_array( + v.name, size, resolved_type) + + # Initialise variable if necessary + self._synapse_group._init_wum_var(var_data, 1) + + # Load any var initialisation egps associated with this variable + self._load_egp(var_data.extra_global_params, v.name) + + # Load pre and postsynaptic variables + self._load_vars(self.model.get_pre_vars(), self.src.size, + self.pre_vars, self.get_pre_var_location) + self._load_vars(self.model.get_post_vars(), self.trg.size, + self.post_vars, self.get_post_var_location) + + # Load custom update extra global parameters + self._load_egp() + + def load_init_egps(self): + # Load any egps used for variable initialisation + self._load_var_init_egps() + + # Load any egps used for pre and postsynaptic variable initialisation + self._load_var_init_egps(self.pre_vars) + self._load_var_init_egps(self.post_vars) + + def unload(self): + self._unload_vars() + self._unload_vars(self.pre_vars) + self._unload_vars(self.post_vars) + self._unload_egps() diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index 3af651d37f..4761b02a35 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -59,7 +59,8 @@ # pygenn imports from .genn import (generate_code, init_logging, CurrentSource, - CurrentSourceModelBase, CustomUpdate, + CurrentSourceModelBase, CustomConnectivityUpdate, + CustomConnectivityUpdateModelBase, CustomUpdate, CustomUpdateModelBase, CustomUpdateWU, DerivedParam, EGP, EGPRef, InitSparseConnectivitySnippetBase, InitToeplitzConnectivitySnippetBase, InitVarSnippetBase, @@ -71,16 +72,18 @@ from .shared_library_model import (SharedLibraryModelDouble, SharedLibraryModelFloat) -from .genn_groups import (CurrentSourceMixin, CustomUpdateMixin, - NeuronGroupMixin, SynapseGroupMixin) +from .genn_groups import (CurrentSourceMixin, CustomConnectivityUpdateMixin, + CustomUpdateMixin, NeuronGroupMixin, + SynapseGroupMixin) from .model_preprocessor import get_snippet, get_var_init -from . import (current_source_models, custom_update_models, - init_sparse_connectivity_snippets, +from . import (current_source_models, custom_connectivity_update_models, + custom_update_models, init_sparse_connectivity_snippets, init_toeplitz_connectivity_snippets, init_var_snippets, neuron_models, postsynaptic_models, types, weight_update_models) # Dynamically add Python mixin to wrapped class CurrentSource.__bases__ += (CurrentSourceMixin,) +CustomConnectivityUpdate.__bases__ += (CustomConnectivityUpdateMixin,) CustomUpdate.__bases__ += (CustomUpdateMixin,) CustomUpdateWU.__bases__ += (CustomUpdateMixin,) NeuronGroup.__bases__ += (NeuronGroupMixin,) @@ -190,6 +193,7 @@ def __init__(self, precision="float", model_name="GeNNModel", self.neuron_populations = {} self.synapse_populations = {} self.current_sources = {} + self.custom_connectivity_updates = {} self.custom_updates = {} # Build dictionary containing conversions @@ -489,6 +493,69 @@ class derived from c_update._init_group(self, var_space) self.custom_updates[cu_name] = c_update return c_update + + def add_custom_connectivity_update(self, cu_name, group_name, syn_group, + custom_conn_update_model, + param_space, var_space, pre_var_space, + post_var_space, var_ref_space, + pre_var_ref_space, post_var_ref_space): + """Add a custom connectivity update to the GeNN model + + Args: + cu_name -- name of the new custom connectivity update + group_name -- name of group this custom connectivity update + should be performed within + syn_group -- synapse group this custom connectivity + update should be attached to (either + name or SynapseGroup object) + custom_conn_update_model -- type of the CustomConnetivityUpdateModel + class as string or instance of + CustomConnectivityUpdateModel class + derived from + ``pygenn.genn_wrapper.CustomConnectivityUpdateModel.Custom`` + (see also pygenn.genn_model.create_custom_custom_connectivity_update_class) + param_space -- dict with param values for the + CustomConnectivityUpdateModel class + var_space -- dict with initial variable values for the + CustomConnectivityUpdateModel class + pre_var_space -- dict with initial presynaptic variable + values for the + CustomConnectivityUpdateModel class + post_var_space -- dict with initial postsynaptic variable + values for the + CustomConnectivityUpdateModel class + var_ref_space -- dict with variable references for the + CustomConnectivityUpdateModel class + pre_var_ref_space -- dict with presynaptic variable + references for the + CustomConnectivityUpdateModel class + post_var_ref_space -- dict with postsynaptic variable + references for the + CustomConnectivityUpdateModel class + """ + if self._built: + raise Exception("GeNN model already built") + + # Resolve custom update model + custom_connectivity_update_model = get_snippet( + custom_conn_update_model, CustomConnectivityUpdateModelBase, + custom_connectivity_update_models) + + # Extract parts of var_space which should be initialised by GeNN + var_init = get_var_init(var_space) + pre_var_init = get_var_init(pre_var_space) + post_var_init = get_var_init(post_var_space) + + # Use superclass to add population + c_update = super(GeNNModel, self).add_custom_connectivity_update( + cu_name, group_name, custom_connectivity_update_model, + param_space, var_init, pre_var_init, post_var_init, + var_ref_space, pre_var_ref_space, post_var_ref_space) + + # Setup back-reference, store group in dictionary and return + c_update._init_group(self, var_space, pre_var_space, post_var_space) + self.custom_connectivity_updates[cu_name] = c_update + return c_update def build(self, path_to_model="./", force_rebuild=False): """Finalize and build a GeNN model @@ -572,6 +639,10 @@ def load(self, path_to_model="./", num_recording_timesteps=None): for src_data in itervalues(self.current_sources): src_data.load_init_egps() + # Loop through custom connectivity updates + for cu_data in itervalues(self.custom_connectivity_updates): + cu_data.load_init_egps() + # Loop through custom updates for cu_data in itervalues(self.custom_updates): cu_data.load_init_egps() @@ -591,6 +662,10 @@ def load(self, path_to_model="./", num_recording_timesteps=None): for src_data in itervalues(self.current_sources): src_data.load() + # Loop through custom connectivity updates + for cu_data in itervalues(self.custom_connectivity_updates): + cu_data.load() + # Loop through custom updates for cu_data in itervalues(self.custom_updates): cu_data.load() @@ -607,7 +682,11 @@ def unload(self): # Loop through custom updates and unload for cu_data in itervalues(self.custom_updates): cu_data.unload() - + + # Loop through custom connectivity updates and unload + for cu_data in itervalues(self.custom_connectivity_updates): + cu_data.unload() + # Loop through current sources and unload for src_data in itervalues(self.current_sources): src_data.unload() @@ -655,7 +734,8 @@ def pull_recording_buffers_from_device(self): def end(self): """Free memory""" for group in [self.neuron_populations, self.synapse_populations, - self.current_sources, custom_updates]: + self.current_sources, self.custom_connectivity_updates, + self.custom_updates]: for g_name, g_dat in iteritems(group): for egp_name, egp_dat in iteritems(g_dat.extra_global_params): # if auto allocation is not enabled, let the user care @@ -768,20 +848,20 @@ class as string or instance of class init_toeplitz_connectivity_snippets) return InitToeplitzConnectivitySnippet(init_toeplitz_connect_snippet, param_space) -def create_custom_neuron_class(class_name, param_names=None, - var_name_types=None, derived_params=None, - sim_code=None, threshold_condition_code=None, - reset_code=None, support_code=None, - extra_global_params=None, - additional_input_vars=None, - is_auto_refractory_required=None): +def create_neuron_model(class_name, param_names=None, + var_name_types=None, derived_params=None, + sim_code=None, threshold_condition_code=None, + reset_code=None, support_code=None, + extra_global_params=None, + additional_input_vars=None, + is_auto_refractory_required=None): """This helper function creates a custom NeuronModel class. See also: - create_custom_postsynaptic_class - create_custom_weight_update_class - create_custom_current_source_class - create_custom_init_var_snippet_class - create_custom_sparse_connect_init_snippet_class + create_postsynaptic_model + create_weight_update_model + create_current_source_model + create_init_var_snippet + create_sparse_connect_init_snippet Args: class_name -- name of the new class @@ -831,22 +911,22 @@ def create_custom_neuron_class(class_name, param_names=None, body["is_auto_refractory_required"] = \ lambda self: is_auto_refractory_required - return create_custom_model_class( - class_name, NeuronModelBase, param_names, - var_name_types, derived_params, extra_global_params, body) + return create_model(class_name, NeuronModelBase, param_names, + var_name_types, derived_params, + extra_global_params, body) -def create_custom_postsynaptic_class(class_name, param_names=None, - var_name_types=None, derived_params=None, - decay_code=None, apply_input_code=None, - support_code=None, extra_global_params=None): +def create_postsynaptic_model(class_name, param_names=None, + var_name_types=None, derived_params=None, + decay_code=None, apply_input_code=None, + support_code=None, extra_global_params=None): """This helper function creates a custom PostsynapticModel class. See also: - create_custom_neuron_class - create_custom_weight_update_class - create_custom_current_source_class - create_custom_init_var_snippet_class - create_custom_sparse_connect_init_snippet_class + create_neuron_model + create_weight_update_model + create_current_source_model + create_init_var_snippet + create_sparse_connect_init_snippet Args: class_name -- name of the new class @@ -875,41 +955,40 @@ def create_custom_postsynaptic_class(class_name, param_names=None, if support_code is not None: body["get_support_code"] = lambda self: dedent(support_code) - return create_custom_model_class( - class_name, PostsynapticModelBase, param_names, - var_name_types, derived_params, extra_global_params, body) - - -def create_custom_weight_update_class(class_name, param_names=None, - var_name_types=None, - pre_var_name_types=None, - post_var_name_types=None, - derived_params=None, sim_code=None, - event_code=None, learn_post_code=None, - synapse_dynamics_code=None, - event_threshold_condition_code=None, - pre_spike_code=None, - post_spike_code=None, - pre_dynamics_code=None, - post_dynamics_code=None, - sim_support_code=None, - learn_post_support_code=None, - synapse_dynamics_suppport_code=None, - extra_global_params=None, - is_pre_spike_time_required=None, - is_post_spike_time_required=None, - is_pre_spike_event_time_required=None, - is_prev_pre_spike_time_required=None, - is_prev_post_spike_time_required=None, - is_prev_pre_spike_event_time_required=None, - custom_body=None): + return create_model(class_name, PostsynapticModelBase, param_names, + var_name_types, derived_params, + extra_global_params, body) + + +def create_weight_update_model(class_name, param_names=None, + var_name_types=None, + pre_var_name_types=None, + post_var_name_types=None, + derived_params=None, sim_code=None, + event_code=None, learn_post_code=None, + synapse_dynamics_code=None, + event_threshold_condition_code=None, + pre_spike_code=None, + post_spike_code=None, + pre_dynamics_code=None, + post_dynamics_code=None, + sim_support_code=None, + learn_post_support_code=None, + synapse_dynamics_suppport_code=None, + extra_global_params=None, + is_pre_spike_time_required=None, + is_post_spike_time_required=None, + is_pre_spike_event_time_required=None, + is_prev_pre_spike_time_required=None, + is_prev_post_spike_time_required=None, + is_prev_pre_spike_event_time_required=None): """This helper function creates a custom WeightUpdateModel class. See also: - create_custom_neuron_class - create_custom_postsynaptic_class - create_custom_current_source_class - create_custom_init_var_snippet_class - create_custom_sparse_connect_init_snippet_class + create_neuron_model + create_postsynaptic_model + create_current_source_model + create_init_var_snippet + create_sparse_connect_init_snippet Args: class_name -- name of the new class @@ -1037,23 +1116,23 @@ def create_custom_weight_update_class(class_name, param_names=None, body["is_prev_pre_spike_event_time_required"] = \ lambda self: is_prev_pre_spike_event_time_required - return create_custom_model_class( - class_name, WeightUpdateModelBase, param_names, - var_name_types, derived_params, extra_global_params, body) + return create_model(class_name, WeightUpdateModelBase, param_names, + var_name_types, derived_params, + extra_global_params, body) -def create_custom_current_source_class(class_name, param_names=None, - var_name_types=None, - derived_params=None, - injection_code=None, - extra_global_params=None): +def create_current_source_model(class_name, param_names=None, + var_name_types=None, + derived_params=None, + injection_code=None, + extra_global_params=None): """This helper function creates a custom NeuronModel class. See also: - create_custom_neuron_class - create_custom_weight_update_class - create_custom_current_source_class - create_custom_init_var_snippet_class - create_custom_sparse_connect_init_snippet_class + create_neuron_model + create_weight_update_model + create_current_source_model + create_init_var_snippet + create_sparse_connect_init_snippet Args: class_name -- name of the new class @@ -1077,25 +1156,24 @@ def create_custom_current_source_class(class_name, param_names=None, if injection_code is not None: body["get_injection_code"] = lambda self: dedent(injection_code) - return create_custom_model_class( - class_name, CurrentSourceModelBase, param_names, - var_name_types, derived_params, CurrentSourceModels, body) + return create_model(class_name, CurrentSourceModelBase, param_names, + var_name_types, derived_params, body) -def create_custom_custom_update_class(class_name, param_names=None, - var_name_types=None, - derived_params=None, - var_refs=None, - update_code=None, - extra_global_params=None, - extra_global_param_refs=None,): +def create_custom_update_model(class_name, param_names=None, + var_name_types=None, + derived_params=None, + var_refs=None, + update_code=None, + extra_global_params=None, + extra_global_param_refs=None): """This helper function creates a custom CustomUpdate class. See also: - create_custom_neuron_class - create_custom_weight_update_class - create_custom_current_source_class - create_custom_init_var_snippet_class - create_custom_sparse_connect_init_snippet_class + create_neuron_model + create_weight_update_model + create_current_source_model + create_init_var_snippet + create_sparse_connect_init Args: class_name -- name of the new class @@ -1132,20 +1210,101 @@ def create_custom_custom_update_class(class_name, param_names=None, class_name, CustomUpdateModelBase, param_names, var_name_types, derived_params, extra_global_params, body) +def create_custom_connectivity_update_model(class_name, + param_names=None, + var_name_types=None, + pre_var_name_types=None, + post_var_name_types=None, + derived_params=None, + var_refs=None, + pre_var_refs=None, + post_var_refs=None, + row_update_code=None, + host_update_code=None, + extra_global_params=None): + """This helper function creates a custom CustomConnectivityUpdate class. + See also: + create_neuron_model + create_weight_update_model + create_current_source_model + create_init_var_snippet + create_init_var_snippet + create_sparse_connect_init + + Args: + class_name -- name of the new class + + Keyword args: + param_names -- list of strings with param names of the model + var_name_types -- list of tuples of strings with variable names and + types of the variable + pre_var_name_types -- list of tuples of strings with variable names and + types of the variable + var_name_types -- list of tuples of strings with variable names and + types of the variable + derived_params -- list of tuples, where the first member is string + with name of the derived parameter and the second + should be a functor returned by create_dpf_class + var_refs -- list of tuples of strings with variable names and + types of variabled variable + update_code -- string with the current injection code + extra_global_params -- list of pairs of strings with names and types of + additional parameters + """ + body = {} + + if row_update_code is not None: + body["get_row_update_code"] = lambda self: dedent(row_update_code) + + if host_update_code is not None: + body["get_host_update_code"] = lambda self: dedent(host_update_code) + + if pre_var_name_types is not None: + body["get_pre_vars"] = \ + lambda self: VarVector([Var(*vn) + for vn in pre_var_name_types]) + + if post_var_name_types is not None: + body["get_post_vars"] = \ + lambda self: VarVector([Var(*vn) + for vn in post_var_name_types]) + + if var_refs is not None: + body["get_var_refs"] = \ + lambda self: VarRefVector([VarRef(*v) + for v in var_refs]) + + if pre_var_refs is not None: + body["get_pre_var_refs"] = \ + lambda self: VarRefVector([VarRef(*v) + for v in pre_var_refs]) + + if post_var_refs is not None: + body["get_post_var_refs"] = \ + lambda self: VarRefVector([VarRef(*v) + for v in post_var_refs]) + + if custom_body is not None: + body.update(custom_body) + + return create_model(class_name, CustomConnectivityUpdateModelBase, + param_names, var_name_types, derived_params, + extra_global_params, body) -def create_custom_model_class(class_name, base, param_names, var_name_types, - derived_params, extra_global_params, custom_body): + +def create_model(class_name, base, param_names, var_name_types, + derived_params, extra_global_params, custom_body): """This helper function completes a custom model class creation. This part is common for all model classes and is nearly useless on its own unless you specify custom_body. See also: - create_custom_neuron_class - create_custom_weight_update_class - create_custom_postsynaptic_class - create_custom_current_source_class - create_custom_init_var_snippet_class - create_custom_sparse_connect_init_snippet_class + create_neuron_model + create_weight_update_model + create_postsynaptic_model + create_current_source_model + create_init_var_snippet + create_sparse_connect_init_snippet Args: class_name -- name of the new class @@ -1191,17 +1350,17 @@ def ctor(self): return type(class_name, (base,), body)() -def create_custom_init_var_snippet_class(class_name, param_names=None, - derived_params=None, - var_init_code=None, - extra_global_params=None): +def create_init_var_snippet(class_name, param_names=None, + derived_params=None, + var_init_code=None, + extra_global_params=None): """This helper function creates a custom InitVarSnippet class. See also: - create_custom_neuron_class - create_custom_weight_update_class - create_custom_postsynaptic_class - create_custom_current_source_class - create_custom_sparse_connect_init_snippet_class + create_neuron_model + create_weight_update_model + create_postsynaptic_model + create_current_source_model + create_sparse_connect_init_snippet Args: class_name -- name of the new class @@ -1221,30 +1380,30 @@ def create_custom_init_var_snippet_class(class_name, param_names=None, if var_init_code is not None: body["get_code"] = lambda self: dedent(var_init_code) - return create_custom_model_class( - class_name, genn_wrapper.InitVarSnippet.Custom, param_names, - None, derived_params, extra_global_params, body) - - -def create_custom_sparse_connect_init_snippet_class(class_name, - param_names=None, - derived_params=None, - row_build_code=None, - row_build_state_vars=None, - col_build_code=None, - col_build_state_vars=None, - calc_max_row_len_func=None, - calc_max_col_len_func=None, - calc_kernel_size_func=None, - extra_global_params=None): + return create_model(class_name, InitVarSnippetBase, + param_names, None, derived_params, + extra_global_params, body) + + +def create_sparse_connect_init_snippet(class_name, + param_names=None, + derived_params=None, + row_build_code=None, + row_build_state_vars=None, + rol_build_code=None, + col_build_state_vars=None, + calc_max_row_len_func=None, + calc_max_col_len_func=None, + calc_kernel_size_func=None, + extra_global_params=None): """This helper function creates a custom InitSparseConnectivitySnippet class. See also: - create_custom_neuron_class - create_custom_weight_update_class - create_custom_postsynaptic_class - create_custom_current_source_class - create_custom_init_var_snippet_class + create_neuron_model + create_weight_update_model + create_postsynaptic_model + create_current_source_model + create_init_var_snippet Args: class_name -- name of the new class @@ -1304,9 +1463,8 @@ def create_custom_sparse_connect_init_snippet_class(class_name, body["get_calc_kernel_size_func"] = \ lambda self: make_cksf(calc_kernel_size_func) - return create_custom_model_class( - class_name, genn_wrapper.InitSparseConnectivitySnippet.Custom, param_names, - None, derived_params, extra_global_params, body) + return create_model(class_name, InitSparseConnectivitySnippetBase, param_names, + None, derived_params, extra_global_params, body) @deprecated("this wrapper is now unnecessary - use callables directly") def create_dpf_class(dp_func): diff --git a/pygenn/src/customConnectivityUpdateModels.cc b/pygenn/src/customConnectivityUpdateModels.cc new file mode 100644 index 0000000000..189bdbb58e --- /dev/null +++ b/pygenn/src/customConnectivityUpdateModels.cc @@ -0,0 +1,31 @@ +// PyBind11 includes +#include +#include + +// GeNN includes +#include "customConnectivityUpdateModels.h" + +using namespace GeNN::CustomConnectivityUpdateModels; + +namespace +{ +template +const Base *getBaseInstance() +{ + return static_cast(T::getInstance()); +} +} + +//---------------------------------------------------------------------------- +// custom_connectivity_update_models +//---------------------------------------------------------------------------- +PYBIND11_MODULE(custom_connectivity_update_models, m) +{ + pybind11::module_::import("pygenn.genn"); + + //------------------------------------------------------------------------ + // Free functions + //------------------------------------------------------------------------ + // **THINK** with some cunning, standard macros could maybe populate + // an array with instance pointers that we could loop over +} diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index 3b8e444946..0be90e5b0e 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -10,6 +10,8 @@ // GeNN includes #include "currentSource.h" #include "currentSourceModels.h" +#include "customConnectivityUpdate.h" +#include "customConnectivityUpdateModels.h" #include "customUpdate.h" #include "customUpdateModels.h" #include "initSparseConnectivitySnippet.h" @@ -112,6 +114,25 @@ class PyCurrentSourceModelBase : public PyModel virtual std::string getInjectionCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_injection_code", getInjectionCode); } }; +//---------------------------------------------------------------------------- +// PyCustomConnectivityUpdateModelBase +//---------------------------------------------------------------------------- +// 'Trampoline' class for custom connectivity update models +class PyCustomConnectivityUpdateModelBase : public PyModel +{ + using Base = CustomConnectivityUpdateModels::Base; +public: + virtual VarVec getPreVars() const override { PYBIND11_OVERRIDE_NAME(Models::Base::VarVec, Base, "get_pre_vars", getPreVars); } + virtual VarVec getPostVars() const override { PYBIND11_OVERRIDE_NAME(Models::Base::VarVec, Base, "get_post_vars", getPostVars); } + + virtual VarRefVec getVarRefs() const override { PYBIND11_OVERRIDE_NAME(VarRefVec, Base, "get_var_refs", getVarRefs); } + virtual VarRefVec getPreVarRefs() const override { PYBIND11_OVERRIDE_NAME(VarRefVec, Base, "get_pre_var_refs", getPreVarRefs); } + virtual VarRefVec getPostVarRefs() const override { PYBIND11_OVERRIDE_NAME(VarRefVec, Base, "get_post_var_refs", getPostVarRefs); } + + virtual std::string getRowUpdateCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_row_update_code", getRowUpdateCode); } + virtual std::string getHostUpdateCode() const override { PYBIND11_OVERRIDE_NAME(std::string, Base, "get_host_update_code", getHostUpdateCode); } +}; + //---------------------------------------------------------------------------- // PyCustomUpdateModelBase //---------------------------------------------------------------------------- @@ -371,6 +392,12 @@ PYBIND11_MODULE(genn, m) const std::string&, const CurrentSourceModels::Base*, const std::string&, const ParamValues&, const VarValues&)>(&ModelSpecInternal::addCurrentSource), pybind11::return_value_policy::reference) + .def("add_custom_connectivity_update", + static_cast(&ModelSpecInternal::addCustomConnectivityUpdate), + pybind11::return_value_policy::reference) .def("add_custom_update", static_cast(&CurrentSource::getVarLocation, pybind11::const_)); + //------------------------------------------------------------------------ + // genn.CustomConnectivityUpdate + //------------------------------------------------------------------------ + pybind11::class_(m, "CustomConnectivityUpdate", pybind11::dynamic_attr()) + //-------------------------------------------------------------------- + // Properties + //-------------------------------------------------------------------- + .def_property_readonly("name", &CustomConnectivityUpdate::getName) + .def_property_readonly("update_group_name", &CustomConnectivityUpdate::getUpdateGroupName) + .def_property_readonly("model", &CustomConnectivityUpdate::getCustomConnectivityUpdateModel, pybind11::return_value_policy::reference) + .def_property_readonly("params", &CustomConnectivityUpdate::getParams) + + .def_property_readonly("var_initialisers", &CustomConnectivityUpdate::getVarInitialisers) + .def_property_readonly("pre_var_initialisers", &CustomConnectivityUpdate::getPreVarInitialisers) + .def_property_readonly("post_var_initialisers", &CustomConnectivityUpdate::getPostVarInitialisers) + + .def_property_readonly("var_references", &CustomConnectivityUpdate::getVarReferences) + .def_property_readonly("pre_var_references", &CustomConnectivityUpdate::getPreVarReferences) + .def_property_readonly("post_var_references", &CustomConnectivityUpdate::getPostVarReferences) + + // **NOTE** we use the 'publicist' pattern to expose some protected properties + .def_property_readonly("_synapse_group", &CustomConnectivityUpdateInternal::getSynapseGroup) + + //-------------------------------------------------------------------- + // Methods + //-------------------------------------------------------------------- + .def("set_var_location", &CustomConnectivityUpdate::setVarLocation) + .def("set_pre_var_location", &CustomConnectivityUpdate::setPreVarLocation) + .def("set_post_var_location", &CustomConnectivityUpdate::setPostVarLocation) + .def("get_var_location", &CustomConnectivityUpdate::getVarLocation) + .def("get_pre_var_location", &CustomConnectivityUpdate::getPreVarLocation) + .def("get_post_var_location", &CustomConnectivityUpdate::getPostVarLocation); + + //------------------------------------------------------------------------ // genn.CustomUpdateBase //------------------------------------------------------------------------ @@ -438,7 +499,7 @@ PYBIND11_MODULE(genn, m) // Methods //-------------------------------------------------------------------- .def("set_var_location", &CustomUpdateBase::setVarLocation) - .def("get_var_location", pybind11::overload_cast(&CustomUpdateBase::getVarLocation, pybind11::const_)); + .def("get_var_location", &CustomUpdateBase::getVarLocation); //------------------------------------------------------------------------ // genn.CustomUpdate @@ -448,7 +509,7 @@ PYBIND11_MODULE(genn, m) .def_property_readonly("var_references", &CustomUpdate::getVarReferences); //------------------------------------------------------------------------ - // genn.CustomUpdate + // genn.CustomUpdateWU //------------------------------------------------------------------------ pybind11::class_(m, "CustomUpdateWU", pybind11::dynamic_attr()) .def_property_readonly("var_references", &CustomUpdateWU::getVarReferences); @@ -509,7 +570,7 @@ PYBIND11_MODULE(genn, m) .def_property("num_threads_per_spike",&SynapseGroup::getNumThreadsPerSpike, &SynapseGroup::setNumThreadsPerSpike) .def_property("back_prop_delay_steps",&SynapseGroup::getBackPropDelaySteps, &SynapseGroup::setBackPropDelaySteps) .def_property("narrow_sparse_ind_enabled",nullptr, &SynapseGroup::setNarrowSparseIndEnabled) - // **NOTE** we use the 'publicist' pattern to expose some protected properties + // **NOTE** we use the 'publicist' pattern to expose some protected properties .def_property_readonly("_ps_model_fused", &SynapseGroupInternal::isPSModelFused) .def_property_readonly("_wu_pre_model_fused", &SynapseGroupInternal::isWUPreModelFused) .def_property_readonly("_wu_post_model_fused", &SynapseGroupInternal::isWUPostModelFused) @@ -662,6 +723,22 @@ PYBIND11_MODULE(genn, m) .def(pybind11::init<>()) .def("get_injection_code", &CurrentSourceModels::Base::getInjectionCode); + + //------------------------------------------------------------------------ + // genn.CustomConnectivityUpdateModelBase + //------------------------------------------------------------------------ + pybind11::class_(m, "CustomConnectivityUpdateModelBase") + .def(pybind11::init<>()) + + .def("get_pre_vars", &CustomConnectivityUpdateModels::Base::getPreVars) + .def("get_post_vars", &CustomConnectivityUpdateModels::Base::getPostVars) + + .def("get_var_refs", &CustomConnectivityUpdateModels::Base::getVarRefs) + .def("get_pre_var_refs", &CustomConnectivityUpdateModels::Base::getPreVarRefs) + .def("get_post_var_refs", &CustomConnectivityUpdateModels::Base::getPostVarRefs) + + .def("get_row_update_code", &CustomConnectivityUpdateModels::Base::getRowUpdateCode) + .def("get_host_update_code", &CustomConnectivityUpdateModels::Base::getHostUpdateCode); //------------------------------------------------------------------------ // genn.CustomUpdateModelBase diff --git a/setup.py b/setup.py index 18bcf8f1b2..51a8fbbbbe 100644 --- a/setup.py +++ b/setup.py @@ -143,6 +143,9 @@ Pybind11Extension("current_source_models", [os.path.join(pygenn_src, "currentSourceModels.cc")], **genn_extension_kwargs), + Pybind11Extension("custom_connectivity_update_models", + [os.path.join(pygenn_src, "customConnectivityUpdateModels.cc")], + **genn_extension_kwargs), Pybind11Extension("custom_update_models", [os.path.join(pygenn_src, "customUpdateModels.cc")], **genn_extension_kwargs), From 354f33d6c743c9d1d79318789baa88d7e80ee3da Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 15:09:54 +0100 Subject: [PATCH 445/725] fixed a nasty bug involving ISyn --- src/genn/genn/code_generator/neuronUpdateGroupMerged.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index e2a8b5a0a5..71819b71f4 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -114,7 +114,9 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env psmEnv.add(getScalarType(), "inSyn", "linSyn"); // Allow synapse group's PS output var to override what Isyn points to - psmEnv.add(getScalarType(), "Isyn", getArchetype().getPSTargetVar()); + if(getArchetype().getPSTargetVar() != "Isyn") { + psmEnv.add(getScalarType(), "Isyn", "$(" + getArchetype().getPSTargetVar() + ")"); + } // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( From d7ca4582cc9f58b85f8cd9ebacb77aaab0b7f81a Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 17:23:42 +0100 Subject: [PATCH 446/725] fixed a lot of small things and implemented first feature test --- include/genn/genn/synapseGroup.h | 1 + pygenn/genn_groups.py | 10 +- pygenn/genn_model.py | 31 ++--- pygenn/src/genn.cc | 3 - setup.py | 23 +++- .../backends/single_threaded_cpu/backend.cc | 16 +-- src/genn/genn/code_generator/backendBase.cc | 16 +++ .../customConnectivityUpdateGroupMerged.cc | 14 ++- tests/features/test_sim_rng.py | 112 ++++++++++++++++++ 9 files changed, 180 insertions(+), 46 deletions(-) create mode 100644 tests/features/test_sim_rng.py diff --git a/include/genn/genn/synapseGroup.h b/include/genn/genn/synapseGroup.h index 443e7e6257..2a89f796d3 100644 --- a/include/genn/genn/synapseGroup.h +++ b/include/genn/genn/synapseGroup.h @@ -172,6 +172,7 @@ class GENN_EXPORT SynapseGroup //! Get location of weight update model per-synapse state variable by name VarLocation getWUVarLocation(const std::string &var) const; + //! Get location of weight update model presynaptic state variable by name VarLocation getWUPreVarLocation(const std::string &var) const; diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index 00d49b3ef2..76f4ad02c9 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -894,7 +894,8 @@ def _synapse_group(self): class CustomConnectivityUpdateMixin(GroupMixin): """Class representing a custom connectivity update""" - def _init_group(self, model, var_space, pre_var_space, post_var_space): + def _init_group(self, model, var_space, pre_var_space, + post_var_space, synapse_group): """Init CustomConnectivityUpdateGroup Args: @@ -902,6 +903,7 @@ def _init_group(self, model, var_space, pre_var_space, post_var_space): model -- pygenn.genn_model.GeNNModel this neuron group is part of """ super(CustomConnectivityUpdateMixin, self)._init_group(model) + self.synapse_group = synapse_group self.vars, self.extra_global_params = prepare_model( self.model, self, var_space) self.pre_vars = {vnt.name: Variable(vnt.name, vnt.type, @@ -933,9 +935,11 @@ def load(self): self._load_egp(var_data.extra_global_params, v.name) # Load pre and postsynaptic variables - self._load_vars(self.model.get_pre_vars(), self.src.size, + self._load_vars(self.model.get_pre_vars(), + self.synapse_group.src.size, self.pre_vars, self.get_pre_var_location) - self._load_vars(self.model.get_post_vars(), self.trg.size, + self._load_vars(self.model.get_post_vars(), + self.synapse_group.trg.size, self.post_vars, self.get_post_var_location) # Load custom update extra global parameters diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index 4761b02a35..966adada8c 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -547,13 +547,15 @@ class as string or instance of post_var_init = get_var_init(post_var_space) # Use superclass to add population + syn_group = self._validate_synapse_group(syn_group, "syn_group") c_update = super(GeNNModel, self).add_custom_connectivity_update( - cu_name, group_name, custom_connectivity_update_model, + cu_name, group_name, syn_group.name, custom_connectivity_update_model, param_space, var_init, pre_var_init, post_var_init, var_ref_space, pre_var_ref_space, post_var_ref_space) # Setup back-reference, store group in dictionary and return - c_update._init_group(self, var_space, pre_var_space, post_var_space) + c_update._init_group(self, var_space, pre_var_space, + post_var_space, syn_group) self.custom_connectivity_updates[cu_name] = c_update return c_update @@ -1148,16 +1150,14 @@ def create_current_source_model(class_name, param_names=None, extra_global_params -- list of pairs of strings with names and types of additional parameters """ - if not isinstance(custom_body, dict) and custom_body is not None: - raise ValueError("custom_body must be an instance of dict or None") - body = {} if injection_code is not None: body["get_injection_code"] = lambda self: dedent(injection_code) return create_model(class_name, CurrentSourceModelBase, param_names, - var_name_types, derived_params, body) + var_name_types, derived_params, + extra_global_params, body) def create_custom_update_model(class_name, param_names=None, @@ -1261,31 +1261,22 @@ def create_custom_connectivity_update_model(class_name, if pre_var_name_types is not None: body["get_pre_vars"] = \ - lambda self: VarVector([Var(*vn) - for vn in pre_var_name_types]) + lambda self: [Var(*vn) for vn in pre_var_name_types] if post_var_name_types is not None: body["get_post_vars"] = \ - lambda self: VarVector([Var(*vn) - for vn in post_var_name_types]) + lambda self: [Var(*vn) for vn in post_var_name_types] if var_refs is not None: - body["get_var_refs"] = \ - lambda self: VarRefVector([VarRef(*v) - for v in var_refs]) + body["get_var_refs"] = lambda self: [VarRef(*v) for v in var_refs] if pre_var_refs is not None: body["get_pre_var_refs"] = \ - lambda self: VarRefVector([VarRef(*v) - for v in pre_var_refs]) + lambda self: [VarRef(*v) for v in pre_var_refs] if post_var_refs is not None: body["get_post_var_refs"] = \ - lambda self: VarRefVector([VarRef(*v) - for v in post_var_refs]) - - if custom_body is not None: - body.update(custom_body) + lambda self: [VarRef(*v) for v in post_var_refs] return create_model(class_name, CustomConnectivityUpdateModelBase, param_names, var_name_types, derived_params, diff --git a/pygenn/src/genn.cc b/pygenn/src/genn.cc index 0be90e5b0e..71c3a7e102 100644 --- a/pygenn/src/genn.cc +++ b/pygenn/src/genn.cc @@ -467,9 +467,6 @@ PYBIND11_MODULE(genn, m) .def_property_readonly("var_references", &CustomConnectivityUpdate::getVarReferences) .def_property_readonly("pre_var_references", &CustomConnectivityUpdate::getPreVarReferences) .def_property_readonly("post_var_references", &CustomConnectivityUpdate::getPostVarReferences) - - // **NOTE** we use the 'publicist' pattern to expose some protected properties - .def_property_readonly("_synapse_group", &CustomConnectivityUpdateInternal::getSynapseGroup) //-------------------------------------------------------------------- // Methods diff --git a/setup.py b/setup.py index 51a8fbbbbe..4e258dd117 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ import os import sys from copy import deepcopy -from platform import system +from platform import system, uname from shutil import copytree, rmtree from pybind11.setup_helpers import Pybind11Extension, build_ext, WIN, MACOS from setuptools import find_packages, setup @@ -26,6 +26,12 @@ # **NOTE** Pybind11Extension provides WIN and MAC LINUX = system() == "Linux" +# Are we on WSL? +if sys.version_info < (3, 3): + WSL = "microsoft" in uname()[2] +else: + WSL = "microsoft" in uname().release + # Determine correct suffix for GeNN libraries if WIN: genn_lib_suffix = "_Debug_DLL" if debug_build else "_Release_DLL" @@ -87,20 +93,25 @@ # If CUDA was found, add backend configuration if cuda_installed: # Get CUDA library directory + cuda_library_dirs = [] if MACOS: - cuda_library_dir = os.path.join(cuda_path, "lib") + cuda_library_dirs.append(os.path.join(cuda_path, "lib")) elif WIN: - cuda_library_dir = os.path.join(cuda_path, "lib", "x64") + cuda_library_dirs.append(os.path.join(cuda_path, "lib", "x64")) else: - cuda_library_dir = os.path.join(cuda_path, "lib64") + cuda_library_dirs.append(os.path.join(cuda_path, "lib64")) + + # If we're running on WSL, add additional library path so libcuda can be found + if WSL: + cuda_library_dirs.append("/usr/lib/wsl/lib") # Add backend # **NOTE** on Mac OS X, a)runtime_library_dirs doesn't work b)setting rpath is required to find CUDA backends.append(("cuda", "cuda", {"libraries": ["cuda", "cudart"], "include_dirs": [os.path.join(cuda_path, "include")], - "library_dirs": [cuda_library_dir], - "extra_link_args": ["-Wl,-rpath," + cuda_library_dir] if MACOS else []})) + "library_dirs": cuda_library_dirs, + "extra_link_args": ["-Wl,-rpath," + cuda_library_dirs[0]] if MACOS else []})) # If OpenCL was found, add backend configuration if opencl_installed: diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 99e8a9ad16..9978dc0233 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -577,7 +577,6 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, c); - buildStandardEnvironment(groupEnv); if (c.getArchetype().isNeuronReduction()) { @@ -724,21 +723,22 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomConnectivityUpdateGroup" << c.getIndex() << "[g]; " << std::endl; - // Create matching environment - EnvironmentGroupMergedField groupEnv(funcEnv, c); + // Add host RNG functions + EnvironmentLibrary rngEnv(funcEnv, StandardLibrary::getHostRNGFunctions(c.getScalarType())); + // Create matching environment + EnvironmentGroupMergedField groupEnv(rngEnv, c); buildStandardEnvironment(groupEnv); - + // Loop through presynaptic neurons - funcEnv.getStream() << "for(unsigned int i = 0; i < " << funcEnv["num_pre"] << "; i++)"; + groupEnv.print("for(unsigned int i = 0; i < $(num_pre); i++)"); { - CodeStream::Scope b(funcEnv.getStream()); + CodeStream::Scope b(groupEnv.getStream()); // Configure substitutions groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); - assert(false); - //c.generateUpdate(*this, cuEnv, model.getBatchSize()); + c.generateUpdate(*this, groupEnv, 1); } } }); diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 8950739d6a..689deae254 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -380,6 +380,22 @@ void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const { + // Add fields for number of pre and postsynaptic neurons + env.addField(Type::Uint32.addConst(), "num_pre", + Type::Uint32, "numSrcNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); + }); + env.addField(Type::Uint32.addConst(), "num_post", + Type::Uint32, "numTrgNeurons", + [](const auto &cg, size_t) + { + const SynapseGroupInternal *sgInternal = static_cast(cg.getSynapseGroup()); + return std::to_string(sgInternal->getSrcNeuronGroup()->getNumNeurons()); + }); + // If there are delays on presynaptic variable references if(env.getGroup().getArchetype().getPreDelayNeuronGroup() != nullptr) { env.add(Type::Uint32.addConst(), "_pre_delay_offset", "preDelayOffset", diff --git a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc index 7e2cadb186..76b84c7746 100644 --- a/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customConnectivityUpdateGroupMerged.cc @@ -5,6 +5,7 @@ // GeNN code generator includes #include "code_generator/modelSpecMerged.h" +#include "code_generator/standardLibrary.h" // GeNN transpiler includes #include "transpiler/errorHandler.h" @@ -437,18 +438,19 @@ const std::string CustomConnectivityHostUpdateGroupMerged::name = "CustomConnect //---------------------------------------------------------------------------- void CustomConnectivityHostUpdateGroupMerged::generateUpdate(const BackendBase &backend, EnvironmentExternalBase &env) { - CodeStream::Scope b(env.getStream()); + EnvironmentLibrary rngEnv(env, StandardLibrary::getHostRNGFunctions(getScalarType())); + CodeStream::Scope b(rngEnv.getStream()); - env.getStream() << "// merged custom connectivity host update group " << getIndex() << std::endl; - env.getStream() << "for(unsigned int g = 0; g < " << getGroups().size() << "; g++)"; + rngEnv.getStream() << "// merged custom connectivity host update group " << getIndex() << std::endl; + rngEnv.getStream() << "for(unsigned int g = 0; g < " << getGroups().size() << "; g++)"; { - CodeStream::Scope b(env.getStream()); + CodeStream::Scope b(rngEnv.getStream()); // Get reference to group - env.getStream() << "const auto *group = &mergedCustomConnectivityHostUpdateGroup" << getIndex() << "[g]; " << std::endl; + rngEnv.getStream() << "const auto *group = &mergedCustomConnectivityHostUpdateGroup" << getIndex() << "[g]; " << std::endl; // Create matching environment - EnvironmentGroupMergedField groupEnv(env, *this); + EnvironmentGroupMergedField groupEnv(rngEnv, *this); // Add fields for number of pre and postsynaptic neurons groupEnv.addField(Type::Uint32.addConst(), "num_pre", diff --git a/tests/features/test_sim_rng.py b/tests/features/test_sim_rng.py new file mode 100644 index 0000000000..66c63a2d8b --- /dev/null +++ b/tests/features/test_sim_rng.py @@ -0,0 +1,112 @@ +import numpy as np +import pytest +from pygenn import types +from scipy import stats + +from pygenn import GeNNModel + +from pygenn import (create_current_source_model, + create_custom_connectivity_update_model, + create_neuron_model, + init_sparse_connectivity) + +neuron_model = create_neuron_model( + "neuron", + sim_code= + """ + uniform = gennrand_uniform(); + normal = gennrand_normal(); + """, + var_name_types=[("uniform", "scalar"), ("normal", "scalar")]) + +current_source_model = create_current_source_model( + "current_source", + injection_code= + """ + uniform = gennrand_uniform(); + normal = gennrand_normal(); + injectCurrent(0.0); + """, + var_name_types=[("uniform", "scalar"), ("normal", "scalar")]) + +custom_connectivity_update_model = create_custom_connectivity_update_model( + "custom_connectivity_update", + row_update_code= + """ + preUniform = gennrand_uniform(); + preNormal = gennrand_normal(); + """, + host_update_code= + """ + for(int i = 0; i < num_pre; i++) { + postUniform[i] = gennrand_uniform(); + postNormal[i] = gennrand_normal(); + } + pushpostUniformToDevice(); + pushpostNormalToDevice(); + """, + pre_var_name_types=[("preUniform", "scalar"), ("preNormal", "scalar")], + post_var_name_types=[("postUniform", "scalar"), ("postNormal", "scalar")]) + +@pytest.mark.parametrize("backend", ["single_threaded_cpu", "cuda"]) +@pytest.mark.parametrize("precision", [types.Double, types.Float]) +def test_sim_rng(backend, precision): + model = GeNNModel(precision, "test_sim_rng", backend=backend) + + # Add neuron and current source populations + var_init = {"uniform": 0.0, "normal": 0.0} + n_pop = model.add_neuron_population("Neurons", 1000, neuron_model, + {}, var_init) + cs_pop = model.add_current_source("CurrentSource", + current_source_model, n_pop, + {}, var_init) + + # Add second neuron and synapse population to hang custom connectivity update off + post_n_pop = model.add_neuron_population("PostNeurons", 1000, "SpikeSource", + {}, {}) + s_pop = model.add_synapse_population("Synapses", "SPARSE", 0, + n_pop, post_n_pop, + "StaticPulseConstantWeight", {"g": 1.0}, {}, {}, {}, + "DeltaCurr", {}, {}, + init_sparse_connectivity("FixedProbability", {"prob": 0.1})) + + # Add custom connectivity update + ccu = model.add_custom_connectivity_update( + "CustomConnectivityUpdate", "Connectivity", s_pop, + custom_connectivity_update_model, {}, {}, {"preUniform": 0.0, "preNormal": 0.0}, + {"postUniform": 0.0, "postNormal": 0.0}, {}, {}, {}) + + # Build model and load + model.build() + model.load() + + # Run for 1000 timesteps + samples = [ + (n_pop, "uniform", n_pop.vars, stats.uniform.cdf, np.empty((1000, 1000))), + (n_pop, "normal", n_pop.vars, stats.norm.cdf, np.empty((1000, 1000))), + (cs_pop, "uniform", cs_pop.vars, stats.uniform.cdf, np.empty((1000, 1000))), + (cs_pop, "normal", cs_pop.vars, stats.norm.cdf, np.empty((1000, 1000))), + (ccu, "preUniform", ccu.pre_vars, stats.uniform.cdf, np.empty((1000, 1000))), + (ccu, "preNormal", ccu.pre_vars, stats.norm.cdf, np.empty((1000, 1000))), + (ccu, "postUniform", ccu.post_vars, stats.uniform.cdf, np.empty((1000, 1000))), + (ccu, "postNormal", ccu.post_vars, stats.norm.cdf, np.empty((1000, 1000)))] + while model.timestep < 1000: + model.step_time() + model.custom_update("Connectivity") + + # Loop through samples + for pop, var_name, vars, _, data in samples: + pop.pull_var_from_device(var_name) + + # Copy data into array + data[model.timestep - 1,:] = vars[var_name].view[:] + + # Check all p-values exceed 95% confidence internal + for pop, var_name, _, cdf, data in samples: + p = stats.kstest(data.flatten(), cdf).pvalue + if p < 0.05: + assert False, f"{pop.name} '{var_name} initialisation failes KS test (p={p})" + + +if __name__ == '__main__': + test_sim_rng("single_threaded_cpu", types.Double) \ No newline at end of file From bc765491aab8d43743036560f5a8046bf4f3dd89 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 17:46:23 +0100 Subject: [PATCH 447/725] renamed feature test for consistency --- tests/features/{test_sim_rng.py => test_rng_sim.py} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename tests/features/{test_sim_rng.py => test_rng_sim.py} (97%) diff --git a/tests/features/test_sim_rng.py b/tests/features/test_rng_sim.py similarity index 97% rename from tests/features/test_sim_rng.py rename to tests/features/test_rng_sim.py index 66c63a2d8b..9863f7a559 100644 --- a/tests/features/test_sim_rng.py +++ b/tests/features/test_rng_sim.py @@ -50,8 +50,8 @@ @pytest.mark.parametrize("backend", ["single_threaded_cpu", "cuda"]) @pytest.mark.parametrize("precision", [types.Double, types.Float]) -def test_sim_rng(backend, precision): - model = GeNNModel(precision, "test_sim_rng", backend=backend) +def test_rng_sim(backend, precision): + model = GeNNModel(precision, "test_rng_sim", backend=backend) # Add neuron and current source populations var_init = {"uniform": 0.0, "normal": 0.0} From a6ce01ce9040f248de06e788b55b65ea3aea4056 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 18:17:22 +0100 Subject: [PATCH 448/725] fixed a couple of typos --- pygenn/genn_model.py | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index 966adada8c..f111d99522 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -1380,9 +1380,7 @@ def create_sparse_connect_init_snippet(class_name, param_names=None, derived_params=None, row_build_code=None, - row_build_state_vars=None, - rol_build_code=None, - col_build_state_vars=None, + col_build_code=None, calc_max_row_len_func=None, calc_max_col_len_func=None, calc_kernel_size_func=None, @@ -1406,13 +1404,8 @@ def create_sparse_connect_init_snippet(class_name, second MUST be an instance of the class which inherits from pygenn.genn_wrapper.DerivedParamFunc row_build_code -- string with row building initialization code - row_build_state_vars -- list of tuples of state variables, their types - and their initial values to use across - row building loop col_build_code -- string with column building initialization code - col_build_state_vars -- list of tuples of state variables, their types - and their initial values to use across - column building loop + calc_max_row_len_func -- instance of class inheriting from CalcMaxLengthFunc used to calculate maximum row length of synaptic matrix @@ -1430,19 +1423,9 @@ def create_sparse_connect_init_snippet(class_name, if row_build_code is not None: body["get_row_build_code"] = lambda self: dedent(row_build_code) - if row_build_state_vars is not None: - body["get_row_build_state_vars"] = \ - lambda self: ParamValVector([ParamVal(r[0], r[1], r[2]) - for r in row_build_state_vars]) - if col_build_code is not None: body["get_col_build_code"] = lambda self: dedent(col_build_code) - if col_build_state_vars is not None: - body["get_col_build_state_vars"] = \ - lambda self: ParamValVector([ParamVal(r[0], r[1], r[2]) - for r in col_build_state_vars]) - if calc_max_row_len_func is not None: body["get_calc_max_row_length_func"] = \ lambda self: make_cmlf(calc_max_row_len_func) @@ -1450,6 +1433,7 @@ def create_sparse_connect_init_snippet(class_name, if calc_max_col_len_func is not None: body["get_calc_max_col_length_func"] = \ lambda self: make_cmlf(calc_max_col_len_func) + if calc_kernel_size_func is not None: body["get_calc_kernel_size_func"] = \ lambda self: make_cksf(calc_kernel_size_func) From 8f278d9b874a2e9ed32a0e5edd8846cb3f676d74 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 18:17:31 +0100 Subject: [PATCH 449/725] started work on spike propagation test --- tests/features/test_spike_propagation.py | 42 ++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/features/test_spike_propagation.py diff --git a/tests/features/test_spike_propagation.py b/tests/features/test_spike_propagation.py new file mode 100644 index 0000000000..d597737947 --- /dev/null +++ b/tests/features/test_spike_propagation.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest +from pygenn import types + +from pygenn import GeNNModel + +from pygenn.genn import VarAccess +from pygenn import (create_neuron_model, + create_sparse_connect_init_snippet, + init_sparse_connectivity) + +decoder_model = create_sparse_connect_init_snippet( + "decoder", + row_build_code= + """ + for(unsigned int j = 0; j < num_post; j++) { + const unsigned int jValue = (1 << j); + if(((id_pre + 1) & jValue) != 0) { + addSynapse(j); + } + } + """) + +pre_cont_neuron_model = create_neuron_model( + "pre_cont_neuron", + var_name_types=[("x", "scalar", VarAccess.READ_ONLY)]) + +post_neuron_model = create_neuron_model( + "post_neuron", + sim_code= + """ + x= Isyn; + """, + var_name_types=[("x", "scalar")]) + +@pytest.mark.parametrize("backend", ["single_threaded_cpu", "cuda"]) +@pytest.mark.parametrize("precision", [types.Double, types.Float]) +def test_spike_propagation(backend, precision): + model = GeNNModel(precision, "test_spike_propagation", backend=backend) + +if __name__ == '__main__': + test_spike_propagation("single_threaded_cpu", types.Float) \ No newline at end of file From aa340cc72256b6a3ddc070565173aaa1dd646efc Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 20:31:24 +0100 Subject: [PATCH 450/725] fixed issue with Variable.set_values and ExtraGlobalParam.set_values --- pygenn/model_preprocessor.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pygenn/model_preprocessor.py b/pygenn/model_preprocessor.py index 6ee1fa0627..7a4db95cc5 100644 --- a/pygenn/model_preprocessor.py +++ b/pygenn/model_preprocessor.py @@ -148,8 +148,7 @@ def set_values(self, values): # they must be loaded at simulate time try: iter(values) - self.values = np.asarray( - values, dtype=self.group._model.genn_types[self.type]) + self.values = np.asarray(values) self.init_required = True self.extra_global_params = {} # Otherwise - they can be initialised on device as a scalar @@ -190,8 +189,7 @@ def set_values(self, values): # Try and iterate values try: iter(values) - self.values = np.asarray( - values, dtype=self.group._model.genn_types[self.type]) + self.values = np.asarray(values) # Otherwise give an error except TypeError: raise ValueError("extra global variables can only be " From a413a6c5d8ff7fe86ed6d29c72c016606776f295 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 20:31:37 +0100 Subject: [PATCH 451/725] first spike propagation test --- tests/features/test_spike_propagation.py | 78 +++++++++++++++++++++++- 1 file changed, 75 insertions(+), 3 deletions(-) diff --git a/tests/features/test_spike_propagation.py b/tests/features/test_spike_propagation.py index d597737947..f2ec975aa9 100644 --- a/tests/features/test_spike_propagation.py +++ b/tests/features/test_spike_propagation.py @@ -29,14 +29,86 @@ "post_neuron", sim_code= """ - x= Isyn; + x = Isyn; """, var_name_types=[("x", "scalar")]) +# decode_matrix_conn_gen_globalg_ragged, decode_matrix_conn_gen_globalg_bitmask, +# decode_matrix_conn_gen_globalg_bitmask_optimised, decode_matrix_conn_gen_individualg_ragged @pytest.mark.parametrize("backend", ["single_threaded_cpu", "cuda"]) @pytest.mark.parametrize("precision", [types.Double, types.Float]) -def test_spike_propagation(backend, precision): +def test_spike_propagation_snippet(backend, precision): model = GeNNModel(precision, "test_spike_propagation", backend=backend) + model.dt = 1.0 + + # Create spike source array to generate one-hot pattern to decode + ss_pop = model.add_neuron_population("SpikeSource", 16, "SpikeSourceArray", + {}, {"startSpike": np.arange(16), "endSpike": np.arange(1, 17)}) + ss_pop.extra_global_params["spikeTimes"].set_values(np.arange(16.0)) + + # Create one output neuron pop with constant weight sparse decoder population + sparse_constant_weight_n_pop = model.add_neuron_population( + "PostSparseConstantWeightNeuron", 4, post_neuron_model, + {}, {"x": 0.0}) + model.add_synapse_population( + "SparseConstantWeightSynapse", "SPARSE", 0, + ss_pop, sparse_constant_weight_n_pop, + "StaticPulseConstantWeight", {"g": 1.0}, {}, {}, {}, + "DeltaCurr", {}, {}, + init_sparse_connectivity(decoder_model, {})) + + # Create one output neuron pop with sparse decoder population + sparse_n_pop = model.add_neuron_population( + "PostSparseNeuron", 4, post_neuron_model, + {}, {"x": 0.0}) + model.add_synapse_population( + "SparseSynapse", "SPARSE", 0, + ss_pop, sparse_n_pop, + "StaticPulse", {}, {"g": 1.0}, {}, {}, + "DeltaCurr", {}, {}, + init_sparse_connectivity(decoder_model, {})) + + # Create one output neuron pop with bitmask decoder population + bitmask_n_pop = model.add_neuron_population( + "PostBitmaskNeuron", 4, post_neuron_model, + {}, {"x": 0.0}) + model.add_synapse_population( + "BitmaskSynapse", "SPARSE", 0, + ss_pop, bitmask_n_pop, + "StaticPulseConstantWeight", {"g": 1.0}, {}, {}, {}, + "DeltaCurr", {}, {}, + init_sparse_connectivity(decoder_model, {})) + + # Build model and load + model.build() + model.load() + + # Simulate 16 timesteps + output_place_values = 2 ** np.arange(4) + output_populations = [sparse_constant_weight_n_pop, + sparse_n_pop, bitmask_n_pop] + while model.timestep < 16: + model.step_time() + + # Loop through output populations + for pop in output_populations: + # Pull state variable + pop.pull_var_from_device("x") + + # Convert to binary mask + output_binary = np.isclose(np.ones(4), pop.vars["x"].view) + + # Sum up active place values + output_value = np.sum(output_place_values[output_binary]) + if output_value != (model.timestep - 1): + assert False, f"{pop.name} decoding incorrect ({output_value} rather than {model.timestep - 1})" + print("BEEP") + +@pytest.mark.parametrize("backend", ["single_threaded_cpu", "cuda"]) +@pytest.mark.parametrize("precision", [types.Double, types.Float]) +def test_cont_propagation_snippet(backend, precision): + model = GeNNModel(precision, "test_cont_propagation", backend=backend) + if __name__ == '__main__': - test_spike_propagation("single_threaded_cpu", types.Float) \ No newline at end of file + test_spike_propagation_snippet("single_threaded_cpu", types.Float) \ No newline at end of file From bc2c7cd2d34582bd9c303f3e71fae9368233a5bf Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 20:46:37 +0100 Subject: [PATCH 452/725] fixed pretty printing of scalar literals --- src/genn/genn/transpiler/prettyPrinter.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 1ece7b70f1..2bad4b250d 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -236,6 +236,10 @@ class Visitor : public Expression::Visitor, public Statement::Visitor else if (literal.getValue().type == Token::Type::UINT32_NUMBER) { m_Environment.get().getStream() << "u"; } + // Otherwise, if literal is a scalar, return literal suffix of scalar type fro context + else if (literal.getValue().type == Token::Type::SCALAR_NUMBER) { + m_Environment.get().getStream() << m_Context.at("scalar").getNumeric().literalSuffix; + } } virtual void visit(const Expression::Logical &logical) final From 4bf1bb377c8e5807ffccb0d58a71cf6dcec1cb30 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 20:47:06 +0100 Subject: [PATCH 453/725] renamed ``create_init_var_snippet`` to ``create_var_init_snippet`` for consistency --- pygenn/__init__.py | 2 +- pygenn/genn_model.py | 19 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pygenn/__init__.py b/pygenn/__init__.py index 092997cdae..47fdeabd26 100644 --- a/pygenn/__init__.py +++ b/pygenn/__init__.py @@ -13,7 +13,7 @@ create_current_source_model, create_custom_update_model, create_custom_connectivity_update_model, - create_init_var_snippet, + create_var_init_snippet, create_sparse_connect_init_snippet, init_sparse_connectivity, init_toeplitz_connectivity, init_var) diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index f111d99522..ef3f2d93e0 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -862,7 +862,7 @@ def create_neuron_model(class_name, param_names=None, create_postsynaptic_model create_weight_update_model create_current_source_model - create_init_var_snippet + create_var_init_snippet create_sparse_connect_init_snippet Args: @@ -927,7 +927,7 @@ def create_postsynaptic_model(class_name, param_names=None, create_neuron_model create_weight_update_model create_current_source_model - create_init_var_snippet + create_var_init_snippet create_sparse_connect_init_snippet Args: @@ -989,7 +989,7 @@ def create_weight_update_model(class_name, param_names=None, create_neuron_model create_postsynaptic_model create_current_source_model - create_init_var_snippet + create_var_init_snippet create_sparse_connect_init_snippet Args: @@ -1133,7 +1133,7 @@ def create_current_source_model(class_name, param_names=None, create_neuron_model create_weight_update_model create_current_source_model - create_init_var_snippet + create_var_init_snippet create_sparse_connect_init_snippet Args: @@ -1172,7 +1172,7 @@ def create_custom_update_model(class_name, param_names=None, create_neuron_model create_weight_update_model create_current_source_model - create_init_var_snippet + create_var_init_snippet create_sparse_connect_init Args: @@ -1227,8 +1227,7 @@ def create_custom_connectivity_update_model(class_name, create_neuron_model create_weight_update_model create_current_source_model - create_init_var_snippet - create_init_var_snippet + create_var_init_snippet create_sparse_connect_init Args: @@ -1294,7 +1293,7 @@ def create_model(class_name, base, param_names, var_name_types, create_weight_update_model create_postsynaptic_model create_current_source_model - create_init_var_snippet + create_var_init_snippet create_sparse_connect_init_snippet Args: @@ -1341,7 +1340,7 @@ def ctor(self): return type(class_name, (base,), body)() -def create_init_var_snippet(class_name, param_names=None, +def create_var_init_snippet(class_name, param_names=None, derived_params=None, var_init_code=None, extra_global_params=None): @@ -1392,7 +1391,7 @@ def create_sparse_connect_init_snippet(class_name, create_weight_update_model create_postsynaptic_model create_current_source_model - create_init_var_snippet + create_var_init_snippet Args: class_name -- name of the new class From f2fd2ab6b80aa5d30a2aa2050463c0b722c69f95 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 20:54:47 +0100 Subject: [PATCH 454/725] test_spike_propagation_snippet test complete --- tests/features/test_spike_propagation.py | 37 ++++++++++++++---------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/tests/features/test_spike_propagation.py b/tests/features/test_spike_propagation.py index f2ec975aa9..9c58dc260b 100644 --- a/tests/features/test_spike_propagation.py +++ b/tests/features/test_spike_propagation.py @@ -7,7 +7,8 @@ from pygenn.genn import VarAccess from pygenn import (create_neuron_model, create_sparse_connect_init_snippet, - init_sparse_connectivity) + create_var_init_snippet, + init_sparse_connectivity, init_var) decoder_model = create_sparse_connect_init_snippet( "decoder", @@ -21,9 +22,13 @@ } """) -pre_cont_neuron_model = create_neuron_model( - "pre_cont_neuron", - var_name_types=[("x", "scalar", VarAccess.READ_ONLY)]) +decoder_dense_model = create_var_init_snippet( + "decoder_dense", + var_init_code= + """ + const unsigned int jValue = (1 << id_post); + value = (((id_pre + 1) & jValue) != 0) ? 1.0 : 0.0; + """) post_neuron_model = create_neuron_model( "post_neuron", @@ -33,12 +38,11 @@ """, var_name_types=[("x", "scalar")]) -# decode_matrix_conn_gen_globalg_ragged, decode_matrix_conn_gen_globalg_bitmask, -# decode_matrix_conn_gen_globalg_bitmask_optimised, decode_matrix_conn_gen_individualg_ragged + @pytest.mark.parametrize("backend", ["single_threaded_cpu", "cuda"]) @pytest.mark.parametrize("precision", [types.Double, types.Float]) def test_spike_propagation_snippet(backend, precision): - model = GeNNModel(precision, "test_spike_propagation", backend=backend) + model = GeNNModel(precision, "test_spike_propagation_snippet", backend=backend) model.dt = 1.0 # Create spike source array to generate one-hot pattern to decode @@ -79,6 +83,16 @@ def test_spike_propagation_snippet(backend, precision): "DeltaCurr", {}, {}, init_sparse_connectivity(decoder_model, {})) + # Create one output neuron pop with bitmask decoder population + dense_n_pop = model.add_neuron_population( + "PostDenseNeuron", 4, post_neuron_model, + {}, {"x": 0.0}) + model.add_synapse_population( + "PostDenseSynapse", "DENSE", 0, + ss_pop, dense_n_pop, + "StaticPulse", {}, {"g": init_var(decoder_dense_model, {})}, {}, {}, + "DeltaCurr", {}, {}) + # Build model and load model.build() model.load() @@ -86,7 +100,7 @@ def test_spike_propagation_snippet(backend, precision): # Simulate 16 timesteps output_place_values = 2 ** np.arange(4) output_populations = [sparse_constant_weight_n_pop, - sparse_n_pop, bitmask_n_pop] + sparse_n_pop, bitmask_n_pop, dense_n_pop] while model.timestep < 16: model.step_time() @@ -102,13 +116,6 @@ def test_spike_propagation_snippet(backend, precision): output_value = np.sum(output_place_values[output_binary]) if output_value != (model.timestep - 1): assert False, f"{pop.name} decoding incorrect ({output_value} rather than {model.timestep - 1})" - print("BEEP") - -@pytest.mark.parametrize("backend", ["single_threaded_cpu", "cuda"]) -@pytest.mark.parametrize("precision", [types.Double, types.Float]) -def test_cont_propagation_snippet(backend, precision): - model = GeNNModel(precision, "test_cont_propagation", backend=backend) - if __name__ == '__main__': test_spike_propagation_snippet("single_threaded_cpu", types.Float) \ No newline at end of file From af98c139b8637815f141bd9a1c8b6466a104e69d Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 21:10:50 +0100 Subject: [PATCH 455/725] fixed some bugs in SynapseGroup --- pygenn/genn_groups.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index 76f4ad02c9..5ba0b126a7 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -468,7 +468,7 @@ def get_var_values(self, var_name): elif self.matrix_type & SynapseMatrixConnectivity.KERNEL: return np.copy(var_view) elif self.matrix_type & SynapseMatrixConnectivity.SPARSE: - max_rl = self.max_row_length + max_rl = self.max_connections row_ls = self._row_lengths if self._connectivity_initialiser_provided else self.row_lengths # Create range containing the index where each row starts in ind @@ -557,7 +557,7 @@ def get_sparse_post_inds(self): # the _ind array view still has some non-valid data so we remove them # with the row_lengths return np.hstack([ - self._ind[i * self.max_row_length: (i * self.max_row_length) + r] + self._ind[i * self.max_connections: (i * self.max_connections) + r] for i, r in enumerate(self._row_lengths)]) else: @@ -566,7 +566,7 @@ def get_sparse_post_inds(self): def pull_connectivity_from_device(self): """Wrapper around GeNNModel.pull_connectivity_from_device""" - self._model._slmpull_connectivity_from_device(self.name) + self._model._slm.pull_connectivity_from_device(self.name) def push_connectivity_to_device(self): """Wrapper around GeNNModel.push_connectivity_to_device""" @@ -624,7 +624,7 @@ def load(self): # Create (x)range containing the index where each row starts in ind row_start_idx = xrange(0, self.weight_update_var_size, - self.max_row_length) + self.max_connections) # Loop through ragged matrix rows syn = 0 @@ -762,7 +762,7 @@ def _init_wum_var(self, var_data, num_copies): # Create (x)range containing the index # where each row starts in ind row_start_idx = xrange(0, self.weight_update_var_size, - self.max_row_length) + self.max_connections) # Loop through ragged matrix rows syn = 0 From 1c4a429302934cfbbfa93088bc133b7da93a8401 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 25 Jul 2023 21:12:19 +0100 Subject: [PATCH 456/725] connectivity init test --- tests/features/test_connect_init.py | 58 +++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 tests/features/test_connect_init.py diff --git a/tests/features/test_connect_init.py b/tests/features/test_connect_init.py new file mode 100644 index 0000000000..088d16f9d5 --- /dev/null +++ b/tests/features/test_connect_init.py @@ -0,0 +1,58 @@ +import numpy as np +import pytest +from pygenn import types + +from pygenn import GeNNModel + +from pygenn import init_sparse_connectivity + +@pytest.mark.parametrize("backend", ["single_threaded_cpu", "cuda"]) +@pytest.mark.parametrize("precision", [types.Double, types.Float]) +def test_connect_init(backend, precision): + model = GeNNModel(precision, "test_connect_init", backend=backend) + model.narrow_sparse_ind_enabled = True + + # Create pre and postsynaptic neuron populations + pre_pop = model.add_neuron_population("Pre", 100, "SpikeSource", {}, {}) + post_pop = model.add_neuron_population("Post", 100, "SpikeSource", {}, {}) + + # Add synapse populations with different types of built-in connectivity + fixed_number_total_s_pop = model.add_synapse_population( + "FixedNumberTotal", "SPARSE", 0, + pre_pop, post_pop, + "StaticPulseConstantWeight", {"g": 1.0}, {}, {}, {}, + "DeltaCurr", {}, {}, + init_sparse_connectivity("FixedNumberTotalWithReplacement", {"total": 1000})) + + fixed_number_pre_s_pop = model.add_synapse_population( + "FixedNumberPre", "SPARSE", 0, + pre_pop, post_pop, + "StaticPulseConstantWeight", {"g": 1.0}, {}, {}, {}, + "DeltaCurr", {}, {}, + init_sparse_connectivity("FixedNumberPreWithReplacement", {"colLength": 10})) + + fixed_number_post_s_pop = model.add_synapse_population( + "FixedNumberPost", "SPARSE", 0, + pre_pop, post_pop, + "StaticPulseConstantWeight", {"g": 1.0}, {}, {}, {}, + "DeltaCurr", {}, {}, + init_sparse_connectivity("FixedNumberPostWithReplacement", {"rowLength": 10})) + + # Build and load model + model.build() + model.load() + + # Pull connectivity + fixed_number_total_s_pop.pull_connectivity_from_device() + fixed_number_pre_s_pop.pull_connectivity_from_device() + fixed_number_post_s_pop.pull_connectivity_from_device() + + # Check connectivity + assert np.all(np.bincount(fixed_number_post_s_pop.get_sparse_pre_inds()) == 10) + assert np.all(np.bincount(fixed_number_pre_s_pop.get_sparse_post_inds()) == 10) + assert len(fixed_number_total_s_pop.get_sparse_pre_inds()) == 1000 + + # **TODO** we could also build a histogram of postsynaptic neurons and check that they are approximately uniformly distributed + +if __name__ == '__main__': + test_connect_init("single_threaded_cpu", types.Float) \ No newline at end of file From da6392117bf7949424ebcb77c07b92e0468477b2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Jul 2023 12:41:21 +0100 Subject: [PATCH 457/725] delete now-reimplemented tests --- tests/features/connect_init/Makefile | 14 ---- tests/features/connect_init/connect_init.sln | 30 ------- .../connect_init/connect_init.vcxproj | 63 -------------- tests/features/connect_init/model.cc | 72 ---------------- tests/features/connect_init/runner_guid.txt | 1 - tests/features/connect_init/test.cc | 72 ---------------- .../current_source_rng_normal/Makefile | 1 - .../current_source_rng_normal.sln | 30 ------- .../current_source_rng_normal.vcxproj | 63 -------------- .../current_source_rng_normal/model.cc | 52 ------------ .../current_source_rng_normal/runner_guid.txt | 1 - .../current_source_rng_normal/test.cc | 45 ---------- .../custom_connectivity_update_rng/Makefile | 1 - .../custom_connectivity_update_rng/model.cc | 84 ------------------- .../custom_connectivity_update_rng/test.cc | 55 ------------ .../Makefile | 1 - ...decode_matrix_conn_gen_globalg_bitmask.sln | 30 ------- ...de_matrix_conn_gen_globalg_bitmask.vcxproj | 63 -------------- .../model.cc | 80 ------------------ .../runner_guid.txt | 1 - .../test.cc | 38 --------- .../Makefile | 1 - .../decode_matrix_conn_gen_globalg_ragged.sln | 30 ------- ...ode_matrix_conn_gen_globalg_ragged.vcxproj | 63 -------------- .../model.cc | 80 ------------------ .../runner_guid.txt | 1 - .../test.cc | 38 --------- .../Makefile | 14 ---- ...code_matrix_conn_gen_individualg_dense.sln | 30 ------- ..._matrix_conn_gen_individualg_dense.vcxproj | 63 -------------- .../model.cc | 68 --------------- .../runner_guid.txt | 1 - .../test.cc | 31 ------- .../Makefile | 1 - ...ode_matrix_conn_gen_individualg_ragged.sln | 30 ------- ...matrix_conn_gen_individualg_ragged.vcxproj | 63 -------------- .../model.cc | 80 ------------------ .../runner_guid.txt | 1 - .../test.cc | 38 --------- tests/features/neuron_rng_normal/Makefile | 1 - tests/features/neuron_rng_normal/model.cc | 45 ---------- .../neuron_rng_normal/neuron_rng_normal.sln | 30 ------- .../neuron_rng_normal.vcxproj | 63 -------------- .../neuron_rng_normal/runner_guid.txt | 1 - tests/features/neuron_rng_normal/test.cc | 45 ---------- tests/features/neuron_rng_uniform/Makefile | 1 - tests/features/neuron_rng_uniform/model.cc | 45 ---------- .../neuron_rng_uniform/neuron_rng_uniform.sln | 30 ------- .../neuron_rng_uniform.vcxproj | 63 -------------- .../neuron_rng_uniform/runner_guid.txt | 1 - tests/features/neuron_rng_uniform/test.cc | 45 ---------- 51 files changed, 1800 deletions(-) delete mode 100644 tests/features/connect_init/Makefile delete mode 100644 tests/features/connect_init/connect_init.sln delete mode 100644 tests/features/connect_init/connect_init.vcxproj delete mode 100644 tests/features/connect_init/model.cc delete mode 100644 tests/features/connect_init/runner_guid.txt delete mode 100644 tests/features/connect_init/test.cc delete mode 120000 tests/features/current_source_rng_normal/Makefile delete mode 100644 tests/features/current_source_rng_normal/current_source_rng_normal.sln delete mode 100644 tests/features/current_source_rng_normal/current_source_rng_normal.vcxproj delete mode 100644 tests/features/current_source_rng_normal/model.cc delete mode 100644 tests/features/current_source_rng_normal/runner_guid.txt delete mode 100644 tests/features/current_source_rng_normal/test.cc delete mode 120000 tests/features/custom_connectivity_update_rng/Makefile delete mode 100644 tests/features/custom_connectivity_update_rng/model.cc delete mode 100644 tests/features/custom_connectivity_update_rng/test.cc delete mode 120000 tests/features/decode_matrix_conn_gen_globalg_bitmask/Makefile delete mode 100644 tests/features/decode_matrix_conn_gen_globalg_bitmask/decode_matrix_conn_gen_globalg_bitmask.sln delete mode 100644 tests/features/decode_matrix_conn_gen_globalg_bitmask/decode_matrix_conn_gen_globalg_bitmask.vcxproj delete mode 100644 tests/features/decode_matrix_conn_gen_globalg_bitmask/model.cc delete mode 100644 tests/features/decode_matrix_conn_gen_globalg_bitmask/runner_guid.txt delete mode 100644 tests/features/decode_matrix_conn_gen_globalg_bitmask/test.cc delete mode 120000 tests/features/decode_matrix_conn_gen_globalg_ragged/Makefile delete mode 100644 tests/features/decode_matrix_conn_gen_globalg_ragged/decode_matrix_conn_gen_globalg_ragged.sln delete mode 100644 tests/features/decode_matrix_conn_gen_globalg_ragged/decode_matrix_conn_gen_globalg_ragged.vcxproj delete mode 100644 tests/features/decode_matrix_conn_gen_globalg_ragged/model.cc delete mode 100644 tests/features/decode_matrix_conn_gen_globalg_ragged/runner_guid.txt delete mode 100644 tests/features/decode_matrix_conn_gen_globalg_ragged/test.cc delete mode 100644 tests/features/decode_matrix_conn_gen_individualg_dense/Makefile delete mode 100644 tests/features/decode_matrix_conn_gen_individualg_dense/decode_matrix_conn_gen_individualg_dense.sln delete mode 100644 tests/features/decode_matrix_conn_gen_individualg_dense/decode_matrix_conn_gen_individualg_dense.vcxproj delete mode 100644 tests/features/decode_matrix_conn_gen_individualg_dense/model.cc delete mode 100644 tests/features/decode_matrix_conn_gen_individualg_dense/runner_guid.txt delete mode 100644 tests/features/decode_matrix_conn_gen_individualg_dense/test.cc delete mode 120000 tests/features/decode_matrix_conn_gen_individualg_ragged/Makefile delete mode 100644 tests/features/decode_matrix_conn_gen_individualg_ragged/decode_matrix_conn_gen_individualg_ragged.sln delete mode 100644 tests/features/decode_matrix_conn_gen_individualg_ragged/decode_matrix_conn_gen_individualg_ragged.vcxproj delete mode 100644 tests/features/decode_matrix_conn_gen_individualg_ragged/model.cc delete mode 100644 tests/features/decode_matrix_conn_gen_individualg_ragged/runner_guid.txt delete mode 100644 tests/features/decode_matrix_conn_gen_individualg_ragged/test.cc delete mode 120000 tests/features/neuron_rng_normal/Makefile delete mode 100644 tests/features/neuron_rng_normal/model.cc delete mode 100644 tests/features/neuron_rng_normal/neuron_rng_normal.sln delete mode 100644 tests/features/neuron_rng_normal/neuron_rng_normal.vcxproj delete mode 100644 tests/features/neuron_rng_normal/runner_guid.txt delete mode 100644 tests/features/neuron_rng_normal/test.cc delete mode 120000 tests/features/neuron_rng_uniform/Makefile delete mode 100644 tests/features/neuron_rng_uniform/model.cc delete mode 100644 tests/features/neuron_rng_uniform/neuron_rng_uniform.sln delete mode 100644 tests/features/neuron_rng_uniform/neuron_rng_uniform.vcxproj delete mode 100644 tests/features/neuron_rng_uniform/runner_guid.txt delete mode 100644 tests/features/neuron_rng_uniform/test.cc diff --git a/tests/features/connect_init/Makefile b/tests/features/connect_init/Makefile deleted file mode 100644 index 7a76a5c854..0000000000 --- a/tests/features/connect_init/Makefile +++ /dev/null @@ -1,14 +0,0 @@ -CXXFLAGS +=-std=c++11 -Wall -Wpedantic -Wextra -I $(GTEST_DIR) -isystem $(GTEST_DIR)/include - -.PHONY: all clean generated_code - -all: test - -test: test.cc generated_code - $(CXX) $(CXXFLAGS) test.cc $(GTEST_DIR)/src/gtest-all.cc $(GTEST_DIR)/src/gtest_main.cc -o test -L$(SIM_CODE) -pthread -lrunner -Wl,-rpath $(SIM_CODE) - -generated_code: - $(MAKE) -C $(SIM_CODE) - -clean: - @rm -f test $(SIM_CODE)/librunner.so $(SIM_CODE)/*.o $(SIM_CODE)/*.d default.profraw diff --git a/tests/features/connect_init/connect_init.sln b/tests/features/connect_init/connect_init.sln deleted file mode 100644 index b1e1c4c921..0000000000 --- a/tests/features/connect_init/connect_init.sln +++ /dev/null @@ -1,30 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.30501.0 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "connect_init", "connect_init.vcxproj", "{D0CC74A2-4924-45EF-BEC3-655AAF288ED6}" - ProjectSection(ProjectDependencies) = postProject - {2CE209EC-77CC-4F55-8A11-9CB657C42CB0} = {2CE209EC-77CC-4F55-8A11-9CB657C42CB0} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "connect_init_CODE\runner.vcxproj", "{2CE209EC-77CC-4F55-8A11-9CB657C42CB0}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {D0CC74A2-4924-45EF-BEC3-655AAF288ED6}.Debug|x64.ActiveCfg = Debug|x64 - {D0CC74A2-4924-45EF-BEC3-655AAF288ED6}.Debug|x64.Build.0 = Debug|x64 - {D0CC74A2-4924-45EF-BEC3-655AAF288ED6}.Release|x64.ActiveCfg = Release|x64 - {D0CC74A2-4924-45EF-BEC3-655AAF288ED6}.Release|x64.Build.0 = Release|x64 - {2CE209EC-77CC-4F55-8A11-9CB657C42CB0}.Debug|x64.ActiveCfg = Debug|x64 - {2CE209EC-77CC-4F55-8A11-9CB657C42CB0}.Debug|x64.Build.0 = Debug|x64 - {2CE209EC-77CC-4F55-8A11-9CB657C42CB0}.Release|x64.ActiveCfg = Release|x64 - {2CE209EC-77CC-4F55-8A11-9CB657C42CB0}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/tests/features/connect_init/connect_init.vcxproj b/tests/features/connect_init/connect_init.vcxproj deleted file mode 100644 index ec79d1d1ec..0000000000 --- a/tests/features/connect_init/connect_init.vcxproj +++ /dev/null @@ -1,63 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {D0CC74A2-4924-45EF-BEC3-655AAF288ED6} - - - - - - - - - Application - true - $(DefaultPlatformToolset) - true - MultiByte - - - - - - - - - - ./ - $(Platform)\$(Configuration)\ - test - - - - Level3 - MaxSpeed - Disabled - true - true - true - connect_init_CODE;$(GTEST_DIR);$(GTEST_DIR)/include - _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;_MBCS;%(PreprocessorDefinitions) - - - true - true - true - runner_Release.lib;%(AdditionalDependencies) - runner_Debug.lib;%(AdditionalDependencies) - - - - - - diff --git a/tests/features/connect_init/model.cc b/tests/features/connect_init/model.cc deleted file mode 100644 index 47f8d0efcd..0000000000 --- a/tests/features/connect_init/model.cc +++ /dev/null @@ -1,72 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file connect_init/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - model.setDT(0.1); - model.setName("connect_init"); - model.setDefaultNarrowSparseIndEnabled(true); - - NeuronModels::LIF::ParamValues lifParams( - 0.25, // 0 - C - 10.0, // 1 - TauM - -65.0, // 2 - Vrest - -65.0, // 3 - Vreset - -50.0, // 4 - Vthresh - 0.0, // 5 - Ioffset - 2.0); // 6 - TauRefrac - NeuronModels::LIF::VarValues lifInit( - -65.0, // 0 - V - 0.0); // 1 - RefracTime - - WeightUpdateModels::StaticPulse::VarValues staticSynapseInit(0.1); - - InitSparseConnectivitySnippet::FixedNumberTotalWithReplacement::ParamValues fixedNumTotalParams(1000); - InitSparseConnectivitySnippet::FixedNumberPostWithReplacement::ParamValues fixedNumPostParams(10); - InitSparseConnectivitySnippet::FixedNumberPreWithReplacement::ParamValues fixedNumPreParams(10); - - model.addNeuronPopulation("SpikeSource", 100, {}, {}); - model.addNeuronPopulation("LIF", 100, lifParams, lifInit); - - // Fixed number total connectivity - model.addSynapsePopulation( - "FixedNumberTotal", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, - "SpikeSource", "LIF", - {}, staticSynapseInit, {}, {}, - {}, {}, - initConnectivity(fixedNumTotalParams)); - - // Fixed number post connectivity - model.addSynapsePopulation( - "FixedNumberPost", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, - "SpikeSource", "LIF", - {}, staticSynapseInit, {}, {}, - {}, {}, - initConnectivity(fixedNumPostParams)); - - // Fixed number pre connectivity - model.addSynapsePopulation( - "FixedNumberPre", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, - "SpikeSource", "LIF", - {}, staticSynapseInit, {}, {}, - {}, {}, - initConnectivity(fixedNumPreParams)); -} diff --git a/tests/features/connect_init/runner_guid.txt b/tests/features/connect_init/runner_guid.txt deleted file mode 100644 index 4a0c4cf99c..0000000000 --- a/tests/features/connect_init/runner_guid.txt +++ /dev/null @@ -1 +0,0 @@ -2CE209EC-77CC-4F55-8A11-9CB657C42CB0 diff --git a/tests/features/connect_init/test.cc b/tests/features/connect_init/test.cc deleted file mode 100644 index 91c4c78a59..0000000000 --- a/tests/features/connect_init/test.cc +++ /dev/null @@ -1,72 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file connect_init/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- -// Standard C++ includes -#include -#include -#include - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "connect_init_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test.h" - -//---------------------------------------------------------------------------- -// Macros -//---------------------------------------------------------------------------- -#define CALC_ROW_LENGTH(NAME, HISTOGRAM) calcHistogram(rowLength##NAME, ind##NAME, maxRowLength##NAME, HISTOGRAM) - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTest -{ -}; - -template -void calcHistogram(const unsigned int *rowLength, const I *ind, - unsigned int maxRowLength, std::array &histogram) -{ - // Loop through rows - for(unsigned int i = 0; i < N; i++) { - // Loop through synapses - for(unsigned int j = 0; j < rowLength[i]; j++) { - // Increment histogram bin - EXPECT_LT(ind[j], N); - histogram[ind[j]]++; - } - - // Advance to next row - ind += maxRowLength; - } -} - -TEST_F(SimTest, ConnectInit) -{ - // Pull connectivity back to host - pullFixedNumberTotalConnectivityFromDevice(); - pullFixedNumberPostConnectivityFromDevice(); - pullFixedNumberPreConnectivityFromDevice(); - - // Test that connectivity has required properties - EXPECT_EQ(std::accumulate(&rowLengthFixedNumberTotal[0], &rowLengthFixedNumberTotal[100], 0u), 1000); - EXPECT_TRUE(std::all_of(&rowLengthFixedNumberPost[0], &rowLengthFixedNumberPost[100], - [](unsigned int rowLength) { return rowLength == 10; })); - - std::array fixedNumPreHist{}; - CALC_ROW_LENGTH(FixedNumberPre, fixedNumPreHist); - EXPECT_TRUE(std::all_of(fixedNumPreHist.cbegin(), fixedNumPreHist.cend(), - [](unsigned int colLength) { return colLength == 10; })); - - // **TODO** we could also build a histogram of postsynaptic neurons and check that they are approximately uniformly distributed -} - diff --git a/tests/features/current_source_rng_normal/Makefile b/tests/features/current_source_rng_normal/Makefile deleted file mode 120000 index 1302b13ca5..0000000000 --- a/tests/features/current_source_rng_normal/Makefile +++ /dev/null @@ -1 +0,0 @@ -../../utils/Makefile \ No newline at end of file diff --git a/tests/features/current_source_rng_normal/current_source_rng_normal.sln b/tests/features/current_source_rng_normal/current_source_rng_normal.sln deleted file mode 100644 index 18bd101b0d..0000000000 --- a/tests/features/current_source_rng_normal/current_source_rng_normal.sln +++ /dev/null @@ -1,30 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.30501.0 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "current_source_rng_normal", "current_source_rng_normal.vcxproj", "{739BEB3B-5C6A-42A7-A21B-C007D0FF4069}" - ProjectSection(ProjectDependencies) = postProject - {87BACC48-B452-41D9-898B-20A5866EEBCD} = {87BACC48-B452-41D9-898B-20A5866EEBCD} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "current_source_rng_normal_CODE\runner.vcxproj", "{87BACC48-B452-41D9-898B-20A5866EEBCD}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {739BEB3B-5C6A-42A7-A21B-C007D0FF4069}.Debug|x64.ActiveCfg = Debug|x64 - {739BEB3B-5C6A-42A7-A21B-C007D0FF4069}.Debug|x64.Build.0 = Debug|x64 - {739BEB3B-5C6A-42A7-A21B-C007D0FF4069}.Release|x64.ActiveCfg = Release|x64 - {739BEB3B-5C6A-42A7-A21B-C007D0FF4069}.Release|x64.Build.0 = Release|x64 - {87BACC48-B452-41D9-898B-20A5866EEBCD}.Debug|x64.ActiveCfg = Debug|x64 - {87BACC48-B452-41D9-898B-20A5866EEBCD}.Debug|x64.Build.0 = Debug|x64 - {87BACC48-B452-41D9-898B-20A5866EEBCD}.Release|x64.ActiveCfg = Release|x64 - {87BACC48-B452-41D9-898B-20A5866EEBCD}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/tests/features/current_source_rng_normal/current_source_rng_normal.vcxproj b/tests/features/current_source_rng_normal/current_source_rng_normal.vcxproj deleted file mode 100644 index f22a430385..0000000000 --- a/tests/features/current_source_rng_normal/current_source_rng_normal.vcxproj +++ /dev/null @@ -1,63 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {739BEB3B-5C6A-42A7-A21B-C007D0FF4069} - - - - - - - - - Application - true - $(DefaultPlatformToolset) - true - MultiByte - - - - - - - - - - ./ - $(Platform)\$(Configuration)\ - test - - - - Level3 - MaxSpeed - Disabled - true - true - true - current_source_rng_normal_CODE;$(GTEST_DIR);$(GTEST_DIR)/include - _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) - - - true - true - true - runner_Release.lib;%(AdditionalDependencies) - runner_Debug.lib;%(AdditionalDependencies) - - - - - - diff --git a/tests/features/current_source_rng_normal/model.cc b/tests/features/current_source_rng_normal/model.cc deleted file mode 100644 index f75d2c96ad..0000000000 --- a/tests/features/current_source_rng_normal/model.cc +++ /dev/null @@ -1,52 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file current_source_rng_normal/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - -//---------------------------------------------------------------------------- -// Neuron -//---------------------------------------------------------------------------- -class Neuron : public NeuronModels::Base -{ -public: - DECLARE_MODEL(Neuron, 0, 1); - - SET_SIM_CODE("$(x)= $(Isyn);\n"); - - SET_VARS({{"x", "scalar"}}); -}; -IMPLEMENT_MODEL(Neuron); - - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - CurrentSourceModels::GaussianNoise::ParamValues paramVals( - 0.0, // 2 - mean - 1.0); // 3 - standard deviation - - model.setDT(0.1); - model.setName("current_source_rng_normal"); - - model.addNeuronPopulation("Pop", 1000, {}, Neuron::VarValues(0.0)); - - model.addCurrentSource("CurrentSource", - "Pop", - paramVals, {}); - - model.setPrecision(GENN_FLOAT); -} diff --git a/tests/features/current_source_rng_normal/runner_guid.txt b/tests/features/current_source_rng_normal/runner_guid.txt deleted file mode 100644 index e949a4c871..0000000000 --- a/tests/features/current_source_rng_normal/runner_guid.txt +++ /dev/null @@ -1 +0,0 @@ -87BACC48-B452-41D9-898B-20A5866EEBCD diff --git a/tests/features/current_source_rng_normal/test.cc b/tests/features/current_source_rng_normal/test.cc deleted file mode 100644 index 6f66b04df0..0000000000 --- a/tests/features/current_source_rng_normal/test.cc +++ /dev/null @@ -1,45 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file current_source_rng_normal/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "current_source_rng_normal_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test_samples.h" -#include "../../utils/stats.h" - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTestSamples -{ -public: - //---------------------------------------------------------------------------- - // SimulationTestHistogram virtuals - //---------------------------------------------------------------------------- - virtual double Test(std::vector &samples) const - { - // Perform Kolmogorov-Smirnov test - double d; - double prob; - std::tie(d, prob) = Stats::kolmogorovSmirnovTest(samples, Stats::normalCDF); - - return prob; - } -}; - -TEST_F(SimTest, CurrentSourceRngNormal) -{ - // Check p value passes 95% confidence interval - EXPECT_GT(Simulate(), 0.05); -} diff --git a/tests/features/custom_connectivity_update_rng/Makefile b/tests/features/custom_connectivity_update_rng/Makefile deleted file mode 120000 index 1302b13ca5..0000000000 --- a/tests/features/custom_connectivity_update_rng/Makefile +++ /dev/null @@ -1 +0,0 @@ -../../utils/Makefile \ No newline at end of file diff --git a/tests/features/custom_connectivity_update_rng/model.cc b/tests/features/custom_connectivity_update_rng/model.cc deleted file mode 100644 index 09709402a5..0000000000 --- a/tests/features/custom_connectivity_update_rng/model.cc +++ /dev/null @@ -1,84 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file custom_connectivity_update_rng/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - -class TestNeuron : public NeuronModels::Base -{ -public: - DECLARE_MODEL(TestNeuron, 0, 1); - - SET_VARS({{"V","scalar"}}); -}; -IMPLEMENT_MODEL(TestNeuron); - -class RNGTest : public CustomConnectivityUpdateModels::Base -{ -public: - DECLARE_CUSTOM_CONNECTIVITY_UPDATE_MODEL(RNGTest, 0, 0, 0, 0, 0, 0, 0); - - SET_EXTRA_GLOBAL_PARAMS({{"Output", "scalar*"}}); - SET_ROW_UPDATE_CODE( - "for(int j = 0; j < 1000; j++) {\n" - " $(Output)[$(id_pre) + (j * $(num_pre))] = $(gennrand_uniform);\n" - "}\n"); - -}; -IMPLEMENT_MODEL(RNGTest); - -class HostRNGTest : public CustomConnectivityUpdateModels::Base -{ -public: - DECLARE_CUSTOM_CONNECTIVITY_UPDATE_MODEL(HostRNGTest, 2, 0, 1, 0, 0, 0, 0); - - SET_PARAM_NAMES({"min", "max"}); - SET_PRE_VARS({{"Output", "scalar"}}); - SET_HOST_UPDATE_CODE( - "std::uniform_real_distribution dist($(min), $(max));\n" - "for(int i = 0; i < $(num_pre); i++){\n" - " $(Output)[i] = dist($(rng));\n" - "}\n" - "$(pushOutputToDevice);\n"); -}; -IMPLEMENT_MODEL(HostRNGTest); - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - model.setDT(1.0); - model.setName("custom_connectivity_update_rng"); - - model.addNeuronPopulation("SpikeSource", 1000, {}, {}); - model.addNeuronPopulation("Neuron", 1000, {}, {0.0}); - - model.addSynapsePopulation( - "Syn1", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, "SpikeSource", "Neuron", - {}, {1.0}, - {}, {}, - initConnectivity({0.1})); - - model.addCustomConnectivityUpdate( - "RNGTest", "RNGTest", "Syn1", - {}, {}, {}, {}, - {}, {}, {}); - model.addCustomConnectivityUpdate( - "HostRNGTest", "RNGTest", "Syn1", - {0.0, 1.0}, {}, {0.0}, {}, - {}, {}, {}); - - model.setPrecision(GENN_FLOAT); -} diff --git a/tests/features/custom_connectivity_update_rng/test.cc b/tests/features/custom_connectivity_update_rng/test.cc deleted file mode 100644 index 4e99556566..0000000000 --- a/tests/features/custom_connectivity_update_rng/test.cc +++ /dev/null @@ -1,55 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file custom_connectivity_update_rng/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "custom_connectivity_update_rng_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test.h" -#include "../../utils/stats.h" - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTest -{ - virtual void Init() final - { - allocateOutputRNGTest(1000 * 1000); - } -}; - - -TEST_F(SimTest, CustomConnectivityUpdate) -{ - // Launch custom update to generate a bunch of random numbers - updateRNGTest(); - - pullOutputRNGTestFromDevice(1000 * 1000); - - // Perform Kolmogorov-Smirnov test on contents of extra global parameter populated on device - double d; - double prob; - std::vector samplesDevice(&OutputRNGTest[0], &OutputRNGTest[1000 * 1000]); - std::tie(d, prob) = Stats::kolmogorovSmirnovTest(samplesDevice, Stats::uniformCDF); - - // Check p value passes 95% confidence interval - EXPECT_GT(prob, 0.05); - - // Perform Kolmogorov-Smirnov test on contents of presynaptic variable populated on host - std::vector samplesHost(&OutputHostRNGTest[0], &OutputHostRNGTest[1000]); - std::tie(d, prob) = Stats::kolmogorovSmirnovTest(samplesHost, Stats::uniformCDF); - - // Check p value passes 95% confidence interval - EXPECT_GT(prob, 0.05); -} diff --git a/tests/features/decode_matrix_conn_gen_globalg_bitmask/Makefile b/tests/features/decode_matrix_conn_gen_globalg_bitmask/Makefile deleted file mode 120000 index 1302b13ca5..0000000000 --- a/tests/features/decode_matrix_conn_gen_globalg_bitmask/Makefile +++ /dev/null @@ -1 +0,0 @@ -../../utils/Makefile \ No newline at end of file diff --git a/tests/features/decode_matrix_conn_gen_globalg_bitmask/decode_matrix_conn_gen_globalg_bitmask.sln b/tests/features/decode_matrix_conn_gen_globalg_bitmask/decode_matrix_conn_gen_globalg_bitmask.sln deleted file mode 100644 index 4c1de1a04b..0000000000 --- a/tests/features/decode_matrix_conn_gen_globalg_bitmask/decode_matrix_conn_gen_globalg_bitmask.sln +++ /dev/null @@ -1,30 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.30501.0 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "decode_matrix_conn_gen_globalg_bitmask", "decode_matrix_conn_gen_globalg_bitmask.vcxproj", "{8E4BF1F5-E608-4C0A-8916-7954FFE79896}" - ProjectSection(ProjectDependencies) = postProject - {E200B114-4923-42C6-BF7E-684B7765F910} = {E200B114-4923-42C6-BF7E-684B7765F910} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "decode_matrix_conn_gen_globalg_bitmask_CODE\runner.vcxproj", "{E200B114-4923-42C6-BF7E-684B7765F910}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {8E4BF1F5-E608-4C0A-8916-7954FFE79896}.Debug|x64.ActiveCfg = Debug|x64 - {8E4BF1F5-E608-4C0A-8916-7954FFE79896}.Debug|x64.Build.0 = Debug|x64 - {8E4BF1F5-E608-4C0A-8916-7954FFE79896}.Release|x64.ActiveCfg = Release|x64 - {8E4BF1F5-E608-4C0A-8916-7954FFE79896}.Release|x64.Build.0 = Release|x64 - {E200B114-4923-42C6-BF7E-684B7765F910}.Debug|x64.ActiveCfg = Debug|x64 - {E200B114-4923-42C6-BF7E-684B7765F910}.Debug|x64.Build.0 = Debug|x64 - {E200B114-4923-42C6-BF7E-684B7765F910}.Release|x64.ActiveCfg = Release|x64 - {E200B114-4923-42C6-BF7E-684B7765F910}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/tests/features/decode_matrix_conn_gen_globalg_bitmask/decode_matrix_conn_gen_globalg_bitmask.vcxproj b/tests/features/decode_matrix_conn_gen_globalg_bitmask/decode_matrix_conn_gen_globalg_bitmask.vcxproj deleted file mode 100644 index 15fe2177cb..0000000000 --- a/tests/features/decode_matrix_conn_gen_globalg_bitmask/decode_matrix_conn_gen_globalg_bitmask.vcxproj +++ /dev/null @@ -1,63 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {8E4BF1F5-E608-4C0A-8916-7954FFE79896} - - - - - - - - - Application - true - $(DefaultPlatformToolset) - true - MultiByte - - - - - - - - - - ./ - $(Platform)\$(Configuration)\ - test - - - - Level3 - MaxSpeed - Disabled - true - true - true - decode_matrix_conn_gen_globalg_bitmask_CODE;$(GTEST_DIR);$(GTEST_DIR)/include - _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) - - - true - true - true - runner_Release.lib;%(AdditionalDependencies) - runner_Debug.lib;%(AdditionalDependencies) - - - - - - diff --git a/tests/features/decode_matrix_conn_gen_globalg_bitmask/model.cc b/tests/features/decode_matrix_conn_gen_globalg_bitmask/model.cc deleted file mode 100644 index 3cecf9a713..0000000000 --- a/tests/features/decode_matrix_conn_gen_globalg_bitmask/model.cc +++ /dev/null @@ -1,80 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_conn_gen_globalg_bitmask/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - -//---------------------------------------------------------------------------- -// Decoder -//---------------------------------------------------------------------------- -class Decoder : public InitSparseConnectivitySnippet::Base -{ -public: - DECLARE_SNIPPET(Decoder, 0); - - SET_ROW_BUILD_CODE( - "if(j < $(num_post)) {\n" - " const unsigned int jValue = (1 << j);\n" - " if((($(id_pre) + 1) & jValue) != 0)\n" - " {\n" - " $(addSynapse, j);\n" - " }\n" - "}\n" - "else {\n" - " $(endRow);\n" - "}\n" - "j++;\n"); - SET_ROW_BUILD_STATE_VARS({{"j", "unsigned int", 0}}); -}; -IMPLEMENT_SNIPPET(Decoder); - -//---------------------------------------------------------------------------- -// Neuron -//---------------------------------------------------------------------------- -class Neuron : public NeuronModels::Base -{ -public: - DECLARE_MODEL(Neuron, 0, 1); - - SET_SIM_CODE("$(x)= $(Isyn);\n"); - - SET_VARS({{"x", "scalar"}}); -}; - -IMPLEMENT_MODEL(Neuron); - - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - model.setDT(0.1); - model.setName("decode_matrix_conn_gen_globalg_bitmask"); - - // Static synapse parameters - WeightUpdateModels::StaticPulse::VarValues staticSynapseInit(1.0); // 0 - Wij (nA) - - model.addNeuronPopulation("Pre", 10, {}, {}); - model.addNeuronPopulation("Post", 4, {}, Neuron::VarValues(0.0)); - - - model.addSynapsePopulation( - "Syn", SynapseMatrixType::BITMASK_GLOBALG, NO_DELAY, "Pre", "Post", - {}, staticSynapseInit, - {}, {}, - initConnectivity({})); - - model.setPrecision(GENN_FLOAT); -} diff --git a/tests/features/decode_matrix_conn_gen_globalg_bitmask/runner_guid.txt b/tests/features/decode_matrix_conn_gen_globalg_bitmask/runner_guid.txt deleted file mode 100644 index 83db778cc9..0000000000 --- a/tests/features/decode_matrix_conn_gen_globalg_bitmask/runner_guid.txt +++ /dev/null @@ -1 +0,0 @@ -E200B114-4923-42C6-BF7E-684B7765F910 diff --git a/tests/features/decode_matrix_conn_gen_globalg_bitmask/test.cc b/tests/features/decode_matrix_conn_gen_globalg_bitmask/test.cc deleted file mode 100644 index 27b82063a7..0000000000 --- a/tests/features/decode_matrix_conn_gen_globalg_bitmask/test.cc +++ /dev/null @@ -1,38 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_conn_gen_globalg_bitmask/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "decode_matrix_conn_gen_globalg_bitmask_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test_decoder_matrix.h" - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTestDecoderMatrix -{ -public: - //---------------------------------------------------------------------------- - // SimulationTest virtuals - //---------------------------------------------------------------------------- - virtual void Init() - { - } -}; - -TEST_F(SimTest, DecodeMatrixConnGenGlobalgBitmask) -{ - // Check total error is less than some tolerance - EXPECT_TRUE(Simulate()); -} diff --git a/tests/features/decode_matrix_conn_gen_globalg_ragged/Makefile b/tests/features/decode_matrix_conn_gen_globalg_ragged/Makefile deleted file mode 120000 index 1302b13ca5..0000000000 --- a/tests/features/decode_matrix_conn_gen_globalg_ragged/Makefile +++ /dev/null @@ -1 +0,0 @@ -../../utils/Makefile \ No newline at end of file diff --git a/tests/features/decode_matrix_conn_gen_globalg_ragged/decode_matrix_conn_gen_globalg_ragged.sln b/tests/features/decode_matrix_conn_gen_globalg_ragged/decode_matrix_conn_gen_globalg_ragged.sln deleted file mode 100644 index e092643a9e..0000000000 --- a/tests/features/decode_matrix_conn_gen_globalg_ragged/decode_matrix_conn_gen_globalg_ragged.sln +++ /dev/null @@ -1,30 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.30501.0 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "decode_matrix_conn_gen_globalg_ragged", "decode_matrix_conn_gen_globalg_ragged.vcxproj", "{01EB0748-92F7-4FD9-9FB5-84906E60AD6B}" - ProjectSection(ProjectDependencies) = postProject - {F69A9DDA-EFE6-4DAE-B256-6CAB64E743BA} = {F69A9DDA-EFE6-4DAE-B256-6CAB64E743BA} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "decode_matrix_conn_gen_globalg_ragged_CODE\runner.vcxproj", "{F69A9DDA-EFE6-4DAE-B256-6CAB64E743BA}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {01EB0748-92F7-4FD9-9FB5-84906E60AD6B}.Debug|x64.ActiveCfg = Debug|x64 - {01EB0748-92F7-4FD9-9FB5-84906E60AD6B}.Debug|x64.Build.0 = Debug|x64 - {01EB0748-92F7-4FD9-9FB5-84906E60AD6B}.Release|x64.ActiveCfg = Release|x64 - {01EB0748-92F7-4FD9-9FB5-84906E60AD6B}.Release|x64.Build.0 = Release|x64 - {F69A9DDA-EFE6-4DAE-B256-6CAB64E743BA}.Debug|x64.ActiveCfg = Debug|x64 - {F69A9DDA-EFE6-4DAE-B256-6CAB64E743BA}.Debug|x64.Build.0 = Debug|x64 - {F69A9DDA-EFE6-4DAE-B256-6CAB64E743BA}.Release|x64.ActiveCfg = Release|x64 - {F69A9DDA-EFE6-4DAE-B256-6CAB64E743BA}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/tests/features/decode_matrix_conn_gen_globalg_ragged/decode_matrix_conn_gen_globalg_ragged.vcxproj b/tests/features/decode_matrix_conn_gen_globalg_ragged/decode_matrix_conn_gen_globalg_ragged.vcxproj deleted file mode 100644 index 617fa4a42a..0000000000 --- a/tests/features/decode_matrix_conn_gen_globalg_ragged/decode_matrix_conn_gen_globalg_ragged.vcxproj +++ /dev/null @@ -1,63 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {01EB0748-92F7-4FD9-9FB5-84906E60AD6B} - - - - - - - - - Application - true - $(DefaultPlatformToolset) - true - MultiByte - - - - - - - - - - ./ - $(Platform)\$(Configuration)\ - test - - - - Level3 - MaxSpeed - Disabled - true - true - true - decode_matrix_conn_gen_globalg_ragged_CODE;$(GTEST_DIR);$(GTEST_DIR)/include - _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) - - - true - true - true - runner_Release.lib;%(AdditionalDependencies) - runner_Debug.lib;%(AdditionalDependencies) - - - - - - diff --git a/tests/features/decode_matrix_conn_gen_globalg_ragged/model.cc b/tests/features/decode_matrix_conn_gen_globalg_ragged/model.cc deleted file mode 100644 index 473685d8df..0000000000 --- a/tests/features/decode_matrix_conn_gen_globalg_ragged/model.cc +++ /dev/null @@ -1,80 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_conn_gen_globalg_ragged/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - -//---------------------------------------------------------------------------- -// Decoder -//---------------------------------------------------------------------------- -class Decoder : public InitSparseConnectivitySnippet::Base -{ -public: - DECLARE_SNIPPET(Decoder, 0); - - SET_ROW_BUILD_CODE( - "if(j < $(num_post)) {\n" - " const unsigned int jValue = (1 << j);\n" - " if((($(id_pre) + 1) & jValue) != 0)\n" - " {\n" - " $(addSynapse, j);\n" - " }\n" - "}\n" - "else {\n" - " $(endRow);\n" - "}\n" - "j++;\n"); - SET_ROW_BUILD_STATE_VARS({{"j", "unsigned int", 0}}); -}; -IMPLEMENT_SNIPPET(Decoder); - -//---------------------------------------------------------------------------- -// Neuron -//---------------------------------------------------------------------------- -class Neuron : public NeuronModels::Base -{ -public: - DECLARE_MODEL(Neuron, 0, 1); - - SET_SIM_CODE("$(x)= $(Isyn);\n"); - - SET_VARS({{"x", "scalar"}}); -}; - -IMPLEMENT_MODEL(Neuron); - - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - model.setDT(0.1); - model.setName("decode_matrix_conn_gen_globalg_ragged"); - - // Static synapse parameters - WeightUpdateModels::StaticPulse::VarValues staticSynapseInit(1.0); // 0 - Wij (nA) - - model.addNeuronPopulation("Pre", 10, {}, {}); - model.addNeuronPopulation("Post", 4, {}, Neuron::VarValues(0.0)); - - - model.addSynapsePopulation( - "Syn", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, "Pre", "Post", - {}, staticSynapseInit, - {}, {}, - initConnectivity({})); - - model.setPrecision(GENN_FLOAT); -} diff --git a/tests/features/decode_matrix_conn_gen_globalg_ragged/runner_guid.txt b/tests/features/decode_matrix_conn_gen_globalg_ragged/runner_guid.txt deleted file mode 100644 index f1cc98614b..0000000000 --- a/tests/features/decode_matrix_conn_gen_globalg_ragged/runner_guid.txt +++ /dev/null @@ -1 +0,0 @@ -F69A9DDA-EFE6-4DAE-B256-6CAB64E743BA diff --git a/tests/features/decode_matrix_conn_gen_globalg_ragged/test.cc b/tests/features/decode_matrix_conn_gen_globalg_ragged/test.cc deleted file mode 100644 index fc742ee20d..0000000000 --- a/tests/features/decode_matrix_conn_gen_globalg_ragged/test.cc +++ /dev/null @@ -1,38 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_conn_gen_globalg_ragged/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "decode_matrix_conn_gen_globalg_ragged_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test_decoder_matrix.h" - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTestDecoderMatrix -{ -public: - //---------------------------------------------------------------------------- - // SimulationTest virtuals - //---------------------------------------------------------------------------- - virtual void Init() - { - } -}; - -TEST_F(SimTest, DecodeMatrixConnGenGlobalgRagged) -{ - // Check total error is less than some tolerance - EXPECT_TRUE(Simulate()); -} diff --git a/tests/features/decode_matrix_conn_gen_individualg_dense/Makefile b/tests/features/decode_matrix_conn_gen_individualg_dense/Makefile deleted file mode 100644 index 7a76a5c854..0000000000 --- a/tests/features/decode_matrix_conn_gen_individualg_dense/Makefile +++ /dev/null @@ -1,14 +0,0 @@ -CXXFLAGS +=-std=c++11 -Wall -Wpedantic -Wextra -I $(GTEST_DIR) -isystem $(GTEST_DIR)/include - -.PHONY: all clean generated_code - -all: test - -test: test.cc generated_code - $(CXX) $(CXXFLAGS) test.cc $(GTEST_DIR)/src/gtest-all.cc $(GTEST_DIR)/src/gtest_main.cc -o test -L$(SIM_CODE) -pthread -lrunner -Wl,-rpath $(SIM_CODE) - -generated_code: - $(MAKE) -C $(SIM_CODE) - -clean: - @rm -f test $(SIM_CODE)/librunner.so $(SIM_CODE)/*.o $(SIM_CODE)/*.d default.profraw diff --git a/tests/features/decode_matrix_conn_gen_individualg_dense/decode_matrix_conn_gen_individualg_dense.sln b/tests/features/decode_matrix_conn_gen_individualg_dense/decode_matrix_conn_gen_individualg_dense.sln deleted file mode 100644 index 89ecf654cd..0000000000 --- a/tests/features/decode_matrix_conn_gen_individualg_dense/decode_matrix_conn_gen_individualg_dense.sln +++ /dev/null @@ -1,30 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.30501.0 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "decode_matrix_conn_gen_individualg_dense", "decode_matrix_conn_gen_individualg_dense.vcxproj", "{AD02B67D-F180-4D64-B20D-3CE8DADC07C2}" - ProjectSection(ProjectDependencies) = postProject - {A35FE766-F7D8-47C8-A463-62E676D5DAE} = {A35FE766-F7D8-47C8-A463-62E676D5DAE} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "decode_matrix_conn_gen_individualg_dense_CODE\runner.vcxproj", "{A35FE766-F7D8-47C8-A463-62E676D5DAE}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {AD02B67D-F180-4D64-B20D-3CE8DADC07C2}.Debug|x64.ActiveCfg = Debug|x64 - {AD02B67D-F180-4D64-B20D-3CE8DADC07C2}.Debug|x64.Build.0 = Debug|x64 - {AD02B67D-F180-4D64-B20D-3CE8DADC07C2}.Release|x64.ActiveCfg = Release|x64 - {AD02B67D-F180-4D64-B20D-3CE8DADC07C2}.Release|x64.Build.0 = Release|x64 - {A35FE766-F7D8-47C8-A463-62E676D5DAE}.Debug|x64.ActiveCfg = Debug|x64 - {A35FE766-F7D8-47C8-A463-62E676D5DAE}.Debug|x64.Build.0 = Debug|x64 - {A35FE766-F7D8-47C8-A463-62E676D5DAE}.Release|x64.ActiveCfg = Release|x64 - {A35FE766-F7D8-47C8-A463-62E676D5DAE}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/tests/features/decode_matrix_conn_gen_individualg_dense/decode_matrix_conn_gen_individualg_dense.vcxproj b/tests/features/decode_matrix_conn_gen_individualg_dense/decode_matrix_conn_gen_individualg_dense.vcxproj deleted file mode 100644 index a834e77d43..0000000000 --- a/tests/features/decode_matrix_conn_gen_individualg_dense/decode_matrix_conn_gen_individualg_dense.vcxproj +++ /dev/null @@ -1,63 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {AD02B67D-F180-4D64-B20D-3CE8DADC07C2} - - - - - - - - - Application - true - $(DefaultPlatformToolset) - true - MultiByte - - - - - - - - - - ./ - $(Platform)\$(Configuration)\ - test - - - - Level3 - MaxSpeed - Disabled - true - true - true - decode_matrix_conn_gen_individualg_dense_CODE;$(GTEST_DIR);$(GTEST_DIR)/include - _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) - - - true - true - true - runner_Release.lib;%(AdditionalDependencies) - runner_Debug.lib;%(AdditionalDependencies) - - - - - - diff --git a/tests/features/decode_matrix_conn_gen_individualg_dense/model.cc b/tests/features/decode_matrix_conn_gen_individualg_dense/model.cc deleted file mode 100644 index 54c3e3bf82..0000000000 --- a/tests/features/decode_matrix_conn_gen_individualg_dense/model.cc +++ /dev/null @@ -1,68 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_individualg_dense/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - -//---------------------------------------------------------------------------- -// Decoder -//---------------------------------------------------------------------------- -class Decoder : public InitVarSnippet::Base -{ -public: - DECLARE_SNIPPET(Decoder, 0); - - SET_CODE( - "const unsigned int j_value = (1 << $(id_post));\n" - "$(value) = ((($(id_pre) + 1) & j_value) != 0) ? 1.0f : 0.0f;\n") -}; -IMPLEMENT_SNIPPET(Decoder); - -//---------------------------------------------------------------------------- -// Neuron -//---------------------------------------------------------------------------- -class Neuron : public NeuronModels::Base -{ -public: - DECLARE_MODEL(Neuron, 0, 1); - - SET_SIM_CODE("$(x)= $(Isyn);\n"); - - SET_VARS({{"x", "scalar"}}); -}; - -IMPLEMENT_MODEL(Neuron); - - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - model.setDT(0.1); - model.setName("decode_matrix_conn_gen_individualg_dense"); - - // Static synapse parameters - WeightUpdateModels::StaticPulse::VarValues staticSynapseInit(initVar()); // 0 - Wij (nA) - - model.addNeuronPopulation("Pre", 10, {}, {}); - model.addNeuronPopulation("Post", 4, {}, Neuron::VarValues(0.0)); - - model.addSynapsePopulation( - "Syn", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, "Pre", "Post", - {}, staticSynapseInit, - {}, {}); - - model.setPrecision(GENN_FLOAT); -} diff --git a/tests/features/decode_matrix_conn_gen_individualg_dense/runner_guid.txt b/tests/features/decode_matrix_conn_gen_individualg_dense/runner_guid.txt deleted file mode 100644 index 214592df5b..0000000000 --- a/tests/features/decode_matrix_conn_gen_individualg_dense/runner_guid.txt +++ /dev/null @@ -1 +0,0 @@ -A35FE766-F7D8-47C8-A463-62E676D5DAE diff --git a/tests/features/decode_matrix_conn_gen_individualg_dense/test.cc b/tests/features/decode_matrix_conn_gen_individualg_dense/test.cc deleted file mode 100644 index 746695393d..0000000000 --- a/tests/features/decode_matrix_conn_gen_individualg_dense/test.cc +++ /dev/null @@ -1,31 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_conn_genn_individualg_dense/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "decode_matrix_conn_gen_individualg_dense_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test_decoder_matrix.h" - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTestDecoderMatrix -{ -}; - -TEST_F(SimTest, DecodeMatrixConnGenIndividualgDense) -{ - // Check total error is less than some tolerance - EXPECT_TRUE(Simulate()); -} diff --git a/tests/features/decode_matrix_conn_gen_individualg_ragged/Makefile b/tests/features/decode_matrix_conn_gen_individualg_ragged/Makefile deleted file mode 120000 index 1302b13ca5..0000000000 --- a/tests/features/decode_matrix_conn_gen_individualg_ragged/Makefile +++ /dev/null @@ -1 +0,0 @@ -../../utils/Makefile \ No newline at end of file diff --git a/tests/features/decode_matrix_conn_gen_individualg_ragged/decode_matrix_conn_gen_individualg_ragged.sln b/tests/features/decode_matrix_conn_gen_individualg_ragged/decode_matrix_conn_gen_individualg_ragged.sln deleted file mode 100644 index 0076561ab1..0000000000 --- a/tests/features/decode_matrix_conn_gen_individualg_ragged/decode_matrix_conn_gen_individualg_ragged.sln +++ /dev/null @@ -1,30 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.30501.0 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "decode_matrix_conn_gen_individualg_ragged", "decode_matrix_conn_gen_individualg_ragged.vcxproj", "{C0CE9EA5-219D-4D1E-AAAB-31782C0F8446}" - ProjectSection(ProjectDependencies) = postProject - {84F5E3C1-D8A1-482F-BDE9-0D2AEFCEA1B1} = {84F5E3C1-D8A1-482F-BDE9-0D2AEFCEA1B1} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "decode_matrix_conn_gen_individualg_ragged_CODE\runner.vcxproj", "{84F5E3C1-D8A1-482F-BDE9-0D2AEFCEA1B1}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {C0CE9EA5-219D-4D1E-AAAB-31782C0F8446}.Debug|x64.ActiveCfg = Debug|x64 - {C0CE9EA5-219D-4D1E-AAAB-31782C0F8446}.Debug|x64.Build.0 = Debug|x64 - {C0CE9EA5-219D-4D1E-AAAB-31782C0F8446}.Release|x64.ActiveCfg = Release|x64 - {C0CE9EA5-219D-4D1E-AAAB-31782C0F8446}.Release|x64.Build.0 = Release|x64 - {84F5E3C1-D8A1-482F-BDE9-0D2AEFCEA1B1}.Debug|x64.ActiveCfg = Debug|x64 - {84F5E3C1-D8A1-482F-BDE9-0D2AEFCEA1B1}.Debug|x64.Build.0 = Debug|x64 - {84F5E3C1-D8A1-482F-BDE9-0D2AEFCEA1B1}.Release|x64.ActiveCfg = Release|x64 - {84F5E3C1-D8A1-482F-BDE9-0D2AEFCEA1B1}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/tests/features/decode_matrix_conn_gen_individualg_ragged/decode_matrix_conn_gen_individualg_ragged.vcxproj b/tests/features/decode_matrix_conn_gen_individualg_ragged/decode_matrix_conn_gen_individualg_ragged.vcxproj deleted file mode 100644 index 1c3e60f6aa..0000000000 --- a/tests/features/decode_matrix_conn_gen_individualg_ragged/decode_matrix_conn_gen_individualg_ragged.vcxproj +++ /dev/null @@ -1,63 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {C0CE9EA5-219D-4D1E-AAAB-31782C0F8446} - - - - - - - - - Application - true - $(DefaultPlatformToolset) - true - MultiByte - - - - - - - - - - ./ - $(Platform)\$(Configuration)\ - test - - - - Level3 - MaxSpeed - Disabled - true - true - true - decode_matrix_conn_gen_individualg_ragged_CODE;$(GTEST_DIR);$(GTEST_DIR)/include - _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) - - - true - true - true - runner_Release.lib;%(AdditionalDependencies) - runner_Debug.lib;%(AdditionalDependencies) - - - - - - diff --git a/tests/features/decode_matrix_conn_gen_individualg_ragged/model.cc b/tests/features/decode_matrix_conn_gen_individualg_ragged/model.cc deleted file mode 100644 index 6d252294c5..0000000000 --- a/tests/features/decode_matrix_conn_gen_individualg_ragged/model.cc +++ /dev/null @@ -1,80 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_conn_gen_individualg_ragged/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - -//---------------------------------------------------------------------------- -// Decoder -//---------------------------------------------------------------------------- -class Decoder : public InitSparseConnectivitySnippet::Base -{ -public: - DECLARE_SNIPPET(Decoder, 0); - - SET_ROW_BUILD_CODE( - "if(j < $(num_post)) {\n" - " const unsigned int jValue = (1 << j);\n" - " if((($(id_pre) + 1) & jValue) != 0)\n" - " {\n" - " $(addSynapse, j);\n" - " }\n" - "}\n" - "else {\n" - " $(endRow);\n" - "}\n" - "j++;\n"); - SET_ROW_BUILD_STATE_VARS({{"j", "unsigned int", 0}}); -}; -IMPLEMENT_SNIPPET(Decoder); - -//---------------------------------------------------------------------------- -// Neuron -//---------------------------------------------------------------------------- -class Neuron : public NeuronModels::Base -{ -public: - DECLARE_MODEL(Neuron, 0, 1); - - SET_SIM_CODE("$(x)= $(Isyn);\n"); - - SET_VARS({{"x", "scalar"}}); -}; - -IMPLEMENT_MODEL(Neuron); - - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - model.setDT(0.1); - model.setName("decode_matrix_conn_gen_individualg_ragged"); - - // Static synapse parameters - WeightUpdateModels::StaticPulse::VarValues staticSynapseInit(1.0); // 0 - Wij (nA) - - model.addNeuronPopulation("Pre", 10, {}, {}); - model.addNeuronPopulation("Post", 4, {}, Neuron::VarValues(0.0)); - - - model.addSynapsePopulation( - "Syn", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, "Pre", "Post", - {}, staticSynapseInit, - {}, {}, - initConnectivity({})); - - model.setPrecision(GENN_FLOAT); -} diff --git a/tests/features/decode_matrix_conn_gen_individualg_ragged/runner_guid.txt b/tests/features/decode_matrix_conn_gen_individualg_ragged/runner_guid.txt deleted file mode 100644 index 18ed41b505..0000000000 --- a/tests/features/decode_matrix_conn_gen_individualg_ragged/runner_guid.txt +++ /dev/null @@ -1 +0,0 @@ -84F5E3C1-D8A1-482F-BDE9-0D2AEFCEA1B1 diff --git a/tests/features/decode_matrix_conn_gen_individualg_ragged/test.cc b/tests/features/decode_matrix_conn_gen_individualg_ragged/test.cc deleted file mode 100644 index d343a27abc..0000000000 --- a/tests/features/decode_matrix_conn_gen_individualg_ragged/test.cc +++ /dev/null @@ -1,38 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_conn_gen_individualg_ragged/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "decode_matrix_conn_gen_individualg_ragged_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test_decoder_matrix.h" - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTestDecoderMatrix -{ -public: - //---------------------------------------------------------------------------- - // SimulationTest virtuals - //---------------------------------------------------------------------------- - virtual void Init() - { - } -}; - -TEST_F(SimTest, DecodeMatrixConnGenIndividualgRagged) -{ - // Check total error is less than some tolerance - EXPECT_TRUE(Simulate()); -} diff --git a/tests/features/neuron_rng_normal/Makefile b/tests/features/neuron_rng_normal/Makefile deleted file mode 120000 index 1302b13ca5..0000000000 --- a/tests/features/neuron_rng_normal/Makefile +++ /dev/null @@ -1 +0,0 @@ -../../utils/Makefile \ No newline at end of file diff --git a/tests/features/neuron_rng_normal/model.cc b/tests/features/neuron_rng_normal/model.cc deleted file mode 100644 index 8476acfbc0..0000000000 --- a/tests/features/neuron_rng_normal/model.cc +++ /dev/null @@ -1,45 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file neuron_rng_normal/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - -//---------------------------------------------------------------------------- -// Neuron -//---------------------------------------------------------------------------- -class Neuron : public NeuronModels::Base -{ -public: - DECLARE_MODEL(Neuron, 0, 1); - - SET_SIM_CODE("$(x)= $(gennrand_normal);\n"); - - SET_VARS({{"x", "scalar"}}); -}; - -IMPLEMENT_MODEL(Neuron); - - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - model.setDT(0.1); - model.setName("neuron_rng_normal"); - - model.addNeuronPopulation("Pop", 1000, {}, Neuron::VarValues(0.0)); - - model.setPrecision(GENN_FLOAT); -} diff --git a/tests/features/neuron_rng_normal/neuron_rng_normal.sln b/tests/features/neuron_rng_normal/neuron_rng_normal.sln deleted file mode 100644 index b5be8a8f4b..0000000000 --- a/tests/features/neuron_rng_normal/neuron_rng_normal.sln +++ /dev/null @@ -1,30 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.30501.0 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "neuron_rng_normal", "neuron_rng_normal.vcxproj", "{141E75D3-7BED-45A9-8D17-96C91A1D2807}" - ProjectSection(ProjectDependencies) = postProject - {E6A38F0F-5278-4AF1-B4A0-AD5BC7D4AD5B} = {E6A38F0F-5278-4AF1-B4A0-AD5BC7D4AD5B} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "neuron_rng_normal_CODE\runner.vcxproj", "{E6A38F0F-5278-4AF1-B4A0-AD5BC7D4AD5B}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {141E75D3-7BED-45A9-8D17-96C91A1D2807}.Debug|x64.ActiveCfg = Debug|x64 - {141E75D3-7BED-45A9-8D17-96C91A1D2807}.Debug|x64.Build.0 = Debug|x64 - {141E75D3-7BED-45A9-8D17-96C91A1D2807}.Release|x64.ActiveCfg = Release|x64 - {141E75D3-7BED-45A9-8D17-96C91A1D2807}.Release|x64.Build.0 = Release|x64 - {E6A38F0F-5278-4AF1-B4A0-AD5BC7D4AD5B}.Debug|x64.ActiveCfg = Debug|x64 - {E6A38F0F-5278-4AF1-B4A0-AD5BC7D4AD5B}.Debug|x64.Build.0 = Debug|x64 - {E6A38F0F-5278-4AF1-B4A0-AD5BC7D4AD5B}.Release|x64.ActiveCfg = Release|x64 - {E6A38F0F-5278-4AF1-B4A0-AD5BC7D4AD5B}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/tests/features/neuron_rng_normal/neuron_rng_normal.vcxproj b/tests/features/neuron_rng_normal/neuron_rng_normal.vcxproj deleted file mode 100644 index 47869fec4b..0000000000 --- a/tests/features/neuron_rng_normal/neuron_rng_normal.vcxproj +++ /dev/null @@ -1,63 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {141E75D3-7BED-45A9-8D17-96C91A1D2807} - - - - - - - - - Application - true - $(DefaultPlatformToolset) - true - MultiByte - - - - - - - - - - ./ - $(Platform)\$(Configuration)\ - test - - - - Level3 - MaxSpeed - Disabled - true - true - true - _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) - neuron_rng_normal_CODE;$(GTEST_DIR);$(GTEST_DIR)/include - - - true - true - true - runner_Release.lib;%(AdditionalDependencies) - runner_Debug.lib;%(AdditionalDependencies) - - - - - - diff --git a/tests/features/neuron_rng_normal/runner_guid.txt b/tests/features/neuron_rng_normal/runner_guid.txt deleted file mode 100644 index a4fafbe128..0000000000 --- a/tests/features/neuron_rng_normal/runner_guid.txt +++ /dev/null @@ -1 +0,0 @@ -E6A38F0F-5278-4AF1-B4A0-AD5BC7D4AD5B diff --git a/tests/features/neuron_rng_normal/test.cc b/tests/features/neuron_rng_normal/test.cc deleted file mode 100644 index d2373596da..0000000000 --- a/tests/features/neuron_rng_normal/test.cc +++ /dev/null @@ -1,45 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file neuron_rng_normal/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "neuron_rng_normal_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test_samples.h" -#include "../../utils/stats.h" - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTestSamples -{ -public: - //---------------------------------------------------------------------------- - // SimulationTestHistogram virtuals - //---------------------------------------------------------------------------- - virtual double Test(std::vector &samples) const - { - // Perform Kolmogorov-Smirnov test - double d; - double prob; - std::tie(d, prob) = Stats::kolmogorovSmirnovTest(samples, Stats::normalCDF); - - return prob; - } -}; - -TEST_F(SimTest, NeuronRngNormal) -{ - // Check p value passes 95% confidence interval - EXPECT_GT(Simulate(), 0.05); -} diff --git a/tests/features/neuron_rng_uniform/Makefile b/tests/features/neuron_rng_uniform/Makefile deleted file mode 120000 index 1302b13ca5..0000000000 --- a/tests/features/neuron_rng_uniform/Makefile +++ /dev/null @@ -1 +0,0 @@ -../../utils/Makefile \ No newline at end of file diff --git a/tests/features/neuron_rng_uniform/model.cc b/tests/features/neuron_rng_uniform/model.cc deleted file mode 100644 index dd05ba8261..0000000000 --- a/tests/features/neuron_rng_uniform/model.cc +++ /dev/null @@ -1,45 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file neuron_rng_uniform/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - -//---------------------------------------------------------------------------- -// Neuron -//---------------------------------------------------------------------------- -class Neuron : public NeuronModels::Base -{ -public: - DECLARE_MODEL(Neuron, 0, 1); - - SET_SIM_CODE("$(x)= $(gennrand_uniform);\n"); - - SET_VARS({{"x", "scalar"}}); -}; - -IMPLEMENT_MODEL(Neuron); - - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - model.setDT(0.1); - model.setName("neuron_rng_uniform"); - - model.addNeuronPopulation("Pop", 1000, {}, Neuron::VarValues(0.0)); - - model.setPrecision(GENN_FLOAT); -} diff --git a/tests/features/neuron_rng_uniform/neuron_rng_uniform.sln b/tests/features/neuron_rng_uniform/neuron_rng_uniform.sln deleted file mode 100644 index e500f01951..0000000000 --- a/tests/features/neuron_rng_uniform/neuron_rng_uniform.sln +++ /dev/null @@ -1,30 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.30501.0 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "neuron_rng_uniform", "neuron_rng_uniform.vcxproj", "{C3BC6B87-6C1F-4998-ADCD-55FEFE984D0D}" - ProjectSection(ProjectDependencies) = postProject - {BB5B38F9-7D23-481E-B448-C168878BB5FE} = {BB5B38F9-7D23-481E-B448-C168878BB5FE} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "neuron_rng_uniform_CODE\runner.vcxproj", "{BB5B38F9-7D23-481E-B448-C168878BB5FE}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {C3BC6B87-6C1F-4998-ADCD-55FEFE984D0D}.Debug|x64.ActiveCfg = Debug|x64 - {C3BC6B87-6C1F-4998-ADCD-55FEFE984D0D}.Debug|x64.Build.0 = Debug|x64 - {C3BC6B87-6C1F-4998-ADCD-55FEFE984D0D}.Release|x64.ActiveCfg = Release|x64 - {C3BC6B87-6C1F-4998-ADCD-55FEFE984D0D}.Release|x64.Build.0 = Release|x64 - {BB5B38F9-7D23-481E-B448-C168878BB5FE}.Debug|x64.ActiveCfg = Debug|x64 - {BB5B38F9-7D23-481E-B448-C168878BB5FE}.Debug|x64.Build.0 = Debug|x64 - {BB5B38F9-7D23-481E-B448-C168878BB5FE}.Release|x64.ActiveCfg = Release|x64 - {BB5B38F9-7D23-481E-B448-C168878BB5FE}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/tests/features/neuron_rng_uniform/neuron_rng_uniform.vcxproj b/tests/features/neuron_rng_uniform/neuron_rng_uniform.vcxproj deleted file mode 100644 index 20f3c92e52..0000000000 --- a/tests/features/neuron_rng_uniform/neuron_rng_uniform.vcxproj +++ /dev/null @@ -1,63 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {C3BC6B87-6C1F-4998-ADCD-55FEFE984D0D} - - - - - - - - - Application - true - $(DefaultPlatformToolset) - true - MultiByte - - - - - - - - - - ./ - $(Platform)\$(Configuration)\ - test - - - - Level3 - MaxSpeed - Disabled - true - true - true - neuron_rng_uniform_CODE;$(GTEST_DIR);$(GTEST_DIR)/include - _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) - - - true - true - true - runner_Release.lib;%(AdditionalDependencies) - runner_Debug.lib;%(AdditionalDependencies) - - - - - - diff --git a/tests/features/neuron_rng_uniform/runner_guid.txt b/tests/features/neuron_rng_uniform/runner_guid.txt deleted file mode 100644 index aa01467413..0000000000 --- a/tests/features/neuron_rng_uniform/runner_guid.txt +++ /dev/null @@ -1 +0,0 @@ -BB5B38F9-7D23-481E-B448-C168878BB5FE diff --git a/tests/features/neuron_rng_uniform/test.cc b/tests/features/neuron_rng_uniform/test.cc deleted file mode 100644 index 39adb7bf17..0000000000 --- a/tests/features/neuron_rng_uniform/test.cc +++ /dev/null @@ -1,45 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file neuron_rng_uniform/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "neuron_rng_uniform_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test_samples.h" -#include "../../utils/stats.h" - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTestSamples -{ -public: - //---------------------------------------------------------------------------- - // SimulationTestHistogram virtuals - //---------------------------------------------------------------------------- - virtual double Test(std::vector &samples) const - { - // Perform Kolmogorov-Smirnov test - double d; - double prob; - std::tie(d, prob) = Stats::kolmogorovSmirnovTest(samples, Stats::uniformCDF); - - return prob; - } -}; - -TEST_F(SimTest, NeuronRngUniform) -{ - // Check p value passes 95% confidence interval - EXPECT_GT(Simulate(), 0.05); -} From 0e3eb2ad6a0bb2ce5c5e06ab466f72c20973f3bf Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Jul 2023 13:00:31 +0100 Subject: [PATCH 458/725] removed some more now-reimplemented tests --- .../decode_matrix_globalg_ragged/Makefile | 1 - .../decode_matrix_globalg_ragged.sln | 30 --------- .../decode_matrix_globalg_ragged.vcxproj | 63 ------------------- .../decode_matrix_globalg_ragged/model.cc | 55 ---------------- .../runner_guid.txt | 1 - .../decode_matrix_globalg_ragged/test.cc | 56 ----------------- .../decode_matrix_individualg_dense/Makefile | 1 - .../decode_matrix_individualg_dense.sln | 30 --------- .../decode_matrix_individualg_dense.vcxproj | 63 ------------------- .../decode_matrix_individualg_dense/model.cc | 55 ---------------- .../runner_guid.txt | 1 - .../decode_matrix_individualg_dense/test.cc | 53 ---------------- .../decode_matrix_individualg_ragged/Makefile | 1 - .../decode_matrix_individualg_ragged.sln | 30 --------- .../decode_matrix_individualg_ragged.vcxproj | 63 ------------------- .../decode_matrix_individualg_ragged/model.cc | 55 ---------------- .../runner_guid.txt | 1 - .../decode_matrix_individualg_ragged/test.cc | 56 ----------------- 18 files changed, 615 deletions(-) delete mode 120000 tests/features/decode_matrix_globalg_ragged/Makefile delete mode 100644 tests/features/decode_matrix_globalg_ragged/decode_matrix_globalg_ragged.sln delete mode 100644 tests/features/decode_matrix_globalg_ragged/decode_matrix_globalg_ragged.vcxproj delete mode 100644 tests/features/decode_matrix_globalg_ragged/model.cc delete mode 100644 tests/features/decode_matrix_globalg_ragged/runner_guid.txt delete mode 100644 tests/features/decode_matrix_globalg_ragged/test.cc delete mode 120000 tests/features/decode_matrix_individualg_dense/Makefile delete mode 100644 tests/features/decode_matrix_individualg_dense/decode_matrix_individualg_dense.sln delete mode 100644 tests/features/decode_matrix_individualg_dense/decode_matrix_individualg_dense.vcxproj delete mode 100644 tests/features/decode_matrix_individualg_dense/model.cc delete mode 100644 tests/features/decode_matrix_individualg_dense/runner_guid.txt delete mode 100644 tests/features/decode_matrix_individualg_dense/test.cc delete mode 120000 tests/features/decode_matrix_individualg_ragged/Makefile delete mode 100644 tests/features/decode_matrix_individualg_ragged/decode_matrix_individualg_ragged.sln delete mode 100644 tests/features/decode_matrix_individualg_ragged/decode_matrix_individualg_ragged.vcxproj delete mode 100644 tests/features/decode_matrix_individualg_ragged/model.cc delete mode 100644 tests/features/decode_matrix_individualg_ragged/runner_guid.txt delete mode 100644 tests/features/decode_matrix_individualg_ragged/test.cc diff --git a/tests/features/decode_matrix_globalg_ragged/Makefile b/tests/features/decode_matrix_globalg_ragged/Makefile deleted file mode 120000 index 1302b13ca5..0000000000 --- a/tests/features/decode_matrix_globalg_ragged/Makefile +++ /dev/null @@ -1 +0,0 @@ -../../utils/Makefile \ No newline at end of file diff --git a/tests/features/decode_matrix_globalg_ragged/decode_matrix_globalg_ragged.sln b/tests/features/decode_matrix_globalg_ragged/decode_matrix_globalg_ragged.sln deleted file mode 100644 index 473e467592..0000000000 --- a/tests/features/decode_matrix_globalg_ragged/decode_matrix_globalg_ragged.sln +++ /dev/null @@ -1,30 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.30501.0 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "decode_matrix_globalg_ragged", "decode_matrix_globalg_ragged.vcxproj", "{9836F58C-CC56-4AE3-81F1-2CE331309947}" - ProjectSection(ProjectDependencies) = postProject - {708CA18E-F2A5-48F9-9E8C-0AAF55A35BC2} = {708CA18E-F2A5-48F9-9E8C-0AAF55A35BC2} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "decode_matrix_globalg_ragged_CODE\runner.vcxproj", "{708CA18E-F2A5-48F9-9E8C-0AAF55A35BC2}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {9836F58C-CC56-4AE3-81F1-2CE331309947}.Debug|x64.ActiveCfg = Debug|x64 - {9836F58C-CC56-4AE3-81F1-2CE331309947}.Debug|x64.Build.0 = Debug|x64 - {9836F58C-CC56-4AE3-81F1-2CE331309947}.Release|x64.ActiveCfg = Release|x64 - {9836F58C-CC56-4AE3-81F1-2CE331309947}.Release|x64.Build.0 = Release|x64 - {708CA18E-F2A5-48F9-9E8C-0AAF55A35BC2}.Debug|x64.ActiveCfg = Debug|x64 - {708CA18E-F2A5-48F9-9E8C-0AAF55A35BC2}.Debug|x64.Build.0 = Debug|x64 - {708CA18E-F2A5-48F9-9E8C-0AAF55A35BC2}.Release|x64.ActiveCfg = Release|x64 - {708CA18E-F2A5-48F9-9E8C-0AAF55A35BC2}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/tests/features/decode_matrix_globalg_ragged/decode_matrix_globalg_ragged.vcxproj b/tests/features/decode_matrix_globalg_ragged/decode_matrix_globalg_ragged.vcxproj deleted file mode 100644 index 3c95ebe7d4..0000000000 --- a/tests/features/decode_matrix_globalg_ragged/decode_matrix_globalg_ragged.vcxproj +++ /dev/null @@ -1,63 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {9836F58C-CC56-4AE3-81F1-2CE331309947} - - - - - - - - - Application - true - $(DefaultPlatformToolset) - true - MultiByte - - - - - - - - - - ./ - $(Platform)\$(Configuration)\ - test - - - - Level3 - MaxSpeed - Disabled - true - true - true - decode_matrix_globalg_ragged_CODE;$(GTEST_DIR);$(GTEST_DIR)/include - _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) - - - true - true - true - runner_Release.lib;%(AdditionalDependencies) - runner_Debug.lib;%(AdditionalDependencies) - - - - - - diff --git a/tests/features/decode_matrix_globalg_ragged/model.cc b/tests/features/decode_matrix_globalg_ragged/model.cc deleted file mode 100644 index 909e4dc896..0000000000 --- a/tests/features/decode_matrix_globalg_ragged/model.cc +++ /dev/null @@ -1,55 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_globalg_ragged/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - -//---------------------------------------------------------------------------- -// Neuron -//---------------------------------------------------------------------------- -class Neuron : public NeuronModels::Base -{ -public: - DECLARE_MODEL(Neuron, 0, 1); - - SET_SIM_CODE("$(x)= $(Isyn);\n"); - - SET_VARS({{"x", "scalar"}}); -}; - -IMPLEMENT_MODEL(Neuron); - - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - model.setDT(0.1); - model.setName("decode_matrix_globalg_ragged"); - - // Static synapse parameters - WeightUpdateModels::StaticPulse::VarValues staticSynapseInit(1.0); // 0 - Wij (nA) - - model.addNeuronPopulation("Pre", 10, {}, {}); - model.addNeuronPopulation("Post", 4, {}, Neuron::VarValues(0.0)); - - - model.addSynapsePopulation( - "Syn", SynapseMatrixType::SPARSE_GLOBALG, NO_DELAY, "Pre", "Post", - {}, staticSynapseInit, - {}, {}); - - model.setPrecision(GENN_FLOAT); -} diff --git a/tests/features/decode_matrix_globalg_ragged/runner_guid.txt b/tests/features/decode_matrix_globalg_ragged/runner_guid.txt deleted file mode 100644 index 70dfca4740..0000000000 --- a/tests/features/decode_matrix_globalg_ragged/runner_guid.txt +++ /dev/null @@ -1 +0,0 @@ -708CA18E-F2A5-48F9-9E8C-0AAF55A35BC2 diff --git a/tests/features/decode_matrix_globalg_ragged/test.cc b/tests/features/decode_matrix_globalg_ragged/test.cc deleted file mode 100644 index 68360f1918..0000000000 --- a/tests/features/decode_matrix_globalg_ragged/test.cc +++ /dev/null @@ -1,56 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_globalg_ragged/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "decode_matrix_globalg_ragged_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test_decoder_matrix.h" - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTestDecoderMatrix -{ -public: - //---------------------------------------------------------------------------- - // SimulationTest virtuals - //---------------------------------------------------------------------------- - virtual void Init() - { - // Loop through presynaptic neurons - for(unsigned int i = 0; i < 10; i++) - { - // Initially zero row length - rowLengthSyn[i] = 0; - for(unsigned int j = 0; j < 4; j++) - { - // Get value this post synaptic neuron represents - const unsigned int j_value = (1 << j); - - // If this postsynaptic neuron should be connected, add index - if(((i + 1) & j_value) != 0) - { - const unsigned int idx = (i * 4) + rowLengthSyn[i]++; - indSyn[idx] = j; - } - } - } - } -}; - -TEST_F(SimTest, DecodeMatrixGlobalgRagged) -{ - // Check total error is less than some tolerance - EXPECT_TRUE(Simulate()); -} diff --git a/tests/features/decode_matrix_individualg_dense/Makefile b/tests/features/decode_matrix_individualg_dense/Makefile deleted file mode 120000 index 1302b13ca5..0000000000 --- a/tests/features/decode_matrix_individualg_dense/Makefile +++ /dev/null @@ -1 +0,0 @@ -../../utils/Makefile \ No newline at end of file diff --git a/tests/features/decode_matrix_individualg_dense/decode_matrix_individualg_dense.sln b/tests/features/decode_matrix_individualg_dense/decode_matrix_individualg_dense.sln deleted file mode 100644 index 552cc4afd9..0000000000 --- a/tests/features/decode_matrix_individualg_dense/decode_matrix_individualg_dense.sln +++ /dev/null @@ -1,30 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.30501.0 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "decode_matrix_individualg_dense", "decode_matrix_individualg_dense.vcxproj", "{194B1043-4F18-4FF1-AEA7-B7CC13578155}" - ProjectSection(ProjectDependencies) = postProject - {08C1AFC5-045B-43E9-910C-9E3FF168F84C} = {08C1AFC5-045B-43E9-910C-9E3FF168F84C} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "decode_matrix_individualg_dense_CODE\runner.vcxproj", "{08C1AFC5-045B-43E9-910C-9E3FF168F84C}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {194B1043-4F18-4FF1-AEA7-B7CC13578155}.Debug|x64.ActiveCfg = Debug|x64 - {194B1043-4F18-4FF1-AEA7-B7CC13578155}.Debug|x64.Build.0 = Debug|x64 - {194B1043-4F18-4FF1-AEA7-B7CC13578155}.Release|x64.ActiveCfg = Release|x64 - {194B1043-4F18-4FF1-AEA7-B7CC13578155}.Release|x64.Build.0 = Release|x64 - {08C1AFC5-045B-43E9-910C-9E3FF168F84C}.Debug|x64.ActiveCfg = Debug|x64 - {08C1AFC5-045B-43E9-910C-9E3FF168F84C}.Debug|x64.Build.0 = Debug|x64 - {08C1AFC5-045B-43E9-910C-9E3FF168F84C}.Release|x64.ActiveCfg = Release|x64 - {08C1AFC5-045B-43E9-910C-9E3FF168F84C}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/tests/features/decode_matrix_individualg_dense/decode_matrix_individualg_dense.vcxproj b/tests/features/decode_matrix_individualg_dense/decode_matrix_individualg_dense.vcxproj deleted file mode 100644 index ea301a804e..0000000000 --- a/tests/features/decode_matrix_individualg_dense/decode_matrix_individualg_dense.vcxproj +++ /dev/null @@ -1,63 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {194B1043-4F18-4FF1-AEA7-B7CC13578155} - - - - - - - - - Application - true - $(DefaultPlatformToolset) - true - MultiByte - - - - - - - - - - ./ - $(Platform)\$(Configuration)\ - test - - - - Level3 - MaxSpeed - Disabled - true - true - true - decode_matrix_individualg_dense_CODE;$(GTEST_DIR);$(GTEST_DIR)/include - _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) - - - true - true - true - runner_Release.lib;%(AdditionalDependencies) - runner_Debug.lib;%(AdditionalDependencies) - - - - - - diff --git a/tests/features/decode_matrix_individualg_dense/model.cc b/tests/features/decode_matrix_individualg_dense/model.cc deleted file mode 100644 index 8be95017a9..0000000000 --- a/tests/features/decode_matrix_individualg_dense/model.cc +++ /dev/null @@ -1,55 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_individualg_dense/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - -//---------------------------------------------------------------------------- -// Neuron -//---------------------------------------------------------------------------- -class Neuron : public NeuronModels::Base -{ -public: - DECLARE_MODEL(Neuron, 0, 1); - - SET_SIM_CODE("$(x)= $(Isyn);\n"); - - SET_VARS({{"x", "scalar"}}); -}; - -IMPLEMENT_MODEL(Neuron); - - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - model.setDT(0.1); - model.setName("decode_matrix_individualg_dense"); - - // Static synapse parameters - WeightUpdateModels::StaticPulse::VarValues staticSynapseInit(uninitialisedVar()); // 0 - Wij (nA) - - model.addNeuronPopulation("Pre", 10, {}, {}); - model.addNeuronPopulation("Post", 4, {}, Neuron::VarValues(0.0)); - - - model.addSynapsePopulation( - "Syn", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, "Pre", "Post", - {}, staticSynapseInit, - {}, {}); - - model.setPrecision(GENN_FLOAT); -} diff --git a/tests/features/decode_matrix_individualg_dense/runner_guid.txt b/tests/features/decode_matrix_individualg_dense/runner_guid.txt deleted file mode 100644 index bdfffc717e..0000000000 --- a/tests/features/decode_matrix_individualg_dense/runner_guid.txt +++ /dev/null @@ -1 +0,0 @@ -08C1AFC5-045B-43E9-910C-9E3FF168F84C diff --git a/tests/features/decode_matrix_individualg_dense/test.cc b/tests/features/decode_matrix_individualg_dense/test.cc deleted file mode 100644 index a85133f742..0000000000 --- a/tests/features/decode_matrix_individualg_dense/test.cc +++ /dev/null @@ -1,53 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_individualg_dense/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "decode_matrix_individualg_dense_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test_decoder_matrix.h" - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTestDecoderMatrix -{ -public: - //---------------------------------------------------------------------------- - // SimulationTest virtuals - //---------------------------------------------------------------------------- - virtual void Init() - { - // Loop through presynaptic neurons - unsigned int c = 0; - for(unsigned int i = 0; i < 10; i++) - { - // Set start index for this presynaptic neuron's weight matrix row - for(unsigned int j = 0; j < 4; j++) - { - // Get value this post synaptic neuron represents - const unsigned int j_value = (1 << j); - - // If this postsynaptic neuron should be connected, add 1.0 otherwise 0.0 - gSyn[c++] = (((i + 1) & j_value) != 0) ? 1.0f : 0.0f; - - } - } - } -}; - -TEST_F(SimTest, DecodeMatrixIndividualgDense) -{ - // Check total error is less than some tolerance - EXPECT_TRUE(Simulate()); -} diff --git a/tests/features/decode_matrix_individualg_ragged/Makefile b/tests/features/decode_matrix_individualg_ragged/Makefile deleted file mode 120000 index 1302b13ca5..0000000000 --- a/tests/features/decode_matrix_individualg_ragged/Makefile +++ /dev/null @@ -1 +0,0 @@ -../../utils/Makefile \ No newline at end of file diff --git a/tests/features/decode_matrix_individualg_ragged/decode_matrix_individualg_ragged.sln b/tests/features/decode_matrix_individualg_ragged/decode_matrix_individualg_ragged.sln deleted file mode 100644 index ddf8567a22..0000000000 --- a/tests/features/decode_matrix_individualg_ragged/decode_matrix_individualg_ragged.sln +++ /dev/null @@ -1,30 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.30501.0 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "decode_matrix_individualg_ragged", "decode_matrix_individualg_ragged.vcxproj", "{BD304043-4F4B-411B-9AC5-FD09087DEDFB}" - ProjectSection(ProjectDependencies) = postProject - {1D5EB74A-AAF2-4FA0-8858-B88A33159B4F} = {1D5EB74A-AAF2-4FA0-8858-B88A33159B4F} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "decode_matrix_individualg_ragged_CODE\runner.vcxproj", "{1D5EB74A-AAF2-4FA0-8858-B88A33159B4F}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {BD304043-4F4B-411B-9AC5-FD09087DEDFB}.Debug|x64.ActiveCfg = Debug|x64 - {BD304043-4F4B-411B-9AC5-FD09087DEDFB}.Debug|x64.Build.0 = Debug|x64 - {BD304043-4F4B-411B-9AC5-FD09087DEDFB}.Release|x64.ActiveCfg = Release|x64 - {BD304043-4F4B-411B-9AC5-FD09087DEDFB}.Release|x64.Build.0 = Release|x64 - {1D5EB74A-AAF2-4FA0-8858-B88A33159B4F}.Debug|x64.ActiveCfg = Debug|x64 - {1D5EB74A-AAF2-4FA0-8858-B88A33159B4F}.Debug|x64.Build.0 = Debug|x64 - {1D5EB74A-AAF2-4FA0-8858-B88A33159B4F}.Release|x64.ActiveCfg = Release|x64 - {1D5EB74A-AAF2-4FA0-8858-B88A33159B4F}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/tests/features/decode_matrix_individualg_ragged/decode_matrix_individualg_ragged.vcxproj b/tests/features/decode_matrix_individualg_ragged/decode_matrix_individualg_ragged.vcxproj deleted file mode 100644 index e06f683c31..0000000000 --- a/tests/features/decode_matrix_individualg_ragged/decode_matrix_individualg_ragged.vcxproj +++ /dev/null @@ -1,63 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {BD304043-4F4B-411B-9AC5-FD09087DEDFB} - - - - - - - - - Application - true - $(DefaultPlatformToolset) - true - MultiByte - - - - - - - - - - ./ - $(Platform)\$(Configuration)\ - test - - - - Level3 - MaxSpeed - Disabled - true - true - true - decode_matrix_individualg_ragged_CODE;$(GTEST_DIR);$(GTEST_DIR)/include - _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) - - - true - true - true - runner_Release.lib;%(AdditionalDependencies) - runner_Debug.lib;%(AdditionalDependencies) - - - - - - diff --git a/tests/features/decode_matrix_individualg_ragged/model.cc b/tests/features/decode_matrix_individualg_ragged/model.cc deleted file mode 100644 index f317e102f5..0000000000 --- a/tests/features/decode_matrix_individualg_ragged/model.cc +++ /dev/null @@ -1,55 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_individualg_ragged/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - -//---------------------------------------------------------------------------- -// Neuron -//---------------------------------------------------------------------------- -class Neuron : public NeuronModels::Base -{ -public: - DECLARE_MODEL(Neuron, 0, 1); - - SET_SIM_CODE("$(x)= $(Isyn);\n"); - - SET_VARS({{"x", "scalar"}}); -}; - -IMPLEMENT_MODEL(Neuron); - - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - model.setDT(0.1); - model.setName("decode_matrix_individualg_ragged"); - - // Static synapse parameters - WeightUpdateModels::StaticPulse::VarValues staticSynapseInit(1.0); // 0 - Wij (nA) - - model.addNeuronPopulation("Pre", 10, {}, {}); - model.addNeuronPopulation("Post", 4, {}, Neuron::VarValues(0.0)); - - - model.addSynapsePopulation( - "Syn", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, "Pre", "Post", - {}, staticSynapseInit, - {}, {}); - - model.setPrecision(GENN_FLOAT); -} diff --git a/tests/features/decode_matrix_individualg_ragged/runner_guid.txt b/tests/features/decode_matrix_individualg_ragged/runner_guid.txt deleted file mode 100644 index 31853a8cdf..0000000000 --- a/tests/features/decode_matrix_individualg_ragged/runner_guid.txt +++ /dev/null @@ -1 +0,0 @@ -1D5EB74A-AAF2-4FA0-8858-B88A33159B4F diff --git a/tests/features/decode_matrix_individualg_ragged/test.cc b/tests/features/decode_matrix_individualg_ragged/test.cc deleted file mode 100644 index bb0d08f756..0000000000 --- a/tests/features/decode_matrix_individualg_ragged/test.cc +++ /dev/null @@ -1,56 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_individualg_ragged/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "decode_matrix_individualg_ragged_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test_decoder_matrix.h" - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTestDecoderMatrix -{ -public: - //---------------------------------------------------------------------------- - // SimulationTest virtuals - //---------------------------------------------------------------------------- - virtual void Init() - { - // Loop through presynaptic neurons - for(unsigned int i = 0; i < 10; i++) - { - // Initially zero row length - rowLengthSyn[i] = 0; - for(unsigned int j = 0; j < 4; j++) - { - // Get value this post synaptic neuron represents - const unsigned int j_value = (1 << j); - - // If this postsynaptic neuron should be connected, add index - if(((i + 1) & j_value) != 0) - { - const unsigned int idx = (i * 4) + rowLengthSyn[i]++; - indSyn[idx] = j; - } - } - } - } -}; - -TEST_F(SimTest, DecodeMatrixIndividualgRagged) -{ - // Check total error is less than some tolerance - EXPECT_TRUE(Simulate()); -} From 7818be56dc306735046c9ead98fd347a0f61af6c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Jul 2023 13:02:11 +0100 Subject: [PATCH 459/725] fixed small typo in SynapseGroup --- pygenn/genn_groups.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index 5ba0b126a7..1d0b6f3f44 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -752,8 +752,8 @@ def _init_wum_var(self, var_data, num_copies): # If connectivity is dense, # copy variables directly into view # **NOTE** we assume order is row-major - if ((self.matrix_type & SynapseMatrixWeight.DENSE) or - (self.matrix_type & SynapseMatrixWeight.KERNEL)): + if ((self.matrix_type & SynapseMatrixConnectivity.DENSE) or + (self.matrix_type & SynapseMatrixWeight.KERNEL)): var_data.view[:] = var_data.values elif (self.matrix_type & SynapseMatrixConnectivity.SPARSE): # Sort variable to match GeNN order From 9681a14f97e58b162fcf55464736133bd168e84b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Jul 2023 13:02:41 +0100 Subject: [PATCH 460/725] reimplemented some of spike propagation via manually-defined connectivity tests --- tests/features/test_spike_propagation.py | 73 ++++++++++++++++++++---- 1 file changed, 62 insertions(+), 11 deletions(-) diff --git a/tests/features/test_spike_propagation.py b/tests/features/test_spike_propagation.py index 9c58dc260b..20ea2bacbd 100644 --- a/tests/features/test_spike_propagation.py +++ b/tests/features/test_spike_propagation.py @@ -38,11 +38,10 @@ """, var_name_types=[("x", "scalar")]) - @pytest.mark.parametrize("backend", ["single_threaded_cpu", "cuda"]) @pytest.mark.parametrize("precision", [types.Double, types.Float]) -def test_spike_propagation_snippet(backend, precision): - model = GeNNModel(precision, "test_spike_propagation_snippet", backend=backend) +def test_spike_propagation(backend, precision): + model = GeNNModel(precision, "test_spike_propagation", backend=backend) model.dt = 1.0 # Create spike source array to generate one-hot pattern to decode @@ -50,6 +49,22 @@ def test_spike_propagation_snippet(backend, precision): {}, {"startSpike": np.arange(16), "endSpike": np.arange(1, 17)}) ss_pop.extra_global_params["spikeTimes"].set_values(np.arange(16.0)) + # Build sparse connectivity + pre_inds = [] + post_inds = [] + for i in range(16): + for j in range(4): + j_value = 1 << j + if ((i + 1) & j_value) != 0: + pre_inds.append(i) + post_inds.append(j) + pre_inds = np.asarray(pre_inds) + post_inds = np.asarray(post_inds) + + # Use to build dense matrix + dense = np.zeros((16, 4)) + dense[pre_inds,post_inds] = 1.0 + # Create one output neuron pop with constant weight sparse decoder population sparse_constant_weight_n_pop = model.add_neuron_population( "PostSparseConstantWeightNeuron", 4, post_neuron_model, @@ -61,6 +76,18 @@ def test_spike_propagation_snippet(backend, precision): "DeltaCurr", {}, {}, init_sparse_connectivity(decoder_model, {})) + # Create one output neuron pop with constant weight sparse decoder population + manual_sparse_constant_weight_n_pop = model.add_neuron_population( + "ManualPostSparseConstantWeightNeuron", 4, post_neuron_model, + {}, {"x": 0.0}) + manual_sparse_constant_weight_s_pop = model.add_synapse_population( + "ManualSparseConstantWeightSynapse", "SPARSE", 0, + ss_pop, manual_sparse_constant_weight_n_pop, + "StaticPulseConstantWeight", {"g": 1.0}, {}, {}, {}, + "DeltaCurr", {}, {}) + manual_sparse_constant_weight_s_pop.set_sparse_connections(pre_inds, + post_inds) + # Create one output neuron pop with sparse decoder population sparse_n_pop = model.add_neuron_population( "PostSparseNeuron", 4, post_neuron_model, @@ -71,7 +98,19 @@ def test_spike_propagation_snippet(backend, precision): "StaticPulse", {}, {"g": 1.0}, {}, {}, "DeltaCurr", {}, {}, init_sparse_connectivity(decoder_model, {})) - + + # Create one output neuron pop with sparse decoder population + manual_sparse_n_pop = model.add_neuron_population( + "ManualPostSparseNeuron", 4, post_neuron_model, + {}, {"x": 0.0}) + manual_sparse_s_pop = model.add_synapse_population( + "ManualSparseSynapse", "SPARSE", 0, + ss_pop, manual_sparse_n_pop, + "StaticPulse", {}, {"g": 1.0}, {}, {}, + "DeltaCurr", {}, {}, + init_sparse_connectivity(decoder_model, {})) + manual_sparse_s_pop.set_sparse_connections(pre_inds, post_inds) + # Create one output neuron pop with bitmask decoder population bitmask_n_pop = model.add_neuron_population( "PostBitmaskNeuron", 4, post_neuron_model, @@ -82,7 +121,7 @@ def test_spike_propagation_snippet(backend, precision): "StaticPulseConstantWeight", {"g": 1.0}, {}, {}, {}, "DeltaCurr", {}, {}, init_sparse_connectivity(decoder_model, {})) - + # Create one output neuron pop with bitmask decoder population dense_n_pop = model.add_neuron_population( "PostDenseNeuron", 4, post_neuron_model, @@ -92,18 +131,30 @@ def test_spike_propagation_snippet(backend, precision): ss_pop, dense_n_pop, "StaticPulse", {}, {"g": init_var(decoder_dense_model, {})}, {}, {}, "DeltaCurr", {}, {}) - + + # Create one output neuron pop with bitmask decoder population + manual_dense_n_pop = model.add_neuron_population( + "ManualPostDenseNeuron", 4, post_neuron_model, + {}, {"x": 0.0}) + model.add_synapse_population( + "ManualPostDenseSynapse", "DENSE", 0, + ss_pop, manual_dense_n_pop, + "StaticPulse", {}, {"g": dense.flatten()}, {}, {}, + "DeltaCurr", {}, {}) + # Build model and load model.build() model.load() # Simulate 16 timesteps output_place_values = 2 ** np.arange(4) - output_populations = [sparse_constant_weight_n_pop, - sparse_n_pop, bitmask_n_pop, dense_n_pop] + output_populations = [sparse_constant_weight_n_pop, + manual_sparse_constant_weight_n_pop, + sparse_n_pop, manual_sparse_n_pop, + bitmask_n_pop, dense_n_pop, manual_dense_n_pop] while model.timestep < 16: model.step_time() - + # Loop through output populations for pop in output_populations: # Pull state variable @@ -111,11 +162,11 @@ def test_spike_propagation_snippet(backend, precision): # Convert to binary mask output_binary = np.isclose(np.ones(4), pop.vars["x"].view) - + # Sum up active place values output_value = np.sum(output_place_values[output_binary]) if output_value != (model.timestep - 1): assert False, f"{pop.name} decoding incorrect ({output_value} rather than {model.timestep - 1})" if __name__ == '__main__': - test_spike_propagation_snippet("single_threaded_cpu", types.Float) \ No newline at end of file + test_spike_propagation("single_threaded_cpu", types.Float) From 0d2eda1ca58619fe6b80afdd10b0d2a783e51b89 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Jul 2023 13:02:53 +0100 Subject: [PATCH 461/725] fixed typo in RNG sim test --- tests/features/test_rng_sim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/features/test_rng_sim.py b/tests/features/test_rng_sim.py index 9863f7a559..1694d899f7 100644 --- a/tests/features/test_rng_sim.py +++ b/tests/features/test_rng_sim.py @@ -109,4 +109,4 @@ def test_rng_sim(backend, precision): if __name__ == '__main__': - test_sim_rng("single_threaded_cpu", types.Double) \ No newline at end of file + test_rng_sim("single_threaded_cpu", types.Double) From c30850b05fdcf872572b3c6769316793f839b62c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Jul 2023 14:28:23 +0100 Subject: [PATCH 462/725] * added spike propogation with dendritic delay test * removed now-reimplemented tests --- .../Makefile | 1 - ...ode_matrix_den_delay_individualg_dense.sln | 30 --------- ...matrix_den_delay_individualg_dense.vcxproj | 63 ------------------- .../model.cc | 57 ----------------- .../runner_guid.txt | 1 - .../test.cc | 44 ------------- .../Makefile | 1 - ...de_matrix_den_delay_individualg_ragged.sln | 30 --------- ...atrix_den_delay_individualg_ragged.vcxproj | 63 ------------------- .../model.cc | 59 ----------------- .../runner_guid.txt | 1 - .../test.cc | 48 -------------- tests/features/test_spike_propagation.py | 60 +++++++++++++++++- 13 files changed, 57 insertions(+), 401 deletions(-) delete mode 120000 tests/features/decode_matrix_den_delay_individualg_dense/Makefile delete mode 100644 tests/features/decode_matrix_den_delay_individualg_dense/decode_matrix_den_delay_individualg_dense.sln delete mode 100644 tests/features/decode_matrix_den_delay_individualg_dense/decode_matrix_den_delay_individualg_dense.vcxproj delete mode 100644 tests/features/decode_matrix_den_delay_individualg_dense/model.cc delete mode 100644 tests/features/decode_matrix_den_delay_individualg_dense/runner_guid.txt delete mode 100644 tests/features/decode_matrix_den_delay_individualg_dense/test.cc delete mode 120000 tests/features/decode_matrix_den_delay_individualg_ragged/Makefile delete mode 100644 tests/features/decode_matrix_den_delay_individualg_ragged/decode_matrix_den_delay_individualg_ragged.sln delete mode 100644 tests/features/decode_matrix_den_delay_individualg_ragged/decode_matrix_den_delay_individualg_ragged.vcxproj delete mode 100644 tests/features/decode_matrix_den_delay_individualg_ragged/model.cc delete mode 100644 tests/features/decode_matrix_den_delay_individualg_ragged/runner_guid.txt delete mode 100644 tests/features/decode_matrix_den_delay_individualg_ragged/test.cc diff --git a/tests/features/decode_matrix_den_delay_individualg_dense/Makefile b/tests/features/decode_matrix_den_delay_individualg_dense/Makefile deleted file mode 120000 index 1302b13ca5..0000000000 --- a/tests/features/decode_matrix_den_delay_individualg_dense/Makefile +++ /dev/null @@ -1 +0,0 @@ -../../utils/Makefile \ No newline at end of file diff --git a/tests/features/decode_matrix_den_delay_individualg_dense/decode_matrix_den_delay_individualg_dense.sln b/tests/features/decode_matrix_den_delay_individualg_dense/decode_matrix_den_delay_individualg_dense.sln deleted file mode 100644 index fb16f2595c..0000000000 --- a/tests/features/decode_matrix_den_delay_individualg_dense/decode_matrix_den_delay_individualg_dense.sln +++ /dev/null @@ -1,30 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.30501.0 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "decode_matrix_den_delay_individualg_dense", "decode_matrix_den_delay_individualg_dense.vcxproj", "{B92DCD91-C356-4DD9-8F68-26D890348FDD}" - ProjectSection(ProjectDependencies) = postProject - {B2A10F28-9FF5-4C26-998D-FFE625A59A91} = {B2A10F28-9FF5-4C26-998D-FFE625A59A91} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "decode_matrix_den_delay_individualg_dense_CODE\runner.vcxproj", "{B2A10F28-9FF5-4C26-998D-FFE625A59A91}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {B92DCD91-C356-4DD9-8F68-26D890348FDD}.Debug|x64.ActiveCfg = Debug|x64 - {B92DCD91-C356-4DD9-8F68-26D890348FDD}.Debug|x64.Build.0 = Debug|x64 - {B92DCD91-C356-4DD9-8F68-26D890348FDD}.Release|x64.ActiveCfg = Release|x64 - {B92DCD91-C356-4DD9-8F68-26D890348FDD}.Release|x64.Build.0 = Release|x64 - {B2A10F28-9FF5-4C26-998D-FFE625A59A91}.Debug|x64.ActiveCfg = Debug|x64 - {B2A10F28-9FF5-4C26-998D-FFE625A59A91}.Debug|x64.Build.0 = Debug|x64 - {B2A10F28-9FF5-4C26-998D-FFE625A59A91}.Release|x64.ActiveCfg = Release|x64 - {B2A10F28-9FF5-4C26-998D-FFE625A59A91}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/tests/features/decode_matrix_den_delay_individualg_dense/decode_matrix_den_delay_individualg_dense.vcxproj b/tests/features/decode_matrix_den_delay_individualg_dense/decode_matrix_den_delay_individualg_dense.vcxproj deleted file mode 100644 index 201da24096..0000000000 --- a/tests/features/decode_matrix_den_delay_individualg_dense/decode_matrix_den_delay_individualg_dense.vcxproj +++ /dev/null @@ -1,63 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {B92DCD91-C356-4DD9-8F68-26D890348FDD} - - - - - - - - - Application - true - $(DefaultPlatformToolset) - true - MultiByte - - - - - - - - - - ./ - $(Platform)\$(Configuration)\ - test - - - - Level3 - MaxSpeed - Disabled - true - true - true - decode_matrix_den_delay_individualg_dense_CODE;$(GTEST_DIR);$(GTEST_DIR)/include;$(GTEST_DIR);$(GTEST_DIR)/include - _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) - - - true - true - true - runner_Release.lib;%(AdditionalDependencies) - runner_Debug.lib;%(AdditionalDependencies) - - - - - - diff --git a/tests/features/decode_matrix_den_delay_individualg_dense/model.cc b/tests/features/decode_matrix_den_delay_individualg_dense/model.cc deleted file mode 100644 index 50cfe66ad3..0000000000 --- a/tests/features/decode_matrix_den_delay_individualg_dense/model.cc +++ /dev/null @@ -1,57 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_den_delay_individualg_dense/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - -//---------------------------------------------------------------------------- -// Neuron -//---------------------------------------------------------------------------- -class Neuron : public NeuronModels::Base -{ -public: - DECLARE_MODEL(Neuron, 0, 1); - - SET_SIM_CODE("$(x)= $(Isyn);\n"); - - SET_VARS({{"x", "scalar"}}); -}; - -IMPLEMENT_MODEL(Neuron); - - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - model.setDT(1.0); - model.setName("decode_matrix_den_delay_individualg_dense"); - - // Static synapse parameters - WeightUpdateModels::StaticPulseDendriticDelay::VarValues staticSynapseInit( - 1.0, // 0 - Wij (nA) - uninitialisedVar()); // 1 - Dij (timestep) - - model.addNeuronPopulation("Pre", 10, {}, {}); - model.addNeuronPopulation("Post", 1, {}, Neuron::VarValues(0.0)); - - auto *syn = model.addSynapsePopulation( - "Syn", SynapseMatrixType::DENSE_INDIVIDUALG, NO_DELAY, "Pre", "Post", - {}, staticSynapseInit, - {}, {}); - syn->setMaxDendriticDelayTimesteps(10); - - model.setPrecision(GENN_FLOAT); -} diff --git a/tests/features/decode_matrix_den_delay_individualg_dense/runner_guid.txt b/tests/features/decode_matrix_den_delay_individualg_dense/runner_guid.txt deleted file mode 100644 index bb8bd5bb24..0000000000 --- a/tests/features/decode_matrix_den_delay_individualg_dense/runner_guid.txt +++ /dev/null @@ -1 +0,0 @@ -B2A10F28-9FF5-4C26-998D-FFE625A59A91 diff --git a/tests/features/decode_matrix_den_delay_individualg_dense/test.cc b/tests/features/decode_matrix_den_delay_individualg_dense/test.cc deleted file mode 100644 index 47f5777fbf..0000000000 --- a/tests/features/decode_matrix_den_delay_individualg_dense/test.cc +++ /dev/null @@ -1,44 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_den_delay_individualg_dense/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "decode_matrix_den_delay_individualg_dense_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test_den_delay_decoder_matrix.h" - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTestDecoderDenDelayMatrix -{ -public: - //---------------------------------------------------------------------------- - // SimulationTest virtuals - //---------------------------------------------------------------------------- - virtual void Init() - { - // Loop through presynaptic neurons - for(unsigned int i = 0; i < 10; i++) - { - // Connect row to output neuron with weight of one and dendritic delay of (9 - i) - dSyn[i] = (uint8_t)(9 - i); - } - } -}; - -TEST_F(SimTest, DecodeMatrixDenDelayIndividualgDense) -{ - // Check total error is less than some tolerance - EXPECT_TRUE(Simulate()); -} diff --git a/tests/features/decode_matrix_den_delay_individualg_ragged/Makefile b/tests/features/decode_matrix_den_delay_individualg_ragged/Makefile deleted file mode 120000 index 1302b13ca5..0000000000 --- a/tests/features/decode_matrix_den_delay_individualg_ragged/Makefile +++ /dev/null @@ -1 +0,0 @@ -../../utils/Makefile \ No newline at end of file diff --git a/tests/features/decode_matrix_den_delay_individualg_ragged/decode_matrix_den_delay_individualg_ragged.sln b/tests/features/decode_matrix_den_delay_individualg_ragged/decode_matrix_den_delay_individualg_ragged.sln deleted file mode 100644 index 7c22db8b65..0000000000 --- a/tests/features/decode_matrix_den_delay_individualg_ragged/decode_matrix_den_delay_individualg_ragged.sln +++ /dev/null @@ -1,30 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.30501.0 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "decode_matrix_den_delay_individualg_ragged", "decode_matrix_den_delay_individualg_ragged.vcxproj", "{D2CAF075-C02B-4BA6-9735-8541B3A31193}" - ProjectSection(ProjectDependencies) = postProject - {919516EA-6149-4382-AF87-7919AEB07C97} = {919516EA-6149-4382-AF87-7919AEB07C97} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runner", "decode_matrix_den_delay_individualg_ragged_CODE\runner.vcxproj", "{919516EA-6149-4382-AF87-7919AEB07C97}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {D2CAF075-C02B-4BA6-9735-8541B3A31193}.Debug|x64.ActiveCfg = Debug|x64 - {D2CAF075-C02B-4BA6-9735-8541B3A31193}.Debug|x64.Build.0 = Debug|x64 - {D2CAF075-C02B-4BA6-9735-8541B3A31193}.Release|x64.ActiveCfg = Release|x64 - {D2CAF075-C02B-4BA6-9735-8541B3A31193}.Release|x64.Build.0 = Release|x64 - {919516EA-6149-4382-AF87-7919AEB07C97}.Debug|x64.ActiveCfg = Debug|x64 - {919516EA-6149-4382-AF87-7919AEB07C97}.Debug|x64.Build.0 = Debug|x64 - {919516EA-6149-4382-AF87-7919AEB07C97}.Release|x64.ActiveCfg = Release|x64 - {919516EA-6149-4382-AF87-7919AEB07C97}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/tests/features/decode_matrix_den_delay_individualg_ragged/decode_matrix_den_delay_individualg_ragged.vcxproj b/tests/features/decode_matrix_den_delay_individualg_ragged/decode_matrix_den_delay_individualg_ragged.vcxproj deleted file mode 100644 index 1791cd4933..0000000000 --- a/tests/features/decode_matrix_den_delay_individualg_ragged/decode_matrix_den_delay_individualg_ragged.vcxproj +++ /dev/null @@ -1,63 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {D2CAF075-C02B-4BA6-9735-8541B3A31193} - - - - - - - - - Application - true - $(DefaultPlatformToolset) - true - MultiByte - - - - - - - - - - ./ - $(Platform)\$(Configuration)\ - test - - - - Level3 - MaxSpeed - Disabled - true - true - true - decode_matrix_den_delay_individualg_ragged_CODE;$(GTEST_DIR);$(GTEST_DIR)/include - _SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;%(PreprocessorDefinitions) - - - true - true - true - runner_Release.lib;%(AdditionalDependencies) - runner_Debug.lib;%(AdditionalDependencies) - - - - - - diff --git a/tests/features/decode_matrix_den_delay_individualg_ragged/model.cc b/tests/features/decode_matrix_den_delay_individualg_ragged/model.cc deleted file mode 100644 index eabbc20067..0000000000 --- a/tests/features/decode_matrix_den_delay_individualg_ragged/model.cc +++ /dev/null @@ -1,59 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_den_delay_individualg_ragged/model.cc - -\brief model definition file that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -#include "modelSpec.h" - -//---------------------------------------------------------------------------- -// Neuron -//---------------------------------------------------------------------------- -class Neuron : public NeuronModels::Base -{ -public: - DECLARE_MODEL(Neuron, 0, 1); - - SET_SIM_CODE("$(x)= $(Isyn);\n"); - - SET_VARS({{"x", "scalar"}}); -}; - -IMPLEMENT_MODEL(Neuron); - - -void modelDefinition(ModelSpec &model) -{ -#ifdef CL_HPP_TARGET_OPENCL_VERSION - if(std::getenv("OPENCL_DEVICE") != nullptr) { - GENN_PREFERENCES.deviceSelectMethod = DeviceSelect::MANUAL; - GENN_PREFERENCES.manualDeviceID = std::atoi(std::getenv("OPENCL_DEVICE")); - } - if(std::getenv("OPENCL_PLATFORM") != nullptr) { - GENN_PREFERENCES.manualPlatformID = std::atoi(std::getenv("OPENCL_PLATFORM")); - } -#endif - model.setDT(1.0); - model.setName("decode_matrix_den_delay_individualg_ragged"); - - // Static synapse parameters - WeightUpdateModels::StaticPulseDendriticDelay::VarValues staticSynapseInit( - 1.0, // 0 - Wij (nA) - uninitialisedVar()); // 1 - Dij (timestep) - - model.addNeuronPopulation("Pre", 10, {}, {}); - model.addNeuronPopulation("Post", 1, {}, Neuron::VarValues(0.0)); - - - auto *syn = model.addSynapsePopulation( - "Syn", SynapseMatrixType::SPARSE_INDIVIDUALG, NO_DELAY, "Pre", "Post", - {}, staticSynapseInit, - {}, {}); - syn->setMaxDendriticDelayTimesteps(10); - syn->setMaxConnections(1); - - model.setPrecision(GENN_FLOAT); -} diff --git a/tests/features/decode_matrix_den_delay_individualg_ragged/runner_guid.txt b/tests/features/decode_matrix_den_delay_individualg_ragged/runner_guid.txt deleted file mode 100644 index 038430a75d..0000000000 --- a/tests/features/decode_matrix_den_delay_individualg_ragged/runner_guid.txt +++ /dev/null @@ -1 +0,0 @@ -919516EA-6149-4382-AF87-7919AEB07C97 diff --git a/tests/features/decode_matrix_den_delay_individualg_ragged/test.cc b/tests/features/decode_matrix_den_delay_individualg_ragged/test.cc deleted file mode 100644 index 35ff233346..0000000000 --- a/tests/features/decode_matrix_den_delay_individualg_ragged/test.cc +++ /dev/null @@ -1,48 +0,0 @@ -//-------------------------------------------------------------------------- -/*! \file decode_matrix_den_delay_individualg_ragged/test.cc - -\brief Main test code that is part of the feature testing -suite of minimal models with known analytic outcomes that are used for continuous integration testing. -*/ -//-------------------------------------------------------------------------- - - -// Google test includes -#include "gtest/gtest.h" - -// Auto-generated simulation code includess -#include "decode_matrix_den_delay_individualg_ragged_CODE/definitions.h" - -// **NOTE** base-class for simulation tests must be -// included after auto-generated globals are includes -#include "../../utils/simulation_test_den_delay_decoder_matrix.h" - -//---------------------------------------------------------------------------- -// SimTest -//---------------------------------------------------------------------------- -class SimTest : public SimulationTestDecoderDenDelayMatrix -{ -public: - //---------------------------------------------------------------------------- - // SimulationTest virtuals - //---------------------------------------------------------------------------- - virtual void Init() - { - // Loop through presynaptic neurons - for(unsigned int i = 0; i < 10; i++) - { - // Set rowlength to 1 - rowLengthSyn[i] = 1; - - // Connect row to output neuron with weight of one and dendritic delay of (9 - i) - indSyn[i] = 0; - dSyn[i] = (uint8_t)(9 - i); - } - } -}; - -TEST_F(SimTest, DecodeMatrixDenDelayIndividualgRagged) -{ - // Check total error is less than some tolerance - EXPECT_TRUE(Simulate()); -} diff --git a/tests/features/test_spike_propagation.py b/tests/features/test_spike_propagation.py index 20ea2bacbd..4b7bc7a440 100644 --- a/tests/features/test_spike_propagation.py +++ b/tests/features/test_spike_propagation.py @@ -122,7 +122,7 @@ def test_spike_propagation(backend, precision): "DeltaCurr", {}, {}, init_sparse_connectivity(decoder_model, {})) - # Create one output neuron pop with bitmask decoder population + # Create one output neuron pop with dense decoder population dense_n_pop = model.add_neuron_population( "PostDenseNeuron", 4, post_neuron_model, {}, {"x": 0.0}) @@ -132,7 +132,7 @@ def test_spike_propagation(backend, precision): "StaticPulse", {}, {"g": init_var(decoder_dense_model, {})}, {}, {}, "DeltaCurr", {}, {}) - # Create one output neuron pop with bitmask decoder population + # Create one output neuron pop with dense decoder population manual_dense_n_pop = model.add_neuron_population( "ManualPostDenseNeuron", 4, post_neuron_model, {}, {"x": 0.0}) @@ -168,5 +168,59 @@ def test_spike_propagation(backend, precision): if output_value != (model.timestep - 1): assert False, f"{pop.name} decoding incorrect ({output_value} rather than {model.timestep - 1})" +@pytest.mark.parametrize("backend", ["single_threaded_cpu", "cuda"]) +@pytest.mark.parametrize("precision", [types.Double, types.Float]) +def test_spike_propagation_den_delay(backend, precision): + model = GeNNModel(precision, "test_spike_propagation_den_delay", backend=backend) + model.dt = 1.0 + + # Create spike source array to generate one-hot pattern to decode + ss_pop = model.add_neuron_population("SpikeSource", 10, "SpikeSourceArray", + {}, {"startSpike": np.arange(10), "endSpike": np.arange(1, 11)}) + ss_pop.extra_global_params["spikeTimes"].set_values(np.arange(10.0)) + + # Create one output neuron pop with dense decoder population + delay = np.arange(9, -1, -1) + dense_n_pop = model.add_neuron_population( + "PostDenseNeuron", 1, post_neuron_model, + {}, {"x": 0.0}) + dense_s_pop = model.add_synapse_population( + "PostDenseSynapse", "DENSE", 0, + ss_pop, dense_n_pop, + "StaticPulseDendriticDelay", {}, {"g": 1.0, "d": delay}, {}, {}, + "DeltaCurr", {}, {}) + dense_s_pop.max_dendritic_delay_timesteps = 10 + + # Create one output neuron pop with sparse decoder population + sparse_n_pop = model.add_neuron_population( + "PostSparseNeuron", 1, post_neuron_model, + {}, {"x": 0.0}) + sparse_s_pop = model.add_synapse_population( + "PostSparseSynapse", "SPARSE", 0, + ss_pop, sparse_n_pop, + "StaticPulseDendriticDelay", {}, {"g": 1.0, "d": delay}, {}, {}, + "DeltaCurr", {}, {}) + sparse_s_pop.max_dendritic_delay_timesteps = 10 + sparse_s_pop.set_sparse_connections(np.arange(10), np.zeros(10, dtype=int)) + + # Build model and load + model.build() + model.load() + + # Simulate for 11 timesteps + output_populations = [dense_n_pop, sparse_n_pop] + while model.timestep < 11: + model.step_time() + + # Loop through output populations + correct = 10.0 if model.timestep == 11 else 0.0 + for pop in output_populations: + # Pull state variable + pop.pull_var_from_device("x") + + # If not close to correct value, error + if not np.isclose(pop.vars["x"].view[0], correct): + assert False, f"{pop.name} decoding incorrect ({pop.vars['x'].view[0]} rather than {correct})" + if __name__ == '__main__': - test_spike_propagation("single_threaded_cpu", types.Float) + test_spike_propagation_den_delay("single_threaded_cpu", types.Float) From 58e8dd2c2fd838e61f76c71671daf6e1c6761b02 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Jul 2023 15:03:42 +0100 Subject: [PATCH 463/725] pretty printing logic was preventing e.g. CPU backend from substituting empty string for push functions --- src/genn/genn/transpiler/prettyPrinter.cc | 42 +++++++++-------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/src/genn/genn/transpiler/prettyPrinter.cc b/src/genn/genn/transpiler/prettyPrinter.cc index 2bad4b250d..8ac0367d7a 100644 --- a/src/genn/genn/transpiler/prettyPrinter.cc +++ b/src/genn/genn/transpiler/prettyPrinter.cc @@ -291,33 +291,25 @@ class Visitor : public Expression::Visitor, public Statement::Visitor } while(found != std::string::npos); } - // If all arguments haven't been substituted - if (i != m_CallArguments.top().size()) { - // If function is variadic - if (type.getFunction().variadic) { - // If variadic placeholder is found - const std::string variadicPlaceholder = "$(@)"; - const size_t found = name.find(variadicPlaceholder); - if (found != std::string::npos) { - // Concatenate together all remaining arguments - std::ostringstream variadicArgumentsStream; - std::copy(m_CallArguments.top().cbegin() + i, m_CallArguments.top().cend(), - std::ostream_iterator(variadicArgumentsStream, ", ")); - - // Replace variadic placeholder with all remaining arguments (after trimming trailing ", ") - std::string variadicArguments = variadicArgumentsStream.str(); - name.replace(found, variadicPlaceholder.length(), - variadicArguments.substr(0, variadicArguments.length() - 2)); - } - else { - throw std::runtime_error("Variadic function template for '" + variable.getName().lexeme + "' (" + name + ") has " - "insufficient placeholders for " + std::to_string(m_CallArguments.top().size()) + " argument call and no variadic placeholder '$(@)'"); - } + // If function is variadic + if (type.getFunction().variadic) { + // If variadic placeholder is found + const std::string variadicPlaceholder = "$(@)"; + const size_t found = name.find(variadicPlaceholder); + if (found != std::string::npos) { + // Concatenate together all remaining arguments + std::ostringstream variadicArgumentsStream; + std::copy(m_CallArguments.top().cbegin() + i, m_CallArguments.top().cend(), + std::ostream_iterator(variadicArgumentsStream, ", ")); + + // Replace variadic placeholder with all remaining arguments (after trimming trailing ", ") + std::string variadicArguments = variadicArgumentsStream.str(); + name.replace(found, variadicPlaceholder.length(), + variadicArguments.substr(0, variadicArguments.length() - 2)); } - // Otherwise, give error else { - throw std::runtime_error("Function template for '" + variable.getName().lexeme + "' (" + name + ") has " - "insufficient placeholders for " + std::to_string(m_CallArguments.top().size()) + " argument call"); + throw std::runtime_error("Variadic function template for '" + variable.getName().lexeme + "' (" + name + ") has " + "insufficient placeholders for " + std::to_string(m_CallArguments.top().size()) + " argument call and no variadic placeholder '$(@)'"); } } } From 89bd4566734dbb51dba14d42882981bd92f58b14 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Jul 2023 15:12:44 +0100 Subject: [PATCH 464/725] fixed extra parenthesis in column-wise connectivity building --- src/genn/genn/code_generator/backendSIMT.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 7a95ab230f..9ce305eb82 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -1586,7 +1586,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer kernelInit << "const unsigned int idx = ($(id_pre) * $(_row_stride)) + $(_row_length)[$(id)];" << std::endl; } else { - kernelInit << "const unsigned int idx = (($(0)) * $(_row_stride))) + $(_row_length)[$(0)];" << std::endl; + kernelInit << "const unsigned int idx = (($(0)) * $(_row_stride)) + $(_row_length)[$(0)];" << std::endl; } } From 6abaeb9f27e94be32809cee37479e9743e35f8df Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Jul 2023 15:47:28 +0100 Subject: [PATCH 465/725] RNG fixes for custom update * In CUDA backend, generate custom connectivity update in environment with RNG functions * In SIMT backend, point _rng to population RNG --- src/genn/backends/cuda/backend.cc | 5 ++++- src/genn/genn/code_generator/backendSIMT.cc | 17 +++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 21eeda887d..df40918d63 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -683,7 +683,10 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Custom connectivity updates" << std::endl; - genCustomConnectivityUpdateKernel(funcEnv, modelMerged, memorySpaces, g, idCustomUpdateStart); + + // Add RNG functions to environment and generate kernel + EnvironmentLibrary rngEnv(funcEnv, getRNGFunctions(model.getPrecision())); + genCustomConnectivityUpdateKernel(rngEnv, modelMerged, memorySpaces, g, idCustomUpdateStart); } } diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 9ce305eb82..0a1c8f5972 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -1335,18 +1335,19 @@ void BackendSIMT::genCustomConnectivityUpdateKernel(EnvironmentExternalBase &env // Configure substitutions groupEnv.add(Type::Uint32.addConst(), "id_pre", "$(id)"); - // Copy global RNG stream to local and use pointer to this for rng - const std::string rng = printSubs("$(_rng)[$(id)]", groupEnv); - if(Utils::isRNGRequired(cg.getArchetype().getRowUpdateCodeTokens())) { - groupEnv.add(Type::Void, "rng", genPopulationRNGPreamble(groupEnv.getStream(), rng)); - } + // Add population RNG field + groupEnv.addField(getPopulationRNGType().createPointer(), "_rng", "rng", + [this](const auto &g, size_t) { return getDeviceVarPrefix() + "rng" + g.getName(); }, + "$(id)"); + // **TODO** for OCL do genPopulationRNGPreamble(os, popSubs, "$(id)") in initialiser + cg.generateUpdate(*this, groupEnv, modelMerged.getModel().getBatchSize()); // Copy local stream back to local - if(Utils::isRNGRequired(cg.getArchetype().getRowUpdateCodeTokens())) { + /*if(Utils::isRNGRequired(cg.getArchetype().getRowUpdateCodeTokens())) { genPopulationRNGPostamble(groupEnv.getStream(), rng); - } + }*/ } }); } @@ -1744,7 +1745,7 @@ void BackendSIMT::genInitializeSparseKernel(EnvironmentExternalBase &env, ModelS groupEnv.add(Type::Void, "_rng", genGlobalRNGSkipAhead(groupEnv.getStream(), std::to_string(numInitializeThreads) + " + id")); } - + // Generate sparse synapse variable initialisation code genSparseSynapseVarInit( groupEnv, batchSize, cg, true, From a4b5d465d35e4a727527228514175baa0aa78e41 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 26 Jul 2023 09:30:08 +0100 Subject: [PATCH 466/725] add some chi-squared tests to determine how random connectivity is --- tests/features/test_connect_init.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/features/test_connect_init.py b/tests/features/test_connect_init.py index 088d16f9d5..917aa160c7 100644 --- a/tests/features/test_connect_init.py +++ b/tests/features/test_connect_init.py @@ -1,6 +1,7 @@ import numpy as np import pytest from pygenn import types +from scipy import stats from pygenn import GeNNModel @@ -17,12 +18,12 @@ def test_connect_init(backend, precision): post_pop = model.add_neuron_population("Post", 100, "SpikeSource", {}, {}) # Add synapse populations with different types of built-in connectivity - fixed_number_total_s_pop = model.add_synapse_population( - "FixedNumberTotal", "SPARSE", 0, - pre_pop, post_pop, - "StaticPulseConstantWeight", {"g": 1.0}, {}, {}, {}, - "DeltaCurr", {}, {}, - init_sparse_connectivity("FixedNumberTotalWithReplacement", {"total": 1000})) + #fixed_number_total_s_pop = model.add_synapse_population( + # "FixedNumberTotal", "SPARSE", 0, + # pre_pop, post_pop, + # "StaticPulseConstantWeight", {"g": 1.0}, {}, {}, {}, + # "DeltaCurr", {}, {}, + # init_sparse_connectivity("FixedNumberTotalWithReplacement", {"total": 1000})) fixed_number_pre_s_pop = model.add_synapse_population( "FixedNumberPre", "SPARSE", 0, @@ -43,16 +44,19 @@ def test_connect_init(backend, precision): model.load() # Pull connectivity - fixed_number_total_s_pop.pull_connectivity_from_device() + #fixed_number_total_s_pop.pull_connectivity_from_device() fixed_number_pre_s_pop.pull_connectivity_from_device() fixed_number_post_s_pop.pull_connectivity_from_device() # Check connectivity assert np.all(np.bincount(fixed_number_post_s_pop.get_sparse_pre_inds()) == 10) assert np.all(np.bincount(fixed_number_pre_s_pop.get_sparse_post_inds()) == 10) - assert len(fixed_number_total_s_pop.get_sparse_pre_inds()) == 1000 + #assert len(fixed_number_total_s_pop.get_sparse_pre_inds()) == 1000 + + # Check neurons are uniformly distributed within each row/column + assert stats.chisquare(np.bincount(fixed_number_post_s_pop.get_sparse_post_inds(), minlength=100)).pvalue > 0.05 + assert stats.chisquare(np.bincount(fixed_number_pre_s_pop.get_sparse_pre_inds(), minlength=100)).pvalue > 0.05 - # **TODO** we could also build a histogram of postsynaptic neurons and check that they are approximately uniformly distributed if __name__ == '__main__': test_connect_init("single_threaded_cpu", types.Float) \ No newline at end of file From 4bc45133bc4bff10ce2161c30cb61461586a8370 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 27 Jul 2023 12:00:05 +0100 Subject: [PATCH 467/725] WIP fixing of initialisation environment bugs --- .../genn/genn/code_generator/backendBase.h | 7 + .../genn/genn/code_generator/backendSIMT.h | 4 +- .../backends/single_threaded_cpu/backend.cc | 24 ++- src/genn/genn/code_generator/backendBase.cc | 168 ++++++++++++++---- src/genn/genn/code_generator/backendSIMT.cc | 40 +++-- .../genn/code_generator/initGroupMerged.cc | 91 +--------- 6 files changed, 192 insertions(+), 142 deletions(-) diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index ed578e78be..2cc31f5093 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -59,6 +59,8 @@ class CustomUpdateInitGroupMerged; class CustomWUUpdateInitGroupMerged; class CustomWUUpdateSparseInitGroupMerged; class SynapseConnectivityInitGroupMerged; +class CustomConnectivityUpdatePreInitGroupMerged; +class CustomConnectivityUpdatePostInitGroupMerged; class SynapseInitGroupMerged; class SynapseSparseInitGroupMerged; } @@ -481,10 +483,15 @@ class GENN_EXPORT BackendBase void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; + void buildStandardEnvironment(EnvironmentGroupMergedField &env) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; void buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const; diff --git a/include/genn/genn/code_generator/backendSIMT.h b/include/genn/genn/code_generator/backendSIMT.h index 5a766133fe..571e5ed8d4 100644 --- a/include/genn/genn/code_generator/backendSIMT.h +++ b/include/genn/genn/code_generator/backendSIMT.h @@ -340,7 +340,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase void genSynapseVarInit(EnvironmentExternalBase &env, unsigned int batchSize, G &g, bool initRNGRequired, bool kernel, size_t kernelDimensions) const { - env.getStream() << "if(" << env["id"] << " < "; + env.print("if($(id) < "); // If synapse group has kernel weights, check ID against product of kernel dimensions if (kernel) { @@ -356,7 +356,7 @@ class GENN_EXPORT BackendSIMT : public BackendBase } // Otherwise, against number of postsynaptic neurons else { - env.getStream() << env["num_post"]; + env.print("$(num_post)"); } env.getStream() << ")"; { diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 9978dc0233..19d47344c5 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -644,6 +644,7 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back // Create matching environment EnvironmentGroupMergedField groupEnv(funcEnv, c); + buildStandardEnvironment(groupEnv); // **TODO** add fields const SynapseGroupInternal *sg = c.getArchetype().getSynapseGroup(); @@ -660,16 +661,16 @@ void Backend::genCustomUpdate(CodeStream &os, ModelSpecMerged &modelMerged, Back } else { // Loop through presynaptic neurons - groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)"; + groupEnv.print("for(unsigned int i = 0; i < $(num_pre); i++)"); { // If this synapse group has sparse connectivity, loop through length of this row CodeStream::Scope b(groupEnv.getStream()); if (sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { - groupEnv.getStream() << "for(unsigned int s = 0; s < " << groupEnv["_row_length"] << "[i]; s++)"; + groupEnv.print("for(unsigned int s = 0; s < $(_row_length)[i]; s++)"); } // Otherwise, if it's dense, loop through each postsynaptic neuron else if (sg->getMatrixType() & SynapseMatrixConnectivity::DENSE) { - groupEnv.getStream() << "for (unsigned int j = 0; j < " << groupEnv["size"] << "; j++)"; + groupEnv.print("for (unsigned int j = 0; j < $(num_post); j++)"); } else { throw std::runtime_error("Only DENSE and SPARSE format connectivity can be used for custom updates"); @@ -904,7 +905,10 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomUpdateInitGroup" << c.getIndex() << "[g]; " << std::endl; - c.generateInit(*this, funcEnv, 1); + + EnvironmentGroupMergedField groupEnv(funcEnv, c); + buildStandardEnvironment(groupEnv); + c.generateInit(*this, groupEnv, 1); } }); @@ -922,7 +926,10 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomConnectivityUpdatePreInitGroup" << c.getIndex() << "[g]; " << std::endl; - c.generateInit(*this, funcEnv, 1); + + EnvironmentGroupMergedField groupEnv(funcEnv, c); + buildStandardEnvironment(groupEnv); + c.generateInit(*this, groupEnv, 1); } }); @@ -940,6 +947,8 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomConnectivityUpdatePostInitGroup" << c.getIndex() << "[g]; " << std::endl; + EnvironmentGroupMergedField groupEnv(funcEnv, c); + buildStandardEnvironment(groupEnv); c.generateInit(*this, funcEnv, 1); } }); @@ -958,7 +967,10 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomWUUpdateInitGroup" << c.getIndex() << "[g]; " << std::endl; - c.generateInit(*this, funcEnv, 1); + + EnvironmentGroupMergedField groupEnv(funcEnv, c); + buildStandardEnvironment(groupEnv); + c.generateInit(*this, groupEnv, 1); } }); diff --git a/src/genn/genn/code_generator/backendBase.cc b/src/genn/genn/code_generator/backendBase.cc index 689deae254..7273c4d018 100644 --- a/src/genn/genn/code_generator/backendBase.cc +++ b/src/genn/genn/code_generator/backendBase.cc @@ -5,6 +5,7 @@ #include "logging.h" // GeNN code generator includes +#include "code_generator/codeGenUtils.h" #include "code_generator/groupMerged.h" #include "code_generator/customConnectivityUpdateGroupMerged.h" #include "code_generator/customUpdateGroupMerged.h" @@ -279,7 +280,101 @@ void buildStandardSynapseEnvironment(const BackendBase &backend, EnvironmentGrou } } } +//-------------------------------------------------------------------------- +template +void buildStandardCustomUpdateEnvironment(const BackendBase &backend, EnvironmentGroupMergedField &env) +{ + // Add size field + env.addField(Type::Uint32, "size", "size", + [](const auto &c, size_t) { return std::to_string(c.getSize()); }); + + // If batching is enabled, calculate batch offset + if(env.getGroup().getArchetype().isBatched()) { + env.add(Type::Uint32.addConst(), "_batch_offset", "batchOffset", + {env.addInitialiser("const unsigned int batchOffset = $(size) * batch;")}); + } + + // If axonal delays are required + if(env.getGroup().getArchetype().getDelayNeuronGroup() != nullptr) { + // Add spike queue pointer field + env.addField(Type::Uint32.createPointer(), "_spk_que_ptr", "spkQuePtr", + [&backend](const auto &cg, size_t) + { + return backend.getScalarAddressPrefix() + "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); + }); + + // We should read from delay slot pointed to be spkQuePtr + env.add(Type::Uint32.addConst(), "_delay_slot", "delaySlot", + {env.addInitialiser("const unsigned int delaySlot = * $(_spk_que_ptr);")}); + env.add(Type::Uint32.addConst(), "_delay_offset", "delayOffset", + {env.addInitialiser("const unsigned int delayOffset = $(_delay_slot) * $(size);")}); + + // If batching is also enabled, calculate offset including delay and batch + if(env.getGroup().getArchetype().isBatched()) { + const std::string numDelaySlotsStr = std::to_string(env.getGroup().getArchetype().getDelayNeuronGroup()->getNumDelaySlots()); + env.add(Type::Uint32.addConst(), "_batch_delay_slot", "batchDelaySlot", + {env.addInitialiser("const unsigned int batchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_delay_slot);")}); + + // Calculate current batch offset + env.add(Type::Uint32.addConst(), "_batch_delay_offset", "batchDelayOffset", + {env.addInitialiser("const unsigned int batchDelayOffset = $(_batch_offset) * " + numDelaySlotsStr + ";")}); + } + } +} +//-------------------------------------------------------------------------- +template +void buildStandardCustomUpdateWUEnvironment(const BackendBase &backend, EnvironmentGroupMergedField &env) +{ + // If synapse group has kernel + const auto &kernelSize = env.getGroup().getArchetype().getKernelSize(); + if(!kernelSize.empty()) { + if(env.getGroup().getArchetype().isBatched()) { + // Loop through kernel dimensions and multiply together + std::ostringstream kernBatchOffsetInit; + kernBatchOffsetInit << "const unsigned int kernBatchOffset = "; + for(size_t i = 0; i < kernelSize.size(); i++) { + kernBatchOffsetInit << getKernelSize(env.getGroup(), i) << " * "; + } + + // And finally by batch + kernBatchOffsetInit << "$(batch);" << std::endl; + + env.add(Type::Uint32.addConst(), "_kern_batch_offset", "kernBatchOffset", + {env.addInitialiser(kernBatchOffsetInit.str())}); + } + } + + // Synapse group fields + env.addField(Type::Uint32.addConst(), "num_pre", + Type::Uint32, "numSrcNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); + env.addField(Type::Uint32.addConst(), "num_post", + Type::Uint32, "numTrgNeurons", + [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); + env.addField(Type::Uint32, "_row_stride", "rowStride", + [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); + + // Connectivity fields + auto *sg = env.getGroup().getArchetype().getSynapseGroup(); + if(sg->getMatrixType() & SynapseMatrixConnectivity::SPARSE) { + env.addField(Type::Uint32.createPointer(), "_row_length", "rowLength", + [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "rowLength" + cg.getSynapseGroup()->getName(); }); + env.addField(sg->getSparseIndType().createPointer(), "_ind", "ind", + [&backend](const auto &cg, size_t) { return backend.getDeviceVarPrefix() + "ind" + cg.getSynapseGroup()->getName(); }); + } + else if(sg->getMatrixType() & SynapseMatrixWeight::KERNEL) { + // Loop through kernel size dimensions + for (size_t d = 0; d < sg->getKernelSize().size(); d++) { + // If this dimension has a heterogeneous size, add it to struct + if (isKernelSizeHeterogeneous(env.getGroup(), d)) { + env.addField(Type::Uint32.addConst(), "_kernel_size_" + std::to_string(d), "kernelSize" + std::to_string(d), + [d](const auto &g, size_t) { return std::to_string(g.getSynapseGroup()->getKernelSize().at(d)); }); + } + } + } } +} // Anonymous namespace + //-------------------------------------------------------------------------- // GeNN::CodeGenerator::BackendBase //-------------------------------------------------------------------------- @@ -340,42 +435,12 @@ void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const { - // Add size field - env.addField(Type::Uint32, "size", "size", - [](const auto &c, size_t) { return std::to_string(c.getSize()); }); - - // If batching is enabled, calculate batch offset - if(env.getGroup().getArchetype().isBatched()) { - env.add(Type::Uint32.addConst(), "_batch_offset", "batchOffset", - {env.addInitialiser("const unsigned int batchOffset = $(size) * batch;")}); - } - - // If axonal delays are required - if(env.getGroup().getArchetype().getDelayNeuronGroup() != nullptr) { - // Add spike queue pointer field - env.addField(Type::Uint32.createPointer(), "_spk_que_ptr", "spkQuePtr", - [this](const auto &cg, size_t) - { - return getScalarAddressPrefix() + "spkQuePtr" + cg.getDelayNeuronGroup()->getName(); - }); - - // We should read from delay slot pointed to be spkQuePtr - env.add(Type::Uint32.addConst(), "_delay_slot", "delaySlot", - {env.addInitialiser("const unsigned int delaySlot = * $(_spk_que_ptr);")}); - env.add(Type::Uint32.addConst(), "_delay_offset", "delayOffset", - {env.addInitialiser("const unsigned int delayOffset = $(_delay_slot) * $(size);")}); - - // If batching is also enabled, calculate offset including delay and batch - if(env.getGroup().getArchetype().isBatched()) { - const std::string numDelaySlotsStr = std::to_string(env.getGroup().getArchetype().getDelayNeuronGroup()->getNumDelaySlots()); - env.add(Type::Uint32.addConst(), "_batch_delay_slot", "batchDelaySlot", - {env.addInitialiser("const unsigned int batchDelaySlot = (batch * " + numDelaySlotsStr + ") + $(_delay_slot);")}); - - // Calculate current batch offset - env.add(Type::Uint32.addConst(), "_batch_delay_offset", "batchDelayOffset", - {env.addInitialiser("const unsigned int batchDelayOffset = $(_batch_offset) * " + numDelaySlotsStr + ";")}); - } - } + buildStandardCustomUpdateEnvironment(*this, env); +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const +{ + buildStandardCustomUpdateWUEnvironment(*this, env); } //----------------------------------------------------------------------- void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const @@ -419,6 +484,37 @@ void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const +{ + buildStandardCustomUpdateEnvironment(*this, env); +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const +{ + buildStandardCustomUpdateWUEnvironment(*this, env); +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const +{ + env.addField(Type::Uint32.addConst(), "size", + Type::Uint32, "size", + [](const auto &c, size_t) + { + return std::to_string(c.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); + }); + +} +//----------------------------------------------------------------------- +void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env) const +{ + env.addField(Type::Uint32.addConst(), "size", + Type::Uint32, "size", + [](const auto &c, size_t) + { + return std::to_string(c.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); + }); +} +//----------------------------------------------------------------------- void BackendBase::buildStandardEnvironment(EnvironmentGroupMergedField &env, unsigned int batchSize) const { buildStandardSynapseEnvironment(*this, env, batchSize); diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 0a1c8f5972..56c5c45dd0 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -1416,7 +1416,9 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer [this](const SynapseGroupInternal &sg) { return padKernelSize(getNumInitThreads(sg), KernelInitialize); }, [batchSize, this](EnvironmentExternalBase &env, SynapseInitGroupMerged &sg) { - genSynapseVarInit(env, batchSize, sg, sg.getArchetype().isWUInitRNGRequired(), + EnvironmentGroupMergedField groupEnv(env, sg); + buildStandardEnvironment(groupEnv, 1); + genSynapseVarInit(groupEnv, batchSize, sg, sg.getArchetype().isWUInitRNGRequired(), (sg.getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL), sg.getArchetype().getKernelSize().size()); }); @@ -1429,11 +1431,13 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer [this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelInitialize); }, [batchSize, this](EnvironmentExternalBase &env, CustomUpdateInitGroupMerged &cg) { - env.getStream() << "// only do this for existing variables" << std::endl; - env.print("if($(id) < $(size))"); + EnvironmentGroupMergedField groupEnv(env, cg); + buildStandardEnvironment(groupEnv); + + groupEnv.getStream() << "// only do this for existing variables" << std::endl; + groupEnv.print("if($(id) < $(size))"); { - CodeStream::Scope b(env.getStream()); - EnvironmentGroupMergedField groupEnv(env, cg); + CodeStream::Scope b(groupEnv.getStream()); // If this custom update requires an RNG for initialisation, // make copy of global phillox RNG and skip ahead by thread id @@ -1454,8 +1458,10 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer [this](const CustomUpdateWUInternal &cg) { return padKernelSize(getNumInitThreads(cg), KernelInitialize); }, [batchSize, this](EnvironmentExternalBase &env, CustomWUUpdateInitGroupMerged &cg) { + EnvironmentGroupMergedField groupEnv(env, cg); + buildStandardEnvironment(groupEnv); const SynapseGroup *sg = cg.getArchetype().getSynapseGroup(); - genSynapseVarInit(env, batchSize, cg, cg.getArchetype().isInitRNGRequired(), + genSynapseVarInit(groupEnv, batchSize, cg, cg.getArchetype().isInitRNGRequired(), (sg->getMatrixType() & SynapseMatrixWeight::KERNEL), sg->getKernelSize().size()); }); env.getStream() << std::endl; @@ -1467,11 +1473,14 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons(), KernelInitialize); }, [batchSize, this](EnvironmentExternalBase &env, CustomConnectivityUpdatePreInitGroupMerged &cg) { - env.getStream() << "// only do this for existing variables" << std::endl; - env.print("if($(id) < $(size))"); + // Create environment + EnvironmentGroupMergedField groupEnv(env, cg); + buildStandardEnvironment(groupEnv); + + groupEnv.getStream() << "// only do this for existing variables" << std::endl; + groupEnv.print("if($(id) < $(size))"); { - CodeStream::Scope b(env.getStream()); - EnvironmentGroupMergedField groupEnv(env, cg); + CodeStream::Scope b(groupEnv.getStream()); // If population RNGs are initialised on device and this custom connectivity update // required one, initialise single RNG using GLOBAL thread id for sequence @@ -1504,11 +1513,14 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer [this](const CustomConnectivityUpdateInternal &cg) { return padKernelSize(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons(), KernelInitialize); }, [batchSize, this](EnvironmentExternalBase &env, CustomConnectivityUpdatePostInitGroupMerged &cg) { - env.getStream() << "// only do this for existing variables" << std::endl; - env.print("if($(id) < $(size))"); + // Create environment + EnvironmentGroupMergedField groupEnv(env, cg); + buildStandardEnvironment(groupEnv); + + groupEnv.getStream() << "// only do this for existing variables" << std::endl; + groupEnv.print("if($(id) < $(size))"); { - CodeStream::Scope b(env.getStream()); - EnvironmentGroupMergedField groupEnv(env, cg); + CodeStream::Scope b(groupEnv.getStream()); // If population RNGs are initialised on device and this custom connectivity update // required one, initialise single RNG using GLOBAL thread id for sequence diff --git a/src/genn/genn/code_generator/initGroupMerged.cc b/src/genn/genn/code_generator/initGroupMerged.cc index 5f3cca3600..ec7fadea00 100644 --- a/src/genn/genn/code_generator/initGroupMerged.cc +++ b/src/genn/genn/code_generator/initGroupMerged.cc @@ -530,27 +530,8 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, Environmen // Create environment for group EnvironmentGroupMergedField groupEnv(env, *this); - // If model is batched and has kernel weights - const bool kernel = (getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL); - if (kernel && batchSize > 1) { - // Loop through kernel dimensions and multiply together to calculate batch stride - std::ostringstream batchStrideInit; - batchStrideInit << "const unsigned int batchStride = "; - const auto &kernelSize = getArchetype().getKernelSize(); - for (size_t i = 0; i < kernelSize.size(); i++) { - batchStrideInit << getKernelSize(*this, i); - - if (i != (kernelSize.size() - 1)) { - batchStrideInit << " * "; - } - } - batchStrideInit << ";" << std::endl; - groupEnv.add(Type::Uint32.addConst(), "_batch_stride", "batchStride", - {groupEnv.addInitialiser(batchStrideInit.str())}); - } - - // If we're using non-kernel weights, generate loop over source neurons + const bool kernel = (getArchetype().getMatrixType() & SynapseMatrixWeight::KERNEL); if (!kernel) { groupEnv.print("for(unsigned int i = 0; i < $(num_pre); i++)"); groupEnv.getStream() << CodeStream::OB(1); @@ -558,7 +539,7 @@ void SynapseInitGroupMerged::generateInit(const BackendBase &backend, Environmen } // Generate initialisation code - const std::string stride = kernel ? "$(_batch_stride)" : "$(num_pre) * $(_row_stride)"; + const std::string stride = kernel ? "$(_kern_batch_offset)" : "$(num_pre) * $(_row_stride)"; genInitWUVarCode(backend, groupEnv, *this, stride, batchSize, [&backend, kernel, this](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) { @@ -918,51 +899,14 @@ void CustomWUUpdateInitGroupMerged::generateInit(const BackendBase &backend, Env EnvironmentGroupMergedField groupEnv(env, *this); const bool kernel = (getArchetype().getSynapseGroup()->getMatrixType() & SynapseMatrixWeight::KERNEL); - if(kernel) { - // Loop through kernel size dimensions - for (size_t d = 0; d < getArchetype().getSynapseGroup()->getKernelSize().size(); d++) { - // If this dimension has a heterogeneous size, add it to struct - if (isKernelSizeHeterogeneous(*this, d)) { - groupEnv.addField(Type::Uint32, "_kernel_size_" + std::to_string(d), "kernelSize" + std::to_string(d), - [d](const auto &g, size_t) { return std::to_string(g.getSynapseGroup()->getKernelSize().at(d)); }); - } - } - - if(batchSize > 1) { - // Loop through kernel dimensions and multiply together to calculate batch stride - std::ostringstream batchStrideInit; - batchStrideInit << "const unsigned int batchStride = "; - const auto &kernelSize = getArchetype().getSynapseGroup()->getKernelSize(); - for (size_t i = 0; i < kernelSize.size(); i++) { - batchStrideInit << getKernelSize(*this, i); - - if (i != (kernelSize.size() - 1)) { - batchStrideInit << " * "; - } - } - batchStrideInit << ";" << std::endl; - groupEnv.add(Type::Uint32.addConst(), "_batch_stride", "batchStride", - {groupEnv.addInitialiser(batchStrideInit.str())}); - } - } - else { - groupEnv.addField(Type::Uint32.addConst(), "num_pre", - Type::Uint32, "numSrcNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); }); - groupEnv.addField(Type::Uint32.addConst(), "num_post", - Type::Uint32, "numTrgNeurons", - [](const auto &cg, size_t) { return std::to_string(cg.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); }); - groupEnv.addField(Type::Uint32, "_row_stride", "rowStride", - [&backend](const auto &cg, size_t) { return std::to_string(backend.getSynapticMatrixRowStride(*cg.getSynapseGroup())); }); - - - groupEnv.getStream() << "for(unsigned int i = 0; i < " << groupEnv["num_pre"] << "; i++)"; + if(!kernel) { + groupEnv.print("for(unsigned int i = 0; i < $(num_pre); i++)"); groupEnv.getStream() << CodeStream::OB(3); groupEnv.add(Type::Uint32.addConst(), "id_pre", "i"); } // Loop through rows - const std::string stride = kernel ? "$(_batch_stride)" : "$(num_pre) * $(_row_stride)"; + const std::string stride = kernel ? "$(_kern_batch_offset)" : "$(num_pre) * $(_row_stride)"; genInitWUVarCode( backend, groupEnv, *this, stride, getArchetype().isBatched() ? batchSize : 1, [&backend, kernel, this](EnvironmentExternalBase &varInitEnv, BackendBase::HandlerEnv handler) @@ -1075,18 +1019,7 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePreInitGroupMerg //---------------------------------------------------------------------------- void CustomConnectivityUpdatePreInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { - // Create environment for group - EnvironmentGroupMergedField groupEnv(env, *this); - - groupEnv.addField(Type::Uint32.addConst(), "size", - Type::Uint32, "size", - [](const auto &c, size_t) - { - return std::to_string(c.getSynapseGroup()->getSrcNeuronGroup()->getNumNeurons()); - }); - - // Initialise presynaptic custom connectivity update variables - genInitNeuronVarCode(backend, groupEnv, *this, "", "size", 0, batchSize); + genInitNeuronVarCode(backend, env, *this, "", "size", 0, batchSize); } // ---------------------------------------------------------------------------- @@ -1115,18 +1048,8 @@ boost::uuids::detail::sha1::digest_type CustomConnectivityUpdatePostInitGroupMer //---------------------------------------------------------------------------- void CustomConnectivityUpdatePostInitGroupMerged::generateInit(const BackendBase &backend, EnvironmentExternalBase &env, unsigned int batchSize) { - // Create environment for group - EnvironmentGroupMergedField groupEnv(env, *this); - - groupEnv.addField(Type::Uint32.addConst(), "size", - Type::Uint32, "size", - [](const auto &c, size_t) - { - return std::to_string(c.getSynapseGroup()->getTrgNeuronGroup()->getNumNeurons()); - }); - // Initialise presynaptic custom connectivity update variables - genInitNeuronVarCode(backend, groupEnv, *this, "", "size", 0, batchSize); + genInitNeuronVarCode(backend, env, *this, "", "size", 0, batchSize); } // ---------------------------------------------------------------------------- From 48c85c943c83b16dd5bdade0e630ccd53e72816f Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 27 Jul 2023 12:18:41 +0100 Subject: [PATCH 468/725] fixed some more bugs --- src/genn/backends/single_threaded_cpu/backend.cc | 14 +++++++------- src/genn/genn/code_generator/backendSIMT.cc | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 19d47344c5..1c1defa25e 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -838,7 +838,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Begin environment with RNG library and standard library EnvironmentLibrary rngEnv(init, StandardLibrary::getHostRNGFunctions(modelMerged.getModel().getPrecision())); EnvironmentLibrary initEnv(rngEnv, StandardLibrary::getMathsFunctions()); - + initEnv.getStream() << "void initialize()"; { @@ -869,7 +869,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: n.generateInit(*this, groupEnv, 1); } }); - + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Synapse groups" << std::endl; modelMerged.genMergedSynapseInitGroups( @@ -884,9 +884,9 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedSynapseInitGroup" << s.getIndex() << "[g]; " << std::endl; + EnvironmentGroupMergedField groupEnv(funcEnv, s); buildStandardEnvironment(groupEnv, 1); - s.generateInit(*this, groupEnv, 1); } }); @@ -905,13 +905,13 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: // Get reference to group funcEnv.getStream() << "const auto *group = &mergedCustomUpdateInitGroup" << c.getIndex() << "[g]; " << std::endl; - + EnvironmentGroupMergedField groupEnv(funcEnv, c); buildStandardEnvironment(groupEnv); c.generateInit(*this, groupEnv, 1); } }); - + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Custom connectivity presynaptic update groups" << std::endl; modelMerged.genMergedCustomConnectivityUpdatePreInitGroups( @@ -932,7 +932,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: c.generateInit(*this, groupEnv, 1); } }); - + funcEnv.getStream() << "// ------------------------------------------------------------------------" << std::endl; funcEnv.getStream() << "// Custom connectivity postsynaptic update groups" << std::endl; modelMerged.genMergedCustomConnectivityUpdatePostInitGroups( @@ -949,7 +949,7 @@ void Backend::genInit(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase: funcEnv.getStream() << "const auto *group = &mergedCustomConnectivityUpdatePostInitGroup" << c.getIndex() << "[g]; " << std::endl; EnvironmentGroupMergedField groupEnv(funcEnv, c); buildStandardEnvironment(groupEnv); - c.generateInit(*this, funcEnv, 1); + c.generateInit(*this, groupEnv, 1); } }); diff --git a/src/genn/genn/code_generator/backendSIMT.cc b/src/genn/genn/code_generator/backendSIMT.cc index 56c5c45dd0..9abed10b3c 100644 --- a/src/genn/genn/code_generator/backendSIMT.cc +++ b/src/genn/genn/code_generator/backendSIMT.cc @@ -1490,7 +1490,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer rngInitEnv.addField(getPopulationRNGType().createPointer(), "_rng", "rng", [this](const auto &g, size_t) { return getDeviceVarPrefix() + "rng" + g.getName(); }); - genPopulationRNGInit(rngInitEnv.getStream(), printSubs("$(_rng)[$(id)]", groupEnv), + genPopulationRNGInit(rngInitEnv.getStream(), printSubs("$(_rng)[$(id)]", rngInitEnv), "deviceRNGSeed", "id"); } @@ -1522,7 +1522,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer { CodeStream::Scope b(groupEnv.getStream()); - // If population RNGs are initialised on device and this custom connectivity update + // If population RNGs are initialised on device and this custom connectivity update // required one, initialise single RNG using GLOBAL thread id for sequence if(isPopulationRNGInitialisedOnDevice() && Utils::isRNGRequired(cg.getArchetype().getRowUpdateCodeTokens())) { // Add field for RNG @@ -1530,7 +1530,7 @@ void BackendSIMT::genInitializeKernel(EnvironmentExternalBase &env, ModelSpecMer rngInitEnv.addField(getPopulationRNGType().createPointer(), "_rng", "rng", [this](const auto &g, size_t) { return getDeviceVarPrefix() + "rng" + g.getName(); }); - genPopulationRNGInit(rngInitEnv.getStream(), printSubs("$(_rng)[$(id)]", groupEnv), + genPopulationRNGInit(rngInitEnv.getStream(), printSubs("$(_rng)[$(id)]", rngInitEnv), "deviceRNGSeed", "id"); } From 47ae1beaee42e32850fc6eca8e2d999a29947633 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 27 Jul 2023 12:33:51 +0100 Subject: [PATCH 469/725] standard distributions needed whenever backends need a global host RNG, not only in single-threaded CPU --- .../backends/single_threaded_cpu/backend.cc | 18 ----- .../genn/code_generator/generateRunner.cc | 69 ++++++++++--------- 2 files changed, 37 insertions(+), 50 deletions(-) diff --git a/src/genn/backends/single_threaded_cpu/backend.cc b/src/genn/backends/single_threaded_cpu/backend.cc index 1c1defa25e..c1427f19f6 100644 --- a/src/genn/backends/single_threaded_cpu/backend.cc +++ b/src/genn/backends/single_threaded_cpu/backend.cc @@ -1314,14 +1314,6 @@ void Backend::genDefinitionsPreamble(CodeStream &os, const ModelSpecMerged &mode os << "#include " << std::endl; os << "#include " << std::endl; os << "#include " << std::endl; - - // If a global RNG is required, define standard host distributions as recreating them each call is slow - if(isGlobalHostRNGRequired(model)) { - os << "EXPORT_VAR " << "std::uniform_real_distribution<" << model.getPrecision().getName() << "> standardUniformDistribution;" << std::endl; - os << "EXPORT_VAR " << "std::normal_distribution<" << model.getPrecision().getName() << "> standardNormalDistribution;" << std::endl; - os << "EXPORT_VAR " << "std::exponential_distribution<" << model.getPrecision().getName() << "> standardExponentialDistribution;" << std::endl; - os << std::endl; - } } //-------------------------------------------------------------------------- void Backend::genDefinitionsInternalPreamble(CodeStream &os, const ModelSpecMerged &) const @@ -1361,16 +1353,6 @@ void Backend::genDefinitionsInternalPreamble(CodeStream &os, const ModelSpecMerg //-------------------------------------------------------------------------- void Backend::genRunnerPreamble(CodeStream &os, const ModelSpecMerged &modelMerged, const MemAlloc&) const { - const ModelSpecInternal &model = modelMerged.getModel(); - - // If a global RNG is required, implement standard host distributions as recreating them each call is slow - if(isGlobalHostRNGRequired(model)) { - os << "std::uniform_real_distribution<" << model.getPrecision().getName() << "> standardUniformDistribution(" << writePreciseLiteral(0.0, model.getPrecision()) << ", " << writePreciseLiteral(1.0, model.getPrecision()) << ");" << std::endl; - os << "std::normal_distribution<" << model.getPrecision().getName() << "> standardNormalDistribution(" << writePreciseLiteral(0.0, model.getPrecision()) << ", " << writePreciseLiteral(1.0, model.getPrecision()) << ");" << std::endl; - os << "std::exponential_distribution<" << model.getPrecision().getName() << "> standardExponentialDistribution(" << writePreciseLiteral(1.0, model.getPrecision()) << ");" << std::endl; - os << std::endl; - } - os << std::endl; } //-------------------------------------------------------------------------- void Backend::genAllocateMemPreamble(CodeStream&, const ModelSpecMerged&, const MemAlloc&) const diff --git a/src/genn/genn/code_generator/generateRunner.cc b/src/genn/genn/code_generator/generateRunner.cc index 63678609fd..1680a28607 100644 --- a/src/genn/genn/code_generator/generateRunner.cc +++ b/src/genn/genn/code_generator/generateRunner.cc @@ -371,37 +371,6 @@ void genExtraGlobalParam(const ModelSpecMerged &modelMerged, const BackendBase & } } //------------------------------------------------------------------------- -void genGlobalHostRNG(CodeStream &definitionsVar, CodeStream &runnerVarDecl, - CodeStream &runnerVarAlloc, unsigned int seed, MemAlloc &mem) -{ - definitionsVar << "EXPORT_VAR " << "std::mt19937 hostRNG;" << std::endl; - runnerVarDecl << "std::mt19937 hostRNG;" << std::endl; - - // If no seed is specified, use system randomness to generate seed sequence - CodeStream::Scope b(runnerVarAlloc); - if(seed == 0) { - runnerVarAlloc << "uint32_t seedData[std::mt19937::state_size];" << std::endl; - runnerVarAlloc << "std::random_device seedSource;" << std::endl; - runnerVarAlloc << "for(int i = 0; i < std::mt19937::state_size; i++)"; - { - CodeStream::Scope b(runnerVarAlloc); - runnerVarAlloc << "seedData[i] = seedSource();" << std::endl; - } - runnerVarAlloc << "std::seed_seq seeds(std::begin(seedData), std::end(seedData));" << std::endl; - } - // Otherwise, create a seed sequence from model seed - // **NOTE** this is a terrible idea see http://www.pcg-random.org/posts/cpp-seeding-surprises.html - else { - runnerVarAlloc << "std::seed_seq seeds{" << seed << "};" << std::endl; - } - - // Seed RNG from seed sequence - runnerVarAlloc << "hostRNG.seed(seeds);" << std::endl; - - // Add size of Mersenne Twister to memory tracker - mem += MemAlloc::host(sizeof(std::mt19937)); -} -//------------------------------------------------------------------------- template void genRunnerVars(const ModelSpecMerged &modelMerged, const BackendBase &backend, CodeStream &definitionsVar, CodeStream &definitionsFunc, CodeStream &definitionsInternalVar, @@ -630,7 +599,43 @@ MemAlloc GeNN::CodeGenerator::generateRunner(const filesystem::path &outputPath, } // If backend required a global host RNG to simulate (or initialize) this model, generate a standard Mersenne Twister if(backend.isGlobalHostRNGRequired(model)) { - genGlobalHostRNG(definitionsVar, runnerVarDecl, runnerVarAlloc, model.getSeed(), mem); + // Define standard RNG + definitionsVar << "EXPORT_VAR " << "std::mt19937 hostRNG;" << std::endl; + runnerVarDecl << "std::mt19937 hostRNG;" << std::endl; + + // Define standard host distributions as recreating them each call is slow + definitionsVar << "EXPORT_VAR " << "std::uniform_real_distribution<" << model.getPrecision().getName() << "> standardUniformDistribution;" << std::endl; + definitionsVar << "EXPORT_VAR " << "std::normal_distribution<" << model.getPrecision().getName() << "> standardNormalDistribution;" << std::endl; + definitionsVar << "EXPORT_VAR " << "std::exponential_distribution<" << model.getPrecision().getName() << "> standardExponentialDistribution;" << std::endl; + definitionsVar << std::endl; + runnerVarDecl << "std::uniform_real_distribution<" << model.getPrecision().getName() << "> standardUniformDistribution(" << writePreciseLiteral(0.0, model.getPrecision()) << ", " << writePreciseLiteral(1.0, model.getPrecision()) << ");" << std::endl; + runnerVarDecl << "std::normal_distribution<" << model.getPrecision().getName() << "> standardNormalDistribution(" << writePreciseLiteral(0.0, model.getPrecision()) << ", " << writePreciseLiteral(1.0, model.getPrecision()) << ");" << std::endl; + runnerVarDecl << "std::exponential_distribution<" << model.getPrecision().getName() << "> standardExponentialDistribution(" << writePreciseLiteral(1.0, model.getPrecision()) << ");" << std::endl; + runnerVarDecl << std::endl; + + // If no seed is specified, use system randomness to generate seed sequence + CodeStream::Scope b(runnerVarAlloc); + if(model.getSeed() == 0) { + runnerVarAlloc << "uint32_t seedData[std::mt19937::state_size];" << std::endl; + runnerVarAlloc << "std::random_device seedSource;" << std::endl; + runnerVarAlloc << "for(int i = 0; i < std::mt19937::state_size; i++)"; + { + CodeStream::Scope b(runnerVarAlloc); + runnerVarAlloc << "seedData[i] = seedSource();" << std::endl; + } + runnerVarAlloc << "std::seed_seq seeds(std::begin(seedData), std::end(seedData));" << std::endl; + } + // Otherwise, create a seed sequence from model seed + // **NOTE** this is a terrible idea see http://www.pcg-random.org/posts/cpp-seeding-surprises.html + else { + runnerVarAlloc << "std::seed_seq seeds{" << model.getSeed() << "};" << std::endl; + } + + // Seed RNG from seed sequence + runnerVarAlloc << "hostRNG.seed(seeds);" << std::endl; + + // Add size of Mersenne Twister to memory tracker + mem += MemAlloc::host(sizeof(std::mt19937)); } allVarStreams << std::endl; From 3dd6da9251996f35b353efb3f6fcfdc70a9b7e61 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 27 Jul 2023 13:43:43 +0100 Subject: [PATCH 470/725] fixed typo --- include/genn/backends/cuda/backend.h | 4 ++-- include/genn/backends/single_threaded_cpu/backend.h | 4 ++-- include/genn/genn/code_generator/backendBase.h | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/genn/backends/cuda/backend.h b/include/genn/backends/cuda/backend.h index 9b9d7ca9a9..d06b2db60a 100644 --- a/include/genn/backends/cuda/backend.h +++ b/include/genn/backends/cuda/backend.h @@ -238,12 +238,12 @@ class BACKEND_EXPORT Backend : public BackendSIMT const Type::ResolvedType &type, const std::string &name, VarLocation loc, unsigned int batchSize) const final; - //! Generate code for pushing a variable with a size known at tuntime to the 'device' + //! Generate code for pushing a variable with a size known at runtime to the 'device' virtual void genVariableDynamicPush(CodeStream &os, const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; - //! Generate code for pushing a variable with a size known at tuntime to the 'device' + //! Generate code for pushing a variable with a size known at runtime to the 'device' virtual void genLazyVariableDynamicPush(CodeStream &os, const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName) const final; diff --git a/include/genn/backends/single_threaded_cpu/backend.h b/include/genn/backends/single_threaded_cpu/backend.h index e5f5c4cc86..361a159b30 100644 --- a/include/genn/backends/single_threaded_cpu/backend.h +++ b/include/genn/backends/single_threaded_cpu/backend.h @@ -109,12 +109,12 @@ class BACKEND_EXPORT Backend : public BackendBase const Type::ResolvedType &type, const std::string &name, VarLocation loc, unsigned int batchSize) const final; - //! Generate code for pushing a variable with a size known at tuntime to the 'device' + //! Generate code for pushing a variable with a size known at runtime to the 'device' virtual void genVariableDynamicPush(CodeStream &os, const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const final; - //! Generate code for pushing a variable with a size known at tuntime to the 'device' + //! Generate code for pushing a variable with a size known at runtime to the 'device' virtual void genLazyVariableDynamicPush(CodeStream &os, const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName) const final; diff --git a/include/genn/genn/code_generator/backendBase.h b/include/genn/genn/code_generator/backendBase.h index 2cc31f5093..96b177dff8 100644 --- a/include/genn/genn/code_generator/backendBase.h +++ b/include/genn/genn/code_generator/backendBase.h @@ -306,12 +306,12 @@ class GENN_EXPORT BackendBase const Type::ResolvedType &type, const std::string &name, VarLocation loc, unsigned int batchSize) const = 0; - //! Generate code for pushing a variable with a size known at tuntime to the 'device' + //! Generate code for pushing a variable with a size known at runtime to the 'device' virtual void genVariableDynamicPush(CodeStream &os, const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName = "count", const std::string &prefix = "") const = 0; - //! Generate code for pushing a variable with a size known at tuntime to the 'device' + //! Generate code for pushing a variable with a size known at runtime to the 'device' virtual void genLazyVariableDynamicPush(CodeStream &os, const Type::ResolvedType &type, const std::string &name, VarLocation loc, const std::string &countVarName) const = 0; From 84f2cdd755bb337dda2c65eddc559549f15fb301 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 27 Jul 2023 13:56:02 +0100 Subject: [PATCH 471/725] removed random spurious new calls! --- src/genn/backends/cuda/backend.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index df40918d63..b822b2ab91 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -1757,7 +1757,6 @@ void Backend::genVariableDynamicPush(CodeStream &os, os << ", " << countVarName << " * sizeof(" << type.getPointer().valueType->getName() << "), cudaMemcpyHostToDevice));" << std::endl; } else { - os << prefix << name << " = new " << type.getName() << "[" << countVarName << "];" << std::endl; os << "CHECK_CUDA_ERRORS(cudaMemcpy(" << prefix << "d_" << name; os << ", " << prefix << name; os << ", " << countVarName << " * sizeof(" << type.getName() << "), cudaMemcpyHostToDevice));" << std::endl; @@ -1775,7 +1774,6 @@ void Backend::genLazyVariableDynamicPush(CodeStream &os, os << countVarName << " * sizeof(" << type.getPointer().valueType->getName() << "), cudaMemcpyHostToDevice));" << std::endl; } else { - os << "$(d_" << name << ") = new " << type.getName() << "[" << countVarName << "];" << std::endl; os << "CHECK_CUDA_ERRORS(cudaMemcpy($(_d_" << name << "), $(_" << name << "), "; os << countVarName << " * sizeof(" << type.getName() << "), cudaMemcpyHostToDevice));" << std::endl; } From c2d050c4b0a932cb23bdb0d8fe3e897c3b4b90d0 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 27 Jul 2023 14:23:15 +0100 Subject: [PATCH 472/725] fixed lots of small issues with custom connectivity updates --- .../customConnectivityUpdateGroupMerged.h | 22 +++--- include/genn/genn/type.h | 1 + src/genn/genn/code_generator/backendSIMT.cc | 16 +--- .../customConnectivityUpdateGroupMerged.cc | 73 ++++++++++--------- .../genn/code_generator/initGroupMerged.cc | 7 +- 5 files changed, 57 insertions(+), 62 deletions(-) diff --git a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h index 64bc3451f7..ca6c7aa396 100644 --- a/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h +++ b/include/genn/genn/code_generator/customConnectivityUpdateGroupMerged.h @@ -187,16 +187,20 @@ class GENN_EXPORT CustomConnectivityHostUpdateGroupMerged : public GroupMerged(arrayPrefix, [&indexSuffix](VarAccess, const std::string &) { return indexSuffix; }, - fieldSuffix); + fieldSuffix, readOnly); } template - void addVarRefs(const std::string &arrayPrefix, GetVarRefIndexFn getIndexFn, const std::string &fieldSuffix = "") + void addVarRefs(const std::string &arrayPrefix, GetVarRefIndexFn getIndexFn, + const std::string &fieldSuffix = "", bool readOnly = false) { // Loop through variable references const A archetypeAdaptor(this->getGroup().getArchetype()); for(const auto &v : archetypeAdaptor.getDefs()) { // If variable access is read-only, qualify type with const const auto resolvedType = v.type.resolve(this->getGroup().getTypeContext()); - const auto qualifiedType = (v.access & VarAccessModeAttribute::READ_ONLY) ? resolvedType.addConst() : resolvedType; + const auto qualifiedType = (readOnly || (v.access & VarAccessModeAttribute::READ_ONLY)) ? resolvedType.addConst() : resolvedType; addField(qualifiedType, v.name, resolvedType.createPointer(), v.name + fieldSuffix, [arrayPrefix, v](const auto &g, size_t) @@ -676,7 +679,8 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase - void addVarRefs(const std::string &arrayPrefix, const std::string &indexSuffix, const std::string &fieldSuffix = "") + void addVarRefs(const std::string &arrayPrefix, const std::string &indexSuffix, + const std::string &fieldSuffix = "", bool readOnly = false) { addVarRefs(arrayPrefix, [&indexSuffix](VarAccess a, auto &) { return indexSuffix; }, fieldSuffix); diff --git a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc index fd669aa381..a1577e0096 100644 --- a/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/synapseUpdateGroupMerged.cc @@ -33,13 +33,13 @@ void applySynapseSubstitutions(const BackendBase &backend, EnvironmentExternalBa [&sg, batchSize](VarAccess a, const std::string&) { return sg.getPreWUVarIndex(batchSize, getVarAccessDuplication(a), "$(id_pre)"); - }); + }, "", true); synEnv.template addVars( backend.getDeviceVarPrefix(), [&sg, batchSize](VarAccess a, const std::string&) { return sg.getPostWUVarIndex(batchSize, getVarAccessDuplication(a), "$(id_post)"); - }); + }, "", true); // If this synapse group has a kernel From b9e36c89922da46c0c756184a4e62c92f939d459 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 15 Aug 2023 11:11:07 +0100 Subject: [PATCH 711/725] fixed typo in pytest --- tests/features/test_wu_vars.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/features/test_wu_vars.py b/tests/features/test_wu_vars.py index f13b04023c..deeb6ea2c1 100644 --- a/tests/features/test_wu_vars.py +++ b/tests/features/test_wu_vars.py @@ -209,7 +209,7 @@ def test_wu_var(backend, precision, fuse, delay): if not np.allclose(delayed_time, w_value): assert False, f"{s.name} var has wrong value ({w_value} rather than {delayed_time})" -pytest.mark.parametrize("backend", ["single_threaded_cpu", "cuda"]) +@pytest.mark.parametrize("backend", ["single_threaded_cpu", "cuda"]) @pytest.mark.parametrize("precision", [types.Double, types.Float]) @pytest.mark.parametrize("fuse", [True, False]) @pytest.mark.parametrize("delay", [0, 20]) @@ -351,4 +351,4 @@ def test_wu_var_cont(backend, precision, fuse, delay): assert False, f"{s.name} var has wrong value ({w_value} rather than {delayed_time})" if __name__ == '__main__': - test_wu_var_cont("cuda", types.Float, True, 20) + test_wu_var_cont("cuda", types.Float, True, 0) From c1ddd0553ec4ed7174c97c92ea9fb45e682625e2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 15 Aug 2023 13:08:11 +0100 Subject: [PATCH 712/725] isTrueSpikeRequired needed in neuron hash --- src/genn/genn/neuronGroup.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/genn/genn/neuronGroup.cc b/src/genn/genn/neuronGroup.cc index 664fe0b441..b133294d3b 100644 --- a/src/genn/genn/neuronGroup.cc +++ b/src/genn/genn/neuronGroup.cc @@ -520,6 +520,7 @@ boost::uuids::detail::sha1::digest_type NeuronGroup::getHashDigest() const Utils::updateHash(isPrevSpikeTimeRequired(), hash); //Utils::updateHash(getSpikeEventCondition(), hash); **FIXME** Utils::updateHash(isSpikeEventRequired(), hash); + Utils::updateHash(isTrueSpikeRequired(), hash); Utils::updateHash(isSpikeRecordingEnabled(), hash); Utils::updateHash(isSpikeEventRecordingEnabled(), hash); Utils::updateHash(getNumDelaySlots(), hash); @@ -549,6 +550,7 @@ boost::uuids::detail::sha1::digest_type NeuronGroup::getInitHashDigest() const Utils::updateHash(isSpikeTimeRequired(), hash); Utils::updateHash(isPrevSpikeTimeRequired(), hash); Utils::updateHash(isSpikeEventRequired(), hash); + Utils::updateHash(isTrueSpikeRequired(), hash); Utils::updateHash(isSimRNGRequired(), hash); Utils::updateHash(getNumDelaySlots(), hash); Utils::updateHash(m_VarQueueRequired, hash); From b4b297536eecfc5a0b25af0e25ee51341bf6bc6b Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 15 Aug 2023 15:31:46 +0100 Subject: [PATCH 713/725] fixed test - time_min is different depending on precision --- tests/features/test_spike_times.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/features/test_spike_times.py b/tests/features/test_spike_times.py index 355a62686c..d37d5207b0 100644 --- a/tests/features/test_spike_times.py +++ b/tests/features/test_spike_times.py @@ -65,7 +65,8 @@ def test_spike_times(backend, precision): {}, {}) # Add synapse models testing various ways of reading presynaptic WU vars - float_min = np.finfo(np.float32).min + np_scalar = np.float32 if precision == types.Float else np.float64 + float_min = np.finfo(np_scalar).min s_pre_pop = model.add_synapse_population( "PreSynapses", "SPARSE", 20, pre_n_pop, post_n_pop, @@ -113,4 +114,4 @@ def test_spike_times(backend, precision): if __name__ == '__main__': - test_spike_times("cuda", types.Float) \ No newline at end of file + test_spike_times("single_threaded_cpu", types.Double) \ No newline at end of file From 2af369b8a6574d24eb3d9707cfa6d33fe97f1657 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 15 Aug 2023 15:34:20 +0100 Subject: [PATCH 714/725] build pygenn with coverage support --- Jenkinsfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 89c1a46c49..a8a87bcb2c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -203,10 +203,10 @@ for(b = 0; b < builderNodes.size(); b++) { buildStep("Installing PyGeNN (${NODE_NAME})") { dir("genn") { - // Build dynamic LibGeNN + // Build dynamic LibGeNN with coverage support echo "Building LibGeNN"; def commandsLibGeNN = """ - make DYNAMIC=1 LIBRARY_DIRECTORY=`pwd`/pygenn 1>> "${outputFilename}" 2>&1 + make DYNAMIC=1 COVERAGE=1 LIBRARY_DIRECTORY=`pwd`/pygenn 1>> "${outputFilename}" 2>&1 """; def statusLibGeNN = sh script:commandsLibGeNN, returnStatus:true; if (statusLibGeNN != 0) { @@ -218,7 +218,7 @@ for(b = 0; b < builderNodes.size(); b++) { echo "Building and installing PyGeNN"; def commandsPyGeNN = """ . ${WORKSPACE}/venv/bin/activate - pip install -e . 1>> "${outputFilename}" 2>&1 + pip install --install-option="--coverage" --editable . 1>> "${outputFilename}" 2>&1 """; def statusPyGeNN = sh script:commandsPyGeNN, returnStatus:true; if (statusPyGeNN != 0) { From 3ad62eb32fe1f02c798e9047e339aef9db339132 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 15 Aug 2023 15:42:35 +0100 Subject: [PATCH 715/725] bloody pip --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index a8a87bcb2c..f4b01c7653 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -218,7 +218,7 @@ for(b = 0; b < builderNodes.size(); b++) { echo "Building and installing PyGeNN"; def commandsPyGeNN = """ . ${WORKSPACE}/venv/bin/activate - pip install --install-option="--coverage" --editable . 1>> "${outputFilename}" 2>&1 + pip install --editable . --install-option="--coverage" 1>> "${outputFilename}" 2>&1 """; def statusPyGeNN = sh script:commandsPyGeNN, returnStatus:true; if (statusPyGeNN != 0) { From 5ee79deb83c28d177ea6901f6b2d8925d1dbd286 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 15 Aug 2023 15:52:20 +0100 Subject: [PATCH 716/725] give up on pip entirely --- Jenkinsfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index f4b01c7653..111afee8e6 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -214,11 +214,10 @@ for(b = 0; b < builderNodes.size(); b++) { } // Build PyGeNN module - // **NOTE** we have to install echo "Building and installing PyGeNN"; def commandsPyGeNN = """ . ${WORKSPACE}/venv/bin/activate - pip install --editable . --install-option="--coverage" 1>> "${outputFilename}" 2>&1 + python setup.py develop --coverage 1>> "${outputFilename}" 2>&1 """; def statusPyGeNN = sh script:commandsPyGeNN, returnStatus:true; if (statusPyGeNN != 0) { From 087d31736a142d9feb05ae8388bae428e89a5004 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 15 Aug 2023 16:32:48 +0100 Subject: [PATCH 717/725] add coverage compiler flags if PyGeNN is built with coverage --- setup.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/setup.py b/setup.py index ff58e810c0..8bbfe6032a 100644 --- a/setup.py +++ b/setup.py @@ -88,6 +88,12 @@ if WIN: extension_kwargs["extra_compile_args"].extend(["/wd4251", "-DWIN32_LEAN_AND_MEAN", "-DNOMINMAX"]) +if coverage_build: + if LINUX: + extension_kwargs["extra_compile_args"].extend(["--coverage"]) + elif MAC + extension_kwargs["extra_compile_args"].extend(["-fprofile-instr-generate -fcoverage-mapping"]) + # Extend these kwargs for extensions which link against GeNN genn_extension_kwargs = deepcopy(extension_kwargs) genn_extension_kwargs["include_dirs"].extend([genn_include, genn_third_party_include]) From 07aeeafea2a29c2823bd82688ca695e07c2d9c71 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 15 Aug 2023 16:34:00 +0100 Subject: [PATCH 718/725] missing colon --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8bbfe6032a..31a1c30329 100644 --- a/setup.py +++ b/setup.py @@ -91,7 +91,7 @@ if coverage_build: if LINUX: extension_kwargs["extra_compile_args"].extend(["--coverage"]) - elif MAC + elif MAC: extension_kwargs["extra_compile_args"].extend(["-fprofile-instr-generate -fcoverage-mapping"]) # Extend these kwargs for extensions which link against GeNN From a31da6e27402c4b2d45010dbde737f115b6fa6a2 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 15 Aug 2023 16:45:24 +0100 Subject: [PATCH 719/725] coverage linker arguments --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 31a1c30329..df3cc826bd 100644 --- a/setup.py +++ b/setup.py @@ -91,8 +91,9 @@ if coverage_build: if LINUX: extension_kwargs["extra_compile_args"].extend(["--coverage"]) + extension_kwargs["extra_link_args"].extend(["--coverage"]) elif MAC: - extension_kwargs["extra_compile_args"].extend(["-fprofile-instr-generate -fcoverage-mapping"]) + extension_kwargs["extra_compile_args"].extend(["-fprofile-instr-generate", "-fcoverage-mapping"]) # Extend these kwargs for extensions which link against GeNN genn_extension_kwargs = deepcopy(extension_kwargs) From b08fd1d88931cb8fed5b4efb2fa22cba1833c052 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 15 Aug 2023 16:58:18 +0100 Subject: [PATCH 720/725] new bash script for gathering combined feature and unit test coverage --- tests/gather_coverage.sh | 86 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100755 tests/gather_coverage.sh diff --git a/tests/gather_coverage.sh b/tests/gather_coverage.sh new file mode 100755 index 0000000000..c22a4c4433 --- /dev/null +++ b/tests/gather_coverage.sh @@ -0,0 +1,86 @@ +#!/bin/bash +# By default no flags are passed to genn-buildmodel.sh +REPORT=0 + +# Parse command line arguments +OPTIND=1 +while getopts "r" opt; do + case "$opt" in + r) REPORT=1 + ;; + esac +done + +# Find this script i.e. tests directory and hence GeNN itself +TESTS_DIR=$(dirname "$0") +GENN_PATH=$TESTS_DIR/../ + +# Clean GeNN library and build a version of the single-threaded CPU backend with coverage calculation built +cd $GENN_PATH + +if [[ "$(uname)" = "Darwin" ]]; then + # Loop through features and build list of raw profile output files + for f in features/* ; do + if [[ -f "$f/default.profraw" && -f "$f/generator_coverage" ]]; then + LLVM_PROFRAW_FILES+="$f/default.profraw " + + if [ -z "$LLVM_TEST_EXECUTABLES" ]; then + LLVM_TEST_EXECUTABLES+="$f/generator_coverage " + else + LLVM_TEST_EXECUTABLES+="-object $f/generator_coverage " + fi + fi + done + + # Add unit tests profiling data to lists + if [[ -f "unit/default.profraw" && -f "unit/test_coverage" ]]; then + LLVM_PROFRAW_FILES+="unit/default.profraw " + + if [ -z "$LLVM_TEST_EXECUTABLES" ]; then + LLVM_TEST_EXECUTABLES+="unit/test_coverage " + else + LLVM_TEST_EXECUTABLES+="-object unit/test_coverage " + fi + fi + + # Merge coverage + xcrun llvm-profdata merge -sparse $LLVM_PROFRAW_FILES -o coverage.profdata + + # 'Show' text based coverage + xcrun llvm-cov show $LLVM_TEST_EXECUTABLES -instr-profile=coverage.profdata > coverage_$NODE_NAME.txt +else + # Loop through all object directories with coverage data + for OBJ_DIR in obj_coverage*/ ; do + # Use lcov to capture libgenn coverage + OBJ_NAME=$(basename $OBJ_DIR) + lcov --directory ${OBJ_DIR}/genn/genn --base-directory src/genn/genn --capture -rc lcov_branch_coverage=1 --output-file genn_${OBJ_NAME}.txt + + # Add tracefile to list of tracefile arguments to pass to lcov + LCOV_TRACEFILE_ARGS+=" --add-tracefile genn_${OBJ_NAME}.txt" + + # Loop through directories in which there might be coverage for backends + for BACKEND_OBJ_DIR in ${OBJ_DIR}/genn/backends/*/ ; do + # Get corresponding module name + MODULE=$(basename $BACKEND_OBJ_DIR) + + # Use lcov to capture all coverage for this module + lcov --directory $BACKEND_OBJ_DIR --base-directory src/genn/backends/$MODULE/ --capture -rc lcov_branch_coverage=1 --output-file ${MODULE}_${OBJ_NAME}.txt + + # Add tracefile to list of tracefile arguments to pass to lcov + LCOV_TRACEFILE_ARGS+=" --add-tracefile ${MODULE}_${OBJ_NAME}.txt" + done + done + + # Combine all tracefiles together + lcov $LCOV_TRACEFILE_ARGS --output-file coverage_$NODE_NAME.txt + + # Strip system libraries from output + lcov --remove coverage_$NODE_NAME.txt "/usr/*" --output-file coverage_$NODE_NAME.txt +fi + +if [ $REPORT -eq 1 ]; then + echo "Generating HTML coverage report..." + + # Generate browseable HTML + genhtml coverage_$NODE_NAME.txt --branch-coverage --output-directory ./code_coverage_report/ +fi From cbf088bfa5e0f67997c4ccccaec528553c13d9a7 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 15 Aug 2023 17:04:28 +0100 Subject: [PATCH 721/725] run gather coverage from jenkins and upload both --- Jenkinsfile | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 111afee8e6..63a4fd5ca2 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -135,6 +135,7 @@ for(b = 0; b < builderNodes.size(); b++) { def outputFilename = "${WORKSPACE}/msg_${NODE_NAME}.txt"; def coveragePython = "${WORKSPACE}/coverage_python_${NODE_NAME}.xml"; + def coverageCPP = "${WORKSPACE}/genn/coverage_${NODE_NAME}.txt"; buildStep("Running unit tests (" + env.NODE_NAME + ")") { // Run automatic tests dir("genn") { @@ -248,22 +249,33 @@ for(b = 0; b < builderNodes.size(); b++) { } } - /*buildStep("Uploading coverage (${NODE_NAME})") { + buildStep("Uploading coverage (${NODE_NAME})") { dir("genn/tests") { if(isUnix()) { - // If Python coverage was emitted - if(fileExists(coveragePython)) { - // Upload to code cov - withCredentials([string(credentialsId: "codecov_token_genn", variable: "CODECOV_TOKEN")]) { - sh 'curl -s https://codecov.io/bash | bash -s - -n ' + env.NODE_NAME + ' -f ' + uniqueCoverage + ' -t $CODECOV_TOKEN'; + // Run script to gather together GCOV coverage from unit and feature tests + sh './gather_coverage.sh' + + // Upload to code cov + withCredentials([string(credentialsId: "codecov_token_genn", variable: "CODECOV_TOKEN")]) { + // Upload Python coverage if it was produced + if(fileExists(coveragePython)) { + sh 'curl -s https://codecov.io/bash | bash -s - -n ' + env.NODE_NAME + ' -f ' + coveragePython + ' -t $CODECOV_TOKEN'; + } + else { + echo coveragePython + " doesn't exist!"; + } + + // Upload CPP coverage if it was produced + if(fileExists(coverageCPP)) { + sh 'curl -s https://codecov.io/bash | bash -s - -n ' + env.NODE_NAME + ' -f ' + coverageCPP + ' -t $CODECOV_TOKEN'; + } + else { + echo coverageCPP + " doesn't exist!"; } - } - else { - echo uniqueCoverage + " doesn't exist!"; } } } - }*/ + } buildStep("Building Python wheels (${NODE_NAME})") { dir("genn") { From 45c5b69f2928d58fe7f5da789574b857931dc562 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Tue, 15 Aug 2023 17:39:42 +0100 Subject: [PATCH 722/725] * tee compile output to seperate file for warnings * move outputs into genn so archive and warnings can hopefully find --- Jenkinsfile | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 63a4fd5ca2..17c3de7b1b 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -133,8 +133,9 @@ for(b = 0; b < builderNodes.size(); b++) { } } - def outputFilename = "${WORKSPACE}/msg_${NODE_NAME}.txt"; - def coveragePython = "${WORKSPACE}/coverage_python_${NODE_NAME}.xml"; + def outputFilename = "${WORKSPACE}/genn/output_${NODE_NAME}.txt"; + def compileOutputFilename = "${WORKSPACE}/genn/compile_${NODE_NAME}.txt"; + def coveragePython = "${WORKSPACE}/genn/coverage_python_${NODE_NAME}.xml"; def coverageCPP = "${WORKSPACE}/genn/coverage_${NODE_NAME}.txt"; buildStep("Running unit tests (" + env.NODE_NAME + ")") { // Run automatic tests @@ -207,7 +208,7 @@ for(b = 0; b < builderNodes.size(); b++) { // Build dynamic LibGeNN with coverage support echo "Building LibGeNN"; def commandsLibGeNN = """ - make DYNAMIC=1 COVERAGE=1 LIBRARY_DIRECTORY=`pwd`/pygenn 1>> "${outputFilename}" 2>&1 + make DYNAMIC=1 COVERAGE=1 LIBRARY_DIRECTORY=`pwd`/pygenn 2>&1 | tee -a "${compileOutputFilename}" >> "${outputFilename}" """; def statusLibGeNN = sh script:commandsLibGeNN, returnStatus:true; if (statusLibGeNN != 0) { @@ -218,7 +219,7 @@ for(b = 0; b < builderNodes.size(); b++) { echo "Building and installing PyGeNN"; def commandsPyGeNN = """ . ${WORKSPACE}/venv/bin/activate - python setup.py develop --coverage 1>> "${outputFilename}" 2>&1 + python setup.py develop --coverage 2>&1 | tee -a "${compileOutputFilename}" >> "${outputFilename}" """; def statusPyGeNN = sh script:commandsPyGeNN, returnStatus:true; if (statusPyGeNN != 0) { @@ -365,17 +366,19 @@ for(b = 0; b < builderNodes.size(); b++) { buildStep("Archiving output (${NODE_NAME})") { dir("genn") { - archive outputFilename; + def outputPattern = "output_" + env.NODE_NAME + ".txt"; + archive outputPattern; // Run 'next-generation' warning plugin on results + def compilePattern = "compile_" + env.NODE_NAME + ".txt"; if("mac" in nodeLabel) { - recordIssues enabledForFailure: true, tool: clang(pattern: outputFilename); + recordIssues enabledForFailure: true, tool: clang(pattern: compilePattern); } else if("windows" in nodeLabel){ - recordIssues enabledForFailure: true, tool: msBuild(pattern: outputFilename); + recordIssues enabledForFailure: true, tool: msBuild(pattern: compilePattern); } else { - recordIssues enabledForFailure: true, tool: gcc4(pattern: outputFilename); + recordIssues enabledForFailure: true, tool: gcc4(pattern: compilePattern); } } From 46dfb23aef54ff8d7283121ffdba8e640a9a0d30 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Wed, 16 Aug 2023 11:19:07 +0100 Subject: [PATCH 723/725] Always copy logic in ``EnvironmentLocalCacheBase`` was overly pessimistic - only delayed variables should always be copied --- .../genn/genn/code_generator/environment.h | 38 ++++++++++++++----- .../code_generator/customUpdateGroupMerged.cc | 8 ++-- .../code_generator/neuronUpdateGroupMerged.cc | 28 ++++++++++---- 3 files changed, 53 insertions(+), 21 deletions(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 5062fae4db..88c3e82164 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -704,18 +704,27 @@ class VarCachePolicy public: using GroupInternal = typename G::GroupInternal; using GetIndexFn = std::function; + using ShouldAlwaysCopyFn = std::function; - VarCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex) - : m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex) + VarCachePolicy(GetIndexFn getReadIndex, GetIndexFn getWriteIndex, + ShouldAlwaysCopyFn shouldAlwaysCopy = ShouldAlwaysCopyFn()) + : m_GetReadIndex(getReadIndex), m_GetWriteIndex(getWriteIndex), + m_ShouldAlwaysCopy(shouldAlwaysCopy) {} - VarCachePolicy(GetIndexFn getIndex) - : m_GetReadIndex(getIndex), m_GetWriteIndex(getIndex) + VarCachePolicy(GetIndexFn getIndex, ShouldAlwaysCopyFn shouldAlwaysCopy = ShouldAlwaysCopyFn()) + : m_GetReadIndex(getIndex), m_GetWriteIndex(getIndex), + m_ShouldAlwaysCopy(shouldAlwaysCopy) {} //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ + bool shouldAlwaysCopy(G&, const Models::Base::Var &var) const + { + return m_ShouldAlwaysCopy(var.name, getVarAccessDuplication(var.access)); + } + std::string getReadIndex(G&, const Models::Base::Var &var) const { return m_GetReadIndex(var.name, getVarAccessDuplication(var.access)); @@ -737,6 +746,7 @@ class VarCachePolicy //------------------------------------------------------------------------ GetIndexFn m_GetReadIndex; GetIndexFn m_GetWriteIndex; + ShouldAlwaysCopyFn m_ShouldAlwaysCopy; }; //------------------------------------------------------------------------ @@ -761,6 +771,13 @@ class VarRefCachePolicy //------------------------------------------------------------------------ // Public API //------------------------------------------------------------------------ + bool shouldAlwaysCopy(G&, const Models::Base::VarRef &var) const + { + // **NOTE** something else is managing the actual variables + // and is therefore responsible for copying between delay slots etc + return false; + } + std::string getReadIndex(G &g, const Models::Base::VarRef &var) const { return m_GetReadIndex(var.name, A(g.getArchetype()).getInitialisers().at(var.name)); @@ -798,11 +815,11 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P public: template EnvironmentLocalCacheBase(G &group, F &fieldGroup, const Type::TypeContext &context, EnvironmentExternalBase &enclosing, - const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, bool alwaysCopy, + const std::string &arrayPrefix, const std::string &fieldSuffix, const std::string &localPrefix, PolicyArgs&&... policyArgs) : EnvironmentExternalBase(enclosing), P(std::forward(policyArgs)...), m_Group(group), m_FieldGroup(fieldGroup), m_Context(context), m_Contents(m_ContentsStream), m_ArrayPrefix(arrayPrefix), - m_FieldSuffix(fieldSuffix), m_LocalPrefix(localPrefix), m_AlwaysCopy(alwaysCopy) + m_FieldSuffix(fieldSuffix), m_LocalPrefix(localPrefix) { // Copy variables into variables referenced, alongside boolean const auto defs = A(m_Group.get().getArchetype()).getDefs(); @@ -820,7 +837,11 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P const auto varDefs = archetypeAdapter.getDefs(); std::vector referencedDefs; std::copy_if(varDefs.cbegin(), varDefs.cend(), std::back_inserter(referencedDefs), - [this](const auto &v){ return m_AlwaysCopy || m_VariablesReferenced.at(v.name).first; }); + [this](const auto &v) + { + const bool alwaysCopy = this->shouldAlwaysCopy(m_Group.get(), v); + return (alwaysCopy || m_VariablesReferenced.at(v.name).first); + }); // Loop through referenced definitions for(const auto &v : referencedDefs) { @@ -855,7 +876,7 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P // Loop through referenced definitions again for(const auto &v : referencedDefs) { // If we should always copy variable or variable is read-write - if(m_AlwaysCopy || v.access & VarAccessMode::READ_WRITE) { + if(this->shouldAlwaysCopy(m_Group.get(), v) || v.access & VarAccessMode::READ_WRITE) { getContextStream() << "group->" << v.name << m_FieldSuffix << "[" << printSubs(this->getWriteIndex(m_Group.get(), v), *this) << "]"; getContextStream() << " = _" << m_LocalPrefix << v.name << ";" << std::endl; } @@ -921,7 +942,6 @@ class EnvironmentLocalCacheBase : public EnvironmentExternalBase, public P std::string m_ArrayPrefix; std::string m_FieldSuffix; std::string m_LocalPrefix; - bool m_AlwaysCopy; std::unordered_map> m_VariablesReferenced; }; diff --git a/src/genn/genn/code_generator/customUpdateGroupMerged.cc b/src/genn/genn/code_generator/customUpdateGroupMerged.cc index e734d53866..901d8be7aa 100644 --- a/src/genn/genn/code_generator/customUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/customUpdateGroupMerged.cc @@ -59,7 +59,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( - *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", false, + *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", [this, &cuEnv](const std::string&, VarAccessDuplication d) { return getVarIndex(d, "$(id)"); @@ -67,7 +67,7 @@ void CustomUpdateGroupMerged::generateCustomUpdate(const BackendBase &backend, E // Create an environment which caches variable references in local variables if they are accessed EnvironmentLocalVarRefCache varRefEnv( - *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", false, + *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", [this, &varEnv](const std::string&, const Models::VarReference &v) { return getVarRefIndex(v.getDelayNeuronGroup() != nullptr, @@ -184,7 +184,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( - *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", false, + *this, *this, getTypeContext(), cuEnv, backend.getDeviceVarPrefix(), "", "l", [this, &cuEnv](const std::string&, VarAccessDuplication d) { return getVarIndex(d, "$(id_syn)"); @@ -192,7 +192,7 @@ void CustomUpdateWUGroupMergedBase::generateCustomUpdate(const BackendBase &back // Create an environment which caches variable references in local variables if they are accessed EnvironmentLocalVarRefCache varRefEnv( - *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", false, + *this, *this, getTypeContext(), varEnv, backend.getDeviceVarPrefix(), "", "l", [this, &varEnv](const std::string&, const Models::WUVarReference &v) { return getVarRefIndex(getVarAccessDuplication(v.getVar().access), diff --git a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc index 1bc37254e2..a871b43515 100644 --- a/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc +++ b/src/genn/genn/code_generator/neuronUpdateGroupMerged.cc @@ -39,7 +39,7 @@ void NeuronUpdateGroupMerged::CurrentSource::generate(const BackendBase &backend // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( - *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", false, + *this, ng, getTypeContext(), csEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [batchSize, &ng](const std::string&, VarAccessDuplication d) { return ng.getVarIndex(batchSize, d, "$(id)"); @@ -120,7 +120,7 @@ void NeuronUpdateGroupMerged::InSynPSM::generate(const BackendBase &backend, Env // Create an environment which caches variables in local variables if they are accessed EnvironmentLocalVarCache varEnv( - *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", false, + *this, ng, getTypeContext(), psmEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [batchSize, &ng](const std::string&, VarAccessDuplication d) { return ng.getVarIndex(batchSize, d, "$(id)"); @@ -198,10 +198,10 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back synEnv.addExtraGlobalParams(wum->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", fieldSuffix); // Create an environment which caches variables in local variables if they are accessed - // **NOTE** always copy variables here as this is when they are copied between delay slots + // **NOTE** always copy variables if synapse group is delayed const bool delayed = (getArchetype().getBackPropDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( - *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", true, + *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { return ng.getReadVarIndex(delayed, batchSize, d, "$(id)"); @@ -209,6 +209,10 @@ void NeuronUpdateGroupMerged::InSynWUMPostCode::generate(const BackendBase &back [batchSize, delayed, &synEnv, &ng](const std::string&, VarAccessDuplication d) { return ng.getWriteVarIndex(delayed, batchSize, d, "$(id)"); + }, + [delayed](const std::string&, VarAccessDuplication) + { + return delayed; }); /*neuronSubstitutionsInSynapticCode(varEnv, &ng.getArchetype(), "", "_post", "", "", "", dynamicsNotSpike, @@ -285,10 +289,10 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back synEnv.addExtraGlobalParams(wum->getExtraGlobalParams(), backend.getDeviceVarPrefix(), "", fieldSuffix); // Create an environment which caches variables in local variables if they are accessed - // **NOTE** always copy variables here as this is when they are copied between delay slots + // **NOTE** always copy variables if synapse group is delayed const bool delayed = (getArchetype().getDelaySteps() != NO_DELAY); EnvironmentLocalVarCache varEnv( - *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", true, + *this, ng, getTypeContext(), synEnv, backend.getDeviceVarPrefix(), fieldSuffix, "l", [batchSize, delayed, &ng](const std::string&, VarAccessDuplication d) { return ng.getReadVarIndex(delayed, batchSize, d, "$(id)"); @@ -296,6 +300,10 @@ void NeuronUpdateGroupMerged::OutSynWUMPreCode::generate(const BackendBase &back [batchSize, delayed, &ng](const std::string&, VarAccessDuplication d) { return ng.getWriteVarIndex(delayed, batchSize, d, "$(id)"); + }, + [delayed](const std::string&, VarAccessDuplication) + { + return delayed; }); /*neuronSubstitutionsInSynapticCode(subs, &ng.getArchetype(), "", "_pre", "", "", "", dynamicsNotSpike, @@ -498,9 +506,9 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E // Create an environment which caches variables in local variables if they are accessed // **NOTE** we do this right at the top so that local copies can be used by child groups - // **NOTE** always copy variables here as this is when they are copied between delay slots + // **NOTE** always copy variables if variable is delayed EnvironmentLocalVarCache neuronVarEnv( - *this, *this, getTypeContext(), neuronEnv, backend.getDeviceVarPrefix(), "", "l", true, + *this, *this, getTypeContext(), neuronEnv, backend.getDeviceVarPrefix(), "", "l", [batchSize, &neuronEnv, this](const std::string &varName, VarAccessDuplication d) { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); @@ -510,6 +518,10 @@ void NeuronUpdateGroupMerged::generateNeuronUpdate(const BackendBase &backend, E { const bool delayed = (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); return getWriteVarIndex(delayed, batchSize, d, "$(id)") ; + }, + [this](const std::string &varName, VarAccessDuplication) + { + return (getArchetype().isVarQueueRequired(varName) && getArchetype().isDelayRequired()); }); From 3870b19a81c7930cf607a66c5b4bd32b425302b8 Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Thu, 17 Aug 2023 10:14:03 +0100 Subject: [PATCH 724/725] fixed failing test --- tests/unit/synapseGroup.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/synapseGroup.cc b/tests/unit/synapseGroup.cc index c09845904f..1be3cdd278 100644 --- a/tests/unit/synapseGroup.cc +++ b/tests/unit/synapseGroup.cc @@ -687,7 +687,7 @@ TEST(SynapseGroup, InitCompareWUDifferentVars) // Check that only synaptic weight initialistion parameters are heterogeneous ASSERT_FALSE(modelSpecMerged.getMergedSynapseConnectivityInitGroups().at(0).isSparseConnectivityInitParamHeterogeneous("prob")); - ASSERT_FALSE(modelSpecMerged.getMergedSynapseConnectivityInitGroups().at(0).isSparseConnectivityInitDerivedParamHeterogeneous("prob")); + ASSERT_FALSE(modelSpecMerged.getMergedSynapseConnectivityInitGroups().at(0).isSparseConnectivityInitDerivedParamHeterogeneous("probLogRecip")); ASSERT_TRUE(modelSpecMerged.getMergedSynapseSparseInitGroups().at(0).isVarInitParamHeterogeneous("g", "constant")); } From ca12c317e377353f7f36456ef8dae4ecad7d805c Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 1 Sep 2023 15:00:32 +0100 Subject: [PATCH 725/725] fix bug in environment # Conflicts: # include/genn/genn/code_generator/environment.h --- include/genn/genn/code_generator/environment.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/include/genn/genn/code_generator/environment.h b/include/genn/genn/code_generator/environment.h index 88c3e82164..3b380ed126 100644 --- a/include/genn/genn/code_generator/environment.h +++ b/include/genn/genn/code_generator/environment.h @@ -722,7 +722,12 @@ class VarCachePolicy //------------------------------------------------------------------------ bool shouldAlwaysCopy(G&, const Models::Base::Var &var) const { - return m_ShouldAlwaysCopy(var.name, getVarAccessDuplication(var.access)); + if(m_ShouldAlwaysCopy) { + return m_ShouldAlwaysCopy(var.name, getVarAccessDuplication(var.access)); + } + else { + return false; + } } std::string getReadIndex(G&, const Models::Base::Var &var) const