1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <assert.h> 9 #include <ceed/backend.h> 10 #include <ceed/ceed.h> 11 #include <ceed/jit-tools.h> 12 #include <cuda.h> 13 #include <cuda_runtime.h> 14 #include <stdbool.h> 15 #include <string.h> 16 17 #include "../cuda/ceed-cuda-compile.h" 18 #include "ceed-cuda-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Destroy operator 22 //------------------------------------------------------------------------------ 23 static int CeedOperatorDestroy_Cuda(CeedOperator op) { 24 CeedOperator_Cuda *impl; 25 CeedCallBackend(CeedOperatorGetData(op, &impl)); 26 27 // Apply data 28 for (CeedInt i = 0; i < impl->numein + impl->numeout; i++) { 29 CeedCallBackend(CeedVectorDestroy(&impl->evecs[i])); 30 } 31 CeedCallBackend(CeedFree(&impl->evecs)); 32 33 for (CeedInt i = 0; i < impl->numein; i++) { 34 CeedCallBackend(CeedVectorDestroy(&impl->qvecsin[i])); 35 } 36 CeedCallBackend(CeedFree(&impl->qvecsin)); 37 38 for (CeedInt i = 0; i < impl->numeout; i++) { 39 CeedCallBackend(CeedVectorDestroy(&impl->qvecsout[i])); 40 } 41 CeedCallBackend(CeedFree(&impl->qvecsout)); 42 43 // QFunction assembly data 44 for (CeedInt i = 0; i < impl->qfnumactivein; i++) { 45 CeedCallBackend(CeedVectorDestroy(&impl->qfactivein[i])); 46 } 47 CeedCallBackend(CeedFree(&impl->qfactivein)); 48 49 // Diag data 50 if (impl->diag) { 51 Ceed ceed; 52 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 53 CeedCallCuda(ceed, cuModuleUnload(impl->diag->module)); 54 CeedCallBackend(CeedFree(&impl->diag->h_emodein)); 55 CeedCallBackend(CeedFree(&impl->diag->h_emodeout)); 56 CeedCallCuda(ceed, cudaFree(impl->diag->d_emodein)); 57 CeedCallCuda(ceed, cudaFree(impl->diag->d_emodeout)); 58 CeedCallCuda(ceed, cudaFree(impl->diag->d_identity)); 59 CeedCallCuda(ceed, cudaFree(impl->diag->d_interpin)); 60 CeedCallCuda(ceed, cudaFree(impl->diag->d_interpout)); 61 CeedCallCuda(ceed, cudaFree(impl->diag->d_gradin)); 62 CeedCallCuda(ceed, cudaFree(impl->diag->d_gradout)); 63 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->pbdiagrstr)); 64 CeedCallBackend(CeedVectorDestroy(&impl->diag->elemdiag)); 65 CeedCallBackend(CeedVectorDestroy(&impl->diag->pbelemdiag)); 66 } 67 CeedCallBackend(CeedFree(&impl->diag)); 68 69 if (impl->asmb) { 70 Ceed ceed; 71 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 72 CeedCallCuda(ceed, cuModuleUnload(impl->asmb->module)); 73 CeedCallCuda(ceed, cudaFree(impl->asmb->d_B_in)); 74 CeedCallCuda(ceed, cudaFree(impl->asmb->d_B_out)); 75 } 76 CeedCallBackend(CeedFree(&impl->asmb)); 77 78 CeedCallBackend(CeedFree(&impl)); 79 return CEED_ERROR_SUCCESS; 80 } 81 82 //------------------------------------------------------------------------------ 83 // Setup infields or outfields 84 //------------------------------------------------------------------------------ 85 static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool isinput, CeedVector *evecs, CeedVector *qvecs, CeedInt starte, 86 CeedInt numfields, CeedInt Q, CeedInt numelements) { 87 CeedInt dim, size; 88 CeedSize q_size; 89 Ceed ceed; 90 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 91 CeedBasis basis; 92 CeedElemRestriction Erestrict; 93 CeedOperatorField *opfields; 94 CeedQFunctionField *qffields; 95 CeedVector fieldvec; 96 bool strided; 97 bool skiprestrict; 98 99 if (isinput) { 100 CeedCallBackend(CeedOperatorGetFields(op, NULL, &opfields, NULL, NULL)); 101 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qffields, NULL, NULL)); 102 } else { 103 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &opfields)); 104 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qffields)); 105 } 106 107 // Loop over fields 108 for (CeedInt i = 0; i < numfields; i++) { 109 CeedEvalMode emode; 110 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qffields[i], &emode)); 111 112 strided = false; 113 skiprestrict = false; 114 if (emode != CEED_EVAL_WEIGHT) { 115 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opfields[i], &Erestrict)); 116 117 // Check whether this field can skip the element restriction: 118 // must be passive input, with emode NONE, and have a strided restriction with 119 // CEED_STRIDES_BACKEND. 120 121 // First, check whether the field is input or output: 122 if (isinput) { 123 // Check for passive input: 124 CeedCallBackend(CeedOperatorFieldGetVector(opfields[i], &fieldvec)); 125 if (fieldvec != CEED_VECTOR_ACTIVE) { 126 // Check emode 127 if (emode == CEED_EVAL_NONE) { 128 // Check for strided restriction 129 CeedCallBackend(CeedElemRestrictionIsStrided(Erestrict, &strided)); 130 if (strided) { 131 // Check if vector is already in preferred backend ordering 132 CeedCallBackend(CeedElemRestrictionHasBackendStrides(Erestrict, &skiprestrict)); 133 } 134 } 135 } 136 } 137 if (skiprestrict) { 138 // We do not need an E-Vector, but will use the input field vector's data 139 // directly in the operator application. 140 evecs[i + starte] = NULL; 141 } else { 142 CeedCallBackend(CeedElemRestrictionCreateVector(Erestrict, NULL, &evecs[i + starte])); 143 } 144 } 145 146 switch (emode) { 147 case CEED_EVAL_NONE: 148 CeedCallBackend(CeedQFunctionFieldGetSize(qffields[i], &size)); 149 q_size = (CeedSize)numelements * Q * size; 150 CeedCallBackend(CeedVectorCreate(ceed, q_size, &qvecs[i])); 151 break; 152 case CEED_EVAL_INTERP: 153 CeedCallBackend(CeedQFunctionFieldGetSize(qffields[i], &size)); 154 q_size = (CeedSize)numelements * Q * size; 155 CeedCallBackend(CeedVectorCreate(ceed, q_size, &qvecs[i])); 156 break; 157 case CEED_EVAL_GRAD: 158 CeedCallBackend(CeedOperatorFieldGetBasis(opfields[i], &basis)); 159 CeedCallBackend(CeedQFunctionFieldGetSize(qffields[i], &size)); 160 CeedCallBackend(CeedBasisGetDimension(basis, &dim)); 161 q_size = (CeedSize)numelements * Q * size; 162 CeedCallBackend(CeedVectorCreate(ceed, q_size, &qvecs[i])); 163 break; 164 case CEED_EVAL_WEIGHT: // Only on input fields 165 CeedCallBackend(CeedOperatorFieldGetBasis(opfields[i], &basis)); 166 q_size = (CeedSize)numelements * Q; 167 CeedCallBackend(CeedVectorCreate(ceed, q_size, &qvecs[i])); 168 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, NULL, qvecs[i])); 169 break; 170 case CEED_EVAL_DIV: 171 break; // TODO: Not implemented 172 case CEED_EVAL_CURL: 173 break; // TODO: Not implemented 174 } 175 } 176 return CEED_ERROR_SUCCESS; 177 } 178 179 //------------------------------------------------------------------------------ 180 // CeedOperator needs to connect all the named fields (be they active or passive) 181 // to the named inputs and outputs of its CeedQFunction. 182 //------------------------------------------------------------------------------ 183 static int CeedOperatorSetup_Cuda(CeedOperator op) { 184 bool setupdone; 185 CeedCallBackend(CeedOperatorIsSetupDone(op, &setupdone)); 186 if (setupdone) return CEED_ERROR_SUCCESS; 187 Ceed ceed; 188 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 189 CeedOperator_Cuda *impl; 190 CeedCallBackend(CeedOperatorGetData(op, &impl)); 191 CeedQFunction qf; 192 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 193 CeedInt Q, numelements, numinputfields, numoutputfields; 194 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 195 CeedCallBackend(CeedOperatorGetNumElements(op, &numelements)); 196 CeedOperatorField *opinputfields, *opoutputfields; 197 CeedCallBackend(CeedOperatorGetFields(op, &numinputfields, &opinputfields, &numoutputfields, &opoutputfields)); 198 CeedQFunctionField *qfinputfields, *qfoutputfields; 199 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields)); 200 201 // Allocate 202 CeedCallBackend(CeedCalloc(numinputfields + numoutputfields, &impl->evecs)); 203 204 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->qvecsin)); 205 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->qvecsout)); 206 207 impl->numein = numinputfields; 208 impl->numeout = numoutputfields; 209 210 // Set up infield and outfield evecs and qvecs 211 // Infields 212 CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, true, impl->evecs, impl->qvecsin, 0, numinputfields, Q, numelements)); 213 214 // Outfields 215 CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, impl->evecs, impl->qvecsout, numinputfields, numoutputfields, Q, numelements)); 216 217 CeedCallBackend(CeedOperatorSetSetupDone(op)); 218 return CEED_ERROR_SUCCESS; 219 } 220 221 //------------------------------------------------------------------------------ 222 // Setup Operator Inputs 223 //------------------------------------------------------------------------------ 224 static inline int CeedOperatorSetupInputs_Cuda(CeedInt numinputfields, CeedQFunctionField *qfinputfields, CeedOperatorField *opinputfields, 225 CeedVector invec, const bool skipactive, CeedScalar *edata[2 * CEED_FIELD_MAX], 226 CeedOperator_Cuda *impl, CeedRequest *request) { 227 CeedEvalMode emode; 228 CeedVector vec; 229 CeedElemRestriction Erestrict; 230 231 for (CeedInt i = 0; i < numinputfields; i++) { 232 // Get input vector 233 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 234 if (vec == CEED_VECTOR_ACTIVE) { 235 if (skipactive) continue; 236 else vec = invec; 237 } 238 239 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode)); 240 if (emode == CEED_EVAL_WEIGHT) { // Skip 241 } else { 242 // Get input vector 243 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 244 // Get input element restriction 245 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict)); 246 if (vec == CEED_VECTOR_ACTIVE) vec = invec; 247 // Restrict, if necessary 248 if (!impl->evecs[i]) { 249 // No restriction for this field; read data directly from vec. 250 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, (const CeedScalar **)&edata[i])); 251 } else { 252 CeedCallBackend(CeedElemRestrictionApply(Erestrict, CEED_NOTRANSPOSE, vec, impl->evecs[i], request)); 253 // Get evec 254 CeedCallBackend(CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_DEVICE, (const CeedScalar **)&edata[i])); 255 } 256 } 257 } 258 return CEED_ERROR_SUCCESS; 259 } 260 261 //------------------------------------------------------------------------------ 262 // Input Basis Action 263 //------------------------------------------------------------------------------ 264 static inline int CeedOperatorInputBasis_Cuda(CeedInt numelements, CeedQFunctionField *qfinputfields, CeedOperatorField *opinputfields, 265 CeedInt numinputfields, const bool skipactive, CeedScalar *edata[2 * CEED_FIELD_MAX], 266 CeedOperator_Cuda *impl) { 267 CeedInt elemsize, size; 268 CeedElemRestriction Erestrict; 269 CeedEvalMode emode; 270 CeedBasis basis; 271 272 for (CeedInt i = 0; i < numinputfields; i++) { 273 // Skip active input 274 if (skipactive) { 275 CeedVector vec; 276 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 277 if (vec == CEED_VECTOR_ACTIVE) continue; 278 } 279 // Get elemsize, emode, size 280 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict)); 281 CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elemsize)); 282 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode)); 283 CeedCallBackend(CeedQFunctionFieldGetSize(qfinputfields[i], &size)); 284 // Basis action 285 switch (emode) { 286 case CEED_EVAL_NONE: 287 CeedCallBackend(CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_DEVICE, CEED_USE_POINTER, edata[i])); 288 break; 289 case CEED_EVAL_INTERP: 290 CeedCallBackend(CeedOperatorFieldGetBasis(opinputfields[i], &basis)); 291 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_NOTRANSPOSE, CEED_EVAL_INTERP, impl->evecs[i], impl->qvecsin[i])); 292 break; 293 case CEED_EVAL_GRAD: 294 CeedCallBackend(CeedOperatorFieldGetBasis(opinputfields[i], &basis)); 295 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_NOTRANSPOSE, CEED_EVAL_GRAD, impl->evecs[i], impl->qvecsin[i])); 296 break; 297 case CEED_EVAL_WEIGHT: 298 break; // No action 299 case CEED_EVAL_DIV: 300 break; // TODO: Not implemented 301 case CEED_EVAL_CURL: 302 break; // TODO: Not implemented 303 } 304 } 305 return CEED_ERROR_SUCCESS; 306 } 307 308 //------------------------------------------------------------------------------ 309 // Restore Input Vectors 310 //------------------------------------------------------------------------------ 311 static inline int CeedOperatorRestoreInputs_Cuda(CeedInt numinputfields, CeedQFunctionField *qfinputfields, CeedOperatorField *opinputfields, 312 const bool skipactive, CeedScalar *edata[2 * CEED_FIELD_MAX], CeedOperator_Cuda *impl) { 313 CeedEvalMode emode; 314 CeedVector vec; 315 316 for (CeedInt i = 0; i < numinputfields; i++) { 317 // Skip active input 318 if (skipactive) { 319 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 320 if (vec == CEED_VECTOR_ACTIVE) continue; 321 } 322 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode)); 323 if (emode == CEED_EVAL_WEIGHT) { // Skip 324 } else { 325 if (!impl->evecs[i]) { // This was a skiprestrict case 326 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 327 CeedCallBackend(CeedVectorRestoreArrayRead(vec, (const CeedScalar **)&edata[i])); 328 } else { 329 CeedCallBackend(CeedVectorRestoreArrayRead(impl->evecs[i], (const CeedScalar **)&edata[i])); 330 } 331 } 332 } 333 return CEED_ERROR_SUCCESS; 334 } 335 336 //------------------------------------------------------------------------------ 337 // Apply and add to output 338 //------------------------------------------------------------------------------ 339 static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector invec, CeedVector outvec, CeedRequest *request) { 340 CeedOperator_Cuda *impl; 341 CeedCallBackend(CeedOperatorGetData(op, &impl)); 342 CeedQFunction qf; 343 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 344 CeedInt Q, numelements, elemsize, numinputfields, numoutputfields, size; 345 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 346 CeedCallBackend(CeedOperatorGetNumElements(op, &numelements)); 347 CeedOperatorField *opinputfields, *opoutputfields; 348 CeedCallBackend(CeedOperatorGetFields(op, &numinputfields, &opinputfields, &numoutputfields, &opoutputfields)); 349 CeedQFunctionField *qfinputfields, *qfoutputfields; 350 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields)); 351 CeedEvalMode emode; 352 CeedVector vec; 353 CeedBasis basis; 354 CeedElemRestriction Erestrict; 355 CeedScalar *edata[2 * CEED_FIELD_MAX] = {0}; 356 357 // Setup 358 CeedCallBackend(CeedOperatorSetup_Cuda(op)); 359 360 // Input Evecs and Restriction 361 CeedCallBackend(CeedOperatorSetupInputs_Cuda(numinputfields, qfinputfields, opinputfields, invec, false, edata, impl, request)); 362 363 // Input basis apply if needed 364 CeedCallBackend(CeedOperatorInputBasis_Cuda(numelements, qfinputfields, opinputfields, numinputfields, false, edata, impl)); 365 366 // Output pointers, as necessary 367 for (CeedInt i = 0; i < numoutputfields; i++) { 368 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode)); 369 if (emode == CEED_EVAL_NONE) { 370 // Set the output Q-Vector to use the E-Vector data directly. 371 CeedCallBackend(CeedVectorGetArrayWrite(impl->evecs[i + impl->numein], CEED_MEM_DEVICE, &edata[i + numinputfields])); 372 CeedCallBackend(CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_DEVICE, CEED_USE_POINTER, edata[i + numinputfields])); 373 } 374 } 375 376 // Q function 377 CeedCallBackend(CeedQFunctionApply(qf, numelements * Q, impl->qvecsin, impl->qvecsout)); 378 379 // Output basis apply if needed 380 for (CeedInt i = 0; i < numoutputfields; i++) { 381 // Get elemsize, emode, size 382 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict)); 383 CeedCallBackend(CeedElemRestrictionGetElementSize(Erestrict, &elemsize)); 384 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode)); 385 CeedCallBackend(CeedQFunctionFieldGetSize(qfoutputfields[i], &size)); 386 // Basis action 387 switch (emode) { 388 case CEED_EVAL_NONE: 389 break; 390 case CEED_EVAL_INTERP: 391 CeedCallBackend(CeedOperatorFieldGetBasis(opoutputfields[i], &basis)); 392 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_TRANSPOSE, CEED_EVAL_INTERP, impl->qvecsout[i], impl->evecs[i + impl->numein])); 393 break; 394 case CEED_EVAL_GRAD: 395 CeedCallBackend(CeedOperatorFieldGetBasis(opoutputfields[i], &basis)); 396 CeedCallBackend(CeedBasisApply(basis, numelements, CEED_TRANSPOSE, CEED_EVAL_GRAD, impl->qvecsout[i], impl->evecs[i + impl->numein])); 397 break; 398 // LCOV_EXCL_START 399 case CEED_EVAL_WEIGHT: { 400 Ceed ceed; 401 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 402 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 403 break; // Should not occur 404 } 405 case CEED_EVAL_DIV: 406 break; // TODO: Not implemented 407 case CEED_EVAL_CURL: 408 break; // TODO: Not implemented 409 // LCOV_EXCL_STOP 410 } 411 } 412 413 // Output restriction 414 for (CeedInt i = 0; i < numoutputfields; i++) { 415 // Restore evec 416 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode)); 417 if (emode == CEED_EVAL_NONE) { 418 CeedCallBackend(CeedVectorRestoreArray(impl->evecs[i + impl->numein], &edata[i + numinputfields])); 419 } 420 // Get output vector 421 CeedCallBackend(CeedOperatorFieldGetVector(opoutputfields[i], &vec)); 422 // Restrict 423 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict)); 424 // Active 425 if (vec == CEED_VECTOR_ACTIVE) vec = outvec; 426 427 CeedCallBackend(CeedElemRestrictionApply(Erestrict, CEED_TRANSPOSE, impl->evecs[i + impl->numein], vec, request)); 428 } 429 430 // Restore input arrays 431 CeedCallBackend(CeedOperatorRestoreInputs_Cuda(numinputfields, qfinputfields, opinputfields, false, edata, impl)); 432 return CEED_ERROR_SUCCESS; 433 } 434 435 //------------------------------------------------------------------------------ 436 // Core code for assembling linear QFunction 437 //------------------------------------------------------------------------------ 438 static inline int CeedOperatorLinearAssembleQFunctionCore_Cuda(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 439 CeedRequest *request) { 440 CeedOperator_Cuda *impl; 441 CeedCallBackend(CeedOperatorGetData(op, &impl)); 442 CeedQFunction qf; 443 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 444 CeedInt Q, numelements, numinputfields, numoutputfields, size; 445 CeedSize q_size; 446 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 447 CeedCallBackend(CeedOperatorGetNumElements(op, &numelements)); 448 CeedOperatorField *opinputfields, *opoutputfields; 449 CeedCallBackend(CeedOperatorGetFields(op, &numinputfields, &opinputfields, &numoutputfields, &opoutputfields)); 450 CeedQFunctionField *qfinputfields, *qfoutputfields; 451 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields)); 452 CeedVector vec; 453 CeedInt numactivein = impl->qfnumactivein, numactiveout = impl->qfnumactiveout; 454 CeedVector *activein = impl->qfactivein; 455 CeedScalar *a, *tmp; 456 Ceed ceed, ceedparent; 457 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 458 CeedCallBackend(CeedGetOperatorFallbackParentCeed(ceed, &ceedparent)); 459 ceedparent = ceedparent ? ceedparent : ceed; 460 CeedScalar *edata[2 * CEED_FIELD_MAX]; 461 462 // Setup 463 CeedCallBackend(CeedOperatorSetup_Cuda(op)); 464 465 // Check for identity 466 bool identityqf; 467 CeedCallBackend(CeedQFunctionIsIdentity(qf, &identityqf)); 468 if (identityqf) { 469 // LCOV_EXCL_START 470 return CeedError(ceed, CEED_ERROR_BACKEND, "Assembling identity QFunctions not supported"); 471 // LCOV_EXCL_STOP 472 } 473 474 // Input Evecs and Restriction 475 CeedCallBackend(CeedOperatorSetupInputs_Cuda(numinputfields, qfinputfields, opinputfields, NULL, true, edata, impl, request)); 476 477 // Count number of active input fields 478 if (!numactivein) { 479 for (CeedInt i = 0; i < numinputfields; i++) { 480 // Get input vector 481 CeedCallBackend(CeedOperatorFieldGetVector(opinputfields[i], &vec)); 482 // Check if active input 483 if (vec == CEED_VECTOR_ACTIVE) { 484 CeedCallBackend(CeedQFunctionFieldGetSize(qfinputfields[i], &size)); 485 CeedCallBackend(CeedVectorSetValue(impl->qvecsin[i], 0.0)); 486 CeedCallBackend(CeedVectorGetArray(impl->qvecsin[i], CEED_MEM_DEVICE, &tmp)); 487 CeedCallBackend(CeedRealloc(numactivein + size, &activein)); 488 for (CeedInt field = 0; field < size; field++) { 489 q_size = (CeedSize)Q * numelements; 490 CeedCallBackend(CeedVectorCreate(ceed, q_size, &activein[numactivein + field])); 491 CeedCallBackend(CeedVectorSetArray(activein[numactivein + field], CEED_MEM_DEVICE, CEED_USE_POINTER, &tmp[field * Q * numelements])); 492 } 493 numactivein += size; 494 CeedCallBackend(CeedVectorRestoreArray(impl->qvecsin[i], &tmp)); 495 } 496 } 497 impl->qfnumactivein = numactivein; 498 impl->qfactivein = activein; 499 } 500 501 // Count number of active output fields 502 if (!numactiveout) { 503 for (CeedInt i = 0; i < numoutputfields; i++) { 504 // Get output vector 505 CeedCallBackend(CeedOperatorFieldGetVector(opoutputfields[i], &vec)); 506 // Check if active output 507 if (vec == CEED_VECTOR_ACTIVE) { 508 CeedCallBackend(CeedQFunctionFieldGetSize(qfoutputfields[i], &size)); 509 numactiveout += size; 510 } 511 } 512 impl->qfnumactiveout = numactiveout; 513 } 514 515 // Check sizes 516 if (!numactivein || !numactiveout) { 517 // LCOV_EXCL_START 518 return CeedError(ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 519 // LCOV_EXCL_STOP 520 } 521 522 // Build objects if needed 523 if (build_objects) { 524 // Create output restriction 525 CeedInt strides[3] = {1, numelements * Q, Q}; /* *NOPAD* */ 526 CeedCallBackend(CeedElemRestrictionCreateStrided(ceedparent, numelements, Q, numactivein * numactiveout, 527 numactivein * numactiveout * numelements * Q, strides, rstr)); 528 // Create assembled vector 529 CeedSize l_size = (CeedSize)numelements * Q * numactivein * numactiveout; 530 CeedCallBackend(CeedVectorCreate(ceedparent, l_size, assembled)); 531 } 532 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 533 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_DEVICE, &a)); 534 535 // Input basis apply 536 CeedCallBackend(CeedOperatorInputBasis_Cuda(numelements, qfinputfields, opinputfields, numinputfields, true, edata, impl)); 537 538 // Assemble QFunction 539 for (CeedInt in = 0; in < numactivein; in++) { 540 // Set Inputs 541 CeedCallBackend(CeedVectorSetValue(activein[in], 1.0)); 542 if (numactivein > 1) { 543 CeedCallBackend(CeedVectorSetValue(activein[(in + numactivein - 1) % numactivein], 0.0)); 544 } 545 // Set Outputs 546 for (CeedInt out = 0; out < numoutputfields; out++) { 547 // Get output vector 548 CeedCallBackend(CeedOperatorFieldGetVector(opoutputfields[out], &vec)); 549 // Check if active output 550 if (vec == CEED_VECTOR_ACTIVE) { 551 CeedCallBackend(CeedVectorSetArray(impl->qvecsout[out], CEED_MEM_DEVICE, CEED_USE_POINTER, a)); 552 CeedCallBackend(CeedQFunctionFieldGetSize(qfoutputfields[out], &size)); 553 a += size * Q * numelements; // Advance the pointer by the size of the output 554 } 555 } 556 // Apply QFunction 557 CeedCallBackend(CeedQFunctionApply(qf, Q * numelements, impl->qvecsin, impl->qvecsout)); 558 } 559 560 // Un-set output Qvecs to prevent accidental overwrite of Assembled 561 for (CeedInt out = 0; out < numoutputfields; out++) { 562 // Get output vector 563 CeedCallBackend(CeedOperatorFieldGetVector(opoutputfields[out], &vec)); 564 // Check if active output 565 if (vec == CEED_VECTOR_ACTIVE) { 566 CeedCallBackend(CeedVectorTakeArray(impl->qvecsout[out], CEED_MEM_DEVICE, NULL)); 567 } 568 } 569 570 // Restore input arrays 571 CeedCallBackend(CeedOperatorRestoreInputs_Cuda(numinputfields, qfinputfields, opinputfields, true, edata, impl)); 572 573 // Restore output 574 CeedCallBackend(CeedVectorRestoreArray(*assembled, &a)); 575 576 return CEED_ERROR_SUCCESS; 577 } 578 579 //------------------------------------------------------------------------------ 580 // Assemble Linear QFunction 581 //------------------------------------------------------------------------------ 582 static int CeedOperatorLinearAssembleQFunction_Cuda(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 583 return CeedOperatorLinearAssembleQFunctionCore_Cuda(op, true, assembled, rstr, request); 584 } 585 586 //------------------------------------------------------------------------------ 587 // Update Assembled Linear QFunction 588 //------------------------------------------------------------------------------ 589 static int CeedOperatorLinearAssembleQFunctionUpdate_Cuda(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 590 return CeedOperatorLinearAssembleQFunctionCore_Cuda(op, false, &assembled, &rstr, request); 591 } 592 593 //------------------------------------------------------------------------------ 594 // Create point block restriction 595 //------------------------------------------------------------------------------ 596 static int CreatePBRestriction(CeedElemRestriction rstr, CeedElemRestriction *pbRstr) { 597 Ceed ceed; 598 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 599 const CeedInt *offsets; 600 CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets)); 601 602 // Expand offsets 603 CeedInt nelem, ncomp, elemsize, compstride, *pbOffsets; 604 CeedSize l_size; 605 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &nelem)); 606 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &ncomp)); 607 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elemsize)); 608 CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &compstride)); 609 CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size)); 610 CeedInt shift = ncomp; 611 if (compstride != 1) shift *= ncomp; 612 CeedCallBackend(CeedCalloc(nelem * elemsize, &pbOffsets)); 613 for (CeedInt i = 0; i < nelem * elemsize; i++) { 614 pbOffsets[i] = offsets[i] * shift; 615 } 616 617 // Create new restriction 618 CeedCallBackend( 619 CeedElemRestrictionCreate(ceed, nelem, elemsize, ncomp * ncomp, 1, l_size * ncomp, CEED_MEM_HOST, CEED_OWN_POINTER, pbOffsets, pbRstr)); 620 621 // Cleanup 622 CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets)); 623 624 return CEED_ERROR_SUCCESS; 625 } 626 627 //------------------------------------------------------------------------------ 628 // Assemble diagonal setup 629 //------------------------------------------------------------------------------ 630 static inline int CeedOperatorAssembleDiagonalSetup_Cuda(CeedOperator op, const bool pointBlock) { 631 Ceed ceed; 632 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 633 CeedQFunction qf; 634 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 635 CeedInt numinputfields, numoutputfields; 636 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields)); 637 638 // Determine active input basis 639 CeedOperatorField *opfields; 640 CeedQFunctionField *qffields; 641 CeedCallBackend(CeedOperatorGetFields(op, NULL, &opfields, NULL, NULL)); 642 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qffields, NULL, NULL)); 643 CeedInt numemodein = 0, ncomp = 0, dim = 1; 644 CeedEvalMode *emodein = NULL; 645 CeedBasis basisin = NULL; 646 CeedElemRestriction rstrin = NULL; 647 for (CeedInt i = 0; i < numinputfields; i++) { 648 CeedVector vec; 649 CeedCallBackend(CeedOperatorFieldGetVector(opfields[i], &vec)); 650 if (vec == CEED_VECTOR_ACTIVE) { 651 CeedElemRestriction rstr; 652 CeedCallBackend(CeedOperatorFieldGetBasis(opfields[i], &basisin)); 653 CeedCallBackend(CeedBasisGetNumComponents(basisin, &ncomp)); 654 CeedCallBackend(CeedBasisGetDimension(basisin, &dim)); 655 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opfields[i], &rstr)); 656 if (rstrin && rstrin != rstr) { 657 // LCOV_EXCL_START 658 return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator diagonal assembly"); 659 // LCOV_EXCL_STOP 660 } 661 rstrin = rstr; 662 CeedEvalMode emode; 663 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qffields[i], &emode)); 664 switch (emode) { 665 case CEED_EVAL_NONE: 666 case CEED_EVAL_INTERP: 667 CeedCallBackend(CeedRealloc(numemodein + 1, &emodein)); 668 emodein[numemodein] = emode; 669 numemodein += 1; 670 break; 671 case CEED_EVAL_GRAD: 672 CeedCallBackend(CeedRealloc(numemodein + dim, &emodein)); 673 for (CeedInt d = 0; d < dim; d++) emodein[numemodein + d] = emode; 674 numemodein += dim; 675 break; 676 case CEED_EVAL_WEIGHT: 677 case CEED_EVAL_DIV: 678 case CEED_EVAL_CURL: 679 break; // Caught by QF Assembly 680 } 681 } 682 } 683 684 // Determine active output basis 685 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &opfields)); 686 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qffields)); 687 CeedInt numemodeout = 0; 688 CeedEvalMode *emodeout = NULL; 689 CeedBasis basisout = NULL; 690 CeedElemRestriction rstrout = NULL; 691 for (CeedInt i = 0; i < numoutputfields; i++) { 692 CeedVector vec; 693 CeedCallBackend(CeedOperatorFieldGetVector(opfields[i], &vec)); 694 if (vec == CEED_VECTOR_ACTIVE) { 695 CeedElemRestriction rstr; 696 CeedCallBackend(CeedOperatorFieldGetBasis(opfields[i], &basisout)); 697 CeedCallBackend(CeedOperatorFieldGetElemRestriction(opfields[i], &rstr)); 698 if (rstrout && rstrout != rstr) { 699 // LCOV_EXCL_START 700 return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator diagonal assembly"); 701 // LCOV_EXCL_STOP 702 } 703 rstrout = rstr; 704 CeedEvalMode emode; 705 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qffields[i], &emode)); 706 switch (emode) { 707 case CEED_EVAL_NONE: 708 case CEED_EVAL_INTERP: 709 CeedCallBackend(CeedRealloc(numemodeout + 1, &emodeout)); 710 emodeout[numemodeout] = emode; 711 numemodeout += 1; 712 break; 713 case CEED_EVAL_GRAD: 714 CeedCallBackend(CeedRealloc(numemodeout + dim, &emodeout)); 715 for (CeedInt d = 0; d < dim; d++) emodeout[numemodeout + d] = emode; 716 numemodeout += dim; 717 break; 718 case CEED_EVAL_WEIGHT: 719 case CEED_EVAL_DIV: 720 case CEED_EVAL_CURL: 721 break; // Caught by QF Assembly 722 } 723 } 724 } 725 726 // Operator data struct 727 CeedOperator_Cuda *impl; 728 CeedCallBackend(CeedOperatorGetData(op, &impl)); 729 CeedCallBackend(CeedCalloc(1, &impl->diag)); 730 CeedOperatorDiag_Cuda *diag = impl->diag; 731 diag->basisin = basisin; 732 diag->basisout = basisout; 733 diag->h_emodein = emodein; 734 diag->h_emodeout = emodeout; 735 diag->numemodein = numemodein; 736 diag->numemodeout = numemodeout; 737 738 // Assemble kernel 739 char *diagonal_kernel_path, *diagonal_kernel_source; 740 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-operator-assemble-diagonal.h", &diagonal_kernel_path)); 741 CeedDebug256(ceed, 2, "----- Loading Diagonal Assembly Kernel Source -----\n"); 742 CeedCallBackend(CeedLoadSourceToBuffer(ceed, diagonal_kernel_path, &diagonal_kernel_source)); 743 CeedDebug256(ceed, 2, "----- Loading Diagonal Assembly Source Complete! -----\n"); 744 CeedInt nnodes, nqpts; 745 CeedCallBackend(CeedBasisGetNumNodes(basisin, &nnodes)); 746 CeedCallBackend(CeedBasisGetNumQuadraturePoints(basisin, &nqpts)); 747 diag->nnodes = nnodes; 748 CeedCallCuda(ceed, CeedCompileCuda(ceed, diagonal_kernel_source, &diag->module, 5, "NUMEMODEIN", numemodein, "NUMEMODEOUT", numemodeout, "NNODES", 749 nnodes, "NQPTS", nqpts, "NCOMP", ncomp)); 750 CeedCallCuda(ceed, CeedGetKernelCuda(ceed, diag->module, "linearDiagonal", &diag->linearDiagonal)); 751 CeedCallCuda(ceed, CeedGetKernelCuda(ceed, diag->module, "linearPointBlockDiagonal", &diag->linearPointBlock)); 752 CeedCallBackend(CeedFree(&diagonal_kernel_path)); 753 CeedCallBackend(CeedFree(&diagonal_kernel_source)); 754 755 // Basis matrices 756 const CeedInt qBytes = nqpts * sizeof(CeedScalar); 757 const CeedInt iBytes = qBytes * nnodes; 758 const CeedInt gBytes = qBytes * nnodes * dim; 759 const CeedInt eBytes = sizeof(CeedEvalMode); 760 const CeedScalar *interpin, *interpout, *gradin, *gradout; 761 762 // CEED_EVAL_NONE 763 CeedScalar *identity = NULL; 764 bool evalNone = false; 765 for (CeedInt i = 0; i < numemodein; i++) evalNone = evalNone || (emodein[i] == CEED_EVAL_NONE); 766 for (CeedInt i = 0; i < numemodeout; i++) evalNone = evalNone || (emodeout[i] == CEED_EVAL_NONE); 767 if (evalNone) { 768 CeedCallBackend(CeedCalloc(nqpts * nnodes, &identity)); 769 for (CeedInt i = 0; i < (nnodes < nqpts ? nnodes : nqpts); i++) identity[i * nnodes + i] = 1.0; 770 CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_identity, iBytes)); 771 CeedCallCuda(ceed, cudaMemcpy(diag->d_identity, identity, iBytes, cudaMemcpyHostToDevice)); 772 } 773 774 // CEED_EVAL_INTERP 775 CeedCallBackend(CeedBasisGetInterp(basisin, &interpin)); 776 CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_interpin, iBytes)); 777 CeedCallCuda(ceed, cudaMemcpy(diag->d_interpin, interpin, iBytes, cudaMemcpyHostToDevice)); 778 CeedCallBackend(CeedBasisGetInterp(basisout, &interpout)); 779 CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_interpout, iBytes)); 780 CeedCallCuda(ceed, cudaMemcpy(diag->d_interpout, interpout, iBytes, cudaMemcpyHostToDevice)); 781 782 // CEED_EVAL_GRAD 783 CeedCallBackend(CeedBasisGetGrad(basisin, &gradin)); 784 CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_gradin, gBytes)); 785 CeedCallCuda(ceed, cudaMemcpy(diag->d_gradin, gradin, gBytes, cudaMemcpyHostToDevice)); 786 CeedCallBackend(CeedBasisGetGrad(basisout, &gradout)); 787 CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_gradout, gBytes)); 788 CeedCallCuda(ceed, cudaMemcpy(diag->d_gradout, gradout, gBytes, cudaMemcpyHostToDevice)); 789 790 // Arrays of emodes 791 CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_emodein, numemodein * eBytes)); 792 CeedCallCuda(ceed, cudaMemcpy(diag->d_emodein, emodein, numemodein * eBytes, cudaMemcpyHostToDevice)); 793 CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_emodeout, numemodeout * eBytes)); 794 CeedCallCuda(ceed, cudaMemcpy(diag->d_emodeout, emodeout, numemodeout * eBytes, cudaMemcpyHostToDevice)); 795 796 // Restriction 797 diag->diagrstr = rstrout; 798 799 return CEED_ERROR_SUCCESS; 800 } 801 802 //------------------------------------------------------------------------------ 803 // Assemble diagonal common code 804 //------------------------------------------------------------------------------ 805 static inline int CeedOperatorAssembleDiagonalCore_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool pointBlock) { 806 Ceed ceed; 807 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 808 CeedOperator_Cuda *impl; 809 CeedCallBackend(CeedOperatorGetData(op, &impl)); 810 811 // Assemble QFunction 812 CeedVector assembledqf; 813 CeedElemRestriction rstr; 814 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembledqf, &rstr, request)); 815 CeedCallBackend(CeedElemRestrictionDestroy(&rstr)); 816 817 // Setup 818 if (!impl->diag) { 819 CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Cuda(op, pointBlock)); 820 } 821 CeedOperatorDiag_Cuda *diag = impl->diag; 822 assert(diag != NULL); 823 824 // Restriction 825 if (pointBlock && !diag->pbdiagrstr) { 826 CeedElemRestriction pbdiagrstr; 827 CeedCallBackend(CreatePBRestriction(diag->diagrstr, &pbdiagrstr)); 828 diag->pbdiagrstr = pbdiagrstr; 829 } 830 CeedElemRestriction diagrstr = pointBlock ? diag->pbdiagrstr : diag->diagrstr; 831 832 // Create diagonal vector 833 CeedVector elemdiag = pointBlock ? diag->pbelemdiag : diag->elemdiag; 834 if (!elemdiag) { 835 CeedCallBackend(CeedElemRestrictionCreateVector(diagrstr, NULL, &elemdiag)); 836 if (pointBlock) diag->pbelemdiag = elemdiag; 837 else diag->elemdiag = elemdiag; 838 } 839 CeedCallBackend(CeedVectorSetValue(elemdiag, 0.0)); 840 841 // Assemble element operator diagonals 842 CeedScalar *elemdiagarray; 843 const CeedScalar *assembledqfarray; 844 CeedCallBackend(CeedVectorGetArray(elemdiag, CEED_MEM_DEVICE, &elemdiagarray)); 845 CeedCallBackend(CeedVectorGetArrayRead(assembledqf, CEED_MEM_DEVICE, &assembledqfarray)); 846 CeedInt nelem; 847 CeedCallBackend(CeedElemRestrictionGetNumElements(diagrstr, &nelem)); 848 849 // Compute the diagonal of B^T D B 850 int elemsPerBlock = 1; 851 int grid = nelem / elemsPerBlock + ((nelem / elemsPerBlock * elemsPerBlock < nelem) ? 1 : 0); 852 void *args[] = {(void *)&nelem, &diag->d_identity, &diag->d_interpin, &diag->d_gradin, &diag->d_interpout, 853 &diag->d_gradout, &diag->d_emodein, &diag->d_emodeout, &assembledqfarray, &elemdiagarray}; 854 if (pointBlock) { 855 CeedCallBackend(CeedRunKernelDimCuda(ceed, diag->linearPointBlock, grid, diag->nnodes, 1, elemsPerBlock, args)); 856 } else { 857 CeedCallBackend(CeedRunKernelDimCuda(ceed, diag->linearDiagonal, grid, diag->nnodes, 1, elemsPerBlock, args)); 858 } 859 860 // Restore arrays 861 CeedCallBackend(CeedVectorRestoreArray(elemdiag, &elemdiagarray)); 862 CeedCallBackend(CeedVectorRestoreArrayRead(assembledqf, &assembledqfarray)); 863 864 // Assemble local operator diagonal 865 CeedCallBackend(CeedElemRestrictionApply(diagrstr, CEED_TRANSPOSE, elemdiag, assembled, request)); 866 867 // Cleanup 868 CeedCallBackend(CeedVectorDestroy(&assembledqf)); 869 870 return CEED_ERROR_SUCCESS; 871 } 872 873 //------------------------------------------------------------------------------ 874 // Assemble Linear Diagonal 875 //------------------------------------------------------------------------------ 876 static int CeedOperatorLinearAssembleAddDiagonal_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request) { 877 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Cuda(op, assembled, request, false)); 878 return CEED_ERROR_SUCCESS; 879 } 880 881 //------------------------------------------------------------------------------ 882 // Assemble Linear Point Block Diagonal 883 //------------------------------------------------------------------------------ 884 static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request) { 885 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Cuda(op, assembled, request, true)); 886 return CEED_ERROR_SUCCESS; 887 } 888 889 //------------------------------------------------------------------------------ 890 // Single operator assembly setup 891 //------------------------------------------------------------------------------ 892 static int CeedSingleOperatorAssembleSetup_Cuda(CeedOperator op) { 893 Ceed ceed; 894 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 895 CeedOperator_Cuda *impl; 896 CeedCallBackend(CeedOperatorGetData(op, &impl)); 897 898 // Get intput and output fields 899 CeedInt num_input_fields, num_output_fields; 900 CeedOperatorField *input_fields; 901 CeedOperatorField *output_fields; 902 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields)); 903 904 // Determine active input basis eval mode 905 CeedQFunction qf; 906 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 907 CeedQFunctionField *qf_fields; 908 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 909 // Note that the kernel will treat each dimension of a gradient action separately; 910 // i.e., when an active input has a CEED_EVAL_GRAD mode, num_emode_in will increment 911 // by dim. However, for the purposes of loading the B matrices, it will be treated 912 // as one mode, and we will load/copy the entire gradient matrix at once, so 913 // num_B_in_mats_to_load will be incremented by 1. 914 CeedInt num_emode_in = 0, dim = 1, num_B_in_mats_to_load = 0, size_B_in = 0; 915 CeedEvalMode *eval_mode_in = NULL; // will be of size num_B_in_mats_load 916 CeedBasis basis_in = NULL; 917 CeedInt nqpts = 0, esize = 0; 918 CeedElemRestriction rstr_in = NULL; 919 for (CeedInt i = 0; i < num_input_fields; i++) { 920 CeedVector vec; 921 CeedCallBackend(CeedOperatorFieldGetVector(input_fields[i], &vec)); 922 if (vec == CEED_VECTOR_ACTIVE) { 923 CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis_in)); 924 CeedCallBackend(CeedBasisGetDimension(basis_in, &dim)); 925 CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &nqpts)); 926 CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &rstr_in)); 927 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &esize)); 928 CeedEvalMode eval_mode; 929 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 930 if (eval_mode != CEED_EVAL_NONE) { 931 CeedCallBackend(CeedRealloc(num_B_in_mats_to_load + 1, &eval_mode_in)); 932 eval_mode_in[num_B_in_mats_to_load] = eval_mode; 933 num_B_in_mats_to_load += 1; 934 if (eval_mode == CEED_EVAL_GRAD) { 935 num_emode_in += dim; 936 size_B_in += dim * esize * nqpts; 937 } else { 938 num_emode_in += 1; 939 size_B_in += esize * nqpts; 940 } 941 } 942 } 943 } 944 945 // Determine active output basis; basis_out and rstr_out only used if same as input, TODO 946 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 947 CeedInt num_emode_out = 0, num_B_out_mats_to_load = 0, size_B_out = 0; 948 CeedEvalMode *eval_mode_out = NULL; 949 CeedBasis basis_out = NULL; 950 CeedElemRestriction rstr_out = NULL; 951 for (CeedInt i = 0; i < num_output_fields; i++) { 952 CeedVector vec; 953 CeedCallBackend(CeedOperatorFieldGetVector(output_fields[i], &vec)); 954 if (vec == CEED_VECTOR_ACTIVE) { 955 CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis_out)); 956 CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &rstr_out)); 957 if (rstr_out && rstr_out != rstr_in) { 958 // LCOV_EXCL_START 959 return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator assembly"); 960 // LCOV_EXCL_STOP 961 } 962 CeedEvalMode eval_mode; 963 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 964 if (eval_mode != CEED_EVAL_NONE) { 965 CeedCallBackend(CeedRealloc(num_B_out_mats_to_load + 1, &eval_mode_out)); 966 eval_mode_out[num_B_out_mats_to_load] = eval_mode; 967 num_B_out_mats_to_load += 1; 968 if (eval_mode == CEED_EVAL_GRAD) { 969 num_emode_out += dim; 970 size_B_out += dim * esize * nqpts; 971 } else { 972 num_emode_out += 1; 973 size_B_out += esize * nqpts; 974 } 975 } 976 } 977 } 978 979 if (num_emode_in == 0 || num_emode_out == 0) { 980 // LCOV_EXCL_START 981 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs"); 982 // LCOV_EXCL_STOP 983 } 984 985 CeedInt nelem, ncomp; 986 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &nelem)); 987 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &ncomp)); 988 989 CeedCallBackend(CeedCalloc(1, &impl->asmb)); 990 CeedOperatorAssemble_Cuda *asmb = impl->asmb; 991 asmb->nelem = nelem; 992 993 // Compile kernels 994 int elemsPerBlock = 1; 995 asmb->elemsPerBlock = elemsPerBlock; 996 CeedInt block_size = esize * esize * elemsPerBlock; 997 Ceed_Cuda *cuda_data; 998 CeedCallBackend(CeedGetData(ceed, &cuda_data)); 999 char *assembly_kernel_path, *assembly_kernel_source; 1000 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-operator-assemble.h", &assembly_kernel_path)); 1001 CeedDebug256(ceed, 2, "----- Loading Assembly Kernel Source -----\n"); 1002 CeedCallBackend(CeedLoadSourceToBuffer(ceed, assembly_kernel_path, &assembly_kernel_source)); 1003 CeedDebug256(ceed, 2, "----- Loading Assembly Source Complete! -----\n"); 1004 bool fallback = block_size > cuda_data->device_prop.maxThreadsPerBlock; 1005 if (fallback) { 1006 // Use fallback kernel with 1D threadblock 1007 block_size = esize * elemsPerBlock; 1008 asmb->block_size_x = esize; 1009 asmb->block_size_y = 1; 1010 } else { // Use kernel with 2D threadblock 1011 asmb->block_size_x = esize; 1012 asmb->block_size_y = esize; 1013 } 1014 CeedCallCuda(ceed, CeedCompileCuda(ceed, assembly_kernel_source, &asmb->module, 7, "NELEM", nelem, "NUMEMODEIN", num_emode_in, "NUMEMODEOUT", 1015 num_emode_out, "NQPTS", nqpts, "NNODES", esize, "BLOCK_SIZE", block_size, "NCOMP", ncomp)); 1016 CeedCallCuda(ceed, CeedGetKernelCuda(ceed, asmb->module, fallback ? "linearAssembleFallback" : "linearAssemble", &asmb->linearAssemble)); 1017 CeedCallBackend(CeedFree(&assembly_kernel_path)); 1018 CeedCallBackend(CeedFree(&assembly_kernel_source)); 1019 1020 // Build 'full' B matrices (not 1D arrays used for tensor-product matrices) 1021 const CeedScalar *interp_in, *grad_in; 1022 CeedCallBackend(CeedBasisGetInterp(basis_in, &interp_in)); 1023 CeedCallBackend(CeedBasisGetGrad(basis_in, &grad_in)); 1024 1025 // Load into B_in, in order that they will be used in eval_mode 1026 const CeedInt inBytes = size_B_in * sizeof(CeedScalar); 1027 CeedInt mat_start = 0; 1028 CeedCallCuda(ceed, cudaMalloc((void **)&asmb->d_B_in, inBytes)); 1029 for (int i = 0; i < num_B_in_mats_to_load; i++) { 1030 CeedEvalMode eval_mode = eval_mode_in[i]; 1031 if (eval_mode == CEED_EVAL_INTERP) { 1032 CeedCallCuda(ceed, cudaMemcpy(&asmb->d_B_in[mat_start], interp_in, esize * nqpts * sizeof(CeedScalar), cudaMemcpyHostToDevice)); 1033 mat_start += esize * nqpts; 1034 } else if (eval_mode == CEED_EVAL_GRAD) { 1035 CeedCallCuda(ceed, cudaMemcpy(&asmb->d_B_in[mat_start], grad_in, dim * esize * nqpts * sizeof(CeedScalar), cudaMemcpyHostToDevice)); 1036 mat_start += dim * esize * nqpts; 1037 } 1038 } 1039 1040 const CeedScalar *interp_out, *grad_out; 1041 // Note that this function currently assumes 1 basis, so this should always be true 1042 // for now 1043 if (basis_out == basis_in) { 1044 interp_out = interp_in; 1045 grad_out = grad_in; 1046 } else { 1047 CeedCallBackend(CeedBasisGetInterp(basis_out, &interp_out)); 1048 CeedCallBackend(CeedBasisGetGrad(basis_out, &grad_out)); 1049 } 1050 1051 // Load into B_out, in order that they will be used in eval_mode 1052 const CeedInt outBytes = size_B_out * sizeof(CeedScalar); 1053 mat_start = 0; 1054 CeedCallCuda(ceed, cudaMalloc((void **)&asmb->d_B_out, outBytes)); 1055 for (int i = 0; i < num_B_out_mats_to_load; i++) { 1056 CeedEvalMode eval_mode = eval_mode_out[i]; 1057 if (eval_mode == CEED_EVAL_INTERP) { 1058 CeedCallCuda(ceed, cudaMemcpy(&asmb->d_B_out[mat_start], interp_out, esize * nqpts * sizeof(CeedScalar), cudaMemcpyHostToDevice)); 1059 mat_start += esize * nqpts; 1060 } else if (eval_mode == CEED_EVAL_GRAD) { 1061 CeedCallCuda(ceed, cudaMemcpy(&asmb->d_B_out[mat_start], grad_out, dim * esize * nqpts * sizeof(CeedScalar), cudaMemcpyHostToDevice)); 1062 mat_start += dim * esize * nqpts; 1063 } 1064 } 1065 return CEED_ERROR_SUCCESS; 1066 } 1067 1068 //------------------------------------------------------------------------------ 1069 // Assemble matrix data for COO matrix of assembled operator. 1070 // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic. 1071 // 1072 // Note that this (and other assembly routines) currently assume only one 1073 // active input restriction/basis per operator (could have multiple basis eval 1074 // modes). 1075 // TODO: allow multiple active input restrictions/basis objects 1076 //------------------------------------------------------------------------------ 1077 static int CeedSingleOperatorAssemble_Cuda(CeedOperator op, CeedInt offset, CeedVector values) { 1078 Ceed ceed; 1079 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1080 CeedOperator_Cuda *impl; 1081 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1082 1083 // Setup 1084 if (!impl->asmb) { 1085 CeedCallBackend(CeedSingleOperatorAssembleSetup_Cuda(op)); 1086 assert(impl->asmb != NULL); 1087 } 1088 1089 // Assemble QFunction 1090 CeedVector assembled_qf; 1091 CeedElemRestriction rstr_q; 1092 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &rstr_q, CEED_REQUEST_IMMEDIATE)); 1093 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_q)); 1094 CeedScalar *values_array; 1095 CeedCallBackend(CeedVectorGetArrayWrite(values, CEED_MEM_DEVICE, &values_array)); 1096 values_array += offset; 1097 const CeedScalar *qf_array; 1098 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &qf_array)); 1099 1100 // Compute B^T D B 1101 const CeedInt nelem = impl->asmb->nelem; 1102 const CeedInt elemsPerBlock = impl->asmb->elemsPerBlock; 1103 const CeedInt grid = nelem / elemsPerBlock + ((nelem / elemsPerBlock * elemsPerBlock < nelem) ? 1 : 0); 1104 void *args[] = {&impl->asmb->d_B_in, &impl->asmb->d_B_out, &qf_array, &values_array}; 1105 CeedCallBackend( 1106 CeedRunKernelDimCuda(ceed, impl->asmb->linearAssemble, grid, impl->asmb->block_size_x, impl->asmb->block_size_y, elemsPerBlock, args)); 1107 1108 // Restore arrays 1109 CeedCallBackend(CeedVectorRestoreArray(values, &values_array)); 1110 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &qf_array)); 1111 1112 // Cleanup 1113 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1114 1115 return CEED_ERROR_SUCCESS; 1116 } 1117 1118 //------------------------------------------------------------------------------ 1119 // Create operator 1120 //------------------------------------------------------------------------------ 1121 int CeedOperatorCreate_Cuda(CeedOperator op) { 1122 Ceed ceed; 1123 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1124 CeedOperator_Cuda *impl; 1125 1126 CeedCallBackend(CeedCalloc(1, &impl)); 1127 CeedCallBackend(CeedOperatorSetData(op, impl)); 1128 1129 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Cuda)); 1130 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Cuda)); 1131 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonal_Cuda)); 1132 CeedCallBackend( 1133 CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal", CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda)); 1134 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssemble_Cuda)); 1135 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Cuda)); 1136 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda)); 1137 return CEED_ERROR_SUCCESS; 1138 } 1139 1140 //------------------------------------------------------------------------------ 1141