xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-qfunction-load.cpp (revision d92fedf5b7546cf2fc50391dbcfb657a2e1f0a3b)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed/ceed.h>
18 #include <ceed/backend.h>
19 #include <iostream>
20 #include <sstream>
21 #include <string.h>
22 #include "ceed-cuda-ref.h"
23 #include "../cuda/ceed-cuda-compile.h"
24 
25 static const char *qReadWrite = QUOTE(
26 template <int SIZE>
27 //------------------------------------------------------------------------------
28 // Read from quadrature points
29 //------------------------------------------------------------------------------
30 inline __device__ void readQuads(const CeedInt quad, const CeedInt nquads, const CeedScalar* d_u, CeedScalar* r_u) {
31   for(CeedInt comp = 0; comp < SIZE; ++comp) {
32     r_u[comp] = d_u[quad + nquads * comp];
33   }
34 }
35 
36 //------------------------------------------------------------------------------
37 // Write at quadrature points
38 //------------------------------------------------------------------------------
39 template <int SIZE>
40 inline __device__ void writeQuads(const CeedInt quad, const CeedInt nquads, const CeedScalar* r_v, CeedScalar* d_v) {
41   for(CeedInt comp = 0; comp < SIZE; ++comp) {
42     d_v[quad + nquads * comp] = r_v[comp];
43   }
44 }
45 );
46 
47 //------------------------------------------------------------------------------
48 // Build QFunction kernel
49 //------------------------------------------------------------------------------
50 extern "C" int CeedCudaBuildQFunction(CeedQFunction qf) {
51   CeedInt ierr;
52   using std::ostringstream;
53   using std::string;
54   Ceed ceed;
55   CeedQFunctionGetCeed(qf, &ceed);
56   CeedQFunction_Cuda *data;
57   ierr = CeedQFunctionGetData(qf, (void **)&data); CeedChkBackend(ierr);
58   // QFunction is built
59   if (data->qFunction)
60     return CEED_ERROR_SUCCESS;
61   if (!data->qFunctionSource)
62     return CeedError(ceed, CEED_ERROR_BACKEND, "No QFunction source or CUfunction provided.");
63 
64   // QFunction kernel generation
65   CeedInt numinputfields, numoutputfields, size;
66   CeedQFunctionField *qfinputfields, *qfoutputfields;
67   ierr = CeedQFunctionGetFields(qf, &numinputfields, &qfinputfields, &numoutputfields, &qfoutputfields);
68   CeedChkBackend(ierr);
69 
70   // Build strings for final kernel
71   string qFunction(data->qFunctionSource);
72   string qReadWriteS(qReadWrite);
73   ostringstream code;
74   string qFunctionName(data->qFunctionName);
75   string kernelName;
76   kernelName = "CeedKernel_Cuda_ref_" + qFunctionName;
77 
78   // Defintions
79   code << "\n#define CEED_QFUNCTION(name) inline __device__ int name\n";
80   code << "#define CEED_QFUNCTION_HELPER inline __device__\n";
81   code << "#define CeedPragmaSIMD\n";
82   code << "#define CEED_ERROR_SUCCESS 0\n";
83   code << "#define CEED_Q_VLA 1\n\n";
84   code << "typedef struct { const CeedScalar* inputs[16]; CeedScalar* outputs[16]; } Fields_Cuda;\n";
85   code << qReadWriteS;
86   code << qFunction;
87   code << "extern \"C\" __global__ void " << kernelName << "(void *ctx, CeedInt Q, Fields_Cuda fields) {\n";
88 
89   // Inputs
90   for (CeedInt i = 0; i < numinputfields; i++) {
91     code << "// Input field "<<i<<"\n";
92     ierr = CeedQFunctionFieldGetSize(qfinputfields[i], &size); CeedChkBackend(ierr);
93     code << "  const CeedInt size_in_"<<i<<" = "<<size<<";\n";
94     code << "  CeedScalar r_q"<<i<<"[size_in_"<<i<<"];\n";
95   }
96 
97   // Outputs
98   for (CeedInt i = 0; i < numoutputfields; i++) {
99     code << "// Output field "<<i<<"\n";
100     ierr = CeedQFunctionFieldGetSize(qfoutputfields[i], &size); CeedChkBackend(ierr);
101     code << "  const CeedInt size_out_"<<i<<" = "<<size<<";\n";
102     code << "  CeedScalar r_qq"<<i<<"[size_out_"<<i<<"];\n";
103   }
104 
105   // Setup input/output arrays
106   code << "  const CeedScalar* in["<<numinputfields<<"];\n";
107   for (CeedInt i = 0; i < numinputfields; i++) {
108     code << "    in["<<i<<"] = r_q"<<i<<";\n";
109   }
110   code << "  CeedScalar* out["<<numoutputfields<<"];\n";
111   for (CeedInt i = 0; i < numoutputfields; i++) {
112     code << "    out["<<i<<"] = r_qq"<<i<<";\n";
113   }
114 
115   // Loop over quadrature points
116   code << "  for (CeedInt q = blockIdx.x * blockDim.x + threadIdx.x; q < Q; q += blockDim.x * gridDim.x) {\n";
117 
118   // Load inputs
119   for (CeedInt i = 0; i < numinputfields; i++) {
120     code << "// Input field "<<i<<"\n";
121     code << "  readQuads<size_in_"<<i<<">(q, Q, fields.inputs["<<i<<"], r_q"<<i<<");\n";
122   }
123   // QFunction
124   code << "// QFunction\n";
125   code << "    "<<qFunctionName<<"(ctx, 1, in, out);\n";
126 
127   // Write outputs
128   for (CeedInt i = 0; i < numoutputfields; i++) {
129     code << "// Output field "<<i<<"\n";
130     code << "  writeQuads<size_out_"<<i<<">(q, Q, r_qq"<<i<<", fields.outputs["<<i<<"]);\n";
131   }
132   code << "  }\n";
133   code << "}\n";
134 
135   // View kernel for debugging
136   CeedDebug(ceed, code.str().c_str());
137 
138   // Compile kernel
139   ierr = CeedCompileCuda(ceed, code.str().c_str(), &data->module, 0);
140   CeedChkBackend(ierr);
141   ierr = CeedGetKernelCuda(ceed, data->module, kernelName.c_str(), &data->qFunction);
142   CeedChkBackend(ierr);
143 
144   // Cleanup
145   ierr = CeedFree(&data->qFunctionSource); CeedChkBackend(ierr);
146   return CEED_ERROR_SUCCESS;
147 }
148 //------------------------------------------------------------------------------
149