xref: /libCEED/backends/hip-ref/ceed-hip-ref-qfunction-load.cpp (revision d92fedf5b7546cf2fc50391dbcfb657a2e1f0a3b)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed/ceed.h>
18 #include <ceed/backend.h>
19 #include <iostream>
20 #include <sstream>
21 #include <string.h>
22 #include "ceed-hip-ref.h"
23 #include "../hip/ceed-hip-compile.h"
24 
25 static const char *qReadWrite = QUOTE(
26 template <int SIZE>
27 //------------------------------------------------------------------------------
28 // Read from quadrature points
29 //------------------------------------------------------------------------------
30 inline __device__ void readQuads(const CeedInt quad, const CeedInt nquads, const CeedScalar* d_u, CeedScalar* r_u) {
31   for(CeedInt comp = 0; comp < SIZE; ++comp) {
32     r_u[comp] = d_u[quad + nquads * comp];
33   }
34 }
35 
36 //------------------------------------------------------------------------------
37 // Write at quadrature points
38 //------------------------------------------------------------------------------
39 template <int SIZE>
40 inline __device__ void writeQuads(const CeedInt quad, const CeedInt nquads, const CeedScalar* r_v, CeedScalar* d_v) {
41   for(CeedInt comp = 0; comp < SIZE; ++comp) {
42     d_v[quad + nquads * comp] = r_v[comp];
43   }
44 }
45 );
46 
47 //------------------------------------------------------------------------------
48 // Build QFunction kernel
49 //------------------------------------------------------------------------------
50 extern "C" int CeedHipBuildQFunction(CeedQFunction qf) {
51   CeedInt ierr;
52   using std::ostringstream;
53   using std::string;
54   CeedQFunction_Hip *data;
55   ierr = CeedQFunctionGetData(qf, (void **)&data); CeedChkBackend(ierr);
56   // QFunction is built
57   if (!data->qFunctionSource)
58     return CEED_ERROR_SUCCESS;
59 
60   // QFunction kernel generation
61   CeedInt numinputfields, numoutputfields, size;
62   CeedQFunctionField *qfinputfields, *qfoutputfields;
63   ierr = CeedQFunctionGetFields(qf, &numinputfields, &qfinputfields, &numoutputfields, &qfoutputfields);
64   CeedChkBackend(ierr);
65 
66   // Build strings for final kernel
67   string qFunction(data->qFunctionSource);
68   string qReadWriteS(qReadWrite);
69   ostringstream code;
70   string qFunctionName(data->qFunctionName);
71   string kernelName;
72   kernelName = "CeedKernel_Hip_ref_" + qFunctionName;
73 
74   // Defintions
75   code << "\n#define CEED_QFUNCTION(name) inline __device__ int name\n";
76   code << "#define CEED_QFUNCTION_HELPER inline __device__ __forceinline__\n";
77   code << "#define CeedPragmaSIMD\n";
78   code << "#define CEED_ERROR_SUCCESS 0\n";
79   code << "#define CEED_Q_VLA 1\n\n";
80   code << "typedef struct { const CeedScalar* inputs[16]; CeedScalar* outputs[16]; } Fields_Hip;\n";
81   code << qReadWriteS;
82   code << qFunction;
83   code << "extern \"C\" __global__ void " << kernelName << "(void *ctx, CeedInt Q, Fields_Hip fields) {\n";
84 
85   // Inputs
86   for (CeedInt i = 0; i < numinputfields; i++) {
87     code << "// Input field "<<i<<"\n";
88     ierr = CeedQFunctionFieldGetSize(qfinputfields[i], &size); CeedChkBackend(ierr);
89     code << "  const CeedInt size_in_"<<i<<" = "<<size<<";\n";
90     code << "  CeedScalar r_q"<<i<<"[size_in_"<<i<<"];\n";
91   }
92 
93   // Outputs
94   for (CeedInt i = 0; i < numoutputfields; i++) {
95     code << "// Output field "<<i<<"\n";
96     ierr = CeedQFunctionFieldGetSize(qfoutputfields[i], &size); CeedChkBackend(ierr);
97     code << "  const CeedInt size_out_"<<i<<" = "<<size<<";\n";
98     code << "  CeedScalar r_qq"<<i<<"[size_out_"<<i<<"];\n";
99   }
100 
101   // Setup input/output arrays
102   code << "  const CeedScalar* in["<<numinputfields<<"];\n";
103   for (CeedInt i = 0; i < numinputfields; i++) {
104     code << "    in["<<i<<"] = r_q"<<i<<";\n";
105   }
106   code << "  CeedScalar* out["<<numoutputfields<<"];\n";
107   for (CeedInt i = 0; i < numoutputfields; i++) {
108     code << "    out["<<i<<"] = r_qq"<<i<<";\n";
109   }
110 
111   // Loop over quadrature points
112   code << "  for (CeedInt q = blockIdx.x * blockDim.x + threadIdx.x; q < Q; q += blockDim.x * gridDim.x) {\n";
113 
114   // Load inputs
115   for (CeedInt i = 0; i < numinputfields; i++) {
116     code << "// Input field "<<i<<"\n";
117     code << "  readQuads<size_in_"<<i<<">(q, Q, fields.inputs["<<i<<"], r_q"<<i<<");\n";
118   }
119   // QFunction
120   code << "// QFunction\n";
121   code << "    "<<qFunctionName<<"(ctx, 1, in, out);\n";
122 
123   // Write outputs
124   for (CeedInt i = 0; i < numoutputfields; i++) {
125     code << "// Output field "<<i<<"\n";
126     code << "  writeQuads<size_out_"<<i<<">(q, Q, r_qq"<<i<<", fields.outputs["<<i<<"]);\n";
127   }
128   code << "  }\n";
129   code << "}\n";
130 
131   // View kernel for debugging
132   Ceed ceed;
133   CeedQFunctionGetCeed(qf, &ceed);
134   CeedDebug(ceed, code.str().c_str());
135 
136   // Compile kernel
137   ierr = CeedCompileHip(ceed, code.str().c_str(), &data->module, 0);
138   CeedChkBackend(ierr);
139   ierr = CeedGetKernelHip(ceed, data->module, kernelName.c_str(), &data->qFunction);
140   CeedChkBackend(ierr);
141 
142   // Cleanup
143   ierr = CeedFree(&data->qFunctionSource); CeedChkBackend(ierr);
144   return CEED_ERROR_SUCCESS;
145 }
146 //------------------------------------------------------------------------------
147