xref: /libCEED/backends/hip-gen/ceed-hip-gen-operator.c (revision f1c40530bc6f02bc6ca847af718b8b41cc3fd3ca)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <stddef.h>
11 #include "ceed-hip-gen.h"
12 #include "ceed-hip-gen-operator-build.h"
13 #include "../hip/ceed-hip-compile.h"
14 
15 //------------------------------------------------------------------------------
16 // Destroy operator
17 //------------------------------------------------------------------------------
18 static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
19   int ierr;
20   CeedOperator_Hip_gen *impl;
21   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
22   ierr = CeedFree(&impl); CeedChkBackend(ierr);
23   return CEED_ERROR_SUCCESS;
24 }
25 
26 //------------------------------------------------------------------------------
27 // Apply and add to output
28 //------------------------------------------------------------------------------
29 static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec,
30                                         CeedVector output_vec, CeedRequest *request) {
31   int ierr;
32   Ceed ceed;
33   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
34   CeedOperator_Hip_gen *data;
35   ierr = CeedOperatorGetData(op, &data); CeedChkBackend(ierr);
36   CeedQFunction qf;
37   CeedQFunction_Hip_gen *qf_data;
38   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
39   ierr = CeedQFunctionGetData(qf, &qf_data); CeedChkBackend(ierr);
40   CeedInt num_elem, num_input_fields, num_output_fields;
41   ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr);
42   CeedOperatorField *op_input_fields, *op_output_fields;
43   ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields,
44                                &num_output_fields, &op_output_fields);
45   CeedChkBackend(ierr);
46   CeedQFunctionField *qf_input_fields, *qf_output_fields;
47   ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL,
48                                 &qf_output_fields);
49   CeedChkBackend(ierr);
50   CeedEvalMode eval_mode;
51   CeedVector vec, output_vecs[CEED_FIELD_MAX] = {};
52 
53   //Creation of the operator
54   ierr = CeedHipGenOperatorBuild(op); CeedChkBackend(ierr);
55 
56   // Input vectors
57   for (CeedInt i = 0; i < num_input_fields; i++) {
58     ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode);
59     CeedChkBackend(ierr);
60     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
61       data->fields.inputs[i] = NULL;
62     } else {
63       // Get input vector
64       ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
65       CeedChkBackend(ierr);
66       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
67       ierr = CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]);
68       CeedChkBackend(ierr);
69     }
70   }
71 
72   // Output vectors
73   for (CeedInt i = 0; i < num_output_fields; i++) {
74     ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode);
75     CeedChkBackend(ierr);
76     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
77       data->fields.outputs[i] = NULL;
78     } else {
79       // Get output vector
80       ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec);
81       CeedChkBackend(ierr);
82       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
83       output_vecs[i] = vec;
84       // Check for multiple output modes
85       CeedInt index = -1;
86       for (CeedInt j = 0; j < i; j++) {
87         if (vec == output_vecs[j]) {
88           index = j;
89           break;
90         }
91       }
92       if (index == -1) {
93         ierr = CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i]);
94         CeedChkBackend(ierr);
95       } else {
96         data->fields.outputs[i] = data->fields.outputs[index];
97       }
98     }
99   }
100 
101   // Get context data
102   ierr = CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c);
103   CeedChkBackend(ierr);
104 
105   // Apply operator
106   void *opargs[] = {(void *) &num_elem, &qf_data->d_c, &data->indices,
107                     &data->fields, &data->B, &data->G, &data->W
108                    };
109   const CeedInt dim = data->dim;
110   const CeedInt Q_1d = data->Q_1d;
111   const CeedInt P_1d = data->max_P_1d;
112   const CeedInt thread_1d = CeedIntMax(Q_1d, P_1d);
113   CeedInt block_sizes[3];
114   ierr = BlockGridCalculate_Hip_gen(dim, num_elem, P_1d, Q_1d, block_sizes);
115   CeedChkBackend(ierr);
116   if (dim==1) {
117     CeedInt grid = num_elem/block_sizes[2] + ( (
118                      num_elem/block_sizes[2]*block_sizes[2]<num_elem)
119                    ? 1 : 0 );
120     CeedInt sharedMem = block_sizes[2]*thread_1d*sizeof(CeedScalar);
121     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0],
122                                      block_sizes[1],
123                                      block_sizes[2], sharedMem, opargs);
124   } else if (dim==2) {
125     CeedInt grid = num_elem/block_sizes[2] + ( (
126                      num_elem/block_sizes[2]*block_sizes[2]<num_elem)
127                    ? 1 : 0 );
128     CeedInt sharedMem = block_sizes[2]*thread_1d*thread_1d*sizeof(CeedScalar);
129     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0],
130                                      block_sizes[1],
131                                      block_sizes[2], sharedMem, opargs);
132   } else if (dim==3) {
133     CeedInt grid = num_elem/block_sizes[2] + ( (
134                      num_elem/block_sizes[2]*block_sizes[2]<num_elem)
135                    ? 1 : 0 );
136     CeedInt sharedMem = block_sizes[2]*thread_1d*thread_1d*sizeof(CeedScalar);
137     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0],
138                                      block_sizes[1],
139                                      block_sizes[2], sharedMem, opargs);
140   }
141   CeedChkBackend(ierr);
142 
143   // Restore input arrays
144   for (CeedInt i = 0; i < num_input_fields; i++) {
145     ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode);
146     CeedChkBackend(ierr);
147     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
148     } else {
149       ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
150       CeedChkBackend(ierr);
151       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
152       ierr = CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]);
153       CeedChkBackend(ierr);
154     }
155   }
156 
157   // Restore output arrays
158   for (CeedInt i = 0; i < num_output_fields; i++) {
159     ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode);
160     CeedChkBackend(ierr);
161     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
162     } else {
163       ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec);
164       CeedChkBackend(ierr);
165       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
166       // Check for multiple output modes
167       CeedInt index = -1;
168       for (CeedInt j = 0; j < i; j++) {
169         if (vec == output_vecs[j]) {
170           index = j;
171           break;
172         }
173       }
174       if (index == -1) {
175         ierr = CeedVectorRestoreArray(vec, &data->fields.outputs[i]);
176         CeedChkBackend(ierr);
177       }
178     }
179   }
180 
181   // Restore context data
182   ierr = CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c);
183   CeedChkBackend(ierr);
184 
185   return CEED_ERROR_SUCCESS;
186 }
187 
188 //------------------------------------------------------------------------------
189 // Create operator
190 //------------------------------------------------------------------------------
191 int CeedOperatorCreate_Hip_gen(CeedOperator op) {
192   int ierr;
193   Ceed ceed;
194   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
195   CeedOperator_Hip_gen *impl;
196 
197   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
198   ierr = CeedOperatorSetData(op, impl); CeedChkBackend(ierr);
199 
200   ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd",
201                                 CeedOperatorApplyAdd_Hip_gen); CeedChkBackend(ierr);
202   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
203                                 CeedOperatorDestroy_Hip_gen); CeedChkBackend(ierr);
204   return CEED_ERROR_SUCCESS;
205 }
206 //------------------------------------------------------------------------------
207