1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <assert.h> 9 #include <ceed/backend.h> 10 #include <ceed/ceed.h> 11 #include <ceed/jit-tools.h> 12 #include <hip/hip_runtime.h> 13 #include <stdbool.h> 14 #include <string.h> 15 16 #include "../hip/ceed-hip-compile.h" 17 #include "ceed-hip-ref.h" 18 19 //------------------------------------------------------------------------------ 20 // Destroy operator 21 //------------------------------------------------------------------------------ 22 static int CeedOperatorDestroy_Hip(CeedOperator op) { 23 CeedOperator_Hip *impl; 24 CeedCallBackend(CeedOperatorGetData(op, &impl)); 25 26 // Apply data 27 for (CeedInt i = 0; i < impl->numein + impl->numeout; i++) { 28 CeedCallBackend(CeedVectorDestroy(&impl->evecs[i])); 29 } 30 CeedCallBackend(CeedFree(&impl->evecs)); 31 32 for (CeedInt i = 0; i < impl->numein; i++) { 33 CeedCallBackend(CeedVectorDestroy(&impl->qvecsin[i])); 34 } 35 CeedCallBackend(CeedFree(&impl->qvecsin)); 36 37 for (CeedInt i = 0; i < impl->numeout; i++) { 38 CeedCallBackend(CeedVectorDestroy(&impl->qvecsout[i])); 39 } 40 CeedCallBackend(CeedFree(&impl->qvecsout)); 41 42 // QFunction diagonal assembly data 43 for (CeedInt i = 0; i < impl->qfnumactivein; i++) { 44 CeedCallBackend(CeedVectorDestroy(&impl->qfactivein[i])); 45 } 46 CeedCallBackend(CeedFree(&impl->qfactivein)); 47 48 // Diag data 49 if (impl->diag) { 50 Ceed ceed; 51 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 52 CeedCallHip(ceed, hipModuleUnload(impl->diag->module)); 53 CeedCallBackend(CeedFree(&impl->diag->h_emodein)); 54 CeedCallBackend(CeedFree(&impl->diag->h_emodeout)); 55 CeedCallHip(ceed, hipFree(impl->diag->d_emodein)); 56 CeedCallHip(ceed, hipFree(impl->diag->d_emodeout)); 57 CeedCallHip(ceed, hipFree(impl->diag->d_identity)); 58 CeedCallHip(ceed, hipFree(impl->diag->d_interpin)); 59 CeedCallHip(ceed, hipFree(impl->diag->d_interpout)); 60 CeedCallHip(ceed, hipFree(impl->diag->d_gradin)); 61 CeedCallHip(ceed, hipFree(impl->diag->d_gradout)); 62 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->pbdiagrstr)); 63 CeedCallBackend(CeedVectorDestroy(&impl->diag->elemdiag)); 64 CeedCallBackend(CeedVectorDestroy(&impl->diag->pbelemdiag)); 65 } 66 CeedCallBackend(CeedFree(&impl->diag)); 67 68 if (impl->asmb) { 69 Ceed ceed; 70 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 71 CeedCallHip(ceed, hipModuleUnload(impl->asmb->module)); 72 CeedCallHip(ceed, hipFree(impl->asmb->d_B_in)); 73 CeedCallHip(ceed, hipFree(impl->asmb->d_B_out)); 74 } 75 CeedCallBackend(CeedFree(&impl->asmb)); 76 77 CeedCallBackend(CeedFree(&impl)); 78 return CEED_ERROR_SUCCESS; 79 } 80 81 //------------------------------------------------------------------------------ 82 // Setup infields or outfields 83 //------------------------------------------------------------------------------ 84 static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool isinput, CeedVector *evecs, CeedVector *qvecs, CeedInt starte, 85 CeedInt numfields, CeedInt Q, CeedInt numelements) { 86 CeedInt dim, size; 87 CeedSize q_size; 88 Ceed ceed; 89 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 90 CeedBasis basis; 91 CeedElemRestriction Erestrict; 92 CeedOperatorField *opfields; 93 CeedQFunctionField *qffields; 94 CeedVector fieldvec; 95 bool strided; 96 bool skiprestrict; 97 98 if (isinput) { 99 CeedCallBackend(CeedOperatorGetFields(op, NULL, &opfields, NULL, NULL)); 100 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qffields, NULL, NULL)); 101 } else { 102 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &opfields)); 103 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qffields)); 104 } 105 106 // Loop over fields 107 for (CeedInt i = 0; i < numfields; i++) { 108 CeedEvalMode emode; 109 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qffields[i], &emode)); 110 111 strided = false; 112 skiprestrict = false; 113 if (emode != CEED_EVAL_WEIGHT) { 114 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opfields[i], &Erestrict)); 115 116 // Check whether this field can skip the element restriction: 117 // must be passive input, with emode NONE, and have a strided restriction with 118 // CEED_STRIDES_BACKEND. 119 120 // First, check whether the field is input or output: 121 if (isinput) { 122 // Check for passive input: 123 CeedCallBackend(CeedOperatorFieldGetVector(opfields[i], &fieldvec)); 124 if (fieldvec != CEED_VECTOR_ACTIVE) { 125 // Check emode 126 if (emode == CEED_EVAL_NONE) { 127 // Check for strided restriction 128 CeedCallBackend(CeedElemRestrictionIsStrided(Erestrict, &strided)); 129 if (strided) { 130 // Check if vector is already in preferred backend ordering 131 CeedCallBackend(CeedElemRestrictionHasBackendStrides(Erestrict, &skiprestrict)); 132 } 133 } 134 } 135 } 136 if (skiprestrict) { 137 // We do not need an E-Vector, but will use the input field vector's data 138 // directly in the operator application. 139 evecs[i + starte] = NULL; 140 } else { 141 CeedCallBackend(CeedElemRestrictionCreateVector(Erestrict, NULL, &evecs[i + starte])); 142 } 143 } 144 145 switch (emode) { 146 case CEED_EVAL_NONE: 147 CeedCallBackend(CeedQFunctionFieldGetSize(qffields[i], &size)); 148 q_size = (CeedSize)numelements * Q * size; 149 CeedCallBackend(CeedVectorCreate(ceed, q_size, &qvecs[i])); 150 break; 151 case CEED_EVAL_INTERP: 152 CeedCallBackend(CeedQFunctionFieldGetSize(qffields[i], &size)); 153 q_size = (CeedSize)numelements * Q * size; 154 CeedCallBackend(CeedVectorCreate(ceed, q_size, &qvecs[i])); 155 break; 156 case CEED_EVAL_GRAD: 157 CeedCallBackend(CeedOperatorFieldGetBasis(opfields[i], &basis)); 158 CeedCallBackend(CeedQFunctionFieldGetSize(qffields[i], &size)); 159 CeedCallBackend(CeedBasisGetDimension(basis, &dim)); 160 q_size = (CeedSize)numelements * Q * size; 161 CeedCallBackend(CeedVectorCreate(ceed, q_size, &qvecs[i])); 162 break; 163 case CEED_EVAL_WEIGHT: // Only on input fields 164 CeedCallBackend(CeedOperatorFieldGetBasis(opfields[i], &basis)); 165 q_size = (CeedSize)numelements * Q; 166 CeedCallBackend(CeedVectorCreate(ceed, q_size, &qvecs[i])); 167 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, NULL, qvecs[i])); 168 break; 169 case CEED_EVAL_DIV: 170 break; // TODO: Not implemented 171 case CEED_EVAL_CURL: 172 break; // TODO: Not implemented 173 } 174 } 175 return CEED_ERROR_SUCCESS; 176 } 177 178 //------------------------------------------------------------------------------ 179 // CeedOperator needs to connect all the named fields (be they active or passive) 180 // to the named inputs and outputs of its CeedQFunction. 181 //------------------------------------------------------------------------------ 182 static int CeedOperatorSetup_Hip(CeedOperator op) { 183 bool setupdone; 184 CeedCallBackend(CeedOperatorIsSetupDone(op, &setupdone)); 185 if (setupdone) return CEED_ERROR_SUCCESS; 186 Ceed ceed; 187 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 188 CeedOperator_Hip *impl; 189 CeedCallBackend(CeedOperatorGetData(op, &impl)); 190 CeedQFunction qf; 191 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 192 CeedInt Q, numelements, numinputfields, numoutputfields; 193 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 194 CeedCallBackend(CeedOperatorGetNumElements(op, &numelements)); 195 CeedOperatorField *opinputfields, *opoutputfields; 196 CeedCallBackend(CeedOperatorGetFields(op, &numinputfields, &opinputfields, &numoutputfields, &opoutputfields)); 197 CeedQFunctionField *qfinputfields, *qfoutputfields; 198 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields)); 199 200 // Allocate 201 CeedCallBackend(CeedCalloc(numinputfields + numoutputfields, &impl->evecs)); 202 203 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->qvecsin)); 204 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->qvecsout)); 205 206 impl->numein = numinputfields; 207 impl->numeout = numoutputfields; 208 209 // Set up infield and outfield evecs and qvecs 210 // Infields 211 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, impl->evecs, impl->qvecsin, 0, numinputfields, Q, numelements)); 212 213 // Outfields 214 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, impl->evecs, impl->qvecsout, numinputfields, numoutputfields, Q, numelements)); 215 216 CeedCallBackend(CeedOperatorSetSetupDone(op)); 217 return CEED_ERROR_SUCCESS; 218 } 219 220 //------------------------------------------------------------------------------ 221 // Setup Operator Inputs 222 //------------------------------------------------------------------------------ 223 static inline int CeedOperatorSetupInputs_Hip(CeedInt numinputfields, CeedQFunctionField *qfinputfields, CeedOperatorField *opinputfields, 224 CeedVector invec, const bool skipactive, CeedScalar *edata[2 * CEED_FIELD_MAX], CeedOperator_Hip *impl, 225 CeedRequest *request) { 226 CeedEvalMode emode; 227 CeedVector vec; 228 CeedElemRestriction Erestrict; 229 230 for (CeedInt i = 0; i < numinputfields; i++) { 231 // Get input vector 232 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 233 if (vec == CEED_VECTOR_ACTIVE) { 234 if (skipactive) continue; 235 else vec = invec; 236 } 237 238 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode)); 239 if (emode == CEED_EVAL_WEIGHT) { // Skip 240 } else { 241 // Get input vector 242 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 243 // Get input element restriction 244 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict)); 245 if (vec == CEED_VECTOR_ACTIVE) vec = invec; 246 // Restrict, if necessary 247 if (!impl->evecs[i]) { 248 // No restriction for this field; read data directly from vec. 249 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, (const CeedScalar **)&edata[i])); 250 } else { 251 CeedCallBackend(CeedElemRestrictionApply(Erestrict, CEED_NOTRANSPOSE, vec, impl->evecs[i], request)); 252 // Get evec 253 CeedCallBackend(CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_DEVICE, (const CeedScalar **)&edata[i])); 254 } 255 } 256 } 257 return CEED_ERROR_SUCCESS; 258 } 259 260 //------------------------------------------------------------------------------ 261 // Input Basis Action 262 //------------------------------------------------------------------------------ 263 static inline int CeedOperatorInputBasis_Hip(CeedInt numelements, CeedQFunctionField *qfinputfields, CeedOperatorField *opinputfields, 264 CeedInt numinputfields, const bool skipactive, CeedScalar *edata[2 * CEED_FIELD_MAX], 265 CeedOperator_Hip *impl) { 266 CeedInt elemsize, size; 267 CeedElemRestriction Erestrict; 268 CeedEvalMode emode; 269 CeedBasis basis; 270 271 for (CeedInt i = 0; i < numinputfields; i++) { 272 // Skip active input 273 if (skipactive) { 274 CeedVector vec; 275 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 276 if (vec == CEED_VECTOR_ACTIVE) continue; 277 } 278 // Get elemsize, emode, size 279 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict)); 280 CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elemsize)); 281 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode)); 282 CeedCallBackend(CeedQFunctionFieldGetSize(qfinputfields[i], &size)); 283 // Basis action 284 switch (emode) { 285 case CEED_EVAL_NONE: 286 CeedCallBackend(CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_DEVICE, CEED_USE_POINTER, edata[i])); 287 break; 288 case CEED_EVAL_INTERP: 289 CeedCallBackend(CeedOperatorFieldGetBasis(opinputfields[i], &basis)); 290 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_NOTRANSPOSE, CEED_EVAL_INTERP, impl->evecs[i], impl->qvecsin[i])); 291 break; 292 case CEED_EVAL_GRAD: 293 CeedCallBackend(CeedOperatorFieldGetBasis(opinputfields[i], &basis)); 294 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_NOTRANSPOSE, CEED_EVAL_GRAD, impl->evecs[i], impl->qvecsin[i])); 295 break; 296 case CEED_EVAL_WEIGHT: 297 break; // No action 298 case CEED_EVAL_DIV: 299 break; // TODO: Not implemented 300 case CEED_EVAL_CURL: 301 break; // TODO: Not implemented 302 } 303 } 304 return CEED_ERROR_SUCCESS; 305 } 306 307 //------------------------------------------------------------------------------ 308 // Restore Input Vectors 309 //------------------------------------------------------------------------------ 310 static inline int CeedOperatorRestoreInputs_Hip(CeedInt numinputfields, CeedQFunctionField *qfinputfields, CeedOperatorField *opinputfields, 311 const bool skipactive, CeedScalar *edata[2 * CEED_FIELD_MAX], CeedOperator_Hip *impl) { 312 CeedEvalMode emode; 313 CeedVector vec; 314 315 for (CeedInt i = 0; i < numinputfields; i++) { 316 // Skip active input 317 if (skipactive) { 318 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 319 if (vec == CEED_VECTOR_ACTIVE) continue; 320 } 321 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode)); 322 if (emode == CEED_EVAL_WEIGHT) { // Skip 323 } else { 324 if (!impl->evecs[i]) { // This was a skiprestrict case 325 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 326 CeedCallBackend(CeedVectorRestoreArrayRead(vec, (const CeedScalar **)&edata[i])); 327 } else { 328 CeedCallBackend(CeedVectorRestoreArrayRead(impl->evecs[i], (const CeedScalar **)&edata[i])); 329 } 330 } 331 } 332 return CEED_ERROR_SUCCESS; 333 } 334 335 //------------------------------------------------------------------------------ 336 // Apply and add to output 337 //------------------------------------------------------------------------------ 338 static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector invec, CeedVector outvec, CeedRequest *request) { 339 CeedOperator_Hip *impl; 340 CeedCallBackend(CeedOperatorGetData(op, &impl)); 341 CeedQFunction qf; 342 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 343 CeedInt Q, numelements, elemsize, numinputfields, numoutputfields, size; 344 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 345 CeedCallBackend(CeedOperatorGetNumElements(op, &numelements)); 346 CeedOperatorField *opinputfields, *opoutputfields; 347 CeedCallBackend(CeedOperatorGetFields(op, &numinputfields, &opinputfields, &numoutputfields, &opoutputfields)); 348 CeedQFunctionField *qfinputfields, *qfoutputfields; 349 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields)); 350 CeedEvalMode emode; 351 CeedVector vec; 352 CeedBasis basis; 353 CeedElemRestriction Erestrict; 354 CeedScalar *edata[2 * CEED_FIELD_MAX]; 355 356 // Setup 357 CeedCallBackend(CeedOperatorSetup_Hip(op)); 358 359 // Input Evecs and Restriction 360 CeedCallBackend(CeedOperatorSetupInputs_Hip(numinputfields, qfinputfields, opinputfields, invec, false, edata, impl, request)); 361 362 // Input basis apply if needed 363 CeedCallBackend(CeedOperatorInputBasis_Hip(numelements, qfinputfields, opinputfields, numinputfields, false, edata, impl)); 364 365 // Output pointers, as necessary 366 for (CeedInt i = 0; i < numoutputfields; i++) { 367 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode)); 368 if (emode == CEED_EVAL_NONE) { 369 // Set the output Q-Vector to use the E-Vector data directly. 370 CeedCallBackend(CeedVectorGetArrayWrite(impl->evecs[i + impl->numein], CEED_MEM_DEVICE, &edata[i + numinputfields])); 371 CeedCallBackend(CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_DEVICE, CEED_USE_POINTER, edata[i + numinputfields])); 372 } 373 } 374 375 // Q function 376 CeedCallBackend(CeedQFunctionApply(qf, numelements * Q, impl->qvecsin, impl->qvecsout)); 377 378 // Output basis apply if needed 379 for (CeedInt i = 0; i < numoutputfields; i++) { 380 // Get elemsize, emode, size 381 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict)); 382 CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elemsize)); 383 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode)); 384 CeedCallBackend(CeedQFunctionFieldGetSize(qfoutputfields[i], &size)); 385 // Basis action 386 switch (emode) { 387 case CEED_EVAL_NONE: 388 break; 389 case CEED_EVAL_INTERP: 390 CeedCallBackend(CeedOperatorFieldGetBasis(opoutputfields[i], &basis)); 391 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_TRANSPOSE, CEED_EVAL_INTERP, impl->qvecsout[i], impl->evecs[i + impl->numein])); 392 break; 393 case CEED_EVAL_GRAD: 394 CeedCallBackend(CeedOperatorFieldGetBasis(opoutputfields[i], &basis)); 395 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_TRANSPOSE, CEED_EVAL_GRAD, impl->qvecsout[i], impl->evecs[i + impl->numein])); 396 break; 397 // LCOV_EXCL_START 398 case CEED_EVAL_WEIGHT: { 399 Ceed ceed; 400 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 401 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 402 break; // Should not occur 403 } 404 case CEED_EVAL_DIV: 405 break; // TODO: Not implemented 406 case CEED_EVAL_CURL: 407 break; // TODO: Not implemented 408 // LCOV_EXCL_STOP 409 } 410 } 411 412 // Output restriction 413 for (CeedInt i = 0; i < numoutputfields; i++) { 414 // Restore evec 415 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode)); 416 if (emode == CEED_EVAL_NONE) { 417 CeedCallBackend(CeedVectorRestoreArray(impl->evecs[i + impl->numein], &edata[i + numinputfields])); 418 } 419 // Get output vector 420 CeedCallBackend(CeedOperatorFieldGetVector(opoutputfields[i], &vec)); 421 // Restrict 422 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict)); 423 // Active 424 if (vec == CEED_VECTOR_ACTIVE) vec = outvec; 425 426 CeedCallBackend(CeedElemRestrictionApply(Erestrict, CEED_TRANSPOSE, impl->evecs[i + impl->numein], vec, request)); 427 } 428 429 // Restore input arrays 430 CeedCallBackend(CeedOperatorRestoreInputs_Hip(numinputfields, qfinputfields, opinputfields, false, edata, impl)); 431 return CEED_ERROR_SUCCESS; 432 } 433 434 //------------------------------------------------------------------------------ 435 // Core code for assembling linear QFunction 436 //------------------------------------------------------------------------------ 437 static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 438 CeedRequest *request) { 439 CeedOperator_Hip *impl; 440 CeedCallBackend(CeedOperatorGetData(op, &impl)); 441 CeedQFunction qf; 442 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 443 CeedInt Q, numelements, numinputfields, numoutputfields, size; 444 CeedSize q_size; 445 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 446 CeedCallBackend(CeedOperatorGetNumElements(op, &numelements)); 447 CeedOperatorField *opinputfields, *opoutputfields; 448 CeedCallBackend(CeedOperatorGetFields(op, &numinputfields, &opinputfields, &numoutputfields, &opoutputfields)); 449 CeedQFunctionField *qfinputfields, *qfoutputfields; 450 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields)); 451 CeedVector vec; 452 CeedInt numactivein = impl->qfnumactivein, numactiveout = impl->qfnumactiveout; 453 CeedVector *activein = impl->qfactivein; 454 CeedScalar *a, *tmp; 455 Ceed ceed, ceedparent; 456 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 457 CeedCallBackend(CeedGetOperatorFallbackParentCeed(ceed, &ceedparent)); 458 ceedparent = ceedparent ? ceedparent : ceed; 459 CeedScalar *edata[2 * CEED_FIELD_MAX]; 460 461 // Setup 462 CeedCallBackend(CeedOperatorSetup_Hip(op)); 463 464 // Check for identity 465 bool identityqf; 466 CeedCallBackend(CeedQFunctionIsIdentity(qf, &identityqf)); 467 if (identityqf) { 468 // LCOV_EXCL_START 469 return CeedError(ceed, CEED_ERROR_BACKEND, "Assembling identity QFunctions not supported"); 470 // LCOV_EXCL_STOP 471 } 472 473 // Input Evecs and Restriction 474 CeedCallBackend(CeedOperatorSetupInputs_Hip(numinputfields, qfinputfields, opinputfields, NULL, true, edata, impl, request)); 475 476 // Count number of active input fields 477 if (!numactivein) { 478 for (CeedInt i = 0; i < numinputfields; i++) { 479 // Get input vector 480 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 481 // Check if active input 482 if (vec == CEED_VECTOR_ACTIVE) { 483 CeedCallBackend(CeedQFunctionFieldGetSize(qfinputfields[i], &size)); 484 CeedCallBackend(CeedVectorSetValue(impl->qvecsin[i], 0.0)); 485 CeedCallBackend(CeedVectorGetArray(impl->qvecsin[i], CEED_MEM_DEVICE, &tmp)); 486 CeedCallBackend(CeedRealloc(numactivein + size, &activein)); 487 for (CeedInt field = 0; field < size; field++) { 488 q_size = (CeedSize)Q * numelements; 489 CeedCallBackend(CeedVectorCreate(ceed, q_size, &activein[numactivein + field])); 490 CeedCallBackend(CeedVectorSetArray(activein[numactivein + field], CEED_MEM_DEVICE, CEED_USE_POINTER, &tmp[field * Q * numelements])); 491 } 492 numactivein += size; 493 CeedCallBackend(CeedVectorRestoreArray(impl->qvecsin[i], &tmp)); 494 } 495 } 496 impl->qfnumactivein = numactivein; 497 impl->qfactivein = activein; 498 } 499 500 // Count number of active output fields 501 if (!numactiveout) { 502 for (CeedInt i = 0; i < numoutputfields; i++) { 503 // Get output vector 504 CeedCallBackend(CeedOperatorFieldGetVector(opoutputfields[i], &vec)); 505 // Check if active output 506 if (vec == CEED_VECTOR_ACTIVE) { 507 CeedCallBackend(CeedQFunctionFieldGetSize(qfoutputfields[i], &size)); 508 numactiveout += size; 509 } 510 } 511 impl->qfnumactiveout = numactiveout; 512 } 513 514 // Check sizes 515 if (!numactivein || !numactiveout) { 516 // LCOV_EXCL_START 517 return CeedError(ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 518 // LCOV_EXCL_STOP 519 } 520 521 // Build objects if needed 522 if (build_objects) { 523 // Create output restriction 524 CeedInt strides[3] = {1, numelements * Q, Q}; /* *NOPAD* */ 525 CeedCallBackend(CeedElemRestrictionCreateStrided(ceedparent, numelements, Q, numactivein * numactiveout, 526 numactivein * numactiveout * numelements * Q, strides, rstr)); 527 // Create assembled vector 528 CeedSize l_size = (CeedSize)numelements * Q * numactivein * numactiveout; 529 CeedCallBackend(CeedVectorCreate(ceedparent, l_size, assembled)); 530 } 531 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 532 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_DEVICE, &a)); 533 534 // Input basis apply 535 CeedCallBackend(CeedOperatorInputBasis_Hip(numelements, qfinputfields, opinputfields, numinputfields, true, edata, impl)); 536 537 // Assemble QFunction 538 for (CeedInt in = 0; in < numactivein; in++) { 539 // Set Inputs 540 CeedCallBackend(CeedVectorSetValue(activein[in], 1.0)); 541 if (numactivein > 1) { 542 CeedCallBackend(CeedVectorSetValue(activein[(in + numactivein - 1) % numactivein], 0.0)); 543 } 544 // Set Outputs 545 for (CeedInt out = 0; out < numoutputfields; out++) { 546 // Get output vector 547 CeedCallBackend(CeedOperatorFieldGetVector(opoutputfields[out], &vec)); 548 // Check if active output 549 if (vec == CEED_VECTOR_ACTIVE) { 550 CeedCallBackend(CeedVectorSetArray(impl->qvecsout[out], CEED_MEM_DEVICE, CEED_USE_POINTER, a)); 551 CeedCallBackend(CeedQFunctionFieldGetSize(qfoutputfields[out], &size)); 552 a += size * Q * numelements; // Advance the pointer by the size of the output 553 } 554 } 555 // Apply QFunction 556 CeedCallBackend(CeedQFunctionApply(qf, Q * numelements, impl->qvecsin, impl->qvecsout)); 557 } 558 559 // Un-set output Qvecs to prevent accidental overwrite of Assembled 560 for (CeedInt out = 0; out < numoutputfields; out++) { 561 // Get output vector 562 CeedCallBackend(CeedOperatorFieldGetVector(opoutputfields[out], &vec)); 563 // Check if active output 564 if (vec == CEED_VECTOR_ACTIVE) { 565 CeedCallBackend(CeedVectorTakeArray(impl->qvecsout[out], CEED_MEM_DEVICE, NULL)); 566 } 567 } 568 569 // Restore input arrays 570 CeedCallBackend(CeedOperatorRestoreInputs_Hip(numinputfields, qfinputfields, opinputfields, true, edata, impl)); 571 572 // Restore output 573 CeedCallBackend(CeedVectorRestoreArray(*assembled, &a)); 574 575 return CEED_ERROR_SUCCESS; 576 } 577 578 //------------------------------------------------------------------------------ 579 // Assemble Linear QFunction 580 //------------------------------------------------------------------------------ 581 static int CeedOperatorLinearAssembleQFunction_Hip(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 582 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, true, assembled, rstr, request); 583 } 584 585 //------------------------------------------------------------------------------ 586 // Assemble Linear QFunction 587 //------------------------------------------------------------------------------ 588 static int CeedOperatorLinearAssembleQFunctionUpdate_Hip(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 589 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, false, &assembled, &rstr, request); 590 } 591 592 //------------------------------------------------------------------------------ 593 // Create point block restriction 594 //------------------------------------------------------------------------------ 595 static int CreatePBRestriction(CeedElemRestriction rstr, CeedElemRestriction *pbRstr) { 596 Ceed ceed; 597 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 598 const CeedInt *offsets; 599 CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets)); 600 601 // Expand offsets 602 CeedInt nelem, ncomp, elemsize, compstride, *pbOffsets; 603 CeedSize l_size; 604 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &nelem)); 605 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &ncomp)); 606 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elemsize)); 607 CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &compstride)); 608 CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size)); 609 CeedInt shift = ncomp; 610 if (compstride != 1) shift *= ncomp; 611 CeedCallBackend(CeedCalloc(nelem * elemsize, &pbOffsets)); 612 for (CeedInt i = 0; i < nelem * elemsize; i++) { 613 pbOffsets[i] = offsets[i] * shift; 614 } 615 616 // Create new restriction 617 CeedCallBackend( 618 CeedElemRestrictionCreate(ceed, nelem, elemsize, ncomp * ncomp, 1, l_size * ncomp, CEED_MEM_HOST, CEED_OWN_POINTER, pbOffsets, pbRstr)); 619 620 // Cleanup 621 CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets)); 622 623 return CEED_ERROR_SUCCESS; 624 } 625 626 //------------------------------------------------------------------------------ 627 // Assemble diagonal setup 628 //------------------------------------------------------------------------------ 629 static inline int CeedOperatorAssembleDiagonalSetup_Hip(CeedOperator op, const bool pointBlock) { 630 Ceed ceed; 631 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 632 CeedQFunction qf; 633 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 634 CeedInt numinputfields, numoutputfields; 635 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields)); 636 637 // Determine active input basis 638 CeedOperatorField *opfields; 639 CeedQFunctionField *qffields; 640 CeedCallBackend(CeedOperatorGetFields(op, NULL, &opfields, NULL, NULL)); 641 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qffields, NULL, NULL)); 642 CeedInt numemodein = 0, ncomp = 0, dim = 1; 643 CeedEvalMode *emodein = NULL; 644 CeedBasis basisin = NULL; 645 CeedElemRestriction rstrin = NULL; 646 for (CeedInt i = 0; i < numinputfields; i++) { 647 CeedVector vec; 648 CeedCallBackend(CeedOperatorFieldGetVector(opfields[i], &vec)); 649 if (vec == CEED_VECTOR_ACTIVE) { 650 CeedElemRestriction rstr; 651 CeedCallBackend(CeedOperatorFieldGetBasis(opfields[i], &basisin)); 652 CeedCallBackend(CeedBasisGetNumComponents(basisin, &ncomp)); 653 CeedCallBackend(CeedBasisGetDimension(basisin, &dim)); 654 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opfields[i], &rstr)); 655 if (rstrin && rstrin != rstr) { 656 // LCOV_EXCL_START 657 return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator diagonal assembly"); 658 // LCOV_EXCL_STOP 659 } 660 rstrin = rstr; 661 CeedEvalMode emode; 662 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qffields[i], &emode)); 663 switch (emode) { 664 case CEED_EVAL_NONE: 665 case CEED_EVAL_INTERP: 666 CeedCallBackend(CeedRealloc(numemodein + 1, &emodein)); 667 emodein[numemodein] = emode; 668 numemodein += 1; 669 break; 670 case CEED_EVAL_GRAD: 671 CeedCallBackend(CeedRealloc(numemodein + dim, &emodein)); 672 for (CeedInt d = 0; d < dim; d++) emodein[numemodein + d] = emode; 673 numemodein += dim; 674 break; 675 case CEED_EVAL_WEIGHT: 676 case CEED_EVAL_DIV: 677 case CEED_EVAL_CURL: 678 break; // Caught by QF Assembly 679 } 680 } 681 } 682 683 // Determine active output basis 684 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &opfields)); 685 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qffields)); 686 CeedInt numemodeout = 0; 687 CeedEvalMode *emodeout = NULL; 688 CeedBasis basisout = NULL; 689 CeedElemRestriction rstrout = NULL; 690 for (CeedInt i = 0; i < numoutputfields; i++) { 691 CeedVector vec; 692 CeedCallBackend(CeedOperatorFieldGetVector(opfields[i], &vec)); 693 if (vec == CEED_VECTOR_ACTIVE) { 694 CeedElemRestriction rstr; 695 CeedCallBackend(CeedOperatorFieldGetBasis(opfields[i], &basisout)); 696 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opfields[i], &rstr)); 697 if (rstrout && rstrout != rstr) { 698 // LCOV_EXCL_START 699 return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator diagonal assembly"); 700 // LCOV_EXCL_STOP 701 } 702 rstrout = rstr; 703 CeedEvalMode emode; 704 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qffields[i], &emode)); 705 switch (emode) { 706 case CEED_EVAL_NONE: 707 case CEED_EVAL_INTERP: 708 CeedCallBackend(CeedRealloc(numemodeout + 1, &emodeout)); 709 emodeout[numemodeout] = emode; 710 numemodeout += 1; 711 break; 712 case CEED_EVAL_GRAD: 713 CeedCallBackend(CeedRealloc(numemodeout + dim, &emodeout)); 714 for (CeedInt d = 0; d < dim; d++) emodeout[numemodeout + d] = emode; 715 numemodeout += dim; 716 break; 717 case CEED_EVAL_WEIGHT: 718 case CEED_EVAL_DIV: 719 case CEED_EVAL_CURL: 720 break; // Caught by QF Assembly 721 } 722 } 723 } 724 725 // Operator data struct 726 CeedOperator_Hip *impl; 727 CeedCallBackend(CeedOperatorGetData(op, &impl)); 728 CeedCallBackend(CeedCalloc(1, &impl->diag)); 729 CeedOperatorDiag_Hip *diag = impl->diag; 730 diag->basisin = basisin; 731 diag->basisout = basisout; 732 diag->h_emodein = emodein; 733 diag->h_emodeout = emodeout; 734 diag->numemodein = numemodein; 735 diag->numemodeout = numemodeout; 736 737 // Assemble kernel 738 739 char *diagonal_kernel_path, *diagonal_kernel_source; 740 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h", &diagonal_kernel_path)); 741 CeedDebug256(ceed, 2, "----- Loading Diagonal Assembly Kernel Source -----\n"); 742 CeedCallBackend(CeedLoadSourceToBuffer(ceed, diagonal_kernel_path, &diagonal_kernel_source)); 743 CeedDebug256(ceed, 2, "----- Loading Diagonal Assembly Source Complete! -----\n"); 744 CeedInt nnodes, nqpts; 745 CeedCallBackend(CeedBasisGetNumNodes(basisin, &nnodes)); 746 CeedCallBackend(CeedBasisGetNumQuadraturePoints(basisin, &nqpts)); 747 diag->nnodes = nnodes; 748 CeedCallBackend(CeedCompileHip(ceed, diagonal_kernel_source, &diag->module, 5, "NUMEMODEIN", numemodein, "NUMEMODEOUT", numemodeout, "NNODES", 749 nnodes, "NQPTS", nqpts, "NCOMP", ncomp)); 750 CeedCallBackend(CeedGetKernelHip(ceed, diag->module, "linearDiagonal", &diag->linearDiagonal)); 751 CeedCallBackend(CeedGetKernelHip(ceed, diag->module, "linearPointBlockDiagonal", &diag->linearPointBlock)); 752 CeedCallBackend(CeedFree(&diagonal_kernel_path)); 753 CeedCallBackend(CeedFree(&diagonal_kernel_source)); 754 755 // Basis matrices 756 const CeedInt qBytes = nqpts * sizeof(CeedScalar); 757 const CeedInt iBytes = qBytes * nnodes; 758 const CeedInt gBytes = qBytes * nnodes * dim; 759 const CeedInt eBytes = sizeof(CeedEvalMode); 760 const CeedScalar *interpin, *interpout, *gradin, *gradout; 761 762 // CEED_EVAL_NONE 763 CeedScalar *identity = NULL; 764 bool evalNone = false; 765 for (CeedInt i = 0; i < numemodein; i++) evalNone = evalNone || (emodein[i] == CEED_EVAL_NONE); 766 for (CeedInt i = 0; i < numemodeout; i++) evalNone = evalNone || (emodeout[i] == CEED_EVAL_NONE); 767 if (evalNone) { 768 CeedCallBackend(CeedCalloc(nqpts * nnodes, &identity)); 769 for (CeedInt i = 0; i < (nnodes < nqpts ? nnodes : nqpts); i++) identity[i * nnodes + i] = 1.0; 770 CeedCallHip(ceed, hipMalloc((void **)&diag->d_identity, iBytes)); 771 CeedCallHip(ceed, hipMemcpy(diag->d_identity, identity, iBytes, hipMemcpyHostToDevice)); 772 } 773 774 // CEED_EVAL_INTERP 775 CeedCallBackend(CeedBasisGetInterp(basisin, &interpin)); 776 CeedCallHip(ceed, hipMalloc((void **)&diag->d_interpin, iBytes)); 777 CeedCallHip(ceed, hipMemcpy(diag->d_interpin, interpin, iBytes, hipMemcpyHostToDevice)); 778 CeedCallBackend(CeedBasisGetInterp(basisout, &interpout)); 779 CeedCallHip(ceed, hipMalloc((void **)&diag->d_interpout, iBytes)); 780 CeedCallHip(ceed, hipMemcpy(diag->d_interpout, interpout, iBytes, hipMemcpyHostToDevice)); 781 782 // CEED_EVAL_GRAD 783 CeedCallBackend(CeedBasisGetGrad(basisin, &gradin)); 784 CeedCallHip(ceed, hipMalloc((void **)&diag->d_gradin, gBytes)); 785 CeedCallHip(ceed, hipMemcpy(diag->d_gradin, gradin, gBytes, hipMemcpyHostToDevice)); 786 CeedCallBackend(CeedBasisGetGrad(basisout, &gradout)); 787 CeedCallHip(ceed, hipMalloc((void **)&diag->d_gradout, gBytes)); 788 CeedCallHip(ceed, hipMemcpy(diag->d_gradout, gradout, gBytes, hipMemcpyHostToDevice)); 789 790 // Arrays of emodes 791 CeedCallHip(ceed, hipMalloc((void **)&diag->d_emodein, numemodein * eBytes)); 792 CeedCallHip(ceed, hipMemcpy(diag->d_emodein, emodein, numemodein * eBytes, hipMemcpyHostToDevice)); 793 CeedCallHip(ceed, hipMalloc((void **)&diag->d_emodeout, numemodeout * eBytes)); 794 CeedCallHip(ceed, hipMemcpy(diag->d_emodeout, emodeout, numemodeout * eBytes, hipMemcpyHostToDevice)); 795 796 // Restriction 797 diag->diagrstr = rstrout; 798 799 return CEED_ERROR_SUCCESS; 800 } 801 802 //------------------------------------------------------------------------------ 803 // Assemble diagonal common code 804 //------------------------------------------------------------------------------ 805 static inline int CeedOperatorAssembleDiagonalCore_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool pointBlock) { 806 Ceed ceed; 807 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 808 CeedOperator_Hip *impl; 809 CeedCallBackend(CeedOperatorGetData(op, &impl)); 810 811 // Assemble QFunction 812 CeedVector assembledqf; 813 CeedElemRestriction rstr; 814 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembledqf, &rstr, request)); 815 CeedCallBackend(CeedElemRestrictionDestroy(&rstr)); 816 817 // Setup 818 if (!impl->diag) CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Hip(op, pointBlock)); 819 CeedOperatorDiag_Hip *diag = impl->diag; 820 assert(diag != NULL); 821 822 // Restriction 823 if (pointBlock && !diag->pbdiagrstr) { 824 CeedElemRestriction pbdiagrstr; 825 CeedCallBackend(CreatePBRestriction(diag->diagrstr, &pbdiagrstr)); 826 diag->pbdiagrstr = pbdiagrstr; 827 } 828 CeedElemRestriction diagrstr = pointBlock ? diag->pbdiagrstr : diag->diagrstr; 829 830 // Create diagonal vector 831 CeedVector elemdiag = pointBlock ? diag->pbelemdiag : diag->elemdiag; 832 if (!elemdiag) { 833 // Element diagonal vector 834 CeedCallBackend(CeedElemRestrictionCreateVector(diagrstr, NULL, &elemdiag)); 835 if (pointBlock) diag->pbelemdiag = elemdiag; 836 else diag->elemdiag = elemdiag; 837 } 838 CeedCallBackend(CeedVectorSetValue(elemdiag, 0.0)); 839 840 // Assemble element operator diagonals 841 CeedScalar *elemdiagarray; 842 const CeedScalar *assembledqfarray; 843 CeedCallBackend(CeedVectorGetArray(elemdiag, CEED_MEM_DEVICE, &elemdiagarray)); 844 CeedCallBackend(CeedVectorGetArrayRead(assembledqf, CEED_MEM_DEVICE, &assembledqfarray)); 845 CeedInt nelem; 846 CeedCallBackend(CeedElemRestrictionGetNumElements(diagrstr, &nelem)); 847 848 // Compute the diagonal of B^T D B 849 int elemsPerBlock = 1; 850 int grid = nelem / elemsPerBlock + ((nelem / elemsPerBlock * elemsPerBlock < nelem) ? 1 : 0); 851 void *args[] = {(void *)&nelem, &diag->d_identity, &diag->d_interpin, &diag->d_gradin, &diag->d_interpout, 852 &diag->d_gradout, &diag->d_emodein, &diag->d_emodeout, &assembledqfarray, &elemdiagarray}; 853 if (pointBlock) { 854 CeedCallBackend(CeedRunKernelDimHip(ceed, diag->linearPointBlock, grid, diag->nnodes, 1, elemsPerBlock, args)); 855 } else { 856 CeedCallBackend(CeedRunKernelDimHip(ceed, diag->linearDiagonal, grid, diag->nnodes, 1, elemsPerBlock, args)); 857 } 858 859 // Restore arrays 860 CeedCallBackend(CeedVectorRestoreArray(elemdiag, &elemdiagarray)); 861 CeedCallBackend(CeedVectorRestoreArrayRead(assembledqf, &assembledqfarray)); 862 863 // Assemble local operator diagonal 864 CeedCallBackend(CeedElemRestrictionApply(diagrstr, CEED_TRANSPOSE, elemdiag, assembled, request)); 865 866 // Cleanup 867 CeedCallBackend(CeedVectorDestroy(&assembledqf)); 868 869 return CEED_ERROR_SUCCESS; 870 } 871 872 //------------------------------------------------------------------------------ 873 // Assemble Linear Diagonal 874 //------------------------------------------------------------------------------ 875 static int CeedOperatorLinearAssembleAddDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 876 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, false)); 877 return CEED_ERROR_SUCCESS; 878 } 879 880 //------------------------------------------------------------------------------ 881 // Assemble Linear Point Block Diagonal 882 //------------------------------------------------------------------------------ 883 static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 884 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, true)); 885 return CEED_ERROR_SUCCESS; 886 } 887 888 //------------------------------------------------------------------------------ 889 // Single operator assembly setup 890 //------------------------------------------------------------------------------ 891 static int CeedSingleOperatorAssembleSetup_Hip(CeedOperator op) { 892 Ceed ceed; 893 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 894 CeedOperator_Hip *impl; 895 CeedCallBackend(CeedOperatorGetData(op, &impl)); 896 897 // Get intput and output fields 898 CeedInt num_input_fields, num_output_fields; 899 CeedOperatorField *input_fields; 900 CeedOperatorField *output_fields; 901 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields)); 902 903 // Determine active input basis eval mode 904 CeedQFunction qf; 905 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 906 CeedQFunctionField *qf_fields; 907 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 908 // Note that the kernel will treat each dimension of a gradient action separately; 909 // i.e., when an active input has a CEED_EVAL_GRAD mode, num_emode_in will increment 910 // by dim. However, for the purposes of loading the B matrices, it will be treated 911 // as one mode, and we will load/copy the entire gradient matrix at once, so 912 // num_B_in_mats_to_load will be incremented by 1. 913 CeedInt num_emode_in = 0, dim = 1, num_B_in_mats_to_load = 0, size_B_in = 0; 914 CeedEvalMode *eval_mode_in = NULL; // will be of size num_B_in_mats_load 915 CeedBasis basis_in = NULL; 916 CeedInt nqpts = 0, esize = 0; 917 CeedElemRestriction rstr_in = NULL; 918 for (CeedInt i = 0; i < num_input_fields; i++) { 919 CeedVector vec; 920 CeedCallBackend(CeedOperatorFieldGetVector(input_fields[i], &vec)); 921 if (vec == CEED_VECTOR_ACTIVE) { 922 CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis_in)); 923 CeedCallBackend(CeedBasisGetDimension(basis_in, &dim)); 924 CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &nqpts)); 925 CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &rstr_in)); 926 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &esize)); 927 CeedEvalMode eval_mode; 928 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 929 if (eval_mode != CEED_EVAL_NONE) { 930 CeedCallBackend(CeedRealloc(num_B_in_mats_to_load + 1, &eval_mode_in)); 931 eval_mode_in[num_B_in_mats_to_load] = eval_mode; 932 num_B_in_mats_to_load += 1; 933 if (eval_mode == CEED_EVAL_GRAD) { 934 num_emode_in += dim; 935 size_B_in += dim * esize * nqpts; 936 } else { 937 num_emode_in += 1; 938 size_B_in += esize * nqpts; 939 } 940 } 941 } 942 } 943 944 // Determine active output basis; basis_out and rstr_out only used if same as input, TODO 945 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 946 CeedInt num_emode_out = 0, num_B_out_mats_to_load = 0, size_B_out = 0; 947 CeedEvalMode *eval_mode_out = NULL; 948 CeedBasis basis_out = NULL; 949 CeedElemRestriction rstr_out = NULL; 950 for (CeedInt i = 0; i < num_output_fields; i++) { 951 CeedVector vec; 952 CeedCallBackend(CeedOperatorFieldGetVector(output_fields[i], &vec)); 953 if (vec == CEED_VECTOR_ACTIVE) { 954 CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis_out)); 955 CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &rstr_out)); 956 if (rstr_out && rstr_out != rstr_in) { 957 // LCOV_EXCL_START 958 return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator assembly"); 959 // LCOV_EXCL_STOP 960 } 961 CeedEvalMode eval_mode; 962 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 963 if (eval_mode != CEED_EVAL_NONE) { 964 CeedCallBackend(CeedRealloc(num_B_out_mats_to_load + 1, &eval_mode_out)); 965 eval_mode_out[num_B_out_mats_to_load] = eval_mode; 966 num_B_out_mats_to_load += 1; 967 if (eval_mode == CEED_EVAL_GRAD) { 968 num_emode_out += dim; 969 size_B_out += dim * esize * nqpts; 970 } else { 971 num_emode_out += 1; 972 size_B_out += esize * nqpts; 973 } 974 } 975 } 976 } 977 978 if (num_emode_in == 0 || num_emode_out == 0) { 979 // LCOV_EXCL_START 980 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs"); 981 // LCOV_EXCL_STOP 982 } 983 984 CeedInt nelem, ncomp; 985 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &nelem)); 986 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &ncomp)); 987 988 CeedCallBackend(CeedCalloc(1, &impl->asmb)); 989 CeedOperatorAssemble_Hip *asmb = impl->asmb; 990 asmb->nelem = nelem; 991 992 // Compile kernels 993 int elemsPerBlock = 1; 994 asmb->elemsPerBlock = elemsPerBlock; 995 CeedInt block_size = esize * esize * elemsPerBlock; 996 char *assembly_kernel_path, *assembly_kernel_source; 997 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-operator-assemble.h", &assembly_kernel_path)); 998 CeedDebug256(ceed, 2, "----- Loading Assembly Kernel Source -----\n"); 999 CeedCallBackend(CeedLoadSourceToBuffer(ceed, assembly_kernel_path, &assembly_kernel_source)); 1000 CeedDebug256(ceed, 2, "----- Loading Assembly Source Complete! -----\n"); 1001 bool fallback = block_size > 1024; 1002 if (fallback) { // Use fallback kernel with 1D threadblock 1003 block_size = esize * elemsPerBlock; 1004 asmb->block_size_x = esize; 1005 asmb->block_size_y = 1; 1006 } else { // Use kernel with 2D threadblock 1007 asmb->block_size_x = esize; 1008 asmb->block_size_y = esize; 1009 } 1010 CeedCallBackend(CeedCompileHip(ceed, assembly_kernel_source, &asmb->module, 7, "NELEM", nelem, "NUMEMODEIN", num_emode_in, "NUMEMODEOUT", 1011 num_emode_out, "NQPTS", nqpts, "NNODES", esize, "BLOCK_SIZE", block_size, "NCOMP", ncomp)); 1012 CeedCallBackend(CeedGetKernelHip(ceed, asmb->module, fallback ? "linearAssembleFallback" : "linearAssemble", &asmb->linearAssemble)); 1013 CeedCallBackend(CeedFree(&assembly_kernel_path)); 1014 CeedCallBackend(CeedFree(&assembly_kernel_source)); 1015 1016 // Build 'full' B matrices (not 1D arrays used for tensor-product matrices) 1017 const CeedScalar *interp_in, *grad_in; 1018 CeedCallBackend(CeedBasisGetInterp(basis_in, &interp_in)); 1019 CeedCallBackend(CeedBasisGetGrad(basis_in, &grad_in)); 1020 1021 // Load into B_in, in order that they will be used in eval_mode 1022 const CeedInt inBytes = size_B_in * sizeof(CeedScalar); 1023 CeedInt mat_start = 0; 1024 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_in, inBytes)); 1025 for (int i = 0; i < num_B_in_mats_to_load; i++) { 1026 CeedEvalMode eval_mode = eval_mode_in[i]; 1027 if (eval_mode == CEED_EVAL_INTERP) { 1028 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_in[mat_start], interp_in, esize * nqpts * sizeof(CeedScalar), hipMemcpyHostToDevice)); 1029 mat_start += esize * nqpts; 1030 } else if (eval_mode == CEED_EVAL_GRAD) { 1031 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_in[mat_start], grad_in, dim * esize * nqpts * sizeof(CeedScalar), hipMemcpyHostToDevice)); 1032 mat_start += dim * esize * nqpts; 1033 } 1034 } 1035 1036 const CeedScalar *interp_out, *grad_out; 1037 // Note that this function currently assumes 1 basis, so this should always be true 1038 // for now 1039 if (basis_out == basis_in) { 1040 interp_out = interp_in; 1041 grad_out = grad_in; 1042 } else { 1043 CeedCallBackend(CeedBasisGetInterp(basis_out, &interp_out)); 1044 CeedCallBackend(CeedBasisGetGrad(basis_out, &grad_out)); 1045 } 1046 1047 // Load into B_out, in order that they will be used in eval_mode 1048 const CeedInt outBytes = size_B_out * sizeof(CeedScalar); 1049 mat_start = 0; 1050 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_out, outBytes)); 1051 for (int i = 0; i < num_B_out_mats_to_load; i++) { 1052 CeedEvalMode eval_mode = eval_mode_out[i]; 1053 if (eval_mode == CEED_EVAL_INTERP) { 1054 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_out[mat_start], interp_out, esize * nqpts * sizeof(CeedScalar), hipMemcpyHostToDevice)); 1055 mat_start += esize * nqpts; 1056 } else if (eval_mode == CEED_EVAL_GRAD) { 1057 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_out[mat_start], grad_out, dim * esize * nqpts * sizeof(CeedScalar), hipMemcpyHostToDevice)); 1058 mat_start += dim * esize * nqpts; 1059 } 1060 } 1061 return CEED_ERROR_SUCCESS; 1062 } 1063 1064 //------------------------------------------------------------------------------ 1065 // Assemble matrix data for COO matrix of assembled operator. 1066 // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic. 1067 // 1068 // Note that this (and other assembly routines) currently assume only one 1069 // active input restriction/basis per operator (could have multiple basis eval 1070 // modes). 1071 // TODO: allow multiple active input restrictions/basis objects 1072 //------------------------------------------------------------------------------ 1073 static int CeedSingleOperatorAssemble_Hip(CeedOperator op, CeedInt offset, CeedVector values) { 1074 Ceed ceed; 1075 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1076 CeedOperator_Hip *impl; 1077 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1078 1079 // Setup 1080 if (!impl->asmb) { 1081 CeedCallBackend(CeedSingleOperatorAssembleSetup_Hip(op)); 1082 assert(impl->asmb != NULL); 1083 } 1084 1085 // Assemble QFunction 1086 CeedVector assembled_qf; 1087 CeedElemRestriction rstr_q; 1088 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &rstr_q, CEED_REQUEST_IMMEDIATE)); 1089 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_q)); 1090 CeedScalar *values_array; 1091 CeedCallBackend(CeedVectorGetArrayWrite(values, CEED_MEM_DEVICE, &values_array)); 1092 values_array += offset; 1093 const CeedScalar *qf_array; 1094 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &qf_array)); 1095 1096 // Compute B^T D B 1097 const CeedInt nelem = impl->asmb->nelem; // to satisfy clang-tidy 1098 const CeedInt elemsPerBlock = impl->asmb->elemsPerBlock; 1099 const CeedInt grid = nelem / elemsPerBlock + ((nelem / elemsPerBlock * elemsPerBlock < nelem) ? 1 : 0); 1100 void *args[] = {&impl->asmb->d_B_in, &impl->asmb->d_B_out, &qf_array, &values_array}; 1101 CeedCallBackend( 1102 CeedRunKernelDimHip(ceed, impl->asmb->linearAssemble, grid, impl->asmb->block_size_x, impl->asmb->block_size_y, elemsPerBlock, args)); 1103 1104 // Restore arrays 1105 CeedCallBackend(CeedVectorRestoreArray(values, &values_array)); 1106 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &qf_array)); 1107 1108 // Cleanup 1109 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1110 1111 return CEED_ERROR_SUCCESS; 1112 } 1113 1114 //------------------------------------------------------------------------------ 1115 // Create operator 1116 //------------------------------------------------------------------------------ 1117 int CeedOperatorCreate_Hip(CeedOperator op) { 1118 Ceed ceed; 1119 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1120 CeedOperator_Hip *impl; 1121 1122 CeedCallBackend(CeedCalloc(1, &impl)); 1123 CeedCallBackend(CeedOperatorSetData(op, impl)); 1124 1125 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Hip)); 1126 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Hip)); 1127 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonal_Hip)); 1128 CeedCallBackend( 1129 CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal", CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip)); 1130 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssemble_Hip)); 1131 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip)); 1132 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip)); 1133 return CEED_ERROR_SUCCESS; 1134 } 1135 1136 //------------------------------------------------------------------------------ 1137