1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <assert.h> 12 #include <stdbool.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Destroy operator 22 //------------------------------------------------------------------------------ 23 static int CeedOperatorDestroy_Hip(CeedOperator op) { 24 CeedOperator_Hip *impl; 25 26 CeedCallBackend(CeedOperatorGetData(op, &impl)); 27 28 // Apply data 29 for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { 30 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[i])); 31 } 32 CeedCallBackend(CeedFree(&impl->e_vecs)); 33 34 for (CeedInt i = 0; i < impl->num_inputs; i++) { 35 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 36 } 37 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 38 39 for (CeedInt i = 0; i < impl->num_outputs; i++) { 40 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 41 } 42 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 43 44 // QFunction assembly data 45 for (CeedInt i = 0; i < impl->num_active_in; i++) { 46 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i])); 47 } 48 CeedCallBackend(CeedFree(&impl->qf_active_in)); 49 50 // Diag data 51 if (impl->diag) { 52 Ceed ceed; 53 54 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 55 CeedCallHip(ceed, hipModuleUnload(impl->diag->module)); 56 CeedCallHip(ceed, hipFree(impl->diag->d_eval_modes_in)); 57 CeedCallHip(ceed, hipFree(impl->diag->d_eval_modes_out)); 58 CeedCallHip(ceed, hipFree(impl->diag->d_identity)); 59 CeedCallHip(ceed, hipFree(impl->diag->d_interp_in)); 60 CeedCallHip(ceed, hipFree(impl->diag->d_interp_out)); 61 CeedCallHip(ceed, hipFree(impl->diag->d_grad_in)); 62 CeedCallHip(ceed, hipFree(impl->diag->d_grad_out)); 63 CeedCallHip(ceed, hipFree(impl->diag->d_div_in)); 64 CeedCallHip(ceed, hipFree(impl->diag->d_div_out)); 65 CeedCallHip(ceed, hipFree(impl->diag->d_curl_in)); 66 CeedCallHip(ceed, hipFree(impl->diag->d_curl_out)); 67 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr)); 68 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr)); 69 CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag)); 70 CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag)); 71 } 72 CeedCallBackend(CeedFree(&impl->diag)); 73 74 if (impl->asmb) { 75 Ceed ceed; 76 77 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 78 CeedCallHip(ceed, hipModuleUnload(impl->asmb->module)); 79 CeedCallHip(ceed, hipFree(impl->asmb->d_B_in)); 80 CeedCallHip(ceed, hipFree(impl->asmb->d_B_out)); 81 } 82 CeedCallBackend(CeedFree(&impl->asmb)); 83 84 CeedCallBackend(CeedFree(&impl)); 85 return CEED_ERROR_SUCCESS; 86 } 87 88 //------------------------------------------------------------------------------ 89 // Setup infields or outfields 90 //------------------------------------------------------------------------------ 91 static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool is_input, CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, 92 CeedInt num_fields, CeedInt Q, CeedInt num_elem) { 93 Ceed ceed; 94 CeedQFunctionField *qf_fields; 95 CeedOperatorField *op_fields; 96 97 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 98 if (is_input) { 99 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 100 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 101 } else { 102 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 103 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 104 } 105 106 // Loop over fields 107 for (CeedInt i = 0; i < num_fields; i++) { 108 bool is_strided = false, skip_restriction = false; 109 CeedSize q_size; 110 CeedInt size; 111 CeedEvalMode eval_mode; 112 CeedBasis basis; 113 114 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 115 if (eval_mode != CEED_EVAL_WEIGHT) { 116 CeedElemRestriction elem_rstr; 117 118 // Check whether this field can skip the element restriction: 119 // Must be passive input, with eval_mode NONE, and have a strided restriction with CEED_STRIDES_BACKEND. 120 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); 121 122 // First, check whether the field is input or output: 123 if (is_input) { 124 CeedVector vec; 125 126 // Check for passive input 127 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 128 if (vec != CEED_VECTOR_ACTIVE) { 129 // Check eval_mode 130 if (eval_mode == CEED_EVAL_NONE) { 131 // Check for strided restriction 132 CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided)); 133 if (is_strided) { 134 // Check if vector is already in preferred backend ordering 135 CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &skip_restriction)); 136 } 137 } 138 } 139 } 140 if (skip_restriction) { 141 // We do not need an E-Vector, but will use the input field vector's data directly in the operator application. 142 e_vecs[i + start_e] = NULL; 143 } else { 144 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs[i + start_e])); 145 } 146 } 147 148 switch (eval_mode) { 149 case CEED_EVAL_NONE: 150 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 151 q_size = (CeedSize)num_elem * Q * size; 152 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 153 break; 154 case CEED_EVAL_INTERP: 155 case CEED_EVAL_GRAD: 156 case CEED_EVAL_DIV: 157 case CEED_EVAL_CURL: 158 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 159 q_size = (CeedSize)num_elem * Q * size; 160 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 161 break; 162 case CEED_EVAL_WEIGHT: // Only on input fields 163 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 164 q_size = (CeedSize)num_elem * Q; 165 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 166 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 167 break; 168 } 169 } 170 return CEED_ERROR_SUCCESS; 171 } 172 173 //------------------------------------------------------------------------------ 174 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 175 //------------------------------------------------------------------------------ 176 static int CeedOperatorSetup_Hip(CeedOperator op) { 177 Ceed ceed; 178 bool is_setup_done; 179 CeedInt Q, num_elem, num_input_fields, num_output_fields; 180 CeedQFunctionField *qf_input_fields, *qf_output_fields; 181 CeedQFunction qf; 182 CeedOperatorField *op_input_fields, *op_output_fields; 183 CeedOperator_Hip *impl; 184 185 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 186 if (is_setup_done) return CEED_ERROR_SUCCESS; 187 188 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 189 CeedCallBackend(CeedOperatorGetData(op, &impl)); 190 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 191 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 192 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 193 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 194 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 195 196 // Allocate 197 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); 198 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 199 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 200 impl->num_inputs = num_input_fields; 201 impl->num_outputs = num_output_fields; 202 203 // Set up infield and outfield e_vecs and q_vecs 204 // Infields 205 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem)); 206 // Outfields 207 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, Q, num_elem)); 208 209 CeedCallBackend(CeedOperatorSetSetupDone(op)); 210 return CEED_ERROR_SUCCESS; 211 } 212 213 //------------------------------------------------------------------------------ 214 // Setup Operator Inputs 215 //------------------------------------------------------------------------------ 216 static inline int CeedOperatorSetupInputs_Hip(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 217 CeedVector in_vec, const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 218 CeedOperator_Hip *impl, CeedRequest *request) { 219 for (CeedInt i = 0; i < num_input_fields; i++) { 220 CeedEvalMode eval_mode; 221 CeedVector vec; 222 CeedElemRestriction elem_rstr; 223 224 // Get input vector 225 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 226 if (vec == CEED_VECTOR_ACTIVE) { 227 if (skip_active) continue; 228 else vec = in_vec; 229 } 230 231 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 232 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 233 } else { 234 // Get input vector 235 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 236 // Get input element restriction 237 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 238 if (vec == CEED_VECTOR_ACTIVE) vec = in_vec; 239 // Restrict, if necessary 240 if (!impl->e_vecs[i]) { 241 // No restriction for this field; read data directly from vec. 242 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i])); 243 } else { 244 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs[i], request)); 245 // Get evec 246 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs[i], CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i])); 247 } 248 } 249 } 250 return CEED_ERROR_SUCCESS; 251 } 252 253 //------------------------------------------------------------------------------ 254 // Input Basis Action 255 //------------------------------------------------------------------------------ 256 static inline int CeedOperatorInputBasis_Hip(CeedInt num_elem, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 257 CeedInt num_input_fields, const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 258 CeedOperator_Hip *impl) { 259 for (CeedInt i = 0; i < num_input_fields; i++) { 260 CeedInt elem_size, size; 261 CeedEvalMode eval_mode; 262 CeedElemRestriction elem_rstr; 263 CeedBasis basis; 264 265 // Skip active input 266 if (skip_active) { 267 CeedVector vec; 268 269 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 270 if (vec == CEED_VECTOR_ACTIVE) continue; 271 } 272 // Get elem_size, eval_mode, size 273 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 274 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 275 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 276 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 277 // Basis action 278 switch (eval_mode) { 279 case CEED_EVAL_NONE: 280 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 281 break; 282 case CEED_EVAL_INTERP: 283 case CEED_EVAL_GRAD: 284 case CEED_EVAL_DIV: 285 case CEED_EVAL_CURL: 286 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 287 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs[i], impl->q_vecs_in[i])); 288 break; 289 case CEED_EVAL_WEIGHT: 290 break; // No action 291 } 292 } 293 return CEED_ERROR_SUCCESS; 294 } 295 296 //------------------------------------------------------------------------------ 297 // Restore Input Vectors 298 //------------------------------------------------------------------------------ 299 static inline int CeedOperatorRestoreInputs_Hip(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 300 const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Hip *impl) { 301 for (CeedInt i = 0; i < num_input_fields; i++) { 302 CeedEvalMode eval_mode; 303 CeedVector vec; 304 305 // Skip active input 306 if (skip_active) { 307 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 308 if (vec == CEED_VECTOR_ACTIVE) continue; 309 } 310 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 311 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 312 } else { 313 if (!impl->e_vecs[i]) { // This was a skip_restriction case 314 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 315 CeedCallBackend(CeedVectorRestoreArrayRead(vec, (const CeedScalar **)&e_data[i])); 316 } else { 317 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs[i], (const CeedScalar **)&e_data[i])); 318 } 319 } 320 } 321 return CEED_ERROR_SUCCESS; 322 } 323 324 //------------------------------------------------------------------------------ 325 // Apply and add to output 326 //------------------------------------------------------------------------------ 327 static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 328 CeedInt Q, num_elem, elem_size, num_input_fields, num_output_fields, size; 329 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 330 CeedQFunctionField *qf_input_fields, *qf_output_fields; 331 CeedQFunction qf; 332 CeedOperatorField *op_input_fields, *op_output_fields; 333 CeedOperator_Hip *impl; 334 335 CeedCallBackend(CeedOperatorGetData(op, &impl)); 336 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 337 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 338 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 339 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 340 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 341 342 // Setup 343 CeedCallBackend(CeedOperatorSetup_Hip(op)); 344 345 // Input Evecs and Restriction 346 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data, impl, request)); 347 348 // Input basis apply if needed 349 CeedCallBackend(CeedOperatorInputBasis_Hip(num_elem, qf_input_fields, op_input_fields, num_input_fields, false, e_data, impl)); 350 351 // Output pointers, as necessary 352 for (CeedInt i = 0; i < num_output_fields; i++) { 353 CeedEvalMode eval_mode; 354 355 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 356 if (eval_mode == CEED_EVAL_NONE) { 357 // Set the output Q-Vector to use the E-Vector data directly. 358 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 359 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 360 } 361 } 362 363 // Q function 364 CeedCallBackend(CeedQFunctionApply(qf, num_elem * Q, impl->q_vecs_in, impl->q_vecs_out)); 365 366 // Output basis apply if needed 367 for (CeedInt i = 0; i < num_output_fields; i++) { 368 CeedEvalMode eval_mode; 369 CeedElemRestriction elem_rstr; 370 CeedBasis basis; 371 372 // Get elem_size, eval_mode, size 373 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 374 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 375 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 376 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 377 // Basis action 378 switch (eval_mode) { 379 case CEED_EVAL_NONE: 380 break; // No action 381 case CEED_EVAL_INTERP: 382 case CEED_EVAL_GRAD: 383 case CEED_EVAL_DIV: 384 case CEED_EVAL_CURL: 385 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 386 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); 387 break; 388 // LCOV_EXCL_START 389 case CEED_EVAL_WEIGHT: { 390 Ceed ceed; 391 392 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 393 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 394 // LCOV_EXCL_STOP 395 } 396 } 397 } 398 399 // Output restriction 400 for (CeedInt i = 0; i < num_output_fields; i++) { 401 CeedEvalMode eval_mode; 402 CeedVector vec; 403 CeedElemRestriction elem_rstr; 404 405 // Restore evec 406 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 407 if (eval_mode == CEED_EVAL_NONE) { 408 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 409 } 410 // Get output vector 411 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 412 // Restrict 413 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 414 // Active 415 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 416 417 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request)); 418 } 419 420 // Restore input arrays 421 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl)); 422 return CEED_ERROR_SUCCESS; 423 } 424 425 //------------------------------------------------------------------------------ 426 // Linear QFunction Assembly Core 427 //------------------------------------------------------------------------------ 428 static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 429 CeedRequest *request) { 430 Ceed ceed, ceed_parent; 431 CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size; 432 CeedScalar *assembled_array, *e_data[2 * CEED_FIELD_MAX] = {NULL}; 433 CeedVector *active_inputs; 434 CeedQFunctionField *qf_input_fields, *qf_output_fields; 435 CeedQFunction qf; 436 CeedOperatorField *op_input_fields, *op_output_fields; 437 CeedOperator_Hip *impl; 438 439 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 440 CeedCallBackend(CeedOperatorGetFallbackParentCeed(op, &ceed_parent)); 441 CeedCallBackend(CeedOperatorGetData(op, &impl)); 442 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 443 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 444 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 445 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 446 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 447 active_inputs = impl->qf_active_in; 448 num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 449 450 // Setup 451 CeedCallBackend(CeedOperatorSetup_Hip(op)); 452 453 // Input Evecs and Restriction 454 CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 455 456 // Count number of active input fields 457 if (!num_active_in) { 458 for (CeedInt i = 0; i < num_input_fields; i++) { 459 CeedScalar *q_vec_array; 460 CeedVector vec; 461 462 // Get input vector 463 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 464 // Check if active input 465 if (vec == CEED_VECTOR_ACTIVE) { 466 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 467 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 468 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, &q_vec_array)); 469 CeedCallBackend(CeedRealloc(num_active_in + size, &active_inputs)); 470 for (CeedInt field = 0; field < size; field++) { 471 CeedSize q_size = (CeedSize)Q * num_elem; 472 473 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_inputs[num_active_in + field])); 474 CeedCallBackend( 475 CeedVectorSetArray(active_inputs[num_active_in + field], CEED_MEM_DEVICE, CEED_USE_POINTER, &q_vec_array[field * Q * num_elem])); 476 } 477 num_active_in += size; 478 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array)); 479 } 480 } 481 impl->num_active_in = num_active_in; 482 impl->qf_active_in = active_inputs; 483 } 484 485 // Count number of active output fields 486 if (!num_active_out) { 487 for (CeedInt i = 0; i < num_output_fields; i++) { 488 CeedVector vec; 489 490 // Get output vector 491 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 492 // Check if active output 493 if (vec == CEED_VECTOR_ACTIVE) { 494 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 495 num_active_out += size; 496 } 497 } 498 impl->num_active_out = num_active_out; 499 } 500 501 // Check sizes 502 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 503 504 // Build objects if needed 505 if (build_objects) { 506 CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out; 507 CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */ 508 509 // Create output restriction 510 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out, 511 num_active_in * num_active_out * num_elem * Q, strides, rstr)); 512 // Create assembled vector 513 CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled)); 514 } 515 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 516 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_DEVICE, &assembled_array)); 517 518 // Input basis apply 519 CeedCallBackend(CeedOperatorInputBasis_Hip(num_elem, qf_input_fields, op_input_fields, num_input_fields, true, e_data, impl)); 520 521 // Assemble QFunction 522 for (CeedInt in = 0; in < num_active_in; in++) { 523 // Set Inputs 524 CeedCallBackend(CeedVectorSetValue(active_inputs[in], 1.0)); 525 if (num_active_in > 1) { 526 CeedCallBackend(CeedVectorSetValue(active_inputs[(in + num_active_in - 1) % num_active_in], 0.0)); 527 } 528 // Set Outputs 529 for (CeedInt out = 0; out < num_output_fields; out++) { 530 CeedVector vec; 531 532 // Get output vector 533 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 534 // Check if active output 535 if (vec == CEED_VECTOR_ACTIVE) { 536 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, CEED_USE_POINTER, assembled_array)); 537 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size)); 538 assembled_array += size * Q * num_elem; // Advance the pointer by the size of the output 539 } 540 } 541 // Apply QFunction 542 CeedCallBackend(CeedQFunctionApply(qf, Q * num_elem, impl->q_vecs_in, impl->q_vecs_out)); 543 } 544 545 // Un-set output q_vecs to prevent accidental overwrite of Assembled 546 for (CeedInt out = 0; out < num_output_fields; out++) { 547 CeedVector vec; 548 549 // Get output vector 550 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 551 // Check if active output 552 if (vec == CEED_VECTOR_ACTIVE) { 553 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, NULL)); 554 } 555 } 556 557 // Restore input arrays 558 CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 559 560 // Restore output 561 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 562 return CEED_ERROR_SUCCESS; 563 } 564 565 //------------------------------------------------------------------------------ 566 // Assemble Linear QFunction 567 //------------------------------------------------------------------------------ 568 static int CeedOperatorLinearAssembleQFunction_Hip(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 569 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, true, assembled, rstr, request); 570 } 571 572 //------------------------------------------------------------------------------ 573 // Update Assembled Linear QFunction 574 //------------------------------------------------------------------------------ 575 static int CeedOperatorLinearAssembleQFunctionUpdate_Hip(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 576 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, false, &assembled, &rstr, request); 577 } 578 579 //------------------------------------------------------------------------------ 580 // Assemble Diagonal Setup 581 //------------------------------------------------------------------------------ 582 static inline int CeedOperatorAssembleDiagonalSetup_Hip(CeedOperator op, CeedInt use_ceedsize_idx) { 583 Ceed ceed; 584 char *diagonal_kernel_path, *diagonal_kernel_source; 585 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 586 CeedInt num_comp, q_comp, num_nodes, num_qpts; 587 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 588 CeedBasis basis_in = NULL, basis_out = NULL; 589 CeedQFunctionField *qf_fields; 590 CeedQFunction qf; 591 CeedOperatorField *op_fields; 592 CeedOperator_Hip *impl; 593 594 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 595 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 596 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 597 598 // Determine active input basis 599 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 600 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 601 for (CeedInt i = 0; i < num_input_fields; i++) { 602 CeedVector vec; 603 604 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 605 if (vec == CEED_VECTOR_ACTIVE) { 606 CeedBasis basis; 607 CeedEvalMode eval_mode; 608 609 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 610 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, 611 "Backend does not implement operator diagonal assembly with multiple active bases"); 612 basis_in = basis; 613 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 614 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 615 if (eval_mode != CEED_EVAL_WEIGHT) { 616 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 617 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 618 for (CeedInt d = 0; d < q_comp; d++) eval_modes_in[num_eval_modes_in + d] = eval_mode; 619 num_eval_modes_in += q_comp; 620 } 621 } 622 } 623 624 // Determine active output basis 625 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 626 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 627 for (CeedInt i = 0; i < num_output_fields; i++) { 628 CeedVector vec; 629 630 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 631 if (vec == CEED_VECTOR_ACTIVE) { 632 CeedBasis basis; 633 CeedEvalMode eval_mode; 634 635 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 636 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 637 "Backend does not implement operator diagonal assembly with multiple active bases"); 638 basis_out = basis; 639 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 640 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 641 if (eval_mode != CEED_EVAL_WEIGHT) { 642 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 643 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 644 for (CeedInt d = 0; d < q_comp; d++) eval_modes_out[num_eval_modes_out + d] = eval_mode; 645 num_eval_modes_out += q_comp; 646 } 647 } 648 } 649 650 // Operator data struct 651 CeedCallBackend(CeedOperatorGetData(op, &impl)); 652 CeedCallBackend(CeedCalloc(1, &impl->diag)); 653 CeedOperatorDiag_Hip *diag = impl->diag; 654 655 // Assemble kernel 656 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes)); 657 CeedCallBackend(CeedBasisGetNumComponents(basis_in, &num_comp)); 658 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes; 659 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); 660 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h", &diagonal_kernel_path)); 661 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Kernel Source -----\n"); 662 CeedCallBackend(CeedLoadSourceToBuffer(ceed, diagonal_kernel_path, &diagonal_kernel_source)); 663 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Source Complete! -----\n"); 664 CeedCallHip(ceed, 665 CeedCompile_Hip(ceed, diagonal_kernel_source, &diag->module, 6, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 666 num_eval_modes_out, "NUM_COMP", num_comp, "NUM_NODES", num_nodes, "NUM_QPTS", num_qpts, "CEED_SIZE", use_ceedsize_idx)); 667 CeedCallHip(ceed, CeedGetKernel_Hip(ceed, diag->module, "LinearDiagonal", &diag->LinearDiagonal)); 668 CeedCallHip(ceed, CeedGetKernel_Hip(ceed, diag->module, "LinearPointBlockDiagonal", &diag->LinearPointBlock)); 669 CeedCallBackend(CeedFree(&diagonal_kernel_path)); 670 CeedCallBackend(CeedFree(&diagonal_kernel_source)); 671 672 // Basis matrices 673 const CeedInt interp_bytes = num_nodes * num_qpts * sizeof(CeedScalar); 674 const CeedInt eval_modes_bytes = sizeof(CeedEvalMode); 675 bool has_eval_none = false; 676 677 // CEED_EVAL_NONE 678 for (CeedInt i = 0; i < num_eval_modes_in; i++) has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 679 for (CeedInt i = 0; i < num_eval_modes_out; i++) has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 680 if (has_eval_none) { 681 CeedScalar *identity = NULL; 682 683 CeedCallBackend(CeedCalloc(num_nodes * num_qpts, &identity)); 684 for (CeedInt i = 0; i < (num_nodes < num_qpts ? num_nodes : num_qpts); i++) identity[i * num_nodes + i] = 1.0; 685 CeedCallHip(ceed, hipMalloc((void **)&diag->d_identity, interp_bytes)); 686 CeedCallHip(ceed, hipMemcpy(diag->d_identity, identity, interp_bytes, hipMemcpyHostToDevice)); 687 CeedCallBackend(CeedFree(&identity)); 688 } 689 690 // CEED_EVAL_INTERP, CEED_EVAL_GRAD, CEED_EVAL_DIV, and CEED_EVAL_CURL 691 for (CeedInt in = 0; in < 2; in++) { 692 CeedFESpace fespace; 693 CeedBasis basis = in ? basis_in : basis_out; 694 695 CeedCallBackend(CeedBasisGetFESpace(basis, &fespace)); 696 switch (fespace) { 697 case CEED_FE_SPACE_H1: { 698 CeedInt q_comp_interp, q_comp_grad; 699 const CeedScalar *interp, *grad; 700 CeedScalar *d_interp, *d_grad; 701 702 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 703 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_grad)); 704 705 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 706 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 707 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 708 CeedCallBackend(CeedBasisGetGrad(basis, &grad)); 709 CeedCallHip(ceed, hipMalloc((void **)&d_grad, interp_bytes * q_comp_grad)); 710 CeedCallHip(ceed, hipMemcpy(d_grad, grad, interp_bytes * q_comp_grad, hipMemcpyHostToDevice)); 711 if (in) { 712 diag->d_interp_in = d_interp; 713 diag->d_grad_in = d_grad; 714 } else { 715 diag->d_interp_out = d_interp; 716 diag->d_grad_out = d_grad; 717 } 718 } break; 719 case CEED_FE_SPACE_HDIV: { 720 CeedInt q_comp_interp, q_comp_div; 721 const CeedScalar *interp, *div; 722 CeedScalar *d_interp, *d_div; 723 724 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 725 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_div)); 726 727 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 728 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 729 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 730 CeedCallBackend(CeedBasisGetDiv(basis, &div)); 731 CeedCallHip(ceed, hipMalloc((void **)&d_div, interp_bytes * q_comp_div)); 732 CeedCallHip(ceed, hipMemcpy(d_div, div, interp_bytes * q_comp_div, hipMemcpyHostToDevice)); 733 if (in) { 734 diag->d_interp_in = d_interp; 735 diag->d_div_in = d_div; 736 } else { 737 diag->d_interp_out = d_interp; 738 diag->d_div_out = d_div; 739 } 740 } break; 741 case CEED_FE_SPACE_HCURL: { 742 CeedInt q_comp_interp, q_comp_curl; 743 const CeedScalar *interp, *curl; 744 CeedScalar *d_interp, *d_curl; 745 746 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 747 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_curl)); 748 749 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 750 CeedCallHip(ceed, hipMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 751 CeedCallHip(ceed, hipMemcpy(d_interp, interp, interp_bytes * q_comp_interp, hipMemcpyHostToDevice)); 752 CeedCallBackend(CeedBasisGetCurl(basis, &curl)); 753 CeedCallHip(ceed, hipMalloc((void **)&d_curl, interp_bytes * q_comp_curl)); 754 CeedCallHip(ceed, hipMemcpy(d_curl, curl, interp_bytes * q_comp_curl, hipMemcpyHostToDevice)); 755 if (in) { 756 diag->d_interp_in = d_interp; 757 diag->d_curl_in = d_curl; 758 } else { 759 diag->d_interp_out = d_interp; 760 diag->d_curl_out = d_curl; 761 } 762 } break; 763 } 764 } 765 766 // Arrays of eval_modes 767 CeedCallHip(ceed, hipMalloc((void **)&diag->d_eval_modes_in, num_eval_modes_in * eval_modes_bytes)); 768 CeedCallHip(ceed, hipMemcpy(diag->d_eval_modes_in, eval_modes_in, num_eval_modes_in * eval_modes_bytes, hipMemcpyHostToDevice)); 769 CeedCallHip(ceed, hipMalloc((void **)&diag->d_eval_modes_out, num_eval_modes_out * eval_modes_bytes)); 770 CeedCallHip(ceed, hipMemcpy(diag->d_eval_modes_out, eval_modes_out, num_eval_modes_out * eval_modes_bytes, hipMemcpyHostToDevice)); 771 CeedCallBackend(CeedFree(&eval_modes_in)); 772 CeedCallBackend(CeedFree(&eval_modes_out)); 773 return CEED_ERROR_SUCCESS; 774 } 775 776 //------------------------------------------------------------------------------ 777 // Assemble Diagonal Core 778 //------------------------------------------------------------------------------ 779 static inline int CeedOperatorAssembleDiagonalCore_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool is_point_block) { 780 Ceed ceed; 781 CeedSize assembled_length, assembled_qf_length; 782 CeedInt use_ceedsize_idx = 0, num_elem, num_nodes; 783 CeedScalar *elem_diag_array; 784 const CeedScalar *assembled_qf_array; 785 CeedVector assembled_qf = NULL, elem_diag; 786 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out, diag_rstr; 787 CeedOperator_Hip *impl; 788 789 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 790 CeedCallBackend(CeedOperatorGetData(op, &impl)); 791 792 // Assemble QFunction 793 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, request)); 794 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 795 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 796 797 CeedCallBackend(CeedVectorGetLength(assembled, &assembled_length)); 798 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 799 if ((assembled_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 800 801 // Setup 802 if (!impl->diag) CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Hip(op, use_ceedsize_idx)); 803 CeedOperatorDiag_Hip *diag = impl->diag; 804 805 assert(diag != NULL); 806 807 // Restriction and diagonal vector 808 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 809 CeedCheck(rstr_in == rstr_out, ceed, CEED_ERROR_BACKEND, 810 "Cannot assemble operator diagonal with different input and output active element restrictions"); 811 if (!is_point_block && !diag->diag_rstr) { 812 CeedCallBackend(CeedElemRestrictionCreateUnsignedCopy(rstr_out, &diag->diag_rstr)); 813 CeedCallBackend(CeedElemRestrictionCreateVector(diag->diag_rstr, NULL, &diag->elem_diag)); 814 } else if (is_point_block && !diag->point_block_diag_rstr) { 815 CeedCallBackend(CeedOperatorCreateActivePointBlockRestriction(rstr_out, &diag->point_block_diag_rstr)); 816 CeedCallBackend(CeedElemRestrictionCreateVector(diag->point_block_diag_rstr, NULL, &diag->point_block_elem_diag)); 817 } 818 diag_rstr = is_point_block ? diag->point_block_diag_rstr : diag->diag_rstr; 819 elem_diag = is_point_block ? diag->point_block_elem_diag : diag->elem_diag; 820 CeedCallBackend(CeedVectorSetValue(elem_diag, 0.0)); 821 822 // Only assemble diagonal if the basis has nodes, otherwise inputs are null pointers 823 CeedCallBackend(CeedElemRestrictionGetElementSize(diag_rstr, &num_nodes)); 824 if (num_nodes > 0) { 825 // Assemble element operator diagonals 826 CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array)); 827 CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem)); 828 829 // Compute the diagonal of B^T D B 830 CeedInt elems_per_block = 1; 831 CeedInt grid = CeedDivUpInt(num_elem, elems_per_block); 832 void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_div_in, 833 &diag->d_curl_in, &diag->d_interp_out, &diag->d_grad_out, &diag->d_div_out, &diag->d_curl_out, 834 &diag->d_eval_modes_in, &diag->d_eval_modes_out, &assembled_qf_array, &elem_diag_array}; 835 836 if (is_point_block) { 837 CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->LinearPointBlock, grid, num_nodes, 1, elems_per_block, args)); 838 } else { 839 CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->LinearDiagonal, grid, num_nodes, 1, elems_per_block, args)); 840 } 841 842 // Restore arrays 843 CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array)); 844 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 845 } 846 847 // Assemble local operator diagonal 848 CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request)); 849 850 // Cleanup 851 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 852 return CEED_ERROR_SUCCESS; 853 } 854 855 //------------------------------------------------------------------------------ 856 // Assemble Linear Diagonal 857 //------------------------------------------------------------------------------ 858 static int CeedOperatorLinearAssembleAddDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 859 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, false)); 860 return CEED_ERROR_SUCCESS; 861 } 862 863 //------------------------------------------------------------------------------ 864 // Assemble Linear Point Block Diagonal 865 //------------------------------------------------------------------------------ 866 static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 867 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, true)); 868 return CEED_ERROR_SUCCESS; 869 } 870 871 //------------------------------------------------------------------------------ 872 // Single Operator Assembly Setup 873 //------------------------------------------------------------------------------ 874 static int CeedSingleOperatorAssembleSetup_Hip(CeedOperator op, CeedInt use_ceedsize_idx) { 875 Ceed ceed; 876 char *assembly_kernel_path, *assembly_kernel_source; 877 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 878 CeedInt elem_size_in, num_qpts_in, num_comp_in, elem_size_out, num_qpts_out, num_comp_out, q_comp; 879 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 880 CeedElemRestriction rstr_in = NULL, rstr_out = NULL; 881 CeedBasis basis_in = NULL, basis_out = NULL; 882 CeedQFunctionField *qf_fields; 883 CeedQFunction qf; 884 CeedOperatorField *input_fields, *output_fields; 885 CeedOperator_Hip *impl; 886 887 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 888 CeedCallBackend(CeedOperatorGetData(op, &impl)); 889 890 // Get intput and output fields 891 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields)); 892 893 // Determine active input basis eval mode 894 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 895 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 896 for (CeedInt i = 0; i < num_input_fields; i++) { 897 CeedVector vec; 898 899 CeedCallBackend(CeedOperatorFieldGetVector(input_fields[i], &vec)); 900 if (vec == CEED_VECTOR_ACTIVE) { 901 CeedBasis basis; 902 CeedEvalMode eval_mode; 903 904 CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis)); 905 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases"); 906 basis_in = basis; 907 CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &rstr_in)); 908 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 909 if (basis_in == CEED_BASIS_NONE) num_qpts_in = elem_size_in; 910 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts_in)); 911 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 912 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 913 if (eval_mode != CEED_EVAL_WEIGHT) { 914 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 915 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 916 for (CeedInt d = 0; d < q_comp; d++) { 917 eval_modes_in[num_eval_modes_in + d] = eval_mode; 918 } 919 num_eval_modes_in += q_comp; 920 } 921 } 922 } 923 924 // Determine active output basis; basis_out and rstr_out only used if same as input, TODO 925 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 926 for (CeedInt i = 0; i < num_output_fields; i++) { 927 CeedVector vec; 928 929 CeedCallBackend(CeedOperatorFieldGetVector(output_fields[i], &vec)); 930 if (vec == CEED_VECTOR_ACTIVE) { 931 CeedBasis basis; 932 CeedEvalMode eval_mode; 933 934 CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis)); 935 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 936 "Backend does not implement operator assembly with multiple active bases"); 937 basis_out = basis; 938 CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &rstr_out)); 939 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 940 if (basis_out == CEED_BASIS_NONE) num_qpts_out = elem_size_out; 941 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_out, &num_qpts_out)); 942 CeedCheck(num_qpts_in == num_qpts_out, ceed, CEED_ERROR_UNSUPPORTED, 943 "Active input and output bases must have the same number of quadrature points"); 944 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 945 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 946 if (eval_mode != CEED_EVAL_WEIGHT) { 947 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 948 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 949 for (CeedInt d = 0; d < q_comp; d++) { 950 eval_modes_out[num_eval_modes_out + d] = eval_mode; 951 } 952 num_eval_modes_out += q_comp; 953 } 954 } 955 } 956 CeedCheck(num_eval_modes_in > 0 && num_eval_modes_out > 0, ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs"); 957 958 CeedCallBackend(CeedCalloc(1, &impl->asmb)); 959 CeedOperatorAssemble_Hip *asmb = impl->asmb; 960 asmb->elems_per_block = 1; 961 asmb->block_size_x = elem_size_in; 962 asmb->block_size_y = elem_size_out; 963 964 bool fallback = asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block > 1024; 965 966 if (fallback) { 967 // Use fallback kernel with 1D threadblock 968 asmb->block_size_y = 1; 969 } 970 971 // Compile kernels 972 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &num_comp_in)); 973 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_out, &num_comp_out)); 974 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-operator-assemble.h", &assembly_kernel_path)); 975 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Kernel Source -----\n"); 976 CeedCallBackend(CeedLoadSourceToBuffer(ceed, assembly_kernel_path, &assembly_kernel_source)); 977 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Source Complete! -----\n"); 978 CeedCallBackend(CeedCompile_Hip(ceed, assembly_kernel_source, &asmb->module, 10, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 979 num_eval_modes_out, "NUM_COMP_IN", num_comp_in, "NUM_COMP_OUT", num_comp_out, "NUM_NODES_IN", elem_size_in, 980 "NUM_NODES_OUT", elem_size_out, "NUM_QPTS", num_qpts_in, "BLOCK_SIZE", 981 asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block, "BLOCK_SIZE_Y", asmb->block_size_y, "CEED_SIZE", 982 use_ceedsize_idx)); 983 CeedCallBackend(CeedGetKernel_Hip(ceed, asmb->module, "LinearAssemble", &asmb->LinearAssemble)); 984 CeedCallBackend(CeedFree(&assembly_kernel_path)); 985 CeedCallBackend(CeedFree(&assembly_kernel_source)); 986 987 // Load into B_in, in order that they will be used in eval_modes_in 988 { 989 const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof(CeedScalar); 990 CeedInt d_in = 0; 991 CeedEvalMode eval_modes_in_prev = CEED_EVAL_NONE; 992 bool has_eval_none = false; 993 CeedScalar *identity = NULL; 994 995 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 996 has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 997 } 998 if (has_eval_none) { 999 CeedCallBackend(CeedCalloc(elem_size_in * num_qpts_in, &identity)); 1000 for (CeedInt i = 0; i < (elem_size_in < num_qpts_in ? elem_size_in : num_qpts_in); i++) identity[i * elem_size_in + i] = 1.0; 1001 } 1002 1003 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_in, in_bytes)); 1004 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 1005 const CeedScalar *h_B_in; 1006 1007 CeedCallBackend(CeedOperatorGetBasisPointer(basis_in, eval_modes_in[i], identity, &h_B_in)); 1008 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_modes_in[i], &q_comp)); 1009 if (q_comp > 1) { 1010 if (i == 0 || eval_modes_in[i] != eval_modes_in_prev) d_in = 0; 1011 else h_B_in = &h_B_in[(++d_in) * elem_size_in * num_qpts_in]; 1012 } 1013 eval_modes_in_prev = eval_modes_in[i]; 1014 1015 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_in[i * elem_size_in * num_qpts_in], h_B_in, elem_size_in * num_qpts_in * sizeof(CeedScalar), 1016 hipMemcpyHostToDevice)); 1017 } 1018 1019 if (identity) { 1020 CeedCallBackend(CeedFree(&identity)); 1021 } 1022 } 1023 1024 // Load into B_out, in order that they will be used in eval_modes_out 1025 { 1026 const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof(CeedScalar); 1027 CeedInt d_out = 0; 1028 CeedEvalMode eval_modes_out_prev = CEED_EVAL_NONE; 1029 bool has_eval_none = false; 1030 CeedScalar *identity = NULL; 1031 1032 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1033 has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 1034 } 1035 if (has_eval_none) { 1036 CeedCallBackend(CeedCalloc(elem_size_out * num_qpts_out, &identity)); 1037 for (CeedInt i = 0; i < (elem_size_out < num_qpts_out ? elem_size_out : num_qpts_out); i++) identity[i * elem_size_out + i] = 1.0; 1038 } 1039 1040 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_out, out_bytes)); 1041 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1042 const CeedScalar *h_B_out; 1043 1044 CeedCallBackend(CeedOperatorGetBasisPointer(basis_out, eval_modes_out[i], identity, &h_B_out)); 1045 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_modes_out[i], &q_comp)); 1046 if (q_comp > 1) { 1047 if (i == 0 || eval_modes_out[i] != eval_modes_out_prev) d_out = 0; 1048 else h_B_out = &h_B_out[(++d_out) * elem_size_out * num_qpts_out]; 1049 } 1050 eval_modes_out_prev = eval_modes_out[i]; 1051 1052 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_out[i * elem_size_out * num_qpts_out], h_B_out, elem_size_out * num_qpts_out * sizeof(CeedScalar), 1053 hipMemcpyHostToDevice)); 1054 } 1055 1056 if (identity) { 1057 CeedCallBackend(CeedFree(&identity)); 1058 } 1059 } 1060 return CEED_ERROR_SUCCESS; 1061 } 1062 1063 //------------------------------------------------------------------------------ 1064 // Assemble matrix data for COO matrix of assembled operator. 1065 // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic. 1066 // 1067 // Note that this (and other assembly routines) currently assume only one active input restriction/basis per operator (could have multiple basis eval 1068 // modes). 1069 // TODO: allow multiple active input restrictions/basis objects 1070 //------------------------------------------------------------------------------ 1071 static int CeedSingleOperatorAssemble_Hip(CeedOperator op, CeedInt offset, CeedVector values) { 1072 Ceed ceed; 1073 CeedSize values_length = 0, assembled_qf_length = 0; 1074 CeedInt use_ceedsize_idx = 0, num_elem_in, num_elem_out, elem_size_in, elem_size_out; 1075 CeedScalar *values_array; 1076 const CeedScalar *assembled_qf_array; 1077 CeedVector assembled_qf = NULL; 1078 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out; 1079 CeedRestrictionType rstr_type_in, rstr_type_out; 1080 const bool *orients_in = NULL, *orients_out = NULL; 1081 const CeedInt8 *curl_orients_in = NULL, *curl_orients_out = NULL; 1082 CeedOperator_Hip *impl; 1083 1084 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1085 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1086 1087 // Assemble QFunction 1088 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, CEED_REQUEST_IMMEDIATE)); 1089 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 1090 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 1091 1092 CeedCallBackend(CeedVectorGetLength(values, &values_length)); 1093 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 1094 if ((values_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 1095 1096 // Setup 1097 if (!impl->asmb) CeedCallBackend(CeedSingleOperatorAssembleSetup_Hip(op, use_ceedsize_idx)); 1098 CeedOperatorAssemble_Hip *asmb = impl->asmb; 1099 1100 assert(asmb != NULL); 1101 1102 // Assemble element operator 1103 CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_DEVICE, &values_array)); 1104 values_array += offset; 1105 1106 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 1107 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &num_elem_in)); 1108 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 1109 1110 CeedCallBackend(CeedElemRestrictionGetType(rstr_in, &rstr_type_in)); 1111 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1112 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_in, CEED_MEM_DEVICE, &orients_in)); 1113 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1114 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_in, CEED_MEM_DEVICE, &curl_orients_in)); 1115 } 1116 1117 if (rstr_in != rstr_out) { 1118 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_out, &num_elem_out)); 1119 CeedCheck(num_elem_in == num_elem_out, ceed, CEED_ERROR_UNSUPPORTED, 1120 "Active input and output operator restrictions must have the same number of elements"); 1121 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 1122 1123 CeedCallBackend(CeedElemRestrictionGetType(rstr_out, &rstr_type_out)); 1124 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1125 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_out, CEED_MEM_DEVICE, &orients_out)); 1126 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1127 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_out, CEED_MEM_DEVICE, &curl_orients_out)); 1128 } 1129 } else { 1130 elem_size_out = elem_size_in; 1131 orients_out = orients_in; 1132 curl_orients_out = curl_orients_in; 1133 } 1134 1135 // Compute B^T D B 1136 CeedInt shared_mem = 1137 ((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0) + (curl_orients_in ? elem_size_in * asmb->block_size_y : 0)) * 1138 sizeof(CeedScalar); 1139 CeedInt grid = CeedDivUpInt(num_elem_in, asmb->elems_per_block); 1140 void *args[] = {(void *)&num_elem_in, &asmb->d_B_in, &asmb->d_B_out, &orients_in, &curl_orients_in, 1141 &orients_out, &curl_orients_out, &assembled_qf_array, &values_array}; 1142 1143 CeedCallBackend( 1144 CeedRunKernelDimShared_Hip(ceed, asmb->LinearAssemble, grid, asmb->block_size_x, asmb->block_size_y, asmb->elems_per_block, shared_mem, args)); 1145 1146 // Restore arrays 1147 CeedCallBackend(CeedVectorRestoreArray(values, &values_array)); 1148 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 1149 1150 // Cleanup 1151 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1152 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1153 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_in, &orients_in)); 1154 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1155 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_in, &curl_orients_in)); 1156 } 1157 if (rstr_in != rstr_out) { 1158 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1159 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_out, &orients_out)); 1160 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1161 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_out, &curl_orients_out)); 1162 } 1163 } 1164 return CEED_ERROR_SUCCESS; 1165 } 1166 1167 //------------------------------------------------------------------------------ 1168 // Create operator 1169 //------------------------------------------------------------------------------ 1170 int CeedOperatorCreate_Hip(CeedOperator op) { 1171 Ceed ceed; 1172 CeedOperator_Hip *impl; 1173 1174 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1175 CeedCallBackend(CeedCalloc(1, &impl)); 1176 CeedCallBackend(CeedOperatorSetData(op, impl)); 1177 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Hip)); 1178 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Hip)); 1179 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonal_Hip)); 1180 CeedCallBackend( 1181 CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal", CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip)); 1182 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssemble_Hip)); 1183 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip)); 1184 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip)); 1185 return CEED_ERROR_SUCCESS; 1186 } 1187 1188 //------------------------------------------------------------------------------ 1189