xref: /libCEED/backends/hip-gen/ceed-hip-gen-operator.c (revision 9330daecb0fc008043eec1b94c46ef7aecbb00cd)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-source/hip/hip-types.h>
11 #include <stddef.h>
12 
13 #include "../hip/ceed-hip-compile.h"
14 #include "ceed-hip-gen-operator-build.h"
15 #include "ceed-hip-gen.h"
16 
17 //------------------------------------------------------------------------------
18 // Destroy operator
19 //------------------------------------------------------------------------------
20 static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
21   CeedOperator_Hip_gen *impl;
22   CeedCallBackend(CeedOperatorGetData(op, &impl));
23   CeedCallBackend(CeedFree(&impl));
24   return CEED_ERROR_SUCCESS;
25 }
26 
27 //------------------------------------------------------------------------------
28 // Apply and add to output
29 //------------------------------------------------------------------------------
30 static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
31   Ceed ceed;
32   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
33   CeedOperator_Hip_gen *data;
34   CeedCallBackend(CeedOperatorGetData(op, &data));
35   CeedQFunction          qf;
36   CeedQFunction_Hip_gen *qf_data;
37   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
38   CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
39   CeedInt num_elem, num_input_fields, num_output_fields;
40   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
41   CeedOperatorField *op_input_fields, *op_output_fields;
42   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
43   CeedQFunctionField *qf_input_fields, *qf_output_fields;
44   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
45   CeedEvalMode eval_mode;
46   CeedVector   vec, output_vecs[CEED_FIELD_MAX] = {NULL};
47 
48   // Creation of the operator
49   CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op));
50 
51   // Input vectors
52   for (CeedInt i = 0; i < num_input_fields; i++) {
53     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
54     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
55       data->fields.inputs[i] = NULL;
56     } else {
57       // Get input vector
58       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
59       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
60       CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
61     }
62   }
63 
64   // Output vectors
65   for (CeedInt i = 0; i < num_output_fields; i++) {
66     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
67     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
68       data->fields.outputs[i] = NULL;
69     } else {
70       // Get output vector
71       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
72       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
73       output_vecs[i] = vec;
74       // Check for multiple output modes
75       CeedInt index = -1;
76       for (CeedInt j = 0; j < i; j++) {
77         if (vec == output_vecs[j]) {
78           index = j;
79           break;
80         }
81       }
82       if (index == -1) {
83         CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i]));
84       } else {
85         data->fields.outputs[i] = data->fields.outputs[index];
86       }
87     }
88   }
89 
90   // Get context data
91   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
92 
93   // Apply operator
94   void         *opargs[]  = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W};
95   const CeedInt dim       = data->dim;
96   const CeedInt Q_1d      = data->Q_1d;
97   const CeedInt P_1d      = data->max_P_1d;
98   const CeedInt thread_1d = CeedIntMax(Q_1d, P_1d);
99   CeedInt       block_sizes[3];
100   CeedCallBackend(BlockGridCalculate_Hip_gen(dim, num_elem, P_1d, Q_1d, block_sizes));
101   if (dim == 1) {
102     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
103     CeedInt sharedMem = block_sizes[2] * thread_1d * sizeof(CeedScalar);
104     CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
105   } else if (dim == 2) {
106     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
107     CeedInt sharedMem = block_sizes[2] * thread_1d * thread_1d * sizeof(CeedScalar);
108     CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
109   } else if (dim == 3) {
110     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
111     CeedInt sharedMem = block_sizes[2] * thread_1d * thread_1d * sizeof(CeedScalar);
112     CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
113   }
114 
115   // Restore input arrays
116   for (CeedInt i = 0; i < num_input_fields; i++) {
117     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
118     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
119     } else {
120       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
121       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
122       CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
123     }
124   }
125 
126   // Restore output arrays
127   for (CeedInt i = 0; i < num_output_fields; i++) {
128     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
129     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
130     } else {
131       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
132       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
133       // Check for multiple output modes
134       CeedInt index = -1;
135       for (CeedInt j = 0; j < i; j++) {
136         if (vec == output_vecs[j]) {
137           index = j;
138           break;
139         }
140       }
141       if (index == -1) {
142         CeedCallBackend(CeedVectorRestoreArray(vec, &data->fields.outputs[i]));
143       }
144     }
145   }
146 
147   // Restore context data
148   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
149 
150   return CEED_ERROR_SUCCESS;
151 }
152 
153 //------------------------------------------------------------------------------
154 // Create operator
155 //------------------------------------------------------------------------------
156 int CeedOperatorCreate_Hip_gen(CeedOperator op) {
157   Ceed ceed;
158   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
159   CeedOperator_Hip_gen *impl;
160 
161   CeedCallBackend(CeedCalloc(1, &impl));
162   CeedCallBackend(CeedOperatorSetData(op, impl));
163 
164   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen));
165   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip_gen));
166   return CEED_ERROR_SUCCESS;
167 }
168 
169 //------------------------------------------------------------------------------
170