1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <assert.h> 12 #include <cuda.h> 13 #include <cuda_runtime.h> 14 #include <stdbool.h> 15 #include <string.h> 16 17 #include "../cuda/ceed-cuda-common.h" 18 #include "../cuda/ceed-cuda-compile.h" 19 #include "ceed-cuda-ref.h" 20 21 //------------------------------------------------------------------------------ 22 // Destroy operator 23 //------------------------------------------------------------------------------ 24 static int CeedOperatorDestroy_Cuda(CeedOperator op) { 25 CeedOperator_Cuda *impl; 26 27 CeedCallBackend(CeedOperatorGetData(op, &impl)); 28 29 // Apply data 30 CeedCallBackend(CeedFree(&impl->skip_rstr_in)); 31 CeedCallBackend(CeedFree(&impl->skip_rstr_out)); 32 CeedCallBackend(CeedFree(&impl->apply_add_basis_out)); 33 for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { 34 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[i])); 35 } 36 CeedCallBackend(CeedFree(&impl->e_vecs)); 37 CeedCallBackend(CeedFree(&impl->input_states)); 38 39 for (CeedInt i = 0; i < impl->num_inputs; i++) { 40 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 41 } 42 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 43 44 for (CeedInt i = 0; i < impl->num_outputs; i++) { 45 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 46 } 47 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 48 CeedCallBackend(CeedVectorDestroy(&impl->point_coords_elem)); 49 50 // QFunction assembly data 51 for (CeedInt i = 0; i < impl->num_active_in; i++) { 52 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i])); 53 } 54 CeedCallBackend(CeedFree(&impl->qf_active_in)); 55 56 // Diag data 57 if (impl->diag) { 58 Ceed ceed; 59 60 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 61 if (impl->diag->module) { 62 CeedCallCuda(ceed, cuModuleUnload(impl->diag->module)); 63 } 64 if (impl->diag->module_point_block) { 65 CeedCallCuda(ceed, cuModuleUnload(impl->diag->module_point_block)); 66 } 67 CeedCallCuda(ceed, cudaFree(impl->diag->d_eval_modes_in)); 68 CeedCallCuda(ceed, cudaFree(impl->diag->d_eval_modes_out)); 69 CeedCallCuda(ceed, cudaFree(impl->diag->d_identity)); 70 CeedCallCuda(ceed, cudaFree(impl->diag->d_interp_in)); 71 CeedCallCuda(ceed, cudaFree(impl->diag->d_interp_out)); 72 CeedCallCuda(ceed, cudaFree(impl->diag->d_grad_in)); 73 CeedCallCuda(ceed, cudaFree(impl->diag->d_grad_out)); 74 CeedCallCuda(ceed, cudaFree(impl->diag->d_div_in)); 75 CeedCallCuda(ceed, cudaFree(impl->diag->d_div_out)); 76 CeedCallCuda(ceed, cudaFree(impl->diag->d_curl_in)); 77 CeedCallCuda(ceed, cudaFree(impl->diag->d_curl_out)); 78 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr)); 79 CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr)); 80 CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag)); 81 CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag)); 82 } 83 CeedCallBackend(CeedFree(&impl->diag)); 84 85 if (impl->asmb) { 86 Ceed ceed; 87 88 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 89 CeedCallCuda(ceed, cuModuleUnload(impl->asmb->module)); 90 CeedCallCuda(ceed, cudaFree(impl->asmb->d_B_in)); 91 CeedCallCuda(ceed, cudaFree(impl->asmb->d_B_out)); 92 } 93 CeedCallBackend(CeedFree(&impl->asmb)); 94 95 CeedCallBackend(CeedFree(&impl)); 96 return CEED_ERROR_SUCCESS; 97 } 98 99 //------------------------------------------------------------------------------ 100 // Setup infields or outfields 101 //------------------------------------------------------------------------------ 102 static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, bool *apply_add_basis, 103 CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) { 104 Ceed ceed; 105 CeedQFunctionField *qf_fields; 106 CeedOperatorField *op_fields; 107 108 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 109 if (is_input) { 110 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 111 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 112 } else { 113 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 114 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 115 } 116 117 // Loop over fields 118 for (CeedInt i = 0; i < num_fields; i++) { 119 bool is_strided = false, skip_restriction = false; 120 CeedSize q_size; 121 CeedInt size; 122 CeedEvalMode eval_mode; 123 CeedBasis basis; 124 125 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 126 if (eval_mode != CEED_EVAL_WEIGHT) { 127 CeedElemRestriction elem_rstr; 128 129 // Check whether this field can skip the element restriction: 130 // Must be passive input, with eval_mode NONE, and have a strided restriction with CEED_STRIDES_BACKEND. 131 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); 132 133 // First, check whether the field is input or output: 134 if (is_input) { 135 CeedVector vec; 136 137 // Check for passive input 138 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 139 if (vec != CEED_VECTOR_ACTIVE) { 140 // Check eval_mode 141 if (eval_mode == CEED_EVAL_NONE) { 142 // Check for strided restriction 143 CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided)); 144 if (is_strided) { 145 // Check if vector is already in preferred backend ordering 146 CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &skip_restriction)); 147 } 148 } 149 } 150 } 151 if (skip_restriction) { 152 // We do not need an E-Vector, but will use the input field vector's data directly in the operator application. 153 e_vecs[i + start_e] = NULL; 154 } else { 155 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs[i + start_e])); 156 } 157 } 158 159 switch (eval_mode) { 160 case CEED_EVAL_NONE: 161 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 162 q_size = (CeedSize)num_elem * Q * size; 163 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 164 break; 165 case CEED_EVAL_INTERP: 166 case CEED_EVAL_GRAD: 167 case CEED_EVAL_DIV: 168 case CEED_EVAL_CURL: 169 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 170 q_size = (CeedSize)num_elem * Q * size; 171 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 172 break; 173 case CEED_EVAL_WEIGHT: // Only on input fields 174 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 175 q_size = (CeedSize)num_elem * Q; 176 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 177 if (is_at_points) { 178 CeedInt num_points[num_elem]; 179 180 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = Q; 181 CeedCallBackend( 182 CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, CEED_VECTOR_NONE, q_vecs[i])); 183 } else { 184 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 185 } 186 break; 187 } 188 } 189 // Drop duplicate restrictions 190 if (is_input) { 191 for (CeedInt i = 0; i < num_fields; i++) { 192 CeedVector vec_i; 193 CeedElemRestriction rstr_i; 194 195 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); 196 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); 197 for (CeedInt j = i + 1; j < num_fields; j++) { 198 CeedVector vec_j; 199 CeedElemRestriction rstr_j; 200 201 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); 202 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); 203 if (vec_i == vec_j && rstr_i == rstr_j) { 204 CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e])); 205 skip_rstr[j] = true; 206 } 207 } 208 } 209 } else { 210 for (CeedInt i = num_fields - 1; i >= 0; i--) { 211 CeedVector vec_i; 212 CeedElemRestriction rstr_i; 213 214 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); 215 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); 216 for (CeedInt j = i - 1; j >= 0; j--) { 217 CeedVector vec_j; 218 CeedElemRestriction rstr_j; 219 220 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); 221 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); 222 if (vec_i == vec_j && rstr_i == rstr_j) { 223 CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e])); 224 skip_rstr[j] = true; 225 apply_add_basis[i] = true; 226 } 227 } 228 } 229 } 230 return CEED_ERROR_SUCCESS; 231 } 232 233 //------------------------------------------------------------------------------ 234 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 235 //------------------------------------------------------------------------------ 236 static int CeedOperatorSetup_Cuda(CeedOperator op) { 237 Ceed ceed; 238 bool is_setup_done; 239 CeedInt Q, num_elem, num_input_fields, num_output_fields; 240 CeedQFunctionField *qf_input_fields, *qf_output_fields; 241 CeedQFunction qf; 242 CeedOperatorField *op_input_fields, *op_output_fields; 243 CeedOperator_Cuda *impl; 244 245 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 246 if (is_setup_done) return CEED_ERROR_SUCCESS; 247 248 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 249 CeedCallBackend(CeedOperatorGetData(op, &impl)); 250 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 251 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 252 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 253 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 254 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 255 256 // Allocate 257 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); 258 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); 259 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); 260 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); 261 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 262 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 263 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 264 impl->num_inputs = num_input_fields; 265 impl->num_outputs = num_output_fields; 266 267 // Set up infield and outfield e-vecs and q-vecs 268 // Infields 269 CeedCallBackend( 270 CeedOperatorSetupFields_Cuda(qf, op, true, false, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem)); 271 // Outfields 272 CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, false, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out, 273 num_input_fields, num_output_fields, Q, num_elem)); 274 275 // Reuse active e-vecs where able 276 { 277 CeedInt num_used = 0; 278 CeedElemRestriction *rstr_used = NULL; 279 280 for (CeedInt i = 0; i < num_input_fields; i++) { 281 bool is_used = false; 282 CeedVector vec_i; 283 CeedElemRestriction rstr_i; 284 285 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i)); 286 if (vec_i != CEED_VECTOR_ACTIVE) continue; 287 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i)); 288 for (CeedInt j = 0; j < num_used; j++) { 289 if (rstr_i == rstr_used[i]) is_used = true; 290 } 291 if (is_used) continue; 292 num_used++; 293 if (num_used == 1) CeedCallBackend(CeedCalloc(num_used, &rstr_used)); 294 else CeedCallBackend(CeedRealloc(num_used, &rstr_used)); 295 rstr_used[num_used - 1] = rstr_i; 296 for (CeedInt j = num_output_fields - 1; j >= 0; j--) { 297 CeedEvalMode eval_mode; 298 CeedVector vec_j; 299 CeedElemRestriction rstr_j; 300 301 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec_j)); 302 if (vec_j != CEED_VECTOR_ACTIVE) continue; 303 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode)); 304 if (eval_mode == CEED_EVAL_NONE) continue; 305 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &rstr_j)); 306 if (rstr_i == rstr_j) { 307 CeedCallBackend(CeedVectorReferenceCopy(impl->e_vecs[i], &impl->e_vecs[j + impl->num_inputs])); 308 } 309 } 310 } 311 CeedCallBackend(CeedFree(&rstr_used)); 312 } 313 impl->has_shared_e_vecs = true; 314 CeedCallBackend(CeedOperatorSetSetupDone(op)); 315 return CEED_ERROR_SUCCESS; 316 } 317 318 //------------------------------------------------------------------------------ 319 // Setup Operator Inputs 320 //------------------------------------------------------------------------------ 321 static inline int CeedOperatorSetupInputs_Cuda(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 322 CeedVector in_vec, const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 323 CeedOperator_Cuda *impl, CeedRequest *request) { 324 for (CeedInt i = 0; i < num_input_fields; i++) { 325 CeedEvalMode eval_mode; 326 CeedVector vec; 327 CeedElemRestriction elem_rstr; 328 329 // Get input vector 330 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 331 if (vec == CEED_VECTOR_ACTIVE) { 332 if (skip_active) continue; 333 else vec = in_vec; 334 } 335 336 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 337 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 338 } else { 339 // Get input vector 340 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 341 // Get input element restriction 342 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 343 if (vec == CEED_VECTOR_ACTIVE) vec = in_vec; 344 // Restrict, if necessary 345 if (!impl->e_vecs[i]) { 346 // No restriction for this field; read data directly from vec. 347 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i])); 348 } else { 349 uint64_t state; 350 351 CeedCallBackend(CeedVectorGetState(vec, &state)); 352 if ((state != impl->input_states[i] || vec == in_vec) && !impl->skip_rstr_in[i]) { 353 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs[i], request)); 354 } 355 impl->input_states[i] = state; 356 // Get evec 357 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs[i], CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i])); 358 } 359 } 360 } 361 return CEED_ERROR_SUCCESS; 362 } 363 364 //------------------------------------------------------------------------------ 365 // Input Basis Action 366 //------------------------------------------------------------------------------ 367 static inline int CeedOperatorInputBasis_Cuda(CeedInt num_elem, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 368 CeedInt num_input_fields, const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 369 CeedOperator_Cuda *impl) { 370 for (CeedInt i = 0; i < num_input_fields; i++) { 371 CeedInt elem_size, size; 372 CeedEvalMode eval_mode; 373 CeedElemRestriction elem_rstr; 374 CeedBasis basis; 375 376 // Skip active input 377 if (skip_active) { 378 CeedVector vec; 379 380 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 381 if (vec == CEED_VECTOR_ACTIVE) continue; 382 } 383 // Get elem_size, eval_mode, size 384 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 385 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 386 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 387 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 388 // Basis action 389 switch (eval_mode) { 390 case CEED_EVAL_NONE: 391 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 392 break; 393 case CEED_EVAL_INTERP: 394 case CEED_EVAL_GRAD: 395 case CEED_EVAL_DIV: 396 case CEED_EVAL_CURL: 397 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 398 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs[i], impl->q_vecs_in[i])); 399 break; 400 case CEED_EVAL_WEIGHT: 401 break; // No action 402 } 403 } 404 return CEED_ERROR_SUCCESS; 405 } 406 407 //------------------------------------------------------------------------------ 408 // Restore Input Vectors 409 //------------------------------------------------------------------------------ 410 static inline int CeedOperatorRestoreInputs_Cuda(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 411 const bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Cuda *impl) { 412 for (CeedInt i = 0; i < num_input_fields; i++) { 413 CeedEvalMode eval_mode; 414 CeedVector vec; 415 416 // Skip active input 417 if (skip_active) { 418 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 419 if (vec == CEED_VECTOR_ACTIVE) continue; 420 } 421 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 422 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 423 } else { 424 if (!impl->e_vecs[i]) { // This was a skip_restriction case 425 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 426 CeedCallBackend(CeedVectorRestoreArrayRead(vec, (const CeedScalar **)&e_data[i])); 427 } else { 428 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs[i], (const CeedScalar **)&e_data[i])); 429 } 430 } 431 } 432 return CEED_ERROR_SUCCESS; 433 } 434 435 //------------------------------------------------------------------------------ 436 // Apply and add to output 437 //------------------------------------------------------------------------------ 438 static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 439 CeedInt Q, num_elem, elem_size, num_input_fields, num_output_fields, size; 440 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 441 CeedQFunctionField *qf_input_fields, *qf_output_fields; 442 CeedQFunction qf; 443 CeedOperatorField *op_input_fields, *op_output_fields; 444 CeedOperator_Cuda *impl; 445 446 CeedCallBackend(CeedOperatorGetData(op, &impl)); 447 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 448 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 449 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 450 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 451 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 452 453 // Setup 454 CeedCallBackend(CeedOperatorSetup_Cuda(op)); 455 456 // Input Evecs and Restriction 457 CeedCallBackend(CeedOperatorSetupInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data, impl, request)); 458 459 // Input basis apply if needed 460 CeedCallBackend(CeedOperatorInputBasis_Cuda(num_elem, qf_input_fields, op_input_fields, num_input_fields, false, e_data, impl)); 461 462 // Output pointers, as necessary 463 for (CeedInt i = 0; i < num_output_fields; i++) { 464 CeedEvalMode eval_mode; 465 466 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 467 if (eval_mode == CEED_EVAL_NONE) { 468 // Set the output Q-Vector to use the E-Vector data directly. 469 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 470 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 471 } 472 } 473 474 // Q function 475 CeedCallBackend(CeedQFunctionApply(qf, num_elem * Q, impl->q_vecs_in, impl->q_vecs_out)); 476 477 // Restore input arrays 478 CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl)); 479 480 // Output basis apply if needed 481 for (CeedInt i = 0; i < num_output_fields; i++) { 482 CeedEvalMode eval_mode; 483 CeedElemRestriction elem_rstr; 484 CeedBasis basis; 485 486 // Get elem_size, eval_mode, size 487 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 488 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 489 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 490 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 491 // Basis action 492 switch (eval_mode) { 493 case CEED_EVAL_NONE: 494 break; // No action 495 case CEED_EVAL_INTERP: 496 case CEED_EVAL_GRAD: 497 case CEED_EVAL_DIV: 498 case CEED_EVAL_CURL: 499 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 500 if (impl->apply_add_basis_out[i]) { 501 CeedCallBackend(CeedBasisApplyAdd(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); 502 } else { 503 CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); 504 } 505 break; 506 // LCOV_EXCL_START 507 case CEED_EVAL_WEIGHT: { 508 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 509 // LCOV_EXCL_STOP 510 } 511 } 512 } 513 514 // Output restriction 515 for (CeedInt i = 0; i < num_output_fields; i++) { 516 CeedEvalMode eval_mode; 517 CeedVector vec; 518 CeedElemRestriction elem_rstr; 519 520 // Restore evec 521 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 522 if (eval_mode == CEED_EVAL_NONE) { 523 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 524 } 525 if (impl->skip_rstr_out[i]) continue; 526 // Get output vector 527 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 528 // Restrict 529 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 530 // Active 531 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 532 533 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request)); 534 } 535 return CEED_ERROR_SUCCESS; 536 } 537 538 //------------------------------------------------------------------------------ 539 // CeedOperator needs to connect all the named fields (be they active or passive) to the named inputs and outputs of its CeedQFunction. 540 //------------------------------------------------------------------------------ 541 static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) { 542 Ceed ceed; 543 bool is_setup_done; 544 CeedInt max_num_points = -1, num_elem, num_input_fields, num_output_fields; 545 CeedQFunctionField *qf_input_fields, *qf_output_fields; 546 CeedQFunction qf; 547 CeedOperatorField *op_input_fields, *op_output_fields; 548 CeedOperator_Cuda *impl; 549 550 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 551 if (is_setup_done) return CEED_ERROR_SUCCESS; 552 553 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 554 CeedCallBackend(CeedOperatorGetData(op, &impl)); 555 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 556 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 557 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 558 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 559 { 560 CeedElemRestriction elem_rstr = NULL; 561 562 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &elem_rstr, NULL)); 563 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(elem_rstr, &max_num_points)); 564 } 565 impl->max_num_points = max_num_points; 566 567 // Allocate 568 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); 569 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); 570 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); 571 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); 572 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 573 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 574 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 575 impl->num_inputs = num_input_fields; 576 impl->num_outputs = num_output_fields; 577 578 // Set up infield and outfield e-vecs and q-vecs 579 // Infields 580 CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, true, true, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, 581 max_num_points, num_elem)); 582 // Outfields 583 CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out, 584 num_input_fields, num_output_fields, max_num_points, num_elem)); 585 586 // Reuse active e-vecs where able 587 { 588 CeedInt num_used = 0; 589 CeedElemRestriction *rstr_used = NULL; 590 591 for (CeedInt i = 0; i < num_input_fields; i++) { 592 bool is_used = false; 593 CeedVector vec_i; 594 CeedElemRestriction rstr_i; 595 596 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i)); 597 if (vec_i != CEED_VECTOR_ACTIVE) continue; 598 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i)); 599 for (CeedInt j = 0; j < num_used; j++) { 600 if (rstr_i == rstr_used[i]) is_used = true; 601 } 602 if (is_used) continue; 603 num_used++; 604 if (num_used == 1) CeedCallBackend(CeedCalloc(num_used, &rstr_used)); 605 else CeedCallBackend(CeedRealloc(num_used, &rstr_used)); 606 rstr_used[num_used - 1] = rstr_i; 607 for (CeedInt j = num_output_fields - 1; j >= 0; j--) { 608 CeedEvalMode eval_mode; 609 CeedVector vec_j; 610 CeedElemRestriction rstr_j; 611 612 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec_j)); 613 if (vec_j != CEED_VECTOR_ACTIVE) continue; 614 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode)); 615 if (eval_mode == CEED_EVAL_NONE) continue; 616 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &rstr_j)); 617 if (rstr_i == rstr_j) { 618 CeedCallBackend(CeedVectorReferenceCopy(impl->e_vecs[i], &impl->e_vecs[j + impl->num_inputs])); 619 } 620 } 621 } 622 CeedCallBackend(CeedFree(&rstr_used)); 623 } 624 impl->has_shared_e_vecs = true; 625 CeedCallBackend(CeedOperatorSetSetupDone(op)); 626 return CEED_ERROR_SUCCESS; 627 } 628 629 //------------------------------------------------------------------------------ 630 // Input Basis Action AtPoints 631 //------------------------------------------------------------------------------ 632 static inline int CeedOperatorInputBasisAtPoints_Cuda(CeedInt num_elem, const CeedInt *num_points, CeedQFunctionField *qf_input_fields, 633 CeedOperatorField *op_input_fields, CeedInt num_input_fields, const bool skip_active, 634 CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Cuda *impl) { 635 for (CeedInt i = 0; i < num_input_fields; i++) { 636 CeedInt elem_size, size; 637 CeedEvalMode eval_mode; 638 CeedElemRestriction elem_rstr; 639 CeedBasis basis; 640 641 // Skip active input 642 if (skip_active) { 643 CeedVector vec; 644 645 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 646 if (vec == CEED_VECTOR_ACTIVE) continue; 647 } 648 // Get elem_size, eval_mode, size 649 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 650 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 651 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 652 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 653 // Basis action 654 switch (eval_mode) { 655 case CEED_EVAL_NONE: 656 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 657 break; 658 case CEED_EVAL_INTERP: 659 case CEED_EVAL_GRAD: 660 case CEED_EVAL_DIV: 661 case CEED_EVAL_CURL: 662 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 663 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, impl->e_vecs[i], 664 impl->q_vecs_in[i])); 665 break; 666 case CEED_EVAL_WEIGHT: 667 break; // No action 668 } 669 } 670 return CEED_ERROR_SUCCESS; 671 } 672 673 //------------------------------------------------------------------------------ 674 // Apply and add to output AtPoints 675 //------------------------------------------------------------------------------ 676 static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 677 CeedInt max_num_points, num_elem, elem_size, num_input_fields, num_output_fields, size; 678 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 679 CeedQFunctionField *qf_input_fields, *qf_output_fields; 680 CeedQFunction qf; 681 CeedOperatorField *op_input_fields, *op_output_fields; 682 CeedOperator_Cuda *impl; 683 684 CeedCallBackend(CeedOperatorGetData(op, &impl)); 685 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 686 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 687 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 688 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 689 CeedInt num_points[num_elem]; 690 691 // Setup 692 CeedCallBackend(CeedOperatorSetupAtPoints_Cuda(op)); 693 max_num_points = impl->max_num_points; 694 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = max_num_points; 695 696 // Input Evecs and Restriction 697 CeedCallBackend(CeedOperatorSetupInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data, impl, request)); 698 699 // Get point coordinates 700 if (!impl->point_coords_elem) { 701 CeedVector point_coords = NULL; 702 CeedElemRestriction rstr_points = NULL; 703 704 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 705 CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem)); 706 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 707 } 708 709 // Input basis apply if needed 710 CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(num_elem, num_points, qf_input_fields, op_input_fields, num_input_fields, false, e_data, impl)); 711 712 // Output pointers, as necessary 713 for (CeedInt i = 0; i < num_output_fields; i++) { 714 CeedEvalMode eval_mode; 715 716 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 717 if (eval_mode == CEED_EVAL_NONE) { 718 // Set the output Q-Vector to use the E-Vector data directly. 719 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 720 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 721 } 722 } 723 724 // Q function 725 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out)); 726 727 // Restore input arrays 728 CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl)); 729 730 // Output basis apply if needed 731 for (CeedInt i = 0; i < num_output_fields; i++) { 732 CeedEvalMode eval_mode; 733 CeedElemRestriction elem_rstr; 734 CeedBasis basis; 735 736 // Get elem_size, eval_mode, size 737 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 738 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 739 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 740 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 741 // Basis action 742 switch (eval_mode) { 743 case CEED_EVAL_NONE: 744 break; // No action 745 case CEED_EVAL_INTERP: 746 case CEED_EVAL_GRAD: 747 case CEED_EVAL_DIV: 748 case CEED_EVAL_CURL: 749 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 750 if (impl->apply_add_basis_out[i]) { 751 CeedCallBackend(CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, 752 impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); 753 } else { 754 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i], 755 impl->e_vecs[i + impl->num_inputs])); 756 } 757 break; 758 // LCOV_EXCL_START 759 case CEED_EVAL_WEIGHT: { 760 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 761 // LCOV_EXCL_STOP 762 } 763 } 764 } 765 766 // Output restriction 767 for (CeedInt i = 0; i < num_output_fields; i++) { 768 CeedEvalMode eval_mode; 769 CeedVector vec; 770 CeedElemRestriction elem_rstr; 771 772 // Restore evec 773 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 774 if (eval_mode == CEED_EVAL_NONE) { 775 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 776 } 777 if (impl->skip_rstr_out[i]) continue; 778 // Get output vector 779 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 780 // Restrict 781 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 782 // Active 783 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 784 785 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request)); 786 } 787 return CEED_ERROR_SUCCESS; 788 } 789 790 //------------------------------------------------------------------------------ 791 // Linear QFunction Assembly Core 792 //------------------------------------------------------------------------------ 793 static inline int CeedOperatorLinearAssembleQFunctionCore_Cuda(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 794 CeedRequest *request) { 795 Ceed ceed, ceed_parent; 796 CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size; 797 CeedScalar *assembled_array, *e_data[2 * CEED_FIELD_MAX] = {NULL}; 798 CeedVector *active_inputs; 799 CeedQFunctionField *qf_input_fields, *qf_output_fields; 800 CeedQFunction qf; 801 CeedOperatorField *op_input_fields, *op_output_fields; 802 CeedOperator_Cuda *impl; 803 804 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 805 CeedCallBackend(CeedOperatorGetFallbackParentCeed(op, &ceed_parent)); 806 CeedCallBackend(CeedOperatorGetData(op, &impl)); 807 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 808 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 809 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 810 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 811 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 812 active_inputs = impl->qf_active_in; 813 num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 814 815 // Setup 816 CeedCallBackend(CeedOperatorSetup_Cuda(op)); 817 818 // Input Evecs and Restriction 819 CeedCallBackend(CeedOperatorSetupInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 820 821 // Count number of active input fields 822 if (!num_active_in) { 823 for (CeedInt i = 0; i < num_input_fields; i++) { 824 CeedScalar *q_vec_array; 825 CeedVector vec; 826 827 // Get input vector 828 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 829 // Check if active input 830 if (vec == CEED_VECTOR_ACTIVE) { 831 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 832 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 833 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, &q_vec_array)); 834 CeedCallBackend(CeedRealloc(num_active_in + size, &active_inputs)); 835 for (CeedInt field = 0; field < size; field++) { 836 CeedSize q_size = (CeedSize)Q * num_elem; 837 838 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_inputs[num_active_in + field])); 839 CeedCallBackend( 840 CeedVectorSetArray(active_inputs[num_active_in + field], CEED_MEM_DEVICE, CEED_USE_POINTER, &q_vec_array[field * Q * num_elem])); 841 } 842 num_active_in += size; 843 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array)); 844 } 845 } 846 impl->num_active_in = num_active_in; 847 impl->qf_active_in = active_inputs; 848 } 849 850 // Count number of active output fields 851 if (!num_active_out) { 852 for (CeedInt i = 0; i < num_output_fields; i++) { 853 CeedVector vec; 854 855 // Get output vector 856 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 857 // Check if active output 858 if (vec == CEED_VECTOR_ACTIVE) { 859 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 860 num_active_out += size; 861 } 862 } 863 impl->num_active_out = num_active_out; 864 } 865 866 // Check sizes 867 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 868 869 // Build objects if needed 870 if (build_objects) { 871 CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out; 872 CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */ 873 874 // Create output restriction 875 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out, 876 (CeedSize)num_active_in * (CeedSize)num_active_out * (CeedSize)num_elem * (CeedSize)Q, strides, 877 rstr)); 878 // Create assembled vector 879 CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled)); 880 } 881 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 882 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_DEVICE, &assembled_array)); 883 884 // Input basis apply 885 CeedCallBackend(CeedOperatorInputBasis_Cuda(num_elem, qf_input_fields, op_input_fields, num_input_fields, true, e_data, impl)); 886 887 // Assemble QFunction 888 for (CeedInt in = 0; in < num_active_in; in++) { 889 // Set Inputs 890 CeedCallBackend(CeedVectorSetValue(active_inputs[in], 1.0)); 891 if (num_active_in > 1) { 892 CeedCallBackend(CeedVectorSetValue(active_inputs[(in + num_active_in - 1) % num_active_in], 0.0)); 893 } 894 // Set Outputs 895 for (CeedInt out = 0; out < num_output_fields; out++) { 896 CeedVector vec; 897 898 // Get output vector 899 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 900 // Check if active output 901 if (vec == CEED_VECTOR_ACTIVE) { 902 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, CEED_USE_POINTER, assembled_array)); 903 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size)); 904 assembled_array += size * Q * num_elem; // Advance the pointer by the size of the output 905 } 906 } 907 // Apply QFunction 908 CeedCallBackend(CeedQFunctionApply(qf, Q * num_elem, impl->q_vecs_in, impl->q_vecs_out)); 909 } 910 911 // Un-set output q-vecs to prevent accidental overwrite of Assembled 912 for (CeedInt out = 0; out < num_output_fields; out++) { 913 CeedVector vec; 914 915 // Get output vector 916 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 917 // Check if active output 918 if (vec == CEED_VECTOR_ACTIVE) { 919 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_DEVICE, NULL)); 920 } 921 } 922 923 // Restore input arrays 924 CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 925 926 // Restore output 927 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 928 return CEED_ERROR_SUCCESS; 929 } 930 931 //------------------------------------------------------------------------------ 932 // Assemble Linear QFunction 933 //------------------------------------------------------------------------------ 934 static int CeedOperatorLinearAssembleQFunction_Cuda(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 935 return CeedOperatorLinearAssembleQFunctionCore_Cuda(op, true, assembled, rstr, request); 936 } 937 938 //------------------------------------------------------------------------------ 939 // Update Assembled Linear QFunction 940 //------------------------------------------------------------------------------ 941 static int CeedOperatorLinearAssembleQFunctionUpdate_Cuda(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 942 return CeedOperatorLinearAssembleQFunctionCore_Cuda(op, false, &assembled, &rstr, request); 943 } 944 945 //------------------------------------------------------------------------------ 946 // Assemble Diagonal Setup 947 //------------------------------------------------------------------------------ 948 static inline int CeedOperatorAssembleDiagonalSetup_Cuda(CeedOperator op) { 949 Ceed ceed; 950 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 951 CeedInt q_comp, num_nodes, num_qpts; 952 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 953 CeedBasis basis_in = NULL, basis_out = NULL; 954 CeedQFunctionField *qf_fields; 955 CeedQFunction qf; 956 CeedOperatorField *op_fields; 957 CeedOperator_Cuda *impl; 958 959 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 960 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 961 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 962 963 // Determine active input basis 964 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 965 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 966 for (CeedInt i = 0; i < num_input_fields; i++) { 967 CeedVector vec; 968 969 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 970 if (vec == CEED_VECTOR_ACTIVE) { 971 CeedBasis basis; 972 CeedEvalMode eval_mode; 973 974 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 975 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, 976 "Backend does not implement operator diagonal assembly with multiple active bases"); 977 basis_in = basis; 978 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 979 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 980 if (eval_mode != CEED_EVAL_WEIGHT) { 981 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 982 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 983 for (CeedInt d = 0; d < q_comp; d++) eval_modes_in[num_eval_modes_in + d] = eval_mode; 984 num_eval_modes_in += q_comp; 985 } 986 } 987 } 988 989 // Determine active output basis 990 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 991 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 992 for (CeedInt i = 0; i < num_output_fields; i++) { 993 CeedVector vec; 994 995 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 996 if (vec == CEED_VECTOR_ACTIVE) { 997 CeedBasis basis; 998 CeedEvalMode eval_mode; 999 1000 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 1001 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 1002 "Backend does not implement operator diagonal assembly with multiple active bases"); 1003 basis_out = basis; 1004 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1005 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1006 if (eval_mode != CEED_EVAL_WEIGHT) { 1007 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF assembly 1008 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 1009 for (CeedInt d = 0; d < q_comp; d++) eval_modes_out[num_eval_modes_out + d] = eval_mode; 1010 num_eval_modes_out += q_comp; 1011 } 1012 } 1013 } 1014 1015 // Operator data struct 1016 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1017 CeedCallBackend(CeedCalloc(1, &impl->diag)); 1018 CeedOperatorDiag_Cuda *diag = impl->diag; 1019 1020 // Basis matrices 1021 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes)); 1022 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes; 1023 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); 1024 const CeedInt interp_bytes = num_nodes * num_qpts * sizeof(CeedScalar); 1025 const CeedInt eval_modes_bytes = sizeof(CeedEvalMode); 1026 bool has_eval_none = false; 1027 1028 // CEED_EVAL_NONE 1029 for (CeedInt i = 0; i < num_eval_modes_in; i++) has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 1030 for (CeedInt i = 0; i < num_eval_modes_out; i++) has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 1031 if (has_eval_none) { 1032 CeedScalar *identity = NULL; 1033 1034 CeedCallBackend(CeedCalloc(num_nodes * num_qpts, &identity)); 1035 for (CeedInt i = 0; i < (num_nodes < num_qpts ? num_nodes : num_qpts); i++) identity[i * num_nodes + i] = 1.0; 1036 CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_identity, interp_bytes)); 1037 CeedCallCuda(ceed, cudaMemcpy(diag->d_identity, identity, interp_bytes, cudaMemcpyHostToDevice)); 1038 CeedCallBackend(CeedFree(&identity)); 1039 } 1040 1041 // CEED_EVAL_INTERP, CEED_EVAL_GRAD, CEED_EVAL_DIV, and CEED_EVAL_CURL 1042 for (CeedInt in = 0; in < 2; in++) { 1043 CeedFESpace fespace; 1044 CeedBasis basis = in ? basis_in : basis_out; 1045 1046 CeedCallBackend(CeedBasisGetFESpace(basis, &fespace)); 1047 switch (fespace) { 1048 case CEED_FE_SPACE_H1: { 1049 CeedInt q_comp_interp, q_comp_grad; 1050 const CeedScalar *interp, *grad; 1051 CeedScalar *d_interp, *d_grad; 1052 1053 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 1054 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_grad)); 1055 1056 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 1057 CeedCallCuda(ceed, cudaMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 1058 CeedCallCuda(ceed, cudaMemcpy(d_interp, interp, interp_bytes * q_comp_interp, cudaMemcpyHostToDevice)); 1059 CeedCallBackend(CeedBasisGetGrad(basis, &grad)); 1060 CeedCallCuda(ceed, cudaMalloc((void **)&d_grad, interp_bytes * q_comp_grad)); 1061 CeedCallCuda(ceed, cudaMemcpy(d_grad, grad, interp_bytes * q_comp_grad, cudaMemcpyHostToDevice)); 1062 if (in) { 1063 diag->d_interp_in = d_interp; 1064 diag->d_grad_in = d_grad; 1065 } else { 1066 diag->d_interp_out = d_interp; 1067 diag->d_grad_out = d_grad; 1068 } 1069 } break; 1070 case CEED_FE_SPACE_HDIV: { 1071 CeedInt q_comp_interp, q_comp_div; 1072 const CeedScalar *interp, *div; 1073 CeedScalar *d_interp, *d_div; 1074 1075 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 1076 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_div)); 1077 1078 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 1079 CeedCallCuda(ceed, cudaMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 1080 CeedCallCuda(ceed, cudaMemcpy(d_interp, interp, interp_bytes * q_comp_interp, cudaMemcpyHostToDevice)); 1081 CeedCallBackend(CeedBasisGetDiv(basis, &div)); 1082 CeedCallCuda(ceed, cudaMalloc((void **)&d_div, interp_bytes * q_comp_div)); 1083 CeedCallCuda(ceed, cudaMemcpy(d_div, div, interp_bytes * q_comp_div, cudaMemcpyHostToDevice)); 1084 if (in) { 1085 diag->d_interp_in = d_interp; 1086 diag->d_div_in = d_div; 1087 } else { 1088 diag->d_interp_out = d_interp; 1089 diag->d_div_out = d_div; 1090 } 1091 } break; 1092 case CEED_FE_SPACE_HCURL: { 1093 CeedInt q_comp_interp, q_comp_curl; 1094 const CeedScalar *interp, *curl; 1095 CeedScalar *d_interp, *d_curl; 1096 1097 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 1098 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_curl)); 1099 1100 CeedCallBackend(CeedBasisGetInterp(basis, &interp)); 1101 CeedCallCuda(ceed, cudaMalloc((void **)&d_interp, interp_bytes * q_comp_interp)); 1102 CeedCallCuda(ceed, cudaMemcpy(d_interp, interp, interp_bytes * q_comp_interp, cudaMemcpyHostToDevice)); 1103 CeedCallBackend(CeedBasisGetCurl(basis, &curl)); 1104 CeedCallCuda(ceed, cudaMalloc((void **)&d_curl, interp_bytes * q_comp_curl)); 1105 CeedCallCuda(ceed, cudaMemcpy(d_curl, curl, interp_bytes * q_comp_curl, cudaMemcpyHostToDevice)); 1106 if (in) { 1107 diag->d_interp_in = d_interp; 1108 diag->d_curl_in = d_curl; 1109 } else { 1110 diag->d_interp_out = d_interp; 1111 diag->d_curl_out = d_curl; 1112 } 1113 } break; 1114 } 1115 } 1116 1117 // Arrays of eval_modes 1118 CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_eval_modes_in, num_eval_modes_in * eval_modes_bytes)); 1119 CeedCallCuda(ceed, cudaMemcpy(diag->d_eval_modes_in, eval_modes_in, num_eval_modes_in * eval_modes_bytes, cudaMemcpyHostToDevice)); 1120 CeedCallCuda(ceed, cudaMalloc((void **)&diag->d_eval_modes_out, num_eval_modes_out * eval_modes_bytes)); 1121 CeedCallCuda(ceed, cudaMemcpy(diag->d_eval_modes_out, eval_modes_out, num_eval_modes_out * eval_modes_bytes, cudaMemcpyHostToDevice)); 1122 CeedCallBackend(CeedFree(&eval_modes_in)); 1123 CeedCallBackend(CeedFree(&eval_modes_out)); 1124 return CEED_ERROR_SUCCESS; 1125 } 1126 1127 //------------------------------------------------------------------------------ 1128 // Assemble Diagonal Setup (Compilation) 1129 //------------------------------------------------------------------------------ 1130 static inline int CeedOperatorAssembleDiagonalSetupCompile_Cuda(CeedOperator op, CeedInt use_ceedsize_idx, const bool is_point_block) { 1131 Ceed ceed; 1132 char *diagonal_kernel_source; 1133 const char *diagonal_kernel_path; 1134 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 1135 CeedInt num_comp, q_comp, num_nodes, num_qpts; 1136 CeedBasis basis_in = NULL, basis_out = NULL; 1137 CeedQFunctionField *qf_fields; 1138 CeedQFunction qf; 1139 CeedOperatorField *op_fields; 1140 CeedOperator_Cuda *impl; 1141 1142 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1143 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1144 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 1145 1146 // Determine active input basis 1147 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 1148 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 1149 for (CeedInt i = 0; i < num_input_fields; i++) { 1150 CeedVector vec; 1151 1152 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1153 if (vec == CEED_VECTOR_ACTIVE) { 1154 CeedEvalMode eval_mode; 1155 1156 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_in)); 1157 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1158 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 1159 if (eval_mode != CEED_EVAL_WEIGHT) { 1160 num_eval_modes_in += q_comp; 1161 } 1162 } 1163 } 1164 1165 // Determine active output basis 1166 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 1167 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 1168 for (CeedInt i = 0; i < num_output_fields; i++) { 1169 CeedVector vec; 1170 1171 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 1172 if (vec == CEED_VECTOR_ACTIVE) { 1173 CeedEvalMode eval_mode; 1174 1175 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_out)); 1176 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1177 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1178 if (eval_mode != CEED_EVAL_WEIGHT) { 1179 num_eval_modes_out += q_comp; 1180 } 1181 } 1182 } 1183 1184 // Operator data struct 1185 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1186 CeedOperatorDiag_Cuda *diag = impl->diag; 1187 1188 // Assemble kernel 1189 CUmodule *module = is_point_block ? &diag->module_point_block : &diag->module; 1190 CeedInt elems_per_block = 1; 1191 CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes)); 1192 CeedCallBackend(CeedBasisGetNumComponents(basis_in, &num_comp)); 1193 if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes; 1194 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); 1195 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-operator-assemble-diagonal.h", &diagonal_kernel_path)); 1196 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Kernel Source -----\n"); 1197 CeedCallBackend(CeedLoadSourceToBuffer(ceed, diagonal_kernel_path, &diagonal_kernel_source)); 1198 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Source Complete! -----\n"); 1199 CeedCallCuda(ceed, CeedCompile_Cuda(ceed, diagonal_kernel_source, module, 8, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 1200 num_eval_modes_out, "NUM_COMP", num_comp, "NUM_NODES", num_nodes, "NUM_QPTS", num_qpts, "USE_CEEDSIZE", 1201 use_ceedsize_idx, "USE_POINT_BLOCK", is_point_block ? 1 : 0, "BLOCK_SIZE", num_nodes * elems_per_block)); 1202 CeedCallCuda(ceed, CeedGetKernel_Cuda(ceed, *module, "LinearDiagonal", is_point_block ? &diag->LinearPointBlock : &diag->LinearDiagonal)); 1203 CeedCallBackend(CeedFree(&diagonal_kernel_path)); 1204 CeedCallBackend(CeedFree(&diagonal_kernel_source)); 1205 return CEED_ERROR_SUCCESS; 1206 } 1207 1208 //------------------------------------------------------------------------------ 1209 // Assemble Diagonal Core 1210 //------------------------------------------------------------------------------ 1211 static inline int CeedOperatorAssembleDiagonalCore_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool is_point_block) { 1212 Ceed ceed; 1213 CeedInt num_elem, num_nodes; 1214 CeedScalar *elem_diag_array; 1215 const CeedScalar *assembled_qf_array; 1216 CeedVector assembled_qf = NULL, elem_diag; 1217 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out, diag_rstr; 1218 CeedOperator_Cuda *impl; 1219 1220 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1221 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1222 1223 // Assemble QFunction 1224 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, request)); 1225 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 1226 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 1227 1228 // Setup 1229 if (!impl->diag) CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Cuda(op)); 1230 CeedOperatorDiag_Cuda *diag = impl->diag; 1231 1232 assert(diag != NULL); 1233 1234 // Assemble kernel if needed 1235 if ((!is_point_block && !diag->LinearDiagonal) || (is_point_block && !diag->LinearPointBlock)) { 1236 CeedSize assembled_length, assembled_qf_length; 1237 CeedInt use_ceedsize_idx = 0; 1238 CeedCallBackend(CeedVectorGetLength(assembled, &assembled_length)); 1239 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 1240 if ((assembled_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 1241 1242 CeedCallBackend(CeedOperatorAssembleDiagonalSetupCompile_Cuda(op, use_ceedsize_idx, is_point_block)); 1243 } 1244 1245 // Restriction and diagonal vector 1246 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 1247 CeedCheck(rstr_in == rstr_out, ceed, CEED_ERROR_BACKEND, 1248 "Cannot assemble operator diagonal with different input and output active element restrictions"); 1249 if (!is_point_block && !diag->diag_rstr) { 1250 CeedCallBackend(CeedElemRestrictionCreateUnsignedCopy(rstr_out, &diag->diag_rstr)); 1251 CeedCallBackend(CeedElemRestrictionCreateVector(diag->diag_rstr, NULL, &diag->elem_diag)); 1252 } else if (is_point_block && !diag->point_block_diag_rstr) { 1253 CeedCallBackend(CeedOperatorCreateActivePointBlockRestriction(rstr_out, &diag->point_block_diag_rstr)); 1254 CeedCallBackend(CeedElemRestrictionCreateVector(diag->point_block_diag_rstr, NULL, &diag->point_block_elem_diag)); 1255 } 1256 diag_rstr = is_point_block ? diag->point_block_diag_rstr : diag->diag_rstr; 1257 elem_diag = is_point_block ? diag->point_block_elem_diag : diag->elem_diag; 1258 CeedCallBackend(CeedVectorSetValue(elem_diag, 0.0)); 1259 1260 // Only assemble diagonal if the basis has nodes, otherwise inputs are null pointers 1261 CeedCallBackend(CeedElemRestrictionGetElementSize(diag_rstr, &num_nodes)); 1262 if (num_nodes > 0) { 1263 // Assemble element operator diagonals 1264 CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem)); 1265 CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array)); 1266 1267 // Compute the diagonal of B^T D B 1268 CeedInt elems_per_block = 1; 1269 CeedInt grid = CeedDivUpInt(num_elem, elems_per_block); 1270 void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_div_in, 1271 &diag->d_curl_in, &diag->d_interp_out, &diag->d_grad_out, &diag->d_div_out, &diag->d_curl_out, 1272 &diag->d_eval_modes_in, &diag->d_eval_modes_out, &assembled_qf_array, &elem_diag_array}; 1273 1274 if (is_point_block) { 1275 CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->LinearPointBlock, grid, num_nodes, 1, elems_per_block, args)); 1276 } else { 1277 CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->LinearDiagonal, grid, num_nodes, 1, elems_per_block, args)); 1278 } 1279 1280 // Restore arrays 1281 CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array)); 1282 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 1283 } 1284 1285 // Assemble local operator diagonal 1286 CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request)); 1287 1288 // Cleanup 1289 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1290 return CEED_ERROR_SUCCESS; 1291 } 1292 1293 //------------------------------------------------------------------------------ 1294 // Assemble Linear Diagonal 1295 //------------------------------------------------------------------------------ 1296 static int CeedOperatorLinearAssembleAddDiagonal_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1297 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Cuda(op, assembled, request, false)); 1298 return CEED_ERROR_SUCCESS; 1299 } 1300 1301 //------------------------------------------------------------------------------ 1302 // Assemble Linear Point Block Diagonal 1303 //------------------------------------------------------------------------------ 1304 static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1305 CeedCallBackend(CeedOperatorAssembleDiagonalCore_Cuda(op, assembled, request, true)); 1306 return CEED_ERROR_SUCCESS; 1307 } 1308 1309 //------------------------------------------------------------------------------ 1310 // Single Operator Assembly Setup 1311 //------------------------------------------------------------------------------ 1312 static int CeedSingleOperatorAssembleSetup_Cuda(CeedOperator op, CeedInt use_ceedsize_idx) { 1313 Ceed ceed; 1314 Ceed_Cuda *cuda_data; 1315 char *assembly_kernel_source; 1316 const char *assembly_kernel_path; 1317 CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; 1318 CeedInt elem_size_in, num_qpts_in = 0, num_comp_in, elem_size_out, num_qpts_out, num_comp_out, q_comp; 1319 CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL; 1320 CeedElemRestriction rstr_in = NULL, rstr_out = NULL; 1321 CeedBasis basis_in = NULL, basis_out = NULL; 1322 CeedQFunctionField *qf_fields; 1323 CeedQFunction qf; 1324 CeedOperatorField *input_fields, *output_fields; 1325 CeedOperator_Cuda *impl; 1326 1327 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1328 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1329 1330 // Get intput and output fields 1331 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields)); 1332 1333 // Determine active input basis eval mode 1334 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1335 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 1336 for (CeedInt i = 0; i < num_input_fields; i++) { 1337 CeedVector vec; 1338 1339 CeedCallBackend(CeedOperatorFieldGetVector(input_fields[i], &vec)); 1340 if (vec == CEED_VECTOR_ACTIVE) { 1341 CeedBasis basis; 1342 CeedEvalMode eval_mode; 1343 1344 CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis)); 1345 CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases"); 1346 basis_in = basis; 1347 CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &rstr_in)); 1348 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 1349 if (basis_in == CEED_BASIS_NONE) num_qpts_in = elem_size_in; 1350 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts_in)); 1351 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1352 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp)); 1353 if (eval_mode != CEED_EVAL_WEIGHT) { 1354 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 1355 CeedCallBackend(CeedRealloc(num_eval_modes_in + q_comp, &eval_modes_in)); 1356 for (CeedInt d = 0; d < q_comp; d++) { 1357 eval_modes_in[num_eval_modes_in + d] = eval_mode; 1358 } 1359 num_eval_modes_in += q_comp; 1360 } 1361 } 1362 } 1363 1364 // Determine active output basis; basis_out and rstr_out only used if same as input, TODO 1365 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 1366 for (CeedInt i = 0; i < num_output_fields; i++) { 1367 CeedVector vec; 1368 1369 CeedCallBackend(CeedOperatorFieldGetVector(output_fields[i], &vec)); 1370 if (vec == CEED_VECTOR_ACTIVE) { 1371 CeedBasis basis; 1372 CeedEvalMode eval_mode; 1373 1374 CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis)); 1375 CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND, 1376 "Backend does not implement operator assembly with multiple active bases"); 1377 basis_out = basis; 1378 CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &rstr_out)); 1379 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 1380 if (basis_out == CEED_BASIS_NONE) num_qpts_out = elem_size_out; 1381 else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_out, &num_qpts_out)); 1382 CeedCheck(num_qpts_in == num_qpts_out, ceed, CEED_ERROR_UNSUPPORTED, 1383 "Active input and output bases must have the same number of quadrature points"); 1384 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 1385 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp)); 1386 if (eval_mode != CEED_EVAL_WEIGHT) { 1387 // q_comp = 1 if CEED_EVAL_NONE, CEED_EVAL_WEIGHT caught by QF Assembly 1388 CeedCallBackend(CeedRealloc(num_eval_modes_out + q_comp, &eval_modes_out)); 1389 for (CeedInt d = 0; d < q_comp; d++) { 1390 eval_modes_out[num_eval_modes_out + d] = eval_mode; 1391 } 1392 num_eval_modes_out += q_comp; 1393 } 1394 } 1395 } 1396 CeedCheck(num_eval_modes_in > 0 && num_eval_modes_out > 0, ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs"); 1397 1398 CeedCallBackend(CeedCalloc(1, &impl->asmb)); 1399 CeedOperatorAssemble_Cuda *asmb = impl->asmb; 1400 asmb->elems_per_block = 1; 1401 asmb->block_size_x = elem_size_in; 1402 asmb->block_size_y = elem_size_out; 1403 1404 CeedCallBackend(CeedGetData(ceed, &cuda_data)); 1405 bool fallback = asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block > cuda_data->device_prop.maxThreadsPerBlock; 1406 1407 if (fallback) { 1408 // Use fallback kernel with 1D threadblock 1409 asmb->block_size_y = 1; 1410 } 1411 1412 // Compile kernels 1413 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &num_comp_in)); 1414 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_out, &num_comp_out)); 1415 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-operator-assemble.h", &assembly_kernel_path)); 1416 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Kernel Source -----\n"); 1417 CeedCallBackend(CeedLoadSourceToBuffer(ceed, assembly_kernel_path, &assembly_kernel_source)); 1418 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Source Complete! -----\n"); 1419 CeedCallBackend(CeedCompile_Cuda(ceed, assembly_kernel_source, &asmb->module, 10, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", 1420 num_eval_modes_out, "NUM_COMP_IN", num_comp_in, "NUM_COMP_OUT", num_comp_out, "NUM_NODES_IN", elem_size_in, 1421 "NUM_NODES_OUT", elem_size_out, "NUM_QPTS", num_qpts_in, "BLOCK_SIZE", 1422 asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block, "BLOCK_SIZE_Y", asmb->block_size_y, 1423 "USE_CEEDSIZE", use_ceedsize_idx)); 1424 CeedCallBackend(CeedGetKernel_Cuda(ceed, asmb->module, "LinearAssemble", &asmb->LinearAssemble)); 1425 CeedCallBackend(CeedFree(&assembly_kernel_path)); 1426 CeedCallBackend(CeedFree(&assembly_kernel_source)); 1427 1428 // Load into B_in, in order that they will be used in eval_modes_in 1429 { 1430 const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof(CeedScalar); 1431 CeedInt d_in = 0; 1432 CeedEvalMode eval_modes_in_prev = CEED_EVAL_NONE; 1433 bool has_eval_none = false; 1434 CeedScalar *identity = NULL; 1435 1436 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 1437 has_eval_none = has_eval_none || (eval_modes_in[i] == CEED_EVAL_NONE); 1438 } 1439 if (has_eval_none) { 1440 CeedCallBackend(CeedCalloc(elem_size_in * num_qpts_in, &identity)); 1441 for (CeedInt i = 0; i < (elem_size_in < num_qpts_in ? elem_size_in : num_qpts_in); i++) identity[i * elem_size_in + i] = 1.0; 1442 } 1443 1444 CeedCallCuda(ceed, cudaMalloc((void **)&asmb->d_B_in, in_bytes)); 1445 for (CeedInt i = 0; i < num_eval_modes_in; i++) { 1446 const CeedScalar *h_B_in; 1447 1448 CeedCallBackend(CeedOperatorGetBasisPointer(basis_in, eval_modes_in[i], identity, &h_B_in)); 1449 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_modes_in[i], &q_comp)); 1450 if (q_comp > 1) { 1451 if (i == 0 || eval_modes_in[i] != eval_modes_in_prev) d_in = 0; 1452 else h_B_in = &h_B_in[(++d_in) * elem_size_in * num_qpts_in]; 1453 } 1454 eval_modes_in_prev = eval_modes_in[i]; 1455 1456 CeedCallCuda(ceed, cudaMemcpy(&asmb->d_B_in[i * elem_size_in * num_qpts_in], h_B_in, elem_size_in * num_qpts_in * sizeof(CeedScalar), 1457 cudaMemcpyHostToDevice)); 1458 } 1459 1460 if (identity) { 1461 CeedCallBackend(CeedFree(&identity)); 1462 } 1463 } 1464 1465 // Load into B_out, in order that they will be used in eval_modes_out 1466 { 1467 const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof(CeedScalar); 1468 CeedInt d_out = 0; 1469 CeedEvalMode eval_modes_out_prev = CEED_EVAL_NONE; 1470 bool has_eval_none = false; 1471 CeedScalar *identity = NULL; 1472 1473 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1474 has_eval_none = has_eval_none || (eval_modes_out[i] == CEED_EVAL_NONE); 1475 } 1476 if (has_eval_none) { 1477 CeedCallBackend(CeedCalloc(elem_size_out * num_qpts_out, &identity)); 1478 for (CeedInt i = 0; i < (elem_size_out < num_qpts_out ? elem_size_out : num_qpts_out); i++) identity[i * elem_size_out + i] = 1.0; 1479 } 1480 1481 CeedCallCuda(ceed, cudaMalloc((void **)&asmb->d_B_out, out_bytes)); 1482 for (CeedInt i = 0; i < num_eval_modes_out; i++) { 1483 const CeedScalar *h_B_out; 1484 1485 CeedCallBackend(CeedOperatorGetBasisPointer(basis_out, eval_modes_out[i], identity, &h_B_out)); 1486 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_modes_out[i], &q_comp)); 1487 if (q_comp > 1) { 1488 if (i == 0 || eval_modes_out[i] != eval_modes_out_prev) d_out = 0; 1489 else h_B_out = &h_B_out[(++d_out) * elem_size_out * num_qpts_out]; 1490 } 1491 eval_modes_out_prev = eval_modes_out[i]; 1492 1493 CeedCallCuda(ceed, cudaMemcpy(&asmb->d_B_out[i * elem_size_out * num_qpts_out], h_B_out, elem_size_out * num_qpts_out * sizeof(CeedScalar), 1494 cudaMemcpyHostToDevice)); 1495 } 1496 1497 if (identity) { 1498 CeedCallBackend(CeedFree(&identity)); 1499 } 1500 } 1501 return CEED_ERROR_SUCCESS; 1502 } 1503 1504 //------------------------------------------------------------------------------ 1505 // Assemble matrix data for COO matrix of assembled operator. 1506 // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic. 1507 // 1508 // Note that this (and other assembly routines) currently assume only one active input restriction/basis per operator (could have multiple basis eval 1509 // modes). 1510 // TODO: allow multiple active input restrictions/basis objects 1511 //------------------------------------------------------------------------------ 1512 static int CeedSingleOperatorAssemble_Cuda(CeedOperator op, CeedInt offset, CeedVector values) { 1513 Ceed ceed; 1514 CeedSize values_length = 0, assembled_qf_length = 0; 1515 CeedInt use_ceedsize_idx = 0, num_elem_in, num_elem_out, elem_size_in, elem_size_out; 1516 CeedScalar *values_array; 1517 const CeedScalar *assembled_qf_array; 1518 CeedVector assembled_qf = NULL; 1519 CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out; 1520 CeedRestrictionType rstr_type_in, rstr_type_out; 1521 const bool *orients_in = NULL, *orients_out = NULL; 1522 const CeedInt8 *curl_orients_in = NULL, *curl_orients_out = NULL; 1523 CeedOperator_Cuda *impl; 1524 1525 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1526 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1527 1528 // Assemble QFunction 1529 CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, CEED_REQUEST_IMMEDIATE)); 1530 CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr)); 1531 CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); 1532 1533 CeedCallBackend(CeedVectorGetLength(values, &values_length)); 1534 CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length)); 1535 if ((values_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1; 1536 1537 // Setup 1538 if (!impl->asmb) CeedCallBackend(CeedSingleOperatorAssembleSetup_Cuda(op, use_ceedsize_idx)); 1539 CeedOperatorAssemble_Cuda *asmb = impl->asmb; 1540 1541 assert(asmb != NULL); 1542 1543 // Assemble element operator 1544 CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_DEVICE, &values_array)); 1545 values_array += offset; 1546 1547 CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out)); 1548 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &num_elem_in)); 1549 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in)); 1550 1551 CeedCallBackend(CeedElemRestrictionGetType(rstr_in, &rstr_type_in)); 1552 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1553 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_in, CEED_MEM_DEVICE, &orients_in)); 1554 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1555 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_in, CEED_MEM_DEVICE, &curl_orients_in)); 1556 } 1557 1558 if (rstr_in != rstr_out) { 1559 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_out, &num_elem_out)); 1560 CeedCheck(num_elem_in == num_elem_out, ceed, CEED_ERROR_UNSUPPORTED, 1561 "Active input and output operator restrictions must have the same number of elements"); 1562 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out)); 1563 1564 CeedCallBackend(CeedElemRestrictionGetType(rstr_out, &rstr_type_out)); 1565 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1566 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_out, CEED_MEM_DEVICE, &orients_out)); 1567 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1568 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_out, CEED_MEM_DEVICE, &curl_orients_out)); 1569 } 1570 } else { 1571 elem_size_out = elem_size_in; 1572 orients_out = orients_in; 1573 curl_orients_out = curl_orients_in; 1574 } 1575 1576 // Compute B^T D B 1577 CeedInt shared_mem = 1578 ((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0) + (curl_orients_in ? elem_size_in * asmb->block_size_y : 0)) * 1579 sizeof(CeedScalar); 1580 CeedInt grid = CeedDivUpInt(num_elem_in, asmb->elems_per_block); 1581 void *args[] = {(void *)&num_elem_in, &asmb->d_B_in, &asmb->d_B_out, &orients_in, &curl_orients_in, 1582 &orients_out, &curl_orients_out, &assembled_qf_array, &values_array}; 1583 1584 CeedCallBackend( 1585 CeedRunKernelDimShared_Cuda(ceed, asmb->LinearAssemble, grid, asmb->block_size_x, asmb->block_size_y, asmb->elems_per_block, shared_mem, args)); 1586 1587 // Restore arrays 1588 CeedCallBackend(CeedVectorRestoreArray(values, &values_array)); 1589 CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); 1590 1591 // Cleanup 1592 CeedCallBackend(CeedVectorDestroy(&assembled_qf)); 1593 if (rstr_type_in == CEED_RESTRICTION_ORIENTED) { 1594 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_in, &orients_in)); 1595 } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) { 1596 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_in, &curl_orients_in)); 1597 } 1598 if (rstr_in != rstr_out) { 1599 if (rstr_type_out == CEED_RESTRICTION_ORIENTED) { 1600 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_out, &orients_out)); 1601 } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) { 1602 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_out, &curl_orients_out)); 1603 } 1604 } 1605 return CEED_ERROR_SUCCESS; 1606 } 1607 1608 //------------------------------------------------------------------------------ 1609 // Assemble Linear QFunction AtPoints 1610 //------------------------------------------------------------------------------ 1611 static int CeedOperatorLinearAssembleQFunctionAtPoints_Cuda(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 1612 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "Backend does not implement CeedOperatorLinearAssembleQFunction"); 1613 } 1614 1615 //------------------------------------------------------------------------------ 1616 // Assemble Linear Diagonal AtPoints 1617 //------------------------------------------------------------------------------ 1618 static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1619 CeedInt max_num_points, num_elem, num_input_fields, num_output_fields; 1620 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {NULL}; 1621 CeedQFunctionField *qf_input_fields, *qf_output_fields; 1622 CeedQFunction qf; 1623 CeedOperatorField *op_input_fields, *op_output_fields; 1624 CeedOperator_Cuda *impl; 1625 1626 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1627 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1628 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 1629 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 1630 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 1631 CeedInt num_points[num_elem]; 1632 1633 // Setup 1634 CeedCallBackend(CeedOperatorSetupAtPoints_Cuda(op)); 1635 max_num_points = impl->max_num_points; 1636 for (CeedInt i = 0; i < num_elem; i++) num_points[i] = max_num_points; 1637 1638 // Create separate output e-vecs 1639 if (impl->has_shared_e_vecs) { 1640 for (CeedInt i = 0; i < impl->num_outputs; i++) { 1641 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 1642 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[impl->num_inputs + i])); 1643 } 1644 CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out, 1645 num_input_fields, num_output_fields, max_num_points, num_elem)); 1646 } 1647 impl->has_shared_e_vecs = false; 1648 1649 // Input Evecs and Restriction 1650 CeedCallBackend(CeedOperatorSetupInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 1651 1652 // Get point coordinates 1653 if (!impl->point_coords_elem) { 1654 CeedVector point_coords = NULL; 1655 CeedElemRestriction rstr_points = NULL; 1656 1657 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 1658 CeedCallBackend(CeedElemRestrictionCreateVector(rstr_points, NULL, &impl->point_coords_elem)); 1659 CeedCallBackend(CeedElemRestrictionApply(rstr_points, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 1660 } 1661 1662 // Clear active input Qvecs 1663 for (CeedInt i = 0; i < num_input_fields; i++) { 1664 CeedVector vec; 1665 1666 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1667 if (vec != CEED_VECTOR_ACTIVE) continue; 1668 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 1669 } 1670 1671 // Input basis apply if needed 1672 CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(num_elem, num_points, qf_input_fields, op_input_fields, num_input_fields, true, e_data, impl)); 1673 1674 // Output pointers, as necessary 1675 for (CeedInt i = 0; i < num_output_fields; i++) { 1676 CeedEvalMode eval_mode; 1677 1678 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1679 if (eval_mode == CEED_EVAL_NONE) { 1680 // Set the output Q-Vector to use the E-Vector data directly. 1681 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[i + impl->num_inputs], CEED_MEM_DEVICE, &e_data[i + num_input_fields])); 1682 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i + num_input_fields])); 1683 } 1684 } 1685 1686 // Loop over active fields 1687 for (CeedInt i = 0; i < num_input_fields; i++) { 1688 bool is_active_at_points = true; 1689 CeedInt elem_size = 1, num_comp_active = 1, e_vec_size = 0; 1690 CeedRestrictionType rstr_type; 1691 CeedVector vec; 1692 CeedElemRestriction elem_rstr; 1693 1694 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1695 // -- Skip non-active input 1696 if (vec != CEED_VECTOR_ACTIVE) continue; 1697 1698 // -- Get active restriction type 1699 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 1700 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1701 is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS; 1702 if (!is_active_at_points) CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 1703 else elem_size = max_num_points; 1704 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp_active)); 1705 1706 e_vec_size = elem_size * num_comp_active; 1707 for (CeedInt s = 0; s < e_vec_size; s++) { 1708 bool is_active_input = false; 1709 CeedEvalMode eval_mode; 1710 CeedVector vec; 1711 CeedBasis basis; 1712 1713 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1714 // Skip non-active input 1715 is_active_input = vec == CEED_VECTOR_ACTIVE; 1716 if (!is_active_input) continue; 1717 1718 // Update unit vector 1719 if (s == 0) CeedCallBackend(CeedVectorSetValue(impl->e_vecs[i], 0.0)); 1720 else CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs[i], s - 1, e_vec_size, 0.0)); 1721 CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs[i], s, e_vec_size, 1.0)); 1722 1723 // Basis action 1724 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 1725 switch (eval_mode) { 1726 case CEED_EVAL_NONE: 1727 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i])); 1728 break; 1729 case CEED_EVAL_INTERP: 1730 case CEED_EVAL_GRAD: 1731 case CEED_EVAL_DIV: 1732 case CEED_EVAL_CURL: 1733 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 1734 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, impl->e_vecs[i], 1735 impl->q_vecs_in[i])); 1736 break; 1737 case CEED_EVAL_WEIGHT: 1738 break; // No action 1739 } 1740 1741 // Q function 1742 CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out)); 1743 1744 // Output basis apply if needed 1745 for (CeedInt j = 0; j < num_output_fields; j++) { 1746 bool is_active_output = false; 1747 CeedInt elem_size = 0; 1748 CeedRestrictionType rstr_type; 1749 CeedEvalMode eval_mode; 1750 CeedVector vec; 1751 CeedElemRestriction elem_rstr; 1752 CeedBasis basis; 1753 1754 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec)); 1755 // ---- Skip non-active output 1756 is_active_output = vec == CEED_VECTOR_ACTIVE; 1757 if (!is_active_output) continue; 1758 1759 // ---- Check if elem size matches 1760 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr)); 1761 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1762 if (is_active_at_points && rstr_type != CEED_RESTRICTION_POINTS) continue; 1763 if (rstr_type == CEED_RESTRICTION_POINTS) { 1764 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(elem_rstr, &elem_size)); 1765 } else { 1766 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 1767 } 1768 { 1769 CeedInt num_comp = 0; 1770 1771 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp)); 1772 if (e_vec_size != num_comp * elem_size) continue; 1773 } 1774 1775 // Basis action 1776 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode)); 1777 switch (eval_mode) { 1778 case CEED_EVAL_NONE: 1779 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[j + impl->num_inputs], &e_data[j + num_input_fields])); 1780 break; 1781 case CEED_EVAL_INTERP: 1782 case CEED_EVAL_GRAD: 1783 case CEED_EVAL_DIV: 1784 case CEED_EVAL_CURL: 1785 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[j], &basis)); 1786 CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, 1787 impl->q_vecs_out[j], impl->e_vecs[j + impl->num_inputs])); 1788 break; 1789 // LCOV_EXCL_START 1790 case CEED_EVAL_WEIGHT: { 1791 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 1792 // LCOV_EXCL_STOP 1793 } 1794 } 1795 1796 // Mask output e-vec 1797 CeedCallBackend(CeedVectorPointwiseMult(impl->e_vecs[j + impl->num_inputs], impl->e_vecs[i], impl->e_vecs[j + impl->num_inputs])); 1798 1799 // Restrict 1800 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr)); 1801 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[j + impl->num_inputs], assembled, request)); 1802 1803 // Reset q_vec for 1804 if (eval_mode == CEED_EVAL_NONE) { 1805 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs[j + impl->num_inputs], CEED_MEM_DEVICE, &e_data[j + num_input_fields])); 1806 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[j], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[j + num_input_fields])); 1807 } 1808 } 1809 1810 // Reset vec 1811 if (s == e_vec_size - 1 && i != num_input_fields - 1) CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 1812 } 1813 } 1814 1815 // Restore CEED_EVAL_NONE 1816 for (CeedInt i = 0; i < num_output_fields; i++) { 1817 CeedEvalMode eval_mode; 1818 1819 // Get eval_mode 1820 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1821 1822 // Restore evec 1823 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 1824 if (eval_mode == CEED_EVAL_NONE) { 1825 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); 1826 } 1827 } 1828 1829 // Restore input arrays 1830 CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 1831 return CEED_ERROR_SUCCESS; 1832 } 1833 1834 //------------------------------------------------------------------------------ 1835 // Create operator 1836 //------------------------------------------------------------------------------ 1837 int CeedOperatorCreate_Cuda(CeedOperator op) { 1838 Ceed ceed; 1839 CeedOperator_Cuda *impl; 1840 1841 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1842 CeedCallBackend(CeedCalloc(1, &impl)); 1843 CeedCallBackend(CeedOperatorSetData(op, impl)); 1844 1845 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Cuda)); 1846 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Cuda)); 1847 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonal_Cuda)); 1848 CeedCallBackend( 1849 CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal", CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda)); 1850 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssemble_Cuda)); 1851 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Cuda)); 1852 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda)); 1853 return CEED_ERROR_SUCCESS; 1854 } 1855 1856 //------------------------------------------------------------------------------ 1857 // Create operator AtPoints 1858 //------------------------------------------------------------------------------ 1859 int CeedOperatorCreateAtPoints_Cuda(CeedOperator op) { 1860 Ceed ceed; 1861 CeedOperator_Cuda *impl; 1862 1863 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1864 CeedCallBackend(CeedCalloc(1, &impl)); 1865 CeedCallBackend(CeedOperatorSetData(op, impl)); 1866 1867 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunctionAtPoints_Cuda)); 1868 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda)); 1869 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Cuda)); 1870 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda)); 1871 return CEED_ERROR_SUCCESS; 1872 } 1873 1874 //------------------------------------------------------------------------------ 1875