1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <assert.h> 12 #include <stdbool.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Destroy operator 22 //------------------------------------------------------------------------------ 23 static int CeedOperatorDestroy_Hip(CeedOperator op) { 24 CeedOperator_Hip *impl; 25 26 CeedCallBackend(CeedOperatorGetData(op, &impl)); 27 28 // Apply data 29 for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { 30 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[i])); 31 } 32 CeedCallBackend(CeedFree(&impl->e_vecs)); 33 CeedCallBackend(CeedFree(&impl->input_states)); 34 35 for (CeedInt i = 0; i < impl->num_inputs; i++) { 36 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 37 } 38 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 39 40 for (CeedInt i = 0; i < impl->num_outputs; i++) { 41 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 42 } 43 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 44 CeedCallBackend(CeedVectorDestroy(&impl->point_coords_elem)); 45 46 // QFunction assembly data 47 for (CeedInt i = 0; i < impl->num_active_in; i++) { 48 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i])); 49 } 50 CeedCallBackend(CeedFree(&impl->qf_active_in)); 51 52 // Diag data 53 if (impl->diag) { 54 Ceed ceed; 55 56 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 57 if (impl->diag->module) { 58 CeedCallHip(ceed, hipModuleUnload(impl->diag->module)); 59 } 60 if (impl->diag->module_point_block) { 61 CeedCallHip(ceed, hipModuleUnload(impl->diag->module_point_block)); 62 } 63 CeedCallHip(ceed, hipFree(impl->diag->d_eval_modes_in)); 64 CeedCallHip(ceed, hipFree(impl->diag->d_eval_modes_out)); 65 CeedCallHip(ceed, hipFree(impl->diag->d_identity)); 66 CeedCallHip(ceed, hipFree(impl->diag->d_interp_in)); 67 CeedCallHip(ceed, hipFree(impl->diag->d_interp_out)); 68 CeedCallHip(ceed, hipFree(impl->diag->d_grad_in)); 69 CeedCallHip(ceed, hipFree(impl->diag->d_grad_out)); 70 CeedCallHip(ceed, hipFree(impl->diag->d_div_in)); 71 CeedCallHip(ceed, hipFree(impl->diag->d_div_out)); 72 CeedCallHip(ceed, hipFree(impl->diag->d_curl_in)); 73 CeedCallHip(ceed, hipFree(impl->diag->d_curl_out)); 74 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr)); 75 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr)); 76 CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag)); 77 CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag)); 78 } 79 CeedCallBackend(CeedFree(&impl->diag)); 80 81 if (impl->asmb) { 82 Ceed ceed; 83 84 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 85 CeedCallHip(ceed, hipModuleUnload(impl->asmb->module)); 86 CeedCallHip(ceed, hipFree(impl->asmb->d_B_in)); 87 CeedCallHip(ceed, hipFree(impl->asmb->d_B_out)); 88 } 89 CeedCallBackend(CeedFree(&impl->asmb)); 90 91 CeedCallBackend(CeedFree(&impl)); 92 return CEED_ERROR_SUCCESS; 93 } 94 95 //------------------------------------------------------------------------------ 96 // Setup infields or outfields 97 //------------------------------------------------------------------------------ 98 static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, CeedVector *e_vecs, CeedVector *q_vecs, 99 CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) { 100 Ceed ceed; 101 CeedQFunctionField *qf_fields; 102 CeedOperatorField *op_fields; 103 104 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 105 if (is_input) { 106 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 107 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 108 } else { 109 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 110 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 111 } 112 113 // Loop over fields 114 for (CeedInt i = 0; i < num_fields; i++) { 115 bool is_strided = false, skip_restriction = false; 116 CeedSize q_size; 117 CeedInt size; 118 CeedEvalMode eval_mode; 119 CeedBasis basis; 120 121 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 122 if (eval_mode != CEED_EVAL_WEIGHT) { 123 CeedElemRestriction elem_rstr; 124 125 // Check whether this field can skip the element restriction: 126 // Must be passive input, with eval_mode NONE, and have a strided restriction with CEED_STRIDES_BACKEND. 127 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); 128 129 // First, check whether the field is input or output: 130 if (is_input) { 131 CeedVector vec; 132 133 // Check for passive input 134 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 135 if (vec != CEED_VECTOR_ACTIVE) { 136 // Check eval_mode 137 if (eval_mode == CEED_EVAL_NONE) { 138 // Check for strided restriction 139 CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided)); 140 if (is_strided) { 141 // Check if vector is already in preferred backend ordering 142 CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &skip_restriction)); 143 } 144 } 145 } 146 } 147 if (skip_restriction) { 148 // We do not need an E-Vector, but will use the input field vector's data directly in the operator application. 149 e_vecs[i + start_e] = NULL; 150 } else { 151 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs[i + start_e])); 152 } 153 } 154 155 switch (eval_mode) { 156 case CEED_EVAL_NONE: 157 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 158 q_size = (CeedSize)num_elem * Q * size; 159 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 160 break; 161 case CEED_EVAL_INTERP: 162 case CEED_EVAL_GRAD: 163 case CEED_EVAL_DIV: 164 case CEED_EVAL_CURL: 165 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 166 q_size = (CeedSize)num_elem * Q * size; 167 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 168 break; 169 case CEED_EVAL_WEIGHT: // Only on input fields 170 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 171 q_size = (CeedSize)num_elem * Q; 172 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 173 if (is_at_points) { 174 CeedInt num_points[num_elem]; 175 176 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = Q; 177 CeedCallBackend( 178 CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, CEED_VECTOR_NONE, q_vecs[i])); 179 } else { 180 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 181 } 182 break; 183 } 184 } 185 return CEED_ERROR_SUCCESS; 186 } 187 188 //------------------------------------------------------------------------------ 189 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 190 //------------------------------------------------------------------------------ 191 static int CeedOperatorSetup_Hip(CeedOperator op) { 192 Ceed ceed; 193 bool is_setup_done; 194 CeedInt Q, num_elem, num_input_fields, num_output_fields; 195 CeedQFunctionField *qf_input_fields, *qf_output_fields; 196 CeedQFunction qf; 197 CeedOperatorField *op_input_fields, *op_output_fields; 198 CeedOperator_Hip *impl; 199 200 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 201 if (is_setup_done) return CEED_ERROR_SUCCESS; 202 203 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 204 CeedCallBackend(CeedOperatorGetData(op, &impl)); 205 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 206 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 207 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 208 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 209 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 210 211 // Allocate 212 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); 213 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 214 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 215 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 216 impl->num_inputs = num_input_fields; 217 impl->num_outputs = num_output_fields; 218 219 // Set up infield and outfield e_vecs and q_vecs 220 // Infields 221 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, false, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem)); 222 // Outfields 223 CeedCallBackend( 224 CeedOperatorSetupFields_Hip(qf, op, false, false, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, Q, num_elem)); 225 226 CeedCallBackend(CeedOperatorSetSetupDone(op)); 227 return CEED_ERROR_SUCCESS; 228 } 229 230 //------------------------------------------------------------------------------ 231 // Setup Operator Inputs 232 //------------------------------------------------------------------------------ 233 static inline int CeedOperatorSetupInputs_Hip(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 234 CeedVector in_vec, const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 235 CeedOperator_Hip *impl, CeedRequest *request) { 236 for (CeedInt i = 0; i < num_input_fields; i++) { 237 CeedEvalMode eval_mode; 238 CeedVector vec; 239 CeedElemRestriction elem_rstr; 240 241 // Get input vector 242 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 243 if (vec == CEED_VECTOR_ACTIVE) { 244 if (skip_active) continue; 245 else vec = in_vec; 246 } 247 248 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 249 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 250 } else { 251 // Get input vector 252 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 253 // Get input element restriction 254 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 255 if (vec == CEED_VECTOR_ACTIVE) vec = in_vec; 256 // Restrict, if necessary 257 if (!impl->e_vecs[i]) { 258 // No restriction for this field; read data directly from vec. 259 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i])); 260 } else { 261 uint64_t state; 262 263 CeedCallBackend(CeedVectorGetState(vec, &state)); 264 if (state != impl->input_states[i]) { 265 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs[i], request)); 266 impl->input_states[i] = state; 267 } 268 // Get evec 269 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs[i], CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i])); 270 } 271 } 272 } 273 return CEED_ERROR_SUCCESS; 274 } 275 276 //------------------------------------------------------------------------------ 277 // Input Basis Action 278 //------------------------------------------------------------------------------ 279 static inline int CeedOperatorInputBasis_Hip(CeedInt num_elem, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 280 CeedInt num_input_fields, const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 281 CeedOperator_Hip *impl) { 282 for (CeedInt i = 0; i < num_input_fields; i++) { 283 CeedInt elem_size, size; 284 CeedEvalMode eval_mode; 285 CeedElemRestriction elem_rstr; 286 CeedBasis basis; 287 288 // Skip active input 289 if (skip_active) { 290 CeedVector vec; 291 292 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 293 if (vec == CEED_VECTOR_ACTIVE) continue; 294 } 295 // Get elem_size, eval_mode, size 296 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 297 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 298 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 299 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 300 // Basis action 301 switch (eval_mode) { 302 case CEED_EVAL_NONE: 303 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 304 break; 305 case CEED_EVAL_INTERP: 306 case CEED_EVAL_GRAD: 307 case CEED_EVAL_DIV: 308 case CEED_EVAL_CURL: 309 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 310 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs[i], impl->q_vecs_in[i])); 311 break; 312 case CEED_EVAL_WEIGHT: 313 break; // No action 314 } 315 } 316 return CEED_ERROR_SUCCESS; 317 } 318 319 //------------------------------------------------------------------------------ 320 // Restore Input Vectors 321 //------------------------------------------------------------------------------ 322 static inline int CeedOperatorRestoreInputs_Hip(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 323 const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Hip *impl) { 324 for (CeedInt i = 0; i < num_input_fields; i++) { 325 CeedEvalMode eval_mode; 326 CeedVector vec; 327 328 // Skip active input 329 if (skip_active) { 330 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 331 if (vec == CEED_VECTOR_ACTIVE) continue; 332 } 333 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 334 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 335 } else { 336 if (!impl->e_vecs[i]) { // This was a skip_restriction case 337 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 338 CeedCallBackend(CeedVectorRestoreArrayRead(vec, (const CeedScalar **)&e_data[i])); 339 } else { 340 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs[i], (const CeedScalar **)&e_data[i])); 341 } 342 } 343 } 344 return CEED_ERROR_SUCCESS; 345 } 346 347 //------------------------------------------------------------------------------ 348 // Apply and add to output 349 //------------------------------------------------------------------------------ 350 static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 351 CeedInt Q, num_elem, elem_size, num_input_fields, num_output_fields, size; 352 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 353 CeedQFunctionField *qf_input_fields, *qf_output_fields; 354 CeedQFunction qf; 355 CeedOperatorField *op_input_fields, *op_output_fields; 356 CeedOperator_Hip *impl; 357 358 CeedCallBackend(CeedOperatorGetData(op, &impl)); 359 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 360 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 361 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 362 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 363 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 364 365 // Setup 366 CeedCallBackend(CeedOperatorSetup_Hip(op)); 367 368 // Input Evecs and Restriction 369 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data, impl, request)); 370 371 // Input basis apply if needed 372 CeedCallBackend(CeedOperatorInputBasis_Hip(num_elem, qf_input_fields, op_input_fields, num_input_fields, false, e_data, impl)); 373 374 // Output pointers, as necessary 375 for (CeedInt i = 0; i < num_output_fields; i++) { 376 CeedEvalMode eval_mode; 377 378 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 379 if (eval_mode == CEED_EVAL_NONE) { 380 // Set the output Q-Vector to use the E-Vector data directly. 381 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 382 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 383 } 384 } 385 386 // Q function 387 CeedCallBackend(CeedQFunctionApply(qf, num_elem * Q, impl->q_vecs_in, impl->q_vecs_out)); 388 389 // Output basis apply if needed 390 for (CeedInt i = 0; i < num_output_fields; i++) { 391 CeedEvalMode eval_mode; 392 CeedElemRestriction elem_rstr; 393 CeedBasis basis; 394 395 // Get elem_size, eval_mode, size 396 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 397 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 398 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 399 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 400 // Basis action 401 switch (eval_mode) { 402 case CEED_EVAL_NONE: 403 break; // No action 404 case CEED_EVAL_INTERP: 405 case CEED_EVAL_GRAD: 406 case CEED_EVAL_DIV: 407 case CEED_EVAL_CURL: 408 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 409 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); 410 break; 411 // LCOV_EXCL_START 412 case CEED_EVAL_WEIGHT: { 413 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 414 // LCOV_EXCL_STOP 415 } 416 } 417 } 418 419 // Output restriction 420 for (CeedInt i = 0; i < num_output_fields; i++) { 421 CeedEvalMode eval_mode; 422 CeedVector vec; 423 CeedElemRestriction elem_rstr; 424 425 // Restore evec 426 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 427 if (eval_mode == CEED_EVAL_NONE) { 428 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 429 } 430 // Get output vector 431 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 432 // Restrict 433 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 434 // Active 435 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 436 437 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request)); 438 } 439 440 // Restore input arrays 441 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl)); 442 return CEED_ERROR_SUCCESS; 443 } 444 445 //------------------------------------------------------------------------------ 446 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 447 //------------------------------------------------------------------------------ 448 static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) { 449 Ceed ceed; 450 bool is_setup_done; 451 CeedInt max_num_points = -1, num_elem, num_input_fields, num_output_fields; 452 CeedQFunctionField *qf_input_fields, *qf_output_fields; 453 CeedQFunction qf; 454 CeedOperatorField *op_input_fields, *op_output_fields; 455 CeedOperator_Hip *impl; 456 457 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 458 if (is_setup_done) return CEED_ERROR_SUCCESS; 459 460 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 461 CeedCallBackend(CeedOperatorGetData(op, &impl)); 462 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 463 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 464 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 465 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 466 { 467 CeedElemRestriction elem_rstr = NULL; 468 469 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &elem_rstr, NULL)); 470 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(elem_rstr, &max_num_points)); 471 } 472 impl->max_num_points = max_num_points; 473 474 // Allocate 475 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); 476 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 477 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 478 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 479 impl->num_inputs = num_input_fields; 480 impl->num_outputs = num_output_fields; 481 482 // Set up infield and outfield e_vecs and q_vecs 483 // Infields 484 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, true, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, max_num_points, num_elem)); 485 // Outfields 486 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, 487 max_num_points, num_elem)); 488 489 CeedCallBackend(CeedOperatorSetSetupDone(op)); 490 return CEED_ERROR_SUCCESS; 491 } 492 493 //------------------------------------------------------------------------------ 494 // Input Basis Action AtPoints 495 //------------------------------------------------------------------------------ 496 static inline int CeedOperatorInputBasisAtPoints_Hip(CeedInt num_elem, const CeedInt *num_points, CeedQFunctionField *qf_input_fields, 497 CeedOperatorField *op_input_fields, CeedInt num_input_fields, const bool skip_active, 498 CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Hip *impl) { 499 for (CeedInt i = 0; i < num_input_fields; i++) { 500 CeedInt elem_size, size; 501 CeedEvalMode eval_mode; 502 CeedElemRestriction elem_rstr; 503 CeedBasis basis; 504 505 // Skip active input 506 if (skip_active) { 507 CeedVector vec; 508 509 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 510 if (vec == CEED_VECTOR_ACTIVE) continue; 511 } 512 // Get elem_size, eval_mode, size 513 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 514 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 515 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 516 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 517 // Basis action 518 switch (eval_mode) { 519 case CEED_EVAL_NONE: 520 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 521 break; 522 case CEED_EVAL_INTERP: 523 case CEED_EVAL_GRAD: 524 case CEED_EVAL_DIV: 525 case CEED_EVAL_CURL: 526 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 527 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, impl->e_vecs[i], 528 impl->q_vecs_in[i])); 529 break; 530 case CEED_EVAL_WEIGHT: 531 break; // No action 532 } 533 } 534 return CEED_ERROR_SUCCESS; 535 } 536 537 //------------------------------------------------------------------------------ 538 // Apply and add to output AtPoints 539 //------------------------------------------------------------------------------ 540 static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 541 CeedInt max_num_points, num_elem, elem_size, num_input_fields, num_output_fields, size; 542 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 543 CeedQFunctionField *qf_input_fields, *qf_output_fields; 544 CeedQFunction qf; 545 CeedOperatorField *op_input_fields, *op_output_fields; 546 CeedOperator_Hip *impl; 547 548 CeedCallBackend(CeedOperatorGetData(op, &impl)); 549 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 550 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 551 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 552 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 553 CeedInt num_points[num_elem]; 554 555 // Setup 556 CeedCallBackend(CeedOperatorSetupAtPoints_Hip(op)); 557 max_num_points = impl->max_num_points; 558 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = max_num_points; 559 560 // Input Evecs and Restriction 561 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data, impl, request)); 562 563 // Get point coordinates 564 if (!impl->point_coords_elem) { 565 CeedVector point_coords = NULL; 566 CeedElemRestriction rstr_points = NULL; 567 568 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 569 CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem)); 570 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 571 } 572 573 // Input basis apply if needed 574 CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(num_elem, num_points, qf_input_fields, op_input_fields, num_input_fields, false, e_data, impl)); 575 576 // Output pointers, as necessary 577 for (CeedInt i = 0; i < num_output_fields; i++) { 578 CeedEvalMode eval_mode; 579 580 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 581 if (eval_mode == CEED_EVAL_NONE) { 582 // Set the output Q-Vector to use the E-Vector data directly. 583 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 584 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 585 } 586 } 587 588 // Q function 589 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out)); 590 591 // Output basis apply if needed 592 for (CeedInt i = 0; i < num_output_fields; i++) { 593 CeedEvalMode eval_mode; 594 CeedElemRestriction elem_rstr; 595 CeedBasis basis; 596 597 // Get elem_size, eval_mode, size 598 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 599 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 600 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 601 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 602 // Basis action 603 switch (eval_mode) { 604 case CEED_EVAL_NONE: 605 break; // No action 606 case CEED_EVAL_INTERP: 607 case CEED_EVAL_GRAD: 608 case CEED_EVAL_DIV: 609 case CEED_EVAL_CURL: 610 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 611 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i], 612 impl->e_vecs[i + impl->num_inputs])); 613 break; 614 // LCOV_EXCL_START 615 case CEED_EVAL_WEIGHT: { 616 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 617 // LCOV_EXCL_STOP 618 } 619 } 620 } 621 622 // Output restriction 623 for (CeedInt i = 0; i < num_output_fields; i++) { 624 CeedEvalMode eval_mode; 625 CeedVector vec; 626 CeedElemRestriction elem_rstr; 627 628 // Restore evec 629 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 630 if (eval_mode == CEED_EVAL_NONE) { 631 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 632 } 633 // Get output vector 634 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 635 // Restrict 636 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 637 // Active 638 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 639 640 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request)); 641 } 642 643 // Restore input arrays 644 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl)); 645 return CEED_ERROR_SUCCESS; 646 } 647 648 //------------------------------------------------------------------------------ 649 // Linear QFunction Assembly Core 650 //------------------------------------------------------------------------------ 651 static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 652 CeedRequest *request) { 653 Ceed ceed, ceed_parent; 654 CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size; 655 CeedScalar *assembled_array, *e_data[2 * CEED_FIELD_MAX] = {NULL}; 656 CeedVector *active_inputs; 657 CeedQFunctionField *qf_input_fields, *qf_output_fields; 658 CeedQFunction qf; 659 CeedOperatorField *op_input_fields, *op_output_fields; 660 CeedOperator_Hip *impl; 661 662 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 663 CeedCallBackend(CeedOperatorGetFallbackParentCeed(op, &ceed_parent)); 664 CeedCallBackend(CeedOperatorGetData(op, &impl)); 665 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 666 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 667 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 668 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 669 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 670 active_inputs = impl->qf_active_in; 671 num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 672 673 // Setup 674 CeedCallBackend(CeedOperatorSetup_Hip(op)); 675 676 // Input Evecs and Restriction 677 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 678 679 // Count number of active input fields 680 if (!num_active_in) { 681 for (CeedInt i = 0; i < num_input_fields; i++) { 682 CeedScalar *q_vec_array; 683 CeedVector vec; 684 685 // Get input vector 686 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 687 // Check if active input 688 if (vec == CEED_VECTOR_ACTIVE) { 689 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 690 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 691 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, &q_vec_array)); 692 CeedCallBackend(CeedRealloc(num_active_in + size, &active_inputs)); 693 for (CeedInt field = 0; field < size; field++) { 694 CeedSize q_size = (CeedSize)Q * num_elem; 695 696 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_inputs[num_active_in + field])); 697 CeedCallBackend( 698 CeedVectorSetArray(active_inputs[num_active_in + field], CEED_MEM_DEVICE, CEED_USE_POINTER, &q_vec_array[field * Q * num_elem])); 699 } 700 num_active_in += size; 701 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array)); 702 } 703 } 704 impl->num_active_in = num_active_in; 705 impl->qf_active_in = active_inputs; 706 } 707 708 // Count number of active output fields 709 if (!num_active_out) { 710 for (CeedInt i = 0; i < num_output_fields; i++) { 711 CeedVector vec; 712 713 // Get output vector 714 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 715 // Check if active output 716 if (vec == CEED_VECTOR_ACTIVE) { 717 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 718 num_active_out += size; 719 } 720 } 721 impl->num_active_out = num_active_out; 722 } 723 724 // Check sizes 725 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 726 727 // Build objects if needed 728 if (build_objects) { 729 CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out; 730 CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */ 731 732 // Create output restriction 733 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out, 734 num_active_in * num_active_out * num_elem * Q, strides, rstr)); 735 // Create assembled vector 736 CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled)); 737 } 738 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 739 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_DEVICE, &assembled_array)); 740 741 // Input basis apply 742 CeedCallBackend(CeedOperatorInputBasis_Hip(num_elem, qf_input_fields, op_input_fields, num_input_fields, true, e_data, impl)); 743 744 // Assemble QFunction 745 for (CeedInt in = 0; in < num_active_in; in++) { 746 // Set Inputs 747 CeedCallBackend(CeedVectorSetValue(active_inputs[in], 1.0)); 748 if (num_active_in > 1) { 749 CeedCallBackend(CeedVectorSetValue(active_inputs[(in + num_active_in - 1) % num_active_in], 0.0)); 750 } 751 // Set Outputs 752 for (CeedInt out = 0; out < num_output_fields; out++) { 753 CeedVector vec; 754 755 // Get output vector 756 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 757 // Check if active output 758 if (vec == CEED_VECTOR_ACTIVE) { 759 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, CEED_USE_POINTER, assembled_array)); 760 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size)); 761 assembled_array += size * Q * num_elem; // Advance the pointer by the size of the output 762 } 763 } 764 // Apply QFunction 765 CeedCallBackend(CeedQFunctionApply(qf, Q * num_elem, impl->q_vecs_in, impl->q_vecs_out)); 766 } 767 768 // Un-set output q_vecs to prevent accidental overwrite of Assembled 769 for (CeedInt out = 0; out < num_output_fields; out++) { 770 CeedVector vec; 771 772 // Get output vector 773 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 774 // Check if active output 775 if (vec == CEED_VECTOR_ACTIVE) { 776 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, NULL)); 777 } 778 } 779 780 // Restore input arrays 781 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 782 783 // Restore output 784 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 785 return CEED_ERROR_SUCCESS; 786 } 787 788 //------------------------------------------------------------------------------ 789 // Assemble Linear QFunction 790 //------------------------------------------------------------------------------ 791 static int CeedOperatorLinearAssembleQFunction_Hip(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 792 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, true, assembled, rstr, request); 793 } 794 795 //------------------------------------------------------------------------------ 796 // Update Assembled Linear QFunction 797 //------------------------------------------------------------------------------ 798 static int CeedOperatorLinearAssembleQFunctionUpdate_Hip(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 799 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, false, &assembled, &rstr, request); 800 } 801 802 //------------------------------------------------------------------------------ 803 // Assemble Diagonal Setup 804 //------------------------------------------------------------------------------ 805 static inline int CeedOperatorAssembleDiagonalSetup_Hip(CeedOperator op) { 806 Ceed ceed; 807 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 808 CeedInt q_comp, num_nodes, num_qpts; 809 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 810 CeedBasis basis_in = NULL, basis_out = NULL; 811 CeedQFunctionField *qf_fields; 812 CeedQFunction qf; 813 CeedOperatorField *op_fields; 814 CeedOperator_Hip *impl; 815 816 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 817 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 818 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 819 820 // Determine active input basis 821 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 822 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 823 for (CeedInt i = 0; i < num_input_fields; i++) { 824 CeedVector vec; 825 826 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 827 if (vec == CEED_VECTOR_ACTIVE) { 828 CeedBasis basis; 829 CeedEvalMode eval_mode; 830 831 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 832 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, 833 "Backend does not implement operator diagonal assembly with multiple active bases"); 834 basis_in = basis; 835 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 836 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 837 if (eval_mode != CEED_EVAL_WEIGHT) { 838 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 839 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 840 for (CeedInt d = 0; d < q_comp; d++) eval_modes_in[num_eval_modes_in + d] = eval_mode; 841 num_eval_modes_in += q_comp; 842 } 843 } 844 } 845 846 // Determine active output basis 847 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 848 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 849 for (CeedInt i = 0; i < num_output_fields; i++) { 850 CeedVector vec; 851 852 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 853 if (vec == CEED_VECTOR_ACTIVE) { 854 CeedBasis basis; 855 CeedEvalMode eval_mode; 856 857 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 858 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 859 "Backend does not implement operator diagonal assembly with multiple active bases"); 860 basis_out = basis; 861 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 862 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 863 if (eval_mode != CEED_EVAL_WEIGHT) { 864 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 865 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 866 for (CeedInt d = 0; d < q_comp; d++) eval_modes_out[num_eval_modes_out + d] = eval_mode; 867 num_eval_modes_out += q_comp; 868 } 869 } 870 } 871 872 // Operator data struct 873 CeedCallBackend(CeedOperatorGetData(op, &impl)); 874 CeedCallBackend(CeedCalloc(1, &impl->diag)); 875 CeedOperatorDiag_Hip *diag = impl->diag; 876 877 // Basis matrices 878 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes)); 879 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes; 880 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); 881 const CeedInt interp_bytes = num_nodes * num_qpts * sizeof(CeedScalar); 882 const CeedInt eval_modes_bytes = sizeof(CeedEvalMode); 883 bool has_eval_none = false; 884 885 // CEED_EVAL_NONE 886 for (CeedInt i = 0; i < num_eval_modes_in; i++) has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 887 for (CeedInt i = 0; i < num_eval_modes_out; i++) has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 888 if (has_eval_none) { 889 CeedScalar *identity = NULL; 890 891 CeedCallBackend(CeedCalloc(num_nodes * num_qpts, &identity)); 892 for (CeedInt i = 0; i < (num_nodes < num_qpts ? num_nodes : num_qpts); i++) identity[i * num_nodes + i] = 1.0; 893 CeedCallHip(ceed, hipMalloc((void **)&diag->d_identity, interp_bytes)); 894 CeedCallHip(ceed, hipMemcpy(diag->d_identity, identity, interp_bytes, hipMemcpyHostToDevice)); 895 CeedCallBackend(CeedFree(&identity)); 896 } 897 898 // CEED_EVAL_INTERP, CEED_EVAL_GRAD, CEED_EVAL_DIV, and CEED_EVAL_CURL 899 for (CeedInt in = 0; in < 2; in++) { 900 CeedFESpace fespace; 901 CeedBasis basis = in ? basis_in : basis_out; 902 903 CeedCallBackend(CeedBasisGetFESpace(basis, &fespace)); 904 switch (fespace) { 905 case CEED_FE_SPACE_H1: { 906 CeedInt q_comp_interp, q_comp_grad; 907 const CeedScalar *interp, *grad; 908 CeedScalar *d_interp, *d_grad; 909 910 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 911 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_grad)); 912 913 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 914 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 915 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 916 CeedCallBackend(CeedBasisGetGrad(basis, &grad)); 917 CeedCallHip(ceed, hipMalloc((void **)&d_grad, interp_bytes * q_comp_grad)); 918 CeedCallHip(ceed, hipMemcpy(d_grad, grad, interp_bytes * q_comp_grad, hipMemcpyHostToDevice)); 919 if (in) { 920 diag->d_interp_in = d_interp; 921 diag->d_grad_in = d_grad; 922 } else { 923 diag->d_interp_out = d_interp; 924 diag->d_grad_out = d_grad; 925 } 926 } break; 927 case CEED_FE_SPACE_HDIV: { 928 CeedInt q_comp_interp, q_comp_div; 929 const CeedScalar *interp, *div; 930 CeedScalar *d_interp, *d_div; 931 932 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 933 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_div)); 934 935 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 936 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 937 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 938 CeedCallBackend(CeedBasisGetDiv(basis, &div)); 939 CeedCallHip(ceed, hipMalloc((void **)&d_div, interp_bytes * q_comp_div)); 940 CeedCallHip(ceed, hipMemcpy(d_div, div, interp_bytes * q_comp_div, hipMemcpyHostToDevice)); 941 if (in) { 942 diag->d_interp_in = d_interp; 943 diag->d_div_in = d_div; 944 } else { 945 diag->d_interp_out = d_interp; 946 diag->d_div_out = d_div; 947 } 948 } break; 949 case CEED_FE_SPACE_HCURL: { 950 CeedInt q_comp_interp, q_comp_curl; 951 const CeedScalar *interp, *curl; 952 CeedScalar *d_interp, *d_curl; 953 954 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 955 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_curl)); 956 957 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 958 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 959 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 960 CeedCallBackend(CeedBasisGetCurl(basis, &curl)); 961 CeedCallHip(ceed, hipMalloc((void **)&d_curl, interp_bytes * q_comp_curl)); 962 CeedCallHip(ceed, hipMemcpy(d_curl, curl, interp_bytes * q_comp_curl, hipMemcpyHostToDevice)); 963 if (in) { 964 diag->d_interp_in = d_interp; 965 diag->d_curl_in = d_curl; 966 } else { 967 diag->d_interp_out = d_interp; 968 diag->d_curl_out = d_curl; 969 } 970 } break; 971 } 972 } 973 974 // Arrays of eval_modes 975 CeedCallHip(ceed, hipMalloc((void **)&diag->d_eval_modes_in, num_eval_modes_in * eval_modes_bytes)); 976 CeedCallHip(ceed, hipMemcpy(diag->d_eval_modes_in, eval_modes_in, num_eval_modes_in * eval_modes_bytes, hipMemcpyHostToDevice)); 977 CeedCallHip(ceed, hipMalloc((void **)&diag->d_eval_modes_out, num_eval_modes_out * eval_modes_bytes)); 978 CeedCallHip(ceed, hipMemcpy(diag->d_eval_modes_out, eval_modes_out, num_eval_modes_out * eval_modes_bytes, hipMemcpyHostToDevice)); 979 CeedCallBackend(CeedFree(&eval_modes_in)); 980 CeedCallBackend(CeedFree(&eval_modes_out)); 981 return CEED_ERROR_SUCCESS; 982 } 983 984 //------------------------------------------------------------------------------ 985 // Assemble Diagonal Setup (Compilation) 986 //------------------------------------------------------------------------------ 987 static inline int CeedOperatorAssembleDiagonalSetupCompile_Hip(CeedOperator op, CeedInt use_ceedsize_idx, const bool is_point_block) { 988 Ceed ceed; 989 char *diagonal_kernel_source; 990 const char *diagonal_kernel_path; 991 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 992 CeedInt num_comp, q_comp, num_nodes, num_qpts; 993 CeedBasis basis_in = NULL, basis_out = NULL; 994 CeedQFunctionField *qf_fields; 995 CeedQFunction qf; 996 CeedOperatorField *op_fields; 997 CeedOperator_Hip *impl; 998 999 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1000 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1001 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 1002 1003 // Determine active input basis 1004 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 1005 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 1006 for (CeedInt i = 0; i < num_input_fields; i++) { 1007 CeedVector vec; 1008 1009 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1010 if (vec == CEED_VECTOR_ACTIVE) { 1011 CeedEvalMode eval_mode; 1012 1013 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_in)); 1014 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1015 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 1016 if (eval_mode != CEED_EVAL_WEIGHT) { 1017 num_eval_modes_in += q_comp; 1018 } 1019 } 1020 } 1021 1022 // Determine active output basis 1023 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 1024 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 1025 for (CeedInt i = 0; i < num_output_fields; i++) { 1026 CeedVector vec; 1027 1028 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1029 if (vec == CEED_VECTOR_ACTIVE) { 1030 CeedEvalMode eval_mode; 1031 1032 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_out)); 1033 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1034 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1035 if (eval_mode != CEED_EVAL_WEIGHT) { 1036 num_eval_modes_out += q_comp; 1037 } 1038 } 1039 } 1040 1041 // Operator data struct 1042 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1043 CeedOperatorDiag_Hip *diag = impl->diag; 1044 1045 // Assemble kernel 1046 hipModule_t *module = is_point_block ? &diag->module_point_block : &diag->module; 1047 CeedInt elems_per_block = 1; 1048 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes)); 1049 CeedCallBackend(CeedBasisGetNumComponents(basis_in, &num_comp)); 1050 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes; 1051 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); 1052 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h", &diagonal_kernel_path)); 1053 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Kernel Source -----\n"); 1054 CeedCallBackend(CeedLoadSourceToBuffer(ceed, diagonal_kernel_path, &diagonal_kernel_source)); 1055 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Source Complete! -----\n"); 1056 CeedCallHip(ceed, CeedCompile_Hip(ceed, diagonal_kernel_source, module, 8, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 1057 num_eval_modes_out, "NUM_COMP", num_comp, "NUM_NODES", num_nodes, "NUM_QPTS", num_qpts, "USE_CEEDSIZE", 1058 use_ceedsize_idx, "USE_POINT_BLOCK", is_point_block ? 1 : 0, "BLOCK_SIZE", num_nodes * elems_per_block)); 1059 CeedCallHip(ceed, CeedGetKernel_Hip(ceed, *module, "LinearDiagonal", is_point_block ? &diag->LinearPointBlock : &diag->LinearDiagonal)); 1060 CeedCallBackend(CeedFree(&diagonal_kernel_path)); 1061 CeedCallBackend(CeedFree(&diagonal_kernel_source)); 1062 return CEED_ERROR_SUCCESS; 1063 } 1064 1065 //------------------------------------------------------------------------------ 1066 // Assemble Diagonal Core 1067 //------------------------------------------------------------------------------ 1068 static inline int CeedOperatorAssembleDiagonalCore_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool is_point_block) { 1069 Ceed ceed; 1070 CeedInt num_elem, num_nodes; 1071 CeedScalar *elem_diag_array; 1072 const CeedScalar *assembled_qf_array; 1073 CeedVector assembled_qf = NULL, elem_diag; 1074 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out, diag_rstr; 1075 CeedOperator_Hip *impl; 1076 1077 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1078 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1079 1080 // Assemble QFunction 1081 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, request)); 1082 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 1083 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 1084 1085 // Setup 1086 if (!impl->diag) CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Hip(op)); 1087 CeedOperatorDiag_Hip *diag = impl->diag; 1088 1089 assert(diag != NULL); 1090 1091 // Assemble kernel if needed 1092 if ((!is_point_block && !diag->LinearDiagonal) || (is_point_block && !diag->LinearPointBlock)) { 1093 CeedSize assembled_length, assembled_qf_length; 1094 CeedInt use_ceedsize_idx = 0; 1095 CeedCallBackend(CeedVectorGetLength(assembled, &assembled_length)); 1096 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 1097 if ((assembled_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 1098 1099 CeedCallBackend(CeedOperatorAssembleDiagonalSetupCompile_Hip(op, use_ceedsize_idx, is_point_block)); 1100 } 1101 1102 // Restriction and diagonal vector 1103 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 1104 CeedCheck(rstr_in == rstr_out, ceed, CEED_ERROR_BACKEND, 1105 "Cannot assemble operator diagonal with different input and output active element restrictions"); 1106 if (!is_point_block && !diag->diag_rstr) { 1107 CeedCallBackend(CeedElemRestrictionCreateUnsignedCopy(rstr_out, &diag->diag_rstr)); 1108 CeedCallBackend(CeedElemRestrictionCreateVector(diag->diag_rstr, NULL, &diag->elem_diag)); 1109 } else if (is_point_block && !diag->point_block_diag_rstr) { 1110 CeedCallBackend(CeedOperatorCreateActivePointBlockRestriction(rstr_out, &diag->point_block_diag_rstr)); 1111 CeedCallBackend(CeedElemRestrictionCreateVector(diag->point_block_diag_rstr, NULL, &diag->point_block_elem_diag)); 1112 } 1113 diag_rstr = is_point_block ? diag->point_block_diag_rstr : diag->diag_rstr; 1114 elem_diag = is_point_block ? diag->point_block_elem_diag : diag->elem_diag; 1115 CeedCallBackend(CeedVectorSetValue(elem_diag, 0.0)); 1116 1117 // Only assemble diagonal if the basis has nodes, otherwise inputs are null pointers 1118 CeedCallBackend(CeedElemRestrictionGetElementSize(diag_rstr, &num_nodes)); 1119 if (num_nodes > 0) { 1120 // Assemble element operator diagonals 1121 CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array)); 1122 CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem)); 1123 1124 // Compute the diagonal of B^T D B 1125 CeedInt elems_per_block = 1; 1126 CeedInt grid = CeedDivUpInt(num_elem, elems_per_block); 1127 void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_div_in, 1128 &diag->d_curl_in, &diag->d_interp_out, &diag->d_grad_out, &diag->d_div_out, &diag->d_curl_out, 1129 &diag->d_eval_modes_in, &diag->d_eval_modes_out, &assembled_qf_array, &elem_diag_array}; 1130 1131 if (is_point_block) { 1132 CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->LinearPointBlock, grid, num_nodes, 1, elems_per_block, args)); 1133 } else { 1134 CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->LinearDiagonal, grid, num_nodes, 1, elems_per_block, args)); 1135 } 1136 1137 // Restore arrays 1138 CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array)); 1139 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 1140 } 1141 1142 // Assemble local operator diagonal 1143 CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request)); 1144 1145 // Cleanup 1146 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1147 return CEED_ERROR_SUCCESS; 1148 } 1149 1150 //------------------------------------------------------------------------------ 1151 // Assemble Linear Diagonal 1152 //------------------------------------------------------------------------------ 1153 static int CeedOperatorLinearAssembleAddDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1154 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, false)); 1155 return CEED_ERROR_SUCCESS; 1156 } 1157 1158 //------------------------------------------------------------------------------ 1159 // Assemble Linear Point Block Diagonal 1160 //------------------------------------------------------------------------------ 1161 static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1162 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, true)); 1163 return CEED_ERROR_SUCCESS; 1164 } 1165 1166 //------------------------------------------------------------------------------ 1167 // Single Operator Assembly Setup 1168 //------------------------------------------------------------------------------ 1169 static int CeedSingleOperatorAssembleSetup_Hip(CeedOperator op, CeedInt use_ceedsize_idx) { 1170 Ceed ceed; 1171 char *assembly_kernel_source; 1172 const char *assembly_kernel_path; 1173 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 1174 CeedInt elem_size_in, num_qpts_in = 0, num_comp_in, elem_size_out, num_qpts_out, num_comp_out, q_comp; 1175 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 1176 CeedElemRestriction rstr_in = NULL, rstr_out = NULL; 1177 CeedBasis basis_in = NULL, basis_out = NULL; 1178 CeedQFunctionField *qf_fields; 1179 CeedQFunction qf; 1180 CeedOperatorField *input_fields, *output_fields; 1181 CeedOperator_Hip *impl; 1182 1183 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1184 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1185 1186 // Get intput and output fields 1187 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields)); 1188 1189 // Determine active input basis eval mode 1190 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1191 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 1192 for (CeedInt i = 0; i < num_input_fields; i++) { 1193 CeedVector vec; 1194 1195 CeedCallBackend(CeedOperatorFieldGetVector(input_fields[i], &vec)); 1196 if (vec == CEED_VECTOR_ACTIVE) { 1197 CeedBasis basis; 1198 CeedEvalMode eval_mode; 1199 1200 CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis)); 1201 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases"); 1202 basis_in = basis; 1203 CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &rstr_in)); 1204 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 1205 if (basis_in == CEED_BASIS_NONE) num_qpts_in = elem_size_in; 1206 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts_in)); 1207 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1208 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 1209 if (eval_mode != CEED_EVAL_WEIGHT) { 1210 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 1211 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 1212 for (CeedInt d = 0; d < q_comp; d++) { 1213 eval_modes_in[num_eval_modes_in + d] = eval_mode; 1214 } 1215 num_eval_modes_in += q_comp; 1216 } 1217 } 1218 } 1219 1220 // Determine active output basis; basis_out and rstr_out only used if same as input, TODO 1221 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 1222 for (CeedInt i = 0; i < num_output_fields; i++) { 1223 CeedVector vec; 1224 1225 CeedCallBackend(CeedOperatorFieldGetVector(output_fields[i], &vec)); 1226 if (vec == CEED_VECTOR_ACTIVE) { 1227 CeedBasis basis; 1228 CeedEvalMode eval_mode; 1229 1230 CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis)); 1231 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 1232 "Backend does not implement operator assembly with multiple active bases"); 1233 basis_out = basis; 1234 CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &rstr_out)); 1235 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 1236 if (basis_out == CEED_BASIS_NONE) num_qpts_out = elem_size_out; 1237 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_out, &num_qpts_out)); 1238 CeedCheck(num_qpts_in == num_qpts_out, ceed, CEED_ERROR_UNSUPPORTED, 1239 "Active input and output bases must have the same number of quadrature points"); 1240 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1241 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1242 if (eval_mode != CEED_EVAL_WEIGHT) { 1243 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 1244 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 1245 for (CeedInt d = 0; d < q_comp; d++) { 1246 eval_modes_out[num_eval_modes_out + d] = eval_mode; 1247 } 1248 num_eval_modes_out += q_comp; 1249 } 1250 } 1251 } 1252 CeedCheck(num_eval_modes_in > 0 && num_eval_modes_out > 0, ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs"); 1253 1254 CeedCallBackend(CeedCalloc(1, &impl->asmb)); 1255 CeedOperatorAssemble_Hip *asmb = impl->asmb; 1256 asmb->elems_per_block = 1; 1257 asmb->block_size_x = elem_size_in; 1258 asmb->block_size_y = elem_size_out; 1259 1260 bool fallback = asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block > 1024; 1261 1262 if (fallback) { 1263 // Use fallback kernel with 1D threadblock 1264 asmb->block_size_y = 1; 1265 } 1266 1267 // Compile kernels 1268 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &num_comp_in)); 1269 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_out, &num_comp_out)); 1270 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-operator-assemble.h", &assembly_kernel_path)); 1271 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Kernel Source -----\n"); 1272 CeedCallBackend(CeedLoadSourceToBuffer(ceed, assembly_kernel_path, &assembly_kernel_source)); 1273 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Source Complete! -----\n"); 1274 CeedCallBackend(CeedCompile_Hip(ceed, assembly_kernel_source, &asmb->module, 10, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 1275 num_eval_modes_out, "NUM_COMP_IN", num_comp_in, "NUM_COMP_OUT", num_comp_out, "NUM_NODES_IN", elem_size_in, 1276 "NUM_NODES_OUT", elem_size_out, "NUM_QPTS", num_qpts_in, "BLOCK_SIZE", 1277 asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block, "BLOCK_SIZE_Y", asmb->block_size_y, "USE_CEEDSIZE", 1278 use_ceedsize_idx)); 1279 CeedCallBackend(CeedGetKernel_Hip(ceed, asmb->module, "LinearAssemble", &asmb->LinearAssemble)); 1280 CeedCallBackend(CeedFree(&assembly_kernel_path)); 1281 CeedCallBackend(CeedFree(&assembly_kernel_source)); 1282 1283 // Load into B_in, in order that they will be used in eval_modes_in 1284 { 1285 const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof(CeedScalar); 1286 CeedInt d_in = 0; 1287 CeedEvalMode eval_modes_in_prev = CEED_EVAL_NONE; 1288 bool has_eval_none = false; 1289 CeedScalar *identity = NULL; 1290 1291 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 1292 has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 1293 } 1294 if (has_eval_none) { 1295 CeedCallBackend(CeedCalloc(elem_size_in * num_qpts_in, &identity)); 1296 for (CeedInt i = 0; i < (elem_size_in < num_qpts_in ? elem_size_in : num_qpts_in); i++) identity[i * elem_size_in + i] = 1.0; 1297 } 1298 1299 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_in, in_bytes)); 1300 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 1301 const CeedScalar *h_B_in; 1302 1303 CeedCallBackend(CeedOperatorGetBasisPointer(basis_in, eval_modes_in[i], identity, &h_B_in)); 1304 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_modes_in[i], &q_comp)); 1305 if (q_comp > 1) { 1306 if (i == 0 || eval_modes_in[i] != eval_modes_in_prev) d_in = 0; 1307 else h_B_in = &h_B_in[(++d_in) * elem_size_in * num_qpts_in]; 1308 } 1309 eval_modes_in_prev = eval_modes_in[i]; 1310 1311 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_in[i * elem_size_in * num_qpts_in], h_B_in, elem_size_in * num_qpts_in * sizeof(CeedScalar), 1312 hipMemcpyHostToDevice)); 1313 } 1314 1315 if (identity) { 1316 CeedCallBackend(CeedFree(&identity)); 1317 } 1318 } 1319 1320 // Load into B_out, in order that they will be used in eval_modes_out 1321 { 1322 const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof(CeedScalar); 1323 CeedInt d_out = 0; 1324 CeedEvalMode eval_modes_out_prev = CEED_EVAL_NONE; 1325 bool has_eval_none = false; 1326 CeedScalar *identity = NULL; 1327 1328 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1329 has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 1330 } 1331 if (has_eval_none) { 1332 CeedCallBackend(CeedCalloc(elem_size_out * num_qpts_out, &identity)); 1333 for (CeedInt i = 0; i < (elem_size_out < num_qpts_out ? elem_size_out : num_qpts_out); i++) identity[i * elem_size_out + i] = 1.0; 1334 } 1335 1336 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_out, out_bytes)); 1337 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1338 const CeedScalar *h_B_out; 1339 1340 CeedCallBackend(CeedOperatorGetBasisPointer(basis_out, eval_modes_out[i], identity, &h_B_out)); 1341 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_modes_out[i], &q_comp)); 1342 if (q_comp > 1) { 1343 if (i == 0 || eval_modes_out[i] != eval_modes_out_prev) d_out = 0; 1344 else h_B_out = &h_B_out[(++d_out) * elem_size_out * num_qpts_out]; 1345 } 1346 eval_modes_out_prev = eval_modes_out[i]; 1347 1348 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_out[i * elem_size_out * num_qpts_out], h_B_out, elem_size_out * num_qpts_out * sizeof(CeedScalar), 1349 hipMemcpyHostToDevice)); 1350 } 1351 1352 if (identity) { 1353 CeedCallBackend(CeedFree(&identity)); 1354 } 1355 } 1356 return CEED_ERROR_SUCCESS; 1357 } 1358 1359 //------------------------------------------------------------------------------ 1360 // Assemble matrix data for COO matrix of assembled operator. 1361 // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic. 1362 // 1363 // Note that this (and other assembly routines) currently assume only one active input restriction/basis per operator (could have multiple basis eval 1364 // modes). 1365 // TODO: allow multiple active input restrictions/basis objects 1366 //------------------------------------------------------------------------------ 1367 static int CeedSingleOperatorAssemble_Hip(CeedOperator op, CeedInt offset, CeedVector values) { 1368 Ceed ceed; 1369 CeedSize values_length = 0, assembled_qf_length = 0; 1370 CeedInt use_ceedsize_idx = 0, num_elem_in, num_elem_out, elem_size_in, elem_size_out; 1371 CeedScalar *values_array; 1372 const CeedScalar *assembled_qf_array; 1373 CeedVector assembled_qf = NULL; 1374 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out; 1375 CeedRestrictionType rstr_type_in, rstr_type_out; 1376 const bool *orients_in = NULL, *orients_out = NULL; 1377 const CeedInt8 *curl_orients_in = NULL, *curl_orients_out = NULL; 1378 CeedOperator_Hip *impl; 1379 1380 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1381 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1382 1383 // Assemble QFunction 1384 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, CEED_REQUEST_IMMEDIATE)); 1385 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 1386 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 1387 1388 CeedCallBackend(CeedVectorGetLength(values, &values_length)); 1389 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 1390 if ((values_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 1391 1392 // Setup 1393 if (!impl->asmb) CeedCallBackend(CeedSingleOperatorAssembleSetup_Hip(op, use_ceedsize_idx)); 1394 CeedOperatorAssemble_Hip *asmb = impl->asmb; 1395 1396 assert(asmb != NULL); 1397 1398 // Assemble element operator 1399 CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_DEVICE, &values_array)); 1400 values_array += offset; 1401 1402 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 1403 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &num_elem_in)); 1404 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 1405 1406 CeedCallBackend(CeedElemRestrictionGetType(rstr_in, &rstr_type_in)); 1407 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1408 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_in, CEED_MEM_DEVICE, &orients_in)); 1409 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1410 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_in, CEED_MEM_DEVICE, &curl_orients_in)); 1411 } 1412 1413 if (rstr_in != rstr_out) { 1414 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_out, &num_elem_out)); 1415 CeedCheck(num_elem_in == num_elem_out, ceed, CEED_ERROR_UNSUPPORTED, 1416 "Active input and output operator restrictions must have the same number of elements"); 1417 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 1418 1419 CeedCallBackend(CeedElemRestrictionGetType(rstr_out, &rstr_type_out)); 1420 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1421 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_out, CEED_MEM_DEVICE, &orients_out)); 1422 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1423 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_out, CEED_MEM_DEVICE, &curl_orients_out)); 1424 } 1425 } else { 1426 elem_size_out = elem_size_in; 1427 orients_out = orients_in; 1428 curl_orients_out = curl_orients_in; 1429 } 1430 1431 // Compute B^T D B 1432 CeedInt shared_mem = 1433 ((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0) + (curl_orients_in ? elem_size_in * asmb->block_size_y : 0)) * 1434 sizeof(CeedScalar); 1435 CeedInt grid = CeedDivUpInt(num_elem_in, asmb->elems_per_block); 1436 void *args[] = {(void *)&num_elem_in, &asmb->d_B_in, &asmb->d_B_out, &orients_in, &curl_orients_in, 1437 &orients_out, &curl_orients_out, &assembled_qf_array, &values_array}; 1438 1439 CeedCallBackend( 1440 CeedRunKernelDimShared_Hip(ceed, asmb->LinearAssemble, grid, asmb->block_size_x, asmb->block_size_y, asmb->elems_per_block, shared_mem, args)); 1441 1442 // Restore arrays 1443 CeedCallBackend(CeedVectorRestoreArray(values, &values_array)); 1444 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 1445 1446 // Cleanup 1447 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1448 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1449 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_in, &orients_in)); 1450 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1451 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_in, &curl_orients_in)); 1452 } 1453 if (rstr_in != rstr_out) { 1454 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1455 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_out, &orients_out)); 1456 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1457 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_out, &curl_orients_out)); 1458 } 1459 } 1460 return CEED_ERROR_SUCCESS; 1461 } 1462 1463 //------------------------------------------------------------------------------ 1464 // Assemble Linear QFunction AtPoints 1465 //------------------------------------------------------------------------------ 1466 static int CeedOperatorLinearAssembleQFunctionAtPoints_Hip(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 1467 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "Backend does not implement CeedOperatorLinearAssembleQFunction"); 1468 } 1469 1470 //------------------------------------------------------------------------------ 1471 // Assemble Linear Diagonal AtPoints 1472 //------------------------------------------------------------------------------ 1473 static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1474 bool is_active_at_points = true; 1475 CeedSize e_vec_size = 0; 1476 CeedInt max_num_points, num_elem, num_input_fields, num_output_fields, elem_size_active = 1, num_comp_active = 1; 1477 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 1478 CeedQFunctionField *qf_input_fields, *qf_output_fields; 1479 CeedQFunction qf; 1480 CeedOperatorField *op_input_fields, *op_output_fields; 1481 CeedOperator_Hip *impl; 1482 1483 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1484 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1485 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 1486 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 1487 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 1488 CeedInt num_points[num_elem]; 1489 1490 // Setup 1491 CeedCallBackend(CeedOperatorSetupAtPoints_Hip(op)); 1492 max_num_points = impl->max_num_points; 1493 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = max_num_points; 1494 1495 // Input Evecs and Restriction 1496 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 1497 1498 // Check if active field is at points 1499 for (CeedInt i = 0; i < num_input_fields; i++) { 1500 CeedRestrictionType rstr_type; 1501 CeedVector vec; 1502 CeedElemRestriction elem_rstr; 1503 1504 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1505 // Skip non-active input 1506 if (vec != CEED_VECTOR_ACTIVE) continue; 1507 1508 // Get active restriction type 1509 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 1510 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1511 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp_active)); 1512 is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS; 1513 if (!is_active_at_points) CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size_active)); 1514 } 1515 1516 // Get point coordinates 1517 if (!impl->point_coords_elem) { 1518 CeedVector point_coords = NULL; 1519 CeedElemRestriction rstr_points = NULL; 1520 1521 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 1522 CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem)); 1523 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 1524 } 1525 1526 // Input basis apply if needed 1527 CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(num_elem, num_points, qf_input_fields, op_input_fields, num_input_fields, true, e_data, impl)); 1528 1529 // Output pointers, as necessary 1530 for (CeedInt i = 0; i < num_output_fields; i++) { 1531 CeedEvalMode eval_mode; 1532 1533 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1534 if (eval_mode == CEED_EVAL_NONE) { 1535 // Set the output Q-Vector to use the E-Vector data directly. 1536 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 1537 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 1538 } 1539 } 1540 1541 // Loop over active fields 1542 e_vec_size = (is_active_at_points ? max_num_points : elem_size_active) * num_comp_active; 1543 for (CeedInt s = 0; s < e_vec_size; s++) { 1544 for (CeedInt i = 0; i < num_input_fields; i++) { 1545 bool is_active_input = false; 1546 CeedEvalMode eval_mode; 1547 CeedVector vec; 1548 CeedBasis basis; 1549 1550 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1551 // Skip non-active input 1552 is_active_input = vec == CEED_VECTOR_ACTIVE; 1553 if (!is_active_input) continue; 1554 1555 // Update unit vector 1556 if (s == 0) CeedCallBackend(CeedVectorSetValue(impl->e_vecs[i], 0.0)); 1557 else CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs[i], s - 1, e_vec_size, 0.0)); 1558 CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs[i], s, e_vec_size, 1.0)); 1559 1560 // Basis action 1561 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 1562 switch (eval_mode) { 1563 case CEED_EVAL_NONE: 1564 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 1565 break; 1566 case CEED_EVAL_INTERP: 1567 case CEED_EVAL_GRAD: 1568 case CEED_EVAL_DIV: 1569 case CEED_EVAL_CURL: 1570 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 1571 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, impl->e_vecs[i], 1572 impl->q_vecs_in[i])); 1573 break; 1574 case CEED_EVAL_WEIGHT: 1575 break; // No action 1576 } 1577 } 1578 1579 // Q function 1580 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out)); 1581 1582 // Output basis apply if needed 1583 for (CeedInt i = 0; i < num_output_fields; i++) { 1584 bool is_active_output = false; 1585 CeedEvalMode eval_mode; 1586 CeedVector vec; 1587 CeedElemRestriction elem_rstr; 1588 CeedBasis basis; 1589 1590 // Get output vector 1591 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 1592 is_active_output = vec == CEED_VECTOR_ACTIVE; 1593 if (!is_active_output) continue; 1594 1595 // Basis action 1596 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1597 switch (eval_mode) { 1598 case CEED_EVAL_NONE: 1599 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 1600 break; 1601 case CEED_EVAL_INTERP: 1602 case CEED_EVAL_GRAD: 1603 case CEED_EVAL_DIV: 1604 case CEED_EVAL_CURL: 1605 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 1606 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i], 1607 impl->e_vecs[i + impl->num_inputs])); 1608 break; 1609 // LCOV_EXCL_START 1610 case CEED_EVAL_WEIGHT: { 1611 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 1612 // LCOV_EXCL_STOP 1613 } 1614 } 1615 1616 // Mask output e-vec 1617 { 1618 CeedInt j = num_input_fields; 1619 CeedSize out_size; 1620 1621 CeedCallBackend(CeedVectorGetLength(impl->e_vecs[i + impl->num_inputs], &out_size)); 1622 for (j = 0; j < num_input_fields; j++) { 1623 bool is_active_input = false; 1624 CeedSize in_size; 1625 CeedVector vec; 1626 1627 // Skip non-active input 1628 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[j], &vec)); 1629 is_active_input = vec == CEED_VECTOR_ACTIVE; 1630 if (!is_active_input) continue; 1631 CeedCallBackend(CeedVectorGetLength(impl->e_vecs[j], &in_size)); 1632 if (in_size == out_size) break; 1633 } 1634 CeedCheck(j < num_input_fields, CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "Matching input field not found"); 1635 CeedCallBackend(CeedVectorPointwiseMult(impl->e_vecs[i + impl->num_inputs], impl->e_vecs[j], impl->e_vecs[i + impl->num_inputs])); 1636 } 1637 1638 // Restrict 1639 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 1640 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], assembled, request)); 1641 1642 // Reset q_vec for 1643 if (eval_mode == CEED_EVAL_NONE) { 1644 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 1645 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 1646 } 1647 } 1648 } 1649 1650 // Restore CEED_EVAL_NONE 1651 for (CeedInt i = 0; i < num_output_fields; i++) { 1652 CeedEvalMode eval_mode; 1653 CeedElemRestriction elem_rstr; 1654 1655 // Get eval_mode 1656 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 1657 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1658 1659 // Restore evec 1660 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1661 if (eval_mode == CEED_EVAL_NONE) { 1662 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 1663 } 1664 } 1665 1666 // Restore input arrays 1667 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 1668 return CEED_ERROR_SUCCESS; 1669 } 1670 1671 //------------------------------------------------------------------------------ 1672 // Create operator 1673 //------------------------------------------------------------------------------ 1674 int CeedOperatorCreate_Hip(CeedOperator op) { 1675 Ceed ceed; 1676 CeedOperator_Hip *impl; 1677 1678 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1679 CeedCallBackend(CeedCalloc(1, &impl)); 1680 CeedCallBackend(CeedOperatorSetData(op, impl)); 1681 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Hip)); 1682 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Hip)); 1683 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonal_Hip)); 1684 CeedCallBackend( 1685 CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal", CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip)); 1686 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssemble_Hip)); 1687 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip)); 1688 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip)); 1689 return CEED_ERROR_SUCCESS; 1690 } 1691 1692 //------------------------------------------------------------------------------ 1693 // Create operator AtPoints 1694 //------------------------------------------------------------------------------ 1695 int CeedOperatorCreateAtPoints_Hip(CeedOperator op) { 1696 Ceed ceed; 1697 CeedOperator_Hip *impl; 1698 1699 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1700 CeedCallBackend(CeedCalloc(1, &impl)); 1701 CeedCallBackend(CeedOperatorSetData(op, impl)); 1702 1703 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunctionAtPoints_Hip)); 1704 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip)); 1705 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Hip)); 1706 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip)); 1707 return CEED_ERROR_SUCCESS; 1708 } 1709 1710 //------------------------------------------------------------------------------ 1711