xref: /libCEED/rust/libceed-sys/c-src/backends/hip-gen/ceed-hip-gen-operator.c (revision 7d8d0e25636a94a27ff75b3dec09737e24cdb0fe)
1*7d8d0e25Snbeams // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2*7d8d0e25Snbeams // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3*7d8d0e25Snbeams // All Rights reserved. See files LICENSE and NOTICE for details.
4*7d8d0e25Snbeams //
5*7d8d0e25Snbeams // This file is part of CEED, a collection of benchmarks, miniapps, software
6*7d8d0e25Snbeams // libraries and APIs for efficient high-order finite element and spectral
7*7d8d0e25Snbeams // element discretizations for exascale applications. For more information and
8*7d8d0e25Snbeams // source code availability see http://github.com/ceed.
9*7d8d0e25Snbeams //
10*7d8d0e25Snbeams // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11*7d8d0e25Snbeams // a collaborative effort of two U.S. Department of Energy organizations (Office
12*7d8d0e25Snbeams // of Science and the National Nuclear Security Administration) responsible for
13*7d8d0e25Snbeams // the planning and preparation of a capable exascale ecosystem, including
14*7d8d0e25Snbeams // software, applications, hardware, advanced system engineering and early
15*7d8d0e25Snbeams // testbed platforms, in support of the nation's exascale computing imperative.
16*7d8d0e25Snbeams 
17*7d8d0e25Snbeams #include "ceed-hip-gen.h"
18*7d8d0e25Snbeams #include "ceed-hip-gen-operator-build.h"
19*7d8d0e25Snbeams #include "../hip/ceed-hip-compile.h"
20*7d8d0e25Snbeams 
21*7d8d0e25Snbeams //------------------------------------------------------------------------------
22*7d8d0e25Snbeams // Destroy operator
23*7d8d0e25Snbeams //------------------------------------------------------------------------------
24*7d8d0e25Snbeams static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
25*7d8d0e25Snbeams   int ierr;
26*7d8d0e25Snbeams   CeedOperator_Hip_gen *impl;
27*7d8d0e25Snbeams   ierr = CeedOperatorGetData(op, &impl); CeedChk(ierr);
28*7d8d0e25Snbeams   ierr = CeedFree(&impl); CeedChk(ierr);
29*7d8d0e25Snbeams   return 0;
30*7d8d0e25Snbeams }
31*7d8d0e25Snbeams 
32*7d8d0e25Snbeams //------------------------------------------------------------------------------
33*7d8d0e25Snbeams // Apply and add to output
34*7d8d0e25Snbeams //------------------------------------------------------------------------------
35*7d8d0e25Snbeams static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector invec,
36*7d8d0e25Snbeams                                         CeedVector outvec, CeedRequest *request) {
37*7d8d0e25Snbeams   int ierr;
38*7d8d0e25Snbeams   Ceed ceed;
39*7d8d0e25Snbeams   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
40*7d8d0e25Snbeams   CeedOperator_Hip_gen *data;
41*7d8d0e25Snbeams   ierr = CeedOperatorGetData(op, &data); CeedChk(ierr);
42*7d8d0e25Snbeams   CeedQFunction qf;
43*7d8d0e25Snbeams   CeedQFunction_Hip_gen *qf_data;
44*7d8d0e25Snbeams   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
45*7d8d0e25Snbeams   ierr = CeedQFunctionGetData(qf, &qf_data); CeedChk(ierr);
46*7d8d0e25Snbeams   CeedInt nelem, numinputfields, numoutputfields;
47*7d8d0e25Snbeams   ierr = CeedOperatorGetNumElements(op, &nelem); CeedChk(ierr);
48*7d8d0e25Snbeams   ierr = CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
49*7d8d0e25Snbeams   CeedChk(ierr);
50*7d8d0e25Snbeams   CeedOperatorField *opinputfields, *opoutputfields;
51*7d8d0e25Snbeams   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
52*7d8d0e25Snbeams   CeedChk(ierr);
53*7d8d0e25Snbeams   CeedQFunctionField *qfinputfields, *qfoutputfields;
54*7d8d0e25Snbeams   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
55*7d8d0e25Snbeams   CeedChk(ierr);
56*7d8d0e25Snbeams   CeedEvalMode emode;
57*7d8d0e25Snbeams   CeedVector vec, outvecs[16] = {};
58*7d8d0e25Snbeams 
59*7d8d0e25Snbeams   //Creation of the operator
60*7d8d0e25Snbeams   ierr = CeedHipGenOperatorBuild(op); CeedChk(ierr);
61*7d8d0e25Snbeams 
62*7d8d0e25Snbeams   // Input vectors
63*7d8d0e25Snbeams   for (CeedInt i = 0; i < numinputfields; i++) {
64*7d8d0e25Snbeams     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
65*7d8d0e25Snbeams     CeedChk(ierr);
66*7d8d0e25Snbeams     if (emode == CEED_EVAL_WEIGHT) { // Skip
67*7d8d0e25Snbeams       data->fields.in[i] = NULL;
68*7d8d0e25Snbeams     } else {
69*7d8d0e25Snbeams       // Get input vector
70*7d8d0e25Snbeams       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
71*7d8d0e25Snbeams       if (vec == CEED_VECTOR_ACTIVE) vec = invec;
72*7d8d0e25Snbeams       ierr = CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.in[i]);
73*7d8d0e25Snbeams       CeedChk(ierr);
74*7d8d0e25Snbeams     }
75*7d8d0e25Snbeams   }
76*7d8d0e25Snbeams 
77*7d8d0e25Snbeams   // Output vectors
78*7d8d0e25Snbeams   for (CeedInt i = 0; i < numoutputfields; i++) {
79*7d8d0e25Snbeams     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
80*7d8d0e25Snbeams     CeedChk(ierr);
81*7d8d0e25Snbeams     if (emode == CEED_EVAL_WEIGHT) { // Skip
82*7d8d0e25Snbeams       data->fields.out[i] = NULL;
83*7d8d0e25Snbeams     } else {
84*7d8d0e25Snbeams       // Get output vector
85*7d8d0e25Snbeams       ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
86*7d8d0e25Snbeams       if (vec == CEED_VECTOR_ACTIVE) vec = outvec;
87*7d8d0e25Snbeams       outvecs[i] = vec;
88*7d8d0e25Snbeams       // Check for multiple output modes
89*7d8d0e25Snbeams       CeedInt index = -1;
90*7d8d0e25Snbeams       for (CeedInt j = 0; j < i; j++) {
91*7d8d0e25Snbeams         if (vec == outvecs[j]) {
92*7d8d0e25Snbeams           index = j;
93*7d8d0e25Snbeams           break;
94*7d8d0e25Snbeams         }
95*7d8d0e25Snbeams       }
96*7d8d0e25Snbeams       if (index == -1) {
97*7d8d0e25Snbeams         ierr = CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.out[i]);
98*7d8d0e25Snbeams         CeedChk(ierr);
99*7d8d0e25Snbeams       } else {
100*7d8d0e25Snbeams         data->fields.out[i] = data->fields.out[index];
101*7d8d0e25Snbeams       }
102*7d8d0e25Snbeams     }
103*7d8d0e25Snbeams   }
104*7d8d0e25Snbeams 
105*7d8d0e25Snbeams   // Get context data
106*7d8d0e25Snbeams   CeedQFunctionContext ctx;
107*7d8d0e25Snbeams   ierr = CeedQFunctionGetInnerContext(qf, &ctx); CeedChk(ierr);
108*7d8d0e25Snbeams   if (ctx) {
109*7d8d0e25Snbeams     ierr = CeedQFunctionContextGetData(ctx, CEED_MEM_DEVICE, &qf_data->d_c);
110*7d8d0e25Snbeams     CeedChk(ierr);
111*7d8d0e25Snbeams   }
112*7d8d0e25Snbeams 
113*7d8d0e25Snbeams   // Apply operator
114*7d8d0e25Snbeams   void *opargs[] = {(void *) &nelem, &qf_data->d_c, &data->indices,
115*7d8d0e25Snbeams                     &data->fields, &data->B, &data->G, &data->W
116*7d8d0e25Snbeams                    };
117*7d8d0e25Snbeams   const CeedInt dim = data->dim;
118*7d8d0e25Snbeams   const CeedInt Q1d = data->Q1d;
119*7d8d0e25Snbeams   const CeedInt P1d = data->maxP1d;
120*7d8d0e25Snbeams   const CeedInt thread1d = CeedIntMax(Q1d, P1d);
121*7d8d0e25Snbeams   if (dim==1) {
122*7d8d0e25Snbeams     CeedInt elemsPerBlock = 32*thread1d > 256? 256/thread1d : 32;
123*7d8d0e25Snbeams     elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
124*7d8d0e25Snbeams     CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
125*7d8d0e25Snbeams                                            ? 1 : 0 );
126*7d8d0e25Snbeams     CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
127*7d8d0e25Snbeams     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, thread1d, 1,
128*7d8d0e25Snbeams                                      elemsPerBlock, sharedMem, opargs);
129*7d8d0e25Snbeams   } else if (dim==2) {
130*7d8d0e25Snbeams     const CeedInt elemsPerBlock = thread1d<4? 16 : 2;
131*7d8d0e25Snbeams     CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
132*7d8d0e25Snbeams                                            ? 1 : 0 );
133*7d8d0e25Snbeams     CeedInt sharedMem = elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
134*7d8d0e25Snbeams     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, thread1d, thread1d,
135*7d8d0e25Snbeams                                      elemsPerBlock, sharedMem, opargs);
136*7d8d0e25Snbeams   } else if (dim==3) {
137*7d8d0e25Snbeams     const CeedInt elemsPerBlock = thread1d<6? 4 : (thread1d<8? 2 : 1);
138*7d8d0e25Snbeams     CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
139*7d8d0e25Snbeams                                            ? 1 : 0 );
140*7d8d0e25Snbeams     CeedInt sharedMem = elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
141*7d8d0e25Snbeams     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, thread1d, thread1d,
142*7d8d0e25Snbeams                                      elemsPerBlock, sharedMem, opargs);
143*7d8d0e25Snbeams   }
144*7d8d0e25Snbeams   CeedChk(ierr);
145*7d8d0e25Snbeams 
146*7d8d0e25Snbeams   // Restore input arrays
147*7d8d0e25Snbeams   for (CeedInt i = 0; i < numinputfields; i++) {
148*7d8d0e25Snbeams     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
149*7d8d0e25Snbeams     CeedChk(ierr);
150*7d8d0e25Snbeams     if (emode == CEED_EVAL_WEIGHT) { // Skip
151*7d8d0e25Snbeams     } else {
152*7d8d0e25Snbeams       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
153*7d8d0e25Snbeams       if (vec == CEED_VECTOR_ACTIVE) vec = invec;
154*7d8d0e25Snbeams       ierr = CeedVectorRestoreArrayRead(vec, &data->fields.in[i]);
155*7d8d0e25Snbeams       CeedChk(ierr);
156*7d8d0e25Snbeams     }
157*7d8d0e25Snbeams   }
158*7d8d0e25Snbeams 
159*7d8d0e25Snbeams   // Restore output arrays
160*7d8d0e25Snbeams   for (CeedInt i = 0; i < numoutputfields; i++) {
161*7d8d0e25Snbeams     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
162*7d8d0e25Snbeams     CeedChk(ierr);
163*7d8d0e25Snbeams     if (emode == CEED_EVAL_WEIGHT) { // Skip
164*7d8d0e25Snbeams     } else {
165*7d8d0e25Snbeams       ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
166*7d8d0e25Snbeams       if (vec == CEED_VECTOR_ACTIVE) vec = outvec;
167*7d8d0e25Snbeams       // Check for multiple output modes
168*7d8d0e25Snbeams       CeedInt index = -1;
169*7d8d0e25Snbeams       for (CeedInt j = 0; j < i; j++) {
170*7d8d0e25Snbeams         if (vec == outvecs[j]) {
171*7d8d0e25Snbeams           index = j;
172*7d8d0e25Snbeams           break;
173*7d8d0e25Snbeams         }
174*7d8d0e25Snbeams       }
175*7d8d0e25Snbeams       if (index == -1) {
176*7d8d0e25Snbeams         ierr = CeedVectorRestoreArray(vec, &data->fields.out[i]);
177*7d8d0e25Snbeams         CeedChk(ierr);
178*7d8d0e25Snbeams       }
179*7d8d0e25Snbeams     }
180*7d8d0e25Snbeams   }
181*7d8d0e25Snbeams 
182*7d8d0e25Snbeams   // Restore context data
183*7d8d0e25Snbeams   if (ctx) {
184*7d8d0e25Snbeams     ierr = CeedQFunctionContextRestoreData(ctx, &qf_data->d_c);
185*7d8d0e25Snbeams     CeedChk(ierr);
186*7d8d0e25Snbeams   }
187*7d8d0e25Snbeams   return 0;
188*7d8d0e25Snbeams }
189*7d8d0e25Snbeams 
190*7d8d0e25Snbeams //------------------------------------------------------------------------------
191*7d8d0e25Snbeams // Create FDM element inverse not supported
192*7d8d0e25Snbeams //------------------------------------------------------------------------------
193*7d8d0e25Snbeams static int CeedOperatorCreateFDMElementInverse_Hip(CeedOperator op) {
194*7d8d0e25Snbeams   // LCOV_EXCL_START
195*7d8d0e25Snbeams   int ierr;
196*7d8d0e25Snbeams   Ceed ceed;
197*7d8d0e25Snbeams   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
198*7d8d0e25Snbeams   return CeedError(ceed, 1, "Backend does not implement FDM inverse creation");
199*7d8d0e25Snbeams   // LCOV_EXCL_STOP
200*7d8d0e25Snbeams }
201*7d8d0e25Snbeams 
202*7d8d0e25Snbeams //------------------------------------------------------------------------------
203*7d8d0e25Snbeams // Create operator
204*7d8d0e25Snbeams //------------------------------------------------------------------------------
205*7d8d0e25Snbeams int CeedOperatorCreate_Hip_gen(CeedOperator op) {
206*7d8d0e25Snbeams   int ierr;
207*7d8d0e25Snbeams   Ceed ceed;
208*7d8d0e25Snbeams   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
209*7d8d0e25Snbeams   CeedOperator_Hip_gen *impl;
210*7d8d0e25Snbeams 
211*7d8d0e25Snbeams   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
212*7d8d0e25Snbeams   ierr = CeedOperatorSetData(op, impl); CeedChk(ierr);
213*7d8d0e25Snbeams 
214*7d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Operator", op, "CreateFDMElementInverse",
215*7d8d0e25Snbeams                                 CeedOperatorCreateFDMElementInverse_Hip);
216*7d8d0e25Snbeams   CeedChk(ierr);
217*7d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd",
218*7d8d0e25Snbeams                                 CeedOperatorApplyAdd_Hip_gen); CeedChk(ierr);
219*7d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
220*7d8d0e25Snbeams                                 CeedOperatorDestroy_Hip_gen); CeedChk(ierr);
221*7d8d0e25Snbeams   return 0;
222*7d8d0e25Snbeams }
223*7d8d0e25Snbeams //------------------------------------------------------------------------------
224