1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <assert.h> 12 #include <stdbool.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Destroy operator 22 //------------------------------------------------------------------------------ 23 static int CeedOperatorDestroy_Hip(CeedOperator op) { 24 CeedOperator_Hip *impl; 25 26 CeedCallBackend(CeedOperatorGetData(op, &impl)); 27 28 // Apply data 29 CeedCallBackend(CeedFree(&impl->num_points)); 30 CeedCallBackend(CeedFree(&impl->skip_rstr_in)); 31 CeedCallBackend(CeedFree(&impl->skip_rstr_out)); 32 CeedCallBackend(CeedFree(&impl->apply_add_basis_out)); 33 CeedCallBackend(CeedFree(&impl->input_field_order)); 34 CeedCallBackend(CeedFree(&impl->output_field_order)); 35 CeedCallBackend(CeedFree(&impl->input_states)); 36 37 for (CeedInt i = 0; i < impl->num_inputs; i++) { 38 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i])); 39 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 40 } 41 CeedCallBackend(CeedFree(&impl->e_vecs_in)); 42 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 43 44 for (CeedInt i = 0; i < impl->num_outputs; i++) { 45 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i])); 46 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 47 } 48 CeedCallBackend(CeedFree(&impl->e_vecs_out)); 49 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 50 CeedCallBackend(CeedVectorDestroy(&impl->point_coords_elem)); 51 52 // QFunction assembly data 53 for (CeedInt i = 0; i < impl->num_active_in; i++) { 54 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i])); 55 } 56 CeedCallBackend(CeedFree(&impl->qf_active_in)); 57 58 // Diag data 59 if (impl->diag) { 60 Ceed ceed; 61 62 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 63 if (impl->diag->module) { 64 CeedCallHip(ceed, hipModuleUnload(impl->diag->module)); 65 } 66 if (impl->diag->module_point_block) { 67 CeedCallHip(ceed, hipModuleUnload(impl->diag->module_point_block)); 68 } 69 CeedCallHip(ceed, hipFree(impl->diag->d_eval_modes_in)); 70 CeedCallHip(ceed, hipFree(impl->diag->d_eval_modes_out)); 71 CeedCallHip(ceed, hipFree(impl->diag->d_identity)); 72 CeedCallHip(ceed, hipFree(impl->diag->d_interp_in)); 73 CeedCallHip(ceed, hipFree(impl->diag->d_interp_out)); 74 CeedCallHip(ceed, hipFree(impl->diag->d_grad_in)); 75 CeedCallHip(ceed, hipFree(impl->diag->d_grad_out)); 76 CeedCallHip(ceed, hipFree(impl->diag->d_div_in)); 77 CeedCallHip(ceed, hipFree(impl->diag->d_div_out)); 78 CeedCallHip(ceed, hipFree(impl->diag->d_curl_in)); 79 CeedCallHip(ceed, hipFree(impl->diag->d_curl_out)); 80 CeedCallBackend(CeedDestroy(&ceed)); 81 CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag)); 82 CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag)); 83 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr)); 84 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr)); 85 } 86 CeedCallBackend(CeedFree(&impl->diag)); 87 88 if (impl->asmb) { 89 Ceed ceed; 90 91 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 92 CeedCallHip(ceed, hipModuleUnload(impl->asmb->module)); 93 CeedCallHip(ceed, hipFree(impl->asmb->d_B_in)); 94 CeedCallHip(ceed, hipFree(impl->asmb->d_B_out)); 95 CeedCallBackend(CeedDestroy(&ceed)); 96 } 97 CeedCallBackend(CeedFree(&impl->asmb)); 98 99 CeedCallBackend(CeedFree(&impl)); 100 return CEED_ERROR_SUCCESS; 101 } 102 103 //------------------------------------------------------------------------------ 104 // Setup infields or outfields 105 //------------------------------------------------------------------------------ 106 static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, bool *apply_add_basis, 107 CeedVector *e_vecs, CeedVector *q_vecs, CeedInt num_fields, CeedInt Q, CeedInt num_elem) { 108 Ceed ceed; 109 CeedQFunctionField *qf_fields; 110 CeedOperatorField *op_fields; 111 112 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 113 if (is_input) { 114 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 115 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 116 } else { 117 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 118 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 119 } 120 121 // Loop over fields 122 for (CeedInt i = 0; i < num_fields; i++) { 123 bool is_active = false, is_strided = false, skip_e_vec = false; 124 CeedSize q_size; 125 CeedInt size; 126 CeedEvalMode eval_mode; 127 CeedVector l_vec; 128 CeedElemRestriction elem_rstr; 129 130 // Check whether this field can skip the element restriction: 131 // Input CEED_VECTOR_ACTIVE 132 // Output CEED_VECTOR_ACTIVE without CEED_EVAL_NONE 133 // Input CEED_VECTOR_NONE with CEED_EVAL_WEIGHT 134 // Input passive vector with CEED_EVAL_NONE and strided restriction with CEED_STRIDES_BACKEND 135 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &l_vec)); 136 is_active = l_vec == CEED_VECTOR_ACTIVE; 137 CeedCallBackend(CeedVectorDestroy(&l_vec)); 138 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); 139 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 140 skip_e_vec = (is_input && is_active) || (is_active && eval_mode != CEED_EVAL_NONE) || (eval_mode == CEED_EVAL_WEIGHT); 141 if (!skip_e_vec && is_input && !is_active && eval_mode == CEED_EVAL_NONE) { 142 CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided)); 143 if (is_strided) CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &skip_e_vec)); 144 } 145 if (skip_e_vec) { 146 e_vecs[i] = NULL; 147 } else { 148 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs[i])); 149 } 150 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 151 152 switch (eval_mode) { 153 case CEED_EVAL_NONE: 154 case CEED_EVAL_INTERP: 155 case CEED_EVAL_GRAD: 156 case CEED_EVAL_DIV: 157 case CEED_EVAL_CURL: 158 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 159 q_size = (CeedSize)num_elem * (CeedSize)Q * (CeedSize)size; 160 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 161 break; 162 case CEED_EVAL_WEIGHT: { 163 CeedBasis basis; 164 165 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 166 q_size = (CeedSize)num_elem * (CeedSize)Q; 167 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 168 if (is_at_points) { 169 CeedInt num_points[num_elem]; 170 171 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = Q; 172 CeedCallBackend( 173 CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, CEED_VECTOR_NONE, q_vecs[i])); 174 } else { 175 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 176 } 177 CeedCallBackend(CeedBasisDestroy(&basis)); 178 break; 179 } 180 } 181 } 182 // Drop duplicate restrictions 183 if (is_input) { 184 for (CeedInt i = 0; i < num_fields; i++) { 185 CeedVector vec_i; 186 CeedElemRestriction rstr_i; 187 188 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); 189 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); 190 for (CeedInt j = i + 1; j < num_fields; j++) { 191 CeedVector vec_j; 192 CeedElemRestriction rstr_j; 193 194 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); 195 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); 196 if (vec_i == vec_j && rstr_i == rstr_j) { 197 if (e_vecs[i]) CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); 198 skip_rstr[j] = true; 199 } 200 CeedCallBackend(CeedVectorDestroy(&vec_j)); 201 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j)); 202 } 203 CeedCallBackend(CeedVectorDestroy(&vec_i)); 204 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i)); 205 } 206 } else { 207 for (CeedInt i = num_fields - 1; i >= 0; i--) { 208 CeedVector vec_i; 209 CeedElemRestriction rstr_i; 210 211 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); 212 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); 213 for (CeedInt j = i - 1; j >= 0; j--) { 214 CeedVector vec_j; 215 CeedElemRestriction rstr_j; 216 217 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); 218 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); 219 if (vec_i == vec_j && rstr_i == rstr_j) { 220 if (e_vecs[i]) CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); 221 skip_rstr[j] = true; 222 apply_add_basis[i] = true; 223 } 224 CeedCallBackend(CeedVectorDestroy(&vec_j)); 225 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j)); 226 } 227 CeedCallBackend(CeedVectorDestroy(&vec_i)); 228 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i)); 229 } 230 } 231 CeedCallBackend(CeedDestroy(&ceed)); 232 return CEED_ERROR_SUCCESS; 233 } 234 235 //------------------------------------------------------------------------------ 236 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 237 //------------------------------------------------------------------------------ 238 static int CeedOperatorSetup_Hip(CeedOperator op) { 239 bool is_setup_done; 240 CeedInt Q, num_elem, num_input_fields, num_output_fields; 241 CeedQFunctionField *qf_input_fields, *qf_output_fields; 242 CeedQFunction qf; 243 CeedOperatorField *op_input_fields, *op_output_fields; 244 CeedOperator_Hip *impl; 245 246 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 247 if (is_setup_done) return CEED_ERROR_SUCCESS; 248 249 CeedCallBackend(CeedOperatorGetData(op, &impl)); 250 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 251 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 252 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 253 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 254 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 255 256 // Allocate 257 CeedCallBackend(CeedCalloc(num_input_fields, &impl->e_vecs_in)); 258 CeedCallBackend(CeedCalloc(num_output_fields, &impl->e_vecs_out)); 259 CeedCallBackend(CeedCalloc(num_input_fields, &impl->skip_rstr_in)); 260 CeedCallBackend(CeedCalloc(num_output_fields, &impl->skip_rstr_out)); 261 CeedCallBackend(CeedCalloc(num_output_fields, &impl->apply_add_basis_out)); 262 CeedCallBackend(CeedCalloc(num_input_fields, &impl->input_field_order)); 263 CeedCallBackend(CeedCalloc(num_output_fields, &impl->output_field_order)); 264 CeedCallBackend(CeedCalloc(num_input_fields, &impl->input_states)); 265 CeedCallBackend(CeedCalloc(num_input_fields, &impl->q_vecs_in)); 266 CeedCallBackend(CeedCalloc(num_output_fields, &impl->q_vecs_out)); 267 impl->num_inputs = num_input_fields; 268 impl->num_outputs = num_output_fields; 269 270 // Set up infield and outfield e-vecs and q-vecs 271 CeedCallBackend( 272 CeedOperatorSetupFields_Hip(qf, op, true, false, impl->skip_rstr_in, NULL, impl->e_vecs_in, impl->q_vecs_in, num_input_fields, Q, num_elem)); 273 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, false, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs_out, 274 impl->q_vecs_out, num_output_fields, Q, num_elem)); 275 276 // Reorder fields to allow reuse of buffers 277 impl->max_active_e_vec_len = 0; 278 { 279 bool is_ordered[CEED_FIELD_MAX]; 280 CeedInt curr_index = 0; 281 282 for (CeedInt i = 0; i < num_input_fields; i++) is_ordered[i] = false; 283 for (CeedInt i = 0; i < num_input_fields; i++) { 284 CeedSize e_vec_len_i; 285 CeedVector vec_i; 286 CeedElemRestriction rstr_i; 287 288 if (is_ordered[i]) continue; 289 is_ordered[i] = true; 290 impl->input_field_order[curr_index] = i; 291 curr_index++; 292 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i)); 293 if (vec_i == CEED_VECTOR_NONE) { 294 // CEED_EVAL_WEIGHT 295 CeedCallBackend(CeedVectorDestroy(&vec_i)); 296 continue; 297 }; 298 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i)); 299 CeedCallBackend(CeedElemRestrictionGetEVectorSize(rstr_i, &e_vec_len_i)); 300 impl->max_active_e_vec_len = e_vec_len_i > impl->max_active_e_vec_len ? e_vec_len_i : impl->max_active_e_vec_len; 301 for (CeedInt j = i + 1; j < num_input_fields; j++) { 302 CeedVector vec_j; 303 CeedElemRestriction rstr_j; 304 305 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[j], &vec_j)); 306 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[j], &rstr_j)); 307 if (rstr_i == rstr_j && vec_i == vec_j) { 308 is_ordered[j] = true; 309 impl->input_field_order[curr_index] = j; 310 curr_index++; 311 } 312 CeedCallBackend(CeedVectorDestroy(&vec_j)); 313 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j)); 314 } 315 CeedCallBackend(CeedVectorDestroy(&vec_i)); 316 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i)); 317 } 318 } 319 { 320 bool is_ordered[CEED_FIELD_MAX]; 321 CeedInt curr_index = 0; 322 323 for (CeedInt i = 0; i < num_output_fields; i++) is_ordered[i] = false; 324 for (CeedInt i = 0; i < num_output_fields; i++) { 325 CeedSize e_vec_len_i; 326 CeedVector vec_i; 327 CeedElemRestriction rstr_i; 328 329 if (is_ordered[i]) continue; 330 is_ordered[i] = true; 331 impl->output_field_order[curr_index] = i; 332 curr_index++; 333 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec_i)); 334 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &rstr_i)); 335 CeedCallBackend(CeedElemRestrictionGetEVectorSize(rstr_i, &e_vec_len_i)); 336 impl->max_active_e_vec_len = e_vec_len_i > impl->max_active_e_vec_len ? e_vec_len_i : impl->max_active_e_vec_len; 337 for (CeedInt j = i + 1; j < num_output_fields; j++) { 338 CeedVector vec_j; 339 CeedElemRestriction rstr_j; 340 341 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec_j)); 342 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &rstr_j)); 343 if (rstr_i == rstr_j && vec_i == vec_j) { 344 is_ordered[j] = true; 345 impl->output_field_order[curr_index] = j; 346 curr_index++; 347 } 348 CeedCallBackend(CeedVectorDestroy(&vec_j)); 349 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j)); 350 } 351 CeedCallBackend(CeedVectorDestroy(&vec_i)); 352 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i)); 353 } 354 } 355 CeedCallBackend(CeedClearWorkVectors(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len)); 356 { 357 // Create two work vectors for diagonal assembly 358 CeedVector temp_1, temp_2; 359 360 CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_1)); 361 CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_2)); 362 CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_1)); 363 CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_2)); 364 } 365 CeedCallBackend(CeedOperatorSetSetupDone(op)); 366 CeedCallBackend(CeedQFunctionDestroy(&qf)); 367 return CEED_ERROR_SUCCESS; 368 } 369 370 //------------------------------------------------------------------------------ 371 // Restrict Operator Inputs 372 //------------------------------------------------------------------------------ 373 static inline int CeedOperatorInputRestrict_Hip(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field, 374 CeedVector in_vec, CeedVector active_e_vec, const bool skip_active, CeedOperator_Hip *impl, 375 CeedRequest *request) { 376 bool is_active = false; 377 CeedVector l_vec, e_vec = impl->e_vecs_in[input_field]; 378 379 // Get input vector 380 CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); 381 is_active = l_vec == CEED_VECTOR_ACTIVE; 382 if (is_active && skip_active) return CEED_ERROR_SUCCESS; 383 if (is_active) { 384 l_vec = in_vec; 385 if (!e_vec) e_vec = active_e_vec; 386 } 387 388 // Restriction action 389 if (e_vec) { 390 // Restrict, if necessary 391 if (!impl->skip_rstr_in[input_field]) { 392 uint64_t state; 393 394 CeedCallBackend(CeedVectorGetState(l_vec, &state)); 395 if (is_active || state != impl->input_states[input_field]) { 396 CeedElemRestriction elem_rstr; 397 398 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_field, &elem_rstr)); 399 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, l_vec, e_vec, request)); 400 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 401 } 402 impl->input_states[input_field] = state; 403 } 404 } 405 if (!is_active) CeedCallBackend(CeedVectorDestroy(&l_vec)); 406 return CEED_ERROR_SUCCESS; 407 } 408 409 //------------------------------------------------------------------------------ 410 // Input Basis Action 411 //------------------------------------------------------------------------------ 412 static inline int CeedOperatorInputBasis_Hip(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field, 413 CeedVector in_vec, CeedVector active_e_vec, CeedInt num_elem, const bool skip_active, 414 CeedOperator_Hip *impl) { 415 bool is_active = false; 416 CeedEvalMode eval_mode; 417 CeedVector l_vec, e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field]; 418 419 // Skip active input 420 CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); 421 is_active = l_vec == CEED_VECTOR_ACTIVE; 422 if (is_active && skip_active) return CEED_ERROR_SUCCESS; 423 if (is_active) { 424 l_vec = in_vec; 425 if (!e_vec) e_vec = active_e_vec; 426 } 427 428 // Basis action 429 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode)); 430 switch (eval_mode) { 431 case CEED_EVAL_NONE: { 432 const CeedScalar *e_vec_array; 433 434 if (e_vec) { 435 CeedCallBackend(CeedVectorGetArrayRead(e_vec, CEED_MEM_DEVICE, &e_vec_array)); 436 } else { 437 CeedCallBackend(CeedVectorGetArrayRead(l_vec, CEED_MEM_DEVICE, &e_vec_array)); 438 } 439 CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array)); 440 break; 441 } 442 case CEED_EVAL_INTERP: 443 case CEED_EVAL_GRAD: 444 case CEED_EVAL_DIV: 445 case CEED_EVAL_CURL: { 446 CeedBasis basis; 447 448 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_field, &basis)); 449 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, eval_mode, e_vec, q_vec)); 450 CeedCallBackend(CeedBasisDestroy(&basis)); 451 break; 452 } 453 case CEED_EVAL_WEIGHT: 454 break; // No action 455 } 456 if (!is_active) CeedCallBackend(CeedVectorDestroy(&l_vec)); 457 return CEED_ERROR_SUCCESS; 458 } 459 460 //------------------------------------------------------------------------------ 461 // Restore Input Vectors 462 //------------------------------------------------------------------------------ 463 static inline int CeedOperatorInputRestore_Hip(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field, 464 CeedVector in_vec, CeedVector active_e_vec, const bool skip_active, CeedOperator_Hip *impl) { 465 bool is_active = false; 466 CeedEvalMode eval_mode; 467 CeedVector l_vec, e_vec = impl->e_vecs_in[input_field]; 468 469 // Skip active input 470 CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); 471 is_active = l_vec == CEED_VECTOR_ACTIVE; 472 if (is_active && skip_active) return CEED_ERROR_SUCCESS; 473 if (is_active) { 474 l_vec = in_vec; 475 if (!e_vec) e_vec = active_e_vec; 476 } 477 478 // Restore e-vec 479 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode)); 480 if (eval_mode == CEED_EVAL_NONE) { 481 const CeedScalar *e_vec_array; 482 483 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_in[input_field], CEED_MEM_DEVICE, (CeedScalar **)&e_vec_array)); 484 if (e_vec) { 485 CeedCallBackend(CeedVectorRestoreArrayRead(e_vec, &e_vec_array)); 486 } else { 487 CeedCallBackend(CeedVectorRestoreArrayRead(l_vec, &e_vec_array)); 488 } 489 } 490 if (!is_active) CeedCallBackend(CeedVectorDestroy(&l_vec)); 491 return CEED_ERROR_SUCCESS; 492 } 493 494 //------------------------------------------------------------------------------ 495 // Apply and add to output 496 //------------------------------------------------------------------------------ 497 static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 498 CeedInt Q, num_elem, num_input_fields, num_output_fields; 499 Ceed ceed; 500 CeedVector active_e_vec; 501 CeedQFunctionField *qf_input_fields, *qf_output_fields; 502 CeedQFunction qf; 503 CeedOperatorField *op_input_fields, *op_output_fields; 504 CeedOperator_Hip *impl; 505 506 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 507 CeedCallBackend(CeedOperatorGetData(op, &impl)); 508 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 509 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 510 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 511 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 512 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 513 514 // Setup 515 CeedCallBackend(CeedOperatorSetup_Hip(op)); 516 517 // Work vector 518 CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec)); 519 520 // Process inputs 521 for (CeedInt i = 0; i < num_input_fields; i++) { 522 CeedInt field = impl->input_field_order[i]; 523 524 CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, false, impl, request)); 525 CeedCallBackend(CeedOperatorInputBasis_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, num_elem, false, impl)); 526 } 527 528 // Output pointers, as necessary 529 for (CeedInt i = 0; i < num_output_fields; i++) { 530 CeedEvalMode eval_mode; 531 532 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 533 if (eval_mode == CEED_EVAL_NONE) { 534 CeedScalar *e_vec_array; 535 536 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); 537 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array)); 538 } 539 } 540 541 // Q function 542 CeedCallBackend(CeedQFunctionApply(qf, num_elem * Q, impl->q_vecs_in, impl->q_vecs_out)); 543 544 // Restore input arrays 545 for (CeedInt i = 0; i < num_input_fields; i++) { 546 CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, in_vec, active_e_vec, false, impl)); 547 } 548 549 // Output basis and restriction 550 for (CeedInt i = 0; i < num_output_fields; i++) { 551 bool is_active = false; 552 CeedInt field = impl->output_field_order[i]; 553 CeedEvalMode eval_mode; 554 CeedVector l_vec, e_vec = impl->e_vecs_out[field], q_vec = impl->q_vecs_out[field]; 555 556 // Output vector 557 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[field], &l_vec)); 558 is_active = l_vec == CEED_VECTOR_ACTIVE; 559 if (is_active) { 560 l_vec = out_vec; 561 if (!e_vec) e_vec = active_e_vec; 562 } 563 564 // Basis action 565 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[field], &eval_mode)); 566 switch (eval_mode) { 567 case CEED_EVAL_NONE: 568 break; // No action 569 case CEED_EVAL_INTERP: 570 case CEED_EVAL_GRAD: 571 case CEED_EVAL_DIV: 572 case CEED_EVAL_CURL: { 573 CeedBasis basis; 574 575 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[field], &basis)); 576 if (impl->apply_add_basis_out[field]) { 577 CeedCallBackend(CeedBasisApplyAdd(basis, num_elem, CEED_TRANSPOSE, eval_mode, q_vec, e_vec)); 578 } else { 579 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, q_vec, e_vec)); 580 } 581 CeedCallBackend(CeedBasisDestroy(&basis)); 582 break; 583 } 584 // LCOV_EXCL_START 585 case CEED_EVAL_WEIGHT: { 586 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 587 // LCOV_EXCL_STOP 588 } 589 } 590 591 // Restore evec 592 if (eval_mode == CEED_EVAL_NONE) { 593 CeedScalar *e_vec_array; 594 595 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); 596 CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_vec_array)); 597 } 598 599 // Restrict 600 if (!impl->skip_rstr_out[field]) { 601 CeedElemRestriction elem_rstr; 602 603 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[field], &elem_rstr)); 604 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, e_vec, l_vec, request)); 605 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 606 } 607 if (!is_active) CeedCallBackend(CeedVectorDestroy(&l_vec)); 608 } 609 610 // Return work vector 611 CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec)); 612 CeedCallBackend(CeedDestroy(&ceed)); 613 CeedCallBackend(CeedQFunctionDestroy(&qf)); 614 return CEED_ERROR_SUCCESS; 615 } 616 617 //------------------------------------------------------------------------------ 618 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 619 //------------------------------------------------------------------------------ 620 static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) { 621 bool is_setup_done; 622 CeedInt max_num_points = -1, num_elem, num_input_fields, num_output_fields; 623 CeedQFunctionField *qf_input_fields, *qf_output_fields; 624 CeedQFunction qf; 625 CeedOperatorField *op_input_fields, *op_output_fields; 626 CeedOperator_Hip *impl; 627 628 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 629 if (is_setup_done) return CEED_ERROR_SUCCESS; 630 631 CeedCallBackend(CeedOperatorGetData(op, &impl)); 632 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 633 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 634 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 635 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 636 { 637 CeedElemRestriction rstr_points = NULL; 638 639 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 640 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(rstr_points, &max_num_points)); 641 CeedCallBackend(CeedCalloc(num_elem, &impl->num_points)); 642 for (CeedInt e = 0; e < num_elem; e++) { 643 CeedInt num_points_elem; 644 645 CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem)); 646 impl->num_points[e] = num_points_elem; 647 } 648 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 649 } 650 impl->max_num_points = max_num_points; 651 652 // Allocate 653 CeedCallBackend(CeedCalloc(num_input_fields, &impl->e_vecs_in)); 654 CeedCallBackend(CeedCalloc(num_output_fields, &impl->e_vecs_out)); 655 CeedCallBackend(CeedCalloc(num_input_fields, &impl->skip_rstr_in)); 656 CeedCallBackend(CeedCalloc(num_output_fields, &impl->skip_rstr_out)); 657 CeedCallBackend(CeedCalloc(num_output_fields, &impl->apply_add_basis_out)); 658 CeedCallBackend(CeedCalloc(num_input_fields, &impl->input_field_order)); 659 CeedCallBackend(CeedCalloc(num_output_fields, &impl->output_field_order)); 660 CeedCallBackend(CeedCalloc(num_input_fields, &impl->input_states)); 661 CeedCallBackend(CeedCalloc(num_input_fields, &impl->q_vecs_in)); 662 CeedCallBackend(CeedCalloc(num_output_fields, &impl->q_vecs_out)); 663 impl->num_inputs = num_input_fields; 664 impl->num_outputs = num_output_fields; 665 666 // Set up infield and outfield e-vecs and q-vecs 667 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, true, impl->skip_rstr_in, NULL, impl->e_vecs_in, impl->q_vecs_in, num_input_fields, 668 max_num_points, num_elem)); 669 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs_out, impl->q_vecs_out, 670 num_output_fields, max_num_points, num_elem)); 671 672 // Reorder fields to allow reuse of buffers 673 impl->max_active_e_vec_len = 0; 674 { 675 bool is_ordered[CEED_FIELD_MAX]; 676 CeedInt curr_index = 0; 677 678 for (CeedInt i = 0; i < num_input_fields; i++) is_ordered[i] = false; 679 for (CeedInt i = 0; i < num_input_fields; i++) { 680 CeedSize e_vec_len_i; 681 CeedVector vec_i; 682 CeedElemRestriction rstr_i; 683 684 if (is_ordered[i]) continue; 685 is_ordered[i] = true; 686 impl->input_field_order[curr_index] = i; 687 curr_index++; 688 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i)); 689 if (vec_i == CEED_VECTOR_NONE) { 690 // CEED_EVAL_WEIGHT 691 CeedCallBackend(CeedVectorDestroy(&vec_i)); 692 continue; 693 }; 694 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i)); 695 CeedCallBackend(CeedElemRestrictionGetEVectorSize(rstr_i, &e_vec_len_i)); 696 impl->max_active_e_vec_len = e_vec_len_i > impl->max_active_e_vec_len ? e_vec_len_i : impl->max_active_e_vec_len; 697 for (CeedInt j = i + 1; j < num_input_fields; j++) { 698 CeedVector vec_j; 699 CeedElemRestriction rstr_j; 700 701 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[j], &vec_j)); 702 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[j], &rstr_j)); 703 if (rstr_i == rstr_j && vec_i == vec_j) { 704 is_ordered[j] = true; 705 impl->input_field_order[curr_index] = j; 706 curr_index++; 707 } 708 CeedCallBackend(CeedVectorDestroy(&vec_j)); 709 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j)); 710 } 711 CeedCallBackend(CeedVectorDestroy(&vec_i)); 712 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i)); 713 } 714 } 715 { 716 bool is_ordered[CEED_FIELD_MAX]; 717 CeedInt curr_index = 0; 718 719 for (CeedInt i = 0; i < num_output_fields; i++) is_ordered[i] = false; 720 for (CeedInt i = 0; i < num_output_fields; i++) { 721 CeedSize e_vec_len_i; 722 CeedVector vec_i; 723 CeedElemRestriction rstr_i; 724 725 if (is_ordered[i]) continue; 726 is_ordered[i] = true; 727 impl->output_field_order[curr_index] = i; 728 curr_index++; 729 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec_i)); 730 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &rstr_i)); 731 CeedCallBackend(CeedElemRestrictionGetEVectorSize(rstr_i, &e_vec_len_i)); 732 impl->max_active_e_vec_len = e_vec_len_i > impl->max_active_e_vec_len ? e_vec_len_i : impl->max_active_e_vec_len; 733 for (CeedInt j = i + 1; j < num_output_fields; j++) { 734 CeedVector vec_j; 735 CeedElemRestriction rstr_j; 736 737 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec_j)); 738 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &rstr_j)); 739 if (rstr_i == rstr_j && vec_i == vec_j) { 740 is_ordered[j] = true; 741 impl->output_field_order[curr_index] = j; 742 curr_index++; 743 } 744 CeedCallBackend(CeedVectorDestroy(&vec_j)); 745 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j)); 746 } 747 CeedCallBackend(CeedVectorDestroy(&vec_i)); 748 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i)); 749 } 750 } 751 CeedCallBackend(CeedClearWorkVectors(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len)); 752 { 753 // Create two work vectors for diagonal assembly 754 CeedVector temp_1, temp_2; 755 756 CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_1)); 757 CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_2)); 758 CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_1)); 759 CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_2)); 760 } 761 CeedCallBackend(CeedOperatorSetSetupDone(op)); 762 CeedCallBackend(CeedQFunctionDestroy(&qf)); 763 return CEED_ERROR_SUCCESS; 764 } 765 766 //------------------------------------------------------------------------------ 767 // Input Basis Action AtPoints 768 //------------------------------------------------------------------------------ 769 static inline int CeedOperatorInputBasisAtPoints_Hip(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field, 770 CeedVector in_vec, CeedVector active_e_vec, CeedInt num_elem, const CeedInt *num_points, 771 const bool skip_active, const bool skip_passive, CeedOperator_Hip *impl) { 772 bool is_active = false; 773 CeedEvalMode eval_mode; 774 CeedVector l_vec, e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field]; 775 776 // Skip active input 777 CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); 778 is_active = l_vec == CEED_VECTOR_ACTIVE; 779 if (skip_active && is_active) return CEED_ERROR_SUCCESS; 780 if (skip_passive && !is_active) { 781 CeedCallBackend(CeedVectorDestroy(&l_vec)); 782 return CEED_ERROR_SUCCESS; 783 } 784 if (is_active) { 785 l_vec = in_vec; 786 if (!e_vec) e_vec = active_e_vec; 787 } 788 789 // Basis action 790 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode)); 791 switch (eval_mode) { 792 case CEED_EVAL_NONE: { 793 const CeedScalar *e_vec_array; 794 795 if (e_vec) { 796 CeedCallBackend(CeedVectorGetArrayRead(e_vec, CEED_MEM_DEVICE, &e_vec_array)); 797 } else { 798 CeedCallBackend(CeedVectorGetArrayRead(l_vec, CEED_MEM_DEVICE, &e_vec_array)); 799 } 800 CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array)); 801 break; 802 } 803 case CEED_EVAL_INTERP: 804 case CEED_EVAL_GRAD: 805 case CEED_EVAL_DIV: 806 case CEED_EVAL_CURL: { 807 CeedBasis basis; 808 809 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_field, &basis)); 810 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, e_vec, q_vec)); 811 CeedCallBackend(CeedBasisDestroy(&basis)); 812 break; 813 } 814 case CEED_EVAL_WEIGHT: 815 break; // No action 816 } 817 if (!is_active) CeedCallBackend(CeedVectorDestroy(&l_vec)); 818 return CEED_ERROR_SUCCESS; 819 } 820 821 //------------------------------------------------------------------------------ 822 // Apply and add to output AtPoints 823 //------------------------------------------------------------------------------ 824 static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 825 CeedInt max_num_points, *num_points, num_elem, num_input_fields, num_output_fields; 826 Ceed ceed; 827 CeedVector active_e_vec; 828 CeedQFunctionField *qf_input_fields, *qf_output_fields; 829 CeedQFunction qf; 830 CeedOperatorField *op_input_fields, *op_output_fields; 831 CeedOperator_Hip *impl; 832 833 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 834 CeedCallBackend(CeedOperatorGetData(op, &impl)); 835 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 836 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 837 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 838 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 839 840 // Setup 841 CeedCallBackend(CeedOperatorSetupAtPoints_Hip(op)); 842 num_points = impl->num_points; 843 max_num_points = impl->max_num_points; 844 845 // Work vector 846 CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec)); 847 848 // Get point coordinates 849 if (!impl->point_coords_elem) { 850 CeedVector point_coords = NULL; 851 CeedElemRestriction rstr_points = NULL; 852 853 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 854 CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem)); 855 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 856 CeedCallBackend(CeedVectorDestroy(&point_coords)); 857 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 858 } 859 860 // Process inputs 861 for (CeedInt i = 0; i < num_input_fields; i++) { 862 CeedInt field = impl->input_field_order[i]; 863 864 CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, false, impl, request)); 865 CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, num_elem, 866 num_points, false, false, impl)); 867 } 868 869 // Output pointers, as necessary 870 for (CeedInt i = 0; i < num_output_fields; i++) { 871 CeedEvalMode eval_mode; 872 873 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 874 if (eval_mode == CEED_EVAL_NONE) { 875 CeedScalar *e_vec_array; 876 877 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); 878 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array)); 879 } 880 } 881 882 // Q function 883 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out)); 884 885 // Restore input arrays 886 for (CeedInt i = 0; i < num_input_fields; i++) { 887 CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, in_vec, active_e_vec, false, impl)); 888 } 889 890 // Output basis and restriction 891 for (CeedInt i = 0; i < num_output_fields; i++) { 892 bool is_active = false; 893 CeedInt field = impl->output_field_order[i]; 894 CeedEvalMode eval_mode; 895 CeedVector l_vec, e_vec = impl->e_vecs_out[field], q_vec = impl->q_vecs_out[field]; 896 897 // Output vector 898 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[field], &l_vec)); 899 is_active = l_vec == CEED_VECTOR_ACTIVE; 900 if (is_active) { 901 l_vec = out_vec; 902 if (!e_vec) e_vec = active_e_vec; 903 } 904 905 // Basis action 906 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[field], &eval_mode)); 907 switch (eval_mode) { 908 case CEED_EVAL_NONE: 909 break; // No action 910 case CEED_EVAL_INTERP: 911 case CEED_EVAL_GRAD: 912 case CEED_EVAL_DIV: 913 case CEED_EVAL_CURL: { 914 CeedBasis basis; 915 916 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[field], &basis)); 917 if (impl->apply_add_basis_out[field]) { 918 CeedCallBackend(CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec)); 919 } else { 920 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec)); 921 } 922 CeedCallBackend(CeedBasisDestroy(&basis)); 923 break; 924 } 925 // LCOV_EXCL_START 926 case CEED_EVAL_WEIGHT: { 927 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 928 // LCOV_EXCL_STOP 929 } 930 } 931 932 // Restore evec 933 if (eval_mode == CEED_EVAL_NONE) { 934 CeedScalar *e_vec_array; 935 936 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); 937 CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_vec_array)); 938 } 939 940 // Restrict 941 if (!impl->skip_rstr_out[field]) { 942 CeedElemRestriction elem_rstr; 943 944 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[field], &elem_rstr)); 945 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, e_vec, l_vec, request)); 946 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 947 } 948 if (!is_active) CeedCallBackend(CeedVectorDestroy(&l_vec)); 949 } 950 951 // Restore work vector 952 CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec)); 953 CeedCallBackend(CeedDestroy(&ceed)); 954 CeedCallBackend(CeedQFunctionDestroy(&qf)); 955 return CEED_ERROR_SUCCESS; 956 } 957 958 //------------------------------------------------------------------------------ 959 // Linear QFunction Assembly Core 960 //------------------------------------------------------------------------------ 961 static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 962 CeedRequest *request) { 963 Ceed ceed, ceed_parent; 964 CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size; 965 CeedScalar *assembled_array; 966 CeedVector *active_inputs; 967 CeedQFunctionField *qf_input_fields, *qf_output_fields; 968 CeedQFunction qf; 969 CeedOperatorField *op_input_fields, *op_output_fields; 970 CeedOperator_Hip *impl; 971 972 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 973 CeedCallBackend(CeedOperatorGetFallbackParentCeed(op, &ceed_parent)); 974 CeedCallBackend(CeedOperatorGetData(op, &impl)); 975 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 976 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 977 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 978 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 979 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 980 active_inputs = impl->qf_active_in; 981 num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 982 983 // Setup 984 CeedCallBackend(CeedOperatorSetup_Hip(op)); 985 986 // Process inputs 987 for (CeedInt i = 0; i < num_input_fields; i++) { 988 CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl, request)); 989 CeedCallBackend(CeedOperatorInputBasis_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, true, impl)); 990 } 991 992 // Count number of active input fields 993 if (!num_active_in) { 994 for (CeedInt i = 0; i < num_input_fields; i++) { 995 CeedScalar *q_vec_array; 996 CeedVector l_vec; 997 998 // Check if active input 999 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &l_vec)); 1000 if (l_vec == CEED_VECTOR_ACTIVE) { 1001 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 1002 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 1003 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, &q_vec_array)); 1004 CeedCallBackend(CeedRealloc(num_active_in + size, &active_inputs)); 1005 for (CeedInt field = 0; field < size; field++) { 1006 CeedSize q_size = (CeedSize)Q * num_elem; 1007 1008 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_inputs[num_active_in + field])); 1009 CeedCallBackend( 1010 CeedVectorSetArray(active_inputs[num_active_in + field], CEED_MEM_DEVICE, CEED_USE_POINTER, &q_vec_array[field * Q * num_elem])); 1011 } 1012 num_active_in += size; 1013 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array)); 1014 } 1015 CeedCallBackend(CeedVectorDestroy(&l_vec)); 1016 } 1017 impl->num_active_in = num_active_in; 1018 impl->qf_active_in = active_inputs; 1019 } 1020 1021 // Count number of active output fields 1022 if (!num_active_out) { 1023 for (CeedInt i = 0; i < num_output_fields; i++) { 1024 CeedVector l_vec; 1025 1026 // Check if active output 1027 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &l_vec)); 1028 if (l_vec == CEED_VECTOR_ACTIVE) { 1029 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 1030 num_active_out += size; 1031 } 1032 CeedCallBackend(CeedVectorDestroy(&l_vec)); 1033 } 1034 impl->num_active_out = num_active_out; 1035 } 1036 1037 // Check sizes 1038 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 1039 1040 // Build objects if needed 1041 if (build_objects) { 1042 CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out; 1043 CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */ 1044 1045 // Create output restriction 1046 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out, 1047 (CeedSize)num_active_in * (CeedSize)num_active_out * (CeedSize)num_elem * (CeedSize)Q, strides, 1048 rstr)); 1049 // Create assembled vector 1050 CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled)); 1051 } 1052 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 1053 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_DEVICE, &assembled_array)); 1054 1055 // Assemble QFunction 1056 for (CeedInt in = 0; in < num_active_in; in++) { 1057 // Set Inputs 1058 CeedCallBackend(CeedVectorSetValue(active_inputs[in], 1.0)); 1059 if (num_active_in > 1) { 1060 CeedCallBackend(CeedVectorSetValue(active_inputs[(in + num_active_in - 1) % num_active_in], 0.0)); 1061 } 1062 // Set Outputs 1063 for (CeedInt out = 0; out < num_output_fields; out++) { 1064 CeedVector l_vec; 1065 1066 // Check if active output 1067 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &l_vec)); 1068 if (l_vec == CEED_VECTOR_ACTIVE) { 1069 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, CEED_USE_POINTER, assembled_array)); 1070 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size)); 1071 assembled_array += size * Q * num_elem; // Advance the pointer by the size of the output 1072 } 1073 CeedCallBackend(CeedVectorDestroy(&l_vec)); 1074 } 1075 // Apply QFunction 1076 CeedCallBackend(CeedQFunctionApply(qf, Q * num_elem, impl->q_vecs_in, impl->q_vecs_out)); 1077 } 1078 1079 // Un-set output q-vecs to prevent accidental overwrite of Assembled 1080 for (CeedInt out = 0; out < num_output_fields; out++) { 1081 CeedVector l_vec; 1082 1083 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &l_vec)); 1084 if (l_vec == CEED_VECTOR_ACTIVE) { 1085 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, NULL)); 1086 } 1087 CeedCallBackend(CeedVectorDestroy(&l_vec)); 1088 } 1089 1090 // Restore input arrays 1091 for (CeedInt i = 0; i < num_input_fields; i++) { 1092 CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl)); 1093 } 1094 1095 // Restore output 1096 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 1097 CeedCallBackend(CeedDestroy(&ceed)); 1098 CeedCallBackend(CeedDestroy(&ceed_parent)); 1099 CeedCallBackend(CeedQFunctionDestroy(&qf)); 1100 return CEED_ERROR_SUCCESS; 1101 } 1102 1103 //------------------------------------------------------------------------------ 1104 // Assemble Linear QFunction 1105 //------------------------------------------------------------------------------ 1106 static int CeedOperatorLinearAssembleQFunction_Hip(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 1107 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, true, assembled, rstr, request); 1108 } 1109 1110 //------------------------------------------------------------------------------ 1111 // Update Assembled Linear QFunction 1112 //------------------------------------------------------------------------------ 1113 static int CeedOperatorLinearAssembleQFunctionUpdate_Hip(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 1114 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, false, &assembled, &rstr, request); 1115 } 1116 1117 //------------------------------------------------------------------------------ 1118 // Assemble Diagonal Setup 1119 //------------------------------------------------------------------------------ 1120 static inline int CeedOperatorAssembleDiagonalSetup_Hip(CeedOperator op) { 1121 Ceed ceed; 1122 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 1123 CeedInt q_comp, num_nodes, num_qpts; 1124 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 1125 CeedBasis basis_in = NULL, basis_out = NULL; 1126 CeedQFunctionField *qf_fields; 1127 CeedQFunction qf; 1128 CeedOperatorField *op_fields; 1129 CeedOperator_Hip *impl; 1130 1131 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1132 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1133 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 1134 1135 // Determine active input basis 1136 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 1137 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 1138 for (CeedInt i = 0; i < num_input_fields; i++) { 1139 CeedVector vec; 1140 1141 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1142 if (vec == CEED_VECTOR_ACTIVE) { 1143 CeedEvalMode eval_mode; 1144 CeedBasis basis; 1145 1146 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 1147 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, 1148 "Backend does not implement operator diagonal assembly with multiple active bases"); 1149 if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in)); 1150 CeedCallBackend(CeedBasisDestroy(&basis)); 1151 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1152 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 1153 if (eval_mode != CEED_EVAL_WEIGHT) { 1154 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 1155 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 1156 for (CeedInt d = 0; d < q_comp; d++) eval_modes_in[num_eval_modes_in + d] = eval_mode; 1157 num_eval_modes_in += q_comp; 1158 } 1159 } 1160 CeedCallBackend(CeedVectorDestroy(&vec)); 1161 } 1162 1163 // Determine active output basis 1164 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 1165 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 1166 for (CeedInt i = 0; i < num_output_fields; i++) { 1167 CeedVector vec; 1168 1169 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1170 if (vec == CEED_VECTOR_ACTIVE) { 1171 CeedBasis basis; 1172 CeedEvalMode eval_mode; 1173 1174 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 1175 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 1176 "Backend does not implement operator diagonal assembly with multiple active bases"); 1177 if (!basis_out) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_out)); 1178 CeedCallBackend(CeedBasisDestroy(&basis)); 1179 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1180 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1181 if (eval_mode != CEED_EVAL_WEIGHT) { 1182 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 1183 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 1184 for (CeedInt d = 0; d < q_comp; d++) eval_modes_out[num_eval_modes_out + d] = eval_mode; 1185 num_eval_modes_out += q_comp; 1186 } 1187 } 1188 CeedCallBackend(CeedVectorDestroy(&vec)); 1189 } 1190 1191 // Operator data struct 1192 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1193 CeedCallBackend(CeedCalloc(1, &impl->diag)); 1194 CeedOperatorDiag_Hip *diag = impl->diag; 1195 1196 // Basis matrices 1197 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes)); 1198 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes; 1199 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); 1200 const CeedInt interp_bytes = num_nodes * num_qpts * sizeof(CeedScalar); 1201 const CeedInt eval_modes_bytes = sizeof(CeedEvalMode); 1202 bool has_eval_none = false; 1203 1204 // CEED_EVAL_NONE 1205 for (CeedInt i = 0; i < num_eval_modes_in; i++) has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 1206 for (CeedInt i = 0; i < num_eval_modes_out; i++) has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 1207 if (has_eval_none) { 1208 CeedScalar *identity = NULL; 1209 1210 CeedCallBackend(CeedCalloc(num_nodes * num_qpts, &identity)); 1211 for (CeedInt i = 0; i < (num_nodes < num_qpts ? num_nodes : num_qpts); i++) identity[i * num_nodes + i] = 1.0; 1212 CeedCallHip(ceed, hipMalloc((void **)&diag->d_identity, interp_bytes)); 1213 CeedCallHip(ceed, hipMemcpy(diag->d_identity, identity, interp_bytes, hipMemcpyHostToDevice)); 1214 CeedCallBackend(CeedFree(&identity)); 1215 } 1216 1217 // CEED_EVAL_INTERP, CEED_EVAL_GRAD, CEED_EVAL_DIV, and CEED_EVAL_CURL 1218 for (CeedInt in = 0; in < 2; in++) { 1219 CeedFESpace fespace; 1220 CeedBasis basis = in ? basis_in : basis_out; 1221 1222 CeedCallBackend(CeedBasisGetFESpace(basis, &fespace)); 1223 switch (fespace) { 1224 case CEED_FE_SPACE_H1: { 1225 CeedInt q_comp_interp, q_comp_grad; 1226 const CeedScalar *interp, *grad; 1227 CeedScalar *d_interp, *d_grad; 1228 1229 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 1230 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_grad)); 1231 1232 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 1233 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 1234 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 1235 CeedCallBackend(CeedBasisGetGrad(basis, &grad)); 1236 CeedCallHip(ceed, hipMalloc((void **)&d_grad, interp_bytes * q_comp_grad)); 1237 CeedCallHip(ceed, hipMemcpy(d_grad, grad, interp_bytes * q_comp_grad, hipMemcpyHostToDevice)); 1238 if (in) { 1239 diag->d_interp_in = d_interp; 1240 diag->d_grad_in = d_grad; 1241 } else { 1242 diag->d_interp_out = d_interp; 1243 diag->d_grad_out = d_grad; 1244 } 1245 } break; 1246 case CEED_FE_SPACE_HDIV: { 1247 CeedInt q_comp_interp, q_comp_div; 1248 const CeedScalar *interp, *div; 1249 CeedScalar *d_interp, *d_div; 1250 1251 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 1252 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_div)); 1253 1254 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 1255 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 1256 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 1257 CeedCallBackend(CeedBasisGetDiv(basis, &div)); 1258 CeedCallHip(ceed, hipMalloc((void **)&d_div, interp_bytes * q_comp_div)); 1259 CeedCallHip(ceed, hipMemcpy(d_div, div, interp_bytes * q_comp_div, hipMemcpyHostToDevice)); 1260 if (in) { 1261 diag->d_interp_in = d_interp; 1262 diag->d_div_in = d_div; 1263 } else { 1264 diag->d_interp_out = d_interp; 1265 diag->d_div_out = d_div; 1266 } 1267 } break; 1268 case CEED_FE_SPACE_HCURL: { 1269 CeedInt q_comp_interp, q_comp_curl; 1270 const CeedScalar *interp, *curl; 1271 CeedScalar *d_interp, *d_curl; 1272 1273 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 1274 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_curl)); 1275 1276 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 1277 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 1278 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 1279 CeedCallBackend(CeedBasisGetCurl(basis, &curl)); 1280 CeedCallHip(ceed, hipMalloc((void **)&d_curl, interp_bytes * q_comp_curl)); 1281 CeedCallHip(ceed, hipMemcpy(d_curl, curl, interp_bytes * q_comp_curl, hipMemcpyHostToDevice)); 1282 if (in) { 1283 diag->d_interp_in = d_interp; 1284 diag->d_curl_in = d_curl; 1285 } else { 1286 diag->d_interp_out = d_interp; 1287 diag->d_curl_out = d_curl; 1288 } 1289 } break; 1290 } 1291 } 1292 1293 // Arrays of eval_modes 1294 CeedCallHip(ceed, hipMalloc((void **)&diag->d_eval_modes_in, num_eval_modes_in * eval_modes_bytes)); 1295 CeedCallHip(ceed, hipMemcpy(diag->d_eval_modes_in, eval_modes_in, num_eval_modes_in * eval_modes_bytes, hipMemcpyHostToDevice)); 1296 CeedCallHip(ceed, hipMalloc((void **)&diag->d_eval_modes_out, num_eval_modes_out * eval_modes_bytes)); 1297 CeedCallHip(ceed, hipMemcpy(diag->d_eval_modes_out, eval_modes_out, num_eval_modes_out * eval_modes_bytes, hipMemcpyHostToDevice)); 1298 CeedCallBackend(CeedFree(&eval_modes_in)); 1299 CeedCallBackend(CeedFree(&eval_modes_out)); 1300 CeedCallBackend(CeedDestroy(&ceed)); 1301 CeedCallBackend(CeedBasisDestroy(&basis_in)); 1302 CeedCallBackend(CeedBasisDestroy(&basis_out)); 1303 CeedCallBackend(CeedQFunctionDestroy(&qf)); 1304 return CEED_ERROR_SUCCESS; 1305 } 1306 1307 //------------------------------------------------------------------------------ 1308 // Assemble Diagonal Setup (Compilation) 1309 //------------------------------------------------------------------------------ 1310 static inline int CeedOperatorAssembleDiagonalSetupCompile_Hip(CeedOperator op, CeedInt use_ceedsize_idx, const bool is_point_block) { 1311 Ceed ceed; 1312 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 1313 CeedInt num_comp, q_comp, num_nodes, num_qpts; 1314 CeedBasis basis_in = NULL, basis_out = NULL; 1315 CeedQFunctionField *qf_fields; 1316 CeedQFunction qf; 1317 CeedOperatorField *op_fields; 1318 CeedOperator_Hip *impl; 1319 1320 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1321 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1322 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 1323 1324 // Determine active input basis 1325 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 1326 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 1327 for (CeedInt i = 0; i < num_input_fields; i++) { 1328 CeedVector vec; 1329 1330 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1331 if (vec == CEED_VECTOR_ACTIVE) { 1332 CeedEvalMode eval_mode; 1333 CeedBasis basis; 1334 1335 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 1336 if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in)); 1337 CeedCallBackend(CeedBasisDestroy(&basis)); 1338 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1339 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 1340 if (eval_mode != CEED_EVAL_WEIGHT) { 1341 num_eval_modes_in += q_comp; 1342 } 1343 } 1344 CeedCallBackend(CeedVectorDestroy(&vec)); 1345 } 1346 1347 // Determine active output basis 1348 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 1349 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 1350 for (CeedInt i = 0; i < num_output_fields; i++) { 1351 CeedVector vec; 1352 1353 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1354 if (vec == CEED_VECTOR_ACTIVE) { 1355 CeedEvalMode eval_mode; 1356 CeedBasis basis; 1357 1358 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 1359 if (!basis_out) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_out)); 1360 CeedCallBackend(CeedBasisDestroy(&basis)); 1361 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1362 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1363 if (eval_mode != CEED_EVAL_WEIGHT) { 1364 num_eval_modes_out += q_comp; 1365 } 1366 } 1367 CeedCallBackend(CeedVectorDestroy(&vec)); 1368 } 1369 1370 // Operator data struct 1371 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1372 CeedOperatorDiag_Hip *diag = impl->diag; 1373 1374 // Assemble kernel 1375 const char diagonal_kernel_source[] = "// Diagonal assembly source\n#include <ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h>\n"; 1376 hipModule_t *module = is_point_block ? &diag->module_point_block : &diag->module; 1377 CeedInt elems_per_block = 1; 1378 1379 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes)); 1380 CeedCallBackend(CeedBasisGetNumComponents(basis_in, &num_comp)); 1381 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes; 1382 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); 1383 CeedCallHip(ceed, CeedCompile_Hip(ceed, diagonal_kernel_source, module, 8, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 1384 num_eval_modes_out, "NUM_COMP", num_comp, "NUM_NODES", num_nodes, "NUM_QPTS", num_qpts, "USE_CEEDSIZE", 1385 use_ceedsize_idx, "USE_POINT_BLOCK", is_point_block ? 1 : 0, "BLOCK_SIZE", num_nodes * elems_per_block)); 1386 CeedCallHip(ceed, CeedGetKernel_Hip(ceed, *module, "LinearDiagonal", is_point_block ? &diag->LinearPointBlock : &diag->LinearDiagonal)); 1387 CeedCallBackend(CeedDestroy(&ceed)); 1388 CeedCallBackend(CeedBasisDestroy(&basis_in)); 1389 CeedCallBackend(CeedBasisDestroy(&basis_out)); 1390 CeedCallBackend(CeedQFunctionDestroy(&qf)); 1391 return CEED_ERROR_SUCCESS; 1392 } 1393 1394 //------------------------------------------------------------------------------ 1395 // Assemble Diagonal Core 1396 //------------------------------------------------------------------------------ 1397 static inline int CeedOperatorAssembleDiagonalCore_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool is_point_block) { 1398 Ceed ceed; 1399 CeedInt num_elem, num_nodes; 1400 CeedScalar *elem_diag_array; 1401 const CeedScalar *assembled_qf_array; 1402 CeedVector assembled_qf = NULL, elem_diag; 1403 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out, diag_rstr; 1404 CeedOperator_Hip *impl; 1405 1406 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1407 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1408 1409 // Assemble QFunction 1410 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, request)); 1411 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 1412 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 1413 1414 // Setup 1415 if (!impl->diag) CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Hip(op)); 1416 CeedOperatorDiag_Hip *diag = impl->diag; 1417 1418 assert(diag != NULL); 1419 1420 // Assemble kernel if needed 1421 if ((!is_point_block && !diag->LinearDiagonal) || (is_point_block && !diag->LinearPointBlock)) { 1422 CeedSize assembled_length, assembled_qf_length; 1423 CeedInt use_ceedsize_idx = 0; 1424 CeedCallBackend(CeedVectorGetLength(assembled, &assembled_length)); 1425 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 1426 if ((assembled_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 1427 1428 CeedCallBackend(CeedOperatorAssembleDiagonalSetupCompile_Hip(op, use_ceedsize_idx, is_point_block)); 1429 } 1430 1431 // Restriction and diagonal vector 1432 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 1433 CeedCheck(rstr_in == rstr_out, ceed, CEED_ERROR_BACKEND, 1434 "Cannot assemble operator diagonal with different input and output active element restrictions"); 1435 if (!is_point_block && !diag->diag_rstr) { 1436 CeedCallBackend(CeedElemRestrictionCreateUnsignedCopy(rstr_out, &diag->diag_rstr)); 1437 CeedCallBackend(CeedElemRestrictionCreateVector(diag->diag_rstr, NULL, &diag->elem_diag)); 1438 } else if (is_point_block && !diag->point_block_diag_rstr) { 1439 CeedCallBackend(CeedOperatorCreateActivePointBlockRestriction(rstr_out, &diag->point_block_diag_rstr)); 1440 CeedCallBackend(CeedElemRestrictionCreateVector(diag->point_block_diag_rstr, NULL, &diag->point_block_elem_diag)); 1441 } 1442 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_in)); 1443 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out)); 1444 diag_rstr = is_point_block ? diag->point_block_diag_rstr : diag->diag_rstr; 1445 elem_diag = is_point_block ? diag->point_block_elem_diag : diag->elem_diag; 1446 CeedCallBackend(CeedVectorSetValue(elem_diag, 0.0)); 1447 1448 // Only assemble diagonal if the basis has nodes, otherwise inputs are null pointers 1449 CeedCallBackend(CeedElemRestrictionGetElementSize(diag_rstr, &num_nodes)); 1450 if (num_nodes > 0) { 1451 // Assemble element operator diagonals 1452 CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem)); 1453 CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array)); 1454 1455 // Compute the diagonal of B^T D B 1456 CeedInt elems_per_block = 1; 1457 CeedInt grid = CeedDivUpInt(num_elem, elems_per_block); 1458 void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_div_in, 1459 &diag->d_curl_in, &diag->d_interp_out, &diag->d_grad_out, &diag->d_div_out, &diag->d_curl_out, 1460 &diag->d_eval_modes_in, &diag->d_eval_modes_out, &assembled_qf_array, &elem_diag_array}; 1461 1462 if (is_point_block) { 1463 CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->LinearPointBlock, grid, num_nodes, 1, elems_per_block, args)); 1464 } else { 1465 CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->LinearDiagonal, grid, num_nodes, 1, elems_per_block, args)); 1466 } 1467 1468 // Restore arrays 1469 CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array)); 1470 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 1471 } 1472 1473 // Assemble local operator diagonal 1474 CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request)); 1475 1476 // Cleanup 1477 CeedCallBackend(CeedDestroy(&ceed)); 1478 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1479 return CEED_ERROR_SUCCESS; 1480 } 1481 1482 //------------------------------------------------------------------------------ 1483 // Assemble Linear Diagonal 1484 //------------------------------------------------------------------------------ 1485 static int CeedOperatorLinearAssembleAddDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1486 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, false)); 1487 return CEED_ERROR_SUCCESS; 1488 } 1489 1490 //------------------------------------------------------------------------------ 1491 // Assemble Linear Point Block Diagonal 1492 //------------------------------------------------------------------------------ 1493 static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1494 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, true)); 1495 return CEED_ERROR_SUCCESS; 1496 } 1497 1498 //------------------------------------------------------------------------------ 1499 // Single Operator Assembly Setup 1500 //------------------------------------------------------------------------------ 1501 static int CeedSingleOperatorAssembleSetup_Hip(CeedOperator op, CeedInt use_ceedsize_idx) { 1502 Ceed ceed; 1503 Ceed_Hip *hip_data; 1504 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 1505 CeedInt elem_size_in, num_qpts_in = 0, num_comp_in, elem_size_out, num_qpts_out, num_comp_out, q_comp; 1506 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 1507 CeedElemRestriction rstr_in = NULL, rstr_out = NULL; 1508 CeedBasis basis_in = NULL, basis_out = NULL; 1509 CeedQFunctionField *qf_fields; 1510 CeedQFunction qf; 1511 CeedOperatorField *input_fields, *output_fields; 1512 CeedOperator_Hip *impl; 1513 1514 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1515 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1516 1517 // Get intput and output fields 1518 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields)); 1519 1520 // Determine active input basis eval mode 1521 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1522 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 1523 for (CeedInt i = 0; i < num_input_fields; i++) { 1524 CeedVector vec; 1525 1526 CeedCallBackend(CeedOperatorFieldGetVector(input_fields[i], &vec)); 1527 if (vec == CEED_VECTOR_ACTIVE) { 1528 CeedEvalMode eval_mode; 1529 CeedElemRestriction elem_rstr; 1530 CeedBasis basis; 1531 1532 CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis)); 1533 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases"); 1534 if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in)); 1535 CeedCallBackend(CeedBasisDestroy(&basis)); 1536 CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &elem_rstr)); 1537 if (!rstr_in) CeedCallBackend(CeedElemRestrictionReferenceCopy(elem_rstr, &rstr_in)); 1538 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 1539 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 1540 if (basis_in == CEED_BASIS_NONE) num_qpts_in = elem_size_in; 1541 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts_in)); 1542 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1543 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 1544 if (eval_mode != CEED_EVAL_WEIGHT) { 1545 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 1546 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 1547 for (CeedInt d = 0; d < q_comp; d++) { 1548 eval_modes_in[num_eval_modes_in + d] = eval_mode; 1549 } 1550 num_eval_modes_in += q_comp; 1551 } 1552 } 1553 CeedCallBackend(CeedVectorDestroy(&vec)); 1554 } 1555 1556 // Determine active output basis; basis_out and rstr_out only used if same as input, TODO 1557 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 1558 for (CeedInt i = 0; i < num_output_fields; i++) { 1559 CeedVector vec; 1560 1561 CeedCallBackend(CeedOperatorFieldGetVector(output_fields[i], &vec)); 1562 if (vec == CEED_VECTOR_ACTIVE) { 1563 CeedEvalMode eval_mode; 1564 CeedElemRestriction elem_rstr; 1565 CeedBasis basis; 1566 1567 CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis)); 1568 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 1569 "Backend does not implement operator assembly with multiple active bases"); 1570 if (!basis_out) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_out)); 1571 CeedCallBackend(CeedBasisDestroy(&basis)); 1572 CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &elem_rstr)); 1573 if (!rstr_out) CeedCallBackend(CeedElemRestrictionReferenceCopy(elem_rstr, &rstr_out)); 1574 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 1575 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 1576 if (basis_out == CEED_BASIS_NONE) num_qpts_out = elem_size_out; 1577 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_out, &num_qpts_out)); 1578 CeedCheck(num_qpts_in == num_qpts_out, ceed, CEED_ERROR_UNSUPPORTED, 1579 "Active input and output bases must have the same number of quadrature points"); 1580 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1581 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1582 if (eval_mode != CEED_EVAL_WEIGHT) { 1583 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 1584 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 1585 for (CeedInt d = 0; d < q_comp; d++) { 1586 eval_modes_out[num_eval_modes_out + d] = eval_mode; 1587 } 1588 num_eval_modes_out += q_comp; 1589 } 1590 } 1591 CeedCallBackend(CeedVectorDestroy(&vec)); 1592 } 1593 CeedCheck(num_eval_modes_in > 0 && num_eval_modes_out > 0, ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs"); 1594 1595 CeedCallBackend(CeedCalloc(1, &impl->asmb)); 1596 CeedOperatorAssemble_Hip *asmb = impl->asmb; 1597 asmb->elems_per_block = 1; 1598 asmb->block_size_x = elem_size_in; 1599 asmb->block_size_y = elem_size_out; 1600 1601 CeedCallBackend(CeedGetData(ceed, &hip_data)); 1602 bool fallback = asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block > hip_data->device_prop.maxThreadsPerBlock; 1603 1604 if (fallback) { 1605 // Use fallback kernel with 1D threadblock 1606 asmb->block_size_y = 1; 1607 } 1608 1609 // Compile kernels 1610 const char assembly_kernel_source[] = "// Full assembly source\n#include <ceed/jit-source/hip/hip-ref-operator-assemble.h>\n"; 1611 1612 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &num_comp_in)); 1613 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_out, &num_comp_out)); 1614 CeedCallBackend(CeedCompile_Hip(ceed, assembly_kernel_source, &asmb->module, 10, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 1615 num_eval_modes_out, "NUM_COMP_IN", num_comp_in, "NUM_COMP_OUT", num_comp_out, "NUM_NODES_IN", elem_size_in, 1616 "NUM_NODES_OUT", elem_size_out, "NUM_QPTS", num_qpts_in, "BLOCK_SIZE", 1617 asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block, "BLOCK_SIZE_Y", asmb->block_size_y, "USE_CEEDSIZE", 1618 use_ceedsize_idx)); 1619 CeedCallBackend(CeedGetKernel_Hip(ceed, asmb->module, "LinearAssemble", &asmb->LinearAssemble)); 1620 1621 // Load into B_in, in order that they will be used in eval_modes_in 1622 { 1623 const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof(CeedScalar); 1624 CeedInt d_in = 0; 1625 CeedEvalMode eval_modes_in_prev = CEED_EVAL_NONE; 1626 bool has_eval_none = false; 1627 CeedScalar *identity = NULL; 1628 1629 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 1630 has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 1631 } 1632 if (has_eval_none) { 1633 CeedCallBackend(CeedCalloc(elem_size_in * num_qpts_in, &identity)); 1634 for (CeedInt i = 0; i < (elem_size_in < num_qpts_in ? elem_size_in : num_qpts_in); i++) identity[i * elem_size_in + i] = 1.0; 1635 } 1636 1637 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_in, in_bytes)); 1638 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 1639 const CeedScalar *h_B_in; 1640 1641 CeedCallBackend(CeedOperatorGetBasisPointer(basis_in, eval_modes_in[i], identity, &h_B_in)); 1642 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_modes_in[i], &q_comp)); 1643 if (q_comp > 1) { 1644 if (i == 0 || eval_modes_in[i] != eval_modes_in_prev) d_in = 0; 1645 else h_B_in = &h_B_in[(++d_in) * elem_size_in * num_qpts_in]; 1646 } 1647 eval_modes_in_prev = eval_modes_in[i]; 1648 1649 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_in[i * elem_size_in * num_qpts_in], h_B_in, elem_size_in * num_qpts_in * sizeof(CeedScalar), 1650 hipMemcpyHostToDevice)); 1651 } 1652 CeedCallBackend(CeedFree(&identity)); 1653 } 1654 CeedCallBackend(CeedFree(&eval_modes_in)); 1655 1656 // Load into B_out, in order that they will be used in eval_modes_out 1657 { 1658 const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof(CeedScalar); 1659 CeedInt d_out = 0; 1660 CeedEvalMode eval_modes_out_prev = CEED_EVAL_NONE; 1661 bool has_eval_none = false; 1662 CeedScalar *identity = NULL; 1663 1664 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1665 has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 1666 } 1667 if (has_eval_none) { 1668 CeedCallBackend(CeedCalloc(elem_size_out * num_qpts_out, &identity)); 1669 for (CeedInt i = 0; i < (elem_size_out < num_qpts_out ? elem_size_out : num_qpts_out); i++) identity[i * elem_size_out + i] = 1.0; 1670 } 1671 1672 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_out, out_bytes)); 1673 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1674 const CeedScalar *h_B_out; 1675 1676 CeedCallBackend(CeedOperatorGetBasisPointer(basis_out, eval_modes_out[i], identity, &h_B_out)); 1677 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_modes_out[i], &q_comp)); 1678 if (q_comp > 1) { 1679 if (i == 0 || eval_modes_out[i] != eval_modes_out_prev) d_out = 0; 1680 else h_B_out = &h_B_out[(++d_out) * elem_size_out * num_qpts_out]; 1681 } 1682 eval_modes_out_prev = eval_modes_out[i]; 1683 1684 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_out[i * elem_size_out * num_qpts_out], h_B_out, elem_size_out * num_qpts_out * sizeof(CeedScalar), 1685 hipMemcpyHostToDevice)); 1686 } 1687 CeedCallBackend(CeedFree(&identity)); 1688 } 1689 CeedCallBackend(CeedFree(&eval_modes_out)); 1690 CeedCallBackend(CeedDestroy(&ceed)); 1691 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_in)); 1692 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out)); 1693 CeedCallBackend(CeedBasisDestroy(&basis_in)); 1694 CeedCallBackend(CeedBasisDestroy(&basis_out)); 1695 CeedCallBackend(CeedQFunctionDestroy(&qf)); 1696 return CEED_ERROR_SUCCESS; 1697 } 1698 1699 //------------------------------------------------------------------------------ 1700 // Assemble matrix data for COO matrix of assembled operator. 1701 // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic. 1702 // 1703 // Note that this (and other assembly routines) currently assume only one active input restriction/basis per operator 1704 // (could have multiple basis eval modes). 1705 // TODO: allow multiple active input restrictions/basis objects 1706 //------------------------------------------------------------------------------ 1707 static int CeedSingleOperatorAssemble_Hip(CeedOperator op, CeedInt offset, CeedVector values) { 1708 Ceed ceed; 1709 CeedSize values_length = 0, assembled_qf_length = 0; 1710 CeedInt use_ceedsize_idx = 0, num_elem_in, num_elem_out, elem_size_in, elem_size_out; 1711 CeedScalar *values_array; 1712 const CeedScalar *assembled_qf_array; 1713 CeedVector assembled_qf = NULL; 1714 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out; 1715 CeedRestrictionType rstr_type_in, rstr_type_out; 1716 const bool *orients_in = NULL, *orients_out = NULL; 1717 const CeedInt8 *curl_orients_in = NULL, *curl_orients_out = NULL; 1718 CeedOperator_Hip *impl; 1719 1720 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1721 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1722 1723 // Assemble QFunction 1724 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, CEED_REQUEST_IMMEDIATE)); 1725 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 1726 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 1727 1728 CeedCallBackend(CeedVectorGetLength(values, &values_length)); 1729 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 1730 if ((values_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 1731 1732 // Setup 1733 if (!impl->asmb) CeedCallBackend(CeedSingleOperatorAssembleSetup_Hip(op, use_ceedsize_idx)); 1734 CeedOperatorAssemble_Hip *asmb = impl->asmb; 1735 1736 assert(asmb != NULL); 1737 1738 // Assemble element operator 1739 CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_DEVICE, &values_array)); 1740 values_array += offset; 1741 1742 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 1743 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &num_elem_in)); 1744 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 1745 1746 CeedCallBackend(CeedElemRestrictionGetType(rstr_in, &rstr_type_in)); 1747 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1748 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_in, CEED_MEM_DEVICE, &orients_in)); 1749 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1750 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_in, CEED_MEM_DEVICE, &curl_orients_in)); 1751 } 1752 1753 if (rstr_in != rstr_out) { 1754 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_out, &num_elem_out)); 1755 CeedCheck(num_elem_in == num_elem_out, ceed, CEED_ERROR_UNSUPPORTED, 1756 "Active input and output operator restrictions must have the same number of elements"); 1757 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 1758 1759 CeedCallBackend(CeedElemRestrictionGetType(rstr_out, &rstr_type_out)); 1760 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1761 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_out, CEED_MEM_DEVICE, &orients_out)); 1762 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1763 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_out, CEED_MEM_DEVICE, &curl_orients_out)); 1764 } 1765 } else { 1766 elem_size_out = elem_size_in; 1767 orients_out = orients_in; 1768 curl_orients_out = curl_orients_in; 1769 } 1770 1771 // Compute B^T D B 1772 CeedInt shared_mem = 1773 ((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0) + (curl_orients_in ? elem_size_in * asmb->block_size_y : 0)) * 1774 sizeof(CeedScalar); 1775 CeedInt grid = CeedDivUpInt(num_elem_in, asmb->elems_per_block); 1776 void *args[] = {(void *)&num_elem_in, &asmb->d_B_in, &asmb->d_B_out, &orients_in, &curl_orients_in, 1777 &orients_out, &curl_orients_out, &assembled_qf_array, &values_array}; 1778 1779 CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, asmb->LinearAssemble, NULL, grid, asmb->block_size_x, asmb->block_size_y, asmb->elems_per_block, 1780 shared_mem, args)); 1781 1782 // Restore arrays 1783 CeedCallBackend(CeedVectorRestoreArray(values, &values_array)); 1784 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 1785 1786 // Cleanup 1787 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1788 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1789 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_in, &orients_in)); 1790 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1791 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_in, &curl_orients_in)); 1792 } 1793 if (rstr_in != rstr_out) { 1794 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1795 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_out, &orients_out)); 1796 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1797 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_out, &curl_orients_out)); 1798 } 1799 } 1800 CeedCallBackend(CeedDestroy(&ceed)); 1801 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_in)); 1802 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out)); 1803 return CEED_ERROR_SUCCESS; 1804 } 1805 1806 //------------------------------------------------------------------------------ 1807 // Assemble Linear QFunction AtPoints 1808 //------------------------------------------------------------------------------ 1809 static int CeedOperatorLinearAssembleQFunctionAtPoints_Hip(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 1810 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "Backend does not implement CeedOperatorLinearAssembleQFunction"); 1811 } 1812 1813 //------------------------------------------------------------------------------ 1814 // Assemble Linear Diagonal AtPoints 1815 //------------------------------------------------------------------------------ 1816 static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1817 CeedInt max_num_points, *num_points, num_elem, num_input_fields, num_output_fields; 1818 Ceed ceed; 1819 CeedVector active_e_vec_in, active_e_vec_out; 1820 CeedQFunctionField *qf_input_fields, *qf_output_fields; 1821 CeedQFunction qf; 1822 CeedOperatorField *op_input_fields, *op_output_fields; 1823 CeedOperator_Hip *impl; 1824 1825 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1826 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1827 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1828 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 1829 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 1830 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 1831 1832 // Setup 1833 CeedCallBackend(CeedOperatorSetupAtPoints_Hip(op)); 1834 num_points = impl->num_points; 1835 max_num_points = impl->max_num_points; 1836 1837 // Work vector 1838 CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_in)); 1839 CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_out)); 1840 { 1841 CeedSize length_in, length_out; 1842 1843 CeedCallBackend(CeedVectorGetLength(active_e_vec_in, &length_in)); 1844 CeedCallBackend(CeedVectorGetLength(active_e_vec_out, &length_out)); 1845 // Need input e_vec to be longer 1846 if (length_in < length_out) { 1847 CeedVector temp = active_e_vec_in; 1848 1849 active_e_vec_in = active_e_vec_out; 1850 active_e_vec_out = temp; 1851 } 1852 } 1853 1854 // Get point coordinates 1855 if (!impl->point_coords_elem) { 1856 CeedVector point_coords = NULL; 1857 CeedElemRestriction rstr_points = NULL; 1858 1859 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 1860 CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem)); 1861 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 1862 CeedCallBackend(CeedVectorDestroy(&point_coords)); 1863 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 1864 } 1865 1866 // Process inputs 1867 for (CeedInt i = 0; i < num_input_fields; i++) { 1868 CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl, request)); 1869 CeedCallBackend( 1870 CeedOperatorInputBasisAtPoints_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, num_points, true, false, impl)); 1871 } 1872 1873 // Output pointers, as necessary 1874 for (CeedInt i = 0; i < num_output_fields; i++) { 1875 CeedEvalMode eval_mode; 1876 1877 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1878 if (eval_mode == CEED_EVAL_NONE) { 1879 CeedScalar *e_vec_array; 1880 1881 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); 1882 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array)); 1883 } 1884 } 1885 1886 // Loop over active fields 1887 for (CeedInt i = 0; i < num_input_fields; i++) { 1888 bool is_active = false, is_active_at_points = true; 1889 CeedInt elem_size = 1, num_comp_active = 1, e_vec_size = 0, field_in = impl->input_field_order[i]; 1890 CeedRestrictionType rstr_type; 1891 CeedVector l_vec; 1892 CeedElemRestriction elem_rstr; 1893 1894 // -- Skip non-active input 1895 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[field_in], &l_vec)); 1896 is_active = l_vec == CEED_VECTOR_ACTIVE; 1897 CeedCallBackend(CeedVectorDestroy(&l_vec)); 1898 if (!is_active || impl->skip_rstr_in[field_in]) continue; 1899 1900 // -- Get active restriction type 1901 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[field_in], &elem_rstr)); 1902 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1903 is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS; 1904 if (!is_active_at_points) CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 1905 else elem_size = max_num_points; 1906 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp_active)); 1907 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 1908 1909 e_vec_size = elem_size * num_comp_active; 1910 CeedCallBackend(CeedVectorSetValue(active_e_vec_in, 0.0)); 1911 for (CeedInt s = 0; s < e_vec_size; s++) { 1912 CeedVector q_vec = impl->q_vecs_in[field_in]; 1913 1914 // Update unit vector 1915 { 1916 // Note: E-vec strides are node * (1) + comp * (elem_size * num_elem) + elem * (elem_size) 1917 CeedInt node = (s - 1) % elem_size, comp = (s - 1) / elem_size; 1918 CeedSize start = node * 1 + comp * (elem_size * num_elem); 1919 CeedSize stop = (comp + 1) * (elem_size * num_elem); 1920 1921 if (s != 0) CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, start, stop, elem_size, 0.0)); 1922 1923 node = s % elem_size, comp = s / elem_size; 1924 start = node * 1 + comp * (elem_size * num_elem); 1925 stop = (comp + 1) * (elem_size * num_elem); 1926 CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, start, stop, elem_size, 1.0)); 1927 } 1928 1929 // Basis action 1930 for (CeedInt j = 0; j < num_input_fields; j++) { 1931 CeedInt field = impl->input_field_order[j]; 1932 1933 CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(op_input_fields[field], qf_input_fields[field], field, NULL, active_e_vec_in, num_elem, 1934 num_points, false, true, impl)); 1935 } 1936 1937 // Q function 1938 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out)); 1939 1940 // Output basis apply if needed 1941 for (CeedInt j = 0; j < num_output_fields; j++) { 1942 bool is_active = false; 1943 CeedInt elem_size = 0; 1944 CeedInt field_out = impl->output_field_order[j]; 1945 CeedRestrictionType rstr_type; 1946 CeedEvalMode eval_mode; 1947 CeedVector l_vec, e_vec = impl->e_vecs_out[field_out], q_vec = impl->q_vecs_out[field_out]; 1948 CeedElemRestriction elem_rstr; 1949 1950 // ---- Skip non-active output 1951 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[field_out], &l_vec)); 1952 is_active = l_vec == CEED_VECTOR_ACTIVE; 1953 CeedCallBackend(CeedVectorDestroy(&l_vec)); 1954 if (!is_active) continue; 1955 if (!e_vec) e_vec = active_e_vec_out; 1956 1957 // ---- Check if elem size matches 1958 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[field_out], &elem_rstr)); 1959 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1960 if (is_active_at_points && rstr_type != CEED_RESTRICTION_POINTS) continue; 1961 if (rstr_type == CEED_RESTRICTION_POINTS) { 1962 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(elem_rstr, &elem_size)); 1963 } else { 1964 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 1965 } 1966 { 1967 CeedInt num_comp = 0; 1968 1969 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp)); 1970 if (e_vec_size != num_comp * elem_size) continue; 1971 } 1972 1973 // Basis action 1974 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[field_out], &eval_mode)); 1975 switch (eval_mode) { 1976 case CEED_EVAL_NONE: { 1977 CeedScalar *e_vec_array; 1978 1979 CeedCallBackend(CeedVectorTakeArray(q_vec, CEED_MEM_DEVICE, &e_vec_array)); 1980 CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_vec_array)); 1981 break; 1982 } 1983 case CEED_EVAL_INTERP: 1984 case CEED_EVAL_GRAD: 1985 case CEED_EVAL_DIV: 1986 case CEED_EVAL_CURL: { 1987 CeedBasis basis; 1988 1989 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[field_out], &basis)); 1990 if (impl->apply_add_basis_out[field_out]) { 1991 CeedCallBackend( 1992 CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec)); 1993 } else { 1994 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec)); 1995 } 1996 CeedCallBackend(CeedBasisDestroy(&basis)); 1997 break; 1998 } 1999 // LCOV_EXCL_START 2000 case CEED_EVAL_WEIGHT: { 2001 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 2002 // LCOV_EXCL_STOP 2003 } 2004 } 2005 2006 // Continue if a field that is summed into 2007 if (impl->skip_rstr_out[field_out]) { 2008 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 2009 continue; 2010 } 2011 2012 // Mask output e-vec 2013 CeedCallBackend(CeedVectorPointwiseMult(e_vec, active_e_vec_in, e_vec)); 2014 2015 // Restrict 2016 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, e_vec, assembled, request)); 2017 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 2018 2019 // Reset q_vec for 2020 if (eval_mode == CEED_EVAL_NONE) { 2021 CeedScalar *e_vec_array; 2022 2023 CeedCallBackend(CeedVectorGetArrayWrite(e_vec, CEED_MEM_DEVICE, &e_vec_array)); 2024 CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array)); 2025 } 2026 } 2027 2028 // Reset vec 2029 if (s == e_vec_size - 1 && i != num_input_fields - 1) CeedCallBackend(CeedVectorSetValue(q_vec, 0.0)); 2030 } 2031 } 2032 2033 // Restore CEED_EVAL_NONE 2034 for (CeedInt i = 0; i < num_output_fields; i++) { 2035 CeedEvalMode eval_mode; 2036 2037 // Get eval_mode 2038 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 2039 2040 // Restore evec 2041 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 2042 if (eval_mode == CEED_EVAL_NONE) { 2043 CeedScalar *e_vec_array; 2044 2045 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, &e_vec_array)); 2046 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_in[i], &e_vec_array)); 2047 } 2048 } 2049 2050 // Restore input arrays 2051 for (CeedInt i = 0; i < num_input_fields; i++) { 2052 CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl)); 2053 } 2054 2055 // Restore work vector 2056 CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec_in)); 2057 CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec_out)); 2058 CeedCallBackend(CeedDestroy(&ceed)); 2059 CeedCallBackend(CeedQFunctionDestroy(&qf)); 2060 return CEED_ERROR_SUCCESS; 2061 } 2062 2063 //------------------------------------------------------------------------------ 2064 // Create operator 2065 //------------------------------------------------------------------------------ 2066 int CeedOperatorCreate_Hip(CeedOperator op) { 2067 Ceed ceed; 2068 CeedOperator_Hip *impl; 2069 2070 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 2071 CeedCallBackend(CeedCalloc(1, &impl)); 2072 CeedCallBackend(CeedOperatorSetData(op, impl)); 2073 2074 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Hip)); 2075 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Hip)); 2076 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonal_Hip)); 2077 CeedCallBackend( 2078 CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal", CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip)); 2079 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssemble_Hip)); 2080 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip)); 2081 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip)); 2082 CeedCallBackend(CeedDestroy(&ceed)); 2083 return CEED_ERROR_SUCCESS; 2084 } 2085 2086 //------------------------------------------------------------------------------ 2087 // Create operator AtPoints 2088 //------------------------------------------------------------------------------ 2089 int CeedOperatorCreateAtPoints_Hip(CeedOperator op) { 2090 Ceed ceed; 2091 CeedOperator_Hip *impl; 2092 2093 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 2094 CeedCallBackend(CeedCalloc(1, &impl)); 2095 CeedCallBackend(CeedOperatorSetData(op, impl)); 2096 2097 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunctionAtPoints_Hip)); 2098 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip)); 2099 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Hip)); 2100 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip)); 2101 CeedCallBackend(CeedDestroy(&ceed)); 2102 return CEED_ERROR_SUCCESS; 2103 } 2104 2105 //------------------------------------------------------------------------------ 2106