1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <assert.h> 12 #include <stdbool.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Destroy operator 22 //------------------------------------------------------------------------------ 23 static int CeedOperatorDestroy_Hip(CeedOperator op) { 24 CeedOperator_Hip *impl; 25 26 CeedCallBackend(CeedOperatorGetData(op, &impl)); 27 28 // Apply data 29 CeedCallBackend(CeedFree(&impl->num_points)); 30 CeedCallBackend(CeedFree(&impl->skip_rstr_in)); 31 CeedCallBackend(CeedFree(&impl->skip_rstr_out)); 32 CeedCallBackend(CeedFree(&impl->apply_add_basis_out)); 33 for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { 34 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[i])); 35 } 36 CeedCallBackend(CeedFree(&impl->e_vecs)); 37 CeedCallBackend(CeedFree(&impl->input_states)); 38 39 for (CeedInt i = 0; i < impl->num_inputs; i++) { 40 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 41 } 42 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 43 44 for (CeedInt i = 0; i < impl->num_outputs; i++) { 45 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 46 } 47 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 48 CeedCallBackend(CeedVectorDestroy(&impl->point_coords_elem)); 49 50 // QFunction assembly data 51 for (CeedInt i = 0; i < impl->num_active_in; i++) { 52 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i])); 53 } 54 CeedCallBackend(CeedFree(&impl->qf_active_in)); 55 56 // Diag data 57 if (impl->diag) { 58 Ceed ceed; 59 60 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 61 if (impl->diag->module) { 62 CeedCallHip(ceed, hipModuleUnload(impl->diag->module)); 63 } 64 if (impl->diag->module_point_block) { 65 CeedCallHip(ceed, hipModuleUnload(impl->diag->module_point_block)); 66 } 67 CeedCallHip(ceed, hipFree(impl->diag->d_eval_modes_in)); 68 CeedCallHip(ceed, hipFree(impl->diag->d_eval_modes_out)); 69 CeedCallHip(ceed, hipFree(impl->diag->d_identity)); 70 CeedCallHip(ceed, hipFree(impl->diag->d_interp_in)); 71 CeedCallHip(ceed, hipFree(impl->diag->d_interp_out)); 72 CeedCallHip(ceed, hipFree(impl->diag->d_grad_in)); 73 CeedCallHip(ceed, hipFree(impl->diag->d_grad_out)); 74 CeedCallHip(ceed, hipFree(impl->diag->d_div_in)); 75 CeedCallHip(ceed, hipFree(impl->diag->d_div_out)); 76 CeedCallHip(ceed, hipFree(impl->diag->d_curl_in)); 77 CeedCallHip(ceed, hipFree(impl->diag->d_curl_out)); 78 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr)); 79 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr)); 80 CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag)); 81 CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag)); 82 } 83 CeedCallBackend(CeedFree(&impl->diag)); 84 85 if (impl->asmb) { 86 Ceed ceed; 87 88 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 89 CeedCallHip(ceed, hipModuleUnload(impl->asmb->module)); 90 CeedCallHip(ceed, hipFree(impl->asmb->d_B_in)); 91 CeedCallHip(ceed, hipFree(impl->asmb->d_B_out)); 92 } 93 CeedCallBackend(CeedFree(&impl->asmb)); 94 95 CeedCallBackend(CeedFree(&impl)); 96 return CEED_ERROR_SUCCESS; 97 } 98 99 //------------------------------------------------------------------------------ 100 // Setup infields or outfields 101 //------------------------------------------------------------------------------ 102 static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, bool *apply_add_basis, 103 CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) { 104 Ceed ceed; 105 CeedQFunctionField *qf_fields; 106 CeedOperatorField *op_fields; 107 108 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 109 if (is_input) { 110 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 111 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 112 } else { 113 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 114 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 115 } 116 117 // Loop over fields 118 for (CeedInt i = 0; i < num_fields; i++) { 119 bool is_strided = false, skip_restriction = false; 120 CeedSize q_size; 121 CeedInt size; 122 CeedEvalMode eval_mode; 123 CeedBasis basis; 124 125 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 126 if (eval_mode != CEED_EVAL_WEIGHT) { 127 CeedElemRestriction elem_rstr; 128 129 // Check whether this field can skip the element restriction: 130 // Must be passive input, with eval_mode NONE, and have a strided restriction with CEED_STRIDES_BACKEND. 131 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); 132 133 // First, check whether the field is input or output: 134 if (is_input) { 135 CeedVector vec; 136 137 // Check for passive input 138 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 139 if (vec != CEED_VECTOR_ACTIVE) { 140 // Check eval_mode 141 if (eval_mode == CEED_EVAL_NONE) { 142 // Check for strided restriction 143 CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided)); 144 if (is_strided) { 145 // Check if vector is already in preferred backend ordering 146 CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &skip_restriction)); 147 } 148 } 149 } 150 } 151 if (skip_restriction) { 152 // We do not need an E-Vector, but will use the input field vector's data directly in the operator application. 153 e_vecs[i + start_e] = NULL; 154 } else { 155 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs[i + start_e])); 156 } 157 } 158 159 switch (eval_mode) { 160 case CEED_EVAL_NONE: 161 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 162 q_size = (CeedSize)num_elem * Q * size; 163 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 164 break; 165 case CEED_EVAL_INTERP: 166 case CEED_EVAL_GRAD: 167 case CEED_EVAL_DIV: 168 case CEED_EVAL_CURL: 169 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 170 q_size = (CeedSize)num_elem * Q * size; 171 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 172 break; 173 case CEED_EVAL_WEIGHT: // Only on input fields 174 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 175 q_size = (CeedSize)num_elem * Q; 176 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 177 if (is_at_points) { 178 CeedInt num_points[num_elem]; 179 180 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = Q; 181 CeedCallBackend( 182 CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, CEED_VECTOR_NONE, q_vecs[i])); 183 } else { 184 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 185 } 186 break; 187 } 188 } 189 // Drop duplicate restrictions 190 if (is_input) { 191 for (CeedInt i = 0; i < num_fields; i++) { 192 CeedVector vec_i; 193 CeedElemRestriction rstr_i; 194 195 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); 196 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); 197 for (CeedInt j = i + 1; j < num_fields; j++) { 198 CeedVector vec_j; 199 CeedElemRestriction rstr_j; 200 201 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); 202 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); 203 if (vec_i == vec_j && rstr_i == rstr_j) { 204 CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e])); 205 skip_rstr[j] = true; 206 } 207 } 208 } 209 } else { 210 for (CeedInt i = num_fields - 1; i >= 0; i--) { 211 CeedVector vec_i; 212 CeedElemRestriction rstr_i; 213 214 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); 215 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); 216 for (CeedInt j = i - 1; j >= 0; j--) { 217 CeedVector vec_j; 218 CeedElemRestriction rstr_j; 219 220 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); 221 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); 222 if (vec_i == vec_j && rstr_i == rstr_j) { 223 CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e])); 224 skip_rstr[j] = true; 225 apply_add_basis[i] = true; 226 } 227 } 228 } 229 } 230 return CEED_ERROR_SUCCESS; 231 } 232 233 //------------------------------------------------------------------------------ 234 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 235 //------------------------------------------------------------------------------ 236 static int CeedOperatorSetup_Hip(CeedOperator op) { 237 Ceed ceed; 238 bool is_setup_done; 239 CeedInt Q, num_elem, num_input_fields, num_output_fields; 240 CeedQFunctionField *qf_input_fields, *qf_output_fields; 241 CeedQFunction qf; 242 CeedOperatorField *op_input_fields, *op_output_fields; 243 CeedOperator_Hip *impl; 244 245 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 246 if (is_setup_done) return CEED_ERROR_SUCCESS; 247 248 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 249 CeedCallBackend(CeedOperatorGetData(op, &impl)); 250 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 251 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 252 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 253 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 254 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 255 256 // Allocate 257 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); 258 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); 259 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); 260 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); 261 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 262 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 263 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 264 impl->num_inputs = num_input_fields; 265 impl->num_outputs = num_output_fields; 266 267 // Set up infield and outfield e-vecs and q-vecs 268 // Infields 269 CeedCallBackend( 270 CeedOperatorSetupFields_Hip(qf, op, true, false, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem)); 271 // Outfields 272 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, false, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out, 273 num_input_fields, num_output_fields, Q, num_elem)); 274 275 // Reuse active e-vecs where able 276 { 277 CeedInt num_used = 0; 278 CeedElemRestriction *rstr_used = NULL; 279 280 for (CeedInt i = 0; i < num_input_fields; i++) { 281 bool is_used = false; 282 CeedVector vec_i; 283 CeedElemRestriction rstr_i; 284 285 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i)); 286 if (vec_i != CEED_VECTOR_ACTIVE) continue; 287 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i)); 288 for (CeedInt j = 0; j < num_used; j++) { 289 if (rstr_i == rstr_used[i]) is_used = true; 290 } 291 if (is_used) continue; 292 num_used++; 293 if (num_used == 1) CeedCallBackend(CeedCalloc(num_used, &rstr_used)); 294 else CeedCallBackend(CeedRealloc(num_used, &rstr_used)); 295 rstr_used[num_used - 1] = rstr_i; 296 for (CeedInt j = num_output_fields - 1; j >= 0; j--) { 297 CeedEvalMode eval_mode; 298 CeedVector vec_j; 299 CeedElemRestriction rstr_j; 300 301 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec_j)); 302 if (vec_j != CEED_VECTOR_ACTIVE) continue; 303 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode)); 304 if (eval_mode == CEED_EVAL_NONE) continue; 305 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &rstr_j)); 306 if (rstr_i == rstr_j) { 307 CeedCallBackend(CeedVectorReferenceCopy(impl->e_vecs[i], &impl->e_vecs[j + impl->num_inputs])); 308 } 309 } 310 } 311 CeedCallBackend(CeedFree(&rstr_used)); 312 } 313 impl->has_shared_e_vecs = true; 314 CeedCallBackend(CeedOperatorSetSetupDone(op)); 315 return CEED_ERROR_SUCCESS; 316 } 317 318 //------------------------------------------------------------------------------ 319 // Setup Operator Inputs 320 //------------------------------------------------------------------------------ 321 static inline int CeedOperatorSetupInputs_Hip(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 322 CeedVector in_vec, const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 323 CeedOperator_Hip *impl, CeedRequest *request) { 324 for (CeedInt i = 0; i < num_input_fields; i++) { 325 CeedEvalMode eval_mode; 326 CeedVector vec; 327 CeedElemRestriction elem_rstr; 328 329 // Get input vector 330 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 331 if (vec == CEED_VECTOR_ACTIVE) { 332 if (skip_active) continue; 333 else vec = in_vec; 334 } 335 336 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 337 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 338 } else { 339 // Get input vector 340 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 341 // Get input element restriction 342 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 343 if (vec == CEED_VECTOR_ACTIVE) vec = in_vec; 344 // Restrict, if necessary 345 if (!impl->e_vecs[i]) { 346 // No restriction for this field; read data directly from vec. 347 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i])); 348 } else { 349 uint64_t state; 350 351 CeedCallBackend(CeedVectorGetState(vec, &state)); 352 if ((state != impl->input_states[i] || vec == in_vec) && !impl->skip_rstr_in[i]) { 353 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs[i], request)); 354 } 355 impl->input_states[i] = state; 356 // Get evec 357 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs[i], CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i])); 358 } 359 } 360 } 361 return CEED_ERROR_SUCCESS; 362 } 363 364 //------------------------------------------------------------------------------ 365 // Input Basis Action 366 //------------------------------------------------------------------------------ 367 static inline int CeedOperatorInputBasis_Hip(CeedInt num_elem, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 368 CeedInt num_input_fields, const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 369 CeedOperator_Hip *impl) { 370 for (CeedInt i = 0; i < num_input_fields; i++) { 371 CeedInt elem_size, size; 372 CeedEvalMode eval_mode; 373 CeedElemRestriction elem_rstr; 374 CeedBasis basis; 375 376 // Skip active input 377 if (skip_active) { 378 CeedVector vec; 379 380 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 381 if (vec == CEED_VECTOR_ACTIVE) continue; 382 } 383 // Get elem_size, eval_mode, size 384 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 385 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 386 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 387 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 388 // Basis action 389 switch (eval_mode) { 390 case CEED_EVAL_NONE: 391 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 392 break; 393 case CEED_EVAL_INTERP: 394 case CEED_EVAL_GRAD: 395 case CEED_EVAL_DIV: 396 case CEED_EVAL_CURL: 397 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 398 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs[i], impl->q_vecs_in[i])); 399 break; 400 case CEED_EVAL_WEIGHT: 401 break; // No action 402 } 403 } 404 return CEED_ERROR_SUCCESS; 405 } 406 407 //------------------------------------------------------------------------------ 408 // Restore Input Vectors 409 //------------------------------------------------------------------------------ 410 static inline int CeedOperatorRestoreInputs_Hip(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 411 const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Hip *impl) { 412 for (CeedInt i = 0; i < num_input_fields; i++) { 413 CeedEvalMode eval_mode; 414 CeedVector vec; 415 416 // Skip active input 417 if (skip_active) { 418 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 419 if (vec == CEED_VECTOR_ACTIVE) continue; 420 } 421 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 422 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 423 } else { 424 if (!impl->e_vecs[i]) { // This was a skip_restriction case 425 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 426 CeedCallBackend(CeedVectorRestoreArrayRead(vec, (const CeedScalar **)&e_data[i])); 427 } else { 428 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs[i], (const CeedScalar **)&e_data[i])); 429 } 430 } 431 } 432 return CEED_ERROR_SUCCESS; 433 } 434 435 //------------------------------------------------------------------------------ 436 // Apply and add to output 437 //------------------------------------------------------------------------------ 438 static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 439 CeedInt Q, num_elem, elem_size, num_input_fields, num_output_fields, size; 440 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 441 CeedQFunctionField *qf_input_fields, *qf_output_fields; 442 CeedQFunction qf; 443 CeedOperatorField *op_input_fields, *op_output_fields; 444 CeedOperator_Hip *impl; 445 446 CeedCallBackend(CeedOperatorGetData(op, &impl)); 447 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 448 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 449 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 450 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 451 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 452 453 // Setup 454 CeedCallBackend(CeedOperatorSetup_Hip(op)); 455 456 // Input Evecs and Restriction 457 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data, impl, request)); 458 459 // Input basis apply if needed 460 CeedCallBackend(CeedOperatorInputBasis_Hip(num_elem, qf_input_fields, op_input_fields, num_input_fields, false, e_data, impl)); 461 462 // Output pointers, as necessary 463 for (CeedInt i = 0; i < num_output_fields; i++) { 464 CeedEvalMode eval_mode; 465 466 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 467 if (eval_mode == CEED_EVAL_NONE) { 468 // Set the output Q-Vector to use the E-Vector data directly. 469 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 470 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 471 } 472 } 473 474 // Q function 475 CeedCallBackend(CeedQFunctionApply(qf, num_elem * Q, impl->q_vecs_in, impl->q_vecs_out)); 476 477 // Restore input arrays 478 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl)); 479 480 // Output basis apply if needed 481 for (CeedInt i = 0; i < num_output_fields; i++) { 482 CeedEvalMode eval_mode; 483 CeedElemRestriction elem_rstr; 484 CeedBasis basis; 485 486 // Get elem_size, eval_mode, size 487 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 488 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 489 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 490 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 491 // Basis action 492 switch (eval_mode) { 493 case CEED_EVAL_NONE: 494 break; // No action 495 case CEED_EVAL_INTERP: 496 case CEED_EVAL_GRAD: 497 case CEED_EVAL_DIV: 498 case CEED_EVAL_CURL: 499 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 500 if (impl->apply_add_basis_out[i]) { 501 CeedCallBackend(CeedBasisApplyAdd(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); 502 } else { 503 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); 504 } 505 break; 506 // LCOV_EXCL_START 507 case CEED_EVAL_WEIGHT: { 508 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 509 // LCOV_EXCL_STOP 510 } 511 } 512 } 513 514 // Output restriction 515 for (CeedInt i = 0; i < num_output_fields; i++) { 516 CeedEvalMode eval_mode; 517 CeedVector vec; 518 CeedElemRestriction elem_rstr; 519 520 // Restore evec 521 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 522 if (eval_mode == CEED_EVAL_NONE) { 523 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 524 } 525 if (impl->skip_rstr_out[i]) continue; 526 // Get output vector 527 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 528 // Restrict 529 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 530 // Active 531 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 532 533 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request)); 534 } 535 return CEED_ERROR_SUCCESS; 536 } 537 538 //------------------------------------------------------------------------------ 539 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 540 //------------------------------------------------------------------------------ 541 static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) { 542 Ceed ceed; 543 bool is_setup_done; 544 CeedInt max_num_points = -1, num_elem, num_input_fields, num_output_fields; 545 CeedQFunctionField *qf_input_fields, *qf_output_fields; 546 CeedQFunction qf; 547 CeedOperatorField *op_input_fields, *op_output_fields; 548 CeedOperator_Hip *impl; 549 550 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 551 if (is_setup_done) return CEED_ERROR_SUCCESS; 552 553 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 554 CeedCallBackend(CeedOperatorGetData(op, &impl)); 555 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 556 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 557 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 558 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 559 { 560 CeedElemRestriction rstr_points = NULL; 561 562 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 563 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(rstr_points, &max_num_points)); 564 CeedCallBackend(CeedCalloc(num_elem, &impl->num_points)); 565 for (CeedInt e = 0; e < num_elem; e++) { 566 CeedInt num_points_elem; 567 568 CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem)); 569 impl->num_points[e] = num_points_elem; 570 } 571 } 572 impl->max_num_points = max_num_points; 573 574 // Allocate 575 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); 576 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); 577 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); 578 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); 579 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 580 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 581 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 582 impl->num_inputs = num_input_fields; 583 impl->num_outputs = num_output_fields; 584 585 // Set up infield and outfield e-vecs and q-vecs 586 // Infields 587 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, true, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, 588 max_num_points, num_elem)); 589 // Outfields 590 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out, 591 num_input_fields, num_output_fields, max_num_points, num_elem)); 592 593 // Reuse active e-vecs where able 594 { 595 CeedInt num_used = 0; 596 CeedElemRestriction *rstr_used = NULL; 597 598 for (CeedInt i = 0; i < num_input_fields; i++) { 599 bool is_used = false; 600 CeedVector vec_i; 601 CeedElemRestriction rstr_i; 602 603 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i)); 604 if (vec_i != CEED_VECTOR_ACTIVE) continue; 605 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i)); 606 for (CeedInt j = 0; j < num_used; j++) { 607 if (rstr_i == rstr_used[i]) is_used = true; 608 } 609 if (is_used) continue; 610 num_used++; 611 if (num_used == 1) CeedCallBackend(CeedCalloc(num_used, &rstr_used)); 612 else CeedCallBackend(CeedRealloc(num_used, &rstr_used)); 613 rstr_used[num_used - 1] = rstr_i; 614 for (CeedInt j = num_output_fields - 1; j >= 0; j--) { 615 CeedEvalMode eval_mode; 616 CeedVector vec_j; 617 CeedElemRestriction rstr_j; 618 619 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec_j)); 620 if (vec_j != CEED_VECTOR_ACTIVE) continue; 621 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode)); 622 if (eval_mode == CEED_EVAL_NONE) continue; 623 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &rstr_j)); 624 if (rstr_i == rstr_j) { 625 CeedCallBackend(CeedVectorReferenceCopy(impl->e_vecs[i], &impl->e_vecs[j + impl->num_inputs])); 626 } 627 } 628 } 629 CeedCallBackend(CeedFree(&rstr_used)); 630 } 631 impl->has_shared_e_vecs = true; 632 CeedCallBackend(CeedOperatorSetSetupDone(op)); 633 return CEED_ERROR_SUCCESS; 634 } 635 636 //------------------------------------------------------------------------------ 637 // Input Basis Action AtPoints 638 //------------------------------------------------------------------------------ 639 static inline int CeedOperatorInputBasisAtPoints_Hip(CeedInt num_elem, const CeedInt *num_points, CeedQFunctionField *qf_input_fields, 640 CeedOperatorField *op_input_fields, CeedInt num_input_fields, const bool skip_active, 641 CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Hip *impl) { 642 for (CeedInt i = 0; i < num_input_fields; i++) { 643 CeedInt elem_size, size; 644 CeedEvalMode eval_mode; 645 CeedElemRestriction elem_rstr; 646 CeedBasis basis; 647 648 // Skip active input 649 if (skip_active) { 650 CeedVector vec; 651 652 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 653 if (vec == CEED_VECTOR_ACTIVE) continue; 654 } 655 // Get elem_size, eval_mode, size 656 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 657 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 658 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 659 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 660 // Basis action 661 switch (eval_mode) { 662 case CEED_EVAL_NONE: 663 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 664 break; 665 case CEED_EVAL_INTERP: 666 case CEED_EVAL_GRAD: 667 case CEED_EVAL_DIV: 668 case CEED_EVAL_CURL: 669 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 670 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, impl->e_vecs[i], 671 impl->q_vecs_in[i])); 672 break; 673 case CEED_EVAL_WEIGHT: 674 break; // No action 675 } 676 } 677 return CEED_ERROR_SUCCESS; 678 } 679 680 //------------------------------------------------------------------------------ 681 // Apply and add to output AtPoints 682 //------------------------------------------------------------------------------ 683 static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 684 CeedInt max_num_points, *num_points, num_elem, elem_size, num_input_fields, num_output_fields, size; 685 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 686 CeedQFunctionField *qf_input_fields, *qf_output_fields; 687 CeedQFunction qf; 688 CeedOperatorField *op_input_fields, *op_output_fields; 689 CeedOperator_Hip *impl; 690 691 CeedCallBackend(CeedOperatorGetData(op, &impl)); 692 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 693 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 694 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 695 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 696 697 // Setup 698 CeedCallBackend(CeedOperatorSetupAtPoints_Hip(op)); 699 num_points = impl->num_points; 700 max_num_points = impl->max_num_points; 701 702 // Input Evecs and Restriction 703 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data, impl, request)); 704 705 // Get point coordinates 706 if (!impl->point_coords_elem) { 707 CeedVector point_coords = NULL; 708 CeedElemRestriction rstr_points = NULL; 709 710 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 711 CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem)); 712 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 713 } 714 715 // Input basis apply if needed 716 CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(num_elem, num_points, qf_input_fields, op_input_fields, num_input_fields, false, e_data, impl)); 717 718 // Output pointers, as necessary 719 for (CeedInt i = 0; i < num_output_fields; i++) { 720 CeedEvalMode eval_mode; 721 722 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 723 if (eval_mode == CEED_EVAL_NONE) { 724 // Set the output Q-Vector to use the E-Vector data directly. 725 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 726 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 727 } 728 } 729 730 // Q function 731 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out)); 732 733 // Restore input arrays 734 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl)); 735 736 // Output basis apply if needed 737 for (CeedInt i = 0; i < num_output_fields; i++) { 738 CeedEvalMode eval_mode; 739 CeedElemRestriction elem_rstr; 740 CeedBasis basis; 741 742 // Get elem_size, eval_mode, size 743 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 744 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 745 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 746 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 747 // Basis action 748 switch (eval_mode) { 749 case CEED_EVAL_NONE: 750 break; // No action 751 case CEED_EVAL_INTERP: 752 case CEED_EVAL_GRAD: 753 case CEED_EVAL_DIV: 754 case CEED_EVAL_CURL: 755 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 756 if (impl->apply_add_basis_out[i]) { 757 CeedCallBackend(CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, 758 impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); 759 } else { 760 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i], 761 impl->e_vecs[i + impl->num_inputs])); 762 } 763 break; 764 // LCOV_EXCL_START 765 case CEED_EVAL_WEIGHT: { 766 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 767 // LCOV_EXCL_STOP 768 } 769 } 770 } 771 772 // Output restriction 773 for (CeedInt i = 0; i < num_output_fields; i++) { 774 CeedEvalMode eval_mode; 775 CeedVector vec; 776 CeedElemRestriction elem_rstr; 777 778 // Restore evec 779 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 780 if (eval_mode == CEED_EVAL_NONE) { 781 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 782 } 783 if (impl->skip_rstr_out[i]) continue; 784 // Get output vector 785 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 786 // Restrict 787 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 788 // Active 789 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 790 791 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request)); 792 } 793 return CEED_ERROR_SUCCESS; 794 } 795 796 //------------------------------------------------------------------------------ 797 // Linear QFunction Assembly Core 798 //------------------------------------------------------------------------------ 799 static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 800 CeedRequest *request) { 801 Ceed ceed, ceed_parent; 802 CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size; 803 CeedScalar *assembled_array, *e_data[2 * CEED_FIELD_MAX] = {NULL}; 804 CeedVector *active_inputs; 805 CeedQFunctionField *qf_input_fields, *qf_output_fields; 806 CeedQFunction qf; 807 CeedOperatorField *op_input_fields, *op_output_fields; 808 CeedOperator_Hip *impl; 809 810 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 811 CeedCallBackend(CeedOperatorGetFallbackParentCeed(op, &ceed_parent)); 812 CeedCallBackend(CeedOperatorGetData(op, &impl)); 813 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 814 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 815 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 816 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 817 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 818 active_inputs = impl->qf_active_in; 819 num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 820 821 // Setup 822 CeedCallBackend(CeedOperatorSetup_Hip(op)); 823 824 // Input Evecs and Restriction 825 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 826 827 // Count number of active input fields 828 if (!num_active_in) { 829 for (CeedInt i = 0; i < num_input_fields; i++) { 830 CeedScalar *q_vec_array; 831 CeedVector vec; 832 833 // Get input vector 834 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 835 // Check if active input 836 if (vec == CEED_VECTOR_ACTIVE) { 837 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 838 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 839 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, &q_vec_array)); 840 CeedCallBackend(CeedRealloc(num_active_in + size, &active_inputs)); 841 for (CeedInt field = 0; field < size; field++) { 842 CeedSize q_size = (CeedSize)Q * num_elem; 843 844 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_inputs[num_active_in + field])); 845 CeedCallBackend( 846 CeedVectorSetArray(active_inputs[num_active_in + field], CEED_MEM_DEVICE, CEED_USE_POINTER, &q_vec_array[field * Q * num_elem])); 847 } 848 num_active_in += size; 849 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array)); 850 } 851 } 852 impl->num_active_in = num_active_in; 853 impl->qf_active_in = active_inputs; 854 } 855 856 // Count number of active output fields 857 if (!num_active_out) { 858 for (CeedInt i = 0; i < num_output_fields; i++) { 859 CeedVector vec; 860 861 // Get output vector 862 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 863 // Check if active output 864 if (vec == CEED_VECTOR_ACTIVE) { 865 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 866 num_active_out += size; 867 } 868 } 869 impl->num_active_out = num_active_out; 870 } 871 872 // Check sizes 873 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 874 875 // Build objects if needed 876 if (build_objects) { 877 CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out; 878 CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */ 879 880 // Create output restriction 881 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out, 882 (CeedSize)num_active_in * (CeedSize)num_active_out * (CeedSize)num_elem * (CeedSize)Q, strides, 883 rstr)); 884 // Create assembled vector 885 CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled)); 886 } 887 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 888 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_DEVICE, &assembled_array)); 889 890 // Input basis apply 891 CeedCallBackend(CeedOperatorInputBasis_Hip(num_elem, qf_input_fields, op_input_fields, num_input_fields, true, e_data, impl)); 892 893 // Assemble QFunction 894 for (CeedInt in = 0; in < num_active_in; in++) { 895 // Set Inputs 896 CeedCallBackend(CeedVectorSetValue(active_inputs[in], 1.0)); 897 if (num_active_in > 1) { 898 CeedCallBackend(CeedVectorSetValue(active_inputs[(in + num_active_in - 1) % num_active_in], 0.0)); 899 } 900 // Set Outputs 901 for (CeedInt out = 0; out < num_output_fields; out++) { 902 CeedVector vec; 903 904 // Get output vector 905 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 906 // Check if active output 907 if (vec == CEED_VECTOR_ACTIVE) { 908 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, CEED_USE_POINTER, assembled_array)); 909 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size)); 910 assembled_array += size * Q * num_elem; // Advance the pointer by the size of the output 911 } 912 } 913 // Apply QFunction 914 CeedCallBackend(CeedQFunctionApply(qf, Q * num_elem, impl->q_vecs_in, impl->q_vecs_out)); 915 } 916 917 // Un-set output q-vecs to prevent accidental overwrite of Assembled 918 for (CeedInt out = 0; out < num_output_fields; out++) { 919 CeedVector vec; 920 921 // Get output vector 922 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 923 // Check if active output 924 if (vec == CEED_VECTOR_ACTIVE) { 925 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, NULL)); 926 } 927 } 928 929 // Restore input arrays 930 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 931 932 // Restore output 933 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 934 return CEED_ERROR_SUCCESS; 935 } 936 937 //------------------------------------------------------------------------------ 938 // Assemble Linear QFunction 939 //------------------------------------------------------------------------------ 940 static int CeedOperatorLinearAssembleQFunction_Hip(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 941 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, true, assembled, rstr, request); 942 } 943 944 //------------------------------------------------------------------------------ 945 // Update Assembled Linear QFunction 946 //------------------------------------------------------------------------------ 947 static int CeedOperatorLinearAssembleQFunctionUpdate_Hip(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 948 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, false, &assembled, &rstr, request); 949 } 950 951 //------------------------------------------------------------------------------ 952 // Assemble Diagonal Setup 953 //------------------------------------------------------------------------------ 954 static inline int CeedOperatorAssembleDiagonalSetup_Hip(CeedOperator op) { 955 Ceed ceed; 956 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 957 CeedInt q_comp, num_nodes, num_qpts; 958 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 959 CeedBasis basis_in = NULL, basis_out = NULL; 960 CeedQFunctionField *qf_fields; 961 CeedQFunction qf; 962 CeedOperatorField *op_fields; 963 CeedOperator_Hip *impl; 964 965 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 966 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 967 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 968 969 // Determine active input basis 970 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 971 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 972 for (CeedInt i = 0; i < num_input_fields; i++) { 973 CeedVector vec; 974 975 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 976 if (vec == CEED_VECTOR_ACTIVE) { 977 CeedBasis basis; 978 CeedEvalMode eval_mode; 979 980 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 981 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, 982 "Backend does not implement operator diagonal assembly with multiple active bases"); 983 basis_in = basis; 984 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 985 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 986 if (eval_mode != CEED_EVAL_WEIGHT) { 987 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 988 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 989 for (CeedInt d = 0; d < q_comp; d++) eval_modes_in[num_eval_modes_in + d] = eval_mode; 990 num_eval_modes_in += q_comp; 991 } 992 } 993 } 994 995 // Determine active output basis 996 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 997 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 998 for (CeedInt i = 0; i < num_output_fields; i++) { 999 CeedVector vec; 1000 1001 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1002 if (vec == CEED_VECTOR_ACTIVE) { 1003 CeedBasis basis; 1004 CeedEvalMode eval_mode; 1005 1006 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 1007 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 1008 "Backend does not implement operator diagonal assembly with multiple active bases"); 1009 basis_out = basis; 1010 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1011 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1012 if (eval_mode != CEED_EVAL_WEIGHT) { 1013 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 1014 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 1015 for (CeedInt d = 0; d < q_comp; d++) eval_modes_out[num_eval_modes_out + d] = eval_mode; 1016 num_eval_modes_out += q_comp; 1017 } 1018 } 1019 } 1020 1021 // Operator data struct 1022 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1023 CeedCallBackend(CeedCalloc(1, &impl->diag)); 1024 CeedOperatorDiag_Hip *diag = impl->diag; 1025 1026 // Basis matrices 1027 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes)); 1028 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes; 1029 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); 1030 const CeedInt interp_bytes = num_nodes * num_qpts * sizeof(CeedScalar); 1031 const CeedInt eval_modes_bytes = sizeof(CeedEvalMode); 1032 bool has_eval_none = false; 1033 1034 // CEED_EVAL_NONE 1035 for (CeedInt i = 0; i < num_eval_modes_in; i++) has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 1036 for (CeedInt i = 0; i < num_eval_modes_out; i++) has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 1037 if (has_eval_none) { 1038 CeedScalar *identity = NULL; 1039 1040 CeedCallBackend(CeedCalloc(num_nodes * num_qpts, &identity)); 1041 for (CeedInt i = 0; i < (num_nodes < num_qpts ? num_nodes : num_qpts); i++) identity[i * num_nodes + i] = 1.0; 1042 CeedCallHip(ceed, hipMalloc((void **)&diag->d_identity, interp_bytes)); 1043 CeedCallHip(ceed, hipMemcpy(diag->d_identity, identity, interp_bytes, hipMemcpyHostToDevice)); 1044 CeedCallBackend(CeedFree(&identity)); 1045 } 1046 1047 // CEED_EVAL_INTERP, CEED_EVAL_GRAD, CEED_EVAL_DIV, and CEED_EVAL_CURL 1048 for (CeedInt in = 0; in < 2; in++) { 1049 CeedFESpace fespace; 1050 CeedBasis basis = in ? basis_in : basis_out; 1051 1052 CeedCallBackend(CeedBasisGetFESpace(basis, &fespace)); 1053 switch (fespace) { 1054 case CEED_FE_SPACE_H1: { 1055 CeedInt q_comp_interp, q_comp_grad; 1056 const CeedScalar *interp, *grad; 1057 CeedScalar *d_interp, *d_grad; 1058 1059 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 1060 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_grad)); 1061 1062 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 1063 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 1064 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 1065 CeedCallBackend(CeedBasisGetGrad(basis, &grad)); 1066 CeedCallHip(ceed, hipMalloc((void **)&d_grad, interp_bytes * q_comp_grad)); 1067 CeedCallHip(ceed, hipMemcpy(d_grad, grad, interp_bytes * q_comp_grad, hipMemcpyHostToDevice)); 1068 if (in) { 1069 diag->d_interp_in = d_interp; 1070 diag->d_grad_in = d_grad; 1071 } else { 1072 diag->d_interp_out = d_interp; 1073 diag->d_grad_out = d_grad; 1074 } 1075 } break; 1076 case CEED_FE_SPACE_HDIV: { 1077 CeedInt q_comp_interp, q_comp_div; 1078 const CeedScalar *interp, *div; 1079 CeedScalar *d_interp, *d_div; 1080 1081 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 1082 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_div)); 1083 1084 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 1085 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 1086 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 1087 CeedCallBackend(CeedBasisGetDiv(basis, &div)); 1088 CeedCallHip(ceed, hipMalloc((void **)&d_div, interp_bytes * q_comp_div)); 1089 CeedCallHip(ceed, hipMemcpy(d_div, div, interp_bytes * q_comp_div, hipMemcpyHostToDevice)); 1090 if (in) { 1091 diag->d_interp_in = d_interp; 1092 diag->d_div_in = d_div; 1093 } else { 1094 diag->d_interp_out = d_interp; 1095 diag->d_div_out = d_div; 1096 } 1097 } break; 1098 case CEED_FE_SPACE_HCURL: { 1099 CeedInt q_comp_interp, q_comp_curl; 1100 const CeedScalar *interp, *curl; 1101 CeedScalar *d_interp, *d_curl; 1102 1103 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 1104 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_curl)); 1105 1106 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 1107 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 1108 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 1109 CeedCallBackend(CeedBasisGetCurl(basis, &curl)); 1110 CeedCallHip(ceed, hipMalloc((void **)&d_curl, interp_bytes * q_comp_curl)); 1111 CeedCallHip(ceed, hipMemcpy(d_curl, curl, interp_bytes * q_comp_curl, hipMemcpyHostToDevice)); 1112 if (in) { 1113 diag->d_interp_in = d_interp; 1114 diag->d_curl_in = d_curl; 1115 } else { 1116 diag->d_interp_out = d_interp; 1117 diag->d_curl_out = d_curl; 1118 } 1119 } break; 1120 } 1121 } 1122 1123 // Arrays of eval_modes 1124 CeedCallHip(ceed, hipMalloc((void **)&diag->d_eval_modes_in, num_eval_modes_in * eval_modes_bytes)); 1125 CeedCallHip(ceed, hipMemcpy(diag->d_eval_modes_in, eval_modes_in, num_eval_modes_in * eval_modes_bytes, hipMemcpyHostToDevice)); 1126 CeedCallHip(ceed, hipMalloc((void **)&diag->d_eval_modes_out, num_eval_modes_out * eval_modes_bytes)); 1127 CeedCallHip(ceed, hipMemcpy(diag->d_eval_modes_out, eval_modes_out, num_eval_modes_out * eval_modes_bytes, hipMemcpyHostToDevice)); 1128 CeedCallBackend(CeedFree(&eval_modes_in)); 1129 CeedCallBackend(CeedFree(&eval_modes_out)); 1130 return CEED_ERROR_SUCCESS; 1131 } 1132 1133 //------------------------------------------------------------------------------ 1134 // Assemble Diagonal Setup (Compilation) 1135 //------------------------------------------------------------------------------ 1136 static inline int CeedOperatorAssembleDiagonalSetupCompile_Hip(CeedOperator op, CeedInt use_ceedsize_idx, const bool is_point_block) { 1137 Ceed ceed; 1138 char *diagonal_kernel_source; 1139 const char *diagonal_kernel_path; 1140 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 1141 CeedInt num_comp, q_comp, num_nodes, num_qpts; 1142 CeedBasis basis_in = NULL, basis_out = NULL; 1143 CeedQFunctionField *qf_fields; 1144 CeedQFunction qf; 1145 CeedOperatorField *op_fields; 1146 CeedOperator_Hip *impl; 1147 1148 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1149 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1150 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 1151 1152 // Determine active input basis 1153 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 1154 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 1155 for (CeedInt i = 0; i < num_input_fields; i++) { 1156 CeedVector vec; 1157 1158 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1159 if (vec == CEED_VECTOR_ACTIVE) { 1160 CeedEvalMode eval_mode; 1161 1162 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_in)); 1163 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1164 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 1165 if (eval_mode != CEED_EVAL_WEIGHT) { 1166 num_eval_modes_in += q_comp; 1167 } 1168 } 1169 } 1170 1171 // Determine active output basis 1172 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 1173 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 1174 for (CeedInt i = 0; i < num_output_fields; i++) { 1175 CeedVector vec; 1176 1177 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1178 if (vec == CEED_VECTOR_ACTIVE) { 1179 CeedEvalMode eval_mode; 1180 1181 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_out)); 1182 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1183 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1184 if (eval_mode != CEED_EVAL_WEIGHT) { 1185 num_eval_modes_out += q_comp; 1186 } 1187 } 1188 } 1189 1190 // Operator data struct 1191 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1192 CeedOperatorDiag_Hip *diag = impl->diag; 1193 1194 // Assemble kernel 1195 hipModule_t *module = is_point_block ? &diag->module_point_block : &diag->module; 1196 CeedInt elems_per_block = 1; 1197 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes)); 1198 CeedCallBackend(CeedBasisGetNumComponents(basis_in, &num_comp)); 1199 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes; 1200 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); 1201 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h", &diagonal_kernel_path)); 1202 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Kernel Source -----\n"); 1203 CeedCallBackend(CeedLoadSourceToBuffer(ceed, diagonal_kernel_path, &diagonal_kernel_source)); 1204 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Source Complete! -----\n"); 1205 CeedCallHip(ceed, CeedCompile_Hip(ceed, diagonal_kernel_source, module, 8, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 1206 num_eval_modes_out, "NUM_COMP", num_comp, "NUM_NODES", num_nodes, "NUM_QPTS", num_qpts, "USE_CEEDSIZE", 1207 use_ceedsize_idx, "USE_POINT_BLOCK", is_point_block ? 1 : 0, "BLOCK_SIZE", num_nodes * elems_per_block)); 1208 CeedCallHip(ceed, CeedGetKernel_Hip(ceed, *module, "LinearDiagonal", is_point_block ? &diag->LinearPointBlock : &diag->LinearDiagonal)); 1209 CeedCallBackend(CeedFree(&diagonal_kernel_path)); 1210 CeedCallBackend(CeedFree(&diagonal_kernel_source)); 1211 return CEED_ERROR_SUCCESS; 1212 } 1213 1214 //------------------------------------------------------------------------------ 1215 // Assemble Diagonal Core 1216 //------------------------------------------------------------------------------ 1217 static inline int CeedOperatorAssembleDiagonalCore_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool is_point_block) { 1218 Ceed ceed; 1219 CeedInt num_elem, num_nodes; 1220 CeedScalar *elem_diag_array; 1221 const CeedScalar *assembled_qf_array; 1222 CeedVector assembled_qf = NULL, elem_diag; 1223 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out, diag_rstr; 1224 CeedOperator_Hip *impl; 1225 1226 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1227 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1228 1229 // Assemble QFunction 1230 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, request)); 1231 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 1232 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 1233 1234 // Setup 1235 if (!impl->diag) CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Hip(op)); 1236 CeedOperatorDiag_Hip *diag = impl->diag; 1237 1238 assert(diag != NULL); 1239 1240 // Assemble kernel if needed 1241 if ((!is_point_block && !diag->LinearDiagonal) || (is_point_block && !diag->LinearPointBlock)) { 1242 CeedSize assembled_length, assembled_qf_length; 1243 CeedInt use_ceedsize_idx = 0; 1244 CeedCallBackend(CeedVectorGetLength(assembled, &assembled_length)); 1245 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 1246 if ((assembled_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 1247 1248 CeedCallBackend(CeedOperatorAssembleDiagonalSetupCompile_Hip(op, use_ceedsize_idx, is_point_block)); 1249 } 1250 1251 // Restriction and diagonal vector 1252 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 1253 CeedCheck(rstr_in == rstr_out, ceed, CEED_ERROR_BACKEND, 1254 "Cannot assemble operator diagonal with different input and output active element restrictions"); 1255 if (!is_point_block && !diag->diag_rstr) { 1256 CeedCallBackend(CeedElemRestrictionCreateUnsignedCopy(rstr_out, &diag->diag_rstr)); 1257 CeedCallBackend(CeedElemRestrictionCreateVector(diag->diag_rstr, NULL, &diag->elem_diag)); 1258 } else if (is_point_block && !diag->point_block_diag_rstr) { 1259 CeedCallBackend(CeedOperatorCreateActivePointBlockRestriction(rstr_out, &diag->point_block_diag_rstr)); 1260 CeedCallBackend(CeedElemRestrictionCreateVector(diag->point_block_diag_rstr, NULL, &diag->point_block_elem_diag)); 1261 } 1262 diag_rstr = is_point_block ? diag->point_block_diag_rstr : diag->diag_rstr; 1263 elem_diag = is_point_block ? diag->point_block_elem_diag : diag->elem_diag; 1264 CeedCallBackend(CeedVectorSetValue(elem_diag, 0.0)); 1265 1266 // Only assemble diagonal if the basis has nodes, otherwise inputs are null pointers 1267 CeedCallBackend(CeedElemRestrictionGetElementSize(diag_rstr, &num_nodes)); 1268 if (num_nodes > 0) { 1269 // Assemble element operator diagonals 1270 CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array)); 1271 CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem)); 1272 1273 // Compute the diagonal of B^T D B 1274 CeedInt elems_per_block = 1; 1275 CeedInt grid = CeedDivUpInt(num_elem, elems_per_block); 1276 void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_div_in, 1277 &diag->d_curl_in, &diag->d_interp_out, &diag->d_grad_out, &diag->d_div_out, &diag->d_curl_out, 1278 &diag->d_eval_modes_in, &diag->d_eval_modes_out, &assembled_qf_array, &elem_diag_array}; 1279 1280 if (is_point_block) { 1281 CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->LinearPointBlock, grid, num_nodes, 1, elems_per_block, args)); 1282 } else { 1283 CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->LinearDiagonal, grid, num_nodes, 1, elems_per_block, args)); 1284 } 1285 1286 // Restore arrays 1287 CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array)); 1288 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 1289 } 1290 1291 // Assemble local operator diagonal 1292 CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request)); 1293 1294 // Cleanup 1295 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1296 return CEED_ERROR_SUCCESS; 1297 } 1298 1299 //------------------------------------------------------------------------------ 1300 // Assemble Linear Diagonal 1301 //------------------------------------------------------------------------------ 1302 static int CeedOperatorLinearAssembleAddDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1303 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, false)); 1304 return CEED_ERROR_SUCCESS; 1305 } 1306 1307 //------------------------------------------------------------------------------ 1308 // Assemble Linear Point Block Diagonal 1309 //------------------------------------------------------------------------------ 1310 static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1311 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, true)); 1312 return CEED_ERROR_SUCCESS; 1313 } 1314 1315 //------------------------------------------------------------------------------ 1316 // Single Operator Assembly Setup 1317 //------------------------------------------------------------------------------ 1318 static int CeedSingleOperatorAssembleSetup_Hip(CeedOperator op, CeedInt use_ceedsize_idx) { 1319 Ceed ceed; 1320 char *assembly_kernel_source; 1321 const char *assembly_kernel_path; 1322 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 1323 CeedInt elem_size_in, num_qpts_in = 0, num_comp_in, elem_size_out, num_qpts_out, num_comp_out, q_comp; 1324 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 1325 CeedElemRestriction rstr_in = NULL, rstr_out = NULL; 1326 CeedBasis basis_in = NULL, basis_out = NULL; 1327 CeedQFunctionField *qf_fields; 1328 CeedQFunction qf; 1329 CeedOperatorField *input_fields, *output_fields; 1330 CeedOperator_Hip *impl; 1331 1332 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1333 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1334 1335 // Get intput and output fields 1336 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields)); 1337 1338 // Determine active input basis eval mode 1339 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1340 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 1341 for (CeedInt i = 0; i < num_input_fields; i++) { 1342 CeedVector vec; 1343 1344 CeedCallBackend(CeedOperatorFieldGetVector(input_fields[i], &vec)); 1345 if (vec == CEED_VECTOR_ACTIVE) { 1346 CeedBasis basis; 1347 CeedEvalMode eval_mode; 1348 1349 CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis)); 1350 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases"); 1351 basis_in = basis; 1352 CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &rstr_in)); 1353 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 1354 if (basis_in == CEED_BASIS_NONE) num_qpts_in = elem_size_in; 1355 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts_in)); 1356 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1357 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 1358 if (eval_mode != CEED_EVAL_WEIGHT) { 1359 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 1360 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 1361 for (CeedInt d = 0; d < q_comp; d++) { 1362 eval_modes_in[num_eval_modes_in + d] = eval_mode; 1363 } 1364 num_eval_modes_in += q_comp; 1365 } 1366 } 1367 } 1368 1369 // Determine active output basis; basis_out and rstr_out only used if same as input, TODO 1370 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 1371 for (CeedInt i = 0; i < num_output_fields; i++) { 1372 CeedVector vec; 1373 1374 CeedCallBackend(CeedOperatorFieldGetVector(output_fields[i], &vec)); 1375 if (vec == CEED_VECTOR_ACTIVE) { 1376 CeedBasis basis; 1377 CeedEvalMode eval_mode; 1378 1379 CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis)); 1380 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 1381 "Backend does not implement operator assembly with multiple active bases"); 1382 basis_out = basis; 1383 CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &rstr_out)); 1384 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 1385 if (basis_out == CEED_BASIS_NONE) num_qpts_out = elem_size_out; 1386 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_out, &num_qpts_out)); 1387 CeedCheck(num_qpts_in == num_qpts_out, ceed, CEED_ERROR_UNSUPPORTED, 1388 "Active input and output bases must have the same number of quadrature points"); 1389 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1390 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1391 if (eval_mode != CEED_EVAL_WEIGHT) { 1392 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 1393 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 1394 for (CeedInt d = 0; d < q_comp; d++) { 1395 eval_modes_out[num_eval_modes_out + d] = eval_mode; 1396 } 1397 num_eval_modes_out += q_comp; 1398 } 1399 } 1400 } 1401 CeedCheck(num_eval_modes_in > 0 && num_eval_modes_out > 0, ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs"); 1402 1403 CeedCallBackend(CeedCalloc(1, &impl->asmb)); 1404 CeedOperatorAssemble_Hip *asmb = impl->asmb; 1405 asmb->elems_per_block = 1; 1406 asmb->block_size_x = elem_size_in; 1407 asmb->block_size_y = elem_size_out; 1408 1409 bool fallback = asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block > 1024; 1410 1411 if (fallback) { 1412 // Use fallback kernel with 1D threadblock 1413 asmb->block_size_y = 1; 1414 } 1415 1416 // Compile kernels 1417 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &num_comp_in)); 1418 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_out, &num_comp_out)); 1419 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-operator-assemble.h", &assembly_kernel_path)); 1420 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Kernel Source -----\n"); 1421 CeedCallBackend(CeedLoadSourceToBuffer(ceed, assembly_kernel_path, &assembly_kernel_source)); 1422 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Source Complete! -----\n"); 1423 CeedCallBackend(CeedCompile_Hip(ceed, assembly_kernel_source, &asmb->module, 10, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 1424 num_eval_modes_out, "NUM_COMP_IN", num_comp_in, "NUM_COMP_OUT", num_comp_out, "NUM_NODES_IN", elem_size_in, 1425 "NUM_NODES_OUT", elem_size_out, "NUM_QPTS", num_qpts_in, "BLOCK_SIZE", 1426 asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block, "BLOCK_SIZE_Y", asmb->block_size_y, "USE_CEEDSIZE", 1427 use_ceedsize_idx)); 1428 CeedCallBackend(CeedGetKernel_Hip(ceed, asmb->module, "LinearAssemble", &asmb->LinearAssemble)); 1429 CeedCallBackend(CeedFree(&assembly_kernel_path)); 1430 CeedCallBackend(CeedFree(&assembly_kernel_source)); 1431 1432 // Load into B_in, in order that they will be used in eval_modes_in 1433 { 1434 const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof(CeedScalar); 1435 CeedInt d_in = 0; 1436 CeedEvalMode eval_modes_in_prev = CEED_EVAL_NONE; 1437 bool has_eval_none = false; 1438 CeedScalar *identity = NULL; 1439 1440 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 1441 has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 1442 } 1443 if (has_eval_none) { 1444 CeedCallBackend(CeedCalloc(elem_size_in * num_qpts_in, &identity)); 1445 for (CeedInt i = 0; i < (elem_size_in < num_qpts_in ? elem_size_in : num_qpts_in); i++) identity[i * elem_size_in + i] = 1.0; 1446 } 1447 1448 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_in, in_bytes)); 1449 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 1450 const CeedScalar *h_B_in; 1451 1452 CeedCallBackend(CeedOperatorGetBasisPointer(basis_in, eval_modes_in[i], identity, &h_B_in)); 1453 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_modes_in[i], &q_comp)); 1454 if (q_comp > 1) { 1455 if (i == 0 || eval_modes_in[i] != eval_modes_in_prev) d_in = 0; 1456 else h_B_in = &h_B_in[(++d_in) * elem_size_in * num_qpts_in]; 1457 } 1458 eval_modes_in_prev = eval_modes_in[i]; 1459 1460 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_in[i * elem_size_in * num_qpts_in], h_B_in, elem_size_in * num_qpts_in * sizeof(CeedScalar), 1461 hipMemcpyHostToDevice)); 1462 } 1463 1464 if (identity) { 1465 CeedCallBackend(CeedFree(&identity)); 1466 } 1467 } 1468 1469 // Load into B_out, in order that they will be used in eval_modes_out 1470 { 1471 const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof(CeedScalar); 1472 CeedInt d_out = 0; 1473 CeedEvalMode eval_modes_out_prev = CEED_EVAL_NONE; 1474 bool has_eval_none = false; 1475 CeedScalar *identity = NULL; 1476 1477 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1478 has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 1479 } 1480 if (has_eval_none) { 1481 CeedCallBackend(CeedCalloc(elem_size_out * num_qpts_out, &identity)); 1482 for (CeedInt i = 0; i < (elem_size_out < num_qpts_out ? elem_size_out : num_qpts_out); i++) identity[i * elem_size_out + i] = 1.0; 1483 } 1484 1485 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_out, out_bytes)); 1486 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1487 const CeedScalar *h_B_out; 1488 1489 CeedCallBackend(CeedOperatorGetBasisPointer(basis_out, eval_modes_out[i], identity, &h_B_out)); 1490 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_modes_out[i], &q_comp)); 1491 if (q_comp > 1) { 1492 if (i == 0 || eval_modes_out[i] != eval_modes_out_prev) d_out = 0; 1493 else h_B_out = &h_B_out[(++d_out) * elem_size_out * num_qpts_out]; 1494 } 1495 eval_modes_out_prev = eval_modes_out[i]; 1496 1497 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_out[i * elem_size_out * num_qpts_out], h_B_out, elem_size_out * num_qpts_out * sizeof(CeedScalar), 1498 hipMemcpyHostToDevice)); 1499 } 1500 1501 if (identity) { 1502 CeedCallBackend(CeedFree(&identity)); 1503 } 1504 } 1505 return CEED_ERROR_SUCCESS; 1506 } 1507 1508 //------------------------------------------------------------------------------ 1509 // Assemble matrix data for COO matrix of assembled operator. 1510 // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic. 1511 // 1512 // Note that this (and other assembly routines) currently assume only one active input restriction/basis per operator (could have multiple basis eval 1513 // modes). 1514 // TODO: allow multiple active input restrictions/basis objects 1515 //------------------------------------------------------------------------------ 1516 static int CeedSingleOperatorAssemble_Hip(CeedOperator op, CeedInt offset, CeedVector values) { 1517 Ceed ceed; 1518 CeedSize values_length = 0, assembled_qf_length = 0; 1519 CeedInt use_ceedsize_idx = 0, num_elem_in, num_elem_out, elem_size_in, elem_size_out; 1520 CeedScalar *values_array; 1521 const CeedScalar *assembled_qf_array; 1522 CeedVector assembled_qf = NULL; 1523 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out; 1524 CeedRestrictionType rstr_type_in, rstr_type_out; 1525 const bool *orients_in = NULL, *orients_out = NULL; 1526 const CeedInt8 *curl_orients_in = NULL, *curl_orients_out = NULL; 1527 CeedOperator_Hip *impl; 1528 1529 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1530 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1531 1532 // Assemble QFunction 1533 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, CEED_REQUEST_IMMEDIATE)); 1534 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 1535 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 1536 1537 CeedCallBackend(CeedVectorGetLength(values, &values_length)); 1538 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 1539 if ((values_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 1540 1541 // Setup 1542 if (!impl->asmb) CeedCallBackend(CeedSingleOperatorAssembleSetup_Hip(op, use_ceedsize_idx)); 1543 CeedOperatorAssemble_Hip *asmb = impl->asmb; 1544 1545 assert(asmb != NULL); 1546 1547 // Assemble element operator 1548 CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_DEVICE, &values_array)); 1549 values_array += offset; 1550 1551 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 1552 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &num_elem_in)); 1553 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 1554 1555 CeedCallBackend(CeedElemRestrictionGetType(rstr_in, &rstr_type_in)); 1556 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1557 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_in, CEED_MEM_DEVICE, &orients_in)); 1558 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1559 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_in, CEED_MEM_DEVICE, &curl_orients_in)); 1560 } 1561 1562 if (rstr_in != rstr_out) { 1563 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_out, &num_elem_out)); 1564 CeedCheck(num_elem_in == num_elem_out, ceed, CEED_ERROR_UNSUPPORTED, 1565 "Active input and output operator restrictions must have the same number of elements"); 1566 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 1567 1568 CeedCallBackend(CeedElemRestrictionGetType(rstr_out, &rstr_type_out)); 1569 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1570 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_out, CEED_MEM_DEVICE, &orients_out)); 1571 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1572 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_out, CEED_MEM_DEVICE, &curl_orients_out)); 1573 } 1574 } else { 1575 elem_size_out = elem_size_in; 1576 orients_out = orients_in; 1577 curl_orients_out = curl_orients_in; 1578 } 1579 1580 // Compute B^T D B 1581 CeedInt shared_mem = 1582 ((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0) + (curl_orients_in ? elem_size_in * asmb->block_size_y : 0)) * 1583 sizeof(CeedScalar); 1584 CeedInt grid = CeedDivUpInt(num_elem_in, asmb->elems_per_block); 1585 void *args[] = {(void *)&num_elem_in, &asmb->d_B_in, &asmb->d_B_out, &orients_in, &curl_orients_in, 1586 &orients_out, &curl_orients_out, &assembled_qf_array, &values_array}; 1587 1588 CeedCallBackend( 1589 CeedRunKernelDimShared_Hip(ceed, asmb->LinearAssemble, grid, asmb->block_size_x, asmb->block_size_y, asmb->elems_per_block, shared_mem, args)); 1590 1591 // Restore arrays 1592 CeedCallBackend(CeedVectorRestoreArray(values, &values_array)); 1593 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 1594 1595 // Cleanup 1596 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1597 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1598 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_in, &orients_in)); 1599 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1600 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_in, &curl_orients_in)); 1601 } 1602 if (rstr_in != rstr_out) { 1603 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1604 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_out, &orients_out)); 1605 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1606 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_out, &curl_orients_out)); 1607 } 1608 } 1609 return CEED_ERROR_SUCCESS; 1610 } 1611 1612 //------------------------------------------------------------------------------ 1613 // Assemble Linear QFunction AtPoints 1614 //------------------------------------------------------------------------------ 1615 static int CeedOperatorLinearAssembleQFunctionAtPoints_Hip(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 1616 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "Backend does not implement CeedOperatorLinearAssembleQFunction"); 1617 } 1618 1619 //------------------------------------------------------------------------------ 1620 // Assemble Linear Diagonal AtPoints 1621 //------------------------------------------------------------------------------ 1622 static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1623 CeedInt max_num_points, *num_points, num_elem, num_input_fields, num_output_fields; 1624 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 1625 CeedQFunctionField *qf_input_fields, *qf_output_fields; 1626 CeedQFunction qf; 1627 CeedOperatorField *op_input_fields, *op_output_fields; 1628 CeedOperator_Hip *impl; 1629 1630 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1631 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1632 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 1633 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 1634 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 1635 1636 // Setup 1637 CeedCallBackend(CeedOperatorSetupAtPoints_Hip(op)); 1638 num_points = impl->num_points; 1639 max_num_points = impl->max_num_points; 1640 1641 // Create separate output e-vecs 1642 if (impl->has_shared_e_vecs) { 1643 for (CeedInt i = 0; i < impl->num_outputs; i++) { 1644 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 1645 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[impl->num_inputs + i])); 1646 } 1647 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out, 1648 num_input_fields, num_output_fields, max_num_points, num_elem)); 1649 } 1650 impl->has_shared_e_vecs = false; 1651 1652 // Input Evecs and Restriction 1653 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 1654 1655 // Get point coordinates 1656 if (!impl->point_coords_elem) { 1657 CeedVector point_coords = NULL; 1658 CeedElemRestriction rstr_points = NULL; 1659 1660 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 1661 CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem)); 1662 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 1663 } 1664 1665 // Clear active input Qvecs 1666 for (CeedInt i = 0; i < num_input_fields; i++) { 1667 CeedVector vec; 1668 1669 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1670 if (vec != CEED_VECTOR_ACTIVE) continue; 1671 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 1672 } 1673 1674 // Input basis apply if needed 1675 CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(num_elem, num_points, qf_input_fields, op_input_fields, num_input_fields, true, e_data, impl)); 1676 1677 // Output pointers, as necessary 1678 for (CeedInt i = 0; i < num_output_fields; i++) { 1679 CeedEvalMode eval_mode; 1680 1681 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1682 if (eval_mode == CEED_EVAL_NONE) { 1683 // Set the output Q-Vector to use the E-Vector data directly. 1684 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 1685 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 1686 } 1687 } 1688 1689 // Loop over active fields 1690 for (CeedInt i = 0; i < num_input_fields; i++) { 1691 bool is_active_at_points = true; 1692 CeedInt elem_size = 1, num_comp_active = 1, e_vec_size = 0; 1693 CeedRestrictionType rstr_type; 1694 CeedVector vec; 1695 CeedElemRestriction elem_rstr; 1696 1697 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1698 // -- Skip non-active input 1699 if (vec != CEED_VECTOR_ACTIVE) continue; 1700 1701 // -- Get active restriction type 1702 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 1703 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1704 is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS; 1705 if (!is_active_at_points) CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 1706 else elem_size = max_num_points; 1707 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp_active)); 1708 1709 e_vec_size = elem_size * num_comp_active; 1710 for (CeedInt s = 0; s < e_vec_size; s++) { 1711 bool is_active_input = false; 1712 CeedEvalMode eval_mode; 1713 CeedVector vec; 1714 CeedBasis basis; 1715 1716 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1717 // Skip non-active input 1718 is_active_input = vec == CEED_VECTOR_ACTIVE; 1719 if (!is_active_input) continue; 1720 1721 // Update unit vector 1722 if (s == 0) CeedCallBackend(CeedVectorSetValue(impl->e_vecs[i], 0.0)); 1723 else CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs[i], s - 1, e_vec_size, 0.0)); 1724 CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs[i], s, e_vec_size, 1.0)); 1725 1726 // Basis action 1727 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 1728 switch (eval_mode) { 1729 case CEED_EVAL_NONE: 1730 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 1731 break; 1732 case CEED_EVAL_INTERP: 1733 case CEED_EVAL_GRAD: 1734 case CEED_EVAL_DIV: 1735 case CEED_EVAL_CURL: 1736 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 1737 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, impl->e_vecs[i], 1738 impl->q_vecs_in[i])); 1739 break; 1740 case CEED_EVAL_WEIGHT: 1741 break; // No action 1742 } 1743 1744 // Q function 1745 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out)); 1746 1747 // Output basis apply if needed 1748 for (CeedInt j = 0; j < num_output_fields; j++) { 1749 bool is_active_output = false; 1750 CeedInt elem_size = 0; 1751 CeedRestrictionType rstr_type; 1752 CeedEvalMode eval_mode; 1753 CeedVector vec; 1754 CeedElemRestriction elem_rstr; 1755 CeedBasis basis; 1756 1757 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec)); 1758 // ---- Skip non-active output 1759 is_active_output = vec == CEED_VECTOR_ACTIVE; 1760 if (!is_active_output) continue; 1761 1762 // ---- Check if elem size matches 1763 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr)); 1764 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1765 if (is_active_at_points && rstr_type != CEED_RESTRICTION_POINTS) continue; 1766 if (rstr_type == CEED_RESTRICTION_POINTS) { 1767 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(elem_rstr, &elem_size)); 1768 } else { 1769 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 1770 } 1771 { 1772 CeedInt num_comp = 0; 1773 1774 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp)); 1775 if (e_vec_size != num_comp * elem_size) continue; 1776 } 1777 1778 // Basis action 1779 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode)); 1780 switch (eval_mode) { 1781 case CEED_EVAL_NONE: 1782 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[j + impl->num_inputs], &e_data[j + num_input_fields])); 1783 break; 1784 case CEED_EVAL_INTERP: 1785 case CEED_EVAL_GRAD: 1786 case CEED_EVAL_DIV: 1787 case CEED_EVAL_CURL: 1788 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[j], &basis)); 1789 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, 1790 impl->q_vecs_out[j], impl->e_vecs[j + impl->num_inputs])); 1791 break; 1792 // LCOV_EXCL_START 1793 case CEED_EVAL_WEIGHT: { 1794 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 1795 // LCOV_EXCL_STOP 1796 } 1797 } 1798 1799 // Mask output e-vec 1800 CeedCallBackend(CeedVectorPointwiseMult(impl->e_vecs[j + impl->num_inputs], impl->e_vecs[i], impl->e_vecs[j + impl->num_inputs])); 1801 1802 // Restrict 1803 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr)); 1804 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[j + impl->num_inputs], assembled, request)); 1805 1806 // Reset q_vec for 1807 if (eval_mode == CEED_EVAL_NONE) { 1808 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[j + impl->num_inputs], CEED_MEM_DEVICE, &e_data[j + num_input_fields])); 1809 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[j], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[j + num_input_fields])); 1810 } 1811 } 1812 1813 // Reset vec 1814 if (s == e_vec_size - 1 && i != num_input_fields - 1) CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 1815 } 1816 } 1817 1818 // Restore CEED_EVAL_NONE 1819 for (CeedInt i = 0; i < num_output_fields; i++) { 1820 CeedEvalMode eval_mode; 1821 1822 // Get eval_mode 1823 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1824 1825 // Restore evec 1826 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1827 if (eval_mode == CEED_EVAL_NONE) { 1828 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 1829 } 1830 } 1831 1832 // Restore input arrays 1833 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 1834 return CEED_ERROR_SUCCESS; 1835 } 1836 1837 //------------------------------------------------------------------------------ 1838 // Create operator 1839 //------------------------------------------------------------------------------ 1840 int CeedOperatorCreate_Hip(CeedOperator op) { 1841 Ceed ceed; 1842 CeedOperator_Hip *impl; 1843 1844 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1845 CeedCallBackend(CeedCalloc(1, &impl)); 1846 CeedCallBackend(CeedOperatorSetData(op, impl)); 1847 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Hip)); 1848 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Hip)); 1849 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonal_Hip)); 1850 CeedCallBackend( 1851 CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal", CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip)); 1852 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssemble_Hip)); 1853 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip)); 1854 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip)); 1855 return CEED_ERROR_SUCCESS; 1856 } 1857 1858 //------------------------------------------------------------------------------ 1859 // Create operator AtPoints 1860 //------------------------------------------------------------------------------ 1861 int CeedOperatorCreateAtPoints_Hip(CeedOperator op) { 1862 Ceed ceed; 1863 CeedOperator_Hip *impl; 1864 1865 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1866 CeedCallBackend(CeedCalloc(1, &impl)); 1867 CeedCallBackend(CeedOperatorSetData(op, impl)); 1868 1869 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunctionAtPoints_Hip)); 1870 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip)); 1871 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Hip)); 1872 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip)); 1873 return CEED_ERROR_SUCCESS; 1874 } 1875 1876 //------------------------------------------------------------------------------ 1877