1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <assert.h> 9 #include <ceed/backend.h> 10 #include <ceed/ceed.h> 11 #include <ceed/jit-tools.h> 12 #include <hip/hip_runtime.h> 13 #include <stdbool.h> 14 #include <string.h> 15 16 #include "../hip/ceed-hip-compile.h" 17 #include "ceed-hip-ref.h" 18 19 //------------------------------------------------------------------------------ 20 // Destroy operator 21 //------------------------------------------------------------------------------ 22 static int CeedOperatorDestroy_Hip(CeedOperator op) { 23 CeedOperator_Hip *impl; 24 CeedCallBackend(CeedOperatorGetData(op, &impl)); 25 26 // Apply data 27 for (CeedInt i = 0; i < impl->numein + impl->numeout; i++) { 28 CeedCallBackend(CeedVectorDestroy(&impl->evecs[i])); 29 } 30 CeedCallBackend(CeedFree(&impl->evecs)); 31 32 for (CeedInt i = 0; i < impl->numein; i++) { 33 CeedCallBackend(CeedVectorDestroy(&impl->qvecsin[i])); 34 } 35 CeedCallBackend(CeedFree(&impl->qvecsin)); 36 37 for (CeedInt i = 0; i < impl->numeout; i++) { 38 CeedCallBackend(CeedVectorDestroy(&impl->qvecsout[i])); 39 } 40 CeedCallBackend(CeedFree(&impl->qvecsout)); 41 42 // QFunction diagonal assembly data 43 for (CeedInt i = 0; i < impl->qfnumactivein; i++) { 44 CeedCallBackend(CeedVectorDestroy(&impl->qfactivein[i])); 45 } 46 CeedCallBackend(CeedFree(&impl->qfactivein)); 47 48 // Diag data 49 if (impl->diag) { 50 Ceed ceed; 51 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 52 CeedCallHip(ceed, hipModuleUnload(impl->diag->module)); 53 CeedCallBackend(CeedFree(&impl->diag->h_emodein)); 54 CeedCallBackend(CeedFree(&impl->diag->h_emodeout)); 55 CeedCallHip(ceed, hipFree(impl->diag->d_emodein)); 56 CeedCallHip(ceed, hipFree(impl->diag->d_emodeout)); 57 CeedCallHip(ceed, hipFree(impl->diag->d_identity)); 58 CeedCallHip(ceed, hipFree(impl->diag->d_interpin)); 59 CeedCallHip(ceed, hipFree(impl->diag->d_interpout)); 60 CeedCallHip(ceed, hipFree(impl->diag->d_gradin)); 61 CeedCallHip(ceed, hipFree(impl->diag->d_gradout)); 62 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->pbdiagrstr)); 63 CeedCallBackend(CeedVectorDestroy(&impl->diag->elemdiag)); 64 CeedCallBackend(CeedVectorDestroy(&impl->diag->pbelemdiag)); 65 } 66 CeedCallBackend(CeedFree(&impl->diag)); 67 68 if (impl->asmb) { 69 Ceed ceed; 70 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 71 CeedCallHip(ceed, hipModuleUnload(impl->asmb->module)); 72 CeedCallHip(ceed, hipFree(impl->asmb->d_B_in)); 73 CeedCallHip(ceed, hipFree(impl->asmb->d_B_out)); 74 } 75 CeedCallBackend(CeedFree(&impl->asmb)); 76 77 CeedCallBackend(CeedFree(&impl)); 78 return CEED_ERROR_SUCCESS; 79 } 80 81 //------------------------------------------------------------------------------ 82 // Setup infields or outfields 83 //------------------------------------------------------------------------------ 84 static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool isinput, CeedVector *evecs, CeedVector *qvecs, CeedInt starte, 85 CeedInt numfields, CeedInt Q, CeedInt numelements) { 86 CeedInt dim, size; 87 CeedSize q_size; 88 Ceed ceed; 89 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 90 CeedBasis basis; 91 CeedElemRestriction Erestrict; 92 CeedOperatorField *opfields; 93 CeedQFunctionField *qffields; 94 CeedVector fieldvec; 95 bool strided; 96 bool skiprestrict; 97 98 if (isinput) { 99 CeedCallBackend(CeedOperatorGetFields(op, NULL, &opfields, NULL, NULL)); 100 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qffields, NULL, NULL)); 101 } else { 102 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &opfields)); 103 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qffields)); 104 } 105 106 // Loop over fields 107 for (CeedInt i = 0; i < numfields; i++) { 108 CeedEvalMode emode; 109 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qffields[i], &emode)); 110 111 strided = false; 112 skiprestrict = false; 113 if (emode != CEED_EVAL_WEIGHT) { 114 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opfields[i], &Erestrict)); 115 116 // Check whether this field can skip the element restriction: 117 // must be passive input, with emode NONE, and have a strided restriction with CEED_STRIDES_BACKEND. 118 119 // First, check whether the field is input or output: 120 if (isinput) { 121 // Check for passive input: 122 CeedCallBackend(CeedOperatorFieldGetVector(opfields[i], &fieldvec)); 123 if (fieldvec != CEED_VECTOR_ACTIVE) { 124 // Check emode 125 if (emode == CEED_EVAL_NONE) { 126 // Check for strided restriction 127 CeedCallBackend(CeedElemRestrictionIsStrided(Erestrict, &strided)); 128 if (strided) { 129 // Check if vector is already in preferred backend ordering 130 CeedCallBackend(CeedElemRestrictionHasBackendStrides(Erestrict, &skiprestrict)); 131 } 132 } 133 } 134 } 135 if (skiprestrict) { 136 // We do not need an E-Vector, but will use the input field vector's data directly in the operator application. 137 evecs[i + starte] = NULL; 138 } else { 139 CeedCallBackend(CeedElemRestrictionCreateVector(Erestrict, NULL, &evecs[i + starte])); 140 } 141 } 142 143 switch (emode) { 144 case CEED_EVAL_NONE: 145 CeedCallBackend(CeedQFunctionFieldGetSize(qffields[i], &size)); 146 q_size = (CeedSize)numelements * Q * size; 147 CeedCallBackend(CeedVectorCreate(ceed, q_size, &qvecs[i])); 148 break; 149 case CEED_EVAL_INTERP: 150 CeedCallBackend(CeedQFunctionFieldGetSize(qffields[i], &size)); 151 q_size = (CeedSize)numelements * Q * size; 152 CeedCallBackend(CeedVectorCreate(ceed, q_size, &qvecs[i])); 153 break; 154 case CEED_EVAL_GRAD: 155 CeedCallBackend(CeedOperatorFieldGetBasis(opfields[i], &basis)); 156 CeedCallBackend(CeedQFunctionFieldGetSize(qffields[i], &size)); 157 CeedCallBackend(CeedBasisGetDimension(basis, &dim)); 158 q_size = (CeedSize)numelements * Q * size; 159 CeedCallBackend(CeedVectorCreate(ceed, q_size, &qvecs[i])); 160 break; 161 case CEED_EVAL_WEIGHT: // Only on input fields 162 CeedCallBackend(CeedOperatorFieldGetBasis(opfields[i], &basis)); 163 q_size = (CeedSize)numelements * Q; 164 CeedCallBackend(CeedVectorCreate(ceed, q_size, &qvecs[i])); 165 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, NULL, qvecs[i])); 166 break; 167 case CEED_EVAL_DIV: 168 break; // TODO: Not implemented 169 case CEED_EVAL_CURL: 170 break; // TODO: Not implemented 171 } 172 } 173 return CEED_ERROR_SUCCESS; 174 } 175 176 //------------------------------------------------------------------------------ 177 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 178 //------------------------------------------------------------------------------ 179 static int CeedOperatorSetup_Hip(CeedOperator op) { 180 bool setupdone; 181 CeedCallBackend(CeedOperatorIsSetupDone(op, &setupdone)); 182 if (setupdone) return CEED_ERROR_SUCCESS; 183 Ceed ceed; 184 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 185 CeedOperator_Hip *impl; 186 CeedCallBackend(CeedOperatorGetData(op, &impl)); 187 CeedQFunction qf; 188 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 189 CeedInt Q, numelements, numinputfields, numoutputfields; 190 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 191 CeedCallBackend(CeedOperatorGetNumElements(op, &numelements)); 192 CeedOperatorField *opinputfields, *opoutputfields; 193 CeedCallBackend(CeedOperatorGetFields(op, &numinputfields, &opinputfields, &numoutputfields, &opoutputfields)); 194 CeedQFunctionField *qfinputfields, *qfoutputfields; 195 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields)); 196 197 // Allocate 198 CeedCallBackend(CeedCalloc(numinputfields + numoutputfields, &impl->evecs)); 199 200 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->qvecsin)); 201 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->qvecsout)); 202 203 impl->numein = numinputfields; 204 impl->numeout = numoutputfields; 205 206 // Set up infield and outfield evecs and qvecs 207 // Infields 208 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, impl->evecs, impl->qvecsin, 0, numinputfields, Q, numelements)); 209 210 // Outfields 211 CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, impl->evecs, impl->qvecsout, numinputfields, numoutputfields, Q, numelements)); 212 213 CeedCallBackend(CeedOperatorSetSetupDone(op)); 214 return CEED_ERROR_SUCCESS; 215 } 216 217 //------------------------------------------------------------------------------ 218 // Setup Operator Inputs 219 //------------------------------------------------------------------------------ 220 static inline int CeedOperatorSetupInputs_Hip(CeedInt numinputfields, CeedQFunctionField *qfinputfields, CeedOperatorField *opinputfields, 221 CeedVector invec, const bool skipactive, CeedScalar *edata[2 * CEED_FIELD_MAX], CeedOperator_Hip *impl, 222 CeedRequest *request) { 223 CeedEvalMode emode; 224 CeedVector vec; 225 CeedElemRestriction Erestrict; 226 227 for (CeedInt i = 0; i < numinputfields; i++) { 228 // Get input vector 229 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 230 if (vec == CEED_VECTOR_ACTIVE) { 231 if (skipactive) continue; 232 else vec = invec; 233 } 234 235 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode)); 236 if (emode == CEED_EVAL_WEIGHT) { // Skip 237 } else { 238 // Get input vector 239 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 240 // Get input element restriction 241 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict)); 242 if (vec == CEED_VECTOR_ACTIVE) vec = invec; 243 // Restrict, if necessary 244 if (!impl->evecs[i]) { 245 // No restriction for this field; read data directly from vec. 246 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, (const CeedScalar **)&edata[i])); 247 } else { 248 CeedCallBackend(CeedElemRestrictionApply(Erestrict, CEED_NOTRANSPOSE, vec, impl->evecs[i], request)); 249 // Get evec 250 CeedCallBackend(CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_DEVICE, (const CeedScalar **)&edata[i])); 251 } 252 } 253 } 254 return CEED_ERROR_SUCCESS; 255 } 256 257 //------------------------------------------------------------------------------ 258 // Input Basis Action 259 //------------------------------------------------------------------------------ 260 static inline int CeedOperatorInputBasis_Hip(CeedInt numelements, CeedQFunctionField *qfinputfields, CeedOperatorField *opinputfields, 261 CeedInt numinputfields, const bool skipactive, CeedScalar *edata[2 * CEED_FIELD_MAX], 262 CeedOperator_Hip *impl) { 263 CeedInt elemsize, size; 264 CeedElemRestriction Erestrict; 265 CeedEvalMode emode; 266 CeedBasis basis; 267 268 for (CeedInt i = 0; i < numinputfields; i++) { 269 // Skip active input 270 if (skipactive) { 271 CeedVector vec; 272 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 273 if (vec == CEED_VECTOR_ACTIVE) continue; 274 } 275 // Get elemsize, emode, size 276 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict)); 277 CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elemsize)); 278 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode)); 279 CeedCallBackend(CeedQFunctionFieldGetSize(qfinputfields[i], &size)); 280 // Basis action 281 switch (emode) { 282 case CEED_EVAL_NONE: 283 CeedCallBackend(CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_DEVICE, CEED_USE_POINTER, edata[i])); 284 break; 285 case CEED_EVAL_INTERP: 286 CeedCallBackend(CeedOperatorFieldGetBasis(opinputfields[i], &basis)); 287 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_NOTRANSPOSE, CEED_EVAL_INTERP, impl->evecs[i], impl->qvecsin[i])); 288 break; 289 case CEED_EVAL_GRAD: 290 CeedCallBackend(CeedOperatorFieldGetBasis(opinputfields[i], &basis)); 291 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_NOTRANSPOSE, CEED_EVAL_GRAD, impl->evecs[i], impl->qvecsin[i])); 292 break; 293 case CEED_EVAL_WEIGHT: 294 break; // No action 295 case CEED_EVAL_DIV: 296 break; // TODO: Not implemented 297 case CEED_EVAL_CURL: 298 break; // TODO: Not implemented 299 } 300 } 301 return CEED_ERROR_SUCCESS; 302 } 303 304 //------------------------------------------------------------------------------ 305 // Restore Input Vectors 306 //------------------------------------------------------------------------------ 307 static inline int CeedOperatorRestoreInputs_Hip(CeedInt numinputfields, CeedQFunctionField *qfinputfields, CeedOperatorField *opinputfields, 308 const bool skipactive, CeedScalar *edata[2 * CEED_FIELD_MAX], CeedOperator_Hip *impl) { 309 CeedEvalMode emode; 310 CeedVector vec; 311 312 for (CeedInt i = 0; i < numinputfields; i++) { 313 // Skip active input 314 if (skipactive) { 315 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 316 if (vec == CEED_VECTOR_ACTIVE) continue; 317 } 318 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode)); 319 if (emode == CEED_EVAL_WEIGHT) { // Skip 320 } else { 321 if (!impl->evecs[i]) { // This was a skiprestrict case 322 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 323 CeedCallBackend(CeedVectorRestoreArrayRead(vec, (const CeedScalar **)&edata[i])); 324 } else { 325 CeedCallBackend(CeedVectorRestoreArrayRead(impl->evecs[i], (const CeedScalar **)&edata[i])); 326 } 327 } 328 } 329 return CEED_ERROR_SUCCESS; 330 } 331 332 //------------------------------------------------------------------------------ 333 // Apply and add to output 334 //------------------------------------------------------------------------------ 335 static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector invec, CeedVector outvec, CeedRequest *request) { 336 CeedOperator_Hip *impl; 337 CeedCallBackend(CeedOperatorGetData(op, &impl)); 338 CeedQFunction qf; 339 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 340 CeedInt Q, numelements, elemsize, numinputfields, numoutputfields, size; 341 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 342 CeedCallBackend(CeedOperatorGetNumElements(op, &numelements)); 343 CeedOperatorField *opinputfields, *opoutputfields; 344 CeedCallBackend(CeedOperatorGetFields(op, &numinputfields, &opinputfields, &numoutputfields, &opoutputfields)); 345 CeedQFunctionField *qfinputfields, *qfoutputfields; 346 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields)); 347 CeedEvalMode emode; 348 CeedVector vec; 349 CeedBasis basis; 350 CeedElemRestriction Erestrict; 351 CeedScalar *edata[2 * CEED_FIELD_MAX]; 352 353 // Setup 354 CeedCallBackend(CeedOperatorSetup_Hip(op)); 355 356 // Input Evecs and Restriction 357 CeedCallBackend(CeedOperatorSetupInputs_Hip(numinputfields, qfinputfields, opinputfields, invec, false, edata, impl, request)); 358 359 // Input basis apply if needed 360 CeedCallBackend(CeedOperatorInputBasis_Hip(numelements, qfinputfields, opinputfields, numinputfields, false, edata, impl)); 361 362 // Output pointers, as necessary 363 for (CeedInt i = 0; i < numoutputfields; i++) { 364 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode)); 365 if (emode == CEED_EVAL_NONE) { 366 // Set the output Q-Vector to use the E-Vector data directly. 367 CeedCallBackend(CeedVectorGetArrayWrite(impl->evecs[i + impl->numein], CEED_MEM_DEVICE, &edata[i + numinputfields])); 368 CeedCallBackend(CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_DEVICE, CEED_USE_POINTER, edata[i + numinputfields])); 369 } 370 } 371 372 // Q function 373 CeedCallBackend(CeedQFunctionApply(qf, numelements * Q, impl->qvecsin, impl->qvecsout)); 374 375 // Output basis apply if needed 376 for (CeedInt i = 0; i < numoutputfields; i++) { 377 // Get elemsize, emode, size 378 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict)); 379 CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elemsize)); 380 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode)); 381 CeedCallBackend(CeedQFunctionFieldGetSize(qfoutputfields[i], &size)); 382 // Basis action 383 switch (emode) { 384 case CEED_EVAL_NONE: 385 break; 386 case CEED_EVAL_INTERP: 387 CeedCallBackend(CeedOperatorFieldGetBasis(opoutputfields[i], &basis)); 388 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_TRANSPOSE, CEED_EVAL_INTERP, impl->qvecsout[i], impl->evecs[i + impl->numein])); 389 break; 390 case CEED_EVAL_GRAD: 391 CeedCallBackend(CeedOperatorFieldGetBasis(opoutputfields[i], &basis)); 392 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_TRANSPOSE, CEED_EVAL_GRAD, impl->qvecsout[i], impl->evecs[i + impl->numein])); 393 break; 394 // LCOV_EXCL_START 395 case CEED_EVAL_WEIGHT: { 396 Ceed ceed; 397 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 398 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 399 break; // Should not occur 400 } 401 case CEED_EVAL_DIV: 402 break; // TODO: Not implemented 403 case CEED_EVAL_CURL: 404 break; // TODO: Not implemented 405 // LCOV_EXCL_STOP 406 } 407 } 408 409 // Output restriction 410 for (CeedInt i = 0; i < numoutputfields; i++) { 411 // Restore evec 412 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode)); 413 if (emode == CEED_EVAL_NONE) { 414 CeedCallBackend(CeedVectorRestoreArray(impl->evecs[i + impl->numein], &edata[i + numinputfields])); 415 } 416 // Get output vector 417 CeedCallBackend(CeedOperatorFieldGetVector(opoutputfields[i], &vec)); 418 // Restrict 419 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict)); 420 // Active 421 if (vec == CEED_VECTOR_ACTIVE) vec = outvec; 422 423 CeedCallBackend(CeedElemRestrictionApply(Erestrict, CEED_TRANSPOSE, impl->evecs[i + impl->numein], vec, request)); 424 } 425 426 // Restore input arrays 427 CeedCallBackend(CeedOperatorRestoreInputs_Hip(numinputfields, qfinputfields, opinputfields, false, edata, impl)); 428 return CEED_ERROR_SUCCESS; 429 } 430 431 //------------------------------------------------------------------------------ 432 // Core code for assembling linear QFunction 433 //------------------------------------------------------------------------------ 434 static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 435 CeedRequest *request) { 436 CeedOperator_Hip *impl; 437 CeedCallBackend(CeedOperatorGetData(op, &impl)); 438 CeedQFunction qf; 439 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 440 CeedInt Q, numelements, numinputfields, numoutputfields, size; 441 CeedSize q_size; 442 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 443 CeedCallBackend(CeedOperatorGetNumElements(op, &numelements)); 444 CeedOperatorField *opinputfields, *opoutputfields; 445 CeedCallBackend(CeedOperatorGetFields(op, &numinputfields, &opinputfields, &numoutputfields, &opoutputfields)); 446 CeedQFunctionField *qfinputfields, *qfoutputfields; 447 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields)); 448 CeedVector vec; 449 CeedInt numactivein = impl->qfnumactivein, numactiveout = impl->qfnumactiveout; 450 CeedVector *activein = impl->qfactivein; 451 CeedScalar *a, *tmp; 452 Ceed ceed, ceedparent; 453 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 454 CeedCallBackend(CeedGetOperatorFallbackParentCeed(ceed, &ceedparent)); 455 ceedparent = ceedparent ? ceedparent : ceed; 456 CeedScalar *edata[2 * CEED_FIELD_MAX]; 457 458 // Setup 459 CeedCallBackend(CeedOperatorSetup_Hip(op)); 460 461 // Check for identity 462 bool identityqf; 463 CeedCallBackend(CeedQFunctionIsIdentity(qf, &identityqf)); 464 if (identityqf) { 465 // LCOV_EXCL_START 466 return CeedError(ceed, CEED_ERROR_BACKEND, "Assembling identity QFunctions not supported"); 467 // LCOV_EXCL_STOP 468 } 469 470 // Input Evecs and Restriction 471 CeedCallBackend(CeedOperatorSetupInputs_Hip(numinputfields, qfinputfields, opinputfields, NULL, true, edata, impl, request)); 472 473 // Count number of active input fields 474 if (!numactivein) { 475 for (CeedInt i = 0; i < numinputfields; i++) { 476 // Get input vector 477 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 478 // Check if active input 479 if (vec == CEED_VECTOR_ACTIVE) { 480 CeedCallBackend(CeedQFunctionFieldGetSize(qfinputfields[i], &size)); 481 CeedCallBackend(CeedVectorSetValue(impl->qvecsin[i], 0.0)); 482 CeedCallBackend(CeedVectorGetArray(impl->qvecsin[i], CEED_MEM_DEVICE, &tmp)); 483 CeedCallBackend(CeedRealloc(numactivein + size, &activein)); 484 for (CeedInt field = 0; field < size; field++) { 485 q_size = (CeedSize)Q * numelements; 486 CeedCallBackend(CeedVectorCreate(ceed, q_size, &activein[numactivein + field])); 487 CeedCallBackend(CeedVectorSetArray(activein[numactivein + field], CEED_MEM_DEVICE, CEED_USE_POINTER, &tmp[field * Q * numelements])); 488 } 489 numactivein += size; 490 CeedCallBackend(CeedVectorRestoreArray(impl->qvecsin[i], &tmp)); 491 } 492 } 493 impl->qfnumactivein = numactivein; 494 impl->qfactivein = activein; 495 } 496 497 // Count number of active output fields 498 if (!numactiveout) { 499 for (CeedInt i = 0; i < numoutputfields; i++) { 500 // Get output vector 501 CeedCallBackend(CeedOperatorFieldGetVector(opoutputfields[i], &vec)); 502 // Check if active output 503 if (vec == CEED_VECTOR_ACTIVE) { 504 CeedCallBackend(CeedQFunctionFieldGetSize(qfoutputfields[i], &size)); 505 numactiveout += size; 506 } 507 } 508 impl->qfnumactiveout = numactiveout; 509 } 510 511 // Check sizes 512 if (!numactivein || !numactiveout) { 513 // LCOV_EXCL_START 514 return CeedError(ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 515 // LCOV_EXCL_STOP 516 } 517 518 // Build objects if needed 519 if (build_objects) { 520 // Create output restriction 521 CeedInt strides[3] = {1, numelements * Q, Q}; /* *NOPAD* */ 522 CeedCallBackend(CeedElemRestrictionCreateStrided(ceedparent, numelements, Q, numactivein * numactiveout, 523 numactivein * numactiveout * numelements * Q, strides, rstr)); 524 // Create assembled vector 525 CeedSize l_size = (CeedSize)numelements * Q * numactivein * numactiveout; 526 CeedCallBackend(CeedVectorCreate(ceedparent, l_size, assembled)); 527 } 528 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 529 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_DEVICE, &a)); 530 531 // Input basis apply 532 CeedCallBackend(CeedOperatorInputBasis_Hip(numelements, qfinputfields, opinputfields, numinputfields, true, edata, impl)); 533 534 // Assemble QFunction 535 for (CeedInt in = 0; in < numactivein; in++) { 536 // Set Inputs 537 CeedCallBackend(CeedVectorSetValue(activein[in], 1.0)); 538 if (numactivein > 1) { 539 CeedCallBackend(CeedVectorSetValue(activein[(in + numactivein - 1) % numactivein], 0.0)); 540 } 541 // Set Outputs 542 for (CeedInt out = 0; out < numoutputfields; out++) { 543 // Get output vector 544 CeedCallBackend(CeedOperatorFieldGetVector(opoutputfields[out], &vec)); 545 // Check if active output 546 if (vec == CEED_VECTOR_ACTIVE) { 547 CeedCallBackend(CeedVectorSetArray(impl->qvecsout[out], CEED_MEM_DEVICE, CEED_USE_POINTER, a)); 548 CeedCallBackend(CeedQFunctionFieldGetSize(qfoutputfields[out], &size)); 549 a += size * Q * numelements; // Advance the pointer by the size of the output 550 } 551 } 552 // Apply QFunction 553 CeedCallBackend(CeedQFunctionApply(qf, Q * numelements, impl->qvecsin, impl->qvecsout)); 554 } 555 556 // Un-set output Qvecs to prevent accidental overwrite of Assembled 557 for (CeedInt out = 0; out < numoutputfields; out++) { 558 // Get output vector 559 CeedCallBackend(CeedOperatorFieldGetVector(opoutputfields[out], &vec)); 560 // Check if active output 561 if (vec == CEED_VECTOR_ACTIVE) { 562 CeedCallBackend(CeedVectorTakeArray(impl->qvecsout[out], CEED_MEM_DEVICE, NULL)); 563 } 564 } 565 566 // Restore input arrays 567 CeedCallBackend(CeedOperatorRestoreInputs_Hip(numinputfields, qfinputfields, opinputfields, true, edata, impl)); 568 569 // Restore output 570 CeedCallBackend(CeedVectorRestoreArray(*assembled, &a)); 571 572 return CEED_ERROR_SUCCESS; 573 } 574 575 //------------------------------------------------------------------------------ 576 // Assemble Linear QFunction 577 //------------------------------------------------------------------------------ 578 static int CeedOperatorLinearAssembleQFunction_Hip(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 579 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, true, assembled, rstr, request); 580 } 581 582 //------------------------------------------------------------------------------ 583 // Assemble Linear QFunction 584 //------------------------------------------------------------------------------ 585 static int CeedOperatorLinearAssembleQFunctionUpdate_Hip(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 586 return CeedOperatorLinearAssembleQFunctionCore_Hip(op, false, &assembled, &rstr, request); 587 } 588 589 //------------------------------------------------------------------------------ 590 // Create point block restriction 591 //------------------------------------------------------------------------------ 592 static int CreatePBRestriction(CeedElemRestriction rstr, CeedElemRestriction *pbRstr) { 593 Ceed ceed; 594 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 595 const CeedInt *offsets; 596 CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets)); 597 598 // Expand offsets 599 CeedInt nelem, ncomp, elemsize, compstride, *pbOffsets; 600 CeedSize l_size; 601 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &nelem)); 602 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &ncomp)); 603 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elemsize)); 604 CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &compstride)); 605 CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size)); 606 CeedInt shift = ncomp; 607 if (compstride != 1) shift *= ncomp; 608 CeedCallBackend(CeedCalloc(nelem * elemsize, &pbOffsets)); 609 for (CeedInt i = 0; i < nelem * elemsize; i++) { 610 pbOffsets[i] = offsets[i] * shift; 611 } 612 613 // Create new restriction 614 CeedCallBackend( 615 CeedElemRestrictionCreate(ceed, nelem, elemsize, ncomp * ncomp, 1, l_size * ncomp, CEED_MEM_HOST, CEED_OWN_POINTER, pbOffsets, pbRstr)); 616 617 // Cleanup 618 CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets)); 619 620 return CEED_ERROR_SUCCESS; 621 } 622 623 //------------------------------------------------------------------------------ 624 // Assemble diagonal setup 625 //------------------------------------------------------------------------------ 626 static inline int CeedOperatorAssembleDiagonalSetup_Hip(CeedOperator op, const bool pointBlock) { 627 Ceed ceed; 628 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 629 CeedQFunction qf; 630 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 631 CeedInt numinputfields, numoutputfields; 632 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields)); 633 634 // Determine active input basis 635 CeedOperatorField *opfields; 636 CeedQFunctionField *qffields; 637 CeedCallBackend(CeedOperatorGetFields(op, NULL, &opfields, NULL, NULL)); 638 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qffields, NULL, NULL)); 639 CeedInt numemodein = 0, ncomp = 0, dim = 1; 640 CeedEvalMode *emodein = NULL; 641 CeedBasis basisin = NULL; 642 CeedElemRestriction rstrin = NULL; 643 for (CeedInt i = 0; i < numinputfields; i++) { 644 CeedVector vec; 645 CeedCallBackend(CeedOperatorFieldGetVector(opfields[i], &vec)); 646 if (vec == CEED_VECTOR_ACTIVE) { 647 CeedElemRestriction rstr; 648 CeedCallBackend(CeedOperatorFieldGetBasis(opfields[i], &basisin)); 649 CeedCallBackend(CeedBasisGetNumComponents(basisin, &ncomp)); 650 CeedCallBackend(CeedBasisGetDimension(basisin, &dim)); 651 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opfields[i], &rstr)); 652 if (rstrin && rstrin != rstr) { 653 // LCOV_EXCL_START 654 return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator diagonal assembly"); 655 // LCOV_EXCL_STOP 656 } 657 rstrin = rstr; 658 CeedEvalMode emode; 659 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qffields[i], &emode)); 660 switch (emode) { 661 case CEED_EVAL_NONE: 662 case CEED_EVAL_INTERP: 663 CeedCallBackend(CeedRealloc(numemodein + 1, &emodein)); 664 emodein[numemodein] = emode; 665 numemodein += 1; 666 break; 667 case CEED_EVAL_GRAD: 668 CeedCallBackend(CeedRealloc(numemodein + dim, &emodein)); 669 for (CeedInt d = 0; d < dim; d++) emodein[numemodein + d] = emode; 670 numemodein += dim; 671 break; 672 case CEED_EVAL_WEIGHT: 673 case CEED_EVAL_DIV: 674 case CEED_EVAL_CURL: 675 break; // Caught by QF Assembly 676 } 677 } 678 } 679 680 // Determine active output basis 681 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &opfields)); 682 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qffields)); 683 CeedInt numemodeout = 0; 684 CeedEvalMode *emodeout = NULL; 685 CeedBasis basisout = NULL; 686 CeedElemRestriction rstrout = NULL; 687 for (CeedInt i = 0; i < numoutputfields; i++) { 688 CeedVector vec; 689 CeedCallBackend(CeedOperatorFieldGetVector(opfields[i], &vec)); 690 if (vec == CEED_VECTOR_ACTIVE) { 691 CeedElemRestriction rstr; 692 CeedCallBackend(CeedOperatorFieldGetBasis(opfields[i], &basisout)); 693 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opfields[i], &rstr)); 694 if (rstrout && rstrout != rstr) { 695 // LCOV_EXCL_START 696 return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator diagonal assembly"); 697 // LCOV_EXCL_STOP 698 } 699 rstrout = rstr; 700 CeedEvalMode emode; 701 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qffields[i], &emode)); 702 switch (emode) { 703 case CEED_EVAL_NONE: 704 case CEED_EVAL_INTERP: 705 CeedCallBackend(CeedRealloc(numemodeout + 1, &emodeout)); 706 emodeout[numemodeout] = emode; 707 numemodeout += 1; 708 break; 709 case CEED_EVAL_GRAD: 710 CeedCallBackend(CeedRealloc(numemodeout + dim, &emodeout)); 711 for (CeedInt d = 0; d < dim; d++) emodeout[numemodeout + d] = emode; 712 numemodeout += dim; 713 break; 714 case CEED_EVAL_WEIGHT: 715 case CEED_EVAL_DIV: 716 case CEED_EVAL_CURL: 717 break; // Caught by QF Assembly 718 } 719 } 720 } 721 722 // Operator data struct 723 CeedOperator_Hip *impl; 724 CeedCallBackend(CeedOperatorGetData(op, &impl)); 725 CeedCallBackend(CeedCalloc(1, &impl->diag)); 726 CeedOperatorDiag_Hip *diag = impl->diag; 727 diag->basisin = basisin; 728 diag->basisout = basisout; 729 diag->h_emodein = emodein; 730 diag->h_emodeout = emodeout; 731 diag->numemodein = numemodein; 732 diag->numemodeout = numemodeout; 733 734 // Assemble kernel 735 736 char *diagonal_kernel_path, *diagonal_kernel_source; 737 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h", &diagonal_kernel_path)); 738 CeedDebug256(ceed, 2, "----- Loading Diagonal Assembly Kernel Source -----\n"); 739 CeedCallBackend(CeedLoadSourceToBuffer(ceed, diagonal_kernel_path, &diagonal_kernel_source)); 740 CeedDebug256(ceed, 2, "----- Loading Diagonal Assembly Source Complete! -----\n"); 741 CeedInt nnodes, nqpts; 742 CeedCallBackend(CeedBasisGetNumNodes(basisin, &nnodes)); 743 CeedCallBackend(CeedBasisGetNumQuadraturePoints(basisin, &nqpts)); 744 diag->nnodes = nnodes; 745 CeedCallBackend(CeedCompileHip(ceed, diagonal_kernel_source, &diag->module, 5, "NUMEMODEIN", numemodein, "NUMEMODEOUT", numemodeout, "NNODES", 746 nnodes, "NQPTS", nqpts, "NCOMP", ncomp)); 747 CeedCallBackend(CeedGetKernelHip(ceed, diag->module, "linearDiagonal", &diag->linearDiagonal)); 748 CeedCallBackend(CeedGetKernelHip(ceed, diag->module, "linearPointBlockDiagonal", &diag->linearPointBlock)); 749 CeedCallBackend(CeedFree(&diagonal_kernel_path)); 750 CeedCallBackend(CeedFree(&diagonal_kernel_source)); 751 752 // Basis matrices 753 const CeedInt qBytes = nqpts * sizeof(CeedScalar); 754 const CeedInt iBytes = qBytes * nnodes; 755 const CeedInt gBytes = qBytes * nnodes * dim; 756 const CeedInt eBytes = sizeof(CeedEvalMode); 757 const CeedScalar *interpin, *interpout, *gradin, *gradout; 758 759 // CEED_EVAL_NONE 760 CeedScalar *identity = NULL; 761 bool evalNone = false; 762 for (CeedInt i = 0; i < numemodein; i++) evalNone = evalNone || (emodein[i] == CEED_EVAL_NONE); 763 for (CeedInt i = 0; i < numemodeout; i++) evalNone = evalNone || (emodeout[i] == CEED_EVAL_NONE); 764 if (evalNone) { 765 CeedCallBackend(CeedCalloc(nqpts * nnodes, &identity)); 766 for (CeedInt i = 0; i < (nnodes < nqpts ? nnodes : nqpts); i++) identity[i * nnodes + i] = 1.0; 767 CeedCallHip(ceed, hipMalloc((void **)&diag->d_identity, iBytes)); 768 CeedCallHip(ceed, hipMemcpy(diag->d_identity, identity, iBytes, hipMemcpyHostToDevice)); 769 } 770 771 // CEED_EVAL_INTERP 772 CeedCallBackend(CeedBasisGetInterp(basisin, &interpin)); 773 CeedCallHip(ceed, hipMalloc((void **)&diag->d_interpin, iBytes)); 774 CeedCallHip(ceed, hipMemcpy(diag->d_interpin, interpin, iBytes, hipMemcpyHostToDevice)); 775 CeedCallBackend(CeedBasisGetInterp(basisout, &interpout)); 776 CeedCallHip(ceed, hipMalloc((void **)&diag->d_interpout, iBytes)); 777 CeedCallHip(ceed, hipMemcpy(diag->d_interpout, interpout, iBytes, hipMemcpyHostToDevice)); 778 779 // CEED_EVAL_GRAD 780 CeedCallBackend(CeedBasisGetGrad(basisin, &gradin)); 781 CeedCallHip(ceed, hipMalloc((void **)&diag->d_gradin, gBytes)); 782 CeedCallHip(ceed, hipMemcpy(diag->d_gradin, gradin, gBytes, hipMemcpyHostToDevice)); 783 CeedCallBackend(CeedBasisGetGrad(basisout, &gradout)); 784 CeedCallHip(ceed, hipMalloc((void **)&diag->d_gradout, gBytes)); 785 CeedCallHip(ceed, hipMemcpy(diag->d_gradout, gradout, gBytes, hipMemcpyHostToDevice)); 786 787 // Arrays of emodes 788 CeedCallHip(ceed, hipMalloc((void **)&diag->d_emodein, numemodein * eBytes)); 789 CeedCallHip(ceed, hipMemcpy(diag->d_emodein, emodein, numemodein * eBytes, hipMemcpyHostToDevice)); 790 CeedCallHip(ceed, hipMalloc((void **)&diag->d_emodeout, numemodeout * eBytes)); 791 CeedCallHip(ceed, hipMemcpy(diag->d_emodeout, emodeout, numemodeout * eBytes, hipMemcpyHostToDevice)); 792 793 // Restriction 794 diag->diagrstr = rstrout; 795 796 return CEED_ERROR_SUCCESS; 797 } 798 799 //------------------------------------------------------------------------------ 800 // Assemble diagonal common code 801 //------------------------------------------------------------------------------ 802 static inline int CeedOperatorAssembleDiagonalCore_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool pointBlock) { 803 Ceed ceed; 804 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 805 CeedOperator_Hip *impl; 806 CeedCallBackend(CeedOperatorGetData(op, &impl)); 807 808 // Assemble QFunction 809 CeedVector assembledqf; 810 CeedElemRestriction rstr; 811 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembledqf, &rstr, request)); 812 CeedCallBackend(CeedElemRestrictionDestroy(&rstr)); 813 814 // Setup 815 if (!impl->diag) CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Hip(op, pointBlock)); 816 CeedOperatorDiag_Hip *diag = impl->diag; 817 assert(diag != NULL); 818 819 // Restriction 820 if (pointBlock && !diag->pbdiagrstr) { 821 CeedElemRestriction pbdiagrstr; 822 CeedCallBackend(CreatePBRestriction(diag->diagrstr, &pbdiagrstr)); 823 diag->pbdiagrstr = pbdiagrstr; 824 } 825 CeedElemRestriction diagrstr = pointBlock ? diag->pbdiagrstr : diag->diagrstr; 826 827 // Create diagonal vector 828 CeedVector elemdiag = pointBlock ? diag->pbelemdiag : diag->elemdiag; 829 if (!elemdiag) { 830 // Element diagonal vector 831 CeedCallBackend(CeedElemRestrictionCreateVector(diagrstr, NULL, &elemdiag)); 832 if (pointBlock) diag->pbelemdiag = elemdiag; 833 else diag->elemdiag = elemdiag; 834 } 835 CeedCallBackend(CeedVectorSetValue(elemdiag, 0.0)); 836 837 // Assemble element operator diagonals 838 CeedScalar *elemdiagarray; 839 const CeedScalar *assembledqfarray; 840 CeedCallBackend(CeedVectorGetArray(elemdiag, CEED_MEM_DEVICE, &elemdiagarray)); 841 CeedCallBackend(CeedVectorGetArrayRead(assembledqf, CEED_MEM_DEVICE, &assembledqfarray)); 842 CeedInt nelem; 843 CeedCallBackend(CeedElemRestrictionGetNumElements(diagrstr, &nelem)); 844 845 // Compute the diagonal of B^T D B 846 int elemsPerBlock = 1; 847 int grid = nelem / elemsPerBlock + ((nelem / elemsPerBlock * elemsPerBlock < nelem) ? 1 : 0); 848 void *args[] = {(void *)&nelem, &diag->d_identity, &diag->d_interpin, &diag->d_gradin, &diag->d_interpout, 849 &diag->d_gradout, &diag->d_emodein, &diag->d_emodeout, &assembledqfarray, &elemdiagarray}; 850 if (pointBlock) { 851 CeedCallBackend(CeedRunKernelDimHip(ceed, diag->linearPointBlock, grid, diag->nnodes, 1, elemsPerBlock, args)); 852 } else { 853 CeedCallBackend(CeedRunKernelDimHip(ceed, diag->linearDiagonal, grid, diag->nnodes, 1, elemsPerBlock, args)); 854 } 855 856 // Restore arrays 857 CeedCallBackend(CeedVectorRestoreArray(elemdiag, &elemdiagarray)); 858 CeedCallBackend(CeedVectorRestoreArrayRead(assembledqf, &assembledqfarray)); 859 860 // Assemble local operator diagonal 861 CeedCallBackend(CeedElemRestrictionApply(diagrstr, CEED_TRANSPOSE, elemdiag, assembled, request)); 862 863 // Cleanup 864 CeedCallBackend(CeedVectorDestroy(&assembledqf)); 865 866 return CEED_ERROR_SUCCESS; 867 } 868 869 //------------------------------------------------------------------------------ 870 // Assemble Linear Diagonal 871 //------------------------------------------------------------------------------ 872 static int CeedOperatorLinearAssembleAddDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 873 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, false)); 874 return CEED_ERROR_SUCCESS; 875 } 876 877 //------------------------------------------------------------------------------ 878 // Assemble Linear Point Block Diagonal 879 //------------------------------------------------------------------------------ 880 static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { 881 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Hip(op, assembled, request, true)); 882 return CEED_ERROR_SUCCESS; 883 } 884 885 //------------------------------------------------------------------------------ 886 // Single operator assembly setup 887 //------------------------------------------------------------------------------ 888 static int CeedSingleOperatorAssembleSetup_Hip(CeedOperator op) { 889 Ceed ceed; 890 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 891 CeedOperator_Hip *impl; 892 CeedCallBackend(CeedOperatorGetData(op, &impl)); 893 894 // Get intput and output fields 895 CeedInt num_input_fields, num_output_fields; 896 CeedOperatorField *input_fields; 897 CeedOperatorField *output_fields; 898 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields)); 899 900 // Determine active input basis eval mode 901 CeedQFunction qf; 902 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 903 CeedQFunctionField *qf_fields; 904 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 905 // Note that the kernel will treat each dimension of a gradient action separately; 906 // i.e., when an active input has a CEED_EVAL_GRAD mode, num_emode_in will increment by dim. 907 // However, for the purposes of loading the B matrices, it will be treated as one mode, and we will load/copy the entire gradient matrix at once, so 908 // num_B_in_mats_to_load will be incremented by 1. 909 CeedInt num_emode_in = 0, dim = 1, num_B_in_mats_to_load = 0, size_B_in = 0; 910 CeedEvalMode *eval_mode_in = NULL; // will be of size num_B_in_mats_load 911 CeedBasis basis_in = NULL; 912 CeedInt nqpts = 0, esize = 0; 913 CeedElemRestriction rstr_in = NULL; 914 for (CeedInt i = 0; i < num_input_fields; i++) { 915 CeedVector vec; 916 CeedCallBackend(CeedOperatorFieldGetVector(input_fields[i], &vec)); 917 if (vec == CEED_VECTOR_ACTIVE) { 918 CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis_in)); 919 CeedCallBackend(CeedBasisGetDimension(basis_in, &dim)); 920 CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &nqpts)); 921 CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &rstr_in)); 922 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &esize)); 923 CeedEvalMode eval_mode; 924 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 925 if (eval_mode != CEED_EVAL_NONE) { 926 CeedCallBackend(CeedRealloc(num_B_in_mats_to_load + 1, &eval_mode_in)); 927 eval_mode_in[num_B_in_mats_to_load] = eval_mode; 928 num_B_in_mats_to_load += 1; 929 if (eval_mode == CEED_EVAL_GRAD) { 930 num_emode_in += dim; 931 size_B_in += dim * esize * nqpts; 932 } else { 933 num_emode_in += 1; 934 size_B_in += esize * nqpts; 935 } 936 } 937 } 938 } 939 940 // Determine active output basis; basis_out and rstr_out only used if same as input, TODO 941 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 942 CeedInt num_emode_out = 0, num_B_out_mats_to_load = 0, size_B_out = 0; 943 CeedEvalMode *eval_mode_out = NULL; 944 CeedBasis basis_out = NULL; 945 CeedElemRestriction rstr_out = NULL; 946 for (CeedInt i = 0; i < num_output_fields; i++) { 947 CeedVector vec; 948 CeedCallBackend(CeedOperatorFieldGetVector(output_fields[i], &vec)); 949 if (vec == CEED_VECTOR_ACTIVE) { 950 CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis_out)); 951 CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &rstr_out)); 952 if (rstr_out && rstr_out != rstr_in) { 953 // LCOV_EXCL_START 954 return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator assembly"); 955 // LCOV_EXCL_STOP 956 } 957 CeedEvalMode eval_mode; 958 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 959 if (eval_mode != CEED_EVAL_NONE) { 960 CeedCallBackend(CeedRealloc(num_B_out_mats_to_load + 1, &eval_mode_out)); 961 eval_mode_out[num_B_out_mats_to_load] = eval_mode; 962 num_B_out_mats_to_load += 1; 963 if (eval_mode == CEED_EVAL_GRAD) { 964 num_emode_out += dim; 965 size_B_out += dim * esize * nqpts; 966 } else { 967 num_emode_out += 1; 968 size_B_out += esize * nqpts; 969 } 970 } 971 } 972 } 973 974 if (num_emode_in == 0 || num_emode_out == 0) { 975 // LCOV_EXCL_START 976 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs"); 977 // LCOV_EXCL_STOP 978 } 979 980 CeedInt nelem, ncomp; 981 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &nelem)); 982 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &ncomp)); 983 984 CeedCallBackend(CeedCalloc(1, &impl->asmb)); 985 CeedOperatorAssemble_Hip *asmb = impl->asmb; 986 asmb->nelem = nelem; 987 988 // Compile kernels 989 int elemsPerBlock = 1; 990 asmb->elemsPerBlock = elemsPerBlock; 991 CeedInt block_size = esize * esize * elemsPerBlock; 992 char *assembly_kernel_path, *assembly_kernel_source; 993 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-operator-assemble.h", &assembly_kernel_path)); 994 CeedDebug256(ceed, 2, "----- Loading Assembly Kernel Source -----\n"); 995 CeedCallBackend(CeedLoadSourceToBuffer(ceed, assembly_kernel_path, &assembly_kernel_source)); 996 CeedDebug256(ceed, 2, "----- Loading Assembly Source Complete! -----\n"); 997 bool fallback = block_size > 1024; 998 if (fallback) { // Use fallback kernel with 1D threadblock 999 block_size = esize * elemsPerBlock; 1000 asmb->block_size_x = esize; 1001 asmb->block_size_y = 1; 1002 } else { // Use kernel with 2D threadblock 1003 asmb->block_size_x = esize; 1004 asmb->block_size_y = esize; 1005 } 1006 CeedCallBackend(CeedCompileHip(ceed, assembly_kernel_source, &asmb->module, 7, "NELEM", nelem, "NUMEMODEIN", num_emode_in, "NUMEMODEOUT", 1007 num_emode_out, "NQPTS", nqpts, "NNODES", esize, "BLOCK_SIZE", block_size, "NCOMP", ncomp)); 1008 CeedCallBackend(CeedGetKernelHip(ceed, asmb->module, fallback ? "linearAssembleFallback" : "linearAssemble", &asmb->linearAssemble)); 1009 CeedCallBackend(CeedFree(&assembly_kernel_path)); 1010 CeedCallBackend(CeedFree(&assembly_kernel_source)); 1011 1012 // Build 'full' B matrices (not 1D arrays used for tensor-product matrices) 1013 const CeedScalar *interp_in, *grad_in; 1014 CeedCallBackend(CeedBasisGetInterp(basis_in, &interp_in)); 1015 CeedCallBackend(CeedBasisGetGrad(basis_in, &grad_in)); 1016 1017 // Load into B_in, in order that they will be used in eval_mode 1018 const CeedInt inBytes = size_B_in * sizeof(CeedScalar); 1019 CeedInt mat_start = 0; 1020 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_in, inBytes)); 1021 for (int i = 0; i < num_B_in_mats_to_load; i++) { 1022 CeedEvalMode eval_mode = eval_mode_in[i]; 1023 if (eval_mode == CEED_EVAL_INTERP) { 1024 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_in[mat_start], interp_in, esize * nqpts * sizeof(CeedScalar), hipMemcpyHostToDevice)); 1025 mat_start += esize * nqpts; 1026 } else if (eval_mode == CEED_EVAL_GRAD) { 1027 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_in[mat_start], grad_in, dim * esize * nqpts * sizeof(CeedScalar), hipMemcpyHostToDevice)); 1028 mat_start += dim * esize * nqpts; 1029 } 1030 } 1031 1032 const CeedScalar *interp_out, *grad_out; 1033 // Note that this function currently assumes 1 basis, so this should always be true 1034 // for now 1035 if (basis_out == basis_in) { 1036 interp_out = interp_in; 1037 grad_out = grad_in; 1038 } else { 1039 CeedCallBackend(CeedBasisGetInterp(basis_out, &interp_out)); 1040 CeedCallBackend(CeedBasisGetGrad(basis_out, &grad_out)); 1041 } 1042 1043 // Load into B_out, in order that they will be used in eval_mode 1044 const CeedInt outBytes = size_B_out * sizeof(CeedScalar); 1045 mat_start = 0; 1046 CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_out, outBytes)); 1047 for (int i = 0; i < num_B_out_mats_to_load; i++) { 1048 CeedEvalMode eval_mode = eval_mode_out[i]; 1049 if (eval_mode == CEED_EVAL_INTERP) { 1050 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_out[mat_start], interp_out, esize * nqpts * sizeof(CeedScalar), hipMemcpyHostToDevice)); 1051 mat_start += esize * nqpts; 1052 } else if (eval_mode == CEED_EVAL_GRAD) { 1053 CeedCallHip(ceed, hipMemcpy(&asmb->d_B_out[mat_start], grad_out, dim * esize * nqpts * sizeof(CeedScalar), hipMemcpyHostToDevice)); 1054 mat_start += dim * esize * nqpts; 1055 } 1056 } 1057 return CEED_ERROR_SUCCESS; 1058 } 1059 1060 //------------------------------------------------------------------------------ 1061 // Assemble matrix data for COO matrix of assembled operator. 1062 // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic. 1063 // 1064 // Note that this (and other assembly routines) currently assume only one active input restriction/basis per operator (could have multiple basis eval 1065 // modes). 1066 // TODO: allow multiple active input restrictions/basis objects 1067 //------------------------------------------------------------------------------ 1068 static int CeedSingleOperatorAssemble_Hip(CeedOperator op, CeedInt offset, CeedVector values) { 1069 Ceed ceed; 1070 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1071 CeedOperator_Hip *impl; 1072 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1073 1074 // Setup 1075 if (!impl->asmb) { 1076 CeedCallBackend(CeedSingleOperatorAssembleSetup_Hip(op)); 1077 assert(impl->asmb != NULL); 1078 } 1079 1080 // Assemble QFunction 1081 CeedVector assembled_qf; 1082 CeedElemRestriction rstr_q; 1083 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &rstr_q, CEED_REQUEST_IMMEDIATE)); 1084 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_q)); 1085 CeedScalar *values_array; 1086 CeedCallBackend(CeedVectorGetArrayWrite(values, CEED_MEM_DEVICE, &values_array)); 1087 values_array += offset; 1088 const CeedScalar *qf_array; 1089 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &qf_array)); 1090 1091 // Compute B^T D B 1092 const CeedInt nelem = impl->asmb->nelem; // to satisfy clang-tidy 1093 const CeedInt elemsPerBlock = impl->asmb->elemsPerBlock; 1094 const CeedInt grid = nelem / elemsPerBlock + ((nelem / elemsPerBlock * elemsPerBlock < nelem) ? 1 : 0); 1095 void *args[] = {&impl->asmb->d_B_in, &impl->asmb->d_B_out, &qf_array, &values_array}; 1096 CeedCallBackend( 1097 CeedRunKernelDimHip(ceed, impl->asmb->linearAssemble, grid, impl->asmb->block_size_x, impl->asmb->block_size_y, elemsPerBlock, args)); 1098 1099 // Restore arrays 1100 CeedCallBackend(CeedVectorRestoreArray(values, &values_array)); 1101 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &qf_array)); 1102 1103 // Cleanup 1104 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1105 1106 return CEED_ERROR_SUCCESS; 1107 } 1108 1109 //------------------------------------------------------------------------------ 1110 // Create operator 1111 //------------------------------------------------------------------------------ 1112 int CeedOperatorCreate_Hip(CeedOperator op) { 1113 Ceed ceed; 1114 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1115 CeedOperator_Hip *impl; 1116 1117 CeedCallBackend(CeedCalloc(1, &impl)); 1118 CeedCallBackend(CeedOperatorSetData(op, impl)); 1119 1120 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Hip)); 1121 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Hip)); 1122 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonal_Hip)); 1123 CeedCallBackend( 1124 CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal", CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip)); 1125 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssemble_Hip)); 1126 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip)); 1127 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip)); 1128 return CEED_ERROR_SUCCESS; 1129 } 1130 1131 //------------------------------------------------------------------------------ 1132