1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <assert.h> 12 #include <cuda.h> 13 #include <cuda_runtime.h> 14 #include <stdbool.h> 15 #include <string.h> 16 17 #include "../cuda/ceed-cuda-common.h" 18 #include "../cuda/ceed-cuda-compile.h" 19 #include "ceed-cuda-ref.h" 20 21 //------------------------------------------------------------------------------ 22 // Destroy operator 23 //------------------------------------------------------------------------------ 24 static int CeedOperatorDestroy_Cuda(CeedOperator op) { 25 CeedOperator_Cuda *impl; 26 27 CeedCallBackend(CeedOperatorGetData(op, &impl)); 28 29 // Apply data 30 CeedCallBackend(CeedFree(&impl->skip_rstr_in)); 31 for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { 32 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[i])); 33 } 34 CeedCallBackend(CeedFree(&impl->e_vecs)); 35 CeedCallBackend(CeedFree(&impl->input_states)); 36 37 for (CeedInt i = 0; i < impl->num_inputs; i++) { 38 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 39 } 40 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 41 42 for (CeedInt i = 0; i < impl->num_outputs; i++) { 43 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 44 } 45 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 46 CeedCallBackend(CeedVectorDestroy(&impl->point_coords_elem)); 47 48 // QFunction assembly data 49 for (CeedInt i = 0; i < impl->num_active_in; i++) { 50 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i])); 51 } 52 CeedCallBackend(CeedFree(&impl->qf_active_in)); 53 54 // Diag data 55 if (impl->diag) { 56 Ceed ceed; 57 58 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 59 if (impl->diag->module) { 60 CeedCallCuda(ceed, cuModuleUnload(impl->diag->module)); 61 } 62 if (impl->diag->module_point_block) { 63 CeedCallCuda(ceed, cuModuleUnload(impl->diag->module_point_block)); 64 } 65 CeedCallCuda(ceed, cudaFree(impl->diag->d_eval_modes_in)); 66 CeedCallCuda(ceed, cudaFree(impl->diag->d_eval_modes_out)); 67 CeedCallCuda(ceed, cudaFree(impl->diag->d_identity)); 68 CeedCallCuda(ceed, cudaFree(impl->diag->d_interp_in)); 69 CeedCallCuda(ceed, cudaFree(impl->diag->d_interp_out)); 70 CeedCallCuda(ceed, cudaFree(impl->diag->d_grad_in)); 71 CeedCallCuda(ceed, cudaFree(impl->diag->d_grad_out)); 72 CeedCallCuda(ceed, cudaFree(impl->diag->d_div_in)); 73 CeedCallCuda(ceed, cudaFree(impl->diag->d_div_out)); 74 CeedCallCuda(ceed, cudaFree(impl->diag->d_curl_in)); 75 CeedCallCuda(ceed, cudaFree(impl->diag->d_curl_out)); 76 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr)); 77 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr)); 78 CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag)); 79 CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag)); 80 } 81 CeedCallBackend(CeedFree(&impl->diag)); 82 83 if (impl->asmb) { 84 Ceed ceed; 85 86 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 87 CeedCallCuda(ceed, cuModuleUnload(impl->asmb->module)); 88 CeedCallCuda(ceed, cudaFree(impl->asmb->d_B_in)); 89 CeedCallCuda(ceed, cudaFree(impl->asmb->d_B_out)); 90 } 91 CeedCallBackend(CeedFree(&impl->asmb)); 92 93 CeedCallBackend(CeedFree(&impl)); 94 return CEED_ERROR_SUCCESS; 95 } 96 97 //------------------------------------------------------------------------------ 98 // Setup infields or outfields 99 //------------------------------------------------------------------------------ 100 static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, CeedVector *e_vecs, 101 CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) { 102 Ceed ceed; 103 CeedQFunctionField *qf_fields; 104 CeedOperatorField *op_fields; 105 106 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 107 if (is_input) { 108 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 109 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 110 } else { 111 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 112 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 113 } 114 115 // Loop over fields 116 for (CeedInt i = 0; i < num_fields; i++) { 117 bool is_strided = false, skip_restriction = false; 118 CeedSize q_size; 119 CeedInt size; 120 CeedEvalMode eval_mode; 121 CeedBasis basis; 122 123 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 124 if (eval_mode != CEED_EVAL_WEIGHT) { 125 CeedElemRestriction elem_rstr; 126 127 // Check whether this field can skip the element restriction: 128 // Must be passive input, with eval_mode NONE, and have a strided restriction with CEED_STRIDES_BACKEND. 129 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); 130 131 // First, check whether the field is input or output: 132 if (is_input) { 133 CeedVector vec; 134 135 // Check for passive input 136 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 137 if (vec != CEED_VECTOR_ACTIVE) { 138 // Check eval_mode 139 if (eval_mode == CEED_EVAL_NONE) { 140 // Check for strided restriction 141 CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided)); 142 if (is_strided) { 143 // Check if vector is already in preferred backend ordering 144 CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &skip_restriction)); 145 } 146 } 147 } 148 } 149 if (skip_restriction) { 150 // We do not need an E-Vector, but will use the input field vector's data directly in the operator application. 151 e_vecs[i + start_e] = NULL; 152 } else { 153 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs[i + start_e])); 154 } 155 } 156 157 switch (eval_mode) { 158 case CEED_EVAL_NONE: 159 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 160 q_size = (CeedSize)num_elem * Q * size; 161 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 162 break; 163 case CEED_EVAL_INTERP: 164 case CEED_EVAL_GRAD: 165 case CEED_EVAL_DIV: 166 case CEED_EVAL_CURL: 167 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 168 q_size = (CeedSize)num_elem * Q * size; 169 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 170 break; 171 case CEED_EVAL_WEIGHT: // Only on input fields 172 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 173 q_size = (CeedSize)num_elem * Q; 174 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 175 if (is_at_points) { 176 CeedInt num_points[num_elem]; 177 178 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = Q; 179 CeedCallBackend( 180 CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, CEED_VECTOR_NONE, q_vecs[i])); 181 } else { 182 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 183 } 184 break; 185 } 186 } 187 // Drop duplicate input restrictions 188 if (is_input) { 189 for (CeedInt i = 0; i < num_fields; i++) { 190 CeedVector vec_i; 191 CeedElemRestriction rstr_i; 192 193 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); 194 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); 195 for (CeedInt j = i + 1; j < num_fields; j++) { 196 CeedVector vec_j; 197 CeedElemRestriction rstr_j; 198 199 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); 200 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); 201 if (vec_i == vec_j && rstr_i == rstr_j) { 202 CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); 203 skip_rstr[j] = true; 204 } 205 } 206 } 207 } 208 return CEED_ERROR_SUCCESS; 209 } 210 211 //------------------------------------------------------------------------------ 212 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 213 //------------------------------------------------------------------------------ 214 static int CeedOperatorSetup_Cuda(CeedOperator op) { 215 Ceed ceed; 216 bool is_setup_done; 217 CeedInt Q, num_elem, num_input_fields, num_output_fields; 218 CeedQFunctionField *qf_input_fields, *qf_output_fields; 219 CeedQFunction qf; 220 CeedOperatorField *op_input_fields, *op_output_fields; 221 CeedOperator_Cuda *impl; 222 223 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 224 if (is_setup_done) return CEED_ERROR_SUCCESS; 225 226 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 227 CeedCallBackend(CeedOperatorGetData(op, &impl)); 228 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 229 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 230 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 231 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 232 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 233 234 // Allocate 235 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); 236 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); 237 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 238 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 239 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 240 impl->num_inputs = num_input_fields; 241 impl->num_outputs = num_output_fields; 242 243 // Set up infield and outfield e_vecs and q_vecs 244 // Infields 245 CeedCallBackend( 246 CeedOperatorSetupFields_Cuda(qf, op, true, false, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem)); 247 // Outfields 248 CeedCallBackend( 249 CeedOperatorSetupFields_Cuda(qf, op, false, false, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, Q, num_elem)); 250 251 CeedCallBackend(CeedOperatorSetSetupDone(op)); 252 return CEED_ERROR_SUCCESS; 253 } 254 255 //------------------------------------------------------------------------------ 256 // Setup Operator Inputs 257 //------------------------------------------------------------------------------ 258 static inline int CeedOperatorSetupInputs_Cuda(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 259 CeedVector in_vec, const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 260 CeedOperator_Cuda *impl, CeedRequest *request) { 261 for (CeedInt i = 0; i < num_input_fields; i++) { 262 CeedEvalMode eval_mode; 263 CeedVector vec; 264 CeedElemRestriction elem_rstr; 265 266 // Get input vector 267 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 268 if (vec == CEED_VECTOR_ACTIVE) { 269 if (skip_active) continue; 270 else vec = in_vec; 271 } 272 273 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 274 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 275 } else { 276 // Get input vector 277 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 278 // Get input element restriction 279 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 280 if (vec == CEED_VECTOR_ACTIVE) vec = in_vec; 281 // Restrict, if necessary 282 if (!impl->e_vecs[i]) { 283 // No restriction for this field; read data directly from vec. 284 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i])); 285 } else { 286 uint64_t state; 287 288 CeedCallBackend(CeedVectorGetState(vec, &state)); 289 if (state != impl->input_states[i] && !impl->skip_rstr_in[i]) { 290 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs[i], request)); 291 } 292 impl->input_states[i] = state; 293 // Get evec 294 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs[i], CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i])); 295 } 296 } 297 } 298 return CEED_ERROR_SUCCESS; 299 } 300 301 //------------------------------------------------------------------------------ 302 // Input Basis Action 303 //------------------------------------------------------------------------------ 304 static inline int CeedOperatorInputBasis_Cuda(CeedInt num_elem, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 305 CeedInt num_input_fields, const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 306 CeedOperator_Cuda *impl) { 307 for (CeedInt i = 0; i < num_input_fields; i++) { 308 CeedInt elem_size, size; 309 CeedEvalMode eval_mode; 310 CeedElemRestriction elem_rstr; 311 CeedBasis basis; 312 313 // Skip active input 314 if (skip_active) { 315 CeedVector vec; 316 317 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 318 if (vec == CEED_VECTOR_ACTIVE) continue; 319 } 320 // Get elem_size, eval_mode, size 321 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 322 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 323 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 324 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 325 // Basis action 326 switch (eval_mode) { 327 case CEED_EVAL_NONE: 328 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 329 break; 330 case CEED_EVAL_INTERP: 331 case CEED_EVAL_GRAD: 332 case CEED_EVAL_DIV: 333 case CEED_EVAL_CURL: 334 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 335 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs[i], impl->q_vecs_in[i])); 336 break; 337 case CEED_EVAL_WEIGHT: 338 break; // No action 339 } 340 } 341 return CEED_ERROR_SUCCESS; 342 } 343 344 //------------------------------------------------------------------------------ 345 // Restore Input Vectors 346 //------------------------------------------------------------------------------ 347 static inline int CeedOperatorRestoreInputs_Cuda(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 348 const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Cuda *impl) { 349 for (CeedInt i = 0; i < num_input_fields; i++) { 350 CeedEvalMode eval_mode; 351 CeedVector vec; 352 353 // Skip active input 354 if (skip_active) { 355 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 356 if (vec == CEED_VECTOR_ACTIVE) continue; 357 } 358 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 359 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 360 } else { 361 if (!impl->e_vecs[i]) { // This was a skip_restriction case 362 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 363 CeedCallBackend(CeedVectorRestoreArrayRead(vec, (const CeedScalar **)&e_data[i])); 364 } else { 365 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs[i], (const CeedScalar **)&e_data[i])); 366 } 367 } 368 } 369 return CEED_ERROR_SUCCESS; 370 } 371 372 //------------------------------------------------------------------------------ 373 // Apply and add to output 374 //------------------------------------------------------------------------------ 375 static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 376 CeedInt Q, num_elem, elem_size, num_input_fields, num_output_fields, size; 377 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 378 CeedQFunctionField *qf_input_fields, *qf_output_fields; 379 CeedQFunction qf; 380 CeedOperatorField *op_input_fields, *op_output_fields; 381 CeedOperator_Cuda *impl; 382 383 CeedCallBackend(CeedOperatorGetData(op, &impl)); 384 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 385 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 386 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 387 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 388 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 389 390 // Setup 391 CeedCallBackend(CeedOperatorSetup_Cuda(op)); 392 393 // Input Evecs and Restriction 394 CeedCallBackend(CeedOperatorSetupInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data, impl, request)); 395 396 // Input basis apply if needed 397 CeedCallBackend(CeedOperatorInputBasis_Cuda(num_elem, qf_input_fields, op_input_fields, num_input_fields, false, e_data, impl)); 398 399 // Output pointers, as necessary 400 for (CeedInt i = 0; i < num_output_fields; i++) { 401 CeedEvalMode eval_mode; 402 403 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 404 if (eval_mode == CEED_EVAL_NONE) { 405 // Set the output Q-Vector to use the E-Vector data directly. 406 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 407 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 408 } 409 } 410 411 // Q function 412 CeedCallBackend(CeedQFunctionApply(qf, num_elem * Q, impl->q_vecs_in, impl->q_vecs_out)); 413 414 // Output basis apply if needed 415 for (CeedInt i = 0; i < num_output_fields; i++) { 416 CeedEvalMode eval_mode; 417 CeedElemRestriction elem_rstr; 418 CeedBasis basis; 419 420 // Get elem_size, eval_mode, size 421 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 422 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 423 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 424 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 425 // Basis action 426 switch (eval_mode) { 427 case CEED_EVAL_NONE: 428 break; // No action 429 case CEED_EVAL_INTERP: 430 case CEED_EVAL_GRAD: 431 case CEED_EVAL_DIV: 432 case CEED_EVAL_CURL: 433 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 434 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); 435 break; 436 // LCOV_EXCL_START 437 case CEED_EVAL_WEIGHT: { 438 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 439 // LCOV_EXCL_STOP 440 } 441 } 442 } 443 444 // Output restriction 445 for (CeedInt i = 0; i < num_output_fields; i++) { 446 CeedEvalMode eval_mode; 447 CeedVector vec; 448 CeedElemRestriction elem_rstr; 449 450 // Restore evec 451 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 452 if (eval_mode == CEED_EVAL_NONE) { 453 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 454 } 455 // Get output vector 456 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 457 // Restrict 458 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 459 // Active 460 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 461 462 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request)); 463 } 464 465 // Restore input arrays 466 CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl)); 467 return CEED_ERROR_SUCCESS; 468 } 469 470 //------------------------------------------------------------------------------ 471 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 472 //------------------------------------------------------------------------------ 473 static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) { 474 Ceed ceed; 475 bool is_setup_done; 476 CeedInt max_num_points = -1, num_elem, num_input_fields, num_output_fields; 477 CeedQFunctionField *qf_input_fields, *qf_output_fields; 478 CeedQFunction qf; 479 CeedOperatorField *op_input_fields, *op_output_fields; 480 CeedOperator_Cuda *impl; 481 482 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 483 if (is_setup_done) return CEED_ERROR_SUCCESS; 484 485 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 486 CeedCallBackend(CeedOperatorGetData(op, &impl)); 487 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 488 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 489 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 490 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 491 { 492 CeedElemRestriction elem_rstr = NULL; 493 494 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &elem_rstr, NULL)); 495 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(elem_rstr, &max_num_points)); 496 } 497 impl->max_num_points = max_num_points; 498 499 // Allocate 500 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); 501 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); 502 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 503 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 504 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 505 impl->num_inputs = num_input_fields; 506 impl->num_outputs = num_output_fields; 507 508 // Set up infield and outfield e_vecs and q_vecs 509 // Infields 510 CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, true, true, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, 511 max_num_points, num_elem)); 512 // Outfields 513 CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, true, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, 514 max_num_points, num_elem)); 515 516 CeedCallBackend(CeedOperatorSetSetupDone(op)); 517 return CEED_ERROR_SUCCESS; 518 } 519 520 //------------------------------------------------------------------------------ 521 // Input Basis Action AtPoints 522 //------------------------------------------------------------------------------ 523 static inline int CeedOperatorInputBasisAtPoints_Cuda(CeedInt num_elem, const CeedInt *num_points, CeedQFunctionField *qf_input_fields, 524 CeedOperatorField *op_input_fields, CeedInt num_input_fields, const bool skip_active, 525 CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Cuda *impl) { 526 for (CeedInt i = 0; i < num_input_fields; i++) { 527 CeedInt elem_size, size; 528 CeedEvalMode eval_mode; 529 CeedElemRestriction elem_rstr; 530 CeedBasis basis; 531 532 // Skip active input 533 if (skip_active) { 534 CeedVector vec; 535 536 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 537 if (vec == CEED_VECTOR_ACTIVE) continue; 538 } 539 // Get elem_size, eval_mode, size 540 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 541 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 542 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 543 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 544 // Basis action 545 switch (eval_mode) { 546 case CEED_EVAL_NONE: 547 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 548 break; 549 case CEED_EVAL_INTERP: 550 case CEED_EVAL_GRAD: 551 case CEED_EVAL_DIV: 552 case CEED_EVAL_CURL: 553 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 554 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, impl->e_vecs[i], 555 impl->q_vecs_in[i])); 556 break; 557 case CEED_EVAL_WEIGHT: 558 break; // No action 559 } 560 } 561 return CEED_ERROR_SUCCESS; 562 } 563 564 //------------------------------------------------------------------------------ 565 // Apply and add to output AtPoints 566 //------------------------------------------------------------------------------ 567 static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 568 CeedInt max_num_points, num_elem, elem_size, num_input_fields, num_output_fields, size; 569 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 570 CeedQFunctionField *qf_input_fields, *qf_output_fields; 571 CeedQFunction qf; 572 CeedOperatorField *op_input_fields, *op_output_fields; 573 CeedOperator_Cuda *impl; 574 575 CeedCallBackend(CeedOperatorGetData(op, &impl)); 576 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 577 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 578 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 579 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 580 CeedInt num_points[num_elem]; 581 582 // Setup 583 CeedCallBackend(CeedOperatorSetupAtPoints_Cuda(op)); 584 max_num_points = impl->max_num_points; 585 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = max_num_points; 586 587 // Input Evecs and Restriction 588 CeedCallBackend(CeedOperatorSetupInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data, impl, request)); 589 590 // Get point coordinates 591 if (!impl->point_coords_elem) { 592 CeedVector point_coords = NULL; 593 CeedElemRestriction rstr_points = NULL; 594 595 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 596 CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem)); 597 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 598 } 599 600 // Input basis apply if needed 601 CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(num_elem, num_points, qf_input_fields, op_input_fields, num_input_fields, false, e_data, impl)); 602 603 // Output pointers, as necessary 604 for (CeedInt i = 0; i < num_output_fields; i++) { 605 CeedEvalMode eval_mode; 606 607 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 608 if (eval_mode == CEED_EVAL_NONE) { 609 // Set the output Q-Vector to use the E-Vector data directly. 610 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 611 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 612 } 613 } 614 615 // Q function 616 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out)); 617 618 // Output basis apply if needed 619 for (CeedInt i = 0; i < num_output_fields; i++) { 620 CeedEvalMode eval_mode; 621 CeedElemRestriction elem_rstr; 622 CeedBasis basis; 623 624 // Get elem_size, eval_mode, size 625 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 626 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 627 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 628 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 629 // Basis action 630 switch (eval_mode) { 631 case CEED_EVAL_NONE: 632 break; // No action 633 case CEED_EVAL_INTERP: 634 case CEED_EVAL_GRAD: 635 case CEED_EVAL_DIV: 636 case CEED_EVAL_CURL: 637 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 638 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i], 639 impl->e_vecs[i + impl->num_inputs])); 640 break; 641 // LCOV_EXCL_START 642 case CEED_EVAL_WEIGHT: { 643 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 644 // LCOV_EXCL_STOP 645 } 646 } 647 } 648 649 // Output restriction 650 for (CeedInt i = 0; i < num_output_fields; i++) { 651 CeedEvalMode eval_mode; 652 CeedVector vec; 653 CeedElemRestriction elem_rstr; 654 655 // Restore evec 656 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 657 if (eval_mode == CEED_EVAL_NONE) { 658 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 659 } 660 // Get output vector 661 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 662 // Restrict 663 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 664 // Active 665 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 666 667 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request)); 668 } 669 670 // Restore input arrays 671 CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl)); 672 return CEED_ERROR_SUCCESS; 673 } 674 675 //------------------------------------------------------------------------------ 676 // Linear QFunction Assembly Core 677 //------------------------------------------------------------------------------ 678 static inline int CeedOperatorLinearAssembleQFunctionCore_Cuda(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 679 CeedRequest *request) { 680 Ceed ceed, ceed_parent; 681 CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size; 682 CeedScalar *assembled_array, *e_data[2 * CEED_FIELD_MAX] = {NULL}; 683 CeedVector *active_inputs; 684 CeedQFunctionField *qf_input_fields, *qf_output_fields; 685 CeedQFunction qf; 686 CeedOperatorField *op_input_fields, *op_output_fields; 687 CeedOperator_Cuda *impl; 688 689 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 690 CeedCallBackend(CeedOperatorGetFallbackParentCeed(op, &ceed_parent)); 691 CeedCallBackend(CeedOperatorGetData(op, &impl)); 692 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 693 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 694 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 695 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 696 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 697 active_inputs = impl->qf_active_in; 698 num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 699 700 // Setup 701 CeedCallBackend(CeedOperatorSetup_Cuda(op)); 702 703 // Input Evecs and Restriction 704 CeedCallBackend(CeedOperatorSetupInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 705 706 // Count number of active input fields 707 if (!num_active_in) { 708 for (CeedInt i = 0; i < num_input_fields; i++) { 709 CeedScalar *q_vec_array; 710 CeedVector vec; 711 712 // Get input vector 713 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 714 // Check if active input 715 if (vec == CEED_VECTOR_ACTIVE) { 716 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 717 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 718 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, &q_vec_array)); 719 CeedCallBackend(CeedRealloc(num_active_in + size, &active_inputs)); 720 for (CeedInt field = 0; field < size; field++) { 721 CeedSize q_size = (CeedSize)Q * num_elem; 722 723 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_inputs[num_active_in + field])); 724 CeedCallBackend( 725 CeedVectorSetArray(active_inputs[num_active_in + field], CEED_MEM_DEVICE, CEED_USE_POINTER, &q_vec_array[field * Q * num_elem])); 726 } 727 num_active_in += size; 728 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array)); 729 } 730 } 731 impl->num_active_in = num_active_in; 732 impl->qf_active_in = active_inputs; 733 } 734 735 // Count number of active output fields 736 if (!num_active_out) { 737 for (CeedInt i = 0; i < num_output_fields; i++) { 738 CeedVector vec; 739 740 // Get output vector 741 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 742 // Check if active output 743 if (vec == CEED_VECTOR_ACTIVE) { 744 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 745 num_active_out += size; 746 } 747 } 748 impl->num_active_out = num_active_out; 749 } 750 751 // Check sizes 752 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 753 754 // Build objects if needed 755 if (build_objects) { 756 CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out; 757 CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */ 758 759 // Create output restriction 760 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out, 761 (CeedSize)num_active_in * (CeedSize)num_active_out * (CeedSize)num_elem * (CeedSize)Q, strides, 762 rstr)); 763 // Create assembled vector 764 CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled)); 765 } 766 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 767 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_DEVICE, &assembled_array)); 768 769 // Input basis apply 770 CeedCallBackend(CeedOperatorInputBasis_Cuda(num_elem, qf_input_fields, op_input_fields, num_input_fields, true, e_data, impl)); 771 772 // Assemble QFunction 773 for (CeedInt in = 0; in < num_active_in; in++) { 774 // Set Inputs 775 CeedCallBackend(CeedVectorSetValue(active_inputs[in], 1.0)); 776 if (num_active_in > 1) { 777 CeedCallBackend(CeedVectorSetValue(active_inputs[(in + num_active_in - 1) % num_active_in], 0.0)); 778 } 779 // Set Outputs 780 for (CeedInt out = 0; out < num_output_fields; out++) { 781 CeedVector vec; 782 783 // Get output vector 784 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 785 // Check if active output 786 if (vec == CEED_VECTOR_ACTIVE) { 787 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, CEED_USE_POINTER, assembled_array)); 788 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size)); 789 assembled_array += size * Q * num_elem; // Advance the pointer by the size of the output 790 } 791 } 792 // Apply QFunction 793 CeedCallBackend(CeedQFunctionApply(qf, Q * num_elem, impl->q_vecs_in, impl->q_vecs_out)); 794 } 795 796 // Un-set output q_vecs to prevent accidental overwrite of Assembled 797 for (CeedInt out = 0; out < num_output_fields; out++) { 798 CeedVector vec; 799 800 // Get output vector 801 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 802 // Check if active output 803 if (vec == CEED_VECTOR_ACTIVE) { 804 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, NULL)); 805 } 806 } 807 808 // Restore input arrays 809 CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 810 811 // Restore output 812 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 813 return CEED_ERROR_SUCCESS; 814 } 815 816 //------------------------------------------------------------------------------ 817 // Assemble Linear QFunction 818 //------------------------------------------------------------------------------ 819 static int CeedOperatorLinearAssembleQFunction_Cuda(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 820 return CeedOperatorLinearAssembleQFunctionCore_Cuda(op, true, assembled, rstr, request); 821 } 822 823 //------------------------------------------------------------------------------ 824 // Update Assembled Linear QFunction 825 //------------------------------------------------------------------------------ 826 static int CeedOperatorLinearAssembleQFunctionUpdate_Cuda(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 827 return CeedOperatorLinearAssembleQFunctionCore_Cuda(op, false, &assembled, &rstr, request); 828 } 829 830 //------------------------------------------------------------------------------ 831 // Assemble Diagonal Setup 832 //------------------------------------------------------------------------------ 833 static inline int CeedOperatorAssembleDiagonalSetup_Cuda(CeedOperator op) { 834 Ceed ceed; 835 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 836 CeedInt q_comp, num_nodes, num_qpts; 837 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 838 CeedBasis basis_in = NULL, basis_out = NULL; 839 CeedQFunctionField *qf_fields; 840 CeedQFunction qf; 841 CeedOperatorField *op_fields; 842 CeedOperator_Cuda *impl; 843 844 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 845 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 846 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 847 848 // Determine active input basis 849 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 850 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 851 for (CeedInt i = 0; i < num_input_fields; i++) { 852 CeedVector vec; 853 854 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 855 if (vec == CEED_VECTOR_ACTIVE) { 856 CeedBasis basis; 857 CeedEvalMode eval_mode; 858 859 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 860 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, 861 "Backend does not implement operator diagonal assembly with multiple active bases"); 862 basis_in = basis; 863 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 864 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 865 if (eval_mode != CEED_EVAL_WEIGHT) { 866 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 867 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 868 for (CeedInt d = 0; d < q_comp; d++) eval_modes_in[num_eval_modes_in + d] = eval_mode; 869 num_eval_modes_in += q_comp; 870 } 871 } 872 } 873 874 // Determine active output basis 875 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 876 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 877 for (CeedInt i = 0; i < num_output_fields; i++) { 878 CeedVector vec; 879 880 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 881 if (vec == CEED_VECTOR_ACTIVE) { 882 CeedBasis basis; 883 CeedEvalMode eval_mode; 884 885 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 886 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 887 "Backend does not implement operator diagonal assembly with multiple active bases"); 888 basis_out = basis; 889 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 890 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 891 if (eval_mode != CEED_EVAL_WEIGHT) { 892 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 893 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 894 for (CeedInt d = 0; d < q_comp; d++) eval_modes_out[num_eval_modes_out + d] = eval_mode; 895 num_eval_modes_out += q_comp; 896 } 897 } 898 } 899 900 // Operator data struct 901 CeedCallBackend(CeedOperatorGetData(op, &impl)); 902 CeedCallBackend(CeedCalloc(1, &impl->diag)); 903 CeedOperatorDiag_Cuda *diag = impl->diag; 904 905 // Basis matrices 906 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes)); 907 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes; 908 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); 909 const CeedInt interp_bytes = num_nodes * num_qpts * sizeof(CeedScalar); 910 const CeedInt eval_modes_bytes = sizeof(CeedEvalMode); 911 bool has_eval_none = false; 912 913 // CEED_EVAL_NONE 914 for (CeedInt i = 0; i < num_eval_modes_in; i++) has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 915 for (CeedInt i = 0; i < num_eval_modes_out; i++) has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 916 if (has_eval_none) { 917 CeedScalar *identity = NULL; 918 919 CeedCallBackend(CeedCalloc(num_nodes * num_qpts, &identity)); 920 for (CeedInt i = 0; i < (num_nodes < num_qpts ? num_nodes : num_qpts); i++) identity[i * num_nodes + i] = 1.0; 921 CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_identity, interp_bytes)); 922 CeedCallCuda(ceed, cudaMemcpy(diag->d_identity, identity, interp_bytes, cudaMemcpyHostToDevice)); 923 CeedCallBackend(CeedFree(&identity)); 924 } 925 926 // CEED_EVAL_INTERP, CEED_EVAL_GRAD, CEED_EVAL_DIV, and CEED_EVAL_CURL 927 for (CeedInt in = 0; in < 2; in++) { 928 CeedFESpace fespace; 929 CeedBasis basis = in ? basis_in : basis_out; 930 931 CeedCallBackend(CeedBasisGetFESpace(basis, &fespace)); 932 switch (fespace) { 933 case CEED_FE_SPACE_H1: { 934 CeedInt q_comp_interp, q_comp_grad; 935 const CeedScalar *interp, *grad; 936 CeedScalar *d_interp, *d_grad; 937 938 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 939 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_grad)); 940 941 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 942 CeedCallCuda(ceed, cudaMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 943 CeedCallCuda(ceed, cudaMemcpy(d_interp, interp, interp_bytes * q_comp_interp, cudaMemcpyHostToDevice)); 944 CeedCallBackend(CeedBasisGetGrad(basis, &grad)); 945 CeedCallCuda(ceed, cudaMalloc((void **)&d_grad, interp_bytes * q_comp_grad)); 946 CeedCallCuda(ceed, cudaMemcpy(d_grad, grad, interp_bytes * q_comp_grad, cudaMemcpyHostToDevice)); 947 if (in) { 948 diag->d_interp_in = d_interp; 949 diag->d_grad_in = d_grad; 950 } else { 951 diag->d_interp_out = d_interp; 952 diag->d_grad_out = d_grad; 953 } 954 } break; 955 case CEED_FE_SPACE_HDIV: { 956 CeedInt q_comp_interp, q_comp_div; 957 const CeedScalar *interp, *div; 958 CeedScalar *d_interp, *d_div; 959 960 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 961 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_div)); 962 963 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 964 CeedCallCuda(ceed, cudaMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 965 CeedCallCuda(ceed, cudaMemcpy(d_interp, interp, interp_bytes * q_comp_interp, cudaMemcpyHostToDevice)); 966 CeedCallBackend(CeedBasisGetDiv(basis, &div)); 967 CeedCallCuda(ceed, cudaMalloc((void **)&d_div, interp_bytes * q_comp_div)); 968 CeedCallCuda(ceed, cudaMemcpy(d_div, div, interp_bytes * q_comp_div, cudaMemcpyHostToDevice)); 969 if (in) { 970 diag->d_interp_in = d_interp; 971 diag->d_div_in = d_div; 972 } else { 973 diag->d_interp_out = d_interp; 974 diag->d_div_out = d_div; 975 } 976 } break; 977 case CEED_FE_SPACE_HCURL: { 978 CeedInt q_comp_interp, q_comp_curl; 979 const CeedScalar *interp, *curl; 980 CeedScalar *d_interp, *d_curl; 981 982 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 983 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_curl)); 984 985 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 986 CeedCallCuda(ceed, cudaMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 987 CeedCallCuda(ceed, cudaMemcpy(d_interp, interp, interp_bytes * q_comp_interp, cudaMemcpyHostToDevice)); 988 CeedCallBackend(CeedBasisGetCurl(basis, &curl)); 989 CeedCallCuda(ceed, cudaMalloc((void **)&d_curl, interp_bytes * q_comp_curl)); 990 CeedCallCuda(ceed, cudaMemcpy(d_curl, curl, interp_bytes * q_comp_curl, cudaMemcpyHostToDevice)); 991 if (in) { 992 diag->d_interp_in = d_interp; 993 diag->d_curl_in = d_curl; 994 } else { 995 diag->d_interp_out = d_interp; 996 diag->d_curl_out = d_curl; 997 } 998 } break; 999 } 1000 } 1001 1002 // Arrays of eval_modes 1003 CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_eval_modes_in, num_eval_modes_in * eval_modes_bytes)); 1004 CeedCallCuda(ceed, cudaMemcpy(diag->d_eval_modes_in, eval_modes_in, num_eval_modes_in * eval_modes_bytes, cudaMemcpyHostToDevice)); 1005 CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_eval_modes_out, num_eval_modes_out * eval_modes_bytes)); 1006 CeedCallCuda(ceed, cudaMemcpy(diag->d_eval_modes_out, eval_modes_out, num_eval_modes_out * eval_modes_bytes, cudaMemcpyHostToDevice)); 1007 CeedCallBackend(CeedFree(&eval_modes_in)); 1008 CeedCallBackend(CeedFree(&eval_modes_out)); 1009 return CEED_ERROR_SUCCESS; 1010 } 1011 1012 //------------------------------------------------------------------------------ 1013 // Assemble Diagonal Setup (Compilation) 1014 //------------------------------------------------------------------------------ 1015 static inline int CeedOperatorAssembleDiagonalSetupCompile_Cuda(CeedOperator op, CeedInt use_ceedsize_idx, const bool is_point_block) { 1016 Ceed ceed; 1017 char *diagonal_kernel_source; 1018 const char *diagonal_kernel_path; 1019 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 1020 CeedInt num_comp, q_comp, num_nodes, num_qpts; 1021 CeedBasis basis_in = NULL, basis_out = NULL; 1022 CeedQFunctionField *qf_fields; 1023 CeedQFunction qf; 1024 CeedOperatorField *op_fields; 1025 CeedOperator_Cuda *impl; 1026 1027 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1028 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1029 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 1030 1031 // Determine active input basis 1032 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 1033 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 1034 for (CeedInt i = 0; i < num_input_fields; i++) { 1035 CeedVector vec; 1036 1037 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1038 if (vec == CEED_VECTOR_ACTIVE) { 1039 CeedEvalMode eval_mode; 1040 1041 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_in)); 1042 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1043 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 1044 if (eval_mode != CEED_EVAL_WEIGHT) { 1045 num_eval_modes_in += q_comp; 1046 } 1047 } 1048 } 1049 1050 // Determine active output basis 1051 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 1052 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 1053 for (CeedInt i = 0; i < num_output_fields; i++) { 1054 CeedVector vec; 1055 1056 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1057 if (vec == CEED_VECTOR_ACTIVE) { 1058 CeedEvalMode eval_mode; 1059 1060 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_out)); 1061 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1062 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1063 if (eval_mode != CEED_EVAL_WEIGHT) { 1064 num_eval_modes_out += q_comp; 1065 } 1066 } 1067 } 1068 1069 // Operator data struct 1070 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1071 CeedOperatorDiag_Cuda *diag = impl->diag; 1072 1073 // Assemble kernel 1074 CUmodule *module = is_point_block ? &diag->module_point_block : &diag->module; 1075 CeedInt elems_per_block = 1; 1076 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes)); 1077 CeedCallBackend(CeedBasisGetNumComponents(basis_in, &num_comp)); 1078 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes; 1079 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); 1080 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-operator-assemble-diagonal.h", &diagonal_kernel_path)); 1081 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Kernel Source -----\n"); 1082 CeedCallBackend(CeedLoadSourceToBuffer(ceed, diagonal_kernel_path, &diagonal_kernel_source)); 1083 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Source Complete! -----\n"); 1084 CeedCallCuda(ceed, CeedCompile_Cuda(ceed, diagonal_kernel_source, module, 8, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 1085 num_eval_modes_out, "NUM_COMP", num_comp, "NUM_NODES", num_nodes, "NUM_QPTS", num_qpts, "USE_CEEDSIZE", 1086 use_ceedsize_idx, "USE_POINT_BLOCK", is_point_block ? 1 : 0, "BLOCK_SIZE", num_nodes * elems_per_block)); 1087 CeedCallCuda(ceed, CeedGetKernel_Cuda(ceed, *module, "LinearDiagonal", is_point_block ? &diag->LinearPointBlock : &diag->LinearDiagonal)); 1088 CeedCallBackend(CeedFree(&diagonal_kernel_path)); 1089 CeedCallBackend(CeedFree(&diagonal_kernel_source)); 1090 return CEED_ERROR_SUCCESS; 1091 } 1092 1093 //------------------------------------------------------------------------------ 1094 // Assemble Diagonal Core 1095 //------------------------------------------------------------------------------ 1096 static inline int CeedOperatorAssembleDiagonalCore_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool is_point_block) { 1097 Ceed ceed; 1098 CeedInt num_elem, num_nodes; 1099 CeedScalar *elem_diag_array; 1100 const CeedScalar *assembled_qf_array; 1101 CeedVector assembled_qf = NULL, elem_diag; 1102 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out, diag_rstr; 1103 CeedOperator_Cuda *impl; 1104 1105 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1106 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1107 1108 // Assemble QFunction 1109 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, request)); 1110 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 1111 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 1112 1113 // Setup 1114 if (!impl->diag) CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Cuda(op)); 1115 CeedOperatorDiag_Cuda *diag = impl->diag; 1116 1117 assert(diag != NULL); 1118 1119 // Assemble kernel if needed 1120 if ((!is_point_block && !diag->LinearDiagonal) || (is_point_block && !diag->LinearPointBlock)) { 1121 CeedSize assembled_length, assembled_qf_length; 1122 CeedInt use_ceedsize_idx = 0; 1123 CeedCallBackend(CeedVectorGetLength(assembled, &assembled_length)); 1124 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 1125 if ((assembled_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 1126 1127 CeedCallBackend(CeedOperatorAssembleDiagonalSetupCompile_Cuda(op, use_ceedsize_idx, is_point_block)); 1128 } 1129 1130 // Restriction and diagonal vector 1131 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 1132 CeedCheck(rstr_in == rstr_out, ceed, CEED_ERROR_BACKEND, 1133 "Cannot assemble operator diagonal with different input and output active element restrictions"); 1134 if (!is_point_block && !diag->diag_rstr) { 1135 CeedCallBackend(CeedElemRestrictionCreateUnsignedCopy(rstr_out, &diag->diag_rstr)); 1136 CeedCallBackend(CeedElemRestrictionCreateVector(diag->diag_rstr, NULL, &diag->elem_diag)); 1137 } else if (is_point_block && !diag->point_block_diag_rstr) { 1138 CeedCallBackend(CeedOperatorCreateActivePointBlockRestriction(rstr_out, &diag->point_block_diag_rstr)); 1139 CeedCallBackend(CeedElemRestrictionCreateVector(diag->point_block_diag_rstr, NULL, &diag->point_block_elem_diag)); 1140 } 1141 diag_rstr = is_point_block ? diag->point_block_diag_rstr : diag->diag_rstr; 1142 elem_diag = is_point_block ? diag->point_block_elem_diag : diag->elem_diag; 1143 CeedCallBackend(CeedVectorSetValue(elem_diag, 0.0)); 1144 1145 // Only assemble diagonal if the basis has nodes, otherwise inputs are null pointers 1146 CeedCallBackend(CeedElemRestrictionGetElementSize(diag_rstr, &num_nodes)); 1147 if (num_nodes > 0) { 1148 // Assemble element operator diagonals 1149 CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem)); 1150 CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array)); 1151 1152 // Compute the diagonal of B^T D B 1153 CeedInt elems_per_block = 1; 1154 CeedInt grid = CeedDivUpInt(num_elem, elems_per_block); 1155 void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_div_in, 1156 &diag->d_curl_in, &diag->d_interp_out, &diag->d_grad_out, &diag->d_div_out, &diag->d_curl_out, 1157 &diag->d_eval_modes_in, &diag->d_eval_modes_out, &assembled_qf_array, &elem_diag_array}; 1158 1159 if (is_point_block) { 1160 CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->LinearPointBlock, grid, num_nodes, 1, elems_per_block, args)); 1161 } else { 1162 CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->LinearDiagonal, grid, num_nodes, 1, elems_per_block, args)); 1163 } 1164 1165 // Restore arrays 1166 CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array)); 1167 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 1168 } 1169 1170 // Assemble local operator diagonal 1171 CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request)); 1172 1173 // Cleanup 1174 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1175 return CEED_ERROR_SUCCESS; 1176 } 1177 1178 //------------------------------------------------------------------------------ 1179 // Assemble Linear Diagonal 1180 //------------------------------------------------------------------------------ 1181 static int CeedOperatorLinearAssembleAddDiagonal_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1182 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Cuda(op, assembled, request, false)); 1183 return CEED_ERROR_SUCCESS; 1184 } 1185 1186 //------------------------------------------------------------------------------ 1187 // Assemble Linear Point Block Diagonal 1188 //------------------------------------------------------------------------------ 1189 static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1190 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Cuda(op, assembled, request, true)); 1191 return CEED_ERROR_SUCCESS; 1192 } 1193 1194 //------------------------------------------------------------------------------ 1195 // Single Operator Assembly Setup 1196 //------------------------------------------------------------------------------ 1197 static int CeedSingleOperatorAssembleSetup_Cuda(CeedOperator op, CeedInt use_ceedsize_idx) { 1198 Ceed ceed; 1199 Ceed_Cuda *cuda_data; 1200 char *assembly_kernel_source; 1201 const char *assembly_kernel_path; 1202 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 1203 CeedInt elem_size_in, num_qpts_in = 0, num_comp_in, elem_size_out, num_qpts_out, num_comp_out, q_comp; 1204 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 1205 CeedElemRestriction rstr_in = NULL, rstr_out = NULL; 1206 CeedBasis basis_in = NULL, basis_out = NULL; 1207 CeedQFunctionField *qf_fields; 1208 CeedQFunction qf; 1209 CeedOperatorField *input_fields, *output_fields; 1210 CeedOperator_Cuda *impl; 1211 1212 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1213 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1214 1215 // Get intput and output fields 1216 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields)); 1217 1218 // Determine active input basis eval mode 1219 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1220 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 1221 for (CeedInt i = 0; i < num_input_fields; i++) { 1222 CeedVector vec; 1223 1224 CeedCallBackend(CeedOperatorFieldGetVector(input_fields[i], &vec)); 1225 if (vec == CEED_VECTOR_ACTIVE) { 1226 CeedBasis basis; 1227 CeedEvalMode eval_mode; 1228 1229 CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis)); 1230 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases"); 1231 basis_in = basis; 1232 CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &rstr_in)); 1233 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 1234 if (basis_in == CEED_BASIS_NONE) num_qpts_in = elem_size_in; 1235 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts_in)); 1236 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1237 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 1238 if (eval_mode != CEED_EVAL_WEIGHT) { 1239 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 1240 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 1241 for (CeedInt d = 0; d < q_comp; d++) { 1242 eval_modes_in[num_eval_modes_in + d] = eval_mode; 1243 } 1244 num_eval_modes_in += q_comp; 1245 } 1246 } 1247 } 1248 1249 // Determine active output basis; basis_out and rstr_out only used if same as input, TODO 1250 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 1251 for (CeedInt i = 0; i < num_output_fields; i++) { 1252 CeedVector vec; 1253 1254 CeedCallBackend(CeedOperatorFieldGetVector(output_fields[i], &vec)); 1255 if (vec == CEED_VECTOR_ACTIVE) { 1256 CeedBasis basis; 1257 CeedEvalMode eval_mode; 1258 1259 CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis)); 1260 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 1261 "Backend does not implement operator assembly with multiple active bases"); 1262 basis_out = basis; 1263 CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &rstr_out)); 1264 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 1265 if (basis_out == CEED_BASIS_NONE) num_qpts_out = elem_size_out; 1266 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_out, &num_qpts_out)); 1267 CeedCheck(num_qpts_in == num_qpts_out, ceed, CEED_ERROR_UNSUPPORTED, 1268 "Active input and output bases must have the same number of quadrature points"); 1269 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1270 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1271 if (eval_mode != CEED_EVAL_WEIGHT) { 1272 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 1273 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 1274 for (CeedInt d = 0; d < q_comp; d++) { 1275 eval_modes_out[num_eval_modes_out + d] = eval_mode; 1276 } 1277 num_eval_modes_out += q_comp; 1278 } 1279 } 1280 } 1281 CeedCheck(num_eval_modes_in > 0 && num_eval_modes_out > 0, ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs"); 1282 1283 CeedCallBackend(CeedCalloc(1, &impl->asmb)); 1284 CeedOperatorAssemble_Cuda *asmb = impl->asmb; 1285 asmb->elems_per_block = 1; 1286 asmb->block_size_x = elem_size_in; 1287 asmb->block_size_y = elem_size_out; 1288 1289 CeedCallBackend(CeedGetData(ceed, &cuda_data)); 1290 bool fallback = asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block > cuda_data->device_prop.maxThreadsPerBlock; 1291 1292 if (fallback) { 1293 // Use fallback kernel with 1D threadblock 1294 asmb->block_size_y = 1; 1295 } 1296 1297 // Compile kernels 1298 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &num_comp_in)); 1299 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_out, &num_comp_out)); 1300 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-operator-assemble.h", &assembly_kernel_path)); 1301 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Kernel Source -----\n"); 1302 CeedCallBackend(CeedLoadSourceToBuffer(ceed, assembly_kernel_path, &assembly_kernel_source)); 1303 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Source Complete! -----\n"); 1304 CeedCallBackend(CeedCompile_Cuda(ceed, assembly_kernel_source, &asmb->module, 10, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 1305 num_eval_modes_out, "NUM_COMP_IN", num_comp_in, "NUM_COMP_OUT", num_comp_out, "NUM_NODES_IN", elem_size_in, 1306 "NUM_NODES_OUT", elem_size_out, "NUM_QPTS", num_qpts_in, "BLOCK_SIZE", 1307 asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block, "BLOCK_SIZE_Y", asmb->block_size_y, 1308 "USE_CEEDSIZE", use_ceedsize_idx)); 1309 CeedCallBackend(CeedGetKernel_Cuda(ceed, asmb->module, "LinearAssemble", &asmb->LinearAssemble)); 1310 CeedCallBackend(CeedFree(&assembly_kernel_path)); 1311 CeedCallBackend(CeedFree(&assembly_kernel_source)); 1312 1313 // Load into B_in, in order that they will be used in eval_modes_in 1314 { 1315 const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof(CeedScalar); 1316 CeedInt d_in = 0; 1317 CeedEvalMode eval_modes_in_prev = CEED_EVAL_NONE; 1318 bool has_eval_none = false; 1319 CeedScalar *identity = NULL; 1320 1321 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 1322 has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 1323 } 1324 if (has_eval_none) { 1325 CeedCallBackend(CeedCalloc(elem_size_in * num_qpts_in, &identity)); 1326 for (CeedInt i = 0; i < (elem_size_in < num_qpts_in ? elem_size_in : num_qpts_in); i++) identity[i * elem_size_in + i] = 1.0; 1327 } 1328 1329 CeedCallCuda(ceed, cudaMalloc((void **)&asmb->d_B_in, in_bytes)); 1330 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 1331 const CeedScalar *h_B_in; 1332 1333 CeedCallBackend(CeedOperatorGetBasisPointer(basis_in, eval_modes_in[i], identity, &h_B_in)); 1334 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_modes_in[i], &q_comp)); 1335 if (q_comp > 1) { 1336 if (i == 0 || eval_modes_in[i] != eval_modes_in_prev) d_in = 0; 1337 else h_B_in = &h_B_in[(++d_in) * elem_size_in * num_qpts_in]; 1338 } 1339 eval_modes_in_prev = eval_modes_in[i]; 1340 1341 CeedCallCuda(ceed, cudaMemcpy(&asmb->d_B_in[i * elem_size_in * num_qpts_in], h_B_in, elem_size_in * num_qpts_in * sizeof(CeedScalar), 1342 cudaMemcpyHostToDevice)); 1343 } 1344 1345 if (identity) { 1346 CeedCallBackend(CeedFree(&identity)); 1347 } 1348 } 1349 1350 // Load into B_out, in order that they will be used in eval_modes_out 1351 { 1352 const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof(CeedScalar); 1353 CeedInt d_out = 0; 1354 CeedEvalMode eval_modes_out_prev = CEED_EVAL_NONE; 1355 bool has_eval_none = false; 1356 CeedScalar *identity = NULL; 1357 1358 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1359 has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 1360 } 1361 if (has_eval_none) { 1362 CeedCallBackend(CeedCalloc(elem_size_out * num_qpts_out, &identity)); 1363 for (CeedInt i = 0; i < (elem_size_out < num_qpts_out ? elem_size_out : num_qpts_out); i++) identity[i * elem_size_out + i] = 1.0; 1364 } 1365 1366 CeedCallCuda(ceed, cudaMalloc((void **)&asmb->d_B_out, out_bytes)); 1367 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1368 const CeedScalar *h_B_out; 1369 1370 CeedCallBackend(CeedOperatorGetBasisPointer(basis_out, eval_modes_out[i], identity, &h_B_out)); 1371 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_modes_out[i], &q_comp)); 1372 if (q_comp > 1) { 1373 if (i == 0 || eval_modes_out[i] != eval_modes_out_prev) d_out = 0; 1374 else h_B_out = &h_B_out[(++d_out) * elem_size_out * num_qpts_out]; 1375 } 1376 eval_modes_out_prev = eval_modes_out[i]; 1377 1378 CeedCallCuda(ceed, cudaMemcpy(&asmb->d_B_out[i * elem_size_out * num_qpts_out], h_B_out, elem_size_out * num_qpts_out * sizeof(CeedScalar), 1379 cudaMemcpyHostToDevice)); 1380 } 1381 1382 if (identity) { 1383 CeedCallBackend(CeedFree(&identity)); 1384 } 1385 } 1386 return CEED_ERROR_SUCCESS; 1387 } 1388 1389 //------------------------------------------------------------------------------ 1390 // Assemble matrix data for COO matrix of assembled operator. 1391 // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic. 1392 // 1393 // Note that this (and other assembly routines) currently assume only one active input restriction/basis per operator (could have multiple basis eval 1394 // modes). 1395 // TODO: allow multiple active input restrictions/basis objects 1396 //------------------------------------------------------------------------------ 1397 static int CeedSingleOperatorAssemble_Cuda(CeedOperator op, CeedInt offset, CeedVector values) { 1398 Ceed ceed; 1399 CeedSize values_length = 0, assembled_qf_length = 0; 1400 CeedInt use_ceedsize_idx = 0, num_elem_in, num_elem_out, elem_size_in, elem_size_out; 1401 CeedScalar *values_array; 1402 const CeedScalar *assembled_qf_array; 1403 CeedVector assembled_qf = NULL; 1404 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out; 1405 CeedRestrictionType rstr_type_in, rstr_type_out; 1406 const bool *orients_in = NULL, *orients_out = NULL; 1407 const CeedInt8 *curl_orients_in = NULL, *curl_orients_out = NULL; 1408 CeedOperator_Cuda *impl; 1409 1410 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1411 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1412 1413 // Assemble QFunction 1414 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, CEED_REQUEST_IMMEDIATE)); 1415 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 1416 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 1417 1418 CeedCallBackend(CeedVectorGetLength(values, &values_length)); 1419 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 1420 if ((values_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 1421 1422 // Setup 1423 if (!impl->asmb) CeedCallBackend(CeedSingleOperatorAssembleSetup_Cuda(op, use_ceedsize_idx)); 1424 CeedOperatorAssemble_Cuda *asmb = impl->asmb; 1425 1426 assert(asmb != NULL); 1427 1428 // Assemble element operator 1429 CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_DEVICE, &values_array)); 1430 values_array += offset; 1431 1432 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 1433 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &num_elem_in)); 1434 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 1435 1436 CeedCallBackend(CeedElemRestrictionGetType(rstr_in, &rstr_type_in)); 1437 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1438 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_in, CEED_MEM_DEVICE, &orients_in)); 1439 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1440 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_in, CEED_MEM_DEVICE, &curl_orients_in)); 1441 } 1442 1443 if (rstr_in != rstr_out) { 1444 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_out, &num_elem_out)); 1445 CeedCheck(num_elem_in == num_elem_out, ceed, CEED_ERROR_UNSUPPORTED, 1446 "Active input and output operator restrictions must have the same number of elements"); 1447 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 1448 1449 CeedCallBackend(CeedElemRestrictionGetType(rstr_out, &rstr_type_out)); 1450 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1451 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_out, CEED_MEM_DEVICE, &orients_out)); 1452 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1453 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_out, CEED_MEM_DEVICE, &curl_orients_out)); 1454 } 1455 } else { 1456 elem_size_out = elem_size_in; 1457 orients_out = orients_in; 1458 curl_orients_out = curl_orients_in; 1459 } 1460 1461 // Compute B^T D B 1462 CeedInt shared_mem = 1463 ((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0) + (curl_orients_in ? elem_size_in * asmb->block_size_y : 0)) * 1464 sizeof(CeedScalar); 1465 CeedInt grid = CeedDivUpInt(num_elem_in, asmb->elems_per_block); 1466 void *args[] = {(void *)&num_elem_in, &asmb->d_B_in, &asmb->d_B_out, &orients_in, &curl_orients_in, 1467 &orients_out, &curl_orients_out, &assembled_qf_array, &values_array}; 1468 1469 CeedCallBackend( 1470 CeedRunKernelDimShared_Cuda(ceed, asmb->LinearAssemble, grid, asmb->block_size_x, asmb->block_size_y, asmb->elems_per_block, shared_mem, args)); 1471 1472 // Restore arrays 1473 CeedCallBackend(CeedVectorRestoreArray(values, &values_array)); 1474 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 1475 1476 // Cleanup 1477 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1478 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1479 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_in, &orients_in)); 1480 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1481 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_in, &curl_orients_in)); 1482 } 1483 if (rstr_in != rstr_out) { 1484 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1485 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_out, &orients_out)); 1486 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1487 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_out, &curl_orients_out)); 1488 } 1489 } 1490 return CEED_ERROR_SUCCESS; 1491 } 1492 1493 //------------------------------------------------------------------------------ 1494 // Assemble Linear QFunction AtPoints 1495 //------------------------------------------------------------------------------ 1496 static int CeedOperatorLinearAssembleQFunctionAtPoints_Cuda(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 1497 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "Backend does not implement CeedOperatorLinearAssembleQFunction"); 1498 } 1499 1500 //------------------------------------------------------------------------------ 1501 // Assemble Linear Diagonal AtPoints 1502 //------------------------------------------------------------------------------ 1503 static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1504 CeedInt max_num_points, num_elem, num_input_fields, num_output_fields; 1505 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 1506 CeedQFunctionField *qf_input_fields, *qf_output_fields; 1507 CeedQFunction qf; 1508 CeedOperatorField *op_input_fields, *op_output_fields; 1509 CeedOperator_Cuda *impl; 1510 1511 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1512 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1513 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 1514 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 1515 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 1516 CeedInt num_points[num_elem]; 1517 1518 // Setup 1519 CeedCallBackend(CeedOperatorSetupAtPoints_Cuda(op)); 1520 max_num_points = impl->max_num_points; 1521 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = max_num_points; 1522 1523 // Input Evecs and Restriction 1524 CeedCallBackend(CeedOperatorSetupInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 1525 1526 // Get point coordinates 1527 if (!impl->point_coords_elem) { 1528 CeedVector point_coords = NULL; 1529 CeedElemRestriction rstr_points = NULL; 1530 1531 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 1532 CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem)); 1533 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 1534 } 1535 1536 // Clear active input Qvecs 1537 for (CeedInt i = 0; i < num_input_fields; i++) { 1538 CeedVector vec; 1539 1540 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1541 if (vec != CEED_VECTOR_ACTIVE) continue; 1542 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 1543 } 1544 1545 // Input basis apply if needed 1546 CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(num_elem, num_points, qf_input_fields, op_input_fields, num_input_fields, true, e_data, impl)); 1547 1548 // Output pointers, as necessary 1549 for (CeedInt i = 0; i < num_output_fields; i++) { 1550 CeedEvalMode eval_mode; 1551 1552 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1553 if (eval_mode == CEED_EVAL_NONE) { 1554 // Set the output Q-Vector to use the E-Vector data directly. 1555 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 1556 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 1557 } 1558 } 1559 1560 // Loop over active fields 1561 for (CeedInt i = 0; i < num_input_fields; i++) { 1562 bool is_active_at_points = true; 1563 CeedInt elem_size = 1, num_comp_active = 1, e_vec_size = 0; 1564 CeedRestrictionType rstr_type; 1565 CeedVector vec; 1566 CeedElemRestriction elem_rstr; 1567 1568 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1569 // -- Skip non-active input 1570 if (vec != CEED_VECTOR_ACTIVE) continue; 1571 1572 // -- Get active restriction type 1573 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 1574 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1575 is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS; 1576 if (!is_active_at_points) CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 1577 else elem_size = max_num_points; 1578 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp_active)); 1579 1580 e_vec_size = elem_size * num_comp_active; 1581 for (CeedInt s = 0; s < e_vec_size; s++) { 1582 bool is_active_input = false; 1583 CeedEvalMode eval_mode; 1584 CeedVector vec; 1585 CeedBasis basis; 1586 1587 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1588 // Skip non-active input 1589 is_active_input = vec == CEED_VECTOR_ACTIVE; 1590 if (!is_active_input) continue; 1591 1592 // Update unit vector 1593 if (s == 0) CeedCallBackend(CeedVectorSetValue(impl->e_vecs[i], 0.0)); 1594 else CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs[i], s - 1, e_vec_size, 0.0)); 1595 CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs[i], s, e_vec_size, 1.0)); 1596 1597 // Basis action 1598 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 1599 switch (eval_mode) { 1600 case CEED_EVAL_NONE: 1601 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 1602 break; 1603 case CEED_EVAL_INTERP: 1604 case CEED_EVAL_GRAD: 1605 case CEED_EVAL_DIV: 1606 case CEED_EVAL_CURL: 1607 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 1608 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, impl->e_vecs[i], 1609 impl->q_vecs_in[i])); 1610 break; 1611 case CEED_EVAL_WEIGHT: 1612 break; // No action 1613 } 1614 1615 // Q function 1616 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out)); 1617 1618 // Output basis apply if needed 1619 for (CeedInt j = 0; j < num_output_fields; j++) { 1620 bool is_active_output = false; 1621 CeedInt elem_size = 0; 1622 CeedRestrictionType rstr_type; 1623 CeedEvalMode eval_mode; 1624 CeedVector vec; 1625 CeedElemRestriction elem_rstr; 1626 CeedBasis basis; 1627 1628 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec)); 1629 // ---- Skip non-active output 1630 is_active_output = vec == CEED_VECTOR_ACTIVE; 1631 if (!is_active_output) continue; 1632 1633 // ---- Check if elem size matches 1634 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr)); 1635 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1636 if (is_active_at_points && rstr_type != CEED_RESTRICTION_POINTS) continue; 1637 if (rstr_type == CEED_RESTRICTION_POINTS) { 1638 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(elem_rstr, &elem_size)); 1639 } else { 1640 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 1641 } 1642 { 1643 CeedInt num_comp = 0; 1644 1645 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp)); 1646 if (e_vec_size != num_comp * elem_size) continue; 1647 } 1648 1649 // Basis action 1650 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode)); 1651 switch (eval_mode) { 1652 case CEED_EVAL_NONE: 1653 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[j + impl->num_inputs], &e_data[j + num_input_fields])); 1654 break; 1655 case CEED_EVAL_INTERP: 1656 case CEED_EVAL_GRAD: 1657 case CEED_EVAL_DIV: 1658 case CEED_EVAL_CURL: 1659 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[j], &basis)); 1660 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, 1661 impl->q_vecs_out[j], impl->e_vecs[j + impl->num_inputs])); 1662 break; 1663 // LCOV_EXCL_START 1664 case CEED_EVAL_WEIGHT: { 1665 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 1666 // LCOV_EXCL_STOP 1667 } 1668 } 1669 1670 // Mask output e-vec 1671 CeedCallBackend(CeedVectorPointwiseMult(impl->e_vecs[j + impl->num_inputs], impl->e_vecs[i], impl->e_vecs[j + impl->num_inputs])); 1672 1673 // Restrict 1674 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr)); 1675 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[j + impl->num_inputs], assembled, request)); 1676 1677 // Reset q_vec for 1678 if (eval_mode == CEED_EVAL_NONE) { 1679 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[j + impl->num_inputs], CEED_MEM_DEVICE, &e_data[j + num_input_fields])); 1680 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[j], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[j + num_input_fields])); 1681 } 1682 } 1683 1684 // Reset vec 1685 if (s == e_vec_size - 1 && i != num_input_fields - 1) CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 1686 } 1687 } 1688 1689 // Restore CEED_EVAL_NONE 1690 for (CeedInt i = 0; i < num_output_fields; i++) { 1691 CeedEvalMode eval_mode; 1692 1693 // Get eval_mode 1694 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1695 1696 // Restore evec 1697 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1698 if (eval_mode == CEED_EVAL_NONE) { 1699 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 1700 } 1701 } 1702 1703 // Restore input arrays 1704 CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 1705 return CEED_ERROR_SUCCESS; 1706 } 1707 1708 //------------------------------------------------------------------------------ 1709 // Create operator 1710 //------------------------------------------------------------------------------ 1711 int CeedOperatorCreate_Cuda(CeedOperator op) { 1712 Ceed ceed; 1713 CeedOperator_Cuda *impl; 1714 1715 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1716 CeedCallBackend(CeedCalloc(1, &impl)); 1717 CeedCallBackend(CeedOperatorSetData(op, impl)); 1718 1719 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Cuda)); 1720 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Cuda)); 1721 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonal_Cuda)); 1722 CeedCallBackend( 1723 CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal", CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda)); 1724 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssemble_Cuda)); 1725 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Cuda)); 1726 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda)); 1727 return CEED_ERROR_SUCCESS; 1728 } 1729 1730 //------------------------------------------------------------------------------ 1731 // Create operator AtPoints 1732 //------------------------------------------------------------------------------ 1733 int CeedOperatorCreateAtPoints_Cuda(CeedOperator op) { 1734 Ceed ceed; 1735 CeedOperator_Cuda *impl; 1736 1737 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1738 CeedCallBackend(CeedCalloc(1, &impl)); 1739 CeedCallBackend(CeedOperatorSetData(op, impl)); 1740 1741 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunctionAtPoints_Cuda)); 1742 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda)); 1743 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Cuda)); 1744 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda)); 1745 return CEED_ERROR_SUCCESS; 1746 } 1747 1748 //------------------------------------------------------------------------------ 1749