1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <assert.h>
12 #include <stdbool.h>
13 #include <string.h>
14 #include <hip/hip_runtime.h>
15
16 #include "../hip/ceed-hip-common.h"
17 #include "../hip/ceed-hip-compile.h"
18 #include "ceed-hip-ref.h"
19
20 //------------------------------------------------------------------------------
21 // Destroy operator
22 //------------------------------------------------------------------------------
CeedOperatorDestroy_Hip(CeedOperator op)23 static int CeedOperatorDestroy_Hip(CeedOperator op) {
24 CeedOperator_Hip *impl;
25
26 CeedCallBackend(CeedOperatorGetData(op, &impl));
27
28 // Apply data
29 CeedCallBackend(CeedFree(&impl->num_points));
30 CeedCallBackend(CeedFree(&impl->skip_rstr_in));
31 CeedCallBackend(CeedFree(&impl->skip_rstr_out));
32 CeedCallBackend(CeedFree(&impl->apply_add_basis_out));
33 CeedCallBackend(CeedFree(&impl->input_field_order));
34 CeedCallBackend(CeedFree(&impl->output_field_order));
35 CeedCallBackend(CeedFree(&impl->input_states));
36
37 for (CeedInt i = 0; i < impl->num_inputs; i++) {
38 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i]));
39 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i]));
40 }
41 CeedCallBackend(CeedFree(&impl->e_vecs_in));
42 CeedCallBackend(CeedFree(&impl->q_vecs_in));
43
44 for (CeedInt i = 0; i < impl->num_outputs; i++) {
45 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i]));
46 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i]));
47 }
48 CeedCallBackend(CeedFree(&impl->e_vecs_out));
49 CeedCallBackend(CeedFree(&impl->q_vecs_out));
50 CeedCallBackend(CeedVectorDestroy(&impl->point_coords_elem));
51
52 // QFunction assembly data
53 for (CeedInt i = 0; i < impl->num_active_in; i++) {
54 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i]));
55 }
56 CeedCallBackend(CeedFree(&impl->qf_active_in));
57
58 // Diag data
59 if (impl->diag) {
60 Ceed ceed;
61
62 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
63 if (impl->diag->module) {
64 CeedCallHip(ceed, hipModuleUnload(impl->diag->module));
65 }
66 if (impl->diag->module_point_block) {
67 CeedCallHip(ceed, hipModuleUnload(impl->diag->module_point_block));
68 }
69 CeedCallHip(ceed, hipFree(impl->diag->d_eval_modes_in));
70 CeedCallHip(ceed, hipFree(impl->diag->d_eval_modes_out));
71 CeedCallHip(ceed, hipFree(impl->diag->d_identity));
72 CeedCallHip(ceed, hipFree(impl->diag->d_interp_in));
73 CeedCallHip(ceed, hipFree(impl->diag->d_interp_out));
74 CeedCallHip(ceed, hipFree(impl->diag->d_grad_in));
75 CeedCallHip(ceed, hipFree(impl->diag->d_grad_out));
76 CeedCallHip(ceed, hipFree(impl->diag->d_div_in));
77 CeedCallHip(ceed, hipFree(impl->diag->d_div_out));
78 CeedCallHip(ceed, hipFree(impl->diag->d_curl_in));
79 CeedCallHip(ceed, hipFree(impl->diag->d_curl_out));
80 CeedCallBackend(CeedDestroy(&ceed));
81 CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag));
82 CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag));
83 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr));
84 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr));
85 }
86 CeedCallBackend(CeedFree(&impl->diag));
87
88 if (impl->asmb) {
89 Ceed ceed;
90
91 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
92 CeedCallHip(ceed, hipModuleUnload(impl->asmb->module));
93 CeedCallHip(ceed, hipFree(impl->asmb->d_B_in));
94 CeedCallHip(ceed, hipFree(impl->asmb->d_B_out));
95 CeedCallBackend(CeedDestroy(&ceed));
96 }
97 CeedCallBackend(CeedFree(&impl->asmb));
98
99 CeedCallBackend(CeedFree(&impl));
100 return CEED_ERROR_SUCCESS;
101 }
102
103 //------------------------------------------------------------------------------
104 // Setup infields or outfields
105 //------------------------------------------------------------------------------
CeedOperatorSetupFields_Hip(CeedQFunction qf,CeedOperator op,bool is_input,bool is_at_points,bool * skip_rstr,bool * apply_add_basis,CeedVector * e_vecs,CeedVector * q_vecs,CeedInt num_fields,CeedInt Q,CeedInt num_elem)106 static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, bool *apply_add_basis,
107 CeedVector *e_vecs, CeedVector *q_vecs, CeedInt num_fields, CeedInt Q, CeedInt num_elem) {
108 Ceed ceed;
109 CeedQFunctionField *qf_fields;
110 CeedOperatorField *op_fields;
111
112 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
113 if (is_input) {
114 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL));
115 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL));
116 } else {
117 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields));
118 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields));
119 }
120
121 // Loop over fields
122 for (CeedInt i = 0; i < num_fields; i++) {
123 bool is_active = false, is_strided = false, skip_e_vec = false;
124 CeedSize q_size;
125 CeedInt size;
126 CeedEvalMode eval_mode;
127 CeedVector l_vec;
128 CeedElemRestriction elem_rstr;
129
130 // Check whether this field can skip the element restriction:
131 // Input CEED_VECTOR_ACTIVE
132 // Output CEED_VECTOR_ACTIVE without CEED_EVAL_NONE
133 // Input CEED_VECTOR_NONE with CEED_EVAL_WEIGHT
134 // Input passive vector with CEED_EVAL_NONE and strided restriction with CEED_STRIDES_BACKEND
135 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &l_vec));
136 is_active = l_vec == CEED_VECTOR_ACTIVE;
137 CeedCallBackend(CeedVectorDestroy(&l_vec));
138 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr));
139 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
140 skip_e_vec = (is_input && is_active) || (is_active && eval_mode != CEED_EVAL_NONE) || (eval_mode == CEED_EVAL_WEIGHT);
141 if (!skip_e_vec && is_input && !is_active && eval_mode == CEED_EVAL_NONE) {
142 CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided));
143 if (is_strided) CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &skip_e_vec));
144 }
145 if (skip_e_vec) {
146 e_vecs[i] = NULL;
147 } else {
148 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs[i]));
149 }
150 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
151
152 switch (eval_mode) {
153 case CEED_EVAL_NONE:
154 case CEED_EVAL_INTERP:
155 case CEED_EVAL_GRAD:
156 case CEED_EVAL_DIV:
157 case CEED_EVAL_CURL:
158 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
159 q_size = (CeedSize)num_elem * (CeedSize)Q * (CeedSize)size;
160 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
161 break;
162 case CEED_EVAL_WEIGHT: {
163 CeedBasis basis;
164
165 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
166 q_size = (CeedSize)num_elem * (CeedSize)Q;
167 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
168 if (is_at_points) {
169 CeedInt num_points[num_elem];
170
171 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = Q;
172 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, CEED_VECTOR_NONE,
173 q_vecs[i]));
174 } else {
175 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i]));
176 }
177 CeedCallBackend(CeedBasisDestroy(&basis));
178 break;
179 }
180 }
181 }
182 // Drop duplicate restrictions
183 if (is_input) {
184 for (CeedInt i = 0; i < num_fields; i++) {
185 CeedVector vec_i;
186 CeedElemRestriction rstr_i;
187
188 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i));
189 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i));
190 for (CeedInt j = i + 1; j < num_fields; j++) {
191 CeedVector vec_j;
192 CeedElemRestriction rstr_j;
193
194 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
195 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
196 if (vec_i == vec_j && rstr_i == rstr_j) {
197 if (e_vecs[i]) CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
198 skip_rstr[j] = true;
199 }
200 CeedCallBackend(CeedVectorDestroy(&vec_j));
201 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j));
202 }
203 CeedCallBackend(CeedVectorDestroy(&vec_i));
204 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
205 }
206 } else {
207 for (CeedInt i = num_fields - 1; i >= 0; i--) {
208 CeedVector vec_i;
209 CeedElemRestriction rstr_i;
210
211 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i));
212 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i));
213 for (CeedInt j = i - 1; j >= 0; j--) {
214 CeedVector vec_j;
215 CeedElemRestriction rstr_j;
216
217 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
218 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
219 if (vec_i == vec_j && rstr_i == rstr_j) {
220 if (e_vecs[i]) CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
221 skip_rstr[j] = true;
222 apply_add_basis[i] = true;
223 }
224 CeedCallBackend(CeedVectorDestroy(&vec_j));
225 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j));
226 }
227 CeedCallBackend(CeedVectorDestroy(&vec_i));
228 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
229 }
230 }
231 CeedCallBackend(CeedDestroy(&ceed));
232 return CEED_ERROR_SUCCESS;
233 }
234
235 //------------------------------------------------------------------------------
236 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction.
237 //------------------------------------------------------------------------------
CeedOperatorSetup_Hip(CeedOperator op)238 static int CeedOperatorSetup_Hip(CeedOperator op) {
239 bool is_setup_done;
240 CeedInt Q, num_elem, num_input_fields, num_output_fields;
241 CeedQFunctionField *qf_input_fields, *qf_output_fields;
242 CeedQFunction qf;
243 CeedOperatorField *op_input_fields, *op_output_fields;
244 CeedOperator_Hip *impl;
245
246 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done));
247 if (is_setup_done) return CEED_ERROR_SUCCESS;
248
249 CeedCallBackend(CeedOperatorGetData(op, &impl));
250 CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
251 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
252 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
253 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
254 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
255
256 // Allocate
257 CeedCallBackend(CeedCalloc(num_input_fields, &impl->e_vecs_in));
258 CeedCallBackend(CeedCalloc(num_output_fields, &impl->e_vecs_out));
259 CeedCallBackend(CeedCalloc(num_input_fields, &impl->skip_rstr_in));
260 CeedCallBackend(CeedCalloc(num_output_fields, &impl->skip_rstr_out));
261 CeedCallBackend(CeedCalloc(num_output_fields, &impl->apply_add_basis_out));
262 CeedCallBackend(CeedCalloc(num_input_fields, &impl->input_field_order));
263 CeedCallBackend(CeedCalloc(num_output_fields, &impl->output_field_order));
264 CeedCallBackend(CeedCalloc(num_input_fields, &impl->input_states));
265 CeedCallBackend(CeedCalloc(num_input_fields, &impl->q_vecs_in));
266 CeedCallBackend(CeedCalloc(num_output_fields, &impl->q_vecs_out));
267 impl->num_inputs = num_input_fields;
268 impl->num_outputs = num_output_fields;
269
270 // Set up infield and outfield e-vecs and q-vecs
271 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, false, impl->skip_rstr_in, NULL, impl->e_vecs_in, impl->q_vecs_in, num_input_fields, Q,
272 num_elem));
273 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, false, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs_out,
274 impl->q_vecs_out, num_output_fields, Q, num_elem));
275
276 // Reorder fields to allow reuse of buffers
277 impl->max_active_e_vec_len = 0;
278 {
279 bool is_ordered[CEED_FIELD_MAX];
280 CeedInt curr_index = 0;
281
282 for (CeedInt i = 0; i < num_input_fields; i++) is_ordered[i] = false;
283 for (CeedInt i = 0; i < num_input_fields; i++) {
284 CeedSize e_vec_len_i;
285 CeedVector vec_i;
286 CeedElemRestriction rstr_i;
287
288 if (is_ordered[i]) continue;
289 is_ordered[i] = true;
290 impl->input_field_order[curr_index] = i;
291 curr_index++;
292 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i));
293 if (vec_i == CEED_VECTOR_NONE) {
294 // CEED_EVAL_WEIGHT
295 CeedCallBackend(CeedVectorDestroy(&vec_i));
296 continue;
297 };
298 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i));
299 CeedCallBackend(CeedElemRestrictionGetEVectorSize(rstr_i, &e_vec_len_i));
300 impl->max_active_e_vec_len = e_vec_len_i > impl->max_active_e_vec_len ? e_vec_len_i : impl->max_active_e_vec_len;
301 for (CeedInt j = i + 1; j < num_input_fields; j++) {
302 CeedVector vec_j;
303 CeedElemRestriction rstr_j;
304
305 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[j], &vec_j));
306 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[j], &rstr_j));
307 if (rstr_i == rstr_j && vec_i == vec_j) {
308 is_ordered[j] = true;
309 impl->input_field_order[curr_index] = j;
310 curr_index++;
311 }
312 CeedCallBackend(CeedVectorDestroy(&vec_j));
313 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j));
314 }
315 CeedCallBackend(CeedVectorDestroy(&vec_i));
316 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
317 }
318 }
319 {
320 bool is_ordered[CEED_FIELD_MAX];
321 CeedInt curr_index = 0;
322
323 for (CeedInt i = 0; i < num_output_fields; i++) is_ordered[i] = false;
324 for (CeedInt i = 0; i < num_output_fields; i++) {
325 CeedSize e_vec_len_i;
326 CeedVector vec_i;
327 CeedElemRestriction rstr_i;
328
329 if (is_ordered[i]) continue;
330 is_ordered[i] = true;
331 impl->output_field_order[curr_index] = i;
332 curr_index++;
333 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec_i));
334 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &rstr_i));
335 CeedCallBackend(CeedElemRestrictionGetEVectorSize(rstr_i, &e_vec_len_i));
336 impl->max_active_e_vec_len = e_vec_len_i > impl->max_active_e_vec_len ? e_vec_len_i : impl->max_active_e_vec_len;
337 for (CeedInt j = i + 1; j < num_output_fields; j++) {
338 CeedVector vec_j;
339 CeedElemRestriction rstr_j;
340
341 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec_j));
342 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &rstr_j));
343 if (rstr_i == rstr_j && vec_i == vec_j) {
344 is_ordered[j] = true;
345 impl->output_field_order[curr_index] = j;
346 curr_index++;
347 }
348 CeedCallBackend(CeedVectorDestroy(&vec_j));
349 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j));
350 }
351 CeedCallBackend(CeedVectorDestroy(&vec_i));
352 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
353 }
354 }
355 CeedCallBackend(CeedClearWorkVectors(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len));
356 {
357 // Create two work vectors for diagonal assembly
358 CeedVector temp_1, temp_2;
359
360 CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_1));
361 CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_2));
362 CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_1));
363 CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_2));
364 }
365 CeedCallBackend(CeedOperatorSetSetupDone(op));
366 CeedCallBackend(CeedQFunctionDestroy(&qf));
367 return CEED_ERROR_SUCCESS;
368 }
369
370 //------------------------------------------------------------------------------
371 // Restrict Operator Inputs
372 //------------------------------------------------------------------------------
CeedOperatorInputRestrict_Hip(CeedOperatorField op_input_field,CeedQFunctionField qf_input_field,CeedInt input_field,CeedVector in_vec,CeedVector active_e_vec,const bool skip_active,CeedOperator_Hip * impl,CeedRequest * request)373 static inline int CeedOperatorInputRestrict_Hip(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field,
374 CeedVector in_vec, CeedVector active_e_vec, const bool skip_active, CeedOperator_Hip *impl,
375 CeedRequest *request) {
376 bool is_active = false;
377 CeedVector l_vec, e_vec = impl->e_vecs_in[input_field];
378
379 // Get input vector
380 CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec));
381 is_active = l_vec == CEED_VECTOR_ACTIVE;
382 if (is_active && skip_active) return CEED_ERROR_SUCCESS;
383 if (is_active) {
384 l_vec = in_vec;
385 if (!e_vec) e_vec = active_e_vec;
386 }
387
388 // Restriction action
389 if (e_vec) {
390 // Restrict, if necessary
391 if (!impl->skip_rstr_in[input_field]) {
392 uint64_t state;
393
394 CeedCallBackend(CeedVectorGetState(l_vec, &state));
395 if (is_active || state != impl->input_states[input_field]) {
396 CeedElemRestriction elem_rstr;
397
398 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_field, &elem_rstr));
399 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, l_vec, e_vec, request));
400 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
401 }
402 impl->input_states[input_field] = state;
403 }
404 }
405 if (!is_active) CeedCallBackend(CeedVectorDestroy(&l_vec));
406 return CEED_ERROR_SUCCESS;
407 }
408
409 //------------------------------------------------------------------------------
410 // Input Basis Action
411 //------------------------------------------------------------------------------
CeedOperatorInputBasis_Hip(CeedOperatorField op_input_field,CeedQFunctionField qf_input_field,CeedInt input_field,CeedVector in_vec,CeedVector active_e_vec,CeedInt num_elem,const bool skip_active,CeedOperator_Hip * impl)412 static inline int CeedOperatorInputBasis_Hip(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field,
413 CeedVector in_vec, CeedVector active_e_vec, CeedInt num_elem, const bool skip_active,
414 CeedOperator_Hip *impl) {
415 bool is_active = false;
416 CeedEvalMode eval_mode;
417 CeedVector l_vec, e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field];
418
419 // Skip active input
420 CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec));
421 is_active = l_vec == CEED_VECTOR_ACTIVE;
422 if (is_active && skip_active) return CEED_ERROR_SUCCESS;
423 if (is_active) {
424 l_vec = in_vec;
425 if (!e_vec) e_vec = active_e_vec;
426 }
427
428 // Basis action
429 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode));
430 switch (eval_mode) {
431 case CEED_EVAL_NONE: {
432 const CeedScalar *e_vec_array;
433
434 if (e_vec) {
435 CeedCallBackend(CeedVectorGetArrayRead(e_vec, CEED_MEM_DEVICE, &e_vec_array));
436 } else {
437 CeedCallBackend(CeedVectorGetArrayRead(l_vec, CEED_MEM_DEVICE, &e_vec_array));
438 }
439 CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array));
440 break;
441 }
442 case CEED_EVAL_INTERP:
443 case CEED_EVAL_GRAD:
444 case CEED_EVAL_DIV:
445 case CEED_EVAL_CURL: {
446 CeedBasis basis;
447
448 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_field, &basis));
449 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, eval_mode, e_vec, q_vec));
450 CeedCallBackend(CeedBasisDestroy(&basis));
451 break;
452 }
453 case CEED_EVAL_WEIGHT:
454 break; // No action
455 }
456 if (!is_active) CeedCallBackend(CeedVectorDestroy(&l_vec));
457 return CEED_ERROR_SUCCESS;
458 }
459
460 //------------------------------------------------------------------------------
461 // Restore Input Vectors
462 //------------------------------------------------------------------------------
CeedOperatorInputRestore_Hip(CeedOperatorField op_input_field,CeedQFunctionField qf_input_field,CeedInt input_field,CeedVector in_vec,CeedVector active_e_vec,const bool skip_active,CeedOperator_Hip * impl)463 static inline int CeedOperatorInputRestore_Hip(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field,
464 CeedVector in_vec, CeedVector active_e_vec, const bool skip_active, CeedOperator_Hip *impl) {
465 bool is_active = false;
466 CeedEvalMode eval_mode;
467 CeedVector l_vec, e_vec = impl->e_vecs_in[input_field];
468
469 // Skip active input
470 CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec));
471 is_active = l_vec == CEED_VECTOR_ACTIVE;
472 if (is_active && skip_active) return CEED_ERROR_SUCCESS;
473 if (is_active) {
474 l_vec = in_vec;
475 if (!e_vec) e_vec = active_e_vec;
476 }
477
478 // Restore e-vec
479 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode));
480 if (eval_mode == CEED_EVAL_NONE) {
481 const CeedScalar *e_vec_array;
482
483 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_in[input_field], CEED_MEM_DEVICE, (CeedScalar **)&e_vec_array));
484 if (e_vec) {
485 CeedCallBackend(CeedVectorRestoreArrayRead(e_vec, &e_vec_array));
486 } else {
487 CeedCallBackend(CeedVectorRestoreArrayRead(l_vec, &e_vec_array));
488 }
489 }
490 if (!is_active) CeedCallBackend(CeedVectorDestroy(&l_vec));
491 return CEED_ERROR_SUCCESS;
492 }
493
494 //------------------------------------------------------------------------------
495 // Apply and add to output
496 //------------------------------------------------------------------------------
CeedOperatorApplyAdd_Hip(CeedOperator op,CeedVector in_vec,CeedVector out_vec,CeedRequest * request)497 static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) {
498 CeedInt Q, num_elem, num_input_fields, num_output_fields;
499 Ceed ceed;
500 CeedVector active_e_vec;
501 CeedQFunctionField *qf_input_fields, *qf_output_fields;
502 CeedQFunction qf;
503 CeedOperatorField *op_input_fields, *op_output_fields;
504 CeedOperator_Hip *impl;
505
506 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
507 CeedCallBackend(CeedOperatorGetData(op, &impl));
508 CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
509 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
510 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
511 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
512 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
513
514 // Setup
515 CeedCallBackend(CeedOperatorSetup_Hip(op));
516
517 // Work vector
518 CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec));
519
520 // Process inputs
521 for (CeedInt i = 0; i < num_input_fields; i++) {
522 CeedInt field = impl->input_field_order[i];
523
524 CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, false, impl, request));
525 CeedCallBackend(CeedOperatorInputBasis_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, num_elem, false, impl));
526 }
527
528 // Output pointers, as necessary
529 for (CeedInt i = 0; i < num_output_fields; i++) {
530 CeedEvalMode eval_mode;
531
532 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
533 if (eval_mode == CEED_EVAL_NONE) {
534 CeedScalar *e_vec_array;
535
536 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array));
537 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array));
538 }
539 }
540
541 // Q function
542 CeedCallBackend(CeedQFunctionApply(qf, num_elem * Q, impl->q_vecs_in, impl->q_vecs_out));
543
544 // Restore input arrays
545 for (CeedInt i = 0; i < num_input_fields; i++) {
546 CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, in_vec, active_e_vec, false, impl));
547 }
548
549 // Output basis and restriction
550 for (CeedInt i = 0; i < num_output_fields; i++) {
551 bool is_active = false;
552 CeedInt field = impl->output_field_order[i];
553 CeedEvalMode eval_mode;
554 CeedVector l_vec, e_vec = impl->e_vecs_out[field], q_vec = impl->q_vecs_out[field];
555
556 // Output vector
557 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[field], &l_vec));
558 is_active = l_vec == CEED_VECTOR_ACTIVE;
559 if (is_active) {
560 l_vec = out_vec;
561 if (!e_vec) e_vec = active_e_vec;
562 }
563
564 // Basis action
565 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[field], &eval_mode));
566 switch (eval_mode) {
567 case CEED_EVAL_NONE:
568 break; // No action
569 case CEED_EVAL_INTERP:
570 case CEED_EVAL_GRAD:
571 case CEED_EVAL_DIV:
572 case CEED_EVAL_CURL: {
573 CeedBasis basis;
574
575 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[field], &basis));
576 if (impl->apply_add_basis_out[field]) {
577 CeedCallBackend(CeedBasisApplyAdd(basis, num_elem, CEED_TRANSPOSE, eval_mode, q_vec, e_vec));
578 } else {
579 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, q_vec, e_vec));
580 }
581 CeedCallBackend(CeedBasisDestroy(&basis));
582 break;
583 }
584 // LCOV_EXCL_START
585 case CEED_EVAL_WEIGHT: {
586 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
587 // LCOV_EXCL_STOP
588 }
589 }
590
591 // Restore evec
592 if (eval_mode == CEED_EVAL_NONE) {
593 CeedScalar *e_vec_array;
594
595 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[field], CEED_MEM_DEVICE, &e_vec_array));
596 CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_vec_array));
597 }
598
599 // Restrict
600 if (!impl->skip_rstr_out[field]) {
601 CeedElemRestriction elem_rstr;
602
603 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[field], &elem_rstr));
604 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, e_vec, l_vec, request));
605 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
606 }
607 if (!is_active) CeedCallBackend(CeedVectorDestroy(&l_vec));
608 }
609
610 // Return work vector
611 CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec));
612 CeedCallBackend(CeedDestroy(&ceed));
613 CeedCallBackend(CeedQFunctionDestroy(&qf));
614 return CEED_ERROR_SUCCESS;
615 }
616
617 //------------------------------------------------------------------------------
618 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction.
619 //------------------------------------------------------------------------------
CeedOperatorSetupAtPoints_Hip(CeedOperator op)620 static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) {
621 bool is_setup_done;
622 CeedInt max_num_points = -1, num_elem, num_input_fields, num_output_fields;
623 CeedQFunctionField *qf_input_fields, *qf_output_fields;
624 CeedQFunction qf;
625 CeedOperatorField *op_input_fields, *op_output_fields;
626 CeedOperator_Hip *impl;
627
628 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done));
629 if (is_setup_done) return CEED_ERROR_SUCCESS;
630
631 CeedCallBackend(CeedOperatorGetData(op, &impl));
632 CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
633 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
634 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
635 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
636 {
637 CeedElemRestriction rstr_points = NULL;
638
639 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
640 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(rstr_points, &max_num_points));
641 CeedCallBackend(CeedCalloc(num_elem, &impl->num_points));
642 for (CeedInt e = 0; e < num_elem; e++) {
643 CeedInt num_points_elem;
644
645 CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
646 impl->num_points[e] = num_points_elem;
647 }
648 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
649 }
650 impl->max_num_points = max_num_points;
651
652 // Allocate
653 CeedCallBackend(CeedCalloc(num_input_fields, &impl->e_vecs_in));
654 CeedCallBackend(CeedCalloc(num_output_fields, &impl->e_vecs_out));
655 CeedCallBackend(CeedCalloc(num_input_fields, &impl->skip_rstr_in));
656 CeedCallBackend(CeedCalloc(num_output_fields, &impl->skip_rstr_out));
657 CeedCallBackend(CeedCalloc(num_output_fields, &impl->apply_add_basis_out));
658 CeedCallBackend(CeedCalloc(num_input_fields, &impl->input_field_order));
659 CeedCallBackend(CeedCalloc(num_output_fields, &impl->output_field_order));
660 CeedCallBackend(CeedCalloc(num_input_fields, &impl->input_states));
661 CeedCallBackend(CeedCalloc(num_input_fields, &impl->q_vecs_in));
662 CeedCallBackend(CeedCalloc(num_output_fields, &impl->q_vecs_out));
663 impl->num_inputs = num_input_fields;
664 impl->num_outputs = num_output_fields;
665
666 // Set up infield and outfield e-vecs and q-vecs
667 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, true, impl->skip_rstr_in, NULL, impl->e_vecs_in, impl->q_vecs_in, num_input_fields,
668 max_num_points, num_elem));
669 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs_out, impl->q_vecs_out,
670 num_output_fields, max_num_points, num_elem));
671
672 // Reorder fields to allow reuse of buffers
673 impl->max_active_e_vec_len = 0;
674 {
675 bool is_ordered[CEED_FIELD_MAX];
676 CeedInt curr_index = 0;
677
678 for (CeedInt i = 0; i < num_input_fields; i++) is_ordered[i] = false;
679 for (CeedInt i = 0; i < num_input_fields; i++) {
680 CeedSize e_vec_len_i;
681 CeedVector vec_i;
682 CeedElemRestriction rstr_i;
683
684 if (is_ordered[i]) continue;
685 is_ordered[i] = true;
686 impl->input_field_order[curr_index] = i;
687 curr_index++;
688 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i));
689 if (vec_i == CEED_VECTOR_NONE) {
690 // CEED_EVAL_WEIGHT
691 CeedCallBackend(CeedVectorDestroy(&vec_i));
692 continue;
693 };
694 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i));
695 CeedCallBackend(CeedElemRestrictionGetEVectorSize(rstr_i, &e_vec_len_i));
696 impl->max_active_e_vec_len = e_vec_len_i > impl->max_active_e_vec_len ? e_vec_len_i : impl->max_active_e_vec_len;
697 for (CeedInt j = i + 1; j < num_input_fields; j++) {
698 CeedVector vec_j;
699 CeedElemRestriction rstr_j;
700
701 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[j], &vec_j));
702 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[j], &rstr_j));
703 if (rstr_i == rstr_j && vec_i == vec_j) {
704 is_ordered[j] = true;
705 impl->input_field_order[curr_index] = j;
706 curr_index++;
707 }
708 CeedCallBackend(CeedVectorDestroy(&vec_j));
709 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j));
710 }
711 CeedCallBackend(CeedVectorDestroy(&vec_i));
712 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
713 }
714 }
715 {
716 bool is_ordered[CEED_FIELD_MAX];
717 CeedInt curr_index = 0;
718
719 for (CeedInt i = 0; i < num_output_fields; i++) is_ordered[i] = false;
720 for (CeedInt i = 0; i < num_output_fields; i++) {
721 CeedSize e_vec_len_i;
722 CeedVector vec_i;
723 CeedElemRestriction rstr_i;
724
725 if (is_ordered[i]) continue;
726 is_ordered[i] = true;
727 impl->output_field_order[curr_index] = i;
728 curr_index++;
729 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec_i));
730 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &rstr_i));
731 CeedCallBackend(CeedElemRestrictionGetEVectorSize(rstr_i, &e_vec_len_i));
732 impl->max_active_e_vec_len = e_vec_len_i > impl->max_active_e_vec_len ? e_vec_len_i : impl->max_active_e_vec_len;
733 for (CeedInt j = i + 1; j < num_output_fields; j++) {
734 CeedVector vec_j;
735 CeedElemRestriction rstr_j;
736
737 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec_j));
738 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &rstr_j));
739 if (rstr_i == rstr_j && vec_i == vec_j) {
740 is_ordered[j] = true;
741 impl->output_field_order[curr_index] = j;
742 curr_index++;
743 }
744 CeedCallBackend(CeedVectorDestroy(&vec_j));
745 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j));
746 }
747 CeedCallBackend(CeedVectorDestroy(&vec_i));
748 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
749 }
750 }
751 CeedCallBackend(CeedClearWorkVectors(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len));
752 {
753 // Create two work vectors for diagonal assembly
754 CeedVector temp_1, temp_2;
755
756 CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_1));
757 CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_2));
758 CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_1));
759 CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_2));
760 }
761 CeedCallBackend(CeedOperatorSetSetupDone(op));
762 CeedCallBackend(CeedQFunctionDestroy(&qf));
763 return CEED_ERROR_SUCCESS;
764 }
765
766 //------------------------------------------------------------------------------
767 // Input Basis Action AtPoints
768 //------------------------------------------------------------------------------
CeedOperatorInputBasisAtPoints_Hip(CeedOperatorField op_input_field,CeedQFunctionField qf_input_field,CeedInt input_field,CeedVector in_vec,CeedVector active_e_vec,CeedInt num_elem,const CeedInt * num_points,const bool skip_active,const bool skip_passive,CeedOperator_Hip * impl)769 static inline int CeedOperatorInputBasisAtPoints_Hip(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field,
770 CeedVector in_vec, CeedVector active_e_vec, CeedInt num_elem, const CeedInt *num_points,
771 const bool skip_active, const bool skip_passive, CeedOperator_Hip *impl) {
772 bool is_active = false;
773 CeedEvalMode eval_mode;
774 CeedVector l_vec, e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field];
775
776 // Skip active input
777 CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec));
778 is_active = l_vec == CEED_VECTOR_ACTIVE;
779 if (skip_active && is_active) return CEED_ERROR_SUCCESS;
780 if (skip_passive && !is_active) {
781 CeedCallBackend(CeedVectorDestroy(&l_vec));
782 return CEED_ERROR_SUCCESS;
783 }
784 if (is_active) {
785 l_vec = in_vec;
786 if (!e_vec) e_vec = active_e_vec;
787 }
788
789 // Basis action
790 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode));
791 switch (eval_mode) {
792 case CEED_EVAL_NONE: {
793 const CeedScalar *e_vec_array;
794
795 if (e_vec) {
796 CeedCallBackend(CeedVectorGetArrayRead(e_vec, CEED_MEM_DEVICE, &e_vec_array));
797 } else {
798 CeedCallBackend(CeedVectorGetArrayRead(l_vec, CEED_MEM_DEVICE, &e_vec_array));
799 }
800 CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array));
801 break;
802 }
803 case CEED_EVAL_INTERP:
804 case CEED_EVAL_GRAD:
805 case CEED_EVAL_DIV:
806 case CEED_EVAL_CURL: {
807 CeedBasis basis;
808
809 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_field, &basis));
810 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, e_vec, q_vec));
811 CeedCallBackend(CeedBasisDestroy(&basis));
812 break;
813 }
814 case CEED_EVAL_WEIGHT:
815 break; // No action
816 }
817 if (!is_active) CeedCallBackend(CeedVectorDestroy(&l_vec));
818 return CEED_ERROR_SUCCESS;
819 }
820
821 //------------------------------------------------------------------------------
822 // Apply and add to output AtPoints
823 //------------------------------------------------------------------------------
CeedOperatorApplyAddAtPoints_Hip(CeedOperator op,CeedVector in_vec,CeedVector out_vec,CeedRequest * request)824 static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) {
825 CeedInt max_num_points, *num_points, num_elem, num_input_fields, num_output_fields;
826 Ceed ceed;
827 CeedVector active_e_vec;
828 CeedQFunctionField *qf_input_fields, *qf_output_fields;
829 CeedQFunction qf;
830 CeedOperatorField *op_input_fields, *op_output_fields;
831 CeedOperator_Hip *impl;
832
833 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
834 CeedCallBackend(CeedOperatorGetData(op, &impl));
835 CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
836 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
837 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
838 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
839
840 // Setup
841 CeedCallBackend(CeedOperatorSetupAtPoints_Hip(op));
842 num_points = impl->num_points;
843 max_num_points = impl->max_num_points;
844
845 // Work vector
846 CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec));
847
848 // Get point coordinates
849 {
850 CeedVector point_coords = NULL;
851 CeedElemRestriction rstr_points = NULL;
852
853 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords));
854 if (!impl->point_coords_elem) CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem));
855 {
856 uint64_t state;
857 CeedCallBackend(CeedVectorGetState(point_coords, &state));
858 if (impl->points_state != state) {
859 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request));
860 }
861 }
862 CeedCallBackend(CeedVectorDestroy(&point_coords));
863 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
864 }
865
866 // Process inputs
867 for (CeedInt i = 0; i < num_input_fields; i++) {
868 CeedInt field = impl->input_field_order[i];
869
870 CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, false, impl, request));
871 CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, num_elem,
872 num_points, false, false, impl));
873 }
874
875 // Output pointers, as necessary
876 for (CeedInt i = 0; i < num_output_fields; i++) {
877 CeedEvalMode eval_mode;
878
879 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
880 if (eval_mode == CEED_EVAL_NONE) {
881 CeedScalar *e_vec_array;
882
883 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array));
884 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array));
885 }
886 }
887
888 // Q function
889 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out));
890
891 // Restore input arrays
892 for (CeedInt i = 0; i < num_input_fields; i++) {
893 CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, in_vec, active_e_vec, false, impl));
894 }
895
896 // Output basis and restriction
897 for (CeedInt i = 0; i < num_output_fields; i++) {
898 bool is_active = false;
899 CeedInt field = impl->output_field_order[i];
900 CeedEvalMode eval_mode;
901 CeedVector l_vec, e_vec = impl->e_vecs_out[field], q_vec = impl->q_vecs_out[field];
902
903 // Output vector
904 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[field], &l_vec));
905 is_active = l_vec == CEED_VECTOR_ACTIVE;
906 if (is_active) {
907 l_vec = out_vec;
908 if (!e_vec) e_vec = active_e_vec;
909 }
910
911 // Basis action
912 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[field], &eval_mode));
913 switch (eval_mode) {
914 case CEED_EVAL_NONE:
915 break; // No action
916 case CEED_EVAL_INTERP:
917 case CEED_EVAL_GRAD:
918 case CEED_EVAL_DIV:
919 case CEED_EVAL_CURL: {
920 CeedBasis basis;
921
922 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[field], &basis));
923 if (impl->apply_add_basis_out[field]) {
924 CeedCallBackend(CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec));
925 } else {
926 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec));
927 }
928 CeedCallBackend(CeedBasisDestroy(&basis));
929 break;
930 }
931 // LCOV_EXCL_START
932 case CEED_EVAL_WEIGHT: {
933 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
934 // LCOV_EXCL_STOP
935 }
936 }
937
938 // Restore evec
939 if (eval_mode == CEED_EVAL_NONE) {
940 CeedScalar *e_vec_array;
941
942 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[field], CEED_MEM_DEVICE, &e_vec_array));
943 CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_vec_array));
944 }
945
946 // Restrict
947 if (!impl->skip_rstr_out[field]) {
948 CeedElemRestriction elem_rstr;
949
950 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[field], &elem_rstr));
951 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, e_vec, l_vec, request));
952 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
953 }
954 if (!is_active) CeedCallBackend(CeedVectorDestroy(&l_vec));
955 }
956
957 // Restore work vector
958 CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec));
959 CeedCallBackend(CeedDestroy(&ceed));
960 CeedCallBackend(CeedQFunctionDestroy(&qf));
961 return CEED_ERROR_SUCCESS;
962 }
963
964 //------------------------------------------------------------------------------
965 // Linear QFunction Assembly Core
966 //------------------------------------------------------------------------------
CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op,bool build_objects,CeedVector * assembled,CeedElemRestriction * rstr,CeedRequest * request)967 static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr,
968 CeedRequest *request) {
969 Ceed ceed, ceed_parent;
970 CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size;
971 CeedScalar *assembled_array;
972 CeedVector *active_inputs;
973 CeedQFunctionField *qf_input_fields, *qf_output_fields;
974 CeedQFunction qf;
975 CeedOperatorField *op_input_fields, *op_output_fields;
976 CeedOperator_Hip *impl;
977
978 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
979 CeedCallBackend(CeedOperatorGetFallbackParentCeed(op, &ceed_parent));
980 CeedCallBackend(CeedOperatorGetData(op, &impl));
981 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
982 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
983 CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
984 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
985 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
986 active_inputs = impl->qf_active_in;
987 num_active_in = impl->num_active_in, num_active_out = impl->num_active_out;
988
989 // Setup
990 CeedCallBackend(CeedOperatorSetup_Hip(op));
991
992 // Process inputs
993 for (CeedInt i = 0; i < num_input_fields; i++) {
994 CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl, request));
995 CeedCallBackend(CeedOperatorInputBasis_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, true, impl));
996 }
997
998 // Count number of active input fields
999 if (!num_active_in) {
1000 for (CeedInt i = 0; i < num_input_fields; i++) {
1001 CeedScalar *q_vec_array;
1002 CeedVector l_vec;
1003
1004 // Check if active input
1005 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &l_vec));
1006 if (l_vec == CEED_VECTOR_ACTIVE) {
1007 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size));
1008 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0));
1009 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, &q_vec_array));
1010 CeedCallBackend(CeedRealloc(num_active_in + size, &active_inputs));
1011 for (CeedInt field = 0; field < size; field++) {
1012 CeedSize q_size = (CeedSize)Q * num_elem;
1013
1014 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_inputs[num_active_in + field]));
1015 CeedCallBackend(CeedVectorSetArray(active_inputs[num_active_in + field], CEED_MEM_DEVICE, CEED_USE_POINTER,
1016 &q_vec_array[field * Q * num_elem]));
1017 }
1018 num_active_in += size;
1019 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array));
1020 }
1021 CeedCallBackend(CeedVectorDestroy(&l_vec));
1022 }
1023 impl->num_active_in = num_active_in;
1024 impl->qf_active_in = active_inputs;
1025 }
1026
1027 // Count number of active output fields
1028 if (!num_active_out) {
1029 for (CeedInt i = 0; i < num_output_fields; i++) {
1030 CeedVector l_vec;
1031
1032 // Check if active output
1033 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &l_vec));
1034 if (l_vec == CEED_VECTOR_ACTIVE) {
1035 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size));
1036 num_active_out += size;
1037 }
1038 CeedCallBackend(CeedVectorDestroy(&l_vec));
1039 }
1040 impl->num_active_out = num_active_out;
1041 }
1042
1043 // Check sizes
1044 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
1045
1046 // Build objects if needed
1047 if (build_objects) {
1048 CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out;
1049 CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */
1050
1051 // Create output restriction
1052 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out,
1053 (CeedSize)num_active_in * (CeedSize)num_active_out * (CeedSize)num_elem * (CeedSize)Q, strides,
1054 rstr));
1055 // Create assembled vector
1056 CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled));
1057 }
1058 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0));
1059 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_DEVICE, &assembled_array));
1060
1061 // Assemble QFunction
1062 for (CeedInt in = 0; in < num_active_in; in++) {
1063 // Set Inputs
1064 CeedCallBackend(CeedVectorSetValue(active_inputs[in], 1.0));
1065 if (num_active_in > 1) {
1066 CeedCallBackend(CeedVectorSetValue(active_inputs[(in + num_active_in - 1) % num_active_in], 0.0));
1067 }
1068 // Set Outputs
1069 for (CeedInt out = 0; out < num_output_fields; out++) {
1070 CeedVector l_vec;
1071
1072 // Check if active output
1073 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &l_vec));
1074 if (l_vec == CEED_VECTOR_ACTIVE) {
1075 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, CEED_USE_POINTER, assembled_array));
1076 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size));
1077 assembled_array += size * Q * num_elem; // Advance the pointer by the size of the output
1078 }
1079 CeedCallBackend(CeedVectorDestroy(&l_vec));
1080 }
1081 // Apply QFunction
1082 CeedCallBackend(CeedQFunctionApply(qf, Q * num_elem, impl->q_vecs_in, impl->q_vecs_out));
1083 }
1084
1085 // Un-set output q-vecs to prevent accidental overwrite of Assembled
1086 for (CeedInt out = 0; out < num_output_fields; out++) {
1087 CeedVector l_vec;
1088
1089 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &l_vec));
1090 if (l_vec == CEED_VECTOR_ACTIVE) {
1091 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, NULL));
1092 }
1093 CeedCallBackend(CeedVectorDestroy(&l_vec));
1094 }
1095
1096 // Restore input arrays
1097 for (CeedInt i = 0; i < num_input_fields; i++) {
1098 CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl));
1099 }
1100
1101 // Restore output
1102 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array));
1103 CeedCallBackend(CeedDestroy(&ceed));
1104 CeedCallBackend(CeedDestroy(&ceed_parent));
1105 CeedCallBackend(CeedQFunctionDestroy(&qf));
1106 return CEED_ERROR_SUCCESS;
1107 }
1108
1109 //------------------------------------------------------------------------------
1110 // Assemble Linear QFunction
1111 //------------------------------------------------------------------------------
CeedOperatorLinearAssembleQFunction_Hip(CeedOperator op,CeedVector * assembled,CeedElemRestriction * rstr,CeedRequest * request)1112 static int CeedOperatorLinearAssembleQFunction_Hip(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) {
1113 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, true, assembled, rstr, request);
1114 }
1115
1116 //------------------------------------------------------------------------------
1117 // Update Assembled Linear QFunction
1118 //------------------------------------------------------------------------------
CeedOperatorLinearAssembleQFunctionUpdate_Hip(CeedOperator op,CeedVector assembled,CeedElemRestriction rstr,CeedRequest * request)1119 static int CeedOperatorLinearAssembleQFunctionUpdate_Hip(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) {
1120 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, false, &assembled, &rstr, request);
1121 }
1122
1123 //------------------------------------------------------------------------------
1124 // Assemble Diagonal Setup
1125 //------------------------------------------------------------------------------
CeedOperatorAssembleDiagonalSetup_Hip(CeedOperator op)1126 static inline int CeedOperatorAssembleDiagonalSetup_Hip(CeedOperator op) {
1127 Ceed ceed;
1128 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0;
1129 CeedInt q_comp, num_nodes, num_qpts;
1130 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL;
1131 CeedBasis basis_in = NULL, basis_out = NULL;
1132 CeedQFunctionField *qf_fields;
1133 CeedQFunction qf;
1134 CeedOperatorField *op_fields;
1135 CeedOperator_Hip *impl;
1136
1137 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1138 CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
1139 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));
1140
1141 // Determine active input basis
1142 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL));
1143 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL));
1144 for (CeedInt i = 0; i < num_input_fields; i++) {
1145 CeedVector vec;
1146
1147 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec));
1148 if (vec == CEED_VECTOR_ACTIVE) {
1149 CeedEvalMode eval_mode;
1150 CeedBasis basis;
1151
1152 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
1153 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND,
1154 "Backend does not implement operator diagonal assembly with multiple active bases");
1155 if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in));
1156 CeedCallBackend(CeedBasisDestroy(&basis));
1157 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
1158 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp));
1159 if (eval_mode != CEED_EVAL_WEIGHT) {
1160 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly
1161 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in));
1162 for (CeedInt d = 0; d < q_comp; d++) eval_modes_in[num_eval_modes_in + d] = eval_mode;
1163 num_eval_modes_in += q_comp;
1164 }
1165 }
1166 CeedCallBackend(CeedVectorDestroy(&vec));
1167 }
1168
1169 // Determine active output basis
1170 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields));
1171 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields));
1172 for (CeedInt i = 0; i < num_output_fields; i++) {
1173 CeedVector vec;
1174
1175 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec));
1176 if (vec == CEED_VECTOR_ACTIVE) {
1177 CeedBasis basis;
1178 CeedEvalMode eval_mode;
1179
1180 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
1181 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND,
1182 "Backend does not implement operator diagonal assembly with multiple active bases");
1183 if (!basis_out) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_out));
1184 CeedCallBackend(CeedBasisDestroy(&basis));
1185 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
1186 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp));
1187 if (eval_mode != CEED_EVAL_WEIGHT) {
1188 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly
1189 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out));
1190 for (CeedInt d = 0; d < q_comp; d++) eval_modes_out[num_eval_modes_out + d] = eval_mode;
1191 num_eval_modes_out += q_comp;
1192 }
1193 }
1194 CeedCallBackend(CeedVectorDestroy(&vec));
1195 }
1196
1197 // Operator data struct
1198 CeedCallBackend(CeedOperatorGetData(op, &impl));
1199 CeedCallBackend(CeedCalloc(1, &impl->diag));
1200 CeedOperatorDiag_Hip *diag = impl->diag;
1201
1202 // Basis matrices
1203 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes));
1204 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes;
1205 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts));
1206 const CeedInt interp_bytes = num_nodes * num_qpts * sizeof(CeedScalar);
1207 const CeedInt eval_modes_bytes = sizeof(CeedEvalMode);
1208 bool has_eval_none = false;
1209
1210 // CEED_EVAL_NONE
1211 for (CeedInt i = 0; i < num_eval_modes_in; i++) has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE);
1212 for (CeedInt i = 0; i < num_eval_modes_out; i++) has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE);
1213 if (has_eval_none) {
1214 CeedScalar *identity = NULL;
1215
1216 CeedCallBackend(CeedCalloc(num_nodes * num_qpts, &identity));
1217 for (CeedInt i = 0; i < (num_nodes < num_qpts ? num_nodes : num_qpts); i++) identity[i * num_nodes + i] = 1.0;
1218 CeedCallHip(ceed, hipMalloc((void **)&diag->d_identity, interp_bytes));
1219 CeedCallHip(ceed, hipMemcpy(diag->d_identity, identity, interp_bytes, hipMemcpyHostToDevice));
1220 CeedCallBackend(CeedFree(&identity));
1221 }
1222
1223 // CEED_EVAL_INTERP, CEED_EVAL_GRAD, CEED_EVAL_DIV, and CEED_EVAL_CURL
1224 for (CeedInt in = 0; in < 2; in++) {
1225 CeedFESpace fespace;
1226 CeedBasis basis = in ? basis_in : basis_out;
1227
1228 CeedCallBackend(CeedBasisGetFESpace(basis, &fespace));
1229 switch (fespace) {
1230 case CEED_FE_SPACE_H1: {
1231 CeedInt q_comp_interp, q_comp_grad;
1232 const CeedScalar *interp, *grad;
1233 CeedScalar *d_interp, *d_grad;
1234
1235 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp));
1236 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_grad));
1237
1238 CeedCallBackend(CeedBasisGetInterp(basis, &interp));
1239 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp));
1240 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice));
1241 CeedCallBackend(CeedBasisGetGrad(basis, &grad));
1242 CeedCallHip(ceed, hipMalloc((void **)&d_grad, interp_bytes * q_comp_grad));
1243 CeedCallHip(ceed, hipMemcpy(d_grad, grad, interp_bytes * q_comp_grad, hipMemcpyHostToDevice));
1244 if (in) {
1245 diag->d_interp_in = d_interp;
1246 diag->d_grad_in = d_grad;
1247 } else {
1248 diag->d_interp_out = d_interp;
1249 diag->d_grad_out = d_grad;
1250 }
1251 } break;
1252 case CEED_FE_SPACE_HDIV: {
1253 CeedInt q_comp_interp, q_comp_div;
1254 const CeedScalar *interp, *div;
1255 CeedScalar *d_interp, *d_div;
1256
1257 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp));
1258 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_div));
1259
1260 CeedCallBackend(CeedBasisGetInterp(basis, &interp));
1261 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp));
1262 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice));
1263 CeedCallBackend(CeedBasisGetDiv(basis, &div));
1264 CeedCallHip(ceed, hipMalloc((void **)&d_div, interp_bytes * q_comp_div));
1265 CeedCallHip(ceed, hipMemcpy(d_div, div, interp_bytes * q_comp_div, hipMemcpyHostToDevice));
1266 if (in) {
1267 diag->d_interp_in = d_interp;
1268 diag->d_div_in = d_div;
1269 } else {
1270 diag->d_interp_out = d_interp;
1271 diag->d_div_out = d_div;
1272 }
1273 } break;
1274 case CEED_FE_SPACE_HCURL: {
1275 CeedInt q_comp_interp, q_comp_curl;
1276 const CeedScalar *interp, *curl;
1277 CeedScalar *d_interp, *d_curl;
1278
1279 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp));
1280 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_curl));
1281
1282 CeedCallBackend(CeedBasisGetInterp(basis, &interp));
1283 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp));
1284 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice));
1285 CeedCallBackend(CeedBasisGetCurl(basis, &curl));
1286 CeedCallHip(ceed, hipMalloc((void **)&d_curl, interp_bytes * q_comp_curl));
1287 CeedCallHip(ceed, hipMemcpy(d_curl, curl, interp_bytes * q_comp_curl, hipMemcpyHostToDevice));
1288 if (in) {
1289 diag->d_interp_in = d_interp;
1290 diag->d_curl_in = d_curl;
1291 } else {
1292 diag->d_interp_out = d_interp;
1293 diag->d_curl_out = d_curl;
1294 }
1295 } break;
1296 }
1297 }
1298
1299 // Arrays of eval_modes
1300 CeedCallHip(ceed, hipMalloc((void **)&diag->d_eval_modes_in, num_eval_modes_in * eval_modes_bytes));
1301 CeedCallHip(ceed, hipMemcpy(diag->d_eval_modes_in, eval_modes_in, num_eval_modes_in * eval_modes_bytes, hipMemcpyHostToDevice));
1302 CeedCallHip(ceed, hipMalloc((void **)&diag->d_eval_modes_out, num_eval_modes_out * eval_modes_bytes));
1303 CeedCallHip(ceed, hipMemcpy(diag->d_eval_modes_out, eval_modes_out, num_eval_modes_out * eval_modes_bytes, hipMemcpyHostToDevice));
1304 CeedCallBackend(CeedFree(&eval_modes_in));
1305 CeedCallBackend(CeedFree(&eval_modes_out));
1306 CeedCallBackend(CeedDestroy(&ceed));
1307 CeedCallBackend(CeedBasisDestroy(&basis_in));
1308 CeedCallBackend(CeedBasisDestroy(&basis_out));
1309 CeedCallBackend(CeedQFunctionDestroy(&qf));
1310 return CEED_ERROR_SUCCESS;
1311 }
1312
1313 //------------------------------------------------------------------------------
1314 // Assemble Diagonal Setup (Compilation)
1315 //------------------------------------------------------------------------------
CeedOperatorAssembleDiagonalSetupCompile_Hip(CeedOperator op,CeedInt use_ceedsize_idx,const bool is_point_block)1316 static inline int CeedOperatorAssembleDiagonalSetupCompile_Hip(CeedOperator op, CeedInt use_ceedsize_idx, const bool is_point_block) {
1317 Ceed ceed;
1318 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0;
1319 CeedInt num_comp, q_comp, num_nodes, num_qpts;
1320 CeedBasis basis_in = NULL, basis_out = NULL;
1321 CeedQFunctionField *qf_fields;
1322 CeedQFunction qf;
1323 CeedOperatorField *op_fields;
1324 CeedOperator_Hip *impl;
1325
1326 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1327 CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
1328 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));
1329
1330 // Determine active input basis
1331 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL));
1332 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL));
1333 for (CeedInt i = 0; i < num_input_fields; i++) {
1334 CeedVector vec;
1335
1336 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec));
1337 if (vec == CEED_VECTOR_ACTIVE) {
1338 CeedEvalMode eval_mode;
1339 CeedBasis basis;
1340
1341 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
1342 if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in));
1343 CeedCallBackend(CeedBasisDestroy(&basis));
1344 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
1345 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp));
1346 if (eval_mode != CEED_EVAL_WEIGHT) {
1347 num_eval_modes_in += q_comp;
1348 }
1349 }
1350 CeedCallBackend(CeedVectorDestroy(&vec));
1351 }
1352
1353 // Determine active output basis
1354 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields));
1355 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields));
1356 for (CeedInt i = 0; i < num_output_fields; i++) {
1357 CeedVector vec;
1358
1359 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec));
1360 if (vec == CEED_VECTOR_ACTIVE) {
1361 CeedEvalMode eval_mode;
1362 CeedBasis basis;
1363
1364 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
1365 if (!basis_out) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_out));
1366 CeedCallBackend(CeedBasisDestroy(&basis));
1367 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
1368 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp));
1369 if (eval_mode != CEED_EVAL_WEIGHT) {
1370 num_eval_modes_out += q_comp;
1371 }
1372 }
1373 CeedCallBackend(CeedVectorDestroy(&vec));
1374 }
1375
1376 // Operator data struct
1377 CeedCallBackend(CeedOperatorGetData(op, &impl));
1378 CeedOperatorDiag_Hip *diag = impl->diag;
1379
1380 // Assemble kernel
1381 const char diagonal_kernel_source[] = "// Diagonal assembly source\n#include <ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h>\n";
1382 hipModule_t *module = is_point_block ? &diag->module_point_block : &diag->module;
1383 CeedInt elems_per_block = 1;
1384
1385 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes));
1386 CeedCallBackend(CeedBasisGetNumComponents(basis_in, &num_comp));
1387 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes;
1388 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts));
1389 CeedCallHip(ceed, CeedCompile_Hip(ceed, diagonal_kernel_source, module, 8, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT",
1390 num_eval_modes_out, "NUM_COMP", num_comp, "NUM_NODES", num_nodes, "NUM_QPTS", num_qpts, "USE_CEEDSIZE",
1391 use_ceedsize_idx, "USE_POINT_BLOCK", is_point_block ? 1 : 0, "BLOCK_SIZE", num_nodes * elems_per_block));
1392 CeedCallHip(ceed, CeedGetKernel_Hip(ceed, *module, "LinearDiagonal", is_point_block ? &diag->LinearPointBlock : &diag->LinearDiagonal));
1393 CeedCallBackend(CeedDestroy(&ceed));
1394 CeedCallBackend(CeedBasisDestroy(&basis_in));
1395 CeedCallBackend(CeedBasisDestroy(&basis_out));
1396 CeedCallBackend(CeedQFunctionDestroy(&qf));
1397 return CEED_ERROR_SUCCESS;
1398 }
1399
1400 //------------------------------------------------------------------------------
1401 // Assemble Diagonal Core
1402 //------------------------------------------------------------------------------
CeedOperatorAssembleDiagonalCore_Hip(CeedOperator op,CeedVector assembled,CeedRequest * request,const bool is_point_block)1403 static inline int CeedOperatorAssembleDiagonalCore_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool is_point_block) {
1404 Ceed ceed;
1405 CeedInt num_elem, num_nodes;
1406 CeedScalar *elem_diag_array;
1407 const CeedScalar *assembled_qf_array;
1408 CeedVector assembled_qf = NULL, elem_diag;
1409 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out, diag_rstr;
1410 CeedOperator_Hip *impl;
1411
1412 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1413 CeedCallBackend(CeedOperatorGetData(op, &impl));
1414
1415 // Assemble QFunction
1416 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, request));
1417 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr));
1418 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array));
1419
1420 // Setup
1421 if (!impl->diag) CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Hip(op));
1422 CeedOperatorDiag_Hip *diag = impl->diag;
1423
1424 assert(diag != NULL);
1425
1426 // Assemble kernel if needed
1427 if ((!is_point_block && !diag->LinearDiagonal) || (is_point_block && !diag->LinearPointBlock)) {
1428 CeedSize assembled_length, assembled_qf_length;
1429 CeedInt use_ceedsize_idx = 0;
1430 CeedCallBackend(CeedVectorGetLength(assembled, &assembled_length));
1431 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length));
1432 if ((assembled_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1;
1433
1434 CeedCallBackend(CeedOperatorAssembleDiagonalSetupCompile_Hip(op, use_ceedsize_idx, is_point_block));
1435 }
1436
1437 // Restriction and diagonal vector
1438 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out));
1439 CeedCheck(rstr_in == rstr_out, ceed, CEED_ERROR_BACKEND,
1440 "Cannot assemble operator diagonal with different input and output active element restrictions");
1441 if (!is_point_block && !diag->diag_rstr) {
1442 CeedCallBackend(CeedElemRestrictionCreateUnsignedCopy(rstr_out, &diag->diag_rstr));
1443 CeedCallBackend(CeedElemRestrictionCreateVector(diag->diag_rstr, NULL, &diag->elem_diag));
1444 } else if (is_point_block && !diag->point_block_diag_rstr) {
1445 CeedCallBackend(CeedOperatorCreateActivePointBlockRestriction(rstr_out, &diag->point_block_diag_rstr));
1446 CeedCallBackend(CeedElemRestrictionCreateVector(diag->point_block_diag_rstr, NULL, &diag->point_block_elem_diag));
1447 }
1448 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_in));
1449 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out));
1450 diag_rstr = is_point_block ? diag->point_block_diag_rstr : diag->diag_rstr;
1451 elem_diag = is_point_block ? diag->point_block_elem_diag : diag->elem_diag;
1452 CeedCallBackend(CeedVectorSetValue(elem_diag, 0.0));
1453
1454 // Only assemble diagonal if the basis has nodes, otherwise inputs are null pointers
1455 CeedCallBackend(CeedElemRestrictionGetElementSize(diag_rstr, &num_nodes));
1456 if (num_nodes > 0) {
1457 // Assemble element operator diagonals
1458 CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem));
1459 CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array));
1460
1461 // Compute the diagonal of B^T D B
1462 CeedInt elems_per_block = 1;
1463 CeedInt grid = CeedDivUpInt(num_elem, elems_per_block);
1464 void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_div_in,
1465 &diag->d_curl_in, &diag->d_interp_out, &diag->d_grad_out, &diag->d_div_out, &diag->d_curl_out,
1466 &diag->d_eval_modes_in, &diag->d_eval_modes_out, &assembled_qf_array, &elem_diag_array};
1467
1468 if (is_point_block) {
1469 CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->LinearPointBlock, grid, num_nodes, 1, elems_per_block, args));
1470 } else {
1471 CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->LinearDiagonal, grid, num_nodes, 1, elems_per_block, args));
1472 }
1473
1474 // Restore arrays
1475 CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array));
1476 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array));
1477 }
1478
1479 // Assemble local operator diagonal
1480 CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request));
1481
1482 // Cleanup
1483 CeedCallBackend(CeedDestroy(&ceed));
1484 CeedCallBackend(CeedVectorDestroy(&assembled_qf));
1485 return CEED_ERROR_SUCCESS;
1486 }
1487
1488 //------------------------------------------------------------------------------
1489 // Assemble Linear Diagonal
1490 //------------------------------------------------------------------------------
CeedOperatorLinearAssembleAddDiagonal_Hip(CeedOperator op,CeedVector assembled,CeedRequest * request)1491 static int CeedOperatorLinearAssembleAddDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) {
1492 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, false));
1493 return CEED_ERROR_SUCCESS;
1494 }
1495
1496 //------------------------------------------------------------------------------
1497 // Assemble Linear Point Block Diagonal
1498 //------------------------------------------------------------------------------
CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip(CeedOperator op,CeedVector assembled,CeedRequest * request)1499 static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) {
1500 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, true));
1501 return CEED_ERROR_SUCCESS;
1502 }
1503
1504 //------------------------------------------------------------------------------
1505 // Single Operator Assembly Setup
1506 //------------------------------------------------------------------------------
CeedOperatorAssembleSingleSetup_Hip(CeedOperator op,CeedInt use_ceedsize_idx)1507 static int CeedOperatorAssembleSingleSetup_Hip(CeedOperator op, CeedInt use_ceedsize_idx) {
1508 Ceed ceed;
1509 Ceed_Hip *hip_data;
1510 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0;
1511 CeedInt elem_size_in, num_qpts_in = 0, num_comp_in, elem_size_out, num_qpts_out, num_comp_out, q_comp;
1512 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL;
1513 CeedElemRestriction rstr_in = NULL, rstr_out = NULL;
1514 CeedBasis basis_in = NULL, basis_out = NULL;
1515 CeedQFunctionField *qf_fields;
1516 CeedQFunction qf;
1517 CeedOperatorField *input_fields, *output_fields;
1518 CeedOperator_Hip *impl;
1519
1520 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1521 CeedCallBackend(CeedOperatorGetData(op, &impl));
1522
1523 // Get intput and output fields
1524 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields));
1525
1526 // Determine active input basis eval mode
1527 CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
1528 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL));
1529 for (CeedInt i = 0; i < num_input_fields; i++) {
1530 CeedVector vec;
1531
1532 CeedCallBackend(CeedOperatorFieldGetVector(input_fields[i], &vec));
1533 if (vec == CEED_VECTOR_ACTIVE) {
1534 CeedEvalMode eval_mode;
1535 CeedElemRestriction elem_rstr;
1536 CeedBasis basis;
1537
1538 CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis));
1539 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases");
1540 if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in));
1541 CeedCallBackend(CeedBasisDestroy(&basis));
1542 CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &elem_rstr));
1543 if (!rstr_in) CeedCallBackend(CeedElemRestrictionReferenceCopy(elem_rstr, &rstr_in));
1544 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
1545 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in));
1546 if (basis_in == CEED_BASIS_NONE) num_qpts_in = elem_size_in;
1547 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts_in));
1548 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
1549 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp));
1550 if (eval_mode != CEED_EVAL_WEIGHT) {
1551 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly
1552 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in));
1553 for (CeedInt d = 0; d < q_comp; d++) {
1554 eval_modes_in[num_eval_modes_in + d] = eval_mode;
1555 }
1556 num_eval_modes_in += q_comp;
1557 }
1558 }
1559 CeedCallBackend(CeedVectorDestroy(&vec));
1560 }
1561
1562 // Determine active output basis; basis_out and rstr_out only used if same as input, TODO
1563 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields));
1564 for (CeedInt i = 0; i < num_output_fields; i++) {
1565 CeedVector vec;
1566
1567 CeedCallBackend(CeedOperatorFieldGetVector(output_fields[i], &vec));
1568 if (vec == CEED_VECTOR_ACTIVE) {
1569 CeedEvalMode eval_mode;
1570 CeedElemRestriction elem_rstr;
1571 CeedBasis basis;
1572
1573 CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis));
1574 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND,
1575 "Backend does not implement operator assembly with multiple active bases");
1576 if (!basis_out) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_out));
1577 CeedCallBackend(CeedBasisDestroy(&basis));
1578 CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &elem_rstr));
1579 if (!rstr_out) CeedCallBackend(CeedElemRestrictionReferenceCopy(elem_rstr, &rstr_out));
1580 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
1581 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out));
1582 if (basis_out == CEED_BASIS_NONE) num_qpts_out = elem_size_out;
1583 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_out, &num_qpts_out));
1584 CeedCheck(num_qpts_in == num_qpts_out, ceed, CEED_ERROR_UNSUPPORTED,
1585 "Active input and output bases must have the same number of quadrature points");
1586 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
1587 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp));
1588 if (eval_mode != CEED_EVAL_WEIGHT) {
1589 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly
1590 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out));
1591 for (CeedInt d = 0; d < q_comp; d++) {
1592 eval_modes_out[num_eval_modes_out + d] = eval_mode;
1593 }
1594 num_eval_modes_out += q_comp;
1595 }
1596 }
1597 CeedCallBackend(CeedVectorDestroy(&vec));
1598 }
1599 CeedCheck(num_eval_modes_in > 0 && num_eval_modes_out > 0, ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs");
1600
1601 CeedCallBackend(CeedCalloc(1, &impl->asmb));
1602 CeedOperatorAssemble_Hip *asmb = impl->asmb;
1603 asmb->elems_per_block = 1;
1604 asmb->block_size_x = elem_size_in;
1605 asmb->block_size_y = elem_size_out;
1606
1607 CeedCallBackend(CeedGetData(ceed, &hip_data));
1608 bool fallback = asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block > hip_data->device_prop.maxThreadsPerBlock;
1609
1610 if (fallback) {
1611 // Use fallback kernel with 1D threadblock
1612 asmb->block_size_y = 1;
1613 }
1614
1615 // Compile kernels
1616 const char assembly_kernel_source[] = "// Full assembly source\n#include <ceed/jit-source/hip/hip-ref-operator-assemble.h>\n";
1617
1618 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &num_comp_in));
1619 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_out, &num_comp_out));
1620 CeedCallBackend(CeedCompile_Hip(ceed, assembly_kernel_source, &asmb->module, 10, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT",
1621 num_eval_modes_out, "NUM_COMP_IN", num_comp_in, "NUM_COMP_OUT", num_comp_out, "NUM_NODES_IN", elem_size_in,
1622 "NUM_NODES_OUT", elem_size_out, "NUM_QPTS", num_qpts_in, "BLOCK_SIZE",
1623 asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block, "BLOCK_SIZE_Y", asmb->block_size_y, "USE_CEEDSIZE",
1624 use_ceedsize_idx));
1625 CeedCallBackend(CeedGetKernel_Hip(ceed, asmb->module, "LinearAssemble", &asmb->LinearAssemble));
1626
1627 // Load into B_in, in order that they will be used in eval_modes_in
1628 {
1629 const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof(CeedScalar);
1630 CeedInt d_in = 0;
1631 CeedEvalMode eval_modes_in_prev = CEED_EVAL_NONE;
1632 bool has_eval_none = false;
1633 CeedScalar *identity = NULL;
1634
1635 for (CeedInt i = 0; i < num_eval_modes_in; i++) {
1636 has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE);
1637 }
1638 if (has_eval_none) {
1639 CeedCallBackend(CeedCalloc(elem_size_in * num_qpts_in, &identity));
1640 for (CeedInt i = 0; i < (elem_size_in < num_qpts_in ? elem_size_in : num_qpts_in); i++) identity[i * elem_size_in + i] = 1.0;
1641 }
1642
1643 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_in, in_bytes));
1644 for (CeedInt i = 0; i < num_eval_modes_in; i++) {
1645 const CeedScalar *h_B_in;
1646
1647 CeedCallBackend(CeedOperatorGetBasisPointer(basis_in, eval_modes_in[i], identity, &h_B_in));
1648 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_modes_in[i], &q_comp));
1649 if (q_comp > 1) {
1650 if (i == 0 || eval_modes_in[i] != eval_modes_in_prev) d_in = 0;
1651 else h_B_in = &h_B_in[(++d_in) * elem_size_in * num_qpts_in];
1652 }
1653 eval_modes_in_prev = eval_modes_in[i];
1654
1655 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_in[i * elem_size_in * num_qpts_in], h_B_in, elem_size_in * num_qpts_in * sizeof(CeedScalar),
1656 hipMemcpyHostToDevice));
1657 }
1658 CeedCallBackend(CeedFree(&identity));
1659 }
1660 CeedCallBackend(CeedFree(&eval_modes_in));
1661
1662 // Load into B_out, in order that they will be used in eval_modes_out
1663 {
1664 const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof(CeedScalar);
1665 CeedInt d_out = 0;
1666 CeedEvalMode eval_modes_out_prev = CEED_EVAL_NONE;
1667 bool has_eval_none = false;
1668 CeedScalar *identity = NULL;
1669
1670 for (CeedInt i = 0; i < num_eval_modes_out; i++) {
1671 has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE);
1672 }
1673 if (has_eval_none) {
1674 CeedCallBackend(CeedCalloc(elem_size_out * num_qpts_out, &identity));
1675 for (CeedInt i = 0; i < (elem_size_out < num_qpts_out ? elem_size_out : num_qpts_out); i++) identity[i * elem_size_out + i] = 1.0;
1676 }
1677
1678 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_out, out_bytes));
1679 for (CeedInt i = 0; i < num_eval_modes_out; i++) {
1680 const CeedScalar *h_B_out;
1681
1682 CeedCallBackend(CeedOperatorGetBasisPointer(basis_out, eval_modes_out[i], identity, &h_B_out));
1683 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_modes_out[i], &q_comp));
1684 if (q_comp > 1) {
1685 if (i == 0 || eval_modes_out[i] != eval_modes_out_prev) d_out = 0;
1686 else h_B_out = &h_B_out[(++d_out) * elem_size_out * num_qpts_out];
1687 }
1688 eval_modes_out_prev = eval_modes_out[i];
1689
1690 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_out[i * elem_size_out * num_qpts_out], h_B_out, elem_size_out * num_qpts_out * sizeof(CeedScalar),
1691 hipMemcpyHostToDevice));
1692 }
1693 CeedCallBackend(CeedFree(&identity));
1694 }
1695 CeedCallBackend(CeedFree(&eval_modes_out));
1696 CeedCallBackend(CeedDestroy(&ceed));
1697 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_in));
1698 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out));
1699 CeedCallBackend(CeedBasisDestroy(&basis_in));
1700 CeedCallBackend(CeedBasisDestroy(&basis_out));
1701 CeedCallBackend(CeedQFunctionDestroy(&qf));
1702 return CEED_ERROR_SUCCESS;
1703 }
1704
1705 //------------------------------------------------------------------------------
1706 // Assemble matrix data for COO matrix of assembled operator.
1707 // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
1708 //
1709 // Note that this (and other assembly routines) currently assume only one active input restriction/basis per operator
1710 // (could have multiple basis eval modes).
1711 // TODO: allow multiple active input restrictions/basis objects
1712 //------------------------------------------------------------------------------
CeedOperatorAssembleSingle_Hip(CeedOperator op,CeedInt offset,CeedVector values)1713 static int CeedOperatorAssembleSingle_Hip(CeedOperator op, CeedInt offset, CeedVector values) {
1714 Ceed ceed;
1715 CeedSize values_length = 0, assembled_qf_length = 0;
1716 CeedInt use_ceedsize_idx = 0, num_elem_in, num_elem_out, elem_size_in, elem_size_out;
1717 CeedScalar *values_array;
1718 const CeedScalar *assembled_qf_array;
1719 CeedVector assembled_qf = NULL;
1720 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out;
1721 CeedRestrictionType rstr_type_in, rstr_type_out;
1722 const bool *orients_in = NULL, *orients_out = NULL;
1723 const CeedInt8 *curl_orients_in = NULL, *curl_orients_out = NULL;
1724 CeedOperator_Hip *impl;
1725
1726 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1727 CeedCallBackend(CeedOperatorGetData(op, &impl));
1728
1729 // Assemble QFunction
1730 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, CEED_REQUEST_IMMEDIATE));
1731 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr));
1732 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array));
1733
1734 CeedCallBackend(CeedVectorGetLength(values, &values_length));
1735 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length));
1736 if ((values_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1;
1737
1738 // Setup
1739 if (!impl->asmb) CeedCallBackend(CeedOperatorAssembleSingleSetup_Hip(op, use_ceedsize_idx));
1740 CeedOperatorAssemble_Hip *asmb = impl->asmb;
1741
1742 assert(asmb != NULL);
1743
1744 // Assemble element operator
1745 CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_DEVICE, &values_array));
1746 values_array += offset;
1747
1748 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out));
1749 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &num_elem_in));
1750 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in));
1751
1752 CeedCallBackend(CeedElemRestrictionGetType(rstr_in, &rstr_type_in));
1753 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) {
1754 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_in, CEED_MEM_DEVICE, &orients_in));
1755 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) {
1756 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_in, CEED_MEM_DEVICE, &curl_orients_in));
1757 }
1758
1759 if (rstr_in != rstr_out) {
1760 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_out, &num_elem_out));
1761 CeedCheck(num_elem_in == num_elem_out, ceed, CEED_ERROR_UNSUPPORTED,
1762 "Active input and output operator restrictions must have the same number of elements");
1763 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out));
1764
1765 CeedCallBackend(CeedElemRestrictionGetType(rstr_out, &rstr_type_out));
1766 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) {
1767 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_out, CEED_MEM_DEVICE, &orients_out));
1768 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) {
1769 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_out, CEED_MEM_DEVICE, &curl_orients_out));
1770 }
1771 } else {
1772 elem_size_out = elem_size_in;
1773 orients_out = orients_in;
1774 curl_orients_out = curl_orients_in;
1775 }
1776
1777 // Compute B^T D B
1778 CeedInt shared_mem =
1779 ((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0) + (curl_orients_in ? elem_size_in * asmb->block_size_y : 0)) *
1780 sizeof(CeedScalar);
1781 CeedInt grid = CeedDivUpInt(num_elem_in, asmb->elems_per_block);
1782 void *args[] = {(void *)&num_elem_in, &asmb->d_B_in, &asmb->d_B_out, &orients_in, &curl_orients_in,
1783 &orients_out, &curl_orients_out, &assembled_qf_array, &values_array};
1784
1785 CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, asmb->LinearAssemble, NULL, grid, asmb->block_size_x, asmb->block_size_y, asmb->elems_per_block,
1786 shared_mem, args));
1787
1788 // Restore arrays
1789 CeedCallBackend(CeedVectorRestoreArray(values, &values_array));
1790 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array));
1791
1792 // Cleanup
1793 CeedCallBackend(CeedVectorDestroy(&assembled_qf));
1794 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) {
1795 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_in, &orients_in));
1796 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) {
1797 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_in, &curl_orients_in));
1798 }
1799 if (rstr_in != rstr_out) {
1800 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) {
1801 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_out, &orients_out));
1802 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) {
1803 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_out, &curl_orients_out));
1804 }
1805 }
1806 CeedCallBackend(CeedDestroy(&ceed));
1807 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_in));
1808 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out));
1809 return CEED_ERROR_SUCCESS;
1810 }
1811
1812 //------------------------------------------------------------------------------
1813 // Assemble Linear QFunction AtPoints
1814 //------------------------------------------------------------------------------
CeedOperatorLinearAssembleQFunctionAtPoints_Hip(CeedOperator op,CeedVector * assembled,CeedElemRestriction * rstr,CeedRequest * request)1815 static int CeedOperatorLinearAssembleQFunctionAtPoints_Hip(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) {
1816 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "Backend does not implement CeedOperatorLinearAssembleQFunction");
1817 }
1818
1819 //------------------------------------------------------------------------------
1820 // Assemble Linear Diagonal AtPoints
1821 //------------------------------------------------------------------------------
CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op,CeedVector assembled,CeedRequest * request)1822 static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) {
1823 CeedInt max_num_points, *num_points, num_elem, num_input_fields, num_output_fields;
1824 Ceed ceed;
1825 CeedVector active_e_vec_in, active_e_vec_out;
1826 CeedQFunctionField *qf_input_fields, *qf_output_fields;
1827 CeedQFunction qf;
1828 CeedOperatorField *op_input_fields, *op_output_fields;
1829 CeedOperator_Hip *impl;
1830
1831 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1832 CeedCallBackend(CeedOperatorGetData(op, &impl));
1833 CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
1834 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
1835 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
1836 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
1837
1838 // Setup
1839 CeedCallBackend(CeedOperatorSetupAtPoints_Hip(op));
1840 num_points = impl->num_points;
1841 max_num_points = impl->max_num_points;
1842
1843 // Work vector
1844 CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_in));
1845 CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_out));
1846 {
1847 CeedSize length_in, length_out;
1848
1849 CeedCallBackend(CeedVectorGetLength(active_e_vec_in, &length_in));
1850 CeedCallBackend(CeedVectorGetLength(active_e_vec_out, &length_out));
1851 // Need input e_vec to be longer
1852 if (length_in < length_out) {
1853 CeedVector temp = active_e_vec_in;
1854
1855 active_e_vec_in = active_e_vec_out;
1856 active_e_vec_out = temp;
1857 }
1858 }
1859
1860 // Get point coordinates
1861 {
1862 CeedVector point_coords = NULL;
1863 CeedElemRestriction rstr_points = NULL;
1864
1865 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords));
1866 if (!impl->point_coords_elem) CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem));
1867 {
1868 uint64_t state;
1869 CeedCallBackend(CeedVectorGetState(point_coords, &state));
1870 if (impl->points_state != state) {
1871 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request));
1872 }
1873 }
1874 CeedCallBackend(CeedVectorDestroy(&point_coords));
1875 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
1876 }
1877
1878 // Process inputs
1879 for (CeedInt i = 0; i < num_input_fields; i++) {
1880 CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl, request));
1881 CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, num_points, true, false,
1882 impl));
1883 }
1884
1885 // Output pointers, as necessary
1886 for (CeedInt i = 0; i < num_output_fields; i++) {
1887 CeedEvalMode eval_mode;
1888
1889 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
1890 if (eval_mode == CEED_EVAL_NONE) {
1891 CeedScalar *e_vec_array;
1892
1893 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array));
1894 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array));
1895 }
1896 }
1897
1898 // Loop over active fields
1899 for (CeedInt i = 0; i < num_input_fields; i++) {
1900 bool is_active = false, is_active_at_points = true;
1901 CeedInt elem_size = 1, num_comp_active = 1, e_vec_size = 0, field_in = impl->input_field_order[i];
1902 CeedRestrictionType rstr_type;
1903 CeedVector l_vec;
1904 CeedElemRestriction elem_rstr;
1905
1906 // -- Skip non-active input
1907 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[field_in], &l_vec));
1908 is_active = l_vec == CEED_VECTOR_ACTIVE;
1909 CeedCallBackend(CeedVectorDestroy(&l_vec));
1910 if (!is_active || impl->skip_rstr_in[field_in]) continue;
1911
1912 // -- Get active restriction type
1913 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[field_in], &elem_rstr));
1914 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
1915 is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS;
1916 if (!is_active_at_points) CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
1917 else elem_size = max_num_points;
1918 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp_active));
1919 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
1920
1921 e_vec_size = elem_size * num_comp_active;
1922 CeedCallBackend(CeedVectorSetValue(active_e_vec_in, 0.0));
1923 for (CeedInt s = 0; s < e_vec_size; s++) {
1924 CeedVector q_vec = impl->q_vecs_in[field_in];
1925
1926 // Update unit vector
1927 {
1928 // Note: E-vec strides are node * (1) + comp * (elem_size * num_elem) + elem * (elem_size)
1929 CeedInt node = (s - 1) % elem_size, comp = (s - 1) / elem_size;
1930 CeedSize start = node * 1 + comp * (elem_size * num_elem);
1931 CeedSize stop = (comp + 1) * (elem_size * num_elem);
1932
1933 if (s != 0) CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, start, stop, elem_size, 0.0));
1934
1935 node = s % elem_size, comp = s / elem_size;
1936 start = node * 1 + comp * (elem_size * num_elem);
1937 stop = (comp + 1) * (elem_size * num_elem);
1938 CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, start, stop, elem_size, 1.0));
1939 }
1940
1941 // Basis action
1942 for (CeedInt j = 0; j < num_input_fields; j++) {
1943 CeedInt field = impl->input_field_order[j];
1944
1945 CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(op_input_fields[field], qf_input_fields[field], field, NULL, active_e_vec_in, num_elem,
1946 num_points, false, true, impl));
1947 }
1948
1949 // Q function
1950 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out));
1951
1952 // Output basis apply if needed
1953 for (CeedInt j = 0; j < num_output_fields; j++) {
1954 bool is_active = false;
1955 CeedInt elem_size = 0;
1956 CeedInt field_out = impl->output_field_order[j];
1957 CeedRestrictionType rstr_type;
1958 CeedEvalMode eval_mode;
1959 CeedVector l_vec, e_vec = impl->e_vecs_out[field_out], q_vec = impl->q_vecs_out[field_out];
1960 CeedElemRestriction elem_rstr;
1961
1962 // ---- Skip non-active output
1963 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[field_out], &l_vec));
1964 is_active = l_vec == CEED_VECTOR_ACTIVE;
1965 CeedCallBackend(CeedVectorDestroy(&l_vec));
1966 if (!is_active) continue;
1967 if (!e_vec) e_vec = active_e_vec_out;
1968
1969 // ---- Check if elem size matches
1970 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[field_out], &elem_rstr));
1971 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
1972 if (is_active_at_points && rstr_type != CEED_RESTRICTION_POINTS) continue;
1973 if (rstr_type == CEED_RESTRICTION_POINTS) {
1974 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(elem_rstr, &elem_size));
1975 } else {
1976 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
1977 }
1978 {
1979 CeedInt num_comp = 0;
1980
1981 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
1982 if (e_vec_size != num_comp * elem_size) continue;
1983 }
1984
1985 // Basis action
1986 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[field_out], &eval_mode));
1987 switch (eval_mode) {
1988 case CEED_EVAL_NONE: {
1989 CeedScalar *e_vec_array;
1990
1991 CeedCallBackend(CeedVectorTakeArray(q_vec, CEED_MEM_DEVICE, &e_vec_array));
1992 CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_vec_array));
1993 break;
1994 }
1995 case CEED_EVAL_INTERP:
1996 case CEED_EVAL_GRAD:
1997 case CEED_EVAL_DIV:
1998 case CEED_EVAL_CURL: {
1999 CeedBasis basis;
2000
2001 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[field_out], &basis));
2002 if (impl->apply_add_basis_out[field_out]) {
2003 CeedCallBackend(CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec,
2004 e_vec));
2005 } else {
2006 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec));
2007 }
2008 CeedCallBackend(CeedBasisDestroy(&basis));
2009 break;
2010 }
2011 // LCOV_EXCL_START
2012 case CEED_EVAL_WEIGHT: {
2013 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
2014 // LCOV_EXCL_STOP
2015 }
2016 }
2017
2018 // Continue if a field that is summed into
2019 if (impl->skip_rstr_out[field_out]) {
2020 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
2021 continue;
2022 }
2023
2024 // Mask output e-vec
2025 CeedCallBackend(CeedVectorPointwiseMult(e_vec, active_e_vec_in, e_vec));
2026
2027 // Restrict
2028 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, e_vec, assembled, request));
2029 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
2030
2031 // Reset q_vec for
2032 if (eval_mode == CEED_EVAL_NONE) {
2033 CeedScalar *e_vec_array;
2034
2035 CeedCallBackend(CeedVectorGetArrayWrite(e_vec, CEED_MEM_DEVICE, &e_vec_array));
2036 CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array));
2037 }
2038 }
2039
2040 // Reset vec
2041 if (s == e_vec_size - 1 && i != num_input_fields - 1) CeedCallBackend(CeedVectorSetValue(q_vec, 0.0));
2042 }
2043 }
2044
2045 // Restore CEED_EVAL_NONE
2046 for (CeedInt i = 0; i < num_output_fields; i++) {
2047 CeedEvalMode eval_mode;
2048
2049 // Get eval_mode
2050 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
2051
2052 // Restore evec
2053 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
2054 if (eval_mode == CEED_EVAL_NONE) {
2055 CeedScalar *e_vec_array;
2056
2057 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array));
2058 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_out[i], &e_vec_array));
2059 }
2060 }
2061
2062 // Restore input arrays
2063 for (CeedInt i = 0; i < num_input_fields; i++) {
2064 CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl));
2065 }
2066
2067 // Restore work vector
2068 CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec_in));
2069 CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec_out));
2070 CeedCallBackend(CeedDestroy(&ceed));
2071 CeedCallBackend(CeedQFunctionDestroy(&qf));
2072 return CEED_ERROR_SUCCESS;
2073 }
2074
2075 //------------------------------------------------------------------------------
2076 // Create operator
2077 //------------------------------------------------------------------------------
CeedOperatorCreate_Hip(CeedOperator op)2078 int CeedOperatorCreate_Hip(CeedOperator op) {
2079 Ceed ceed;
2080 CeedOperator_Hip *impl;
2081
2082 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
2083 CeedCallBackend(CeedCalloc(1, &impl));
2084 CeedCallBackend(CeedOperatorSetData(op, impl));
2085
2086 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Hip));
2087 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Hip));
2088 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonal_Hip));
2089 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal",
2090 CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip));
2091 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedOperatorAssembleSingle_Hip));
2092 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip));
2093 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip));
2094 CeedCallBackend(CeedDestroy(&ceed));
2095 return CEED_ERROR_SUCCESS;
2096 }
2097
2098 //------------------------------------------------------------------------------
2099 // Create operator AtPoints
2100 //------------------------------------------------------------------------------
CeedOperatorCreateAtPoints_Hip(CeedOperator op)2101 int CeedOperatorCreateAtPoints_Hip(CeedOperator op) {
2102 Ceed ceed;
2103 CeedOperator_Hip *impl;
2104
2105 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
2106 CeedCallBackend(CeedCalloc(1, &impl));
2107 CeedCallBackend(CeedOperatorSetData(op, impl));
2108
2109 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunctionAtPoints_Hip));
2110 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip));
2111 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Hip));
2112 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip));
2113 CeedCallBackend(CeedDestroy(&ceed));
2114 return CEED_ERROR_SUCCESS;
2115 }
2116
2117 //------------------------------------------------------------------------------
2118