xref: /libCEED/backends/hip-gen/ceed-hip-gen-operator.c (revision 356036fa84f714fa73ef64c9a80ce2028dde816f)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-source/hip/hip-types.h>
11 #include <stddef.h>
12 
13 #include "../hip/ceed-hip-common.h"
14 #include "../hip/ceed-hip-compile.h"
15 #include "ceed-hip-gen-operator-build.h"
16 #include "ceed-hip-gen.h"
17 
18 //------------------------------------------------------------------------------
19 // Destroy operator
20 //------------------------------------------------------------------------------
21 static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
22   CeedOperator_Hip_gen *impl;
23   CeedCallBackend(CeedOperatorGetData(op, &impl));
24   CeedCallBackend(CeedFree(&impl));
25   return CEED_ERROR_SUCCESS;
26 }
27 
28 //------------------------------------------------------------------------------
29 // Apply and add to output
30 //------------------------------------------------------------------------------
31 static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
32   Ceed ceed;
33   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
34   CeedOperator_Hip_gen *data;
35   CeedCallBackend(CeedOperatorGetData(op, &data));
36   CeedQFunction          qf;
37   CeedQFunction_Hip_gen *qf_data;
38   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
39   CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
40   CeedInt num_elem, num_input_fields, num_output_fields;
41   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
42   CeedOperatorField *op_input_fields, *op_output_fields;
43   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
44   CeedQFunctionField *qf_input_fields, *qf_output_fields;
45   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
46   CeedEvalMode eval_mode;
47   CeedVector   vec, output_vecs[CEED_FIELD_MAX] = {NULL};
48 
49   // Creation of the operator
50   CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op));
51 
52   // Input vectors
53   for (CeedInt i = 0; i < num_input_fields; i++) {
54     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
55     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
56       data->fields.inputs[i] = NULL;
57     } else {
58       // Get input vector
59       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
60       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
61       CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
62     }
63   }
64 
65   // Output vectors
66   for (CeedInt i = 0; i < num_output_fields; i++) {
67     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
68     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
69       data->fields.outputs[i] = NULL;
70     } else {
71       // Get output vector
72       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
73       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
74       output_vecs[i] = vec;
75       // Check for multiple output modes
76       CeedInt index = -1;
77       for (CeedInt j = 0; j < i; j++) {
78         if (vec == output_vecs[j]) {
79           index = j;
80           break;
81         }
82       }
83       if (index == -1) {
84         CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i]));
85       } else {
86         data->fields.outputs[i] = data->fields.outputs[index];
87       }
88     }
89   }
90 
91   // Get context data
92   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
93 
94   // Apply operator
95   void         *opargs[]  = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W};
96   const CeedInt dim       = data->dim;
97   const CeedInt Q_1d      = data->Q_1d;
98   const CeedInt P_1d      = data->max_P_1d;
99   const CeedInt thread_1d = CeedIntMax(Q_1d, P_1d);
100   CeedInt       block_sizes[3];
101   CeedCallBackend(BlockGridCalculate_Hip_gen(dim, num_elem, P_1d, Q_1d, block_sizes));
102   if (dim == 1) {
103     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
104     CeedInt sharedMem = block_sizes[2] * thread_1d * sizeof(CeedScalar);
105     CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
106   } else if (dim == 2) {
107     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
108     CeedInt sharedMem = block_sizes[2] * thread_1d * thread_1d * sizeof(CeedScalar);
109     CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
110   } else if (dim == 3) {
111     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
112     CeedInt sharedMem = block_sizes[2] * thread_1d * thread_1d * sizeof(CeedScalar);
113     CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
114   }
115 
116   // Restore input arrays
117   for (CeedInt i = 0; i < num_input_fields; i++) {
118     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
119     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
120     } else {
121       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
122       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
123       CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
124     }
125   }
126 
127   // Restore output arrays
128   for (CeedInt i = 0; i < num_output_fields; i++) {
129     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
130     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
131     } else {
132       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
133       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
134       // Check for multiple output modes
135       CeedInt index = -1;
136       for (CeedInt j = 0; j < i; j++) {
137         if (vec == output_vecs[j]) {
138           index = j;
139           break;
140         }
141       }
142       if (index == -1) {
143         CeedCallBackend(CeedVectorRestoreArray(vec, &data->fields.outputs[i]));
144       }
145     }
146   }
147 
148   // Restore context data
149   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
150 
151   return CEED_ERROR_SUCCESS;
152 }
153 
154 //------------------------------------------------------------------------------
155 // Create operator
156 //------------------------------------------------------------------------------
157 int CeedOperatorCreate_Hip_gen(CeedOperator op) {
158   Ceed ceed;
159   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
160   CeedOperator_Hip_gen *impl;
161 
162   CeedCallBackend(CeedCalloc(1, &impl));
163   CeedCallBackend(CeedOperatorSetData(op, impl));
164 
165   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen));
166   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip_gen));
167   return CEED_ERROR_SUCCESS;
168 }
169 
170 //------------------------------------------------------------------------------
171