diff --git a/NONCOMMERCIAL.txt b/NONCOMMERCIAL.txt new file mode 100644 index 0000000..86498a0 --- /dev/null +++ b/NONCOMMERCIAL.txt @@ -0,0 +1,146 @@ +NON_COMMERICAL SOFTWARE LICENSE FOR THE GELIB SOFTWARE LIBRARY + +Copyright (c) 2021- Imre (Risi) Kondor. All rights reserved. + + +DEFINITIONS + +"Program" means a copy of the GELIB software library or parts of the GELIB software library explicitly +marked in the source code as distributed under this Noncommercial Software License. + +"Copyright holder" means the author of GELIB, Imre Kondor, who retains the copyright to Program. + +"Work based on the Program" means either the Program or any derivative work under copyright law: that is +to say, a work containing the Program or a portion of it, either verbatim or with modifications and/or +translated into another language. (Hereinafter, translation is included without limitation in the term +"modification".) + +"Using the Program" means any act of creating executables that contain or directly use libraries that +are part of the Program, running any part of the Program or any tools that are part of the Program, or +creating works based on the Program. + +Each licensee is addressed as "you". + + +TERMS AND CONDITIONS FOR USE, COPYING, DISTRIBUTION AND MODIFICATION + +1. This License grants you permission to use the Program free of charge for any noncommercial purpose, +including teaching and research at universities, colleges and other educational institutions, research +at non-profit research institutions, and personal non-profit purposes. + +2. This License does NOT grant permission to use the Program for commercial purposes, including but not +restricted to (a) bundling or integrating the Program with any hardware product or any other software for +transfer, sale or license to a third party (even if distributing the Program on separate media and not +charging for the Program); (b) providing customers with a link to the Program or a copy of the Program +for use with hardware or another program purchased by that customer; or (c) use in connection with the +performance of services for which you are compensated (d) use in connection with research and development +activities in the service of developing commercial products or obtaining patents for derived products +such as pharmaceuticals; (e) other forms of indirect commercial use, such as on a website that accepts +advertising money for content. + +3. You may copy and distribute verbatim copies of the Program's source code as you receive it, in any +medium, provided that you retain the copyright notice on each file of the source code and conspicuously +and appropriately include a copy of this License and Disclaimer of Warranty with the Program in a file +named LICENSE.TXT. + +4. You may modify your copy or copies of the Program or any portion of it, thus forming a work based on +the Program, and copy and distribute such modifications or work under the terms of Section 2 above, +provided that: + +a) You cause the modified files to carry prominent notices stating that you changed the files and the +date of any change. + +b) You cause any work that you distribute or publish, that in whole or in part contains or is derived +from the Program or any part thereof, to be licensed as a whole at no charge to all third parties under +the terms of this License. + +c) You retain the original copyright notice on each file of this Program's source code and conspicuously +include a copy of this License and Disclaimer of Warranty under the terms described in Section 3. + +These requirements apply to the modified work as a whole. If identifiable sections of that work are not +derived from the Program, and can be reasonably considered independent and separate works in themselves, +then this License, and its terms, do not apply to those sections when you distribute them as separate +works. But when you distribute the same sections as part of a whole which is a work based on the Program, +the distribution of the whole must be on the terms of this License, whose regulations for other licensees +extend to the entire whole, and thus to each and every part regardless of who wrote it. (If the same, +independent sections are distributed as part of a package that is otherwise reliant on, or is based on +the Program, then the distribution of the whole package, including but not restricted to the independent +section, must be on the unmodified terms of this License, regadless of who the author of the included +sections was.) + +Thus, it is not the intent of this section to claim rights or contest your rights to work written entirely +by you; rather, the intent is to exercise the right to control the distribution of derivative or collective +works based or reliant on the Program. + +In addition, mere aggregation of another work not based on the Program with the Program (or with a work +based on the Program) on a volume of storage or distribution medium does not bring the other work under +the scope of this License. + +5. You may copy and distribute the Program (or a work based on it, under Section 3) in object code or +executable form under the terms of Sections 3 and 4 above provided that you also accompany it with the +complete corresponding machine-readable source code under the terms of Sections 3 and 4, as well as the +License and Disclaimer of Warranty, under the terms of Section 3. + +If distribution of executable or object code is made by offering access to copy from a designated place, +then offering equivalent access to copy the source code from the same place counts as distribution of the +source code, even though third parties are not compelled to copy the source along with the object code. + +6. You may not copy, modify, sublicense, or distribute the Program except as expressly provided under this +License. Any attempt otherwise to copy, modify, sublicense or distribute the Program is void, and will +automatically terminate your rights under this License. However, parties who have received copies, or rights, +from you under this License will not have their licenses terminated so long as such parties remain in full +compliance. + +7. You are not required to accept this License, since you have not signed it. Nothing else grants you +permission to modify or distribute the Program or its derivative works; law prohibits these actions if you +do not accept this License. Therefore, by modifying or distributing the Program (or any work based on the +Program), you indicate your acceptance of this License and all its terms and conditions for copying, +distributing or modifying the Program or works based on it, to do so. + +8. Each time you redistribute the Program (or any work based on the Program), the recipient automatically +receives a license from the original licensor to copy, distribute or modify the Program subject to these +terms and conditions. You may not impose any further restrictions on the recipients to exercise the +rights granted herein. You are not responsible for enforcing compliance by third parties to this License. + +9. If, as a consequence of a court judgment or allegation of patent infringement or for any other reason +(not limited to patent issues), conditions are imposed on you (whether by court order, agreement or otherwise) +that contradict the conditions of this License, they do not excuse you from the conditions of this License. +If you cannot distribute so as to satisfy simultaneously your obligations under this License and any other +pertinent obligations, then as a consequence you may not distribute the Program at all. For example, if a +patent license would not permit royalty-free redistribution of the Program by all those who receive copies +directly or indirectly through you, then the only way you could satisfy both it and this License would be to +refrain entirely from distribution of the Program. If any portion of this section is held invalid or +unenforceable under any particular circumstance, the balance of the section is intended to apply and the +section as a whole is intended to apply in other circumstances. + +10. If the distribution and/or use of the Program are restricted in certain countries either by patents or +by copyrighted interfaces, the original copyright holder who places the Program under this License may add an +explicit geographical distribution limitation excluding those countries, so that distribution is permitted +only in or among countries not thus excluded. In such case, this License incorporates the limitation as if +written in the body of this License. + +11. Copyright holder retains the right to grant broader rights to the Program to individuals or to commercial +entities on a case by case basis, possibly for a fee. + + +DISCLAIMER OF WARRANTY + +12. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +INCLUDING, BUT NOT LIMITED TO, ANY IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE +PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + +13. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED ON IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY +OTHER PARTY WHO MAY MODIFY AND/OR REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, +INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO +USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +OR PROFITS; BUSINESS INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + + + + + + + + diff --git a/common.txt b/common.txt index 8d003bb..22313c3 100644 --- a/common.txt +++ b/common.txt @@ -29,7 +29,7 @@ SO2DIR=$(ROOTDIR)/objects/SO2 SO3DIR=$(ROOTDIR)/objects/SO3 SO3NDIR=$(ROOTDIR)/objects/SO3n SO3CDIR=$(ROOTDIR)/objects/SO3c -GELIB_CUDADIR=$(ROOTDIR)/../GElib-cuda/cuda +GELIB_CUDADIR=$(ROOTDIR)/cuda # COMBINATORIALDIR=$(ROOTDIR)/objects/combinatorial # GROUPSDIR=$(ROOTDIR)/objects/groups diff --git a/cuda/GElib_base.cu b/cuda/GElib_base.cu new file mode 100644 index 0000000..206f998 --- /dev/null +++ b/cuda/GElib_base.cu @@ -0,0 +1,19 @@ +/* + * This file is part of GElib, a C++/CUDA library for group equivariant + * tensor operations. + * + * Copyright (c) 2023, Imre Risi Kondor + * + * This source code file is subject to the terms of the noncommercial + * license distributed with GElib in the file NONCOMMERICAL.TXT. Commercial + * use is prohibited. All redistributed versions of this file (in orginal + * or modified form) must retain this copyright notice and must be + * accompanied by a verbatim copy of the license. + * + */ + +#include +#include +#include "GElib_base.hpp" + +__device__ __constant__ unsigned char cg_cmem[CNINE_CONST_MEM_SIZE]; diff --git a/cuda/Generate_SO3part_addCGproduct_kernel_calls.cpp b/cuda/Generate_SO3part_addCGproduct_kernel_calls.cpp new file mode 100644 index 0000000..b60f157 --- /dev/null +++ b/cuda/Generate_SO3part_addCGproduct_kernel_calls.cpp @@ -0,0 +1,54 @@ +/* + * This file is part of GElib, a C++/CUDA library for group + * equivariant tensor operations. + * + * Copyright (c) 2023, Imre Risi Kondor + * + * This source code file is subject to the terms of the noncommercial + * license distributed with GElib in the file NONCOMMERICAL.TXT. Commercial + * use is prohibited. All redistributed versions of this file (in + * original or modified form) must retain this copyright notice and + * must be accompanied by a verbatim copy of the license. + * + */ + +#include "GElib_base.cpp" +#include "GElibSession.hpp" +#include + +using namespace cnine; +using namespace GElib; + +const int maxl1=2; +const int maxl=4; + + +int main(int argc, char** arg){ + + ofstream ofs("SO3part_addCGproduct_explicit_calls.inc"); + + ofs<<" switch(l1){\n"; + for(int l1=0; l1<=maxl1; l1++){ + ofs<<" case "<" + <<"<<>>(r,x,y); break;"< + +extern GElib::SO3_CGbank SO3_cgbank; + +using namespace cnine; +using namespace GElib; + +const int maxl1=2; +const int maxl=4; + + +int main(int argc, char** arg){ + + ofstream ofs("SO3part_addCGproduct_subkernels.inc"); + + for(int l1=0; l1<=maxl1; l1++){ + + for(int l2=0; l2<=maxl1; l2++){ + + for(int l=std::abs(l1-l2); l<=l1+l2 && l<=maxl; l++){ + auto& C=SO3_cgbank.getf(CGindex(l1,l2,l)); + + ofs<<"__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_"< +#include +#include "GElib_base.hpp" + +__device__ __constant__ unsigned char cg_cmem[CNINE_CONST_MEM_SIZE]; +#define _SO3CG_CUDA_CONCAT + +//#include "SO3partA_CGproduct.cu" +//#include "SO3partA_DiagCGproduct.cu" + +#include "SO3partB_addCGproduct.cu" +#include "SO3partB_addCGproduct_back0.cu" +#include "SO3partB_addCGproduct_back1.cu" + +#include "SO3partB_addDiagCGproduct.cu" +#include "SO3partB_addDiagCGproduct_back0.cu" +#include "SO3partB_addDiagCGproduct_back1.cu" + +#include "SO3Fpart_addFproduct.cu" +#include "SO3Fpart_addFproduct_back0.cu" +#include "SO3Fpart_addFproduct_back1.cu" + +#include "SO3part_addCGtransform.cu" + diff --git a/cuda/SO3Fpart_addFproduct.cu b/cuda/SO3Fpart_addFproduct.cu new file mode 100644 index 0000000..66c24fe --- /dev/null +++ b/cuda/SO3Fpart_addFproduct.cu @@ -0,0 +1,316 @@ +/* + * This file is part of GElib, a C++/CUDA library for group equivariant + * tensor operations. + * + * Copyright (c) 2023, Imre Risi Kondor + * + * This source code file is subject to the terms of the noncommercial + * license distributed with GElib in the file NONCOMMERICAL.TXT. Commercial + * use is prohibited. All redistributed versions of this file (in orginal + * or modified form) must retain this copyright notice and must be + * accompanied by a verbatim copy of the license. + * + */ + +#ifndef _SO3Fpart_addFproduct_cu +#define _SO3Fpart_addFproduct_cu + +#include +#include +//#include +//#include + +#include "SO3_CGbank.hpp" +#include "Ctensor2_view.hpp" +#include "Ctensor3_view.hpp" + +//__device__ __constant__ unsigned char cg_cmem[32276]; + +extern GElib::SO3_CGbank SO3_cgbank; + + + + +__device__ int loadg3(const cnine::Ctensor3_view& x, float* dest, const int b, const int t){ + int I=x.n1; + int J=x.n2; + int s1=x.s1; + int s2=x.s2; + int offs=I*J; //((I*J-1)/32+1)*32; + float* destc=dest+offs; + float* source=x.arr+x.s0*b; + float* sourcec=x.arrc+x.s0*b; + if(t(cg_cmem)+Cptr; + const int b=blockIdx.x; + const int t=threadIdx.x; + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int xn=x.n2; + int yn=y.n2; + int rn=r.n2; + + float* cptr; + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + + float* xpr=reinterpret_cast(_shared); + float* xpi=xpr+x.n1*x.n2; + loadg3(x,xpr,b,t); + + float* ypr=xpr+((2*xn*xn-1)/32+1)*32; + float* ypi=ypr+y.n1*y.n2; + if(conj==0) loadg3(y,ypr,b,t); + else loadg3c(y,ypr,b,t); + + float* rpr=ypr+((2*yn*yn-1)/32+1)*32; + float* rpi=rpr+r.n1*r.n2; + loadg3(r,rpr,b,t); + + __syncthreads(); + + if(t=0 && il2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=cptr[(m1+l1)*yn+m2+l2]; + const float y_r=ypr[yn*(m2+l2)]; + const float y_i=ypi[yn*(m2+l2)]; + //_rpr[rn*(m1+m2+l)]+=c0*c*(x_r*y_r-x_i*y_i); + //_rpi[rn*(m1+m2+l)]+=c0*c*(x_r*y_i+x_i*y_r); + atomicAdd(_rpr+rn*(m1+m2+l),c0*c*(x_r*y_r-x_i*y_i)); + atomicAdd(_rpi+rn*(m1+m2+l),c0*c*(x_r*y_i+x_i*y_r)); + } + + } + } + } + + __syncthreads(); + + saveg3(r,rpr,b,t); + +} + + +__global__ void SO3Fpart_addFproduct_large_kernel(const cnine::Ctensor3_view r, const cnine::Ctensor3_view x, + const cnine::Ctensor3_view y, const int Cptr, float* cptr_global, const int conj){ + + extern __shared__ unsigned char _shared[]; + //const float* C_ptr=reinterpret_cast(cg_cmem)+Cptr; + const int b=blockIdx.x; + const int t=threadIdx.x; + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int xn=x.n2; + int yn=y.n2; + int rn=r.n2; + + float* cptr; + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + + float* xpr=reinterpret_cast(_shared); + float* xpi=xpr+x.n1*x.n2; + loadg3(x,xpr,b,t); + + float* ypr=xpr+((2*xn*xn-1)/32+1)*32; + float* ypi=ypr+y.n1*y.n2; + if(conj==0) loadg3(y,ypr,b,t); + else loadg3c(y,ypr,b,t); + + float* rpr=ypr+((2*yn*yn-1)/32+1)*32; + float* rpi=rpr+r.n1*r.n2; + loadg3(r,rpr,b,t); + + //int tn=xn*yn; + //float* tpr=rpr+((2*rn*rn-1)/32+1)*32; + //float* tpi=tpr+tn*rn; + + __syncthreads(); + + if(tl2) upper=l2; + + for(int m2=lower; m2<=upper; m2++){ + float c=cptr[(m1+l1)*yn+m2+l2]; + const float y_r=_ypr[yn*(m2)]; + const float y_i=_ypi[yn*(m2)]; + _rpr[rn*(m1+m2)]+=c0*c*(x_r*y_r-x_i*y_i); + _rpi[rn*(m1+m2)]+=c0*c*(x_r*y_i+x_i*y_r); + //atomicAdd(_rpr+rn*(m1+m2+l),c0*c*(x_r*y_r-x_i*y_i)); + //atomicAdd(_rpi+rn*(m1+m2+l),c0*c*(x_r*y_i+x_i*y_r)); + } + + } + + } + } + + __syncthreads(); + + saveg3(r,rpr,b,t); + +} + + +namespace GElib{ + + + void SO3Fpart_addFproduct_cu(const cnine::Ctensor3_view& r, const cnine::Ctensor3_view& x, const cnine::Ctensor3_view& y, + const int conj, const int method, const cudaStream_t& stream){ + + const int xl=(x.n1-1)/2; + const int yl=(y.n1-1)/2; + const int l=(r.n1-1)/2; + const int b=r.n0; + + float* cptr=nullptr; + int Cptr=SO3_cgbank.getfC(xl,yl,l)/4; + if(Cptr<0) cptr=SO3_cgbank.getf(CGindex(xl,yl,l),r.dev).arrg; + int clines=cnine::roundup(x.n1*y.n1,32)/32; + + int nlines=cnine::roundup(x.n1*x.n2*2,32)/32+ + cnine::roundup(y.n1*y.n2*2,32)/32+ + cnine::roundup(r.n1*r.n2*2,32)/32; + + if(nlines<=384){ + + if(method==0){ + + SO3Fpart_addFproduct_kernel<<>> + (r,x,y,Cptr,cptr,conj); + return; + + }else{ + + SO3Fpart_addFproduct_large_kernel<<>> + (r,x,y,Cptr,cptr,conj); + return; + + } + } + + cout<<"error"< +#include +//#include +//#include + +#include "SO3_CGbank.hpp" +#include "Ctensor3_view.hpp" + + +extern GElib::SO3_CGbank SO3_cgbank; + + +__device__ int loadg4(const cnine::Ctensor3_view& x, float* dest, const int b, const int t){ + int I=x.n1; + int J=x.n2; + int s1=x.s1; + int s2=x.s2; + int offs=I*J; //((I*J-1)/32+1)*32; + float* destc=dest+offs; + float* source=x.arr+x.s0*b; + float* sourcec=x.arrc+x.s0*b; + if(t=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + + float* xpr=reinterpret_cast(_shared); + float* xpi=xpr+loadg4(x,xpr,b,t); + + float* ypr=xpr+((2*xn*xn-1)/32+1)*32; + float* ypi; + if(conj==0) ypi=ypr+loadg4(y,ypr,b,t); + else ypi=ypr+loadg4c(y,ypr,b,t); + + float* rpr=ypr+((2*yn*yn-1)/32+1)*32; + float* rpi=rpr+loadg4(r,rpr,b,t); + + __syncthreads(); + + if(t=0 && il2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=cptr[(m1+l1)*yn+m2+l2]; + const float y_r=ypr[yn*(m2+l2)]; + const float y_i=ypi[yn*(m2+l2)]; + const float g_r=_rpr[rn*(m1+m2+l)]; + const float g_i=_rpi[rn*(m1+m2+l)]; + //_xpr[xn*(m1+l1)]+=c0*c*(g_r*y_r+g_i*y_i); + //_xpi[xn*(m1+l1)]+=c0*c*(-g_r*y_i+g_i*y_r); + atomicAdd(_xpr+xn*(m1+l1),c0*c*(g_r*y_r+g_i*y_i)); + atomicAdd(_xpi+xn*(m1+l1),c0*c*(-g_r*y_i+g_i*y_r)); + } + + } + } + } + + __syncthreads(); + + saveg4(x,xpr,b,t); + +} + + +__global__ void SO3Fpart_addFproduct_back0_large_kernel(const cnine::Ctensor3_view r, const cnine::Ctensor3_view x, + const cnine::Ctensor3_view y, const int Cptr, float* cptr_global, const int conj){ + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int xn=x.n2; + int yn=y.n2; + int rn=r.n2; + + float* cptr; + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + + float* xpr=reinterpret_cast(_shared); + float* xpi=xpr+loadg4(x,xpr,b,t); + + float* ypr=xpr+((2*xn*xn-1)/32+1)*32; + float* ypi; + if(conj==0) ypi=ypr+loadg4(y,ypr,b,t); + else ypi=ypr+loadg4c(y,ypr,b,t); + + float* rpr=ypr+((2*yn*yn-1)/32+1)*32; + float* rpi=rpr+loadg4(r,rpr,b,t); + + __syncthreads(); + + if(t=0 && il2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=cptr[(m1+l1)*yn+m2+l2]; + const float y_r=ypr[yn*(m2+l2)]; + const float y_i=ypi[yn*(m2+l2)]; + const float g_r=_rpr[rn*(m1+m2+l)]; + const float g_i=_rpi[rn*(m1+m2+l)]; + _xpr[xn*(m1+l1)]+=c0*c*(g_r*y_r+g_i*y_i); + _xpi[xn*(m1+l1)]+=c0*c*(-g_r*y_i+g_i*y_r); + //atomicAdd(_xpr+xn*(m1+l1),c0*c*(g_r*y_r+g_i*y_i)); + //atomicAdd(_xpi+xn*(m1+l1),c0*c*(-g_r*y_i+g_i*y_r)); + } + } + + } + } + } + + __syncthreads(); + + saveg4(x,xpr,b,t); + +} + + +namespace GElib{ + + + void SO3Fpart_addFproduct_back0_cu(const cnine::Ctensor3_view& x, const cnine::Ctensor3_view& r, const cnine::Ctensor3_view& y, + const int conj, const int method, const cudaStream_t& stream){ + + const int xl=(x.n1-1)/2; + const int yl=(y.n1-1)/2; + const int l=(r.n1-1)/2; + + const int b=r.n0; + assert(x.n0==b); + assert(y.n0==b); + + float* cptr=nullptr; + int Cptr=SO3_cgbank.getfC(xl,yl,l)/4; + if(Cptr<0) cptr=SO3_cgbank.getf(CGindex(xl,yl,l),r.dev).arrg; + int clines=cnine::roundup(x.n1*y.n1,32)/32; + + int nlines=cnine::roundup(x.n1*x.n2*2,32)/32+ + cnine::roundup(y.n1*y.n2*2,32)/32+ + cnine::roundup(r.n1*r.n2*2,32)/32; + + + if(nlines<=384){ + + if(method==0){ + + SO3Fpart_addFproduct_back0_kernel<<>> + (r,x,y,Cptr,cptr,conj); + + }else{ + + SO3Fpart_addFproduct_back0_large_kernel<<>> + (r,x,y,Cptr,cptr,conj); + + } + + }else{ + cout<<"error"< +#include +//#include +//#include + +#include "SO3_CGbank.hpp" +#include "Ctensor3_view.hpp" + + +extern GElib::SO3_CGbank SO3_cgbank; + + +__device__ int loadg5(const cnine::Ctensor3_view& x, float* dest, const int b, const int t){ + int I=x.n1; + int J=x.n2; + int s1=x.s1; + int s2=x.s2; + int offs=I*J; //((I*J-1)/32+1)*32; + float* destc=dest+offs; + float* source=x.arr+x.s0*b; + float* sourcec=x.arrc+x.s0*b; + if(t=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + + float* xpr=reinterpret_cast(_shared); + float* xpi=xpr+loadg5(x,xpr,b,t); + + float* ypr=xpr+((2*xn*xn-1)/32+1)*32; + float* ypi; + if(conj==0) ypi=ypr+loadg5(y,ypr,b,t); + else ypi=ypr+loadg5c(y,ypr,b,t); + + float* rpr=ypr+((2*yn*yn-1)/32+1)*32; + float* rpi=rpr+loadg5(r,rpr,b,t); + + __syncthreads(); + + if(t=0 && il2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=cptr[(m1+l1)*yn+m2+l2]; + const float g_r=_rpr[rn*(m1+m2+l)]; + const float g_i=_rpi[rn*(m1+m2+l)]; + //_ypr[yn*(m2+l2)]+=c*(g_r*x_r+g_i*x_i); + //_ypi[yn*(m2+l2)]+=c*(-g_r*x_i+g_i*x_r); + atomicAdd(_ypr+yn*(m2+l2),c0*c*(g_r*x_r+g_i*x_i)); + atomicAdd(_ypi+yn*(m2+l2),c0*c*(-g_r*x_i+g_i*x_r)); + } + + } + } + } + + __syncthreads(); + + if(conj==0) saveg5(y,ypr,b,t); + else saveg5c(y,ypr,b,t); + +} + + +__global__ void SO3Fpart_addFproduct_back1_large_kernel(const cnine::Ctensor3_view r, const cnine::Ctensor3_view x, + const cnine::Ctensor3_view y, const int Cptr, float* cptr_global, const int conj){ + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int xn=x.n2; + int yn=y.n2; + int rn=r.n2; + + float* cptr; + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + + float* xpr=reinterpret_cast(_shared); + float* xpi=xpr+loadg5(x,xpr,b,t); + + float* ypr=xpr+((2*xn*xn-1)/32+1)*32; + float* ypi; + if(conj==0) ypi=ypr+loadg5(y,ypr,b,t); + else ypi=ypr+loadg5c(y,ypr,b,t); + + float* rpr=ypr+((2*yn*yn-1)/32+1)*32; + float* rpi=rpr+loadg5(r,rpr,b,t); + + __syncthreads(); + + if(t=0 && il2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=cptr[(m1+l1)*yn+m2+l2]; + const float g_r=_rpr[rn*(m1+m2+l)]; + const float g_i=_rpi[rn*(m1+m2+l)]; + _ypr[yn*(m2+l2)]+=c*(g_r*x_r+g_i*x_i); + _ypi[yn*(m2+l2)]+=c*(-g_r*x_i+g_i*x_r); + //atomicAdd(_ypr+yn*(m2+l2),c0*c*(g_r*x_r+g_i*x_i)); + //atomicAdd(_ypi+yn*(m2+l2),c0*c*(-g_r*x_i+g_i*x_r)); + } + } + } + + } + } + + __syncthreads(); + + if(conj==0) saveg5(y,ypr,b,t); + else saveg5c(y,ypr,b,t); + +} + + + +namespace GElib{ + + + void SO3Fpart_addFproduct_back1_cu(const cnine::Ctensor3_view& y, const cnine::Ctensor3_view& r, const cnine::Ctensor3_view& x, + const int conj, const int method, const cudaStream_t& stream){ + + const int xl=(x.n1-1)/2; + const int yl=(y.n1-1)/2; + const int l=(r.n1-1)/2; + + const int b=r.n0; + assert(x.n0==b); + assert(y.n0==b); + + float* cptr=nullptr; + int Cptr=SO3_cgbank.getfC(xl,yl,l)/4; + if(Cptr<0) cptr=SO3_cgbank.getf(CGindex(xl,yl,l),r.dev).arrg; + int clines=cnine::roundup(x.n1*y.n1,32)/32; + + int nlines=cnine::roundup(x.n1*x.n2*2,32)/32+ + cnine::roundup(y.n1*y.n2*2,32)/32+ + cnine::roundup(r.n1*r.n2*2,32)/32; + + + if(nlines<=384){ + + if(method==0){ + + SO3Fpart_addFproduct_back1_kernel<<>> + (r,x,y,Cptr,cptr,conj); + + }else{ + + SO3Fpart_addFproduct_back1_large_kernel<<>> + (r,x,y,Cptr,cptr,conj); + + } + }else{ + cout<<"error"< +#include +#include +#include + +//__device__ __constant__ unsigned char cg_cmem[32276]; + + +#include "SO3partA.hpp" +#include "SO3partArrayA.hpp" +#include "SO3_CGbank.hpp" + +#include "CellwiseBinaryCmap.hpp" +#include "BroadcastBinaryCmap.hpp" +#include "InnerCmap.hpp" +#include "OuterCmap.hpp" +#include "MVprodCmap.hpp" +#include "VMprodCmap.hpp" +//#include "convolve1_cmap.hpp" +#include "Convolve2Cmap.hpp" + +extern GElib::SO3_CGbank SO3_cgbank; + + +__device__ void SO3part_load_lines(float* dest, const float* source, const int nlines, const int t){ + if(t<32){ + for(int i=0; i +__global__ void SO3partA_CGproduct_kernel(float* rarr, float* rarrc, float* xarr, float* xarrc, + float* yarr, float* yarrc, const int rstride, const int xstride, const int ystride, const IMAP cmap, + const int xn, const int yn, const int rn, const int l1, const int l2, const int l, + const int _offs, const int nch, const int Cptr, const int mode=0){ + + extern __shared__ unsigned char _shared[]; + float* shared=reinterpret_cast(_shared); + + const float* C_ptr=reinterpret_cast(cg_cmem)+Cptr; + const int t=threadIdx.x; + + const int r=2*l+1; + const int r1=2*l1+1; + const int r2=2*l2+1; + + const int xwidth=xn*nch; + const int ywidth=yn*nch; + const int rwidth=xn*yn*nch; + const int global_rwidth=rn*nch; + + const int rlines=((r*rwidth-1)/32+1); + const int xlines=((r1*xwidth-1)/32+1); + const int ylines=((r2*ywidth-1)/32+1); + + const int rptr=0; + const int xptr=rptr+rlines*64; + const int yptr=xptr+xlines*64; + + int rix,xix,yix; + int nsum; + int lst; + + + if(mode<2){ + auto T=cmap(blockIdx.x,blockIdx.y,blockIdx.z); + rix=thrust::get<0>(T); + xix=thrust::get<1>(T); + yix=thrust::get<2>(T); + nsum=1; + //if(t==0) printf("foop1\n"); + }else{ + rix=cmap.target(blockIdx.x); + nsum=cmap.n_accum(blockIdx.x); + lst=cmap.lst_ptr(blockIdx.x); + } + + if(mode==1){ + if(t<32){ + for(int i=0; i<2*rlines; i++) + shared[rptr+i*32+t]=0; + } + }else{ + if(t(T); + yix=thrust::get<1>(T); + } + + SO3part_load_lines(shared+xptr,xarr+xix*xstride,xlines,t); + SO3part_load_lines(shared+xptr+xlines*32,xarrc+xix*xstride,xlines,t); + SO3part_load_lines(shared+yptr,yarr+yix*ystride,ylines,t); + SO3part_load_lines(shared+yptr+ylines*32,yarrc+yix*ystride,ylines,t); + + //if(t==0) printf("foop3\n"); + + __syncthreads(); + + const int rpr=rptr+t; + const int rpi=rpr+rlines*32; + + const int xcol=t/yn; + const int xpr=xptr+xcol; + const int xpi=xpr+xlines*32; + + const int ycol=t%ywidth; + const int ypr=yptr+ycol; + const int ypi=ypr+ylines*32; + + + if(tl2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*r2+m2+l2]; + const float y_r=shared[ypr+ywidth*(m2+l2)]; + const float y_i=shared[ypi+ywidth*(m2+l2)]; + shared[rpr+rwidth*(m1+m2+l)]+=c*(x_r*y_r-x_i*y_i); + shared[rpi+rwidth*(m1+m2+l)]+=c*(x_r*y_i+x_i*y_r); + } + } + } + + //if(t==0) printf("foop4\n"); + + __syncthreads(); + } + + //if(t==0) printf("fooq\n"); + + if(t +__global__ void SO3partA_CGproduct_kernel_L(float* rarr, float* rarrc, float* xarr, float* xarrc, + float* yarr, float* yarrc, const int rstride, const int xstride, const int ystride, const IMAP cmap, + const int xn, const int yn, const int rn, const int l1, const int l2, const int l, + const int _offs, const int nch, const int Cptr, const int mode=0){ + + extern __shared__ unsigned char _shared[]; + float* shared=reinterpret_cast(_shared); + + const float* C_ptr=reinterpret_cast(cg_cmem)+Cptr; + const int t=threadIdx.x; + + const int r=2*l+1; + const int r1=2*l1+1; + const int r2=2*l2+1; + + const int xwidth=xn*nch; + const int ywidth=yn*nch; + const int rwidth=xn*nch; + const int global_rwidth=rn*nch; + + const int rlines=((r*rwidth-1)/32+1); + const int xlines=((r1*xwidth-1)/32+1); + const int ylines=((r2*1-1)/32+1); + + const int rptr=0; + const int xptr=rptr+rlines*64; + const int yptr=xptr+xlines*64; + + int rix,xix,yix; + int nsum; + int lst; + + if(mode<2){ + auto T=cmap(blockIdx.x,blockIdx.y,blockIdx.z); + rix=thrust::get<0>(T); + xix=thrust::get<1>(T); + yix=thrust::get<2>(T); + nsum=1; + }else{ + rix=cmap.target(blockIdx.x); + nsum=cmap.n_accum(blockIdx.x); + lst=cmap.lst_ptr(blockIdx.x); + } + + + for(int s=0; s(T); + yix=thrust::get<1>(T); + } + + SO3part_load_lines(shared+xptr,xarr+xix*xstride,xlines,t); + SO3part_load_lines(shared+xptr+xlines*32,xarrc+xix*xstride,xlines,t); + + + for(int ycol=0; ycoll2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*r2+m2+l2]; + const float y_r=shared[ypr+1*(m2+l2)]; + const float y_i=shared[ypi+1*(m2+l2)]; + shared[rpr+rwidth*(m1+m2+l)]+=c*(x_r*y_r-x_i*y_i); + shared[rpi+rwidth*(m1+m2+l)]+=c*(x_r*y_i+x_i*y_r); + } + } + } + + //if(t==0) printf("foop4\n"); + + __syncthreads(); + + //if(t==0) printf("fooq\n"); + + if(t +__global__ void SO3partA_CGproduct_back0_kernel(float* xarr, float* xarrc, float* garr, float* garrc, + float* yarr, float* yarrc, const int xstride, const int ystride, const int gstride, const IMAP cmap, + const int xn, const int yn, const int gn, const int l1, const int l2, const int l, + const int _offs, const int nch, const int Cptr, const int mode=0){ + + extern __shared__ unsigned char _shared[]; + float* shared=reinterpret_cast(_shared); + + const float* C_ptr=reinterpret_cast(cg_cmem)+Cptr; + const int t=threadIdx.x; + + const int rg=2*l+1; + const int rx=2*l1+1; + const int ry=2*l2+1; + + const int xwidth=xn*nch; + const int ywidth=yn*nch; + const int gwidth=xn*yn*nch; + const int global_gwidth=gn*nch; + + const int glines=((rg*gwidth-1)/32+1); + const int xlines=((rx*xwidth-1)/32+1); + const int ylines=((ry*ywidth-1)/32+1); + + const int xptr=0; + const int gptr=xptr+xlines*64; + const int yptr=gptr+glines*64; + + int gix,xix,yix; + int nsum; + int lst; + + if(mode<2){ + auto T=cmap(blockIdx.x,blockIdx.y,blockIdx.z); + xix=thrust::get<0>(T); + gix=thrust::get<1>(T); + yix=thrust::get<2>(T); + nsum=1; + }else{ + xix=cmap.target(blockIdx.x); + nsum=cmap.n_accum(blockIdx.x); + lst=cmap.lst_ptr(blockIdx.x); + } + + if(mode==1){ + if(t<32){ + for(int i=0; i<2*xlines; i++){ + shared[xptr+i*32+t]=0; + } + } + }else{ + SO3part_load_lines(shared+xptr,xarr+xix*xstride,xlines,t); + SO3part_load_lines(shared+xptr+xlines*32,xarrc+xix*xstride,xlines,t); + } + + for(int s=0; s(T); + yix=thrust::get<1>(T); + } + + // hack: gwidth assumed to be <=32 + for(int i=0; il2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*ry+m2+l2]; + const float y_r=shared[ypr+ywidth*(m2+l2)]; + const float y_i=shared[ypi+ywidth*(m2+l2)]; + const float g_r=shared[gpr+gwidth*(m1+m2+l)]; + const float g_i=shared[gpi+gwidth*(m1+m2+l)]; + shared[xpr+xwidth*(m1+l1)]+=c*(g_r*y_r+g_i*y_i); + shared[xpi+xwidth*(m1+l1)]+=c*(-g_r*y_i+g_i*y_r); + } + } + } + __syncthreads(); + + } + + } + + SO3part_save_lines(shared+xptr,xarr+xix*xstride,xlines,t); + SO3part_save_lines(shared+xptr+xlines*32,xarrc+xix*xstride,xlines,t); + + __syncthreads(); + +} + + +template +__global__ void SO3partA_CGproduct_back0_kernel_big(float* xarr, float* xarrc, float* garr, float* garrc, + float* yarr, float* yarrc, const int xstride, const int ystride, const int gstride, const IMAP cmap, + const int xn, const int yn, const int gn, const int l1, const int l2, const int l, + const int _offs, const int nch, const int Cptr, const int mode=0){ + + extern __shared__ unsigned char _shared[]; + float* shared=reinterpret_cast(_shared); + + const float* C_ptr=reinterpret_cast(cg_cmem)+Cptr; + const int t=threadIdx.x; + + //const int rg=2*l+1; + const int rx=2*l1+1; + const int ry=2*l2+1; + + const int xwidth=xn*nch; + const int ywidth=yn*nch; + //const int gwidth=xn*yn*nch; + const int global_gwidth=gn*nch; + + //const int glines=((rg*gwidth-1)/32+1); + const int xlines=((rx*xwidth-1)/32+1); + const int ylines=((ry*ywidth-1)/32+1); + + const int xptr=0; + const int yptr=xptr+xlines*64; + //const int yptr=gptr+glines*64; + + int gix,xix,yix; + int nsum; + int lst; + + if(mode<2){ + auto T=cmap(blockIdx.x,blockIdx.y,blockIdx.z); + xix=thrust::get<0>(T); + gix=thrust::get<1>(T); + yix=thrust::get<2>(T); + nsum=1; + }else{ + xix=cmap.target(blockIdx.x); + nsum=cmap.n_accum(blockIdx.x); + lst=cmap.lst_ptr(blockIdx.x); + } + + if(mode==1){ + if(t<32){ + for(int i=0; i<2*xlines; i++){ + shared[xptr+i*32+t]=0; + } + } + }else{ + SO3part_load_lines(shared+xptr,xarr+xix*xstride,xlines,t); + SO3part_load_lines(shared+xptr+xlines*32,xarrc+xix*xstride,xlines,t); + } + + for(int s=0; s(T); + yix=thrust::get<1>(T); + } + + // hack: gwidth assumed to be <=32 + //for(int i=0; il2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*ry+m2+l2]; + const float y_r=shared[ypr+ywidth*(m2+l2)]; + const float y_i=shared[ypi+ywidth*(m2+l2)]; + //const float g_r=shared[gpr+gwidth*(m1+m2+l)]; + //const float g_i=shared[gpi+gwidth*(m1+m2+l)]; + const float g_r=garr[gix*gstride+_offs+ywidth*t+ycol+(m1+m2+l)*global_gwidth]; + const float g_i=garrc[gix*gstride+_offs+ywidth*t+ycol+(m1+m2+l)*global_gwidth]; + shared[xpr+xwidth*(m1+l1)]+=c*(g_r*y_r+g_i*y_i); + shared[xpi+xwidth*(m1+l1)]+=c*(-g_r*y_i+g_i*y_r); + } + } + } + __syncthreads(); + + } + + } + + SO3part_save_lines(shared+xptr,xarr+xix*xstride,xlines,t); + SO3part_save_lines(shared+xptr+xlines*32,xarrc+xix*xstride,xlines,t); + + __syncthreads(); + +} + + +// ---- back1 ------------------------------------------------------------------------------------------------ + + +template +__global__ void SO3partA_CGproduct_back1_kernel(float* yarr, float* yarrc, float* garr, float* garrc, + float* xarr, float* xarrc, const int xstride, const int ystride, const int gstride, const IMAP cmap, + const int xn, const int yn, const int gn, const int l1, const int l2, const int l, + const int _offs, const int nch, const int Cptr, const int mode=0){ + + extern __shared__ unsigned char _shared[]; + float* shared=reinterpret_cast(_shared); + + const float* C_ptr=reinterpret_cast(cg_cmem)+Cptr; + const int t=threadIdx.x; + + const int rg=2*l+1; + const int rx=2*l1+1; + const int ry=2*l2+1; + + const int xwidth=xn*nch; + const int ywidth=yn*nch; + const int gwidth=xn*yn*nch; + const int global_gwidth=gn*nch; + + const int glines=((rg*gwidth-1)/32+1); + const int xlines=((rx*xwidth-1)/32+1); + const int ylines=((ry*ywidth-1)/32+1); + + const int yptr=0; + const int gptr=yptr+ylines*64; + const int xptr=gptr+glines*64; + + int gix,xix,yix; + int nsum; + int lst; + + if(mode<2){ + auto T=cmap(blockIdx.x,blockIdx.y,blockIdx.z); + yix=thrust::get<0>(T); + gix=thrust::get<1>(T); + xix=thrust::get<2>(T); + nsum=1; + }else{ + yix=cmap.target(blockIdx.x); + nsum=cmap.n_accum(blockIdx.x); + lst=cmap.lst_ptr(blockIdx.x); + } + + if(mode==1){ + if(t<32){ + for(int i=0; i<2*ylines; i++) + shared[yptr+i*32+t]=0; + } + }else{ + SO3part_load_lines(shared+yptr,yarr+yix*ystride,ylines,t); + SO3part_load_lines(shared+yptr+ylines*32,yarrc+yix*ystride,ylines,t); + } + + for(int s=0; s(T); + xix=thrust::get<1>(T); + } + + // hack: gwidth assumed to be <=32 + for(int i=0; il2) upper=l2; + const float x_r=shared[xpr+xwidth*(m1+l1)]; + const float x_i=shared[xpi+xwidth*(m1+l1)]; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*ry+m2+l2]; + const float g_r=shared[gpr+gwidth*(m1+m2+l)]; + const float g_i=shared[gpi+gwidth*(m1+m2+l)]; + shared[ypr+ywidth*(m2+l2)]+=c*(g_r*x_r+g_i*x_i); + shared[ypi+ywidth*(m2+l2)]+=c*(-g_r*x_i+g_i*x_r); + } + } + } + __syncthreads(); + + } + + } + + SO3part_save_lines(shared+yptr,yarr+yix*ystride,ylines,t); + SO3part_save_lines(shared+yptr+ylines*32,yarrc+yix*ystride,ylines,t); + + __syncthreads(); + +} + + +template +__global__ void SO3partA_CGproduct_back1_kernel_big(float* yarr, float* yarrc, float* garr, float* garrc, + float* xarr, float* xarrc, const int xstride, const int ystride, const int gstride, const IMAP cmap, + const int xn, const int yn, const int gn, const int l1, const int l2, const int l, + const int _offs, const int nch, const int Cptr, const int mode=0){ + + extern __shared__ unsigned char _shared[]; + float* shared=reinterpret_cast(_shared); + + const float* C_ptr=reinterpret_cast(cg_cmem)+Cptr; + const int t=threadIdx.x; + + //const int rg=2*l+1; + const int rx=2*l1+1; + const int ry=2*l2+1; + + const int xwidth=xn*nch; + const int ywidth=yn*nch; + //const int gwidth=xn*yn*nch; + const int global_gwidth=gn*nch; + + //const int glines=((rg*gwidth-1)/32+1); + const int xlines=((rx*xwidth-1)/32+1); + const int ylines=((ry*ywidth-1)/32+1); + + const int yptr=0; + const int xptr=yptr+ylines*64; + //const int xptr=gptr+glines*64; + + int gix,xix,yix; + int nsum; + int lst; + + if(mode<2){ + auto T=cmap(blockIdx.x,blockIdx.y,blockIdx.z); + yix=thrust::get<0>(T); + gix=thrust::get<1>(T); + xix=thrust::get<2>(T); + nsum=1; + }else{ + yix=cmap.target(blockIdx.x); + nsum=cmap.n_accum(blockIdx.x); + lst=cmap.lst_ptr(blockIdx.x); + } + + if(mode==1){ + if(t<32){ + for(int i=0; i<2*ylines; i++) + shared[yptr+i*32+t]=0; + } + }else{ + SO3part_load_lines(shared+yptr,yarr+yix*ystride,ylines,t); + SO3part_load_lines(shared+yptr+ylines*32,yarrc+yix*ystride,ylines,t); + } + + for(int s=0; s(T); + xix=thrust::get<1>(T); + } + + // hack: gwidth assumed to be <=32 + //for(int i=0; il2) upper=l2; + const float x_r=shared[xpr+xwidth*(m1+l1)]; + const float x_i=shared[xpi+xwidth*(m1+l1)]; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*ry+m2+l2]; + //const float g_r=shared[gpr+gwidth*(m1+m2+l)]; + //const float g_i=shared[gpi+gwidth*(m1+m2+l)]; + const float g_r=garr[gix*gstride+_offs+ywidth*xcol+t+(m1+m2+l)*global_gwidth]; + const float g_i=garrc[gix*gstride+_offs+ywidth*xcol+t+(m1+m2+l)*global_gwidth]; + shared[ypr+ywidth*(m2+l2)]+=c*(g_r*x_r+g_i*x_i); + shared[ypi+ywidth*(m2+l2)]+=c*(-g_r*x_i+g_i*x_r); + } + } + } + __syncthreads(); + + } + + } + + SO3part_save_lines(shared+yptr,yarr+yix*ystride,ylines,t); + SO3part_save_lines(shared+yptr+ylines*32,yarrc+yix*ystride,ylines,t); + + __syncthreads(); + +} + + +// ----------------------------------------------------------------------------------------------------------- + + +namespace GElib{ + + + template + void SO3partA_CGproduct_cu(const CMAP& map, SO3partArrayA& r, const SO3partArrayA& x, + const SO3partArrayA& y, const cudaStream_t& stream, const int offs, const int mode){ + + const int xl=x.getl(); + const int yl=y.getl(); + const int l=r.getl(); + const int _nch=1; + assert(x.nbu==r.nbu); + assert(y.nbu==r.nbu); + int _nbu=1; if(_nbu<0) _nbu=1; + + int Cptr=SO3_cgbank.getfC(xl,yl,l)/4; + //int nlines=x.cellstride/16+y.cellstride/16+r.cellstride/16; // should be smaller than this! + int nlines=x.cellstride/16+y.cellstride/16+cnine::roundup(x.getn()*y.getn()*_nch*(2*l+1),32)/16; + // nlines/=_nbu; + + cout<<"nlines="<>> + (r.arrg,r.arrgc,x.arrg,x.arrgc,y.arrg,y.arrgc, + r.cellstride,x.cellstride,y.cellstride,map, + x.getn(),y.getn(),r.getn(),xl,yl,l,offs,_nch,Cptr,mode); + + }else{ + + int nlines=x.cellstride/16+cnine::roundup(_nch*(2*yl+1),32)/16+cnine::roundup(x.getn()*_nch*(2*l+1),32)/16; + + cout<<"GElib: large CGproduct"<384){ + cout<<"GElib error: CGproduct too big for shared memory"<>> + (r.arrg,r.arrgc,x.arrg,x.arrgc,y.arrg,y.arrgc, + r.cellstride,x.cellstride,y.cellstride,map, + x.getn(),y.getn(),r.getn(),xl,yl,l,offs,_nch,Cptr,mode); + } + } + + } + + + void SO3partA_CGproduct_cu(SO3partA& r, const SO3partA& x, const SO3partA& y, const int offs, + const cudaStream_t& stream,const int mode){ + + const int xl=x.getl(); + const int yl=y.getl(); + const int l=r.getl(); + const int _nch=1; + assert(x.nbu==r.nbu); + assert(y.nbu==r.nbu); + int _nbu=1; if(_nbu<0) _nbu=1; + cnine::CellwiseBinaryCmap map; + + int Cptr=SO3_cgbank.getfC(xl,yl,l)/4; + int nlines=cnine::roundup(x.memsize,32)/32+cnine::roundup(y.memsize,32)/32+ + cnine::roundup(x.getn()*y.getn()*_nch*(2*l+1),32)/16; + + //cout<<"nlines="<>> + (r.arrg,r.arrgc,x.arrg,x.arrgc,y.arrg,y.arrgc, + 0,0,0,map, + x.getn(),y.getn(),r.getn(),xl,yl,l,offs,_nch,Cptr,mode); + + }else{ + + int nlines=cnine::roundup(x.memsize,32)/32+cnine::roundup(y.memsize,32)/32+ + cnine::roundup(x.getn()*_nch*(2*l+1),32)/16; + + cout<<"GElib: large CGproduct"<384){ + cout<<"GElib error: CGproduct too big for shared memory"<>> + (r.arrg,r.arrgc,x.arrg,x.arrgc,y.arrg,y.arrgc, + 0,0,0,map, + x.getn(),y.getn(),r.getn(),xl,yl,l,offs,_nch,Cptr,mode); + } + } + + } + + + template + void SO3partA_CGproduct_back0_cu(const CMAP& map, SO3partArrayA& x, const SO3partArrayA& g, + const SO3partArrayA& y, const cudaStream_t& stream, const int offs, const int mode){ + + const int xl=x.getl(); + const int yl=y.getl(); + const int l=g.getl(); + + int Cptr=SO3_cgbank.getfC(xl,yl,l)/4; + int nlines=x.cellstride/16+y.cellstride/16+g.cellstride/16; + assert(x.nbu==g.nbu); + assert(y.nbu==g.nbu); + + const int _nch=1; + int _nbu=1; if(_nbu<0) _nbu=1; + nlines/=_nbu; + + cout<<"nlines="<>> + (x.arrg,x.arrgc,g.arrg,g.arrgc,y.arrg,y.arrgc, + x.cellstride,y.cellstride,g.cellstride,map, + x.getn(),y.getn(),g.getn(),xl,yl,l,offs,_nch,Cptr,mode); + + }else{ + + int nlines=x.cellstride/16+y.cellstride/16; + + cout<<"GElib: large CGproduct_back0"<384){ + cout<<"GElib error: CGproduct too big for shared memory"<>> + (x.arrg,x.arrgc,g.arrg,g.arrgc,y.arrg,y.arrgc, + x.cellstride,y.cellstride,g.cellstride,map, + x.getn(),y.getn(),g.getn(),xl,yl,l,offs,_nch,Cptr,mode); + } + + } + + } + + + template + void SO3partA_CGproduct_back1_cu(const CMAP& map, SO3partArrayA& y, const SO3partArrayA& g, + const SO3partArrayA& x, const cudaStream_t& stream, const int offs, const int mode){ + + const int xl=x.getl(); + const int yl=y.getl(); + const int l=g.getl(); + + int Cptr=SO3_cgbank.getfC(xl,yl,l)/4; + int nlines=x.cellstride/16+y.cellstride/16+g.cellstride/16; + assert(x.nbu==g.nbu); + assert(y.nbu==g.nbu); + + const int _nch=1; + int _nbu=1; if(_nbu<0) _nbu=1; + nlines/=_nbu; + + cout<<"nlines="<>> + (y.arrg,y.arrgc,g.arrg,g.arrgc,x.arrg,x.arrgc, + x.cellstride,y.cellstride,g.cellstride,map, + x.getn(),y.getn(),g.getn(),xl,yl,l,offs,_nch,Cptr,mode); + + }else{ + + int nlines=x.cellstride/16+y.cellstride/16; + + cout<<"GElib: large CGproduct_back1"<384){ + cout<<"GElib error: CGproduct too big for shared memory"<>> + (y.arrg,y.arrgc,g.arrg,g.arrgc,x.arrg,x.arrgc, + x.cellstride,y.cellstride,g.cellstride,map, + x.getn(),y.getn(),g.getn(),xl,yl,l,offs,_nch,Cptr,mode); + } + + } + + } + + + template void SO3partA_CGproduct_cu(const cnine::CellwiseBinaryCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_CGproduct_cu(const cnine::BroadcastBinaryCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_CGproduct_cu(const cnine::OuterCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_CGproduct_cu(const cnine::InnerCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_CGproduct_cu(const cnine::MVprodCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_CGproduct_cu(const cnine::Convolve2Cmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + + + template void SO3partA_CGproduct_back0_cu(const cnine::CellwiseBinaryCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_CGproduct_back0_cu(const cnine::BroadcastBinaryCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_CGproduct_back0_cu(const cnine::OuterCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + + template void SO3partA_CGproduct_back1_cu(const cnine::CellwiseBinaryCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_CGproduct_back1_cu(const cnine::BroadcastBinaryCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_CGproduct_back1_cu(const cnine::OuterCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + + + +} + +#endif + + + + + + /* + void SO3partA_CGproduct_cu(SO3partArrayA& r, const SO3partArrayA& x, const SO3partArrayA& y, + const int mode, const cudaStream_t& stream, const int offs){ + + const int xl=x.getl(); + const int yl=y.getl(); + const int l=r.getl(); + + int Cptr=SO3_cgbank.getfC(xl,yl,l)/4; + int nlines=x.cellstride/16+y.cellstride/16+r.cellstride/16; + assert(x.nbu==r.nbu); + assert(y.nbu==r.nbu); + + const int _nch=1; + int _nbu=1; if(_nbu<0) _nbu=1; + nlines/=_nbu; + + if(mode==0){ + dim3 blocks(r.aasize,1,1); + cnine::CellwiseImap imap; + SO3partA_CGproduct_kernel<<>> + (r.arrg,r.arrgc,x.arrg,x.arrgc,y.arrg,y.arrgc, + r.cellstride,x.cellstride,y.cellstride,imap, + x.getn(),y.getn(),r.getn(),xl,yl,l,offs,_nch,Cptr); // + } + + if(mode==1){ + dim3 blocks(x.aasize,y.aasize,1); + cnine::OuterImap imap(r.adims[1]); + SO3partA_CGproduct_kernel<<>> + (r.arrg,r.arrgc,x.arrg,x.arrgc,y.arrg,y.arrgc, + r.cellstride,x.cellstride,y.cellstride,imap, + x.getn(),y.getn(),r.getn(),xl,yl,l,offs,_nch,Cptr); + } + + + } + */ + + + /* + void SO3partA_CGproduct_cu(SO3partArrayA& r, const SO3partArrayA& x, const SO3partArrayA& y, + const int rN, const int xN, const int yN, + const int ris, const int rjs, const int rks, + const int xis, const int xjs, const int xks, + const int yis, const int yjs, const int yks, + const cudaStream_t& stream, const int offs){ + + const int xl=x.getl(); + const int yl=y.getl(); + const int l=r.getl(); + + int Cptr=SO3_cgbank.getfC(xl,yl,l)/4; + int nlines=x.cellstride/16+y.cellstride/16+r.cellstride/16; + assert(x.nbu==r.nbu); + assert(y.nbu==r.nbu); + + const int _nch=1; + int _nbu=1; if(_nbu<0) _nbu=1; + dim3 blocks(rN,xN,yN); + nlines/=_nbu; + + + SO3partA_CGproduct_kernel<<>> + (r.arrg,r.arrgc,x.arrg,x.arrgc,y.arrg,y.arrgc, + ris*r.cellstride,rjs*r.cellstride, rks*r.cellstride, + xis*x.cellstride,xjs*x.cellstride, xks*x.cellstride, + yis*y.cellstride,yjs*y.cellstride, yks*y.cellstride, + x.getn(),y.getn(),r.getn(),xl,yl,l,offs,_nch,Cptr); + } + */ + +/* +__global__ void SO3partA_CGproduct_kernel(float* rarr, float* rarrc, float* xarr, float* xarrc, + float* yarr, float* yarrc, + const int ristride, const int xistride, const int yistride, + const int rjstride, const int xjstride, const int yjstride, + const int rkstride, const int xkstride, const int ykstride, + const int xfrags, const int yfrags, const int rfrags, + const int l1, const int l2, const int l, const int _offs, const int nch, const int Cptr){ + + extern __shared__ unsigned char _shared[]; + float* shared=reinterpret_cast(_shared); + + const float* C_ptr=reinterpret_cast(cg_cmem)+Cptr; + + const int iix=blockIdx.x; + const int jix=blockIdx.y; + const int kix=blockIdx.z; + + const int t=threadIdx.x; + + const int r1=2*l1+1; + const int r2=2*l2+1; + const int r=2*l+1; + + const int xwidth=xfrags*nch; + const int ywidth=yfrags*nch; + const int rwidth=xfrags*yfrags*nch; + const int global_rwidth=rfrags*nch; + + int offs=0; + + int xptr=32*offs; + SO3part_load(offs,shared,xarr,xarrc,l1,xwidth,iix*xistride+jix*xjstride+kix*xkstride,t); + + const int yptr=32*offs; + SO3part_load(offs,shared,yarr,yarrc,l2,ywidth,iix*yistride+jix*yjstride+kix*ykstride,t); + + const int rpr=32*offs+t; + const int rpi=rpr+((r*rwidth-1)/32+1)*32; + float* _rptr=rarr+iix*ristride+jix*rjstride+kix*rkstride+_offs; + float* _rptri=rarrc+iix*ristride+jix*rjstride+kix*rkstride+_offs; + + if(tl2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*r2+m2+l2]; + const float y_r=shared[ypr+ywidth*(m2+l2)]; + const float y_i=shared[ypi+ywidth*(m2+l2)]; + shared[rpr+rwidth*(m1+m2+l)]+=c*(x_r*y_r-x_i*y_i); + shared[rpi+rwidth*(m1+m2+l)]+=c*(x_r*y_i+x_i*y_r); + } + } + } + + __syncthreads(); + + if(t +#include +#include +#include + +//__device__ __constant__ unsigned char cg_cmem[32276]; + +#include "SO3partArrayA.hpp" +#include "SO3_CGbank.hpp" +#include "SO3partA.hpp" + +#include "CellwiseBinaryCmap.hpp" +#include "BroadcastBinaryCmap.hpp" +#include "InnerCmap.hpp" +#include "OuterCmap.hpp" +#include "MVprodCmap.hpp" +#include "VMprodCmap.hpp" +#include "Convolve2Cmap.hpp" + + +// should move these elsewhere + +__device__ void SO3part_load_lines2(float* dest, const float* source, const int nlines, const int t){ + if(t<32){ + for(int i=0; i +__global__ void SO3partA_DiagCGproduct_kernel(float* rarr, float* rarrc, float* xarr, float* xarrc, + float* yarr, float* yarrc, const int rstride, const int xstride, const int ystride, const IMAP cmap, + const int xn, const int yn, const int rn, const int l1, const int l2, const int l, + const int _offs, const int nch, const int Cptr, const int mode=0){ + + extern __shared__ unsigned char _shared[]; + float* shared=reinterpret_cast(_shared); + + const float* C_ptr=reinterpret_cast(cg_cmem)+Cptr; + const int t=threadIdx.x; + + const int r=2*l+1; + const int r1=2*l1+1; + const int r2=2*l2+1; + + const int xwidth=xn*nch; + const int ywidth=yn*nch; + const int rwidth=xn*nch; + const int global_rwidth=rn*nch; + + const int rlines=((r*rwidth-1)/32+1); + const int xlines=((r1*xwidth-1)/32+1); + const int ylines=((r2*ywidth-1)/32+1); + + const int rptr=0; + const int xptr=rptr+rlines*64; + const int yptr=xptr+xlines*64; + + int rix,xix,yix; + int nsum; + int lst; + + + if(mode<2){ + auto T=cmap(blockIdx.x,blockIdx.y,blockIdx.z); + rix=thrust::get<0>(T); + xix=thrust::get<1>(T); + yix=thrust::get<2>(T); + nsum=1; + //if(t==0) printf("foop1\n"); + }else{ + rix=cmap.target(blockIdx.x); + nsum=cmap.n_accum(blockIdx.x); + lst=cmap.lst_ptr(blockIdx.x); + } + + if(mode==1){ + if(t<32){ + for(int i=0; i<2*rlines; i++) + shared[rptr+i*32+t]=0; + } + }else{ + if(t(T); + yix=thrust::get<1>(T); + } + + SO3part_load_lines2(shared+xptr,xarr+xix*xstride,xlines,t); + SO3part_load_lines2(shared+xptr+xlines*32,xarrc+xix*xstride,xlines,t); + SO3part_load_lines2(shared+yptr,yarr+yix*ystride,ylines,t); + SO3part_load_lines2(shared+yptr+ylines*32,yarrc+yix*ystride,ylines,t); + + //if(t==0) printf("foop3\n"); + + __syncthreads(); + + const int rpr=rptr+t; + const int rpi=rpr+rlines*32; + + const int xpr=xptr+t; + const int xpi=xpr+xlines*32; + + const int ypr=yptr+t; + const int ypi=ypr+ylines*32; + + + if(tl2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*r2+m2+l2]; + const float y_r=shared[ypr+ywidth*(m2+l2)]; + const float y_i=shared[ypi+ywidth*(m2+l2)]; + shared[rpr+rwidth*(m1+m2+l)]+=c*(x_r*y_r-x_i*y_i); + shared[rpi+rwidth*(m1+m2+l)]+=c*(x_r*y_i+x_i*y_r); + } + } + } + + //if(t==0) printf("foop4\n"); + + __syncthreads(); + } + + //if(t==0) printf("fooq\n"); + + if(t // TODO +__global__ void SO3partA_DiagCGproduct_kernel_L(float* rarr, float* rarrc, float* xarr, float* xarrc, + float* yarr, float* yarrc, const int rstride, const int xstride, const int ystride, const IMAP cmap, + const int xn, const int yn, const int rn, const int l1, const int l2, const int l, + const int _offs, const int nch, const int Cptr, const int mode=0){ + + extern __shared__ unsigned char _shared[]; + float* shared=reinterpret_cast(_shared); + + const float* C_ptr=reinterpret_cast(cg_cmem)+Cptr; + const int t=threadIdx.x; + + const int r=2*l+1; + const int r1=2*l1+1; + const int r2=2*l2+1; + + const int xwidth=xn*nch; + const int ywidth=yn*nch; + const int rwidth=xn*nch; + const int global_rwidth=rn*nch; + + const int rlines=((r*rwidth-1)/32+1); + const int xlines=((r1*xwidth-1)/32+1); + const int ylines=((r2*1-1)/32+1); + + const int rptr=0; + const int xptr=rptr+rlines*64; + const int yptr=xptr+xlines*64; + + int rix,xix,yix; + int nsum; + int lst; + + if(mode<2){ + auto T=cmap(blockIdx.x,blockIdx.y,blockIdx.z); + rix=thrust::get<0>(T); + xix=thrust::get<1>(T); + yix=thrust::get<2>(T); + nsum=1; + }else{ + rix=cmap.target(blockIdx.x); + nsum=cmap.n_accum(blockIdx.x); + lst=cmap.lst_ptr(blockIdx.x); + } + + + for(int s=0; s(T); + yix=thrust::get<1>(T); + } + + SO3part_load_lines2(shared+xptr,xarr+xix*xstride,xlines,t); + SO3part_load_lines2(shared+xptr+xlines*32,xarrc+xix*xstride,xlines,t); + + + for(int ycol=0; ycoll2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*r2+m2+l2]; + const float y_r=shared[ypr+1*(m2+l2)]; + const float y_i=shared[ypi+1*(m2+l2)]; + shared[rpr+rwidth*(m1+m2+l)]+=c*(x_r*y_r-x_i*y_i); + shared[rpi+rwidth*(m1+m2+l)]+=c*(x_r*y_i+x_i*y_r); + } + } + } + + //if(t==0) printf("foop4\n"); + + __syncthreads(); + + //if(t==0) printf("fooq\n"); + + if(t // TODO +__global__ void SO3partA_DiagCGproduct_back0_kernel(float* xarr, float* xarrc, float* garr, float* garrc, + float* yarr, float* yarrc, const int xstride, const int ystride, const int gstride, const IMAP cmap, + const int xn, const int yn, const int gn, const int l1, const int l2, const int l, + const int _offs, const int nch, const int Cptr, const int mode=0){ + + extern __shared__ unsigned char _shared[]; + float* shared=reinterpret_cast(_shared); + + const float* C_ptr=reinterpret_cast(cg_cmem)+Cptr; + const int t=threadIdx.x; + + const int rg=2*l+1; + const int rx=2*l1+1; + const int ry=2*l2+1; + + const int xwidth=xn*nch; + const int ywidth=yn*nch; + const int gwidth=xn*yn*nch; + const int global_gwidth=gn*nch; + + const int glines=((rg*gwidth-1)/32+1); + const int xlines=((rx*xwidth-1)/32+1); + const int ylines=((ry*ywidth-1)/32+1); + + const int xptr=0; + const int gptr=xptr+xlines*64; + const int yptr=gptr+glines*64; + + int gix,xix,yix; + int nsum; + int lst; + + if(mode<2){ + auto T=cmap(blockIdx.x,blockIdx.y,blockIdx.z); + xix=thrust::get<0>(T); + gix=thrust::get<1>(T); + yix=thrust::get<2>(T); + nsum=1; + }else{ + xix=cmap.target(blockIdx.x); + nsum=cmap.n_accum(blockIdx.x); + lst=cmap.lst_ptr(blockIdx.x); + } + + if(mode==1){ + if(t<32){ + for(int i=0; i<2*xlines; i++){ + shared[xptr+i*32+t]=0; + } + } + }else{ + SO3part_load_lines2(shared+xptr,xarr+xix*xstride,xlines,t); + SO3part_load_lines2(shared+xptr+xlines*32,xarrc+xix*xstride,xlines,t); + } + + for(int s=0; s(T); + yix=thrust::get<1>(T); + } + + // hack: gwidth assumed to be <=32 + for(int i=0; il2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*ry+m2+l2]; + const float y_r=shared[ypr+ywidth*(m2+l2)]; + const float y_i=shared[ypi+ywidth*(m2+l2)]; + const float g_r=shared[gpr+gwidth*(m1+m2+l)]; + const float g_i=shared[gpi+gwidth*(m1+m2+l)]; + shared[xpr+xwidth*(m1+l1)]+=c*(g_r*y_r+g_i*y_i); + shared[xpi+xwidth*(m1+l1)]+=c*(-g_r*y_i+g_i*y_r); + } + } + } + __syncthreads(); + + } + + } + + SO3part_save_lines2(shared+xptr,xarr+xix*xstride,xlines,t); + SO3part_save_lines2(shared+xptr+xlines*32,xarrc+xix*xstride,xlines,t); + + __syncthreads(); + +} + + +// ---- back1 ------------------------------------------------------------------------------------------------ + + +template // TODO +__global__ void SO3partA_DiagCGproduct_back1_kernel(float* yarr, float* yarrc, float* garr, float* garrc, + float* xarr, float* xarrc, const int xstride, const int ystride, const int gstride, const IMAP cmap, + const int xn, const int yn, const int gn, const int l1, const int l2, const int l, + const int _offs, const int nch, const int Cptr, const int mode=0){ + + extern __shared__ unsigned char _shared[]; + float* shared=reinterpret_cast(_shared); + + const float* C_ptr=reinterpret_cast(cg_cmem)+Cptr; + const int t=threadIdx.x; + + const int rg=2*l+1; + const int rx=2*l1+1; + const int ry=2*l2+1; + + const int xwidth=xn*nch; + const int ywidth=yn*nch; + const int gwidth=xn*yn*nch; + const int global_gwidth=gn*nch; + + const int glines=((rg*gwidth-1)/32+1); + const int xlines=((rx*xwidth-1)/32+1); + const int ylines=((ry*ywidth-1)/32+1); + + const int yptr=0; + const int gptr=yptr+ylines*64; + const int xptr=gptr+glines*64; + + int gix,xix,yix; + int nsum; + int lst; + + if(mode<2){ + auto T=cmap(blockIdx.x,blockIdx.y,blockIdx.z); + yix=thrust::get<0>(T); + gix=thrust::get<1>(T); + xix=thrust::get<2>(T); + nsum=1; + }else{ + yix=cmap.target(blockIdx.x); + nsum=cmap.n_accum(blockIdx.x); + lst=cmap.lst_ptr(blockIdx.x); + } + + if(mode==1){ + if(t<32){ + for(int i=0; i<2*ylines; i++) + shared[yptr+i*32+t]=0; + } + }else{ + SO3part_load_lines2(shared+yptr,yarr+yix*ystride,ylines,t); + SO3part_load_lines2(shared+yptr+ylines*32,yarrc+yix*ystride,ylines,t); + } + + for(int s=0; s(T); + xix=thrust::get<1>(T); + } + + // hack: gwidth assumed to be <=32 + for(int i=0; il2) upper=l2; + const float x_r=shared[xpr+xwidth*(m1+l1)]; + const float x_i=shared[xpi+xwidth*(m1+l1)]; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*ry+m2+l2]; + const float g_r=shared[gpr+gwidth*(m1+m2+l)]; + const float g_i=shared[gpi+gwidth*(m1+m2+l)]; + shared[ypr+ywidth*(m2+l2)]+=c*(g_r*x_r+g_i*x_i); + shared[ypi+ywidth*(m2+l2)]+=c*(-g_r*x_i+g_i*x_r); + } + } + } + __syncthreads(); + + } + + } + + SO3part_save_lines2(shared+yptr,yarr+yix*ystride,ylines,t); + SO3part_save_lines2(shared+yptr+ylines*32,yarrc+yix*ystride,ylines,t); + + __syncthreads(); + +} + + +// ----------------------------------------------------------------------------------------------------------- + + +namespace GElib{ + + + template + void SO3partA_DiagCGproduct_cu(const CMAP& map, SO3partArrayA& r, const SO3partArrayA& x, + const SO3partArrayA& y, const cudaStream_t& stream, const int offs, const int mode){ + + const int xl=x.getl(); + const int yl=y.getl(); + const int l=r.getl(); + const int _nch=1; + assert(x.nbu==r.nbu); + assert(y.nbu==r.nbu); + int _nbu=1; if(_nbu<0) _nbu=1; + + int Cptr=SO3_cgbank.getfC(xl,yl,l)/4; + //int nlines=x.cellstride/16+y.cellstride/16+r.cellstride/16; // should be smaller than this! + int nlines=x.cellstride/16+y.cellstride/16+cnine::roundup(x.getn()*_nch*(2*l+1),32)/16; + // nlines/=_nbu; + + if(nlines<=384){ + + SO3partA_DiagCGproduct_kernel<<>> + (r.arrg,r.arrgc,x.arrg,x.arrgc,y.arrg,y.arrgc, + r.cellstride,x.cellstride,y.cellstride,map, + x.getn(),y.getn(),r.getn(),xl,yl,l,offs,_nch,Cptr,mode); + + }else{ // TODO + + int nlines=x.cellstride/16+cnine::roundup(_nch*(2*yl+1),32)/16+cnine::roundup(x.getn()*_nch*(2*l+1),32)/16; + + if(nlines>384){ + cout<<"GElib error: DiagCGproduct too big for shared memory"<>> + (r.arrg,r.arrgc,x.arrg,x.arrgc,y.arrg,y.arrgc, + r.cellstride,x.cellstride,y.cellstride,map, + x.getn(),y.getn(),r.getn(),xl,yl,l,offs,_nch,Cptr,mode); + } + } + + } + + + template // TODO + void SO3partA_DiagCGproduct_back0_cu(const CMAP& map, SO3partArrayA& x, const SO3partArrayA& g, + const SO3partArrayA& y, const cudaStream_t& stream, const int offs, const int mode){ + + const int xl=x.getl(); + const int yl=y.getl(); + const int l=g.getl(); + + int Cptr=SO3_cgbank.getfC(xl,yl,l)/4; + int nlines=x.cellstride/16+y.cellstride/16+g.cellstride/16; + assert(x.nbu==g.nbu); + assert(y.nbu==g.nbu); + + const int _nch=1; + int _nbu=1; if(_nbu<0) _nbu=1; + nlines/=_nbu; + + SO3partA_DiagCGproduct_back0_kernel<<>> + (x.arrg,x.arrgc,g.arrg,g.arrgc,y.arrg,y.arrgc, + x.cellstride,y.cellstride,g.cellstride,map, + x.getn(),y.getn(),g.getn(),xl,yl,l,offs,_nch,Cptr,mode); + + } + + + template // TODO + void SO3partA_DiagCGproduct_back1_cu(const CMAP& map, SO3partArrayA& y, const SO3partArrayA& g, + const SO3partArrayA& x, const cudaStream_t& stream, const int offs, const int mode){ + + const int xl=x.getl(); + const int yl=y.getl(); + const int l=g.getl(); + + int Cptr=SO3_cgbank.getfC(xl,yl,l)/4; + int nlines=x.cellstride/16+y.cellstride/16+g.cellstride/16; + assert(x.nbu==g.nbu); + assert(y.nbu==g.nbu); + + const int _nch=1; + int _nbu=1; if(_nbu<0) _nbu=1; + nlines/=_nbu; + + SO3partA_DiagCGproduct_back1_kernel<<>> + (y.arrg,y.arrgc,g.arrg,g.arrgc,x.arrg,x.arrgc, + x.cellstride,y.cellstride,g.cellstride,map, + x.getn(),y.getn(),g.getn(),xl,yl,l,offs,_nch,Cptr,mode); + + } + + + template void SO3partA_DiagCGproduct_cu(const cnine::CellwiseBinaryCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_DiagCGproduct_cu(const cnine::BroadcastBinaryCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_DiagCGproduct_cu(const cnine::OuterCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_DiagCGproduct_cu(const cnine::InnerCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_DiagCGproduct_cu(const cnine::MVprodCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_DiagCGproduct_cu(const cnine::Convolve2Cmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + + + template void SO3partA_DiagCGproduct_back0_cu(const cnine::CellwiseBinaryCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_DiagCGproduct_back0_cu(const cnine::BroadcastBinaryCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_DiagCGproduct_back0_cu(const cnine::OuterCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + + template void SO3partA_DiagCGproduct_back1_cu(const cnine::CellwiseBinaryCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_DiagCGproduct_back1_cu(const cnine::BroadcastBinaryCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + template void SO3partA_DiagCGproduct_back1_cu(const cnine::OuterCmap& map, + SO3partArrayA&, const SO3partArrayA&, const SO3partArrayA&, const cudaStream_t&, const int offs, + const int mode); + + + + +} + +#endif + + + + + diff --git a/cuda/SO3partB_addCGproduct.cu b/cuda/SO3partB_addCGproduct.cu new file mode 100644 index 0000000..706999d --- /dev/null +++ b/cuda/SO3partB_addCGproduct.cu @@ -0,0 +1,280 @@ +/* + * This file is part of GElib, a C++/CUDA library for group equivariant + * tensor operations. + * + * Copyright (c) 2023, Imre Risi Kondor + * + * This source code file is subject to the terms of the noncommercial + * license distributed with GElib in the file NONCOMMERICAL.TXT. Commercial + * use is prohibited. All redistributed versions of this file (in orginal + * or modified form) must retain this copyright notice and must be + * accompanied by a verbatim copy of the license. + * + */ + +#ifndef _SO3partB_addCGproduct_cu +#define _SO3partB_addCGproduct_cu + +#include +#include + +#include "SO3_CGbank.hpp" +#include "GElibConfig.hpp" +#include "Ctensor3_view.hpp" +#include "Ctensor4_view.hpp" +#include "cuda_loaders.cu" + + +extern GElib::SO3_CGbank SO3_cgbank; +extern GElib::GElibConfig* gelib_config; + +#define maxl1_explicit 2 +#define maxl_explicit 4 + +#include "SO3part_addCGproduct_subkernels.inc" + + +__global__ void SO3partB_addCGproduct_tiled_kernel(const cnine::Ctensor3_view r, const cnine::Ctensor4_view_t3 x, + const cnine::Ctensor4_view_t3 y, const int Cptr, float* cptr_global, const bool preloadCG){ + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int L2=y.n1; + + float* cptr; + float* xpr; + if(preloadCG){ + cptr=reinterpret_cast(_shared); + xpr=cptr+((x.n1*y.n1-1)/32+1)*32; + if(Cptr>=0) loadf(cptr,reinterpret_cast(cg_cmem)+Cptr,x.n1*y.n1); + else loadf(cptr,cptr_global,x.n1*y.n1); + }else{ + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + xpr=reinterpret_cast(_shared); + } + + float* xpi=xpr+x.n1*x.n3; + float* ypr=xpr+((2*x.n1*x.n3-1)/32+1)*32; + float* ypi=ypr+y.n1*y.n3; + + int xs1=x.n3; + int ys1=y.n3; + int rs1=r.s1; + int ytot=(y.n2-1)*y.n3+y.last; + + for(int i=0; i +__global__ void SO3part_addCGproduct_explicit(const cnine::Ctensor3_view r, const cnine::Ctensor3_view x, + const cnine::Ctensor3_view y){ + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + + //int l1=(x.n1-1)/2; + //int l2=(y.n1-1)/2; + //int l=(r.n1-1)/2; + //int L2=y.n1; + //int L2=y.n1; + + float* xpr=reinterpret_cast(_shared); + float* xpi=xpr+16; //xpr+x.n1; + float* ypr=xpr+32; //xpr+((2*x.n1-1)/32+1)*32; + float* ypi=ypr+y.n1*y.n2; + loadg(y,ypr,b,t); + + for(int i=0; iSO3part_CGkernels_explicit && xl<=maxl1_explicit && yl<=maxl1_explicit && l<=maxl_explicit){ + cout<<"Explicit!"<>> + (r,xtiled,ytiled,Cptr,cptr,preloadCG); + return; + } + + cout<<"error"<(_shared); + float* xpi=xpr+loadg(x,xpr,b,t); + + float* ypr=xpr+((2*x.n1*xn-1)/32+1)*32; + float* ypi=ypr+loadg(y,ypr,b,t); + + float* rpr=ypr+((2*y.n1*yn-1)/32+1)*32; + float* rpi=rpr+loadg(r,rpr,b,t); + + float* cptr; + const float C_ptr=reinterpret_cast(cg_cmem)+Cptr; + if(preloadCG){ + cptr=rpr+((2*r.n1*rn-1)/32+1)*32; + loadf(cptr,C_ptr,x.n1*y.n1,t); + }else cptr=C_ptr; + + __syncthreads(); + + if(tl2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*L2+m2+l2]; + const float y_r=ypr[yn*(m2+l2)]; + const float y_i=ypi[yn*(m2+l2)]; + _rpr[rn*(m1+m2+l)]+=c*(x_r*y_r-x_i*y_i); + _rpi[rn*(m1+m2+l)]+=c*(x_r*y_i+x_i*y_r); + } + } + } + + __syncthreads(); + saveg(r,rpr,b,t); +} +*/ diff --git a/cuda/SO3partB_addCGproduct_back0.cu b/cuda/SO3partB_addCGproduct_back0.cu new file mode 100644 index 0000000..f42dc81 --- /dev/null +++ b/cuda/SO3partB_addCGproduct_back0.cu @@ -0,0 +1,323 @@ +/* + * This file is part of GElib, a C++/CUDA library for group equivariant + * tensor operations. + * + * Copyright (c) 2023, Imre Risi Kondor + * + * This source code file is subject to the terms of the noncommercial + * license distributed with GElib in the file NONCOMMERICAL.TXT. Commercial + * use is prohibited. All redistributed versions of this file (in orginal + * or modified form) must retain this copyright notice and must be + * accompanied by a verbatim copy of the license. + * + */ + +#ifndef _SO3partB_addCGproduct_back0_cu +#define _SO3partB_addCGproduct_back0_cu + +#include +#include + +#include "SO3_CGbank.hpp" +#include "Ctensor3_view.hpp" +#include "cuda_loaders.cu" + + +extern GElib::SO3_CGbank SO3_cgbank; + + + + +__global__ void SO3partB_addCGproduct_back0_tiled_kernel(const cnine::Ctensor4_view_t3 x, const cnine::Ctensor3_view r, + const cnine::Ctensor4_view_t3 y, const int Cptr, float* cptr_global, const bool preloadCG){ + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int L2=y.n1; + + float* cptr; + float* xpr; + if(preloadCG){ + cptr=reinterpret_cast(_shared); + xpr=cptr+((x.n1*y.n1-1)/32+1)*32; + if(Cptr>=0) loadf(cptr,reinterpret_cast(cg_cmem)+Cptr,x.n1*y.n1); + else loadf(cptr,cptr_global,x.n1*y.n1); + }else{ + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + xpr=reinterpret_cast(_shared); + } + + float* xpi=xpr+x.n1*x.n3; + float* ypr=xpr+((2*x.n1*x.n3-1)/32+1)*32; + float* ypi=ypr+y.n1*y.n3; + + int xs1=x.n3; + int ys1=y.n3; + int rs1=r.s1; + int ytot=(y.n2-1)*y.n3+y.last; + + + for(int i=0; il2) upper=l2; + float x_r=0; + float x_i=0; + + for(int ycol=0; ycol>> + (xtiled,r,ytiled,Cptr,cptr,preloadCG); + return; + } + + cout<<"error"<l2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*r2+m2+l2]; + const float y_r=shared[ypr+ywidth*(m2+l2)]; + const float y_i=shared[ypi+ywidth*(m2+l2)]; + shared[rpr+rwidth*(m1+m2+l)]+=c*(x_r*y_r-x_i*y_i); + shared[rpi+rwidth*(m1+m2+l)]+=c*(x_r*y_i+x_i*y_r); + } + } + } + */ +/* +__device__ int loadg1(const cnine::Ctensor3_view& x, float* dest, const int b, const int t){ + int I=x.n1; + int J=x.n2; + int s1=x.s1; + int s2=x.s2; + int offs=I*J; + float* destc=dest+offs; + float* source=x.arr+x.s0*b; + float* sourcec=x.arrc+x.s0*b; + if(t(cg_cmem)+Cptr; + const int b=blockIdx.x; + const int t=threadIdx.x; + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int xn=x.n2; + int yn=y.n2; + int rn=xn*yn; + int L2=y.n1; + + float* xpr=reinterpret_cast(_shared); + float* xpi=xpr+loadg(x,xpr,b,t); + + float* ypr=xpr+((2*x.n1*xn-1)/32+1)*32; + float* ypi=ypr+loadg(y,ypr,b,t); + + float* rpr=ypr+((2*y.n1*yn-1)/32+1)*32; + float* rpi=rpr+loadg(r,rpr,b,t); + + __syncthreads(); + + + float* _xpr=xpr+t; + float* _xpi=xpi+t; + + for(int ycol=0; ycoll2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*L2+m2+l2]; + const float y_r=_ypr[yn*(m2+l2)]; + const float y_i=_ypi[yn*(m2+l2)]; + const float g_r=_rpr[rn*(m1+m2+l)]; + const float g_i=_rpi[rn*(m1+m2+l)]; + _xpr[xn*(m1+l1)]+=c*(g_r*y_r+g_i*y_i); + _xpi[xn*(m1+l1)]+=c*(-g_r*y_i+g_i*y_r); + } + } + } + __syncthreads(); + } + + + __syncthreads(); + + saveg(x,xpr,b,t); + +} +*/ + + /* + if(nlines<=384){ + SO3partB_addCGproduct_back0_kernel<<>> + (xg,rg,y,Cptr); + }else{ + cout<<"error"< +#include + +#include "SO3_CGbank.hpp" +#include "Ctensor3_view.hpp" +#include "cuda_loaders.cu" + + +extern GElib::SO3_CGbank SO3_cgbank; + + + +__global__ void SO3partB_addCGproduct_back1_tiled_kernel(const cnine::Ctensor4_view_t3 y, const cnine::Ctensor3_view r, + const cnine::Ctensor4_view_t3 x, const int Cptr, float* cptr_global, const bool preloadCG){ + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int L2=y.n1; + + float* cptr; + float* xpr; + if(preloadCG){ + cptr=reinterpret_cast(_shared); + xpr=cptr+((x.n1*y.n1-1)/32+1)*32; + if(Cptr>=0) loadf(cptr,reinterpret_cast(cg_cmem)+Cptr,x.n1*y.n1); + else loadf(cptr,cptr_global,x.n1*y.n1); + }else{ + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + xpr=reinterpret_cast(_shared); + } + + float* xpi=xpr+x.n1*x.n3; + float* ypr=xpr+((2*x.n1*x.n3-1)/32+1)*32; + float* ypi=ypr+y.n1*y.n3; + + int xs1=x.n3; + int ys1=y.n3; + int rs1=r.s1; + int ytot=(y.n2-1)*y.n3+y.last; + + + for(int j=0; jl1) upper=l1; + float y_r=0; + float y_i=0; + + for(int xcol=0; xcol>> + (ytiled,r,xtiled,Cptr,cptr,preloadCG); + return; + } + + cout<<"error"<>> + (yg,g,x,Cptr); + }else{ + cout<<"error"<(cg_cmem)+Cptr; + const int b=blockIdx.x; + const int t=threadIdx.x; + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int xn=x.n2; + int yn=y.n2; + int rn=xn*yn; + int L2=y.n1; + + float* xpr=reinterpret_cast(_shared); + float* xpi=xpr+loadg(x,xpr,b,t); + + float* ypr=xpr+((2*x.n1*xn-1)/32+1)*32; + float* ypi=ypr+loadg(y,ypr,b,t); + + float* rpr=ypr+((2*y.n1*yn-1)/32+1)*32; + float* rpi=rpr+loadg(r,rpr,b,t); + + __syncthreads(); + + + for(int xcol=0; xcoll2) upper=l2; + for(int m2=lower; m2<=upper; m2++){ + float c=C_ptr[(m1+l1)*L2+m2+l2]; + const float g_r=_rpr[rn*(m1+m2+l)]; + const float g_i=_rpi[rn*(m1+m2+l)]; + _ypr[yn*(m2+l2)]+=c*(g_r*x_r+g_i*x_i); + _ypi[yn*(m2+l2)]+=c*(-g_r*x_i+g_i*x_r); + } + } + } + __syncthreads(); + } + + __syncthreads(); + saveg(y,ypr,b,t); + +} +*/ diff --git a/cuda/SO3partB_addCGsquare.cu b/cuda/SO3partB_addCGsquare.cu new file mode 100644 index 0000000..27a65d7 --- /dev/null +++ b/cuda/SO3partB_addCGsquare.cu @@ -0,0 +1,156 @@ +/* + * This file is part of GElib, a C++/CUDA library for group equivariant + * tensor operations. + * + * Copyright (c) 2023, Imre Risi Kondor + * + * This source code file is subject to the terms of the noncommercial + * license distributed with GElib in the file NONCOMMERICAL.TXT. Commercial + * use is prohibited. All redistributed versions of this file (in orginal + * or modified form) must retain this copyright notice and must be + * accompanied by a verbatim copy of the license. + * + */ + +#ifndef _SO3partB_addCGsquare_cu +#define _SO3partB_addCGsquare_cu + +#include +#include + +#include "SO3_CGbank.hpp" +#include "Ctensor3_view.hpp" +#include "Ctensor4_view.hpp" +#include "cuda_loaders.cu" + + +extern GElib::SO3_CGbank SO3_cgbank; + + +__global__ void SO3partB_addCGsquare_tiled_kernel(const cnine::Ctensor3_view r, const cnine::Ctensor4_view_t3 x, + const int Cptr, float* cptr_global, const bool preloadCG){ + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + + int l1=(x.n1-1)/2; + int l=(r.n1-1)/2; + int L2=x.n1; + + float* cptr; + float* xpr; + if(preloadCG){ + cptr=reinterpret_cast(_shared); + xpr=cptr+((x.n1*x.n1-1)/32+1)*32; + if(Cptr>=0) loadf(cptr,reinterpret_cast(cg_cmem)+Cptr,x.n1*x.n1); + else loadf(cptr,cptr_global,x.n1*x.n1); + }else{ + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + xpr=reinterpret_cast(_shared); + } + + float* xpi=xpr+x.n1*x.n3; + float* ypr=xpr+((2*x.n1*x.n3-1)/32+1)*32; + float* ypi=ypr+y.n1*y.n3; + + int xs1=x.n3; + int ys1=y.n3; + int rs1=r.s1; + int ytot=(y.n2-1)*y.n3+y.last; + + for(int i=0; i>> + (r,xtiled,Cptr,cptr,preloadCG); + return; + } + + cout<<"error"< +#include + +#include "SO3_CGbank.hpp" +#include "Ctensor3_view.hpp" +#include "Ctensor4_view.hpp" +#include "cuda_loaders.cu" + + +extern GElib::SO3_CGbank SO3_cgbank; + + +__global__ void SO3partB_addCGtransform_kernel(const cnine::Ctensor3_view r, const cnine::Ctensor4_view x, + const int Cptr, float* cptr_global, const bool preloadCG){ + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + + int l1=(x.n1-1)/2; + int l2=(x.n2-1)/2; + int l=(r.n1-1)/2; + //int L2=y.n1; + + float* cptr; + float* xpr; + if(preloadCG){ + cptr=reinterpret_cast(_shared); + xpr=cptr+((x.n1*x.n2-1)/32+1)*32; + if(Cptr>=0) loadf(cptr,reinterpret_cast(cg_cmem)+Cptr,x.n1*x.n2); + else loadf(cptr,cptr_global,x.n1*x.n2); + }else{ + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + xpr=reinterpret_cast(_shared); + } + + + +} + + +namespace GElib{ + + + void SO3partB_addCGtransform_cu(cnine::Ctensor3_view r, const cnine::Ctensor4_view& x, + const int offs, const cudaStream_t& stream){ + + const int xl=(x.n1-1)/2; + const int yl=(x.n2-1)/2; + const int l=(r.n1-1)/2; + const int b=r.n0; + + r.arr+=r.s2*offs; + r.arrc+=r.s2*offs; + r.n2=x.n2; + GELIB_CHECK(x.n2==y.n2,"Diag mismatch."); + //GELIB_CHECK(x.n2*y.n2<=1024,"Number of ouput channels can be at most 1024.") + + float* cptr=nullptr; + int Cptr=SO3_cgbank.getfC(xl,yl,l)/4; + if(Cptr<0) cptr=SO3_cgbank.getf(CGindex(xl,yl,l),r.dev).arrg; + int clines=cnine::roundup(x.n1*y.n1,32)/32; + + const int tilesize=std::min(x.n2,32); + cnine::Ctensor4_view_t3 xtiled(x,tilesize); + cnine::Ctensor4_view_t3 ytiled(y,tilesize); + + int nlines=cnine::roundup(xtiled.n1*tilesize*2,32)/32+ + cnine::roundup(ytiled.n1*tilesize*2,32)/32; + + if(nlines<=384){ + bool preloadCG=(nlines+clines<=384); + //preloadCG=false; + SO3partB_addDiagCGproduct_tiled_kernel<<>> + (r,xtiled,ytiled,Cptr,cptr,preloadCG); + return; + } + + cout<<"error"< +#include + +#include "SO3_CGbank.hpp" +#include "Ctensor3_view.hpp" +#include "Ctensor4_view.hpp" +#include "cuda_loaders.cu" + + +extern GElib::SO3_CGbank SO3_cgbank; +//extern long int opcount; + +// Process ncells number of cells in one call +__global__ void SO3partB_addDiagCGproduct_kernel(const cnine::Ctensor3_view r, const cnine::Ctensor3_view x, + const cnine::Ctensor3_view y, const int Cptr, float* cptr_global, const bool preloadCG, const int ncells){ + + bool loadr=false; // does not work because of striding of r + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + const int t0=t/x.n2; // cell selector + const int t1=t%x.n2; // channel selector within cell + const int actual_ncells=min(ncells,r.n0-b*ncells); + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int L2=y.n1; + + float* cptr; + float* xpr; + if(preloadCG){ + cptr=reinterpret_cast(_shared); + xpr=cptr+((x.n1*y.n1-1)/32+1)*32; + if(Cptr>=0) loadf(cptr,reinterpret_cast(cg_cmem)+Cptr,x.n1*y.n1); + else loadf(cptr,cptr_global,x.n1*y.n1); + }else{ + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + xpr=reinterpret_cast(_shared); + } + + float* xpi=xpr+actual_ncells*x.n1*x.n2; + float* ypr=xpr+((2*actual_ncells*x.n1*x.n2-1)/32+1)*32; + float* ypi=ypr+actual_ncells*y.n1*y.n2; + float* rpr=ypr+((2*actual_ncells*y.n1*y.n2-1)/32+1)*32; + float* rpi=rpr+actual_ncells*r.n1*r.n2; // should be x.n2?? + + int xs1=x.s1/2; + int ys1=y.s1/2; + int rs1=r.s1; // changed! + if(loadr) rs1=r.n2; // changed! + + //if(t==0) printf("%d %d %d\n",r.n1,r.n2,r.s0); + + loadf_strided(xpr,x.arr+b*ncells*x.s0,actual_ncells*x.n1*x.n2,2); + loadf_strided(xpi,x.arrc+b*ncells*x.s0,actual_ncells*x.n1*x.n2,2); + loadf_strided(ypr,y.arr+b*ncells*y.s0,actual_ncells*y.n1*y.n2,2); + loadf_strided(ypi,y.arrc+b*ncells*y.s0,actual_ncells*y.n1*y.n2,2); + if(loadr) + for(int i=0; i(_shared); + xpr=cptr+((x.n1*y.n1-1)/32+1)*32; + if(Cptr>=0) loadf(cptr,reinterpret_cast(cg_cmem)+Cptr,x.n1*y.n1); + else loadf(cptr,cptr_global,x.n1*y.n1); + }else{ + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + xpr=reinterpret_cast(_shared); + } + + float* xpi=xpr+x.n1*x.n3; + float* ypr=xpr+((2*x.n1*x.n3-1)/32+1)*32; + float* ypi=ypr+y.n1*y.n3; + + int xs1=x.n3; + int ys1=y.n3; + int rs1=r.s1; + + assert(x.n2==y.n2); + + for(int i=0; i +#include + +#include "SO3_CGbank.hpp" +#include "Ctensor3_view.hpp" +#include "cuda_loaders.cu" + + +extern GElib::SO3_CGbank SO3_cgbank; + + + + +// Process ncells number of cells in one call +__global__ void SO3partB_addDiagCGproduct_back0_kernel(const cnine::Ctensor3_view x, const cnine::Ctensor3_view r, + const cnine::Ctensor3_view y, const int Cptr, float* cptr_global, const bool preloadCG, const int ncells){ + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + const int t0=t/x.n2; // cell selector + const int t1=t%x.n2; // channel selector within cell + const int actual_ncells=min(ncells,r.n0-b*ncells); + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int L2=y.n1; + + float* cptr; + float* xpr; + if(preloadCG){ + cptr=reinterpret_cast(_shared); + xpr=cptr+((x.n1*y.n1-1)/32+1)*32; + if(Cptr>=0) loadf(cptr,reinterpret_cast(cg_cmem)+Cptr,x.n1*y.n1); + else loadf(cptr,cptr_global,x.n1*y.n1); + }else{ + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + xpr=reinterpret_cast(_shared); + } + + float* xpi=xpr+actual_ncells*x.n1*x.n2; + float* ypr=xpr+((2*actual_ncells*x.n1*x.n2-1)/32+1)*32; + float* ypi=ypr+actual_ncells*y.n1*y.n2; + + int xs1=x.s1/2; + int ys1=y.s1/2; + int rs1=r.s1; + + loadf_strided(xpr,x.arr+b*ncells*x.s0,actual_ncells*x.n1*x.n2,2); + loadf_strided(xpi,x.arrc+b*ncells*x.s0,actual_ncells*x.n1*x.n2,2); + loadf_strided(ypr,y.arr+b*ncells*y.s0,actual_ncells*y.n1*y.n2,2); + loadf_strided(ypi,y.arrc+b*ncells*y.s0,actual_ncells*y.n1*y.n2,2); + __syncthreads(); + + + // this handles both the padding of the number of threads to a multiple of 32 + // and the padding of the number of blocks to a multiple of ncells + if(t0l2) upper=l2; + float x_r=0; + float x_i=0; + + for(int m2=lower; m2<=upper; m2++){ + float c=cptr[(m1+l1)*L2+m2+l2]; + const float y_r=_ypr[ys1*(m2+l2)]; + const float y_i=_ypi[ys1*(m2+l2)]; + const float g_r=_rpr[rs1*(m1+m2+l)]; + const float g_i=_rpi[rs1*(m1+m2+l)]; + x_r+=c*(g_r*y_r+g_i*y_i); + x_i+=c*(-g_r*y_i+g_i*y_r); + } + + _xpr[xs1*(m1+l1)]+=x_r; + _xpi[xs1*(m1+l1)]+=x_i; + } + } + + __syncthreads(); + savef_strided(xpr,x.arr+b*ncells*x.s0,actual_ncells*x.n1*x.n2,2); + savef_strided(xpi,x.arrc+b*ncells*x.s0,actual_ncells*x.n1*x.n2,2); + +} + + +__global__ void SO3partB_addDiagCGproduct_back0_tiled_kernel(const cnine::Ctensor4_view_t3 x, const cnine::Ctensor3_view r, + const cnine::Ctensor4_view_t3 y, const int Cptr, float* cptr_global, const bool preloadCG){ + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int L2=y.n1; + + float* cptr; + float* xpr; + if(preloadCG){ + cptr=reinterpret_cast(_shared); + xpr=cptr+((x.n1*y.n1-1)/32+1)*32; + if(Cptr>=0) loadf(cptr,reinterpret_cast(cg_cmem)+Cptr,x.n1*y.n1); + else loadf(cptr,cptr_global,x.n1*y.n1); + }else{ + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + xpr=reinterpret_cast(_shared); + } + + float* xpi=xpr+x.n1*x.n3; + float* ypr=xpr+((2*x.n1*x.n3-1)/32+1)*32; + float* ypi=ypr+y.n1*y.n3; + + int xs1=x.n3; + int ys1=y.n3; + int rs1=r.s1; + assert(x.n2==y.n2); + + for(int i=0; il2) upper=l2; + float x_r=0; + float x_i=0; + + float* _ypr=ypr+t; + float* _ypi=ypi+t; + float* _rpr=r.arr+r.s0*b+r.s2*(i*x.n3+t); + float* _rpi=r.arrc+r.s0*b+r.s2*(i*x.n3+t); + + for(int m2=lower; m2<=upper; m2++){ + float c=cptr[(m1+l1)*L2+m2+l2]; + const float y_r=_ypr[ys1*(m2+l2)]; + const float y_i=_ypi[ys1*(m2+l2)]; + const float g_r=_rpr[rs1*(m1+m2+l)]; + const float g_i=_rpi[rs1*(m1+m2+l)]; + x_r+=c*(g_r*y_r+g_i*y_i); + x_i+=c*(-g_r*y_i+g_i*y_r); + } + + _xpr[xs1*(m1+l1)]+=x_r; + _xpi[xs1*(m1+l1)]+=x_i; + } + + }// end t0 && nlines<=384){ + bool preloadCG=(nlines+clines<=384); + //cout<<"Launching addDiagCGproduct_kernel_back0 with ncells="< +#include + +#include "SO3_CGbank.hpp" +#include "Ctensor3_view.hpp" +#include "cuda_loaders.cu" + + +extern GElib::SO3_CGbank SO3_cgbank; + + + +__global__ void SO3partB_addDiagCGproduct_back1_kernel(const cnine::Ctensor3_view y, const cnine::Ctensor3_view r, + const cnine::Ctensor3_view x, const int Cptr, float* cptr_global, const bool preloadCG, const int ncells){ + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + const int t0=t/x.n2; // cell selector + const int t1=t%x.n2; // channel selector within cell + const int actual_ncells=min(ncells,r.n0-b*ncells); + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int L2=y.n1; + + float* cptr; + float* xpr; + if(preloadCG){ + cptr=reinterpret_cast(_shared); + xpr=cptr+((x.n1*y.n1-1)/32+1)*32; + if(Cptr>=0) loadf(cptr,reinterpret_cast(cg_cmem)+Cptr,x.n1*y.n1); + else loadf(cptr,cptr_global,x.n1*y.n1); + }else{ + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + xpr=reinterpret_cast(_shared); + } + + float* xpi=xpr+actual_ncells*x.n1*x.n2; + float* ypr=xpr+((2*actual_ncells*x.n1*x.n2-1)/32+1)*32; + float* ypi=ypr+actual_ncells*y.n1*y.n2; + + int xs1=x.s1/2; + int ys1=y.s1/2; + int rs1=r.s1; + + loadf_strided(xpr,x.arr+b*ncells*x.s0,actual_ncells*x.n1*x.n2,2); + loadf_strided(xpi,x.arrc+b*ncells*x.s0,actual_ncells*x.n1*x.n2,2); + loadf_strided(ypr,y.arr+b*ncells*y.s0,actual_ncells*y.n1*y.n2,2); + loadf_strided(ypi,y.arrc+b*ncells*y.s0,actual_ncells*y.n1*y.n2,2); + __syncthreads(); + + + // this handles both the padding of the number of threads to a multiple of 32 + // and the padding of the number of blocks to a multiple of ncells + if(t0l1) upper=l1; + float y_r=0; + float y_i=0; + + for(int m1=lower; m1<=upper; m1++){ + float c=cptr[(m1+l1)*L2+m2+l2]; + const float x_r=_xpr[xs1*(m1+l1)]; + const float x_i=_xpi[xs1*(m1+l1)]; + const float g_r=_rpr[rs1*(m1+m2+l)]; + const float g_i=_rpi[rs1*(m1+m2+l)]; + y_r+=c*(g_r*x_r+g_i*x_i); + y_i+=c*(-g_r*x_i+g_i*x_r); + } + + _ypr[ys1*(m2+l2)]+=y_r; + _ypi[ys1*(m2+l2)]+=y_i; + } + } + + __syncthreads(); + savef_strided(ypr,y.arr+b*ncells*y.s0,actual_ncells*y.n1*y.n2,2); + savef_strided(ypi,y.arrc+b*ncells*y.s0,actual_ncells*y.n1*y.n2,2); + +} + + +__global__ void SO3partB_addDiagCGproduct_back1_tiled_kernel(const cnine::Ctensor4_view_t3 y, const cnine::Ctensor3_view r, + const cnine::Ctensor4_view_t3 x, const int Cptr, float* cptr_global, const bool preloadCG){ + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + + int l1=(x.n1-1)/2; + int l2=(y.n1-1)/2; + int l=(r.n1-1)/2; + int L2=y.n1; + + float* cptr; + float* xpr; + if(preloadCG){ + cptr=reinterpret_cast(_shared); + xpr=cptr+((x.n1*y.n1-1)/32+1)*32; + if(Cptr>=0) loadf(cptr,reinterpret_cast(cg_cmem)+Cptr,x.n1*y.n1); + else loadf(cptr,cptr_global,x.n1*y.n1); + }else{ + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + xpr=reinterpret_cast(_shared); + } + + float* xpi=xpr+x.n1*x.n3; + float* ypr=xpr+((2*x.n1*x.n3-1)/32+1)*32; + float* ypi=ypr+y.n1*y.n3; + + int xs1=x.n3; + int ys1=y.n3; + int rs1=r.s1; + assert(x.n2==y.n2); + + + for(int j=0; jl1) upper=l1; + float y_r=0; + float y_i=0; + + float* _xpr=xpr+t; + float* _xpi=xpi+t; + float* _rpr=r.arr+r.s0*b+r.s2*(j*x.n3+t); + float* _rpi=r.arrc+r.s0*b+r.s2*(j*x.n3+t); + + for(int m1=lower; m1<=upper; m1++){ + float c=cptr[(m1+l1)*L2+m2+l2]; + const float x_r=_xpr[xs1*(m1+l1)]; + const float x_i=_xpi[xs1*(m1+l1)]; + const float g_r=_rpr[rs1*(m1+m2+l)]; + const float g_i=_rpi[rs1*(m1+m2+l)]; + y_r+=c*(g_r*x_r+g_i*x_i); + y_i+=c*(-g_r*x_i+g_i*x_r); + } + + _ypr[ys1*(m2+l2)]+=y_r; + _ypi[ys1*(m2+l2)]+=y_i; + + } + + }// end t0 && nlines<=384){ + bool preloadCG=(nlines+clines<=384); + //cout<<"Launching addDiagCGproduct_kernel_back1 with ncells="<<<>>(r,x,y); break; + } + break; + case 1: + switch(l){ + case 1: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + } + break; + case 2: + switch(l){ + case 2: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + } + break; + } + + break; + case 1: + switch(l2){ + case 0: + switch(l){ + case 1: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + } + break; + case 1: + switch(l){ + case 0: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + case 1: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + case 2: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + } + break; + case 2: + switch(l){ + case 1: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + case 2: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + case 3: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + } + break; + } + + break; + case 2: + switch(l2){ + case 0: + switch(l){ + case 2: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + } + break; + case 1: + switch(l){ + case 1: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + case 2: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + case 3: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + } + break; + case 2: + switch(l){ + case 0: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + case 1: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + case 2: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + case 3: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + case 4: SO3part_addCGproduct_explicit<<>>(r,x,y); break; + } + break; + } + + break; + } diff --git a/cuda/SO3part_addCGproduct_subkernels.inc b/cuda/SO3part_addCGproduct_subkernels.inc new file mode 100644 index 0000000..7cc45c9 --- /dev/null +++ b/cuda/SO3part_addCGproduct_subkernels.inc @@ -0,0 +1,609 @@ +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_0_0_0(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (1.000000f)*(xpr[0]*ypr[0*ys]-xpi[0]*ypi[0*ys]); + rpi[0*rs]+= + (1.000000f)*(xpr[0]*ypi[0*ys]+xpi[0]*ypr[0*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_0_1_1(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (1.000000f)*(xpr[0]*ypr[0*ys]-xpi[0]*ypi[0*ys]); + rpi[0*rs]+= + (1.000000f)*(xpr[0]*ypi[0*ys]+xpi[0]*ypr[0*ys]); + rpr[1*rs]+= + (1.000000f)*(xpr[0]*ypr[1*ys]-xpi[0]*ypi[1*ys]); + rpi[1*rs]+= + (1.000000f)*(xpr[0]*ypi[1*ys]+xpi[0]*ypr[1*ys]); + rpr[2*rs]+= + (1.000000f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys]); + rpi[2*rs]+= + (1.000000f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_0_2_2(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (1.000000f)*(xpr[0]*ypr[0*ys]-xpi[0]*ypi[0*ys]); + rpi[0*rs]+= + (1.000000f)*(xpr[0]*ypi[0*ys]+xpi[0]*ypr[0*ys]); + rpr[1*rs]+= + (1.000000f)*(xpr[0]*ypr[1*ys]-xpi[0]*ypi[1*ys]); + rpi[1*rs]+= + (1.000000f)*(xpr[0]*ypi[1*ys]+xpi[0]*ypr[1*ys]); + rpr[2*rs]+= + (1.000000f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys]); + rpi[2*rs]+= + (1.000000f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys]); + rpr[3*rs]+= + (1.000000f)*(xpr[0]*ypr[3*ys]-xpi[0]*ypi[3*ys]); + rpi[3*rs]+= + (1.000000f)*(xpr[0]*ypi[3*ys]+xpi[0]*ypr[3*ys]); + rpr[4*rs]+= + (1.000000f)*(xpr[0]*ypr[4*ys]-xpi[0]*ypi[4*ys]); + rpi[4*rs]+= + (1.000000f)*(xpr[0]*ypi[4*ys]+xpi[0]*ypr[4*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_1_0_1(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (1.000000f)*(xpr[0]*ypr[0*ys]-xpi[0]*ypi[0*ys]); + rpi[0*rs]+= + (1.000000f)*(xpr[0]*ypi[0*ys]+xpi[0]*ypr[0*ys]); + rpr[1*rs]+= + (1.000000f)*(xpr[1]*ypr[0*ys]-xpi[1]*ypi[0*ys]); + rpi[1*rs]+= + (1.000000f)*(xpr[1]*ypi[0*ys]+xpi[1]*ypr[0*ys]); + rpr[2*rs]+= + (1.000000f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[2*rs]+= + (1.000000f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_1_1_0(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (0.577350f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys])+ + (-0.577350f)*(xpr[1]*ypr[1*ys]-xpi[1]*ypi[1*ys])+ + (0.577350f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[0*rs]+= + (0.577350f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys])+ + (-0.577350f)*(xpr[1]*ypi[1*ys]+xpi[1]*ypr[1*ys])+ + (0.577350f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_1_1_1(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (-0.707107f)*(xpr[0]*ypr[1*ys]-xpi[0]*ypi[1*ys])+ + (0.707107f)*(xpr[1]*ypr[0*ys]-xpi[1]*ypi[0*ys]); + rpi[0*rs]+= + (-0.707107f)*(xpr[0]*ypi[1*ys]+xpi[0]*ypr[1*ys])+ + (0.707107f)*(xpr[1]*ypi[0*ys]+xpi[1]*ypr[0*ys]); + rpr[1*rs]+= + (-0.707107f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys])+ + (0.000000f)*(xpr[1]*ypr[1*ys]-xpi[1]*ypi[1*ys])+ + (0.707107f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[1*rs]+= + (-0.707107f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys])+ + (0.000000f)*(xpr[1]*ypi[1*ys]+xpi[1]*ypr[1*ys])+ + (0.707107f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); + rpr[2*rs]+= + (-0.707107f)*(xpr[1]*ypr[2*ys]-xpi[1]*ypi[2*ys])+ + (0.707107f)*(xpr[2]*ypr[1*ys]-xpi[2]*ypi[1*ys]); + rpi[2*rs]+= + (-0.707107f)*(xpr[1]*ypi[2*ys]+xpi[1]*ypr[2*ys])+ + (0.707107f)*(xpr[2]*ypi[1*ys]+xpi[2]*ypr[1*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_1_1_2(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (1.000000f)*(xpr[0]*ypr[0*ys]-xpi[0]*ypi[0*ys]); + rpi[0*rs]+= + (1.000000f)*(xpr[0]*ypi[0*ys]+xpi[0]*ypr[0*ys]); + rpr[1*rs]+= + (0.707107f)*(xpr[0]*ypr[1*ys]-xpi[0]*ypi[1*ys])+ + (0.707107f)*(xpr[1]*ypr[0*ys]-xpi[1]*ypi[0*ys]); + rpi[1*rs]+= + (0.707107f)*(xpr[0]*ypi[1*ys]+xpi[0]*ypr[1*ys])+ + (0.707107f)*(xpr[1]*ypi[0*ys]+xpi[1]*ypr[0*ys]); + rpr[2*rs]+= + (0.408248f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys])+ + (0.816497f)*(xpr[1]*ypr[1*ys]-xpi[1]*ypi[1*ys])+ + (0.408248f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[2*rs]+= + (0.408248f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys])+ + (0.816497f)*(xpr[1]*ypi[1*ys]+xpi[1]*ypr[1*ys])+ + (0.408248f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); + rpr[3*rs]+= + (0.707107f)*(xpr[1]*ypr[2*ys]-xpi[1]*ypi[2*ys])+ + (0.707107f)*(xpr[2]*ypr[1*ys]-xpi[2]*ypi[1*ys]); + rpi[3*rs]+= + (0.707107f)*(xpr[1]*ypi[2*ys]+xpi[1]*ypr[2*ys])+ + (0.707107f)*(xpr[2]*ypi[1*ys]+xpi[2]*ypr[1*ys]); + rpr[4*rs]+= + (1.000000f)*(xpr[2]*ypr[2*ys]-xpi[2]*ypi[2*ys]); + rpi[4*rs]+= + (1.000000f)*(xpr[2]*ypi[2*ys]+xpi[2]*ypr[2*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_1_2_1(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (0.316228f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys])+ + (-0.547723f)*(xpr[1]*ypr[1*ys]-xpi[1]*ypi[1*ys])+ + (0.774597f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[0*rs]+= + (0.316228f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys])+ + (-0.547723f)*(xpr[1]*ypi[1*ys]+xpi[1]*ypr[1*ys])+ + (0.774597f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); + rpr[1*rs]+= + (0.547723f)*(xpr[0]*ypr[3*ys]-xpi[0]*ypi[3*ys])+ + (-0.632456f)*(xpr[1]*ypr[2*ys]-xpi[1]*ypi[2*ys])+ + (0.547723f)*(xpr[2]*ypr[1*ys]-xpi[2]*ypi[1*ys]); + rpi[1*rs]+= + (0.547723f)*(xpr[0]*ypi[3*ys]+xpi[0]*ypr[3*ys])+ + (-0.632456f)*(xpr[1]*ypi[2*ys]+xpi[1]*ypr[2*ys])+ + (0.547723f)*(xpr[2]*ypi[1*ys]+xpi[2]*ypr[1*ys]); + rpr[2*rs]+= + (0.774597f)*(xpr[0]*ypr[4*ys]-xpi[0]*ypi[4*ys])+ + (-0.547723f)*(xpr[1]*ypr[3*ys]-xpi[1]*ypi[3*ys])+ + (0.316228f)*(xpr[2]*ypr[2*ys]-xpi[2]*ypi[2*ys]); + rpi[2*rs]+= + (0.774597f)*(xpr[0]*ypi[4*ys]+xpi[0]*ypr[4*ys])+ + (-0.547723f)*(xpr[1]*ypi[3*ys]+xpi[1]*ypr[3*ys])+ + (0.316228f)*(xpr[2]*ypi[2*ys]+xpi[2]*ypr[2*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_1_2_2(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (-0.577350f)*(xpr[0]*ypr[1*ys]-xpi[0]*ypi[1*ys])+ + (0.816497f)*(xpr[1]*ypr[0*ys]-xpi[1]*ypi[0*ys]); + rpi[0*rs]+= + (-0.577350f)*(xpr[0]*ypi[1*ys]+xpi[0]*ypr[1*ys])+ + (0.816497f)*(xpr[1]*ypi[0*ys]+xpi[1]*ypr[0*ys]); + rpr[1*rs]+= + (-0.707107f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys])+ + (0.408248f)*(xpr[1]*ypr[1*ys]-xpi[1]*ypi[1*ys])+ + (0.577350f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[1*rs]+= + (-0.707107f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys])+ + (0.408248f)*(xpr[1]*ypi[1*ys]+xpi[1]*ypr[1*ys])+ + (0.577350f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); + rpr[2*rs]+= + (-0.707107f)*(xpr[0]*ypr[3*ys]-xpi[0]*ypi[3*ys])+ + (0.000000f)*(xpr[1]*ypr[2*ys]-xpi[1]*ypi[2*ys])+ + (0.707107f)*(xpr[2]*ypr[1*ys]-xpi[2]*ypi[1*ys]); + rpi[2*rs]+= + (-0.707107f)*(xpr[0]*ypi[3*ys]+xpi[0]*ypr[3*ys])+ + (0.000000f)*(xpr[1]*ypi[2*ys]+xpi[1]*ypr[2*ys])+ + (0.707107f)*(xpr[2]*ypi[1*ys]+xpi[2]*ypr[1*ys]); + rpr[3*rs]+= + (-0.577350f)*(xpr[0]*ypr[4*ys]-xpi[0]*ypi[4*ys])+ + (-0.408248f)*(xpr[1]*ypr[3*ys]-xpi[1]*ypi[3*ys])+ + (0.707107f)*(xpr[2]*ypr[2*ys]-xpi[2]*ypi[2*ys]); + rpi[3*rs]+= + (-0.577350f)*(xpr[0]*ypi[4*ys]+xpi[0]*ypr[4*ys])+ + (-0.408248f)*(xpr[1]*ypi[3*ys]+xpi[1]*ypr[3*ys])+ + (0.707107f)*(xpr[2]*ypi[2*ys]+xpi[2]*ypr[2*ys]); + rpr[4*rs]+= + (-0.816497f)*(xpr[1]*ypr[4*ys]-xpi[1]*ypi[4*ys])+ + (0.577350f)*(xpr[2]*ypr[3*ys]-xpi[2]*ypi[3*ys]); + rpi[4*rs]+= + (-0.816497f)*(xpr[1]*ypi[4*ys]+xpi[1]*ypr[4*ys])+ + (0.577350f)*(xpr[2]*ypi[3*ys]+xpi[2]*ypr[3*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_1_2_3(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (1.000000f)*(xpr[0]*ypr[0*ys]-xpi[0]*ypi[0*ys]); + rpi[0*rs]+= + (1.000000f)*(xpr[0]*ypi[0*ys]+xpi[0]*ypr[0*ys]); + rpr[1*rs]+= + (0.816497f)*(xpr[0]*ypr[1*ys]-xpi[0]*ypi[1*ys])+ + (0.577350f)*(xpr[1]*ypr[0*ys]-xpi[1]*ypi[0*ys]); + rpi[1*rs]+= + (0.816497f)*(xpr[0]*ypi[1*ys]+xpi[0]*ypr[1*ys])+ + (0.577350f)*(xpr[1]*ypi[0*ys]+xpi[1]*ypr[0*ys]); + rpr[2*rs]+= + (0.632456f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys])+ + (0.730297f)*(xpr[1]*ypr[1*ys]-xpi[1]*ypi[1*ys])+ + (0.258199f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[2*rs]+= + (0.632456f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys])+ + (0.730297f)*(xpr[1]*ypi[1*ys]+xpi[1]*ypr[1*ys])+ + (0.258199f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); + rpr[3*rs]+= + (0.447214f)*(xpr[0]*ypr[3*ys]-xpi[0]*ypi[3*ys])+ + (0.774597f)*(xpr[1]*ypr[2*ys]-xpi[1]*ypi[2*ys])+ + (0.447214f)*(xpr[2]*ypr[1*ys]-xpi[2]*ypi[1*ys]); + rpi[3*rs]+= + (0.447214f)*(xpr[0]*ypi[3*ys]+xpi[0]*ypr[3*ys])+ + (0.774597f)*(xpr[1]*ypi[2*ys]+xpi[1]*ypr[2*ys])+ + (0.447214f)*(xpr[2]*ypi[1*ys]+xpi[2]*ypr[1*ys]); + rpr[4*rs]+= + (0.258199f)*(xpr[0]*ypr[4*ys]-xpi[0]*ypi[4*ys])+ + (0.730297f)*(xpr[1]*ypr[3*ys]-xpi[1]*ypi[3*ys])+ + (0.632456f)*(xpr[2]*ypr[2*ys]-xpi[2]*ypi[2*ys]); + rpi[4*rs]+= + (0.258199f)*(xpr[0]*ypi[4*ys]+xpi[0]*ypr[4*ys])+ + (0.730297f)*(xpr[1]*ypi[3*ys]+xpi[1]*ypr[3*ys])+ + (0.632456f)*(xpr[2]*ypi[2*ys]+xpi[2]*ypr[2*ys]); + rpr[5*rs]+= + (0.577350f)*(xpr[1]*ypr[4*ys]-xpi[1]*ypi[4*ys])+ + (0.816497f)*(xpr[2]*ypr[3*ys]-xpi[2]*ypi[3*ys]); + rpi[5*rs]+= + (0.577350f)*(xpr[1]*ypi[4*ys]+xpi[1]*ypr[4*ys])+ + (0.816497f)*(xpr[2]*ypi[3*ys]+xpi[2]*ypr[3*ys]); + rpr[6*rs]+= + (1.000000f)*(xpr[2]*ypr[4*ys]-xpi[2]*ypi[4*ys]); + rpi[6*rs]+= + (1.000000f)*(xpr[2]*ypi[4*ys]+xpi[2]*ypr[4*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_2_0_2(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (1.000000f)*(xpr[0]*ypr[0*ys]-xpi[0]*ypi[0*ys]); + rpi[0*rs]+= + (1.000000f)*(xpr[0]*ypi[0*ys]+xpi[0]*ypr[0*ys]); + rpr[1*rs]+= + (1.000000f)*(xpr[1]*ypr[0*ys]-xpi[1]*ypi[0*ys]); + rpi[1*rs]+= + (1.000000f)*(xpr[1]*ypi[0*ys]+xpi[1]*ypr[0*ys]); + rpr[2*rs]+= + (1.000000f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[2*rs]+= + (1.000000f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); + rpr[3*rs]+= + (1.000000f)*(xpr[3]*ypr[0*ys]-xpi[3]*ypi[0*ys]); + rpi[3*rs]+= + (1.000000f)*(xpr[3]*ypi[0*ys]+xpi[3]*ypr[0*ys]); + rpr[4*rs]+= + (1.000000f)*(xpr[4]*ypr[0*ys]-xpi[4]*ypi[0*ys]); + rpi[4*rs]+= + (1.000000f)*(xpr[4]*ypi[0*ys]+xpi[4]*ypr[0*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_2_1_1(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (0.774597f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys])+ + (-0.547723f)*(xpr[1]*ypr[1*ys]-xpi[1]*ypi[1*ys])+ + (0.316228f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[0*rs]+= + (0.774597f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys])+ + (-0.547723f)*(xpr[1]*ypi[1*ys]+xpi[1]*ypr[1*ys])+ + (0.316228f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); + rpr[1*rs]+= + (0.547723f)*(xpr[1]*ypr[2*ys]-xpi[1]*ypi[2*ys])+ + (-0.632456f)*(xpr[2]*ypr[1*ys]-xpi[2]*ypi[1*ys])+ + (0.547723f)*(xpr[3]*ypr[0*ys]-xpi[3]*ypi[0*ys]); + rpi[1*rs]+= + (0.547723f)*(xpr[1]*ypi[2*ys]+xpi[1]*ypr[2*ys])+ + (-0.632456f)*(xpr[2]*ypi[1*ys]+xpi[2]*ypr[1*ys])+ + (0.547723f)*(xpr[3]*ypi[0*ys]+xpi[3]*ypr[0*ys]); + rpr[2*rs]+= + (0.316228f)*(xpr[2]*ypr[2*ys]-xpi[2]*ypi[2*ys])+ + (-0.547723f)*(xpr[3]*ypr[1*ys]-xpi[3]*ypi[1*ys])+ + (0.774597f)*(xpr[4]*ypr[0*ys]-xpi[4]*ypi[0*ys]); + rpi[2*rs]+= + (0.316228f)*(xpr[2]*ypi[2*ys]+xpi[2]*ypr[2*ys])+ + (-0.547723f)*(xpr[3]*ypi[1*ys]+xpi[3]*ypr[1*ys])+ + (0.774597f)*(xpr[4]*ypi[0*ys]+xpi[4]*ypr[0*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_2_1_2(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (-0.816497f)*(xpr[0]*ypr[1*ys]-xpi[0]*ypi[1*ys])+ + (0.577350f)*(xpr[1]*ypr[0*ys]-xpi[1]*ypi[0*ys]); + rpi[0*rs]+= + (-0.816497f)*(xpr[0]*ypi[1*ys]+xpi[0]*ypr[1*ys])+ + (0.577350f)*(xpr[1]*ypi[0*ys]+xpi[1]*ypr[0*ys]); + rpr[1*rs]+= + (-0.577350f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys])+ + (-0.408248f)*(xpr[1]*ypr[1*ys]-xpi[1]*ypi[1*ys])+ + (0.707107f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[1*rs]+= + (-0.577350f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys])+ + (-0.408248f)*(xpr[1]*ypi[1*ys]+xpi[1]*ypr[1*ys])+ + (0.707107f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); + rpr[2*rs]+= + (-0.707107f)*(xpr[1]*ypr[2*ys]-xpi[1]*ypi[2*ys])+ + (0.000000f)*(xpr[2]*ypr[1*ys]-xpi[2]*ypi[1*ys])+ + (0.707107f)*(xpr[3]*ypr[0*ys]-xpi[3]*ypi[0*ys]); + rpi[2*rs]+= + (-0.707107f)*(xpr[1]*ypi[2*ys]+xpi[1]*ypr[2*ys])+ + (0.000000f)*(xpr[2]*ypi[1*ys]+xpi[2]*ypr[1*ys])+ + (0.707107f)*(xpr[3]*ypi[0*ys]+xpi[3]*ypr[0*ys]); + rpr[3*rs]+= + (-0.707107f)*(xpr[2]*ypr[2*ys]-xpi[2]*ypi[2*ys])+ + (0.408248f)*(xpr[3]*ypr[1*ys]-xpi[3]*ypi[1*ys])+ + (0.577350f)*(xpr[4]*ypr[0*ys]-xpi[4]*ypi[0*ys]); + rpi[3*rs]+= + (-0.707107f)*(xpr[2]*ypi[2*ys]+xpi[2]*ypr[2*ys])+ + (0.408248f)*(xpr[3]*ypi[1*ys]+xpi[3]*ypr[1*ys])+ + (0.577350f)*(xpr[4]*ypi[0*ys]+xpi[4]*ypr[0*ys]); + rpr[4*rs]+= + (-0.577350f)*(xpr[3]*ypr[2*ys]-xpi[3]*ypi[2*ys])+ + (0.816497f)*(xpr[4]*ypr[1*ys]-xpi[4]*ypi[1*ys]); + rpi[4*rs]+= + (-0.577350f)*(xpr[3]*ypi[2*ys]+xpi[3]*ypr[2*ys])+ + (0.816497f)*(xpr[4]*ypi[1*ys]+xpi[4]*ypr[1*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_2_1_3(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (1.000000f)*(xpr[0]*ypr[0*ys]-xpi[0]*ypi[0*ys]); + rpi[0*rs]+= + (1.000000f)*(xpr[0]*ypi[0*ys]+xpi[0]*ypr[0*ys]); + rpr[1*rs]+= + (0.577350f)*(xpr[0]*ypr[1*ys]-xpi[0]*ypi[1*ys])+ + (0.816497f)*(xpr[1]*ypr[0*ys]-xpi[1]*ypi[0*ys]); + rpi[1*rs]+= + (0.577350f)*(xpr[0]*ypi[1*ys]+xpi[0]*ypr[1*ys])+ + (0.816497f)*(xpr[1]*ypi[0*ys]+xpi[1]*ypr[0*ys]); + rpr[2*rs]+= + (0.258199f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys])+ + (0.730297f)*(xpr[1]*ypr[1*ys]-xpi[1]*ypi[1*ys])+ + (0.632456f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[2*rs]+= + (0.258199f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys])+ + (0.730297f)*(xpr[1]*ypi[1*ys]+xpi[1]*ypr[1*ys])+ + (0.632456f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); + rpr[3*rs]+= + (0.447214f)*(xpr[1]*ypr[2*ys]-xpi[1]*ypi[2*ys])+ + (0.774597f)*(xpr[2]*ypr[1*ys]-xpi[2]*ypi[1*ys])+ + (0.447214f)*(xpr[3]*ypr[0*ys]-xpi[3]*ypi[0*ys]); + rpi[3*rs]+= + (0.447214f)*(xpr[1]*ypi[2*ys]+xpi[1]*ypr[2*ys])+ + (0.774597f)*(xpr[2]*ypi[1*ys]+xpi[2]*ypr[1*ys])+ + (0.447214f)*(xpr[3]*ypi[0*ys]+xpi[3]*ypr[0*ys]); + rpr[4*rs]+= + (0.632456f)*(xpr[2]*ypr[2*ys]-xpi[2]*ypi[2*ys])+ + (0.730297f)*(xpr[3]*ypr[1*ys]-xpi[3]*ypi[1*ys])+ + (0.258199f)*(xpr[4]*ypr[0*ys]-xpi[4]*ypi[0*ys]); + rpi[4*rs]+= + (0.632456f)*(xpr[2]*ypi[2*ys]+xpi[2]*ypr[2*ys])+ + (0.730297f)*(xpr[3]*ypi[1*ys]+xpi[3]*ypr[1*ys])+ + (0.258199f)*(xpr[4]*ypi[0*ys]+xpi[4]*ypr[0*ys]); + rpr[5*rs]+= + (0.816497f)*(xpr[3]*ypr[2*ys]-xpi[3]*ypi[2*ys])+ + (0.577350f)*(xpr[4]*ypr[1*ys]-xpi[4]*ypi[1*ys]); + rpi[5*rs]+= + (0.816497f)*(xpr[3]*ypi[2*ys]+xpi[3]*ypr[2*ys])+ + (0.577350f)*(xpr[4]*ypi[1*ys]+xpi[4]*ypr[1*ys]); + rpr[6*rs]+= + (1.000000f)*(xpr[4]*ypr[2*ys]-xpi[4]*ypi[2*ys]); + rpi[6*rs]+= + (1.000000f)*(xpr[4]*ypi[2*ys]+xpi[4]*ypr[2*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_2_2_0(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (0.447214f)*(xpr[0]*ypr[4*ys]-xpi[0]*ypi[4*ys])+ + (-0.447214f)*(xpr[1]*ypr[3*ys]-xpi[1]*ypi[3*ys])+ + (0.447214f)*(xpr[2]*ypr[2*ys]-xpi[2]*ypi[2*ys])+ + (-0.447214f)*(xpr[3]*ypr[1*ys]-xpi[3]*ypi[1*ys])+ + (0.447214f)*(xpr[4]*ypr[0*ys]-xpi[4]*ypi[0*ys]); + rpi[0*rs]+= + (0.447214f)*(xpr[0]*ypi[4*ys]+xpi[0]*ypr[4*ys])+ + (-0.447214f)*(xpr[1]*ypi[3*ys]+xpi[1]*ypr[3*ys])+ + (0.447214f)*(xpr[2]*ypi[2*ys]+xpi[2]*ypr[2*ys])+ + (-0.447214f)*(xpr[3]*ypi[1*ys]+xpi[3]*ypr[1*ys])+ + (0.447214f)*(xpr[4]*ypi[0*ys]+xpi[4]*ypr[0*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_2_2_1(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (-0.447214f)*(xpr[0]*ypr[3*ys]-xpi[0]*ypi[3*ys])+ + (0.547723f)*(xpr[1]*ypr[2*ys]-xpi[1]*ypi[2*ys])+ + (-0.547723f)*(xpr[2]*ypr[1*ys]-xpi[2]*ypi[1*ys])+ + (0.447214f)*(xpr[3]*ypr[0*ys]-xpi[3]*ypi[0*ys]); + rpi[0*rs]+= + (-0.447214f)*(xpr[0]*ypi[3*ys]+xpi[0]*ypr[3*ys])+ + (0.547723f)*(xpr[1]*ypi[2*ys]+xpi[1]*ypr[2*ys])+ + (-0.547723f)*(xpr[2]*ypi[1*ys]+xpi[2]*ypr[1*ys])+ + (0.447214f)*(xpr[3]*ypi[0*ys]+xpi[3]*ypr[0*ys]); + rpr[1*rs]+= + (-0.632456f)*(xpr[0]*ypr[4*ys]-xpi[0]*ypi[4*ys])+ + (0.316228f)*(xpr[1]*ypr[3*ys]-xpi[1]*ypi[3*ys])+ + (0.000000f)*(xpr[2]*ypr[2*ys]-xpi[2]*ypi[2*ys])+ + (-0.316228f)*(xpr[3]*ypr[1*ys]-xpi[3]*ypi[1*ys])+ + (0.632456f)*(xpr[4]*ypr[0*ys]-xpi[4]*ypi[0*ys]); + rpi[1*rs]+= + (-0.632456f)*(xpr[0]*ypi[4*ys]+xpi[0]*ypr[4*ys])+ + (0.316228f)*(xpr[1]*ypi[3*ys]+xpi[1]*ypr[3*ys])+ + (0.000000f)*(xpr[2]*ypi[2*ys]+xpi[2]*ypr[2*ys])+ + (-0.316228f)*(xpr[3]*ypi[1*ys]+xpi[3]*ypr[1*ys])+ + (0.632456f)*(xpr[4]*ypi[0*ys]+xpi[4]*ypr[0*ys]); + rpr[2*rs]+= + (-0.447214f)*(xpr[1]*ypr[4*ys]-xpi[1]*ypi[4*ys])+ + (0.547723f)*(xpr[2]*ypr[3*ys]-xpi[2]*ypi[3*ys])+ + (-0.547723f)*(xpr[3]*ypr[2*ys]-xpi[3]*ypi[2*ys])+ + (0.447214f)*(xpr[4]*ypr[1*ys]-xpi[4]*ypi[1*ys]); + rpi[2*rs]+= + (-0.447214f)*(xpr[1]*ypi[4*ys]+xpi[1]*ypr[4*ys])+ + (0.547723f)*(xpr[2]*ypi[3*ys]+xpi[2]*ypr[3*ys])+ + (-0.547723f)*(xpr[3]*ypi[2*ys]+xpi[3]*ypr[2*ys])+ + (0.447214f)*(xpr[4]*ypi[1*ys]+xpi[4]*ypr[1*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_2_2_2(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (0.534522f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys])+ + (-0.654654f)*(xpr[1]*ypr[1*ys]-xpi[1]*ypi[1*ys])+ + (0.534522f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[0*rs]+= + (0.534522f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys])+ + (-0.654654f)*(xpr[1]*ypi[1*ys]+xpi[1]*ypr[1*ys])+ + (0.534522f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); + rpr[1*rs]+= + (0.654654f)*(xpr[0]*ypr[3*ys]-xpi[0]*ypi[3*ys])+ + (-0.267261f)*(xpr[1]*ypr[2*ys]-xpi[1]*ypi[2*ys])+ + (-0.267261f)*(xpr[2]*ypr[1*ys]-xpi[2]*ypi[1*ys])+ + (0.654654f)*(xpr[3]*ypr[0*ys]-xpi[3]*ypi[0*ys]); + rpi[1*rs]+= + (0.654654f)*(xpr[0]*ypi[3*ys]+xpi[0]*ypr[3*ys])+ + (-0.267261f)*(xpr[1]*ypi[2*ys]+xpi[1]*ypr[2*ys])+ + (-0.267261f)*(xpr[2]*ypi[1*ys]+xpi[2]*ypr[1*ys])+ + (0.654654f)*(xpr[3]*ypi[0*ys]+xpi[3]*ypr[0*ys]); + rpr[2*rs]+= + (0.534522f)*(xpr[0]*ypr[4*ys]-xpi[0]*ypi[4*ys])+ + (0.267261f)*(xpr[1]*ypr[3*ys]-xpi[1]*ypi[3*ys])+ + (-0.534522f)*(xpr[2]*ypr[2*ys]-xpi[2]*ypi[2*ys])+ + (0.267261f)*(xpr[3]*ypr[1*ys]-xpi[3]*ypi[1*ys])+ + (0.534522f)*(xpr[4]*ypr[0*ys]-xpi[4]*ypi[0*ys]); + rpi[2*rs]+= + (0.534522f)*(xpr[0]*ypi[4*ys]+xpi[0]*ypr[4*ys])+ + (0.267261f)*(xpr[1]*ypi[3*ys]+xpi[1]*ypr[3*ys])+ + (-0.534522f)*(xpr[2]*ypi[2*ys]+xpi[2]*ypr[2*ys])+ + (0.267261f)*(xpr[3]*ypi[1*ys]+xpi[3]*ypr[1*ys])+ + (0.534522f)*(xpr[4]*ypi[0*ys]+xpi[4]*ypr[0*ys]); + rpr[3*rs]+= + (0.654654f)*(xpr[1]*ypr[4*ys]-xpi[1]*ypi[4*ys])+ + (-0.267261f)*(xpr[2]*ypr[3*ys]-xpi[2]*ypi[3*ys])+ + (-0.267261f)*(xpr[3]*ypr[2*ys]-xpi[3]*ypi[2*ys])+ + (0.654654f)*(xpr[4]*ypr[1*ys]-xpi[4]*ypi[1*ys]); + rpi[3*rs]+= + (0.654654f)*(xpr[1]*ypi[4*ys]+xpi[1]*ypr[4*ys])+ + (-0.267261f)*(xpr[2]*ypi[3*ys]+xpi[2]*ypr[3*ys])+ + (-0.267261f)*(xpr[3]*ypi[2*ys]+xpi[3]*ypr[2*ys])+ + (0.654654f)*(xpr[4]*ypi[1*ys]+xpi[4]*ypr[1*ys]); + rpr[4*rs]+= + (0.534522f)*(xpr[2]*ypr[4*ys]-xpi[2]*ypi[4*ys])+ + (-0.654654f)*(xpr[3]*ypr[3*ys]-xpi[3]*ypi[3*ys])+ + (0.534522f)*(xpr[4]*ypr[2*ys]-xpi[4]*ypi[2*ys]); + rpi[4*rs]+= + (0.534522f)*(xpr[2]*ypi[4*ys]+xpi[2]*ypr[4*ys])+ + (-0.654654f)*(xpr[3]*ypi[3*ys]+xpi[3]*ypr[3*ys])+ + (0.534522f)*(xpr[4]*ypi[2*ys]+xpi[4]*ypr[2*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_2_2_3(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (-0.707107f)*(xpr[0]*ypr[1*ys]-xpi[0]*ypi[1*ys])+ + (0.707107f)*(xpr[1]*ypr[0*ys]-xpi[1]*ypi[0*ys]); + rpi[0*rs]+= + (-0.707107f)*(xpr[0]*ypi[1*ys]+xpi[0]*ypr[1*ys])+ + (0.707107f)*(xpr[1]*ypi[0*ys]+xpi[1]*ypr[0*ys]); + rpr[1*rs]+= + (-0.707107f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys])+ + (0.000000f)*(xpr[1]*ypr[1*ys]-xpi[1]*ypi[1*ys])+ + (0.707107f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[1*rs]+= + (-0.707107f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys])+ + (0.000000f)*(xpr[1]*ypi[1*ys]+xpi[1]*ypr[1*ys])+ + (0.707107f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); + rpr[2*rs]+= + (-0.547723f)*(xpr[0]*ypr[3*ys]-xpi[0]*ypi[3*ys])+ + (-0.447214f)*(xpr[1]*ypr[2*ys]-xpi[1]*ypi[2*ys])+ + (0.447214f)*(xpr[2]*ypr[1*ys]-xpi[2]*ypi[1*ys])+ + (0.547723f)*(xpr[3]*ypr[0*ys]-xpi[3]*ypi[0*ys]); + rpi[2*rs]+= + (-0.547723f)*(xpr[0]*ypi[3*ys]+xpi[0]*ypr[3*ys])+ + (-0.447214f)*(xpr[1]*ypi[2*ys]+xpi[1]*ypr[2*ys])+ + (0.447214f)*(xpr[2]*ypi[1*ys]+xpi[2]*ypr[1*ys])+ + (0.547723f)*(xpr[3]*ypi[0*ys]+xpi[3]*ypr[0*ys]); + rpr[3*rs]+= + (-0.316228f)*(xpr[0]*ypr[4*ys]-xpi[0]*ypi[4*ys])+ + (-0.632456f)*(xpr[1]*ypr[3*ys]-xpi[1]*ypi[3*ys])+ + (0.000000f)*(xpr[2]*ypr[2*ys]-xpi[2]*ypi[2*ys])+ + (0.632456f)*(xpr[3]*ypr[1*ys]-xpi[3]*ypi[1*ys])+ + (0.316228f)*(xpr[4]*ypr[0*ys]-xpi[4]*ypi[0*ys]); + rpi[3*rs]+= + (-0.316228f)*(xpr[0]*ypi[4*ys]+xpi[0]*ypr[4*ys])+ + (-0.632456f)*(xpr[1]*ypi[3*ys]+xpi[1]*ypr[3*ys])+ + (0.000000f)*(xpr[2]*ypi[2*ys]+xpi[2]*ypr[2*ys])+ + (0.632456f)*(xpr[3]*ypi[1*ys]+xpi[3]*ypr[1*ys])+ + (0.316228f)*(xpr[4]*ypi[0*ys]+xpi[4]*ypr[0*ys]); + rpr[4*rs]+= + (-0.547723f)*(xpr[1]*ypr[4*ys]-xpi[1]*ypi[4*ys])+ + (-0.447214f)*(xpr[2]*ypr[3*ys]-xpi[2]*ypi[3*ys])+ + (0.447214f)*(xpr[3]*ypr[2*ys]-xpi[3]*ypi[2*ys])+ + (0.547723f)*(xpr[4]*ypr[1*ys]-xpi[4]*ypi[1*ys]); + rpi[4*rs]+= + (-0.547723f)*(xpr[1]*ypi[4*ys]+xpi[1]*ypr[4*ys])+ + (-0.447214f)*(xpr[2]*ypi[3*ys]+xpi[2]*ypr[3*ys])+ + (0.447214f)*(xpr[3]*ypi[2*ys]+xpi[3]*ypr[2*ys])+ + (0.547723f)*(xpr[4]*ypi[1*ys]+xpi[4]*ypr[1*ys]); + rpr[5*rs]+= + (-0.707107f)*(xpr[2]*ypr[4*ys]-xpi[2]*ypi[4*ys])+ + (0.000000f)*(xpr[3]*ypr[3*ys]-xpi[3]*ypi[3*ys])+ + (0.707107f)*(xpr[4]*ypr[2*ys]-xpi[4]*ypi[2*ys]); + rpi[5*rs]+= + (-0.707107f)*(xpr[2]*ypi[4*ys]+xpi[2]*ypr[4*ys])+ + (0.000000f)*(xpr[3]*ypi[3*ys]+xpi[3]*ypr[3*ys])+ + (0.707107f)*(xpr[4]*ypi[2*ys]+xpi[4]*ypr[2*ys]); + rpr[6*rs]+= + (-0.707107f)*(xpr[3]*ypr[4*ys]-xpi[3]*ypi[4*ys])+ + (0.707107f)*(xpr[4]*ypr[3*ys]-xpi[4]*ypi[3*ys]); + rpi[6*rs]+= + (-0.707107f)*(xpr[3]*ypi[4*ys]+xpi[3]*ypr[4*ys])+ + (0.707107f)*(xpr[4]*ypi[3*ys]+xpi[4]*ypr[3*ys]); +} + +__forceinline__ __device__ void SO3part_addCGproduct_explicit_kernel_2_2_4(const float* xpr, const float* xpi, const float* ypr, const float* ypi, const int ys, float* rpr, float* rpi, const int rs){ + rpr[0*rs]+= + (1.000000f)*(xpr[0]*ypr[0*ys]-xpi[0]*ypi[0*ys]); + rpi[0*rs]+= + (1.000000f)*(xpr[0]*ypi[0*ys]+xpi[0]*ypr[0*ys]); + rpr[1*rs]+= + (0.707107f)*(xpr[0]*ypr[1*ys]-xpi[0]*ypi[1*ys])+ + (0.707107f)*(xpr[1]*ypr[0*ys]-xpi[1]*ypi[0*ys]); + rpi[1*rs]+= + (0.707107f)*(xpr[0]*ypi[1*ys]+xpi[0]*ypr[1*ys])+ + (0.707107f)*(xpr[1]*ypi[0*ys]+xpi[1]*ypr[0*ys]); + rpr[2*rs]+= + (0.462910f)*(xpr[0]*ypr[2*ys]-xpi[0]*ypi[2*ys])+ + (0.755929f)*(xpr[1]*ypr[1*ys]-xpi[1]*ypi[1*ys])+ + (0.462910f)*(xpr[2]*ypr[0*ys]-xpi[2]*ypi[0*ys]); + rpi[2*rs]+= + (0.462910f)*(xpr[0]*ypi[2*ys]+xpi[0]*ypr[2*ys])+ + (0.755929f)*(xpr[1]*ypi[1*ys]+xpi[1]*ypr[1*ys])+ + (0.462910f)*(xpr[2]*ypi[0*ys]+xpi[2]*ypr[0*ys]); + rpr[3*rs]+= + (0.267261f)*(xpr[0]*ypr[3*ys]-xpi[0]*ypi[3*ys])+ + (0.654654f)*(xpr[1]*ypr[2*ys]-xpi[1]*ypi[2*ys])+ + (0.654654f)*(xpr[2]*ypr[1*ys]-xpi[2]*ypi[1*ys])+ + (0.267261f)*(xpr[3]*ypr[0*ys]-xpi[3]*ypi[0*ys]); + rpi[3*rs]+= + (0.267261f)*(xpr[0]*ypi[3*ys]+xpi[0]*ypr[3*ys])+ + (0.654654f)*(xpr[1]*ypi[2*ys]+xpi[1]*ypr[2*ys])+ + (0.654654f)*(xpr[2]*ypi[1*ys]+xpi[2]*ypr[1*ys])+ + (0.267261f)*(xpr[3]*ypi[0*ys]+xpi[3]*ypr[0*ys]); + rpr[4*rs]+= + (0.119523f)*(xpr[0]*ypr[4*ys]-xpi[0]*ypi[4*ys])+ + (0.478091f)*(xpr[1]*ypr[3*ys]-xpi[1]*ypi[3*ys])+ + (0.717137f)*(xpr[2]*ypr[2*ys]-xpi[2]*ypi[2*ys])+ + (0.478091f)*(xpr[3]*ypr[1*ys]-xpi[3]*ypi[1*ys])+ + (0.119523f)*(xpr[4]*ypr[0*ys]-xpi[4]*ypi[0*ys]); + rpi[4*rs]+= + (0.119523f)*(xpr[0]*ypi[4*ys]+xpi[0]*ypr[4*ys])+ + (0.478091f)*(xpr[1]*ypi[3*ys]+xpi[1]*ypr[3*ys])+ + (0.717137f)*(xpr[2]*ypi[2*ys]+xpi[2]*ypr[2*ys])+ + (0.478091f)*(xpr[3]*ypi[1*ys]+xpi[3]*ypr[1*ys])+ + (0.119523f)*(xpr[4]*ypi[0*ys]+xpi[4]*ypr[0*ys]); + rpr[5*rs]+= + (0.267261f)*(xpr[1]*ypr[4*ys]-xpi[1]*ypi[4*ys])+ + (0.654654f)*(xpr[2]*ypr[3*ys]-xpi[2]*ypi[3*ys])+ + (0.654654f)*(xpr[3]*ypr[2*ys]-xpi[3]*ypi[2*ys])+ + (0.267261f)*(xpr[4]*ypr[1*ys]-xpi[4]*ypi[1*ys]); + rpi[5*rs]+= + (0.267261f)*(xpr[1]*ypi[4*ys]+xpi[1]*ypr[4*ys])+ + (0.654654f)*(xpr[2]*ypi[3*ys]+xpi[2]*ypr[3*ys])+ + (0.654654f)*(xpr[3]*ypi[2*ys]+xpi[3]*ypr[2*ys])+ + (0.267261f)*(xpr[4]*ypi[1*ys]+xpi[4]*ypr[1*ys]); + rpr[6*rs]+= + (0.462910f)*(xpr[2]*ypr[4*ys]-xpi[2]*ypi[4*ys])+ + (0.755929f)*(xpr[3]*ypr[3*ys]-xpi[3]*ypi[3*ys])+ + (0.462910f)*(xpr[4]*ypr[2*ys]-xpi[4]*ypi[2*ys]); + rpi[6*rs]+= + (0.462910f)*(xpr[2]*ypi[4*ys]+xpi[2]*ypr[4*ys])+ + (0.755929f)*(xpr[3]*ypi[3*ys]+xpi[3]*ypr[3*ys])+ + (0.462910f)*(xpr[4]*ypi[2*ys]+xpi[4]*ypr[2*ys]); + rpr[7*rs]+= + (0.707107f)*(xpr[3]*ypr[4*ys]-xpi[3]*ypi[4*ys])+ + (0.707107f)*(xpr[4]*ypr[3*ys]-xpi[4]*ypi[3*ys]); + rpi[7*rs]+= + (0.707107f)*(xpr[3]*ypi[4*ys]+xpi[3]*ypr[4*ys])+ + (0.707107f)*(xpr[4]*ypi[3*ys]+xpi[4]*ypr[3*ys]); + rpr[8*rs]+= + (1.000000f)*(xpr[4]*ypr[4*ys]-xpi[4]*ypi[4*ys]); + rpi[8*rs]+= + (1.000000f)*(xpr[4]*ypi[4*ys]+xpi[4]*ypr[4*ys]); +} + diff --git a/cuda/SO3part_addCGtransform.cu b/cuda/SO3part_addCGtransform.cu new file mode 100644 index 0000000..b0314e0 --- /dev/null +++ b/cuda/SO3part_addCGtransform.cu @@ -0,0 +1,238 @@ +/* + * This file is part of GElib, a C++/CUDA library for group equivariant + * tensor operations. + * + * Copyright (c) 2023, Imre Risi Kondor + * + * This source code file is subject to the terms of the noncommercial + * license distributed with GElib in the file NONCOMMERICAL.TXT. Commercial + * use is prohibited. All redistributed versions of this file (in orginal + * or modified form) must retain this copyright notice and must be + * accompanied by a verbatim copy of the license. + * + */ + +#ifndef _SO3part_addCGtransform_cu +#define _SO3part_addCGtransform_cu + +#include +#include + +#include "SO3_CGbank.hpp" +#include "Ctensor3_view.hpp" +#include "Ctensor4_view.hpp" +#include "cuda_loaders.cu" + + +extern GElib::SO3_CGbank SO3_cgbank; +//extern long int opcount; + +// Process ncells number of cells in one call +__global__ void SO3part_addCGtransform_kernel(const cnine::Ctensor3_view r, const cnine::Ctensor4_view x, + const int Cptr, float* cptr_global, const bool preloadCG, const int ncells){ + + extern __shared__ unsigned char _shared[]; + const int b=blockIdx.x; + const int t=threadIdx.x; + const int t0=t/x.n3; // cell selector + const int t1=t%x.n3; // channel selector within cell + const int actual_ncells=min(ncells,r.n0-b*ncells); + + int l1=(x.n1-1)/2; + int l2=(x.n2-1)/2; + int l=(r.n1-1)/2; + int L2=x.n2; + + float* cptr; + //float* xpr; + if(preloadCG){ + cptr=reinterpret_cast(_shared); + //xpr=cptr+((x.n1*x.n2-1)/32+1)*32; + if(Cptr>=0) loadf(cptr,reinterpret_cast(cg_cmem)+Cptr,x.n1*x.n2); + else loadf(cptr,cptr_global,x.n1*x.n2); + }else{ + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + //xpr=reinterpret_cast(_shared); + } + + //loadf(xpr,x.arr+b*ncells*x.s0,actual_ncells*x.n1*x.n2*x.n3); + __syncthreads(); + + if(t0(_shared); + xpr=cptr+((x.n1*y.n1-1)/32+1)*32; + if(Cptr>=0) loadf(cptr,reinterpret_cast(cg_cmem)+Cptr,x.n1*y.n1); + else loadf(cptr,cptr_global,x.n1*y.n1); + }else{ + if(Cptr>=0) cptr=reinterpret_cast(cg_cmem)+Cptr; + else cptr=cptr_global; + xpr=reinterpret_cast(_shared); + } + + float* xpi=xpr+x.n1*x.n3; + float* ypr=xpr+((2*x.n1*x.n3-1)/32+1)*32; + float* ypi=ypr+y.n1*y.n3; + + int xs1=x.n3; + int ys1=y.n3; + int rs1=r.s1; + + assert(x.n2==y.n2); + + for(int i=0; i0 && nlines<=384){ + bool preloadCG=(nlines+clines<=384); + SO3part_addCGtransform_kernel<<>> + (r,x,Cptr,cptr,preloadCG,ncells); + return; + } + } + + + /* + // Otherwise tile the inputs to chunks of width 32 + const int tilesize=std::min(x.n2,32); + cnine::Ctensor4_view_t3 xtiled(x,tilesize); + cnine::Ctensor4_view_t3 ytiled(y,tilesize); + int nlines=cnine::roundup(xtiled.n1*tilesize*2,32)/32+ + cnine::roundup(ytiled.n1*tilesize*2,32)/32; + + if(nlines<=384){ + bool preloadCG=(nlines+clines<=384); + SO3part_addCGtransform_tiled_kernel<<>> + (r,xtiled,ytiled,Cptr,cptr,preloadCG); + return; + } + */ + + GELIB_ERROR("Inputs too large to load in shared memory."); + } + + +} + + +#endif + + + diff --git a/cuda/cuda_loaders.cu b/cuda/cuda_loaders.cu new file mode 100644 index 0000000..b4a75bc --- /dev/null +++ b/cuda/cuda_loaders.cu @@ -0,0 +1,189 @@ +/* + * This file is part of GElib, a C++/CUDA library for group equivariant + * tensor operations. + * + * Copyright (c) 2023, Imre Risi Kondor + * + * This source code file is subject to the terms of the noncommercial + * license distributed with GElib in the file NONCOMMERICAL.TXT. Commercial + * use is prohibited. All redistributed versions of this file (in orginal + * or modified form) must retain this copyright notice and must be + * accompanied by a verbatim copy of the license. + * + */ + +#ifndef _GElib_cuda_loaders +#define _GElib_cuda_loaders + +#include +#include +#include "Ctensor3_view.hpp" +#include "Ctensor4_view.hpp" + +#define tix threadIdx.x + +/* +__forceinline__ __device__ unsigned dynamic_smem_size(){ + unsigned ret; + asm volatile ("mov.u32 %0, %dynamic_smem_size;" : "=r"(ret)); + return ret; +} +*/ + +/* +__forceinline__ __device__ void loadf(float* dest, const float* src, const int n, const int t){ + int nthreads=blockDim.x; + int I=n/nthreads; + for(int i=0; i