xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-qfunction-load.cpp (revision 437930d19388999b5cc2d76e2fe0d14f58fb41f3)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed/ceed.h>
18 #include <ceed/backend.h>
19 #include <ceed/jit-tools.h>
20 #include <iostream>
21 #include <sstream>
22 #include <string.h>
23 #include "ceed-cuda-ref.h"
24 #include "../cuda/ceed-cuda-compile.h"
25 
26 //------------------------------------------------------------------------------
27 // Build QFunction kernel
28 //------------------------------------------------------------------------------
29 extern "C" int CeedCudaBuildQFunction(CeedQFunction qf) {
30   CeedInt ierr;
31   using std::ostringstream;
32   using std::string;
33   Ceed ceed;
34   CeedQFunctionGetCeed(qf, &ceed);
35   CeedQFunction_Cuda *data;
36   ierr = CeedQFunctionGetData(qf, (void **)&data); CeedChkBackend(ierr);
37 
38   // QFunction is built
39   if (data->QFunction)
40     return CEED_ERROR_SUCCESS;
41 
42   if (!data->qfunction_source)
43     // LCOV_EXCL_START
44     return CeedError(ceed, CEED_ERROR_BACKEND,
45                      "No QFunction source or CUfunction provided.");
46   // LCOV_EXCL_STOP
47 
48   // QFunction kernel generation
49   CeedInt num_input_fields, num_output_fields, size;
50   CeedQFunctionField *input_fields, *output_fields;
51   ierr = CeedQFunctionGetFields(qf, &num_input_fields, &input_fields,
52                                 &num_output_fields, &output_fields);
53   CeedChkBackend(ierr);
54 
55   // Build strings for final kernel
56   char *read_write_kernel_path, *read_write_kernel_source;
57   ierr = CeedPathConcatenate(ceed, __FILE__, "kernels/cuda-ref-qfunction.h",
58                              &read_write_kernel_path); CeedChkBackend(ierr);
59   ierr = CeedLoadSourceToBuffer(ceed, read_write_kernel_path, &read_write_kernel_source);
60   CeedChkBackend(ierr);
61   string qfunction_source(data->qfunction_source);
62   string qfunction_name(data->qfunction_name);
63   string read_write(read_write_kernel_source);
64   string kernel_name = "CeedKernel_Cuda_ref_" + qfunction_name;
65   ostringstream code;
66 
67   // Defintions
68   code << "\n#define CEED_QFUNCTION(name) inline __device__ int name\n";
69   code << "#define CEED_QFUNCTION_HELPER inline __device__\n";
70   code << "#define CeedPragmaSIMD\n";
71   code << "#define CEED_ERROR_SUCCESS 0\n";
72   code << "#define CEED_Q_VLA 1\n\n";
73   code << "typedef struct { const CeedScalar* inputs[16]; CeedScalar* outputs[16]; } Fields_Cuda;\n";
74   code << read_write;
75   code << qfunction_source;
76   code << "extern \"C\" __global__ void " << kernel_name << "(void *ctx, CeedInt Q, Fields_Cuda fields) {\n";
77 
78   // Inputs
79   for (CeedInt i = 0; i < num_input_fields; i++) {
80     code << "  // Input field " << i << "\n";
81     ierr = CeedQFunctionFieldGetSize(input_fields[i], &size); CeedChkBackend(ierr);
82     code << "  const CeedInt size_in_" << i << " = "<<size<<";\n";
83     code << "  CeedScalar r_q" << i << "[size_in_" << i << "];\n";
84   }
85   code << "\n";
86 
87   // Outputs
88   for (CeedInt i = 0; i < num_output_fields; i++) {
89     code << "  // Output field " << i << "\n";
90     ierr = CeedQFunctionFieldGetSize(output_fields[i], &size); CeedChkBackend(ierr);
91     code << "  const CeedInt size_out_" << i << " = " << size << ";\n";
92     code << "  CeedScalar r_qq" << i << "[size_out_" << i << "];\n";
93   }
94   code << "\n";
95 
96   // Setup input/output arrays
97   code << "  const CeedScalar* in[" << num_input_fields << "];\n";
98   for (CeedInt i = 0; i < num_input_fields; i++) {
99     code << "    in[" << i << "] = r_q" << i << ";\n";
100   }
101   code << "  CeedScalar* out[" << num_output_fields << "];\n";
102   for (CeedInt i = 0; i < num_output_fields; i++) {
103     code << "    out[" << i << "] = r_qq" << i << ";\n";
104   }
105   code << "\n";
106 
107   // Loop over quadrature points
108   code << "  for (CeedInt q = blockIdx.x * blockDim.x + threadIdx.x; q < Q; q += blockDim.x * gridDim.x) {\n";
109 
110   // Load inputs
111   for (CeedInt i = 0; i < num_input_fields; i++) {
112     code << "    // Input field " << i << "\n";
113     code << "    readQuads<size_in_" << i << ">(q, Q, fields.inputs[" << i << "], r_q" << i << ");\n";
114   }
115   // QFunction
116   code << "    // QFunction\n";
117   code << "    " << qfunction_name << "(ctx, 1, in, out);\n";
118 
119   // Write outputs
120   for (CeedInt i = 0; i < num_output_fields; i++) {
121     code << "    // Output field " << i << "\n";
122     code << "    writeQuads<size_out_" << i << ">(q, Q, r_qq" << i << ", fields.outputs[" << i << "]);\n";
123   }
124   code << "  }\n";
125   code << "}\n";
126 
127   // View kernel for debugging
128   CeedDebug256(ceed, 1, "Generated QFunction Kernels:\n");
129   CeedDebug(ceed, code.str().c_str());
130 
131   // Compile kernel
132   ierr = CeedCompileCuda(ceed, code.str().c_str(), &data->module, 0);
133   CeedChkBackend(ierr);
134   ierr = CeedGetKernelCuda(ceed, data->module, kernel_name.c_str(), &data->QFunction);
135   CeedChkBackend(ierr);
136 
137   // Cleanup
138   ierr = CeedFree(&data->qfunction_source); CeedChkBackend(ierr);
139   ierr = CeedFree(&read_write_kernel_path); CeedChkBackend(ierr);
140   ierr = CeedFree(&read_write_kernel_source); CeedChkBackend(ierr);
141 
142   return CEED_ERROR_SUCCESS;
143 }
144 //------------------------------------------------------------------------------
145