1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <assert.h> 12 #include <stdbool.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Destroy operator 22 //------------------------------------------------------------------------------ 23 static int CeedOperatorDestroy_Hip(CeedOperator op) { 24 CeedOperator_Hip *impl; 25 26 CeedCallBackend(CeedOperatorGetData(op, &impl)); 27 28 // Apply data 29 CeedCallBackend(CeedFree(&impl->skip_rstr_in)); 30 for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { 31 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[i])); 32 } 33 CeedCallBackend(CeedFree(&impl->e_vecs)); 34 CeedCallBackend(CeedFree(&impl->input_states)); 35 36 for (CeedInt i = 0; i < impl->num_inputs; i++) { 37 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 38 } 39 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 40 41 for (CeedInt i = 0; i < impl->num_outputs; i++) { 42 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 43 } 44 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 45 CeedCallBackend(CeedVectorDestroy(&impl->point_coords_elem)); 46 47 // QFunction assembly data 48 for (CeedInt i = 0; i < impl->num_active_in; i++) { 49 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i])); 50 } 51 CeedCallBackend(CeedFree(&impl->qf_active_in)); 52 53 // Diag data 54 if (impl->diag) { 55 Ceed ceed; 56 57 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 58 if (impl->diag->module) { 59 CeedCallHip(ceed, hipModuleUnload(impl->diag->module)); 60 } 61 if (impl->diag->module_point_block) { 62 CeedCallHip(ceed, hipModuleUnload(impl->diag->module_point_block)); 63 } 64 CeedCallHip(ceed, hipFree(impl->diag->d_eval_modes_in)); 65 CeedCallHip(ceed, hipFree(impl->diag->d_eval_modes_out)); 66 CeedCallHip(ceed, hipFree(impl->diag->d_identity)); 67 CeedCallHip(ceed, hipFree(impl->diag->d_interp_in)); 68 CeedCallHip(ceed, hipFree(impl->diag->d_interp_out)); 69 CeedCallHip(ceed, hipFree(impl->diag->d_grad_in)); 70 CeedCallHip(ceed, hipFree(impl->diag->d_grad_out)); 71 CeedCallHip(ceed, hipFree(impl->diag->d_div_in)); 72 CeedCallHip(ceed, hipFree(impl->diag->d_div_out)); 73 CeedCallHip(ceed, hipFree(impl->diag->d_curl_in)); 74 CeedCallHip(ceed, hipFree(impl->diag->d_curl_out)); 75 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr)); 76 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr)); 77 CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag)); 78 CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag)); 79 } 80 CeedCallBackend(CeedFree(&impl->diag)); 81 82 if (impl->asmb) { 83 Ceed ceed; 84 85 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 86 CeedCallHip(ceed, hipModuleUnload(impl->asmb->module)); 87 CeedCallHip(ceed, hipFree(impl->asmb->d_B_in)); 88 CeedCallHip(ceed, hipFree(impl->asmb->d_B_out)); 89 } 90 CeedCallBackend(CeedFree(&impl->asmb)); 91 92 CeedCallBackend(CeedFree(&impl)); 93 return CEED_ERROR_SUCCESS; 94 } 95 96 //------------------------------------------------------------------------------ 97 // Setup infields or outfields 98 //------------------------------------------------------------------------------ 99 static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, CeedVector *e_vecs, 100 CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) { 101 Ceed ceed; 102 CeedQFunctionField *qf_fields; 103 CeedOperatorField *op_fields; 104 105 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 106 if (is_input) { 107 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 108 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 109 } else { 110 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 111 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 112 } 113 114 // Loop over fields 115 for (CeedInt i = 0; i < num_fields; i++) { 116 bool is_strided = false, skip_restriction = false; 117 CeedSize q_size; 118 CeedInt size; 119 CeedEvalMode eval_mode; 120 CeedBasis basis; 121 122 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 123 if (eval_mode != CEED_EVAL_WEIGHT) { 124 CeedElemRestriction elem_rstr; 125 126 // Check whether this field can skip the element restriction: 127 // Must be passive input, with eval_mode NONE, and have a strided restriction with CEED_STRIDES_BACKEND. 128 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); 129 130 // First, check whether the field is input or output: 131 if (is_input) { 132 CeedVector vec; 133 134 // Check for passive input 135 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 136 if (vec != CEED_VECTOR_ACTIVE) { 137 // Check eval_mode 138 if (eval_mode == CEED_EVAL_NONE) { 139 // Check for strided restriction 140 CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided)); 141 if (is_strided) { 142 // Check if vector is already in preferred backend ordering 143 CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &skip_restriction)); 144 } 145 } 146 } 147 } 148 if (skip_restriction) { 149 // We do not need an E-Vector, but will use the input field vector's data directly in the operator application. 150 e_vecs[i + start_e] = NULL; 151 } else { 152 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs[i + start_e])); 153 } 154 } 155 156 switch (eval_mode) { 157 case CEED_EVAL_NONE: 158 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 159 q_size = (CeedSize)num_elem * Q * size; 160 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 161 break; 162 case CEED_EVAL_INTERP: 163 case CEED_EVAL_GRAD: 164 case CEED_EVAL_DIV: 165 case CEED_EVAL_CURL: 166 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 167 q_size = (CeedSize)num_elem * Q * size; 168 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 169 break; 170 case CEED_EVAL_WEIGHT: // Only on input fields 171 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 172 q_size = (CeedSize)num_elem * Q; 173 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 174 if (is_at_points) { 175 CeedInt num_points[num_elem]; 176 177 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = Q; 178 CeedCallBackend( 179 CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, CEED_VECTOR_NONE, q_vecs[i])); 180 } else { 181 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 182 } 183 break; 184 } 185 } 186 // Drop duplicate input restrictions 187 if (is_input) { 188 for (CeedInt i = 0; i < num_fields; i++) { 189 CeedVector vec_i; 190 CeedElemRestriction rstr_i; 191 192 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); 193 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); 194 for (CeedInt j = i + 1; j < num_fields; j++) { 195 CeedVector vec_j; 196 CeedElemRestriction rstr_j; 197 198 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); 199 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); 200 if (vec_i == vec_j && rstr_i == rstr_j) { 201 CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); 202 skip_rstr[j] = true; 203 } 204 } 205 } 206 } 207 return CEED_ERROR_SUCCESS; 208 } 209 210 //------------------------------------------------------------------------------ 211 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 212 //------------------------------------------------------------------------------ 213 static int CeedOperatorSetup_Hip(CeedOperator op) { 214 Ceed ceed; 215 bool is_setup_done; 216 CeedInt Q, num_elem, num_input_fields, num_output_fields; 217 CeedQFunctionField *qf_input_fields, *qf_output_fields; 218 CeedQFunction qf; 219 CeedOperatorField *op_input_fields, *op_output_fields; 220 CeedOperator_Hip *impl; 221 222 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 223 if (is_setup_done) return CEED_ERROR_SUCCESS; 224 225 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 226 CeedCallBackend(CeedOperatorGetData(op, &impl)); 227 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 228 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 229 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 230 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 231 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 232 233 // Allocate 234 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); 235 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); 236 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 237 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 238 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 239 impl->num_inputs = num_input_fields; 240 impl->num_outputs = num_output_fields; 241 242 // Set up infield and outfield e_vecs and q_vecs 243 // Infields 244 CeedCallBackend( 245 CeedOperatorSetupFields_Hip(qf, op, true, false, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem)); 246 // Outfields 247 CeedCallBackend( 248 CeedOperatorSetupFields_Hip(qf, op, false, false, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, Q, num_elem)); 249 250 CeedCallBackend(CeedOperatorSetSetupDone(op)); 251 return CEED_ERROR_SUCCESS; 252 } 253 254 //------------------------------------------------------------------------------ 255 // Setup Operator Inputs 256 //------------------------------------------------------------------------------ 257 static inline int CeedOperatorSetupInputs_Hip(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 258 CeedVector in_vec, const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 259 CeedOperator_Hip *impl, CeedRequest *request) { 260 for (CeedInt i = 0; i < num_input_fields; i++) { 261 CeedEvalMode eval_mode; 262 CeedVector vec; 263 CeedElemRestriction elem_rstr; 264 265 // Get input vector 266 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 267 if (vec == CEED_VECTOR_ACTIVE) { 268 if (skip_active) continue; 269 else vec = in_vec; 270 } 271 272 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 273 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 274 } else { 275 // Get input vector 276 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 277 // Get input element restriction 278 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 279 if (vec == CEED_VECTOR_ACTIVE) vec = in_vec; 280 // Restrict, if necessary 281 if (!impl->e_vecs[i]) { 282 // No restriction for this field; read data directly from vec. 283 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i])); 284 } else { 285 uint64_t state; 286 287 CeedCallBackend(CeedVectorGetState(vec, &state)); 288 if (state != impl->input_states[i] && !impl->skip_rstr_in[i]) { 289 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs[i], request)); 290 } 291 impl->input_states[i] = state; 292 // Get evec 293 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs[i], CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i])); 294 } 295 } 296 } 297 return CEED_ERROR_SUCCESS; 298 } 299 300 //------------------------------------------------------------------------------ 301 // Input Basis Action 302 //------------------------------------------------------------------------------ 303 static inline int CeedOperatorInputBasis_Hip(CeedInt num_elem, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 304 CeedInt num_input_fields, const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 305 CeedOperator_Hip *impl) { 306 for (CeedInt i = 0; i < num_input_fields; i++) { 307 CeedInt elem_size, size; 308 CeedEvalMode eval_mode; 309 CeedElemRestriction elem_rstr; 310 CeedBasis basis; 311 312 // Skip active input 313 if (skip_active) { 314 CeedVector vec; 315 316 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 317 if (vec == CEED_VECTOR_ACTIVE) continue; 318 } 319 // Get elem_size, eval_mode, size 320 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 321 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 322 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 323 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 324 // Basis action 325 switch (eval_mode) { 326 case CEED_EVAL_NONE: 327 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 328 break; 329 case CEED_EVAL_INTERP: 330 case CEED_EVAL_GRAD: 331 case CEED_EVAL_DIV: 332 case CEED_EVAL_CURL: 333 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 334 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs[i], impl->q_vecs_in[i])); 335 break; 336 case CEED_EVAL_WEIGHT: 337 break; // No action 338 } 339 } 340 return CEED_ERROR_SUCCESS; 341 } 342 343 //------------------------------------------------------------------------------ 344 // Restore Input Vectors 345 //------------------------------------------------------------------------------ 346 static inline int CeedOperatorRestoreInputs_Hip(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 347 const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Hip *impl) { 348 for (CeedInt i = 0; i < num_input_fields; i++) { 349 CeedEvalMode eval_mode; 350 CeedVector vec; 351 352 // Skip active input 353 if (skip_active) { 354 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 355 if (vec == CEED_VECTOR_ACTIVE) continue; 356 } 357 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 358 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 359 } else { 360 if (!impl->e_vecs[i]) { // This was a skip_restriction case 361 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 362 CeedCallBackend(CeedVectorRestoreArrayRead(vec, (const CeedScalar **)&e_data[i])); 363 } else { 364 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs[i], (const CeedScalar **)&e_data[i])); 365 } 366 } 367 } 368 return CEED_ERROR_SUCCESS; 369 } 370 371 //------------------------------------------------------------------------------ 372 // Apply and add to output 373 //------------------------------------------------------------------------------ 374 static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 375 CeedInt Q, num_elem, elem_size, num_input_fields, num_output_fields, size; 376 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 377 CeedQFunctionField *qf_input_fields, *qf_output_fields; 378 CeedQFunction qf; 379 CeedOperatorField *op_input_fields, *op_output_fields; 380 CeedOperator_Hip *impl; 381 382 CeedCallBackend(CeedOperatorGetData(op, &impl)); 383 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 384 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 385 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 386 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 387 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 388 389 // Setup 390 CeedCallBackend(CeedOperatorSetup_Hip(op)); 391 392 // Input Evecs and Restriction 393 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data, impl, request)); 394 395 // Input basis apply if needed 396 CeedCallBackend(CeedOperatorInputBasis_Hip(num_elem, qf_input_fields, op_input_fields, num_input_fields, false, e_data, impl)); 397 398 // Output pointers, as necessary 399 for (CeedInt i = 0; i < num_output_fields; i++) { 400 CeedEvalMode eval_mode; 401 402 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 403 if (eval_mode == CEED_EVAL_NONE) { 404 // Set the output Q-Vector to use the E-Vector data directly. 405 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 406 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 407 } 408 } 409 410 // Q function 411 CeedCallBackend(CeedQFunctionApply(qf, num_elem * Q, impl->q_vecs_in, impl->q_vecs_out)); 412 413 // Output basis apply if needed 414 for (CeedInt i = 0; i < num_output_fields; i++) { 415 CeedEvalMode eval_mode; 416 CeedElemRestriction elem_rstr; 417 CeedBasis basis; 418 419 // Get elem_size, eval_mode, size 420 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 421 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 422 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 423 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 424 // Basis action 425 switch (eval_mode) { 426 case CEED_EVAL_NONE: 427 break; // No action 428 case CEED_EVAL_INTERP: 429 case CEED_EVAL_GRAD: 430 case CEED_EVAL_DIV: 431 case CEED_EVAL_CURL: 432 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 433 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); 434 break; 435 // LCOV_EXCL_START 436 case CEED_EVAL_WEIGHT: { 437 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 438 // LCOV_EXCL_STOP 439 } 440 } 441 } 442 443 // Output restriction 444 for (CeedInt i = 0; i < num_output_fields; i++) { 445 CeedEvalMode eval_mode; 446 CeedVector vec; 447 CeedElemRestriction elem_rstr; 448 449 // Restore evec 450 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 451 if (eval_mode == CEED_EVAL_NONE) { 452 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 453 } 454 // Get output vector 455 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 456 // Restrict 457 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 458 // Active 459 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 460 461 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request)); 462 } 463 464 // Restore input arrays 465 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl)); 466 return CEED_ERROR_SUCCESS; 467 } 468 469 //------------------------------------------------------------------------------ 470 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 471 //------------------------------------------------------------------------------ 472 static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) { 473 Ceed ceed; 474 bool is_setup_done; 475 CeedInt max_num_points = -1, num_elem, num_input_fields, num_output_fields; 476 CeedQFunctionField *qf_input_fields, *qf_output_fields; 477 CeedQFunction qf; 478 CeedOperatorField *op_input_fields, *op_output_fields; 479 CeedOperator_Hip *impl; 480 481 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 482 if (is_setup_done) return CEED_ERROR_SUCCESS; 483 484 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 485 CeedCallBackend(CeedOperatorGetData(op, &impl)); 486 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 487 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 488 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 489 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 490 { 491 CeedElemRestriction elem_rstr = NULL; 492 493 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &elem_rstr, NULL)); 494 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(elem_rstr, &max_num_points)); 495 } 496 impl->max_num_points = max_num_points; 497 498 // Allocate 499 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); 500 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); 501 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 502 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 503 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 504 impl->num_inputs = num_input_fields; 505 impl->num_outputs = num_output_fields; 506 507 // Set up infield and outfield e_vecs and q_vecs 508 // Infields 509 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, true, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, 510 max_num_points, num_elem)); 511 // Outfields 512 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, 513 max_num_points, num_elem)); 514 515 CeedCallBackend(CeedOperatorSetSetupDone(op)); 516 return CEED_ERROR_SUCCESS; 517 } 518 519 //------------------------------------------------------------------------------ 520 // Input Basis Action AtPoints 521 //------------------------------------------------------------------------------ 522 static inline int CeedOperatorInputBasisAtPoints_Hip(CeedInt num_elem, const CeedInt *num_points, CeedQFunctionField *qf_input_fields, 523 CeedOperatorField *op_input_fields, CeedInt num_input_fields, const bool skip_active, 524 CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Hip *impl) { 525 for (CeedInt i = 0; i < num_input_fields; i++) { 526 CeedInt elem_size, size; 527 CeedEvalMode eval_mode; 528 CeedElemRestriction elem_rstr; 529 CeedBasis basis; 530 531 // Skip active input 532 if (skip_active) { 533 CeedVector vec; 534 535 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 536 if (vec == CEED_VECTOR_ACTIVE) continue; 537 } 538 // Get elem_size, eval_mode, size 539 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 540 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 541 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 542 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 543 // Basis action 544 switch (eval_mode) { 545 case CEED_EVAL_NONE: 546 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 547 break; 548 case CEED_EVAL_INTERP: 549 case CEED_EVAL_GRAD: 550 case CEED_EVAL_DIV: 551 case CEED_EVAL_CURL: 552 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 553 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, impl->e_vecs[i], 554 impl->q_vecs_in[i])); 555 break; 556 case CEED_EVAL_WEIGHT: 557 break; // No action 558 } 559 } 560 return CEED_ERROR_SUCCESS; 561 } 562 563 //------------------------------------------------------------------------------ 564 // Apply and add to output AtPoints 565 //------------------------------------------------------------------------------ 566 static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 567 CeedInt max_num_points, num_elem, elem_size, num_input_fields, num_output_fields, size; 568 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 569 CeedQFunctionField *qf_input_fields, *qf_output_fields; 570 CeedQFunction qf; 571 CeedOperatorField *op_input_fields, *op_output_fields; 572 CeedOperator_Hip *impl; 573 574 CeedCallBackend(CeedOperatorGetData(op, &impl)); 575 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 576 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 577 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 578 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 579 CeedInt num_points[num_elem]; 580 581 // Setup 582 CeedCallBackend(CeedOperatorSetupAtPoints_Hip(op)); 583 max_num_points = impl->max_num_points; 584 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = max_num_points; 585 586 // Input Evecs and Restriction 587 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data, impl, request)); 588 589 // Get point coordinates 590 if (!impl->point_coords_elem) { 591 CeedVector point_coords = NULL; 592 CeedElemRestriction rstr_points = NULL; 593 594 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 595 CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem)); 596 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 597 } 598 599 // Input basis apply if needed 600 CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(num_elem, num_points, qf_input_fields, op_input_fields, num_input_fields, false, e_data, impl)); 601 602 // Output pointers, as necessary 603 for (CeedInt i = 0; i < num_output_fields; i++) { 604 CeedEvalMode eval_mode; 605 606 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 607 if (eval_mode == CEED_EVAL_NONE) { 608 // Set the output Q-Vector to use the E-Vector data directly. 609 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 610 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 611 } 612 } 613 614 // Q function 615 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out)); 616 617 // Output basis apply if needed 618 for (CeedInt i = 0; i < num_output_fields; i++) { 619 CeedEvalMode eval_mode; 620 CeedElemRestriction elem_rstr; 621 CeedBasis basis; 622 623 // Get elem_size, eval_mode, size 624 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 625 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 626 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 627 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 628 // Basis action 629 switch (eval_mode) { 630 case CEED_EVAL_NONE: 631 break; // No action 632 case CEED_EVAL_INTERP: 633 case CEED_EVAL_GRAD: 634 case CEED_EVAL_DIV: 635 case CEED_EVAL_CURL: 636 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 637 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i], 638 impl->e_vecs[i + impl->num_inputs])); 639 break; 640 // LCOV_EXCL_START 641 case CEED_EVAL_WEIGHT: { 642 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 643 // LCOV_EXCL_STOP 644 } 645 } 646 } 647 648 // Output restriction 649 for (CeedInt i = 0; i < num_output_fields; i++) { 650 CeedEvalMode eval_mode; 651 CeedVector vec; 652 CeedElemRestriction elem_rstr; 653 654 // Restore evec 655 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 656 if (eval_mode == CEED_EVAL_NONE) { 657 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 658 } 659 // Get output vector 660 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 661 // Restrict 662 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 663 // Active 664 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 665 666 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request)); 667 } 668 669 // Restore input arrays 670 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl)); 671 return CEED_ERROR_SUCCESS; 672 } 673 674 //------------------------------------------------------------------------------ 675 // Linear QFunction Assembly Core 676 //------------------------------------------------------------------------------ 677 static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 678 CeedRequest *request) { 679 Ceed ceed, ceed_parent; 680 CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size; 681 CeedScalar *assembled_array, *e_data[2 * CEED_FIELD_MAX] = {NULL}; 682 CeedVector *active_inputs; 683 CeedQFunctionField *qf_input_fields, *qf_output_fields; 684 CeedQFunction qf; 685 CeedOperatorField *op_input_fields, *op_output_fields; 686 CeedOperator_Hip *impl; 687 688 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 689 CeedCallBackend(CeedOperatorGetFallbackParentCeed(op, &ceed_parent)); 690 CeedCallBackend(CeedOperatorGetData(op, &impl)); 691 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 692 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 693 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 694 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 695 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 696 active_inputs = impl->qf_active_in; 697 num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 698 699 // Setup 700 CeedCallBackend(CeedOperatorSetup_Hip(op)); 701 702 // Input Evecs and Restriction 703 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 704 705 // Count number of active input fields 706 if (!num_active_in) { 707 for (CeedInt i = 0; i < num_input_fields; i++) { 708 CeedScalar *q_vec_array; 709 CeedVector vec; 710 711 // Get input vector 712 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 713 // Check if active input 714 if (vec == CEED_VECTOR_ACTIVE) { 715 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 716 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 717 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, &q_vec_array)); 718 CeedCallBackend(CeedRealloc(num_active_in + size, &active_inputs)); 719 for (CeedInt field = 0; field < size; field++) { 720 CeedSize q_size = (CeedSize)Q * num_elem; 721 722 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_inputs[num_active_in + field])); 723 CeedCallBackend( 724 CeedVectorSetArray(active_inputs[num_active_in + field], CEED_MEM_DEVICE, CEED_USE_POINTER, &q_vec_array[field * Q * num_elem])); 725 } 726 num_active_in += size; 727 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array)); 728 } 729 } 730 impl->num_active_in = num_active_in; 731 impl->qf_active_in = active_inputs; 732 } 733 734 // Count number of active output fields 735 if (!num_active_out) { 736 for (CeedInt i = 0; i < num_output_fields; i++) { 737 CeedVector vec; 738 739 // Get output vector 740 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 741 // Check if active output 742 if (vec == CEED_VECTOR_ACTIVE) { 743 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 744 num_active_out += size; 745 } 746 } 747 impl->num_active_out = num_active_out; 748 } 749 750 // Check sizes 751 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 752 753 // Build objects if needed 754 if (build_objects) { 755 CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out; 756 CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */ 757 758 // Create output restriction 759 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out, 760 (CeedSize)num_active_in * (CeedSize)num_active_out * (CeedSize)num_elem * (CeedSize)Q, strides, 761 rstr)); 762 // Create assembled vector 763 CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled)); 764 } 765 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 766 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_DEVICE, &assembled_array)); 767 768 // Input basis apply 769 CeedCallBackend(CeedOperatorInputBasis_Hip(num_elem, qf_input_fields, op_input_fields, num_input_fields, true, e_data, impl)); 770 771 // Assemble QFunction 772 for (CeedInt in = 0; in < num_active_in; in++) { 773 // Set Inputs 774 CeedCallBackend(CeedVectorSetValue(active_inputs[in], 1.0)); 775 if (num_active_in > 1) { 776 CeedCallBackend(CeedVectorSetValue(active_inputs[(in + num_active_in - 1) % num_active_in], 0.0)); 777 } 778 // Set Outputs 779 for (CeedInt out = 0; out < num_output_fields; out++) { 780 CeedVector vec; 781 782 // Get output vector 783 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 784 // Check if active output 785 if (vec == CEED_VECTOR_ACTIVE) { 786 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, CEED_USE_POINTER, assembled_array)); 787 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size)); 788 assembled_array += size * Q * num_elem; // Advance the pointer by the size of the output 789 } 790 } 791 // Apply QFunction 792 CeedCallBackend(CeedQFunctionApply(qf, Q * num_elem, impl->q_vecs_in, impl->q_vecs_out)); 793 } 794 795 // Un-set output q_vecs to prevent accidental overwrite of Assembled 796 for (CeedInt out = 0; out < num_output_fields; out++) { 797 CeedVector vec; 798 799 // Get output vector 800 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 801 // Check if active output 802 if (vec == CEED_VECTOR_ACTIVE) { 803 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, NULL)); 804 } 805 } 806 807 // Restore input arrays 808 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 809 810 // Restore output 811 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 812 return CEED_ERROR_SUCCESS; 813 } 814 815 //------------------------------------------------------------------------------ 816 // Assemble Linear QFunction 817 //------------------------------------------------------------------------------ 818 static int CeedOperatorLinearAssembleQFunction_Hip(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 819 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, true, assembled, rstr, request); 820 } 821 822 //------------------------------------------------------------------------------ 823 // Update Assembled Linear QFunction 824 //------------------------------------------------------------------------------ 825 static int CeedOperatorLinearAssembleQFunctionUpdate_Hip(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 826 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, false, &assembled, &rstr, request); 827 } 828 829 //------------------------------------------------------------------------------ 830 // Assemble Diagonal Setup 831 //------------------------------------------------------------------------------ 832 static inline int CeedOperatorAssembleDiagonalSetup_Hip(CeedOperator op) { 833 Ceed ceed; 834 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 835 CeedInt q_comp, num_nodes, num_qpts; 836 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 837 CeedBasis basis_in = NULL, basis_out = NULL; 838 CeedQFunctionField *qf_fields; 839 CeedQFunction qf; 840 CeedOperatorField *op_fields; 841 CeedOperator_Hip *impl; 842 843 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 844 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 845 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 846 847 // Determine active input basis 848 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 849 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 850 for (CeedInt i = 0; i < num_input_fields; i++) { 851 CeedVector vec; 852 853 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 854 if (vec == CEED_VECTOR_ACTIVE) { 855 CeedBasis basis; 856 CeedEvalMode eval_mode; 857 858 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 859 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, 860 "Backend does not implement operator diagonal assembly with multiple active bases"); 861 basis_in = basis; 862 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 863 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 864 if (eval_mode != CEED_EVAL_WEIGHT) { 865 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 866 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 867 for (CeedInt d = 0; d < q_comp; d++) eval_modes_in[num_eval_modes_in + d] = eval_mode; 868 num_eval_modes_in += q_comp; 869 } 870 } 871 } 872 873 // Determine active output basis 874 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 875 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 876 for (CeedInt i = 0; i < num_output_fields; i++) { 877 CeedVector vec; 878 879 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 880 if (vec == CEED_VECTOR_ACTIVE) { 881 CeedBasis basis; 882 CeedEvalMode eval_mode; 883 884 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 885 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 886 "Backend does not implement operator diagonal assembly with multiple active bases"); 887 basis_out = basis; 888 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 889 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 890 if (eval_mode != CEED_EVAL_WEIGHT) { 891 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 892 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 893 for (CeedInt d = 0; d < q_comp; d++) eval_modes_out[num_eval_modes_out + d] = eval_mode; 894 num_eval_modes_out += q_comp; 895 } 896 } 897 } 898 899 // Operator data struct 900 CeedCallBackend(CeedOperatorGetData(op, &impl)); 901 CeedCallBackend(CeedCalloc(1, &impl->diag)); 902 CeedOperatorDiag_Hip *diag = impl->diag; 903 904 // Basis matrices 905 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes)); 906 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes; 907 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); 908 const CeedInt interp_bytes = num_nodes * num_qpts * sizeof(CeedScalar); 909 const CeedInt eval_modes_bytes = sizeof(CeedEvalMode); 910 bool has_eval_none = false; 911 912 // CEED_EVAL_NONE 913 for (CeedInt i = 0; i < num_eval_modes_in; i++) has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 914 for (CeedInt i = 0; i < num_eval_modes_out; i++) has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 915 if (has_eval_none) { 916 CeedScalar *identity = NULL; 917 918 CeedCallBackend(CeedCalloc(num_nodes * num_qpts, &identity)); 919 for (CeedInt i = 0; i < (num_nodes < num_qpts ? num_nodes : num_qpts); i++) identity[i * num_nodes + i] = 1.0; 920 CeedCallHip(ceed, hipMalloc((void **)&diag->d_identity, interp_bytes)); 921 CeedCallHip(ceed, hipMemcpy(diag->d_identity, identity, interp_bytes, hipMemcpyHostToDevice)); 922 CeedCallBackend(CeedFree(&identity)); 923 } 924 925 // CEED_EVAL_INTERP, CEED_EVAL_GRAD, CEED_EVAL_DIV, and CEED_EVAL_CURL 926 for (CeedInt in = 0; in < 2; in++) { 927 CeedFESpace fespace; 928 CeedBasis basis = in ? basis_in : basis_out; 929 930 CeedCallBackend(CeedBasisGetFESpace(basis, &fespace)); 931 switch (fespace) { 932 case CEED_FE_SPACE_H1: { 933 CeedInt q_comp_interp, q_comp_grad; 934 const CeedScalar *interp, *grad; 935 CeedScalar *d_interp, *d_grad; 936 937 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 938 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_grad)); 939 940 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 941 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 942 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 943 CeedCallBackend(CeedBasisGetGrad(basis, &grad)); 944 CeedCallHip(ceed, hipMalloc((void **)&d_grad, interp_bytes * q_comp_grad)); 945 CeedCallHip(ceed, hipMemcpy(d_grad, grad, interp_bytes * q_comp_grad, hipMemcpyHostToDevice)); 946 if (in) { 947 diag->d_interp_in = d_interp; 948 diag->d_grad_in = d_grad; 949 } else { 950 diag->d_interp_out = d_interp; 951 diag->d_grad_out = d_grad; 952 } 953 } break; 954 case CEED_FE_SPACE_HDIV: { 955 CeedInt q_comp_interp, q_comp_div; 956 const CeedScalar *interp, *div; 957 CeedScalar *d_interp, *d_div; 958 959 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 960 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_div)); 961 962 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 963 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 964 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 965 CeedCallBackend(CeedBasisGetDiv(basis, &div)); 966 CeedCallHip(ceed, hipMalloc((void **)&d_div, interp_bytes * q_comp_div)); 967 CeedCallHip(ceed, hipMemcpy(d_div, div, interp_bytes * q_comp_div, hipMemcpyHostToDevice)); 968 if (in) { 969 diag->d_interp_in = d_interp; 970 diag->d_div_in = d_div; 971 } else { 972 diag->d_interp_out = d_interp; 973 diag->d_div_out = d_div; 974 } 975 } break; 976 case CEED_FE_SPACE_HCURL: { 977 CeedInt q_comp_interp, q_comp_curl; 978 const CeedScalar *interp, *curl; 979 CeedScalar *d_interp, *d_curl; 980 981 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 982 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_curl)); 983 984 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 985 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 986 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 987 CeedCallBackend(CeedBasisGetCurl(basis, &curl)); 988 CeedCallHip(ceed, hipMalloc((void **)&d_curl, interp_bytes * q_comp_curl)); 989 CeedCallHip(ceed, hipMemcpy(d_curl, curl, interp_bytes * q_comp_curl, hipMemcpyHostToDevice)); 990 if (in) { 991 diag->d_interp_in = d_interp; 992 diag->d_curl_in = d_curl; 993 } else { 994 diag->d_interp_out = d_interp; 995 diag->d_curl_out = d_curl; 996 } 997 } break; 998 } 999 } 1000 1001 // Arrays of eval_modes 1002 CeedCallHip(ceed, hipMalloc((void **)&diag->d_eval_modes_in, num_eval_modes_in * eval_modes_bytes)); 1003 CeedCallHip(ceed, hipMemcpy(diag->d_eval_modes_in, eval_modes_in, num_eval_modes_in * eval_modes_bytes, hipMemcpyHostToDevice)); 1004 CeedCallHip(ceed, hipMalloc((void **)&diag->d_eval_modes_out, num_eval_modes_out * eval_modes_bytes)); 1005 CeedCallHip(ceed, hipMemcpy(diag->d_eval_modes_out, eval_modes_out, num_eval_modes_out * eval_modes_bytes, hipMemcpyHostToDevice)); 1006 CeedCallBackend(CeedFree(&eval_modes_in)); 1007 CeedCallBackend(CeedFree(&eval_modes_out)); 1008 return CEED_ERROR_SUCCESS; 1009 } 1010 1011 //------------------------------------------------------------------------------ 1012 // Assemble Diagonal Setup (Compilation) 1013 //------------------------------------------------------------------------------ 1014 static inline int CeedOperatorAssembleDiagonalSetupCompile_Hip(CeedOperator op, CeedInt use_ceedsize_idx, const bool is_point_block) { 1015 Ceed ceed; 1016 char *diagonal_kernel_source; 1017 const char *diagonal_kernel_path; 1018 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 1019 CeedInt num_comp, q_comp, num_nodes, num_qpts; 1020 CeedBasis basis_in = NULL, basis_out = NULL; 1021 CeedQFunctionField *qf_fields; 1022 CeedQFunction qf; 1023 CeedOperatorField *op_fields; 1024 CeedOperator_Hip *impl; 1025 1026 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1027 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1028 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 1029 1030 // Determine active input basis 1031 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 1032 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 1033 for (CeedInt i = 0; i < num_input_fields; i++) { 1034 CeedVector vec; 1035 1036 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1037 if (vec == CEED_VECTOR_ACTIVE) { 1038 CeedEvalMode eval_mode; 1039 1040 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_in)); 1041 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1042 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 1043 if (eval_mode != CEED_EVAL_WEIGHT) { 1044 num_eval_modes_in += q_comp; 1045 } 1046 } 1047 } 1048 1049 // Determine active output basis 1050 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 1051 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 1052 for (CeedInt i = 0; i < num_output_fields; i++) { 1053 CeedVector vec; 1054 1055 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1056 if (vec == CEED_VECTOR_ACTIVE) { 1057 CeedEvalMode eval_mode; 1058 1059 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_out)); 1060 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1061 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1062 if (eval_mode != CEED_EVAL_WEIGHT) { 1063 num_eval_modes_out += q_comp; 1064 } 1065 } 1066 } 1067 1068 // Operator data struct 1069 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1070 CeedOperatorDiag_Hip *diag = impl->diag; 1071 1072 // Assemble kernel 1073 hipModule_t *module = is_point_block ? &diag->module_point_block : &diag->module; 1074 CeedInt elems_per_block = 1; 1075 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes)); 1076 CeedCallBackend(CeedBasisGetNumComponents(basis_in, &num_comp)); 1077 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes; 1078 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); 1079 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h", &diagonal_kernel_path)); 1080 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Kernel Source -----\n"); 1081 CeedCallBackend(CeedLoadSourceToBuffer(ceed, diagonal_kernel_path, &diagonal_kernel_source)); 1082 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Source Complete! -----\n"); 1083 CeedCallHip(ceed, CeedCompile_Hip(ceed, diagonal_kernel_source, module, 8, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 1084 num_eval_modes_out, "NUM_COMP", num_comp, "NUM_NODES", num_nodes, "NUM_QPTS", num_qpts, "USE_CEEDSIZE", 1085 use_ceedsize_idx, "USE_POINT_BLOCK", is_point_block ? 1 : 0, "BLOCK_SIZE", num_nodes * elems_per_block)); 1086 CeedCallHip(ceed, CeedGetKernel_Hip(ceed, *module, "LinearDiagonal", is_point_block ? &diag->LinearPointBlock : &diag->LinearDiagonal)); 1087 CeedCallBackend(CeedFree(&diagonal_kernel_path)); 1088 CeedCallBackend(CeedFree(&diagonal_kernel_source)); 1089 return CEED_ERROR_SUCCESS; 1090 } 1091 1092 //------------------------------------------------------------------------------ 1093 // Assemble Diagonal Core 1094 //------------------------------------------------------------------------------ 1095 static inline int CeedOperatorAssembleDiagonalCore_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool is_point_block) { 1096 Ceed ceed; 1097 CeedInt num_elem, num_nodes; 1098 CeedScalar *elem_diag_array; 1099 const CeedScalar *assembled_qf_array; 1100 CeedVector assembled_qf = NULL, elem_diag; 1101 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out, diag_rstr; 1102 CeedOperator_Hip *impl; 1103 1104 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1105 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1106 1107 // Assemble QFunction 1108 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, request)); 1109 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 1110 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 1111 1112 // Setup 1113 if (!impl->diag) CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Hip(op)); 1114 CeedOperatorDiag_Hip *diag = impl->diag; 1115 1116 assert(diag != NULL); 1117 1118 // Assemble kernel if needed 1119 if ((!is_point_block && !diag->LinearDiagonal) || (is_point_block && !diag->LinearPointBlock)) { 1120 CeedSize assembled_length, assembled_qf_length; 1121 CeedInt use_ceedsize_idx = 0; 1122 CeedCallBackend(CeedVectorGetLength(assembled, &assembled_length)); 1123 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 1124 if ((assembled_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 1125 1126 CeedCallBackend(CeedOperatorAssembleDiagonalSetupCompile_Hip(op, use_ceedsize_idx, is_point_block)); 1127 } 1128 1129 // Restriction and diagonal vector 1130 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 1131 CeedCheck(rstr_in == rstr_out, ceed, CEED_ERROR_BACKEND, 1132 "Cannot assemble operator diagonal with different input and output active element restrictions"); 1133 if (!is_point_block && !diag->diag_rstr) { 1134 CeedCallBackend(CeedElemRestrictionCreateUnsignedCopy(rstr_out, &diag->diag_rstr)); 1135 CeedCallBackend(CeedElemRestrictionCreateVector(diag->diag_rstr, NULL, &diag->elem_diag)); 1136 } else if (is_point_block && !diag->point_block_diag_rstr) { 1137 CeedCallBackend(CeedOperatorCreateActivePointBlockRestriction(rstr_out, &diag->point_block_diag_rstr)); 1138 CeedCallBackend(CeedElemRestrictionCreateVector(diag->point_block_diag_rstr, NULL, &diag->point_block_elem_diag)); 1139 } 1140 diag_rstr = is_point_block ? diag->point_block_diag_rstr : diag->diag_rstr; 1141 elem_diag = is_point_block ? diag->point_block_elem_diag : diag->elem_diag; 1142 CeedCallBackend(CeedVectorSetValue(elem_diag, 0.0)); 1143 1144 // Only assemble diagonal if the basis has nodes, otherwise inputs are null pointers 1145 CeedCallBackend(CeedElemRestrictionGetElementSize(diag_rstr, &num_nodes)); 1146 if (num_nodes > 0) { 1147 // Assemble element operator diagonals 1148 CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array)); 1149 CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem)); 1150 1151 // Compute the diagonal of B^T D B 1152 CeedInt elems_per_block = 1; 1153 CeedInt grid = CeedDivUpInt(num_elem, elems_per_block); 1154 void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_div_in, 1155 &diag->d_curl_in, &diag->d_interp_out, &diag->d_grad_out, &diag->d_div_out, &diag->d_curl_out, 1156 &diag->d_eval_modes_in, &diag->d_eval_modes_out, &assembled_qf_array, &elem_diag_array}; 1157 1158 if (is_point_block) { 1159 CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->LinearPointBlock, grid, num_nodes, 1, elems_per_block, args)); 1160 } else { 1161 CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->LinearDiagonal, grid, num_nodes, 1, elems_per_block, args)); 1162 } 1163 1164 // Restore arrays 1165 CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array)); 1166 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 1167 } 1168 1169 // Assemble local operator diagonal 1170 CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request)); 1171 1172 // Cleanup 1173 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1174 return CEED_ERROR_SUCCESS; 1175 } 1176 1177 //------------------------------------------------------------------------------ 1178 // Assemble Linear Diagonal 1179 //------------------------------------------------------------------------------ 1180 static int CeedOperatorLinearAssembleAddDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1181 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, false)); 1182 return CEED_ERROR_SUCCESS; 1183 } 1184 1185 //------------------------------------------------------------------------------ 1186 // Assemble Linear Point Block Diagonal 1187 //------------------------------------------------------------------------------ 1188 static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1189 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, true)); 1190 return CEED_ERROR_SUCCESS; 1191 } 1192 1193 //------------------------------------------------------------------------------ 1194 // Single Operator Assembly Setup 1195 //------------------------------------------------------------------------------ 1196 static int CeedSingleOperatorAssembleSetup_Hip(CeedOperator op, CeedInt use_ceedsize_idx) { 1197 Ceed ceed; 1198 char *assembly_kernel_source; 1199 const char *assembly_kernel_path; 1200 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 1201 CeedInt elem_size_in, num_qpts_in = 0, num_comp_in, elem_size_out, num_qpts_out, num_comp_out, q_comp; 1202 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 1203 CeedElemRestriction rstr_in = NULL, rstr_out = NULL; 1204 CeedBasis basis_in = NULL, basis_out = NULL; 1205 CeedQFunctionField *qf_fields; 1206 CeedQFunction qf; 1207 CeedOperatorField *input_fields, *output_fields; 1208 CeedOperator_Hip *impl; 1209 1210 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1211 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1212 1213 // Get intput and output fields 1214 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields)); 1215 1216 // Determine active input basis eval mode 1217 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1218 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 1219 for (CeedInt i = 0; i < num_input_fields; i++) { 1220 CeedVector vec; 1221 1222 CeedCallBackend(CeedOperatorFieldGetVector(input_fields[i], &vec)); 1223 if (vec == CEED_VECTOR_ACTIVE) { 1224 CeedBasis basis; 1225 CeedEvalMode eval_mode; 1226 1227 CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis)); 1228 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases"); 1229 basis_in = basis; 1230 CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &rstr_in)); 1231 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 1232 if (basis_in == CEED_BASIS_NONE) num_qpts_in = elem_size_in; 1233 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts_in)); 1234 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1235 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 1236 if (eval_mode != CEED_EVAL_WEIGHT) { 1237 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 1238 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 1239 for (CeedInt d = 0; d < q_comp; d++) { 1240 eval_modes_in[num_eval_modes_in + d] = eval_mode; 1241 } 1242 num_eval_modes_in += q_comp; 1243 } 1244 } 1245 } 1246 1247 // Determine active output basis; basis_out and rstr_out only used if same as input, TODO 1248 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 1249 for (CeedInt i = 0; i < num_output_fields; i++) { 1250 CeedVector vec; 1251 1252 CeedCallBackend(CeedOperatorFieldGetVector(output_fields[i], &vec)); 1253 if (vec == CEED_VECTOR_ACTIVE) { 1254 CeedBasis basis; 1255 CeedEvalMode eval_mode; 1256 1257 CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis)); 1258 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 1259 "Backend does not implement operator assembly with multiple active bases"); 1260 basis_out = basis; 1261 CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &rstr_out)); 1262 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 1263 if (basis_out == CEED_BASIS_NONE) num_qpts_out = elem_size_out; 1264 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_out, &num_qpts_out)); 1265 CeedCheck(num_qpts_in == num_qpts_out, ceed, CEED_ERROR_UNSUPPORTED, 1266 "Active input and output bases must have the same number of quadrature points"); 1267 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1268 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1269 if (eval_mode != CEED_EVAL_WEIGHT) { 1270 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 1271 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 1272 for (CeedInt d = 0; d < q_comp; d++) { 1273 eval_modes_out[num_eval_modes_out + d] = eval_mode; 1274 } 1275 num_eval_modes_out += q_comp; 1276 } 1277 } 1278 } 1279 CeedCheck(num_eval_modes_in > 0 && num_eval_modes_out > 0, ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs"); 1280 1281 CeedCallBackend(CeedCalloc(1, &impl->asmb)); 1282 CeedOperatorAssemble_Hip *asmb = impl->asmb; 1283 asmb->elems_per_block = 1; 1284 asmb->block_size_x = elem_size_in; 1285 asmb->block_size_y = elem_size_out; 1286 1287 bool fallback = asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block > 1024; 1288 1289 if (fallback) { 1290 // Use fallback kernel with 1D threadblock 1291 asmb->block_size_y = 1; 1292 } 1293 1294 // Compile kernels 1295 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &num_comp_in)); 1296 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_out, &num_comp_out)); 1297 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-operator-assemble.h", &assembly_kernel_path)); 1298 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Kernel Source -----\n"); 1299 CeedCallBackend(CeedLoadSourceToBuffer(ceed, assembly_kernel_path, &assembly_kernel_source)); 1300 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Source Complete! -----\n"); 1301 CeedCallBackend(CeedCompile_Hip(ceed, assembly_kernel_source, &asmb->module, 10, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 1302 num_eval_modes_out, "NUM_COMP_IN", num_comp_in, "NUM_COMP_OUT", num_comp_out, "NUM_NODES_IN", elem_size_in, 1303 "NUM_NODES_OUT", elem_size_out, "NUM_QPTS", num_qpts_in, "BLOCK_SIZE", 1304 asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block, "BLOCK_SIZE_Y", asmb->block_size_y, "USE_CEEDSIZE", 1305 use_ceedsize_idx)); 1306 CeedCallBackend(CeedGetKernel_Hip(ceed, asmb->module, "LinearAssemble", &asmb->LinearAssemble)); 1307 CeedCallBackend(CeedFree(&assembly_kernel_path)); 1308 CeedCallBackend(CeedFree(&assembly_kernel_source)); 1309 1310 // Load into B_in, in order that they will be used in eval_modes_in 1311 { 1312 const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof(CeedScalar); 1313 CeedInt d_in = 0; 1314 CeedEvalMode eval_modes_in_prev = CEED_EVAL_NONE; 1315 bool has_eval_none = false; 1316 CeedScalar *identity = NULL; 1317 1318 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 1319 has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 1320 } 1321 if (has_eval_none) { 1322 CeedCallBackend(CeedCalloc(elem_size_in * num_qpts_in, &identity)); 1323 for (CeedInt i = 0; i < (elem_size_in < num_qpts_in ? elem_size_in : num_qpts_in); i++) identity[i * elem_size_in + i] = 1.0; 1324 } 1325 1326 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_in, in_bytes)); 1327 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 1328 const CeedScalar *h_B_in; 1329 1330 CeedCallBackend(CeedOperatorGetBasisPointer(basis_in, eval_modes_in[i], identity, &h_B_in)); 1331 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_modes_in[i], &q_comp)); 1332 if (q_comp > 1) { 1333 if (i == 0 || eval_modes_in[i] != eval_modes_in_prev) d_in = 0; 1334 else h_B_in = &h_B_in[(++d_in) * elem_size_in * num_qpts_in]; 1335 } 1336 eval_modes_in_prev = eval_modes_in[i]; 1337 1338 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_in[i * elem_size_in * num_qpts_in], h_B_in, elem_size_in * num_qpts_in * sizeof(CeedScalar), 1339 hipMemcpyHostToDevice)); 1340 } 1341 1342 if (identity) { 1343 CeedCallBackend(CeedFree(&identity)); 1344 } 1345 } 1346 1347 // Load into B_out, in order that they will be used in eval_modes_out 1348 { 1349 const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof(CeedScalar); 1350 CeedInt d_out = 0; 1351 CeedEvalMode eval_modes_out_prev = CEED_EVAL_NONE; 1352 bool has_eval_none = false; 1353 CeedScalar *identity = NULL; 1354 1355 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1356 has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 1357 } 1358 if (has_eval_none) { 1359 CeedCallBackend(CeedCalloc(elem_size_out * num_qpts_out, &identity)); 1360 for (CeedInt i = 0; i < (elem_size_out < num_qpts_out ? elem_size_out : num_qpts_out); i++) identity[i * elem_size_out + i] = 1.0; 1361 } 1362 1363 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_out, out_bytes)); 1364 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1365 const CeedScalar *h_B_out; 1366 1367 CeedCallBackend(CeedOperatorGetBasisPointer(basis_out, eval_modes_out[i], identity, &h_B_out)); 1368 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_modes_out[i], &q_comp)); 1369 if (q_comp > 1) { 1370 if (i == 0 || eval_modes_out[i] != eval_modes_out_prev) d_out = 0; 1371 else h_B_out = &h_B_out[(++d_out) * elem_size_out * num_qpts_out]; 1372 } 1373 eval_modes_out_prev = eval_modes_out[i]; 1374 1375 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_out[i * elem_size_out * num_qpts_out], h_B_out, elem_size_out * num_qpts_out * sizeof(CeedScalar), 1376 hipMemcpyHostToDevice)); 1377 } 1378 1379 if (identity) { 1380 CeedCallBackend(CeedFree(&identity)); 1381 } 1382 } 1383 return CEED_ERROR_SUCCESS; 1384 } 1385 1386 //------------------------------------------------------------------------------ 1387 // Assemble matrix data for COO matrix of assembled operator. 1388 // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic. 1389 // 1390 // Note that this (and other assembly routines) currently assume only one active input restriction/basis per operator (could have multiple basis eval 1391 // modes). 1392 // TODO: allow multiple active input restrictions/basis objects 1393 //------------------------------------------------------------------------------ 1394 static int CeedSingleOperatorAssemble_Hip(CeedOperator op, CeedInt offset, CeedVector values) { 1395 Ceed ceed; 1396 CeedSize values_length = 0, assembled_qf_length = 0; 1397 CeedInt use_ceedsize_idx = 0, num_elem_in, num_elem_out, elem_size_in, elem_size_out; 1398 CeedScalar *values_array; 1399 const CeedScalar *assembled_qf_array; 1400 CeedVector assembled_qf = NULL; 1401 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out; 1402 CeedRestrictionType rstr_type_in, rstr_type_out; 1403 const bool *orients_in = NULL, *orients_out = NULL; 1404 const CeedInt8 *curl_orients_in = NULL, *curl_orients_out = NULL; 1405 CeedOperator_Hip *impl; 1406 1407 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1408 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1409 1410 // Assemble QFunction 1411 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, CEED_REQUEST_IMMEDIATE)); 1412 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 1413 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 1414 1415 CeedCallBackend(CeedVectorGetLength(values, &values_length)); 1416 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 1417 if ((values_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 1418 1419 // Setup 1420 if (!impl->asmb) CeedCallBackend(CeedSingleOperatorAssembleSetup_Hip(op, use_ceedsize_idx)); 1421 CeedOperatorAssemble_Hip *asmb = impl->asmb; 1422 1423 assert(asmb != NULL); 1424 1425 // Assemble element operator 1426 CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_DEVICE, &values_array)); 1427 values_array += offset; 1428 1429 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 1430 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &num_elem_in)); 1431 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 1432 1433 CeedCallBackend(CeedElemRestrictionGetType(rstr_in, &rstr_type_in)); 1434 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1435 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_in, CEED_MEM_DEVICE, &orients_in)); 1436 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1437 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_in, CEED_MEM_DEVICE, &curl_orients_in)); 1438 } 1439 1440 if (rstr_in != rstr_out) { 1441 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_out, &num_elem_out)); 1442 CeedCheck(num_elem_in == num_elem_out, ceed, CEED_ERROR_UNSUPPORTED, 1443 "Active input and output operator restrictions must have the same number of elements"); 1444 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 1445 1446 CeedCallBackend(CeedElemRestrictionGetType(rstr_out, &rstr_type_out)); 1447 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1448 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_out, CEED_MEM_DEVICE, &orients_out)); 1449 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1450 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_out, CEED_MEM_DEVICE, &curl_orients_out)); 1451 } 1452 } else { 1453 elem_size_out = elem_size_in; 1454 orients_out = orients_in; 1455 curl_orients_out = curl_orients_in; 1456 } 1457 1458 // Compute B^T D B 1459 CeedInt shared_mem = 1460 ((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0) + (curl_orients_in ? elem_size_in * asmb->block_size_y : 0)) * 1461 sizeof(CeedScalar); 1462 CeedInt grid = CeedDivUpInt(num_elem_in, asmb->elems_per_block); 1463 void *args[] = {(void *)&num_elem_in, &asmb->d_B_in, &asmb->d_B_out, &orients_in, &curl_orients_in, 1464 &orients_out, &curl_orients_out, &assembled_qf_array, &values_array}; 1465 1466 CeedCallBackend( 1467 CeedRunKernelDimShared_Hip(ceed, asmb->LinearAssemble, grid, asmb->block_size_x, asmb->block_size_y, asmb->elems_per_block, shared_mem, args)); 1468 1469 // Restore arrays 1470 CeedCallBackend(CeedVectorRestoreArray(values, &values_array)); 1471 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 1472 1473 // Cleanup 1474 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1475 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1476 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_in, &orients_in)); 1477 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1478 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_in, &curl_orients_in)); 1479 } 1480 if (rstr_in != rstr_out) { 1481 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1482 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_out, &orients_out)); 1483 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1484 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_out, &curl_orients_out)); 1485 } 1486 } 1487 return CEED_ERROR_SUCCESS; 1488 } 1489 1490 //------------------------------------------------------------------------------ 1491 // Assemble Linear QFunction AtPoints 1492 //------------------------------------------------------------------------------ 1493 static int CeedOperatorLinearAssembleQFunctionAtPoints_Hip(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 1494 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "Backend does not implement CeedOperatorLinearAssembleQFunction"); 1495 } 1496 1497 //------------------------------------------------------------------------------ 1498 // Assemble Linear Diagonal AtPoints 1499 //------------------------------------------------------------------------------ 1500 static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1501 CeedInt max_num_points, num_elem, num_input_fields, num_output_fields; 1502 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 1503 CeedQFunctionField *qf_input_fields, *qf_output_fields; 1504 CeedQFunction qf; 1505 CeedOperatorField *op_input_fields, *op_output_fields; 1506 CeedOperator_Hip *impl; 1507 1508 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1509 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1510 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 1511 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 1512 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 1513 CeedInt num_points[num_elem]; 1514 1515 // Setup 1516 CeedCallBackend(CeedOperatorSetupAtPoints_Hip(op)); 1517 max_num_points = impl->max_num_points; 1518 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = max_num_points; 1519 1520 // Input Evecs and Restriction 1521 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 1522 1523 // Get point coordinates 1524 if (!impl->point_coords_elem) { 1525 CeedVector point_coords = NULL; 1526 CeedElemRestriction rstr_points = NULL; 1527 1528 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 1529 CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem)); 1530 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 1531 } 1532 1533 // Clear active input Qvecs 1534 for (CeedInt i = 0; i < num_input_fields; i++) { 1535 CeedVector vec; 1536 1537 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1538 if (vec != CEED_VECTOR_ACTIVE) continue; 1539 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 1540 } 1541 1542 // Input basis apply if needed 1543 CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(num_elem, num_points, qf_input_fields, op_input_fields, num_input_fields, true, e_data, impl)); 1544 1545 // Output pointers, as necessary 1546 for (CeedInt i = 0; i < num_output_fields; i++) { 1547 CeedEvalMode eval_mode; 1548 1549 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1550 if (eval_mode == CEED_EVAL_NONE) { 1551 // Set the output Q-Vector to use the E-Vector data directly. 1552 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 1553 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 1554 } 1555 } 1556 1557 // Loop over active fields 1558 for (CeedInt i = 0; i < num_input_fields; i++) { 1559 bool is_active_at_points = true; 1560 CeedInt elem_size = 1, num_comp_active = 1, e_vec_size = 0; 1561 CeedRestrictionType rstr_type; 1562 CeedVector vec; 1563 CeedElemRestriction elem_rstr; 1564 1565 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1566 // -- Skip non-active input 1567 if (vec != CEED_VECTOR_ACTIVE) continue; 1568 1569 // -- Get active restriction type 1570 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 1571 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1572 is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS; 1573 if (!is_active_at_points) CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 1574 else elem_size = max_num_points; 1575 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp_active)); 1576 1577 e_vec_size = elem_size * num_comp_active; 1578 for (CeedInt s = 0; s < e_vec_size; s++) { 1579 bool is_active_input = false; 1580 CeedEvalMode eval_mode; 1581 CeedVector vec; 1582 CeedBasis basis; 1583 1584 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1585 // Skip non-active input 1586 is_active_input = vec == CEED_VECTOR_ACTIVE; 1587 if (!is_active_input) continue; 1588 1589 // Update unit vector 1590 if (s == 0) CeedCallBackend(CeedVectorSetValue(impl->e_vecs[i], 0.0)); 1591 else CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs[i], s - 1, e_vec_size, 0.0)); 1592 CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs[i], s, e_vec_size, 1.0)); 1593 1594 // Basis action 1595 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 1596 switch (eval_mode) { 1597 case CEED_EVAL_NONE: 1598 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 1599 break; 1600 case CEED_EVAL_INTERP: 1601 case CEED_EVAL_GRAD: 1602 case CEED_EVAL_DIV: 1603 case CEED_EVAL_CURL: 1604 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 1605 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, impl->e_vecs[i], 1606 impl->q_vecs_in[i])); 1607 break; 1608 case CEED_EVAL_WEIGHT: 1609 break; // No action 1610 } 1611 1612 // Q function 1613 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out)); 1614 1615 // Output basis apply if needed 1616 for (CeedInt j = 0; j < num_output_fields; j++) { 1617 bool is_active_output = false; 1618 CeedInt elem_size = 0; 1619 CeedRestrictionType rstr_type; 1620 CeedEvalMode eval_mode; 1621 CeedVector vec; 1622 CeedElemRestriction elem_rstr; 1623 CeedBasis basis; 1624 1625 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec)); 1626 // ---- Skip non-active output 1627 is_active_output = vec == CEED_VECTOR_ACTIVE; 1628 if (!is_active_output) continue; 1629 1630 // ---- Check if elem size matches 1631 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr)); 1632 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1633 if (is_active_at_points && rstr_type != CEED_RESTRICTION_POINTS) continue; 1634 if (rstr_type == CEED_RESTRICTION_POINTS) { 1635 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(elem_rstr, &elem_size)); 1636 } else { 1637 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 1638 } 1639 { 1640 CeedInt num_comp = 0; 1641 1642 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp)); 1643 if (e_vec_size != num_comp * elem_size) continue; 1644 } 1645 1646 // Basis action 1647 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode)); 1648 switch (eval_mode) { 1649 case CEED_EVAL_NONE: 1650 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[j + impl->num_inputs], &e_data[j + num_input_fields])); 1651 break; 1652 case CEED_EVAL_INTERP: 1653 case CEED_EVAL_GRAD: 1654 case CEED_EVAL_DIV: 1655 case CEED_EVAL_CURL: 1656 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[j], &basis)); 1657 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, 1658 impl->q_vecs_out[j], impl->e_vecs[j + impl->num_inputs])); 1659 break; 1660 // LCOV_EXCL_START 1661 case CEED_EVAL_WEIGHT: { 1662 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 1663 // LCOV_EXCL_STOP 1664 } 1665 } 1666 1667 // Mask output e-vec 1668 CeedCallBackend(CeedVectorPointwiseMult(impl->e_vecs[j + impl->num_inputs], impl->e_vecs[i], impl->e_vecs[j + impl->num_inputs])); 1669 1670 // Restrict 1671 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr)); 1672 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[j + impl->num_inputs], assembled, request)); 1673 1674 // Reset q_vec for 1675 if (eval_mode == CEED_EVAL_NONE) { 1676 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[j + impl->num_inputs], CEED_MEM_DEVICE, &e_data[j + num_input_fields])); 1677 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[j], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[j + num_input_fields])); 1678 } 1679 } 1680 1681 // Reset vec 1682 if (s == e_vec_size - 1 && i != num_input_fields - 1) CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 1683 } 1684 } 1685 1686 // Restore CEED_EVAL_NONE 1687 for (CeedInt i = 0; i < num_output_fields; i++) { 1688 CeedEvalMode eval_mode; 1689 1690 // Get eval_mode 1691 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1692 1693 // Restore evec 1694 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1695 if (eval_mode == CEED_EVAL_NONE) { 1696 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 1697 } 1698 } 1699 1700 // Restore input arrays 1701 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 1702 return CEED_ERROR_SUCCESS; 1703 } 1704 1705 //------------------------------------------------------------------------------ 1706 // Create operator 1707 //------------------------------------------------------------------------------ 1708 int CeedOperatorCreate_Hip(CeedOperator op) { 1709 Ceed ceed; 1710 CeedOperator_Hip *impl; 1711 1712 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1713 CeedCallBackend(CeedCalloc(1, &impl)); 1714 CeedCallBackend(CeedOperatorSetData(op, impl)); 1715 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Hip)); 1716 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Hip)); 1717 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonal_Hip)); 1718 CeedCallBackend( 1719 CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal", CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip)); 1720 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssemble_Hip)); 1721 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip)); 1722 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip)); 1723 return CEED_ERROR_SUCCESS; 1724 } 1725 1726 //------------------------------------------------------------------------------ 1727 // Create operator AtPoints 1728 //------------------------------------------------------------------------------ 1729 int CeedOperatorCreateAtPoints_Hip(CeedOperator op) { 1730 Ceed ceed; 1731 CeedOperator_Hip *impl; 1732 1733 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1734 CeedCallBackend(CeedCalloc(1, &impl)); 1735 CeedCallBackend(CeedOperatorSetData(op, impl)); 1736 1737 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunctionAtPoints_Hip)); 1738 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip)); 1739 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Hip)); 1740 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip)); 1741 return CEED_ERROR_SUCCESS; 1742 } 1743 1744 //------------------------------------------------------------------------------ 1745