1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3 // All Rights reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative. 16 17 #include <ceed/ceed.h> 18 #include <ceed/backend.h> 19 #include <stdbool.h> 20 #include <stddef.h> 21 #include <stdint.h> 22 #include "ceed-blocked.h" 23 24 //------------------------------------------------------------------------------ 25 // Setup Input/Output Fields 26 //------------------------------------------------------------------------------ 27 static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, 28 CeedOperator op, bool is_input, CeedElemRestriction *blk_restr, 29 CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs, 30 CeedInt start_e, CeedInt num_fields, CeedInt Q) { 31 CeedInt ierr, num_comp, size, P; 32 Ceed ceed; 33 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 34 CeedBasis basis; 35 CeedElemRestriction r; 36 CeedOperatorField *op_fields; 37 CeedQFunctionField *qf_fields; 38 if (is_input) { 39 ierr = CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL); 40 CeedChkBackend(ierr); 41 ierr = CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL); 42 CeedChkBackend(ierr); 43 } else { 44 ierr = CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields); 45 CeedChkBackend(ierr); 46 ierr = CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields); 47 CeedChkBackend(ierr); 48 } 49 const CeedInt blk_size = 8; 50 51 // Loop over fields 52 for (CeedInt i=0; i<num_fields; i++) { 53 CeedEvalMode eval_mode; 54 ierr = CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode); 55 CeedChkBackend(ierr); 56 57 if (eval_mode != CEED_EVAL_WEIGHT) { 58 ierr = CeedOperatorFieldGetElemRestriction(op_fields[i], &r); 59 CeedChkBackend(ierr); 60 ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr); 61 CeedInt num_elem, elem_size, l_size, comp_stride; 62 ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr); 63 ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr); 64 ierr = CeedElemRestrictionGetLVectorSize(r, &l_size); CeedChkBackend(ierr); 65 ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr); 66 67 bool strided; 68 ierr = CeedElemRestrictionIsStrided(r, &strided); CeedChkBackend(ierr); 69 if (strided) { 70 CeedInt strides[3]; 71 ierr = CeedElemRestrictionGetStrides(r, &strides); CeedChkBackend(ierr); 72 ierr = CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, elem_size, 73 blk_size, num_comp, l_size, strides, &blk_restr[i+start_e]); 74 CeedChkBackend(ierr); 75 } else { 76 const CeedInt *offsets = NULL; 77 ierr = CeedElemRestrictionGetOffsets(r, CEED_MEM_HOST, &offsets); 78 CeedChkBackend(ierr); 79 ierr = CeedElemRestrictionGetCompStride(r, &comp_stride); CeedChkBackend(ierr); 80 ierr = CeedElemRestrictionCreateBlocked(ceed, num_elem, elem_size, 81 blk_size, num_comp, comp_stride, 82 l_size, CEED_MEM_HOST, 83 CEED_COPY_VALUES, offsets, 84 &blk_restr[i+start_e]); 85 CeedChkBackend(ierr); 86 ierr = CeedElemRestrictionRestoreOffsets(r, &offsets); CeedChkBackend(ierr); 87 } 88 ierr = CeedElemRestrictionCreateVector(blk_restr[i+start_e], NULL, 89 &e_vecs_full[i+start_e]); 90 CeedChkBackend(ierr); 91 } 92 93 switch(eval_mode) { 94 case CEED_EVAL_NONE: 95 ierr = CeedQFunctionFieldGetSize(qf_fields[i], &size); CeedChkBackend(ierr); 96 ierr = CeedVectorCreate(ceed, Q*size*blk_size, &q_vecs[i]); 97 CeedChkBackend(ierr); 98 break; 99 case CEED_EVAL_INTERP: 100 case CEED_EVAL_GRAD: 101 ierr = CeedOperatorFieldGetBasis(op_fields[i], &basis); CeedChkBackend(ierr); 102 ierr = CeedQFunctionFieldGetSize(qf_fields[i], &size); CeedChkBackend(ierr); 103 ierr = CeedBasisGetNumNodes(basis, &P); CeedChkBackend(ierr); 104 ierr = CeedBasisGetNumComponents(basis, &num_comp); CeedChkBackend(ierr); 105 ierr = CeedVectorCreate(ceed, P*num_comp*blk_size, &e_vecs[i]); 106 CeedChkBackend(ierr); 107 ierr = CeedVectorCreate(ceed, Q*size*blk_size, &q_vecs[i]); 108 CeedChkBackend(ierr); 109 break; 110 case CEED_EVAL_WEIGHT: // Only on input fields 111 ierr = CeedOperatorFieldGetBasis(op_fields[i], &basis); CeedChkBackend(ierr); 112 ierr = CeedVectorCreate(ceed, Q*blk_size, &q_vecs[i]); CeedChkBackend(ierr); 113 ierr = CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, 114 CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i]); 115 CeedChkBackend(ierr); 116 117 break; 118 case CEED_EVAL_DIV: 119 break; // Not implemented 120 case CEED_EVAL_CURL: 121 break; // Not implemented 122 } 123 } 124 return CEED_ERROR_SUCCESS; 125 } 126 127 //------------------------------------------------------------------------------ 128 // Setup Operator 129 //------------------------------------------------------------------------------ 130 static int CeedOperatorSetup_Blocked(CeedOperator op) { 131 int ierr; 132 bool is_setup_done; 133 ierr = CeedOperatorIsSetupDone(op, &is_setup_done); CeedChkBackend(ierr); 134 if (is_setup_done) return CEED_ERROR_SUCCESS; 135 Ceed ceed; 136 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 137 CeedOperator_Blocked *impl; 138 ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 139 CeedQFunction qf; 140 ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr); 141 CeedInt Q, num_input_fields, num_output_fields; 142 ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr); 143 ierr = CeedQFunctionIsIdentity(qf, &impl->is_identity_qf); CeedChkBackend(ierr); 144 CeedOperatorField *op_input_fields, *op_output_fields; 145 ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, 146 &num_output_fields, &op_output_fields); 147 CeedChkBackend(ierr); 148 CeedQFunctionField *qf_input_fields, *qf_output_fields; 149 ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, 150 &qf_output_fields); 151 CeedChkBackend(ierr); 152 153 // Allocate 154 ierr = CeedCalloc(num_input_fields + num_output_fields, &impl->blk_restr); 155 CeedChkBackend(ierr); 156 ierr = CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full); 157 CeedChkBackend(ierr); 158 159 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->input_states); CeedChkBackend(ierr); 160 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in); CeedChkBackend(ierr); 161 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out); CeedChkBackend(ierr); 162 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in); CeedChkBackend(ierr); 163 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out); CeedChkBackend(ierr); 164 165 impl->num_inputs = num_input_fields; 166 impl->num_outputs = num_output_fields; 167 168 // Set up infield and outfield pointer arrays 169 // Infields 170 ierr = CeedOperatorSetupFields_Blocked(qf, op, true, impl->blk_restr, 171 impl->e_vecs_full, impl->e_vecs_in, 172 impl->q_vecs_in, 0, 173 num_input_fields, Q); 174 CeedChkBackend(ierr); 175 // Outfields 176 ierr = CeedOperatorSetupFields_Blocked(qf, op, false, impl->blk_restr, 177 impl->e_vecs_full, impl->e_vecs_out, 178 impl->q_vecs_out, num_input_fields, 179 num_output_fields, Q); 180 CeedChkBackend(ierr); 181 182 // Identity QFunctions 183 if (impl->is_identity_qf) { 184 CeedEvalMode in_mode, out_mode; 185 CeedQFunctionField *in_fields, *out_fields; 186 ierr = CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields); 187 CeedChkBackend(ierr); 188 ierr = CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode); 189 CeedChkBackend(ierr); 190 ierr = CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode); 191 CeedChkBackend(ierr); 192 193 if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) { 194 impl->is_identity_restr_op = true; 195 } else { 196 ierr = CeedVectorDestroy(&impl->q_vecs_out[0]); CeedChkBackend(ierr); 197 impl->q_vecs_out[0] = impl->q_vecs_in[0]; 198 ierr = CeedVectorAddReference(impl->q_vecs_in[0]); CeedChkBackend(ierr); 199 } 200 } 201 202 ierr = CeedOperatorSetSetupDone(op); CeedChkBackend(ierr); 203 204 return CEED_ERROR_SUCCESS; 205 } 206 207 //------------------------------------------------------------------------------ 208 // Setup Operator Inputs 209 //------------------------------------------------------------------------------ 210 static inline int CeedOperatorSetupInputs_Blocked(CeedInt num_input_fields, 211 CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 212 CeedVector in_vec, bool skip_active, CeedScalar *e_data_full[2*CEED_FIELD_MAX], 213 CeedOperator_Blocked *impl, CeedRequest *request) { 214 CeedInt ierr; 215 CeedEvalMode eval_mode; 216 CeedVector vec; 217 uint64_t state; 218 219 for (CeedInt i=0; i<num_input_fields; i++) { 220 // Get input vector 221 ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 222 CeedChkBackend(ierr); 223 if (vec == CEED_VECTOR_ACTIVE) { 224 if (skip_active) 225 continue; 226 else 227 vec = in_vec; 228 } 229 230 ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode); 231 CeedChkBackend(ierr); 232 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 233 } else { 234 // Restrict 235 ierr = CeedVectorGetState(vec, &state); CeedChkBackend(ierr); 236 if (state != impl->input_states[i] || vec == in_vec) { 237 ierr = CeedElemRestrictionApply(impl->blk_restr[i], CEED_NOTRANSPOSE, 238 vec, impl->e_vecs_full[i], request); 239 CeedChkBackend(ierr); 240 impl->input_states[i] = state; 241 } 242 // Get evec 243 ierr = CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, 244 (const CeedScalar **) &e_data_full[i]); 245 CeedChkBackend(ierr); 246 } 247 } 248 return CEED_ERROR_SUCCESS; 249 } 250 251 //------------------------------------------------------------------------------ 252 // Input Basis Action 253 //------------------------------------------------------------------------------ 254 static inline int CeedOperatorInputBasis_Blocked(CeedInt e, CeedInt Q, 255 CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 256 CeedInt num_input_fields, CeedInt blk_size, bool skip_active, 257 CeedScalar *e_data_full[2*CEED_FIELD_MAX], CeedOperator_Blocked *impl) { 258 CeedInt ierr; 259 CeedInt dim, elem_size, size; 260 CeedElemRestriction elem_restr; 261 CeedEvalMode eval_mode; 262 CeedBasis basis; 263 264 for (CeedInt i=0; i<num_input_fields; i++) { 265 // Skip active input 266 if (skip_active) { 267 CeedVector vec; 268 ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 269 CeedChkBackend(ierr); 270 if (vec == CEED_VECTOR_ACTIVE) 271 continue; 272 } 273 274 // Get elem_size, eval_mode, size 275 ierr = CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr); 276 CeedChkBackend(ierr); 277 ierr = CeedElemRestrictionGetElementSize(elem_restr, &elem_size); 278 CeedChkBackend(ierr); 279 ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode); 280 CeedChkBackend(ierr); 281 ierr = CeedQFunctionFieldGetSize(qf_input_fields[i], &size); 282 CeedChkBackend(ierr); 283 // Basis action 284 switch(eval_mode) { 285 case CEED_EVAL_NONE: 286 ierr = CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, 287 CEED_USE_POINTER, &e_data_full[i][e*Q*size]); 288 CeedChkBackend(ierr); 289 break; 290 case CEED_EVAL_INTERP: 291 ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis); 292 CeedChkBackend(ierr); 293 ierr = CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, 294 CEED_USE_POINTER, &e_data_full[i][e*elem_size*size]); 295 CeedChkBackend(ierr); 296 ierr = CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, 297 CEED_EVAL_INTERP, impl->e_vecs_in[i], 298 impl->q_vecs_in[i]); CeedChkBackend(ierr); 299 break; 300 case CEED_EVAL_GRAD: 301 ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis); 302 CeedChkBackend(ierr); 303 ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr); 304 ierr = CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, 305 CEED_USE_POINTER, &e_data_full[i][e*elem_size*size/dim]); 306 CeedChkBackend(ierr); 307 ierr = CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, 308 CEED_EVAL_GRAD, impl->e_vecs_in[i], 309 impl->q_vecs_in[i]); CeedChkBackend(ierr); 310 break; 311 case CEED_EVAL_WEIGHT: 312 break; // No action 313 // LCOV_EXCL_START 314 case CEED_EVAL_DIV: 315 case CEED_EVAL_CURL: { 316 ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis); 317 CeedChkBackend(ierr); 318 Ceed ceed; 319 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 320 return CeedError(ceed, CEED_ERROR_BACKEND, 321 "Ceed evaluation mode not implemented"); 322 // LCOV_EXCL_STOP 323 } 324 } 325 } 326 return CEED_ERROR_SUCCESS; 327 } 328 329 //------------------------------------------------------------------------------ 330 // Output Basis Action 331 //------------------------------------------------------------------------------ 332 static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, 333 CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, 334 CeedInt blk_size, CeedInt num_input_fields, CeedInt num_output_fields, 335 CeedOperator op, CeedScalar *e_data_full[2*CEED_FIELD_MAX], 336 CeedOperator_Blocked *impl) { 337 CeedInt ierr; 338 CeedInt dim, elem_size, size; 339 CeedElemRestriction elem_restr; 340 CeedEvalMode eval_mode; 341 CeedBasis basis; 342 343 for (CeedInt i=0; i<num_output_fields; i++) { 344 // Get elem_size, eval_mode, size 345 ierr = CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr); 346 CeedChkBackend(ierr); 347 ierr = CeedElemRestrictionGetElementSize(elem_restr, &elem_size); 348 CeedChkBackend(ierr); 349 ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode); 350 CeedChkBackend(ierr); 351 ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size); 352 CeedChkBackend(ierr); 353 // Basis action 354 switch(eval_mode) { 355 case CEED_EVAL_NONE: 356 break; // No action 357 case CEED_EVAL_INTERP: 358 ierr = CeedOperatorFieldGetBasis(op_output_fields[i], &basis); 359 CeedChkBackend(ierr); 360 ierr = CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, 361 CEED_USE_POINTER, &e_data_full[i + num_input_fields][e*elem_size*size]); 362 CeedChkBackend(ierr); 363 ierr = CeedBasisApply(basis, blk_size, CEED_TRANSPOSE, 364 CEED_EVAL_INTERP, impl->q_vecs_out[i], 365 impl->e_vecs_out[i]); CeedChkBackend(ierr); 366 break; 367 case CEED_EVAL_GRAD: 368 ierr = CeedOperatorFieldGetBasis(op_output_fields[i], &basis); 369 CeedChkBackend(ierr); 370 ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr); 371 ierr = CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, 372 CEED_USE_POINTER, &e_data_full[i + num_input_fields][e*elem_size*size/dim]); 373 CeedChkBackend(ierr); 374 ierr = CeedBasisApply(basis, blk_size, CEED_TRANSPOSE, 375 CEED_EVAL_GRAD, impl->q_vecs_out[i], 376 impl->e_vecs_out[i]); CeedChkBackend(ierr); 377 break; 378 // LCOV_EXCL_START 379 case CEED_EVAL_WEIGHT: { 380 Ceed ceed; 381 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 382 return CeedError(ceed, CEED_ERROR_BACKEND, 383 "CEED_EVAL_WEIGHT cannot be an output " 384 "evaluation mode"); 385 } 386 case CEED_EVAL_DIV: 387 case CEED_EVAL_CURL: { 388 Ceed ceed; 389 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 390 return CeedError(ceed, CEED_ERROR_BACKEND, 391 "Ceed evaluation mode not implemented"); 392 // LCOV_EXCL_STOP 393 } 394 } 395 } 396 return CEED_ERROR_SUCCESS; 397 } 398 399 //------------------------------------------------------------------------------ 400 // Restore Input Vectors 401 //------------------------------------------------------------------------------ 402 static inline int CeedOperatorRestoreInputs_Blocked(CeedInt num_input_fields, 403 CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 404 bool skip_active, CeedScalar *e_data_full[2*CEED_FIELD_MAX], 405 CeedOperator_Blocked *impl) { 406 CeedInt ierr; 407 CeedEvalMode eval_mode; 408 409 for (CeedInt i=0; i<num_input_fields; i++) { 410 // Skip active inputs 411 if (skip_active) { 412 CeedVector vec; 413 ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 414 CeedChkBackend(ierr); 415 if (vec == CEED_VECTOR_ACTIVE) 416 continue; 417 } 418 ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode); 419 CeedChkBackend(ierr); 420 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 421 } else { 422 ierr = CeedVectorRestoreArrayRead(impl->e_vecs_full[i], 423 (const CeedScalar **) &e_data_full[i]); 424 CeedChkBackend(ierr); 425 } 426 } 427 return CEED_ERROR_SUCCESS; 428 } 429 430 //------------------------------------------------------------------------------ 431 // Operator Apply 432 //------------------------------------------------------------------------------ 433 static int CeedOperatorApplyAdd_Blocked(CeedOperator op, CeedVector in_vec, 434 CeedVector out_vec, 435 CeedRequest *request) { 436 int ierr; 437 CeedOperator_Blocked *impl; 438 ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 439 const CeedInt blk_size = 8; 440 CeedInt Q, num_input_fields, num_output_fields, num_elem, size; 441 ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr); 442 ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr); 443 CeedInt num_blks = (num_elem/blk_size) + !!(num_elem%blk_size); 444 CeedQFunction qf; 445 ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr); 446 CeedOperatorField *op_input_fields, *op_output_fields; 447 ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, 448 &num_output_fields, &op_output_fields); 449 CeedChkBackend(ierr); 450 CeedQFunctionField *qf_input_fields, *qf_output_fields; 451 ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, 452 &qf_output_fields); 453 CeedChkBackend(ierr); 454 CeedEvalMode eval_mode; 455 CeedVector vec; 456 CeedScalar *e_data_full[2*CEED_FIELD_MAX] = {0}; 457 458 // Setup 459 ierr = CeedOperatorSetup_Blocked(op); CeedChkBackend(ierr); 460 461 // Restriction only operator 462 if (impl->is_identity_restr_op) { 463 ierr = CeedElemRestrictionApply(impl->blk_restr[0], CEED_NOTRANSPOSE, in_vec, 464 impl->e_vecs_full[0], request); CeedChkBackend(ierr); 465 ierr = CeedElemRestrictionApply(impl->blk_restr[1], CEED_TRANSPOSE, 466 impl->e_vecs_full[0], out_vec, request); CeedChkBackend(ierr); 467 return CEED_ERROR_SUCCESS; 468 } 469 470 // Input Evecs and Restriction 471 ierr = CeedOperatorSetupInputs_Blocked(num_input_fields, qf_input_fields, 472 op_input_fields, in_vec, false, e_data_full, 473 impl, request); CeedChkBackend(ierr); 474 475 // Output Evecs 476 for (CeedInt i=0; i<num_output_fields; i++) { 477 ierr = CeedVectorGetArrayWrite(impl->e_vecs_full[i+impl->num_inputs], 478 CEED_MEM_HOST, &e_data_full[i + num_input_fields]); 479 CeedChkBackend(ierr); 480 } 481 482 // Loop through elements 483 for (CeedInt e=0; e<num_blks*blk_size; e+=blk_size) { 484 // Output pointers 485 for (CeedInt i=0; i<num_output_fields; i++) { 486 ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode); 487 CeedChkBackend(ierr); 488 if (eval_mode == CEED_EVAL_NONE) { 489 ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size); 490 CeedChkBackend(ierr); 491 ierr = CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST, 492 CEED_USE_POINTER, &e_data_full[i + num_input_fields][e*Q*size]); 493 CeedChkBackend(ierr); 494 } 495 } 496 497 // Input basis apply 498 ierr = CeedOperatorInputBasis_Blocked(e, Q, qf_input_fields, op_input_fields, 499 num_input_fields, blk_size, false, e_data_full, 500 impl); CeedChkBackend(ierr); 501 502 // Q function 503 if (!impl->is_identity_qf) { 504 ierr = CeedQFunctionApply(qf, Q*blk_size, impl->q_vecs_in, impl->q_vecs_out); 505 CeedChkBackend(ierr); 506 } 507 508 // Output basis apply 509 ierr = CeedOperatorOutputBasis_Blocked(e, Q, qf_output_fields, op_output_fields, 510 blk_size, num_input_fields, 511 num_output_fields, op, e_data_full, impl); 512 CeedChkBackend(ierr); 513 } 514 515 // Output restriction 516 for (CeedInt i=0; i<num_output_fields; i++) { 517 // Restore evec 518 ierr = CeedVectorRestoreArray(impl->e_vecs_full[i+impl->num_inputs], 519 &e_data_full[i + num_input_fields]); 520 CeedChkBackend(ierr); 521 // Get output vector 522 ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec); 523 CeedChkBackend(ierr); 524 // Active 525 if (vec == CEED_VECTOR_ACTIVE) 526 vec = out_vec; 527 // Restrict 528 ierr = CeedElemRestrictionApply(impl->blk_restr[i+impl->num_inputs], 529 CEED_TRANSPOSE, impl->e_vecs_full[i+impl->num_inputs], 530 vec, request); CeedChkBackend(ierr); 531 532 } 533 534 // Restore input arrays 535 ierr = CeedOperatorRestoreInputs_Blocked(num_input_fields, qf_input_fields, 536 op_input_fields, false, e_data_full, impl); CeedChkBackend(ierr); 537 538 return CEED_ERROR_SUCCESS; 539 } 540 541 //------------------------------------------------------------------------------ 542 // Core code for assembling linear QFunction 543 //------------------------------------------------------------------------------ 544 static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked( 545 CeedOperator op, bool build_objects, CeedVector *assembled, 546 CeedElemRestriction *rstr, CeedRequest *request) { 547 int ierr; 548 CeedOperator_Blocked *impl; 549 ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 550 const CeedInt blk_size = 8; 551 CeedInt Q, num_input_fields, num_output_fields, num_elem, size; 552 ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr); 553 ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr); 554 CeedInt num_blks = (num_elem/blk_size) + !!(num_elem%blk_size); 555 CeedQFunction qf; 556 ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr); 557 CeedOperatorField *op_input_fields, *op_output_fields; 558 ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, 559 &num_output_fields, &op_output_fields); 560 CeedChkBackend(ierr); 561 CeedQFunctionField *qf_input_fields, *qf_output_fields; 562 ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, 563 &qf_output_fields); 564 CeedChkBackend(ierr); 565 CeedVector vec, l_vec = impl->qf_l_vec; 566 CeedInt num_active_in = impl->num_active_in, 567 num_active_out = impl->num_active_out; 568 CeedVector *active_in = impl->qf_active_in; 569 CeedScalar *a, *tmp; 570 Ceed ceed; 571 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 572 CeedScalar *e_data_full[2*CEED_FIELD_MAX] = {0}; 573 574 // Setup 575 ierr = CeedOperatorSetup_Blocked(op); CeedChkBackend(ierr); 576 577 // Check for identity 578 if (impl->is_identity_qf) 579 // LCOV_EXCL_START 580 return CeedError(ceed, CEED_ERROR_BACKEND, 581 "Assembling identity QFunctions not supported"); 582 // LCOV_EXCL_STOP 583 584 // Input Evecs and Restriction 585 ierr = CeedOperatorSetupInputs_Blocked(num_input_fields, qf_input_fields, 586 op_input_fields, NULL, true, e_data_full, 587 impl, request); CeedChkBackend(ierr); 588 589 // Count number of active input fields 590 if (!num_active_in) { 591 for (CeedInt i=0; i<num_input_fields; i++) { 592 // Get input vector 593 ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 594 CeedChkBackend(ierr); 595 // Check if active input 596 if (vec == CEED_VECTOR_ACTIVE) { 597 ierr = CeedQFunctionFieldGetSize(qf_input_fields[i], &size); 598 CeedChkBackend(ierr); 599 ierr = CeedVectorSetValue(impl->q_vecs_in[i], 0.0); CeedChkBackend(ierr); 600 ierr = CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &tmp); 601 CeedChkBackend(ierr); 602 ierr = CeedRealloc(num_active_in + size, &active_in); CeedChkBackend(ierr); 603 for (CeedInt field=0; field<size; field++) { 604 ierr = CeedVectorCreate(ceed, Q*blk_size, &active_in[num_active_in+field]); 605 CeedChkBackend(ierr); 606 ierr = CeedVectorSetArray(active_in[num_active_in+field], CEED_MEM_HOST, 607 CEED_USE_POINTER, &tmp[field*Q*blk_size]); 608 CeedChkBackend(ierr); 609 } 610 num_active_in += size; 611 ierr = CeedVectorRestoreArray(impl->q_vecs_in[i], &tmp); CeedChkBackend(ierr); 612 } 613 } 614 impl->num_active_in = num_active_in; 615 impl->qf_active_in = active_in; 616 } 617 618 // Count number of active output fields 619 if (!num_active_out) { 620 for (CeedInt i=0; i<num_output_fields; i++) { 621 // Get output vector 622 ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec); 623 CeedChkBackend(ierr); 624 // Check if active output 625 if (vec == CEED_VECTOR_ACTIVE) { 626 ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size); 627 CeedChkBackend(ierr); 628 num_active_out += size; 629 } 630 } 631 impl->num_active_out = num_active_out; 632 } 633 634 // Check sizes 635 if (!num_active_in || !num_active_out) 636 // LCOV_EXCL_START 637 return CeedError(ceed, CEED_ERROR_BACKEND, 638 "Cannot assemble QFunction without active inputs " 639 "and outputs"); 640 // LCOV_EXCL_STOP 641 642 // Setup Lvec 643 if (!l_vec) { 644 ierr = CeedVectorCreate(ceed, num_blks*blk_size*Q*num_active_in*num_active_out, 645 &l_vec); CeedChkBackend(ierr); 646 impl->qf_l_vec = l_vec; 647 } 648 ierr = CeedVectorGetArrayWrite(l_vec, CEED_MEM_HOST, &a); CeedChkBackend(ierr); 649 650 // Build objects if needed 651 CeedInt strides[3] = {1, Q, num_active_in *num_active_out*Q}; 652 if (build_objects) { 653 // Create output restriction 654 ierr = CeedElemRestrictionCreateStrided(ceed, num_elem, Q, 655 num_active_in*num_active_out, 656 num_active_in*num_active_out*num_elem*Q, 657 strides, rstr); CeedChkBackend(ierr); 658 // Create assembled vector 659 ierr = CeedVectorCreate(ceed, num_elem*Q*num_active_in*num_active_out, 660 assembled); CeedChkBackend(ierr); 661 } 662 663 // Loop through elements 664 for (CeedInt e=0; e<num_blks*blk_size; e+=blk_size) { 665 // Input basis apply 666 ierr = CeedOperatorInputBasis_Blocked(e, Q, qf_input_fields, op_input_fields, 667 num_input_fields, blk_size, true, e_data_full, 668 impl); CeedChkBackend(ierr); 669 670 // Assemble QFunction 671 for (CeedInt in=0; in<num_active_in; in++) { 672 // Set Inputs 673 ierr = CeedVectorSetValue(active_in[in], 1.0); CeedChkBackend(ierr); 674 if (num_active_in > 1) { 675 ierr = CeedVectorSetValue(active_in[(in+num_active_in-1)%num_active_in], 676 0.0); CeedChkBackend(ierr); 677 } 678 // Set Outputs 679 for (CeedInt out=0; out<num_output_fields; out++) { 680 // Get output vector 681 ierr = CeedOperatorFieldGetVector(op_output_fields[out], &vec); 682 CeedChkBackend(ierr); 683 // Check if active output 684 if (vec == CEED_VECTOR_ACTIVE) { 685 CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, 686 CEED_USE_POINTER, a); CeedChkBackend(ierr); 687 ierr = CeedQFunctionFieldGetSize(qf_output_fields[out], &size); 688 CeedChkBackend(ierr); 689 a += size*Q*blk_size; // Advance the pointer by the size of the output 690 } 691 } 692 // Apply QFunction 693 ierr = CeedQFunctionApply(qf, Q*blk_size, impl->q_vecs_in, impl->q_vecs_out); 694 CeedChkBackend(ierr); 695 } 696 } 697 698 // Un-set output Qvecs to prevent accidental overwrite of Assembled 699 for (CeedInt out=0; out<num_output_fields; out++) { 700 // Get output vector 701 ierr = CeedOperatorFieldGetVector(op_output_fields[out], &vec); 702 CeedChkBackend(ierr); 703 // Check if active output 704 if (vec == CEED_VECTOR_ACTIVE) { 705 CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_COPY_VALUES, 706 NULL); CeedChkBackend(ierr); 707 } 708 } 709 710 // Restore input arrays 711 ierr = CeedOperatorRestoreInputs_Blocked(num_input_fields, qf_input_fields, 712 op_input_fields, true, e_data_full, impl); CeedChkBackend(ierr); 713 714 // Output blocked restriction 715 ierr = CeedVectorRestoreArray(l_vec, &a); CeedChkBackend(ierr); 716 ierr = CeedVectorSetValue(*assembled, 0.0); CeedChkBackend(ierr); 717 CeedElemRestriction blk_rstr = impl->qf_blk_rstr; 718 if (!impl->qf_blk_rstr) { 719 ierr = CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, Q, blk_size, 720 num_active_in*num_active_out, num_active_in*num_active_out*num_elem*Q, 721 strides, &blk_rstr); CeedChkBackend(ierr); 722 impl->qf_blk_rstr = blk_rstr; 723 } 724 ierr = CeedElemRestrictionApply(blk_rstr, CEED_TRANSPOSE, l_vec, *assembled, 725 request); CeedChkBackend(ierr); 726 727 return CEED_ERROR_SUCCESS; 728 } 729 730 //------------------------------------------------------------------------------ 731 // Assemble Linear QFunction 732 //------------------------------------------------------------------------------ 733 static int CeedOperatorLinearAssembleQFunction_Blocked(CeedOperator op, 734 CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 735 return CeedOperatorLinearAssembleQFunctionCore_Blocked(op, true, assembled, 736 rstr, request); 737 } 738 739 //------------------------------------------------------------------------------ 740 // Update Assembled Linear QFunction 741 //------------------------------------------------------------------------------ 742 static int CeedOperatorLinearAssembleQFunctionUpdate_Blocked(CeedOperator op, 743 CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 744 return CeedOperatorLinearAssembleQFunctionCore_Blocked(op, false, &assembled, 745 &rstr, request); 746 } 747 748 //------------------------------------------------------------------------------ 749 // Operator Destroy 750 //------------------------------------------------------------------------------ 751 static int CeedOperatorDestroy_Blocked(CeedOperator op) { 752 int ierr; 753 CeedOperator_Blocked *impl; 754 ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 755 756 for (CeedInt i=0; i<impl->num_inputs+impl->num_outputs; i++) { 757 ierr = CeedElemRestrictionDestroy(&impl->blk_restr[i]); CeedChkBackend(ierr); 758 ierr = CeedVectorDestroy(&impl->e_vecs_full[i]); CeedChkBackend(ierr); 759 } 760 ierr = CeedFree(&impl->blk_restr); CeedChkBackend(ierr); 761 ierr = CeedFree(&impl->e_vecs_full); CeedChkBackend(ierr); 762 ierr = CeedFree(&impl->input_states); CeedChkBackend(ierr); 763 764 for (CeedInt i=0; i<impl->num_inputs; i++) { 765 ierr = CeedVectorDestroy(&impl->e_vecs_in[i]); CeedChkBackend(ierr); 766 ierr = CeedVectorDestroy(&impl->q_vecs_in[i]); CeedChkBackend(ierr); 767 } 768 ierr = CeedFree(&impl->e_vecs_in); CeedChkBackend(ierr); 769 ierr = CeedFree(&impl->q_vecs_in); CeedChkBackend(ierr); 770 771 for (CeedInt i=0; i<impl->num_outputs; i++) { 772 ierr = CeedVectorDestroy(&impl->e_vecs_out[i]); CeedChkBackend(ierr); 773 ierr = CeedVectorDestroy(&impl->q_vecs_out[i]); CeedChkBackend(ierr); 774 } 775 ierr = CeedFree(&impl->e_vecs_out); CeedChkBackend(ierr); 776 ierr = CeedFree(&impl->q_vecs_out); CeedChkBackend(ierr); 777 778 // QFunction assembly data 779 for (CeedInt i=0; i<impl->num_active_in; i++) { 780 ierr = CeedVectorDestroy(&impl->qf_active_in[i]); CeedChkBackend(ierr); 781 } 782 ierr = CeedFree(&impl->qf_active_in); CeedChkBackend(ierr); 783 ierr = CeedVectorDestroy(&impl->qf_l_vec); CeedChkBackend(ierr); 784 ierr = CeedElemRestrictionDestroy(&impl->qf_blk_rstr); CeedChkBackend(ierr); 785 786 ierr = CeedFree(&impl); CeedChkBackend(ierr); 787 return CEED_ERROR_SUCCESS; 788 } 789 790 //------------------------------------------------------------------------------ 791 // Operator Create 792 //------------------------------------------------------------------------------ 793 int CeedOperatorCreate_Blocked(CeedOperator op) { 794 int ierr; 795 Ceed ceed; 796 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 797 CeedOperator_Blocked *impl; 798 799 ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr); 800 ierr = CeedOperatorSetData(op, impl); CeedChkBackend(ierr); 801 802 ierr = CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", 803 CeedOperatorLinearAssembleQFunction_Blocked); 804 CeedChkBackend(ierr); 805 ierr = CeedSetBackendFunction(ceed, "Operator", op, 806 "LinearAssembleQFunctionUpdate", 807 CeedOperatorLinearAssembleQFunctionUpdate_Blocked); 808 CeedChkBackend(ierr); 809 ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", 810 CeedOperatorApplyAdd_Blocked); CeedChkBackend(ierr); 811 ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy", 812 CeedOperatorDestroy_Blocked); CeedChkBackend(ierr); 813 return CEED_ERROR_SUCCESS; 814 } 815 //------------------------------------------------------------------------------ 816