xref: /petsc/src/dm/dt/fe/impls/opencl/feopencl.c (revision d2522c19e8fa9bca20aaca277941d9a63e71db6a)
1 #include <petsc/private/petscfeimpl.h> /*I "petscfe.h" I*/
2 
3 #if defined(PETSC_HAVE_OPENCL)
4 
5 static PetscErrorCode PetscFEDestroy_OpenCL(PetscFE fem) {
6   PetscFE_OpenCL *ocl = (PetscFE_OpenCL *)fem->data;
7 
8   PetscFunctionBegin;
9   PetscCall(clReleaseCommandQueue(ocl->queue_id));
10   ocl->queue_id = 0;
11   PetscCall(clReleaseContext(ocl->ctx_id));
12   ocl->ctx_id = 0;
13   PetscCall(PetscFree(ocl));
14   PetscFunctionReturn(0);
15 }
16 
17 #define PetscCallSTR(err) \
18   do { \
19     PetscCall(err); \
20     string_tail += count; \
21     PetscCheck(string_tail != end_of_buffer, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Buffer overflow"); \
22   } while (0)
23 enum {
24   LAPLACIAN  = 0,
25   ELASTICITY = 1
26 };
27 
28 /* NOTE: This is now broken for vector problems. Must redo loops to respect vector basis elements */
29 /* dim     Number of spatial dimensions:          2                   */
30 /* N_b     Number of basis functions:             generated           */
31 /* N_{bt}  Number of total basis functions:       N_b * N_{comp}      */
32 /* N_q     Number of quadrature points:           generated           */
33 /* N_{bs}  Number of block cells                  LCM(N_b, N_q)       */
34 /* N_{bst} Number of block cell components        LCM(N_{bt}, N_q)    */
35 /* N_{bl}  Number of concurrent blocks            generated           */
36 /* N_t     Number of threads:                     N_{bl} * N_{bs}     */
37 /* N_{cbc} Number of concurrent basis      cells: N_{bl} * N_q        */
38 /* N_{cqc} Number of concurrent quadrature cells: N_{bl} * N_b        */
39 /* N_{sbc} Number of serial     basis      cells: N_{bs} / N_q        */
40 /* N_{sqc} Number of serial     quadrature cells: N_{bs} / N_b        */
41 /* N_{cb}  Number of serial cell batches:         input               */
42 /* N_c     Number of total cells:                 N_{cb}*N_{t}/N_{comp} */
43 static PetscErrorCode PetscFEOpenCLGenerateIntegrationCode(PetscFE fem, char **string_buffer, PetscInt buffer_length, PetscBool useAux, PetscInt N_bl) {
44   PetscFE_OpenCL  *ocl = (PetscFE_OpenCL *)fem->data;
45   PetscQuadrature  q;
46   char            *string_tail   = *string_buffer;
47   char            *end_of_buffer = *string_buffer + buffer_length;
48   char             float_str[] = "float", double_str[] = "double";
49   char            *numeric_str    = &(float_str[0]);
50   PetscInt         op             = ocl->op;
51   PetscBool        useField       = PETSC_FALSE;
52   PetscBool        useFieldDer    = PETSC_TRUE;
53   PetscBool        useFieldAux    = useAux;
54   PetscBool        useFieldDerAux = PETSC_FALSE;
55   PetscBool        useF0          = PETSC_TRUE;
56   PetscBool        useF1          = PETSC_TRUE;
57   const PetscReal *points, *weights;
58   PetscTabulation  T;
59   PetscInt         dim, qNc, N_b, N_c, N_q, N_t, p, d, b, c;
60   size_t           count;
61 
62   PetscFunctionBegin;
63   PetscCall(PetscFEGetSpatialDimension(fem, &dim));
64   PetscCall(PetscFEGetDimension(fem, &N_b));
65   PetscCall(PetscFEGetNumComponents(fem, &N_c));
66   PetscCall(PetscFEGetQuadrature(fem, &q));
67   PetscCall(PetscQuadratureGetData(q, NULL, &qNc, &N_q, &points, &weights));
68   PetscCheck(qNc == 1, PETSC_COMM_SELF, PETSC_ERR_SUP, "Only supports scalar quadrature, not %" PetscInt_FMT " components", qNc);
69   N_t = N_b * N_c * N_q * N_bl;
70   /* Enable device extension for double precision */
71   if (ocl->realType == PETSC_DOUBLE) {
72     PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
73                                     "#if defined(cl_khr_fp64)\n"
74                                     "#  pragma OPENCL EXTENSION cl_khr_fp64: enable\n"
75                                     "#elif defined(cl_amd_fp64)\n"
76                                     "#  pragma OPENCL EXTENSION cl_amd_fp64: enable\n"
77                                     "#endif\n",
78                                     &count));
79     numeric_str = &(double_str[0]);
80   }
81   /* Kernel API */
82   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
83                                   "\n"
84                                   "__kernel void integrateElementQuadrature(int N_cb, __global %s *coefficients, __global %s *coefficientsAux, __global %s *jacobianInverses, __global %s *jacobianDeterminants, __global %s *elemVec)\n"
85                                   "{\n",
86                                   &count, numeric_str, numeric_str, numeric_str, numeric_str, numeric_str));
87   /* Quadrature */
88   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
89                                   "  /* Quadrature points\n"
90                                   "   - (x1,y1,x2,y2,...) */\n"
91                                   "  const %s points[%d] = {\n",
92                                   &count, numeric_str, N_q * dim));
93   for (p = 0; p < N_q; ++p) {
94     for (d = 0; d < dim; ++d) { PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "%g,\n", &count, points[p * dim + d])); }
95   }
96   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "};\n", &count));
97   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
98                                   "  /* Quadrature weights\n"
99                                   "   - (v1,v2,...) */\n"
100                                   "  const %s weights[%d] = {\n",
101                                   &count, numeric_str, N_q));
102   for (p = 0; p < N_q; ++p) { PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "%g,\n", &count, weights[p])); }
103   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "};\n", &count));
104   /* Basis Functions */
105   PetscCall(PetscFEGetCellTabulation(fem, 1, &T));
106   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
107                                   "  /* Nodal basis function evaluations\n"
108                                   "    - basis component is fastest varying, the basis function, then point */\n"
109                                   "  const %s Basis[%d] = {\n",
110                                   &count, numeric_str, N_q * N_b * N_c));
111   for (p = 0; p < N_q; ++p) {
112     for (b = 0; b < N_b; ++b) {
113       for (c = 0; c < N_c; ++c) { PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "%g,\n", &count, T->T[0][(p * N_b + b) * N_c + c])); }
114     }
115   }
116   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "};\n", &count));
117   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
118                                   "\n"
119                                   "  /* Nodal basis function derivative evaluations,\n"
120                                   "      - derivative direction is fastest varying, then basis component, then basis function, then point */\n"
121                                   "  const %s%d BasisDerivatives[%d] = {\n",
122                                   &count, numeric_str, dim, N_q * N_b * N_c));
123   for (p = 0; p < N_q; ++p) {
124     for (b = 0; b < N_b; ++b) {
125       for (c = 0; c < N_c; ++c) {
126         PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "(%s%d)(", &count, numeric_str, dim));
127         for (d = 0; d < dim; ++d) {
128           if (d > 0) {
129             PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, ", %g", &count, T->T[1][((p * N_b + b) * dim + d) * N_c + c]));
130           } else {
131             PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "%g", &count, T->T[1][((p * N_b + b) * dim + d) * N_c + c]));
132           }
133         }
134         PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "),\n", &count));
135       }
136     }
137   }
138   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "};\n", &count));
139   /* Sizes */
140   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
141                                   "  const int dim    = %d;                           // The spatial dimension\n"
142                                   "  const int N_bl   = %d;                           // The number of concurrent blocks\n"
143                                   "  const int N_b    = %d;                           // The number of basis functions\n"
144                                   "  const int N_comp = %d;                           // The number of basis function components\n"
145                                   "  const int N_bt   = N_b*N_comp;                    // The total number of scalar basis functions\n"
146                                   "  const int N_q    = %d;                           // The number of quadrature points\n"
147                                   "  const int N_bst  = N_bt*N_q;                      // The block size, LCM(N_b*N_comp, N_q), Notice that a block is not processed simultaneously\n"
148                                   "  const int N_t    = N_bst*N_bl;                    // The number of threads, N_bst * N_bl\n"
149                                   "  const int N_bc   = N_t/N_comp;                    // The number of cells per batch (N_b*N_q*N_bl)\n"
150                                   "  const int N_sbc  = N_bst / (N_q * N_comp);\n"
151                                   "  const int N_sqc  = N_bst / N_bt;\n"
152                                   "  /*const int N_c    = N_cb * N_bc;*/\n"
153                                   "\n"
154                                   "  /* Calculated indices */\n"
155                                   "  /*const int tidx    = get_local_id(0) + get_local_size(0)*get_local_id(1);*/\n"
156                                   "  const int tidx    = get_local_id(0);\n"
157                                   "  const int blidx   = tidx / N_bst;                  // Block number for this thread\n"
158                                   "  const int bidx    = tidx %% N_bt;                   // Basis function mapped to this thread\n"
159                                   "  const int cidx    = tidx %% N_comp;                 // Basis component mapped to this thread\n"
160                                   "  const int qidx    = tidx %% N_q;                    // Quadrature point mapped to this thread\n"
161                                   "  const int blbidx  = tidx %% N_q + blidx*N_q;        // Cell mapped to this thread in the basis phase\n"
162                                   "  const int blqidx  = tidx %% N_b + blidx*N_b;        // Cell mapped to this thread in the quadrature phase\n"
163                                   "  const int gidx    = get_group_id(1)*get_num_groups(0) + get_group_id(0);\n"
164                                   "  const int Goffset = gidx*N_cb*N_bc;\n",
165                                   &count, dim, N_bl, N_b, N_c, N_q));
166   /* Local memory */
167   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
168                                   "\n"
169                                   "  /* Quadrature data */\n"
170                                   "  %s                w;                   // $w_q$, Quadrature weight at $x_q$\n"
171                                   "  __local %s         phi_i[%d];    //[N_bt*N_q];  // $\\phi_i(x_q)$, Value of the basis function $i$ at $x_q$\n"
172                                   "  __local %s%d       phiDer_i[%d]; //[N_bt*N_q];  // $\\frac{\\partial\\phi_i(x_q)}{\\partial x_d}$, Value of the derivative of basis function $i$ in direction $x_d$ at $x_q$\n"
173                                   "  /* Geometric data */\n"
174                                   "  __local %s        detJ[%d]; //[N_t];           // $|J(x_q)|$, Jacobian determinant at $x_q$\n"
175                                   "  __local %s        invJ[%d];//[N_t*dim*dim];   // $J^{-1}(x_q)$, Jacobian inverse at $x_q$\n",
176                                   &count, numeric_str, numeric_str, N_b * N_c * N_q, numeric_str, dim, N_b * N_c * N_q, numeric_str, N_t, numeric_str, N_t * dim * dim));
177   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
178                                   "  /* FEM data */\n"
179                                   "  __local %s        u_i[%d]; //[N_t*N_bt];       // Coefficients $u_i$ of the field $u|_{\\mathcal{T}} = \\sum_i u_i \\phi_i$\n",
180                                   &count, numeric_str, N_t * N_b * N_c));
181   if (useAux) {
182     PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "  __local %s        a_i[%d]; //[N_t];            // Coefficients $a_i$ of the auxiliary field $a|_{\\mathcal{T}} = \\sum_i a_i \\phi^R_i$\n", &count, numeric_str, N_t));
183   }
184   if (useF0) {
185     PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
186                                     "  /* Intermediate calculations */\n"
187                                     "  __local %s         f_0[%d]; //[N_t*N_sqc];      // $f_0(u(x_q), \\nabla u(x_q)) |J(x_q)| w_q$\n",
188                                     &count, numeric_str, N_t * N_q));
189   }
190   if (useF1) { PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "  __local %s%d       f_1[%d]; //[N_t*N_sqc];      // $f_1(u(x_q), \\nabla u(x_q)) |J(x_q)| w_q$\n", &count, numeric_str, dim, N_t * N_q)); }
191   /* TODO: If using elasticity, put in mu/lambda coefficients */
192   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
193                                   "  /* Output data */\n"
194                                   "  %s                e_i;                 // Coefficient $e_i$ of the residual\n\n",
195                                   &count, numeric_str));
196   /* One-time loads */
197   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
198                                   "  /* These should be generated inline */\n"
199                                   "  /* Load quadrature weights */\n"
200                                   "  w = weights[qidx];\n"
201                                   "  /* Load basis tabulation \\phi_i for this cell */\n"
202                                   "  if (tidx < N_bt*N_q) {\n"
203                                   "    phi_i[tidx]    = Basis[tidx];\n"
204                                   "    phiDer_i[tidx] = BasisDerivatives[tidx];\n"
205                                   "  }\n\n",
206                                   &count));
207   /* Batch loads */
208   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
209                                   "  for (int batch = 0; batch < N_cb; ++batch) {\n"
210                                   "    /* Load geometry */\n"
211                                   "    detJ[tidx] = jacobianDeterminants[Goffset+batch*N_bc+tidx];\n"
212                                   "    for (int n = 0; n < dim*dim; ++n) {\n"
213                                   "      const int offset = n*N_t;\n"
214                                   "      invJ[offset+tidx] = jacobianInverses[(Goffset+batch*N_bc)*dim*dim+offset+tidx];\n"
215                                   "    }\n"
216                                   "    /* Load coefficients u_i for this cell */\n"
217                                   "    for (int n = 0; n < N_bt; ++n) {\n"
218                                   "      const int offset = n*N_t;\n"
219                                   "      u_i[offset+tidx] = coefficients[(Goffset*N_bt)+batch*N_t*N_b+offset+tidx];\n"
220                                   "    }\n",
221                                   &count));
222   if (useAux) {
223     PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
224                                     "    /* Load coefficients a_i for this cell */\n"
225                                     "    /* TODO: This should not be N_t here, it should be N_bc*N_comp_aux */\n"
226                                     "    a_i[tidx] = coefficientsAux[Goffset+batch*N_t+tidx];\n",
227                                     &count));
228   }
229   /* Quadrature phase */
230   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
231                                   "    barrier(CLK_LOCAL_MEM_FENCE);\n"
232                                   "\n"
233                                   "    /* Map coefficients to values at quadrature points */\n"
234                                   "    for (int c = 0; c < N_sqc; ++c) {\n"
235                                   "      const int cell          = c*N_bl*N_b + blqidx;\n"
236                                   "      const int fidx          = (cell*N_q + qidx)*N_comp + cidx;\n",
237                                   &count));
238   if (useField) { PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      %s  u[%d]; //[N_comp];     // $u(x_q)$, Value of the field at $x_q$\n", &count, numeric_str, N_c)); }
239   if (useFieldDer) { PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      %s%d   gradU[%d]; //[N_comp]; // $\\nabla u(x_q)$, Value of the field gradient at $x_q$\n", &count, numeric_str, dim, N_c)); }
240   if (useFieldAux) { PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      %s  a[%d]; //[1];     // $a(x_q)$, Value of the auxiliary fields at $x_q$\n", &count, numeric_str, 1)); }
241   if (useFieldDerAux) { PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      %s%d   gradA[%d]; //[1]; // $\\nabla a(x_q)$, Value of the auxiliary field gradient at $x_q$\n", &count, numeric_str, dim, 1)); }
242   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
243                                   "\n"
244                                   "      for (int comp = 0; comp < N_comp; ++comp) {\n",
245                                   &count));
246   if (useField) PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "        u[comp] = 0.0;\n", &count));
247   if (useFieldDer) {
248     switch (dim) {
249     case 1: PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "        gradU[comp].x = 0.0;\n", &count)); break;
250     case 2: PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "        gradU[comp].x = 0.0; gradU[comp].y = 0.0;\n", &count)); break;
251     case 3: PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "        gradU[comp].x = 0.0; gradU[comp].y = 0.0; gradU[comp].z = 0.0;\n", &count)); break;
252     }
253   }
254   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      }\n", &count));
255   if (useFieldAux) { PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      a[0] = 0.0;\n", &count)); }
256   if (useFieldDerAux) {
257     switch (dim) {
258     case 1: PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      gradA[0].x = 0.0;\n", &count)); break;
259     case 2: PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      gradA[0].x = 0.0; gradA[0].y = 0.0;\n", &count)); break;
260     case 3: PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      gradA[0].x = 0.0; gradA[0].y = 0.0; gradA[0].z = 0.0;\n", &count)); break;
261     }
262   }
263   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
264                                   "      /* Get field and derivatives at this quadrature point */\n"
265                                   "      for (int i = 0; i < N_b; ++i) {\n"
266                                   "        for (int comp = 0; comp < N_comp; ++comp) {\n"
267                                   "          const int b    = i*N_comp+comp;\n"
268                                   "          const int pidx = qidx*N_bt + b;\n"
269                                   "          const int uidx = cell*N_bt + b;\n"
270                                   "          %s%d   realSpaceDer;\n\n",
271                                   &count, numeric_str, dim));
272   if (useField) PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "          u[comp] += u_i[uidx]*phi_i[pidx];\n", &count));
273   if (useFieldDer) {
274     switch (dim) {
275     case 2:
276       PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
277                                       "          realSpaceDer.x = invJ[cell*dim*dim+0*dim+0]*phiDer_i[pidx].x + invJ[cell*dim*dim+1*dim+0]*phiDer_i[pidx].y;\n"
278                                       "          gradU[comp].x += u_i[uidx]*realSpaceDer.x;\n"
279                                       "          realSpaceDer.y = invJ[cell*dim*dim+0*dim+1]*phiDer_i[pidx].x + invJ[cell*dim*dim+1*dim+1]*phiDer_i[pidx].y;\n"
280                                       "          gradU[comp].y += u_i[uidx]*realSpaceDer.y;\n",
281                                       &count));
282       break;
283     case 3:
284       PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
285                                       "          realSpaceDer.x = invJ[cell*dim*dim+0*dim+0]*phiDer_i[pidx].x + invJ[cell*dim*dim+1*dim+0]*phiDer_i[pidx].y + invJ[cell*dim*dim+2*dim+0]*phiDer_i[pidx].z;\n"
286                                       "          gradU[comp].x += u_i[uidx]*realSpaceDer.x;\n"
287                                       "          realSpaceDer.y = invJ[cell*dim*dim+0*dim+1]*phiDer_i[pidx].x + invJ[cell*dim*dim+1*dim+1]*phiDer_i[pidx].y + invJ[cell*dim*dim+2*dim+1]*phiDer_i[pidx].z;\n"
288                                       "          gradU[comp].y += u_i[uidx]*realSpaceDer.y;\n"
289                                       "          realSpaceDer.z = invJ[cell*dim*dim+0*dim+2]*phiDer_i[pidx].x + invJ[cell*dim*dim+1*dim+2]*phiDer_i[pidx].y + invJ[cell*dim*dim+2*dim+2]*phiDer_i[pidx].z;\n"
290                                       "          gradU[comp].z += u_i[uidx]*realSpaceDer.z;\n",
291                                       &count));
292       break;
293     }
294   }
295   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
296                                   "        }\n"
297                                   "      }\n",
298                                   &count));
299   if (useFieldAux) { PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "          a[0] += a_i[cell];\n", &count)); }
300   /* Calculate residual at quadrature points: Should be generated by an weak form egine */
301   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      /* Process values at quadrature points */\n", &count));
302   switch (op) {
303   case LAPLACIAN:
304     if (useF0) { PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      f_0[fidx] = 4.0;\n", &count)); }
305     if (useF1) {
306       if (useAux) PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      f_1[fidx] = a[0]*gradU[cidx];\n", &count));
307       else PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      f_1[fidx] = gradU[cidx];\n", &count));
308     }
309     break;
310   case ELASTICITY:
311     if (useF0) PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      f_0[fidx] = 4.0;\n", &count));
312     if (useF1) {
313       switch (dim) {
314       case 2:
315         PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
316                                         "      switch (cidx) {\n"
317                                         "      case 0:\n"
318                                         "        f_1[fidx].x = lambda*(gradU[0].x + gradU[1].y) + mu*(gradU[0].x + gradU[0].x);\n"
319                                         "        f_1[fidx].y = lambda*(gradU[0].x + gradU[1].y) + mu*(gradU[0].y + gradU[1].x);\n"
320                                         "        break;\n"
321                                         "      case 1:\n"
322                                         "        f_1[fidx].x = lambda*(gradU[0].x + gradU[1].y) + mu*(gradU[1].x + gradU[0].y);\n"
323                                         "        f_1[fidx].y = lambda*(gradU[0].x + gradU[1].y) + mu*(gradU[1].y + gradU[1].y);\n"
324                                         "      }\n",
325                                         &count));
326         break;
327       case 3:
328         PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
329                                         "      switch (cidx) {\n"
330                                         "      case 0:\n"
331                                         "        f_1[fidx].x = lambda*(gradU[0].x + gradU[1].y + gradU[2].z) + mu*(gradU[0].x + gradU[0].x);\n"
332                                         "        f_1[fidx].y = lambda*(gradU[0].x + gradU[1].y + gradU[2].z) + mu*(gradU[0].y + gradU[1].x);\n"
333                                         "        f_1[fidx].z = lambda*(gradU[0].x + gradU[1].y + gradU[2].z) + mu*(gradU[0].z + gradU[2].x);\n"
334                                         "        break;\n"
335                                         "      case 1:\n"
336                                         "        f_1[fidx].x = lambda*(gradU[0].x + gradU[1].y + gradU[2].z) + mu*(gradU[1].x + gradU[0].y);\n"
337                                         "        f_1[fidx].y = lambda*(gradU[0].x + gradU[1].y + gradU[2].z) + mu*(gradU[1].y + gradU[1].y);\n"
338                                         "        f_1[fidx].z = lambda*(gradU[0].x + gradU[1].y + gradU[2].z) + mu*(gradU[1].y + gradU[2].y);\n"
339                                         "        break;\n"
340                                         "      case 2:\n"
341                                         "        f_1[fidx].x = lambda*(gradU[0].x + gradU[1].y + gradU[2].z) + mu*(gradU[2].x + gradU[0].z);\n"
342                                         "        f_1[fidx].y = lambda*(gradU[0].x + gradU[1].y + gradU[2].z) + mu*(gradU[2].y + gradU[1].z);\n"
343                                         "        f_1[fidx].z = lambda*(gradU[0].x + gradU[1].y + gradU[2].z) + mu*(gradU[2].y + gradU[2].z);\n"
344                                         "      }\n",
345                                         &count));
346         break;
347       }
348     }
349     break;
350   default: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "PDE operator %d is not supported", op);
351   }
352   if (useF0) PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      f_0[fidx] *= detJ[cell]*w;\n", &count));
353   if (useF1) {
354     switch (dim) {
355     case 1: PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      f_1[fidx].x *= detJ[cell]*w;\n", &count)); break;
356     case 2: PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      f_1[fidx].x *= detJ[cell]*w; f_1[fidx].y *= detJ[cell]*w;\n", &count)); break;
357     case 3: PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "      f_1[fidx].x *= detJ[cell]*w; f_1[fidx].y *= detJ[cell]*w; f_1[fidx].z *= detJ[cell]*w;\n", &count)); break;
358     }
359   }
360   /* Thread transpose */
361   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
362                                   "    }\n\n"
363                                   "    /* ==== TRANSPOSE THREADS ==== */\n"
364                                   "    barrier(CLK_LOCAL_MEM_FENCE);\n\n",
365                                   &count));
366   /* Basis phase */
367   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
368                                   "    /* Map values at quadrature points to coefficients */\n"
369                                   "    for (int c = 0; c < N_sbc; ++c) {\n"
370                                   "      const int cell = c*N_bl*N_q + blbidx; /* Cell number in batch */\n"
371                                   "\n"
372                                   "      e_i = 0.0;\n"
373                                   "      for (int q = 0; q < N_q; ++q) {\n"
374                                   "        const int pidx = q*N_bt + bidx;\n"
375                                   "        const int fidx = (cell*N_q + q)*N_comp + cidx;\n"
376                                   "        %s%d   realSpaceDer;\n\n",
377                                   &count, numeric_str, dim));
378 
379   if (useF0) PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail, "        e_i += phi_i[pidx]*f_0[fidx];\n", &count));
380   if (useF1) {
381     switch (dim) {
382     case 2:
383       PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
384                                       "        realSpaceDer.x = invJ[cell*dim*dim+0*dim+0]*phiDer_i[pidx].x + invJ[cell*dim*dim+1*dim+0]*phiDer_i[pidx].y;\n"
385                                       "        e_i           += realSpaceDer.x*f_1[fidx].x;\n"
386                                       "        realSpaceDer.y = invJ[cell*dim*dim+0*dim+1]*phiDer_i[pidx].x + invJ[cell*dim*dim+1*dim+1]*phiDer_i[pidx].y;\n"
387                                       "        e_i           += realSpaceDer.y*f_1[fidx].y;\n",
388                                       &count));
389       break;
390     case 3:
391       PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
392                                       "        realSpaceDer.x = invJ[cell*dim*dim+0*dim+0]*phiDer_i[pidx].x + invJ[cell*dim*dim+1*dim+0]*phiDer_i[pidx].y + invJ[cell*dim*dim+2*dim+0]*phiDer_i[pidx].z;\n"
393                                       "        e_i           += realSpaceDer.x*f_1[fidx].x;\n"
394                                       "        realSpaceDer.y = invJ[cell*dim*dim+0*dim+1]*phiDer_i[pidx].x + invJ[cell*dim*dim+1*dim+1]*phiDer_i[pidx].y + invJ[cell*dim*dim+2*dim+1]*phiDer_i[pidx].z;\n"
395                                       "        e_i           += realSpaceDer.y*f_1[fidx].y;\n"
396                                       "        realSpaceDer.z = invJ[cell*dim*dim+0*dim+2]*phiDer_i[pidx].x + invJ[cell*dim*dim+1*dim+2]*phiDer_i[pidx].y + invJ[cell*dim*dim+2*dim+2]*phiDer_i[pidx].z;\n"
397                                       "        e_i           += realSpaceDer.z*f_1[fidx].z;\n",
398                                       &count));
399       break;
400     }
401   }
402   PetscCallSTR(PetscSNPrintfCount(string_tail, end_of_buffer - string_tail,
403                                   "      }\n"
404                                   "      /* Write element vector for N_{cbc} cells at a time */\n"
405                                   "      elemVec[(Goffset + batch*N_bc + c*N_bl*N_q)*N_bt + tidx] = e_i;\n"
406                                   "    }\n"
407                                   "    /* ==== Could do one write per batch ==== */\n"
408                                   "  }\n"
409                                   "  return;\n"
410                                   "}\n",
411                                   &count));
412   PetscFunctionReturn(0);
413 }
414 
415 static PetscErrorCode PetscFEOpenCLGetIntegrationKernel(PetscFE fem, PetscBool useAux, cl_program *ocl_prog, cl_kernel *ocl_kernel) {
416   PetscFE_OpenCL *ocl = (PetscFE_OpenCL *)fem->data;
417   PetscInt        dim, N_bl;
418   PetscBool       flg;
419   char           *buffer;
420   size_t          len;
421   char            errMsg[8192];
422   cl_int          err;
423 
424   PetscFunctionBegin;
425   PetscCall(PetscFEGetSpatialDimension(fem, &dim));
426   PetscCall(PetscMalloc1(8192, &buffer));
427   PetscCall(PetscFEGetTileSizes(fem, NULL, &N_bl, NULL, NULL));
428   PetscCall(PetscFEOpenCLGenerateIntegrationCode(fem, &buffer, 8192, useAux, N_bl));
429   PetscCall(PetscOptionsHasName(((PetscObject)fem)->options, ((PetscObject)fem)->prefix, "-petscfe_opencl_kernel_print", &flg));
430   if (flg) PetscCall(PetscPrintf(PetscObjectComm((PetscObject)fem), "OpenCL FE Integration Kernel:\n%s\n", buffer));
431   PetscCall(PetscStrlen(buffer, &len));
432   *ocl_prog = clCreateProgramWithSource(ocl->ctx_id, 1, (const char **)&buffer, &len, &err);
433   PetscCall(err);
434   err = clBuildProgram(*ocl_prog, 0, NULL, NULL, NULL, NULL);
435   if (err != CL_SUCCESS) {
436     err = clGetProgramBuildInfo(*ocl_prog, ocl->dev_id, CL_PROGRAM_BUILD_LOG, 8192 * sizeof(char), &errMsg, NULL);
437     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Build failed! Log:\n %s", errMsg);
438   }
439   PetscCall(PetscFree(buffer));
440   *ocl_kernel = clCreateKernel(*ocl_prog, "integrateElementQuadrature", &err);
441   PetscFunctionReturn(0);
442 }
443 
444 static PetscErrorCode PetscFEOpenCLCalculateGrid(PetscFE fem, PetscInt N, PetscInt blockSize, size_t *x, size_t *y, size_t *z) {
445   const PetscInt Nblocks = N / blockSize;
446 
447   PetscFunctionBegin;
448   PetscCheck(!(N % blockSize), PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid block size %d for %d elements", blockSize, N);
449   *z = 1;
450   *y = 1;
451   for (*x = (size_t)(PetscSqrtReal(Nblocks) + 0.5); *x > 0; --*x) {
452     *y = Nblocks / *x;
453     if (*x * *y == (size_t)Nblocks) break;
454   }
455   PetscCheck(*x * *y == (size_t)Nblocks, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Could not find partition for %" PetscInt_FMT " with block size %" PetscInt_FMT, N, blockSize);
456   PetscFunctionReturn(0);
457 }
458 
459 static PetscErrorCode PetscFEOpenCLLogResidual(PetscFE fem, PetscLogDouble time, PetscLogDouble flops) {
460   PetscFE_OpenCL   *ocl = (PetscFE_OpenCL *)fem->data;
461   PetscStageLog     stageLog;
462   PetscEventPerfLog eventLog = NULL;
463   int               stage;
464 
465   PetscFunctionBegin;
466   PetscCall(PetscLogGetStageLog(&stageLog));
467   PetscCall(PetscStageLogGetCurrent(stageLog, &stage));
468   PetscCall(PetscStageLogGetEventPerfLog(stageLog, stage, &eventLog));
469   /* Log performance info */
470   eventLog->eventInfo[ocl->residualEvent].count++;
471   eventLog->eventInfo[ocl->residualEvent].time += time;
472   eventLog->eventInfo[ocl->residualEvent].flops += flops;
473   PetscFunctionReturn(0);
474 }
475 
476 static PetscErrorCode PetscFEIntegrateResidual_OpenCL(PetscDS prob, PetscFormKey key, PetscInt Ne, PetscFEGeom *cgeom, const PetscScalar coefficients[], const PetscScalar coefficients_t[], PetscDS probAux, const PetscScalar coefficientsAux[], PetscReal t, PetscScalar elemVec[]) {
477   /* Nbc = batchSize */
478   PetscFE          fem;
479   PetscFE_OpenCL  *ocl;
480   PetscPointFunc   f0_func;
481   PetscPointFunc   f1_func;
482   PetscQuadrature  q;
483   PetscInt         dim, qNc;
484   PetscInt         N_b;    /* The number of basis functions */
485   PetscInt         N_comp; /* The number of basis function components */
486   PetscInt         N_bt;   /* The total number of scalar basis functions */
487   PetscInt         N_q;    /* The number of quadrature points */
488   PetscInt         N_bst;  /* The block size, LCM(N_bt, N_q), Notice that a block is not process simultaneously */
489   PetscInt         N_t;    /* The number of threads, N_bst * N_bl */
490   PetscInt         N_bl;   /* The number of blocks */
491   PetscInt         N_bc;   /* The batch size, N_bl*N_q*N_b */
492   PetscInt         N_cb;   /* The number of batches */
493   const PetscInt   field = key.field;
494   PetscInt         numFlops, f0Flops = 0, f1Flops = 0;
495   PetscBool        useAux      = probAux ? PETSC_TRUE : PETSC_FALSE;
496   PetscBool        useField    = PETSC_FALSE;
497   PetscBool        useFieldDer = PETSC_TRUE;
498   PetscBool        useF0       = PETSC_TRUE;
499   PetscBool        useF1       = PETSC_TRUE;
500   /* OpenCL variables */
501   cl_program       ocl_prog;
502   cl_kernel        ocl_kernel;
503   cl_event         ocl_ev;   /* The event for tracking kernel execution */
504   cl_ulong         ns_start; /* Nanoseconds counter on GPU at kernel start */
505   cl_ulong         ns_end;   /* Nanoseconds counter on GPU at kernel stop */
506   cl_mem           o_jacobianInverses, o_jacobianDeterminants;
507   cl_mem           o_coefficients, o_coefficientsAux, o_elemVec;
508   float           *f_coeff = NULL, *f_coeffAux = NULL, *f_invJ = NULL, *f_detJ = NULL;
509   double          *d_coeff = NULL, *d_coeffAux = NULL, *d_invJ = NULL, *d_detJ = NULL;
510   PetscReal       *r_invJ = NULL, *r_detJ = NULL;
511   void            *oclCoeff, *oclCoeffAux, *oclInvJ, *oclDetJ;
512   size_t           local_work_size[3], global_work_size[3];
513   size_t           realSize, x, y, z;
514   const PetscReal *points, *weights;
515   int              err;
516 
517   PetscFunctionBegin;
518   PetscCall(PetscDSGetDiscretization(prob, field, (PetscObject *)&fem));
519   ocl = (PetscFE_OpenCL *)fem->data;
520   if (!Ne) {
521     PetscCall(PetscFEOpenCLLogResidual(fem, 0.0, 0.0));
522     PetscFunctionReturn(0);
523   }
524   PetscCall(PetscFEGetSpatialDimension(fem, &dim));
525   PetscCall(PetscFEGetQuadrature(fem, &q));
526   PetscCall(PetscQuadratureGetData(q, NULL, &qNc, &N_q, &points, &weights));
527   PetscCheck(qNc == 1, PETSC_COMM_SELF, PETSC_ERR_SUP, "Only supports scalar quadrature, not %" PetscInt_FMT " components", qNc);
528   PetscCall(PetscFEGetDimension(fem, &N_b));
529   PetscCall(PetscFEGetNumComponents(fem, &N_comp));
530   PetscCall(PetscDSGetResidual(prob, field, &f0_func, &f1_func));
531   PetscCall(PetscFEGetTileSizes(fem, NULL, &N_bl, &N_bc, &N_cb));
532   N_bt  = N_b * N_comp;
533   N_bst = N_bt * N_q;
534   N_t   = N_bst * N_bl;
535   PetscCheck(N_bc * N_comp == N_t, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Number of threads %d should be %d * %d", N_t, N_bc, N_comp);
536   /* Calculate layout */
537   if (Ne % (N_cb * N_bc)) { /* Remainder cells */
538     PetscCall(PetscFEIntegrateResidual_Basic(prob, key, Ne, cgeom, coefficients, coefficients_t, probAux, coefficientsAux, t, elemVec));
539     PetscFunctionReturn(0);
540   }
541   PetscCall(PetscFEOpenCLCalculateGrid(fem, Ne, N_cb * N_bc, &x, &y, &z));
542   local_work_size[0]  = N_bc * N_comp;
543   local_work_size[1]  = 1;
544   local_work_size[2]  = 1;
545   global_work_size[0] = x * local_work_size[0];
546   global_work_size[1] = y * local_work_size[1];
547   global_work_size[2] = z * local_work_size[2];
548   PetscCall(PetscInfo(fem, "GPU layout grid(%zu,%zu,%zu) block(%zu,%zu,%zu) with %d batches\n", x, y, z, local_work_size[0], local_work_size[1], local_work_size[2], N_cb));
549   PetscCall(PetscInfo(fem, " N_t: %d, N_cb: %d\n", N_t, N_cb));
550   /* Generate code */
551   if (probAux) {
552     PetscSpace P;
553     PetscInt   NfAux, order, f;
554 
555     PetscCall(PetscDSGetNumFields(probAux, &NfAux));
556     for (f = 0; f < NfAux; ++f) {
557       PetscFE feAux;
558 
559       PetscCall(PetscDSGetDiscretization(probAux, f, (PetscObject *)&feAux));
560       PetscCall(PetscFEGetBasisSpace(feAux, &P));
561       PetscCall(PetscSpaceGetDegree(P, &order, NULL));
562       PetscCheck(order <= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Can only handle P0 coefficient fields");
563     }
564   }
565   PetscCall(PetscFEOpenCLGetIntegrationKernel(fem, useAux, &ocl_prog, &ocl_kernel));
566   /* Create buffers on the device and send data over */
567   PetscCall(PetscDataTypeGetSize(ocl->realType, &realSize));
568   PetscCheck(cgeom->numPoints <= 1, PETSC_COMM_SELF, PETSC_ERR_SUP, "Only support affine geometry for OpenCL integration right now");
569   if (sizeof(PetscReal) != realSize) {
570     switch (ocl->realType) {
571     case PETSC_FLOAT: {
572       PetscInt c, b, d;
573 
574       PetscCall(PetscMalloc4(Ne * N_bt, &f_coeff, Ne, &f_coeffAux, Ne * dim * dim, &f_invJ, Ne, &f_detJ));
575       for (c = 0; c < Ne; ++c) {
576         f_detJ[c] = (float)cgeom->detJ[c];
577         for (d = 0; d < dim * dim; ++d) { f_invJ[c * dim * dim + d] = (float)cgeom->invJ[c * dim * dim + d]; }
578         for (b = 0; b < N_bt; ++b) { f_coeff[c * N_bt + b] = (float)coefficients[c * N_bt + b]; }
579       }
580       if (coefficientsAux) { /* Assume P0 */
581         for (c = 0; c < Ne; ++c) { f_coeffAux[c] = (float)coefficientsAux[c]; }
582       }
583       oclCoeff = (void *)f_coeff;
584       if (coefficientsAux) {
585         oclCoeffAux = (void *)f_coeffAux;
586       } else {
587         oclCoeffAux = NULL;
588       }
589       oclInvJ = (void *)f_invJ;
590       oclDetJ = (void *)f_detJ;
591     } break;
592     case PETSC_DOUBLE: {
593       PetscInt c, b, d;
594 
595       PetscCall(PetscMalloc4(Ne * N_bt, &d_coeff, Ne, &d_coeffAux, Ne * dim * dim, &d_invJ, Ne, &d_detJ));
596       for (c = 0; c < Ne; ++c) {
597         d_detJ[c] = (double)cgeom->detJ[c];
598         for (d = 0; d < dim * dim; ++d) { d_invJ[c * dim * dim + d] = (double)cgeom->invJ[c * dim * dim + d]; }
599         for (b = 0; b < N_bt; ++b) { d_coeff[c * N_bt + b] = (double)coefficients[c * N_bt + b]; }
600       }
601       if (coefficientsAux) { /* Assume P0 */
602         for (c = 0; c < Ne; ++c) { d_coeffAux[c] = (double)coefficientsAux[c]; }
603       }
604       oclCoeff = (void *)d_coeff;
605       if (coefficientsAux) {
606         oclCoeffAux = (void *)d_coeffAux;
607       } else {
608         oclCoeffAux = NULL;
609       }
610       oclInvJ = (void *)d_invJ;
611       oclDetJ = (void *)d_detJ;
612     } break;
613     default: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Unsupported PETSc type %d", ocl->realType);
614     }
615   } else {
616     PetscInt c, d;
617 
618     PetscCall(PetscMalloc2(Ne * dim * dim, &r_invJ, Ne, &r_detJ));
619     for (c = 0; c < Ne; ++c) {
620       r_detJ[c] = cgeom->detJ[c];
621       for (d = 0; d < dim * dim; ++d) { r_invJ[c * dim * dim + d] = cgeom->invJ[c * dim * dim + d]; }
622     }
623     oclCoeff    = (void *)coefficients;
624     oclCoeffAux = (void *)coefficientsAux;
625     oclInvJ     = (void *)r_invJ;
626     oclDetJ     = (void *)r_detJ;
627   }
628   o_coefficients = clCreateBuffer(ocl->ctx_id, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, Ne * N_bt * realSize, oclCoeff, &err);
629   if (coefficientsAux) {
630     o_coefficientsAux = clCreateBuffer(ocl->ctx_id, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, Ne * realSize, oclCoeffAux, &err);
631   } else {
632     o_coefficientsAux = clCreateBuffer(ocl->ctx_id, CL_MEM_READ_ONLY, Ne * realSize, oclCoeffAux, &err);
633   }
634   o_jacobianInverses     = clCreateBuffer(ocl->ctx_id, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, Ne * dim * dim * realSize, oclInvJ, &err);
635   o_jacobianDeterminants = clCreateBuffer(ocl->ctx_id, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, Ne * realSize, oclDetJ, &err);
636   o_elemVec              = clCreateBuffer(ocl->ctx_id, CL_MEM_WRITE_ONLY, Ne * N_bt * realSize, NULL, &err);
637   /* Kernel launch */
638   PetscCall(clSetKernelArg(ocl_kernel, 0, sizeof(cl_int), (void *)&N_cb));
639   PetscCall(clSetKernelArg(ocl_kernel, 1, sizeof(cl_mem), (void *)&o_coefficients));
640   PetscCall(clSetKernelArg(ocl_kernel, 2, sizeof(cl_mem), (void *)&o_coefficientsAux));
641   PetscCall(clSetKernelArg(ocl_kernel, 3, sizeof(cl_mem), (void *)&o_jacobianInverses));
642   PetscCall(clSetKernelArg(ocl_kernel, 4, sizeof(cl_mem), (void *)&o_jacobianDeterminants));
643   PetscCall(clSetKernelArg(ocl_kernel, 5, sizeof(cl_mem), (void *)&o_elemVec));
644   PetscCall(clEnqueueNDRangeKernel(ocl->queue_id, ocl_kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &ocl_ev));
645   /* Read data back from device */
646   if (sizeof(PetscReal) != realSize) {
647     switch (ocl->realType) {
648     case PETSC_FLOAT: {
649       float   *elem;
650       PetscInt c, b;
651 
652       PetscCall(PetscFree4(f_coeff, f_coeffAux, f_invJ, f_detJ));
653       PetscCall(PetscMalloc1(Ne * N_bt, &elem));
654       PetscCall(clEnqueueReadBuffer(ocl->queue_id, o_elemVec, CL_TRUE, 0, Ne * N_bt * realSize, elem, 0, NULL, NULL));
655       for (c = 0; c < Ne; ++c) {
656         for (b = 0; b < N_bt; ++b) { elemVec[c * N_bt + b] = (PetscScalar)elem[c * N_bt + b]; }
657       }
658       PetscCall(PetscFree(elem));
659     } break;
660     case PETSC_DOUBLE: {
661       double  *elem;
662       PetscInt c, b;
663 
664       PetscCall(PetscFree4(d_coeff, d_coeffAux, d_invJ, d_detJ));
665       PetscCall(PetscMalloc1(Ne * N_bt, &elem));
666       PetscCall(clEnqueueReadBuffer(ocl->queue_id, o_elemVec, CL_TRUE, 0, Ne * N_bt * realSize, elem, 0, NULL, NULL));
667       for (c = 0; c < Ne; ++c) {
668         for (b = 0; b < N_bt; ++b) { elemVec[c * N_bt + b] = (PetscScalar)elem[c * N_bt + b]; }
669       }
670       PetscCall(PetscFree(elem));
671     } break;
672     default: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Unsupported PETSc type %d", ocl->realType);
673     }
674   } else {
675     PetscCall(PetscFree2(r_invJ, r_detJ));
676     PetscCall(clEnqueueReadBuffer(ocl->queue_id, o_elemVec, CL_TRUE, 0, Ne * N_bt * realSize, elemVec, 0, NULL, NULL));
677   }
678   /* Log performance */
679   PetscCall(clGetEventProfilingInfo(ocl_ev, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &ns_start, NULL));
680   PetscCall(clGetEventProfilingInfo(ocl_ev, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &ns_end, NULL));
681   f0Flops = 0;
682   switch (ocl->op) {
683   case LAPLACIAN: f1Flops = useAux ? dim : 0; break;
684   case ELASTICITY: f1Flops = 2 * dim * dim; break;
685   }
686   numFlops = Ne * (N_q * (N_b * N_comp * ((useField ? 2 : 0) + (useFieldDer ? 2 * dim * (dim + 1) : 0))
687                           /*+
688        N_ba*N_compa*((useFieldAux ? 2 : 0) + (useFieldDerAux ? 2*dim*(dim + 1) : 0))*/
689                           + N_comp * ((useF0 ? f0Flops + 2 : 0) + (useF1 ? f1Flops + 2 * dim : 0))) +
690                    N_b * ((useF0 ? 2 : 0) + (useF1 ? 2 * dim * (dim + 1) : 0)));
691   PetscCall(PetscFEOpenCLLogResidual(fem, (ns_end - ns_start) * 1.0e-9, numFlops));
692   /* Cleanup */
693   PetscCall(clReleaseMemObject(o_coefficients));
694   PetscCall(clReleaseMemObject(o_coefficientsAux));
695   PetscCall(clReleaseMemObject(o_jacobianInverses));
696   PetscCall(clReleaseMemObject(o_jacobianDeterminants));
697   PetscCall(clReleaseMemObject(o_elemVec));
698   PetscCall(clReleaseKernel(ocl_kernel));
699   PetscCall(clReleaseProgram(ocl_prog));
700   PetscFunctionReturn(0);
701 }
702 
703 PETSC_INTERN PetscErrorCode PetscFESetUp_Basic(PetscFE);
704 PETSC_INTERN PetscErrorCode PetscFECreateTabulation_Basic(PetscFE, PetscInt, const PetscReal[], PetscInt, PetscTabulation);
705 
706 static PetscErrorCode PetscFEInitialize_OpenCL(PetscFE fem) {
707   PetscFunctionBegin;
708   fem->ops->setfromoptions          = NULL;
709   fem->ops->setup                   = PetscFESetUp_Basic;
710   fem->ops->view                    = NULL;
711   fem->ops->destroy                 = PetscFEDestroy_OpenCL;
712   fem->ops->getdimension            = PetscFEGetDimension_Basic;
713   fem->ops->createtabulation        = PetscFECreateTabulation_Basic;
714   fem->ops->integrateresidual       = PetscFEIntegrateResidual_OpenCL;
715   fem->ops->integratebdresidual     = NULL /* PetscFEIntegrateBdResidual_OpenCL */;
716   fem->ops->integratejacobianaction = NULL /* PetscFEIntegrateJacobianAction_OpenCL */;
717   fem->ops->integratejacobian       = PetscFEIntegrateJacobian_Basic;
718   PetscFunctionReturn(0);
719 }
720 
721 /*MC
722   PETSCFEOPENCL = "opencl" - A PetscFE object that integrates using a vectorized OpenCL implementation
723 
724   Level: intermediate
725 
726 .seealso: `PetscFEType`, `PetscFECreate()`, `PetscFESetType()`
727 M*/
728 
729 PETSC_EXTERN PetscErrorCode PetscFECreate_OpenCL(PetscFE fem) {
730   PetscFE_OpenCL *ocl;
731   cl_uint         num_platforms;
732   cl_platform_id  platform_ids[42];
733   cl_uint         num_devices;
734   cl_device_id    device_ids[42];
735   cl_int          err;
736 
737   PetscFunctionBegin;
738   PetscValidHeaderSpecific(fem, PETSCFE_CLASSID, 1);
739   PetscCall(PetscNewLog(fem, &ocl));
740   fem->data = ocl;
741 
742   /* Init Platform */
743   PetscCall(clGetPlatformIDs(42, platform_ids, &num_platforms));
744   PetscCheck(num_platforms, PetscObjectComm((PetscObject)fem), PETSC_ERR_SUP, "No OpenCL platform found.");
745   ocl->pf_id = platform_ids[0];
746   /* Init Device */
747   PetscCall(clGetDeviceIDs(ocl->pf_id, CL_DEVICE_TYPE_ALL, 42, device_ids, &num_devices));
748   PetscCheck(num_devices, PetscObjectComm((PetscObject)fem), PETSC_ERR_SUP, "No OpenCL device found.");
749   ocl->dev_id = device_ids[0];
750   /* Create context with one command queue */
751   ocl->ctx_id = clCreateContext(0, 1, &(ocl->dev_id), NULL, NULL, &err);
752   PetscCall(err);
753   ocl->queue_id = clCreateCommandQueue(ocl->ctx_id, ocl->dev_id, CL_QUEUE_PROFILING_ENABLE, &err);
754   PetscCall(err);
755   /* Types */
756   ocl->realType = PETSC_FLOAT;
757   /* Register events */
758   PetscCall(PetscLogEventRegister("OpenCL FEResidual", PETSCFE_CLASSID, &ocl->residualEvent));
759   /* Equation handling */
760   ocl->op = LAPLACIAN;
761 
762   PetscCall(PetscFEInitialize_OpenCL(fem));
763   PetscFunctionReturn(0);
764 }
765 
766 /*@
767   PetscFEOpenCLSetRealType - Set the scalar type for running on the accelerator
768 
769   Input Parameters:
770 + fem      - The PetscFE
771 - realType - The scalar type
772 
773   Level: developer
774 
775 .seealso: `PetscFEOpenCLGetRealType()`
776 @*/
777 PetscErrorCode PetscFEOpenCLSetRealType(PetscFE fem, PetscDataType realType) {
778   PetscFE_OpenCL *ocl = (PetscFE_OpenCL *)fem->data;
779 
780   PetscFunctionBegin;
781   PetscValidHeaderSpecific(fem, PETSCFE_CLASSID, 1);
782   ocl->realType = realType;
783   PetscFunctionReturn(0);
784 }
785 
786 /*@
787   PetscFEOpenCLGetRealType - Get the scalar type for running on the accelerator
788 
789   Input Parameter:
790 . fem      - The PetscFE
791 
792   Output Parameter:
793 . realType - The scalar type
794 
795   Level: developer
796 
797 .seealso: `PetscFEOpenCLSetRealType()`
798 @*/
799 PetscErrorCode PetscFEOpenCLGetRealType(PetscFE fem, PetscDataType *realType) {
800   PetscFE_OpenCL *ocl = (PetscFE_OpenCL *)fem->data;
801 
802   PetscFunctionBegin;
803   PetscValidHeaderSpecific(fem, PETSCFE_CLASSID, 1);
804   PetscValidPointer(realType, 2);
805   *realType = ocl->realType;
806   PetscFunctionReturn(0);
807 }
808 
809 #endif /* PETSC_HAVE_OPENCL */
810