1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/ceed.h> 9 #include <ceed/backend.h> 10 #include <stdbool.h> 11 #include <stdint.h> 12 #include <string.h> 13 #include "ceed-opt.h" 14 15 //------------------------------------------------------------------------------ 16 // Setup Input/Output Fields 17 //------------------------------------------------------------------------------ 18 static int CeedOperatorSetupFields_Opt(CeedQFunction qf, CeedOperator op, 19 bool is_input, const CeedInt blk_size, 20 CeedElemRestriction *blk_restr, 21 CeedVector *e_vecs_full, CeedVector *e_vecs, 22 CeedVector *q_vecs, CeedInt start_e, 23 CeedInt num_fields, CeedInt Q) { 24 CeedInt ierr, num_comp, size, P; 25 CeedSize e_size, q_size; 26 Ceed ceed; 27 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 28 CeedBasis basis; 29 CeedElemRestriction r; 30 CeedOperatorField *op_fields; 31 CeedQFunctionField *qf_fields; 32 if (is_input) { 33 ierr = CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL); 34 CeedChkBackend(ierr); 35 ierr = CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL); 36 CeedChkBackend(ierr); 37 } else { 38 ierr = CeedOperatorGetFields(op, NULL, NULL, NULL,&op_fields); 39 CeedChkBackend(ierr); 40 ierr = CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields); 41 CeedChkBackend(ierr); 42 } 43 44 // Loop over fields 45 for (CeedInt i=0; i<num_fields; i++) { 46 CeedEvalMode eval_mode; 47 ierr = CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode); 48 CeedChkBackend(ierr); 49 50 if (eval_mode != CEED_EVAL_WEIGHT) { 51 ierr = CeedOperatorFieldGetElemRestriction(op_fields[i], &r); 52 CeedChkBackend(ierr); 53 Ceed ceed; 54 ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr); 55 CeedSize l_size; 56 CeedInt num_elem, elem_size, comp_stride; 57 ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr); 58 ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr); 59 ierr = CeedElemRestrictionGetLVectorSize(r, &l_size); CeedChkBackend(ierr); 60 ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr); 61 62 bool strided; 63 ierr = CeedElemRestrictionIsStrided(r, &strided); CeedChkBackend(ierr); 64 if (strided) { 65 CeedInt strides[3]; 66 ierr = CeedElemRestrictionGetStrides(r, &strides); CeedChkBackend(ierr); 67 ierr = CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, elem_size, 68 blk_size, num_comp, l_size, strides, &blk_restr[i+start_e]); 69 CeedChkBackend(ierr); 70 } else { 71 const CeedInt *offsets = NULL; 72 ierr = CeedElemRestrictionGetOffsets(r, CEED_MEM_HOST, &offsets); 73 CeedChkBackend(ierr); 74 ierr = CeedElemRestrictionGetCompStride(r, &comp_stride); CeedChkBackend(ierr); 75 ierr = CeedElemRestrictionCreateBlocked(ceed, num_elem, elem_size, 76 blk_size, num_comp, comp_stride, 77 l_size, CEED_MEM_HOST, 78 CEED_COPY_VALUES, offsets, 79 &blk_restr[i+start_e]); 80 CeedChkBackend(ierr); 81 ierr = CeedElemRestrictionRestoreOffsets(r, &offsets); CeedChkBackend(ierr); 82 } 83 ierr = CeedElemRestrictionCreateVector(blk_restr[i+start_e], NULL, 84 &e_vecs_full[i+start_e]); 85 CeedChkBackend(ierr); 86 } 87 88 switch(eval_mode) { 89 case CEED_EVAL_NONE: 90 ierr = CeedQFunctionFieldGetSize(qf_fields[i], &size); CeedChkBackend(ierr); 91 e_size = (CeedSize)Q*size*blk_size; 92 ierr = CeedVectorCreate(ceed, e_size, &e_vecs[i]); CeedChkBackend(ierr); 93 q_size = (CeedSize)Q*size*blk_size; 94 ierr = CeedVectorCreate(ceed, q_size, &q_vecs[i]); CeedChkBackend(ierr); 95 break; 96 case CEED_EVAL_INTERP: 97 case CEED_EVAL_GRAD: 98 ierr = CeedOperatorFieldGetBasis(op_fields[i], &basis); CeedChkBackend(ierr); 99 ierr = CeedQFunctionFieldGetSize(qf_fields[i], &size); CeedChkBackend(ierr); 100 ierr = CeedBasisGetNumNodes(basis, &P); CeedChkBackend(ierr); 101 ierr = CeedBasisGetNumComponents(basis, &num_comp); CeedChkBackend(ierr); 102 e_size = (CeedSize)P*num_comp*blk_size; 103 ierr = CeedVectorCreate(ceed, e_size, &e_vecs[i]); CeedChkBackend(ierr); 104 q_size = (CeedSize)Q*size*blk_size; 105 ierr = CeedVectorCreate(ceed, q_size, &q_vecs[i]); CeedChkBackend(ierr); 106 break; 107 case CEED_EVAL_WEIGHT: // Only on input fields 108 ierr = CeedOperatorFieldGetBasis(op_fields[i], &basis); CeedChkBackend(ierr); 109 q_size = (CeedSize)Q*blk_size; 110 ierr = CeedVectorCreate(ceed, q_size, &q_vecs[i]); CeedChkBackend(ierr); 111 ierr = CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, 112 CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i]); 113 CeedChkBackend(ierr); 114 115 break; 116 case CEED_EVAL_DIV: 117 break; // Not implemented 118 case CEED_EVAL_CURL: 119 break; // Not implemented 120 } 121 if (is_input && !!e_vecs[i]) { 122 ierr = CeedVectorSetArray(e_vecs[i], CEED_MEM_HOST, 123 CEED_COPY_VALUES, NULL); CeedChkBackend(ierr); 124 } 125 } 126 return CEED_ERROR_SUCCESS; 127 } 128 129 //------------------------------------------------------------------------------ 130 // Setup Operator 131 //------------------------------------------------------------------------------ 132 static int CeedOperatorSetup_Opt(CeedOperator op) { 133 int ierr; 134 bool is_setup_done; 135 ierr = CeedOperatorIsSetupDone(op, &is_setup_done); CeedChkBackend(ierr); 136 if (is_setup_done) return CEED_ERROR_SUCCESS; 137 Ceed ceed; 138 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 139 Ceed_Opt *ceed_impl; 140 ierr = CeedGetData(ceed, &ceed_impl); CeedChkBackend(ierr); 141 const CeedInt blk_size = ceed_impl->blk_size; 142 CeedOperator_Opt *impl; 143 ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 144 CeedQFunction qf; 145 ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr); 146 CeedInt Q, num_input_fields, num_output_fields; 147 ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr); 148 ierr = CeedQFunctionIsIdentity(qf, &impl->is_identity_qf); CeedChkBackend(ierr); 149 CeedOperatorField *op_input_fields, *op_output_fields; 150 ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, 151 &num_output_fields, &op_output_fields); 152 CeedChkBackend(ierr); 153 CeedQFunctionField *qf_input_fields, *qf_output_fields; 154 ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, 155 &qf_output_fields); 156 CeedChkBackend(ierr); 157 158 // Allocate 159 ierr = CeedCalloc(num_input_fields + num_output_fields, &impl->blk_restr); 160 CeedChkBackend(ierr); 161 ierr = CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full); 162 CeedChkBackend(ierr); 163 164 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->input_states); CeedChkBackend(ierr); 165 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in); CeedChkBackend(ierr); 166 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out); CeedChkBackend(ierr); 167 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in); CeedChkBackend(ierr); 168 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out); CeedChkBackend(ierr); 169 170 impl->num_inputs = num_input_fields; 171 impl->num_outputs = num_output_fields; 172 173 // Set up infield and outfield pointer arrays 174 // Infields 175 ierr = CeedOperatorSetupFields_Opt(qf, op, true, blk_size, impl->blk_restr, 176 impl->e_vecs_full, impl->e_vecs_in, 177 impl->q_vecs_in, 0, num_input_fields, Q); 178 CeedChkBackend(ierr); 179 // Outfields 180 ierr = CeedOperatorSetupFields_Opt(qf, op, false, blk_size, impl->blk_restr, 181 impl->e_vecs_full, impl->e_vecs_out, 182 impl->q_vecs_out, num_input_fields, 183 num_output_fields, Q); 184 CeedChkBackend(ierr); 185 186 // Identity QFunctions 187 if (impl->is_identity_qf) { 188 CeedEvalMode in_mode, out_mode; 189 CeedQFunctionField *in_fields, *out_fields; 190 ierr = CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields); 191 CeedChkBackend(ierr); 192 ierr = CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode); 193 CeedChkBackend(ierr); 194 ierr = CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode); 195 CeedChkBackend(ierr); 196 197 if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) { 198 impl->is_identity_restr_op = true; 199 } else { 200 ierr = CeedVectorDestroy(&impl->q_vecs_out[0]); CeedChkBackend(ierr); 201 impl->q_vecs_out[0] = impl->q_vecs_in[0]; 202 ierr = CeedVectorAddReference(impl->q_vecs_in[0]); CeedChkBackend(ierr); 203 } 204 } 205 206 ierr = CeedOperatorSetSetupDone(op); CeedChkBackend(ierr); 207 208 return CEED_ERROR_SUCCESS; 209 } 210 211 //------------------------------------------------------------------------------ 212 // Setup Input Fields 213 //------------------------------------------------------------------------------ 214 static inline int CeedOperatorSetupInputs_Opt(CeedInt num_input_fields, 215 CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 216 CeedVector in_vec, CeedScalar *e_data[2*CEED_FIELD_MAX], CeedOperator_Opt *impl, 217 CeedRequest *request) { 218 CeedInt ierr; 219 CeedEvalMode eval_mode; 220 CeedVector vec; 221 uint64_t state; 222 223 for (CeedInt i=0; i<num_input_fields; i++) { 224 ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode); 225 CeedChkBackend(ierr); 226 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 227 } else { 228 // Get input vector 229 ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 230 CeedChkBackend(ierr); 231 if (vec != CEED_VECTOR_ACTIVE) { 232 // Restrict 233 ierr = CeedVectorGetState(vec, &state); CeedChkBackend(ierr); 234 if (state != impl->input_states[i]) { 235 ierr = CeedElemRestrictionApply(impl->blk_restr[i], CEED_NOTRANSPOSE, 236 vec, impl->e_vecs_full[i], request); 237 CeedChkBackend(ierr); 238 impl->input_states[i] = state; 239 } 240 // Get evec 241 ierr = CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, 242 (const CeedScalar **) &e_data[i]); 243 CeedChkBackend(ierr); 244 } else { 245 // Set Qvec for CEED_EVAL_NONE 246 if (eval_mode == CEED_EVAL_NONE) { 247 ierr = CeedVectorGetArrayRead(impl->e_vecs_in[i], CEED_MEM_HOST, 248 (const CeedScalar **)&e_data[i]); 249 CeedChkBackend(ierr); 250 ierr = CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, 251 CEED_USE_POINTER, e_data[i]); CeedChkBackend(ierr); 252 ierr = CeedVectorRestoreArrayRead(impl->e_vecs_in[i], 253 (const CeedScalar **)&e_data[i]); 254 CeedChkBackend(ierr); 255 } 256 } 257 } 258 } 259 return CEED_ERROR_SUCCESS; 260 } 261 262 //------------------------------------------------------------------------------ 263 // Input Basis Action 264 //------------------------------------------------------------------------------ 265 static inline int CeedOperatorInputBasis_Opt(CeedInt e, CeedInt Q, 266 CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 267 CeedInt num_input_fields, CeedInt blk_size, CeedVector in_vec, bool skip_active, 268 CeedScalar *e_data[2*CEED_FIELD_MAX], CeedOperator_Opt *impl, 269 CeedRequest *request) { 270 CeedInt ierr; 271 CeedInt dim, elem_size, size; 272 CeedElemRestriction elem_restr; 273 CeedEvalMode eval_mode; 274 CeedBasis basis; 275 CeedVector vec; 276 277 for (CeedInt i=0; i<num_input_fields; i++) { 278 ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 279 CeedChkBackend(ierr); 280 // Skip active input 281 if (skip_active) { 282 if (vec == CEED_VECTOR_ACTIVE) 283 continue; 284 } 285 286 CeedInt active_in = 0; 287 // Get elem_size, eval_mode, size 288 ierr = CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr); 289 CeedChkBackend(ierr); 290 ierr = CeedElemRestrictionGetElementSize(elem_restr, &elem_size); 291 CeedChkBackend(ierr); 292 ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode); 293 CeedChkBackend(ierr); 294 ierr = CeedQFunctionFieldGetSize(qf_input_fields[i], &size); 295 CeedChkBackend(ierr); 296 // Restrict block active input 297 if (vec == CEED_VECTOR_ACTIVE) { 298 ierr = CeedElemRestrictionApplyBlock(impl->blk_restr[i], e/blk_size, 299 CEED_NOTRANSPOSE, in_vec, 300 impl->e_vecs_in[i], request); 301 CeedChkBackend(ierr); 302 active_in = 1; 303 } 304 // Basis action 305 switch(eval_mode) { 306 case CEED_EVAL_NONE: 307 if (!active_in) { 308 ierr = CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, 309 CEED_USE_POINTER, &e_data[i][e*Q*size]); 310 CeedChkBackend(ierr); 311 } 312 break; 313 case CEED_EVAL_INTERP: 314 ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis); 315 CeedChkBackend(ierr); 316 if (!active_in) { 317 ierr = CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, 318 CEED_USE_POINTER, &e_data[i][e*elem_size*size]); 319 CeedChkBackend(ierr); 320 } 321 ierr = CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, 322 CEED_EVAL_INTERP, impl->e_vecs_in[i], 323 impl->q_vecs_in[i]); CeedChkBackend(ierr); 324 break; 325 case CEED_EVAL_GRAD: 326 ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis); 327 CeedChkBackend(ierr); 328 if (!active_in) { 329 ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr); 330 ierr = CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, 331 CEED_USE_POINTER, 332 &e_data[i][e*elem_size*size/dim]); 333 CeedChkBackend(ierr); 334 } 335 ierr = CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, 336 CEED_EVAL_GRAD, impl->e_vecs_in[i], 337 impl->q_vecs_in[i]); CeedChkBackend(ierr); 338 break; 339 case CEED_EVAL_WEIGHT: 340 break; // No action 341 // LCOV_EXCL_START 342 case CEED_EVAL_DIV: 343 case CEED_EVAL_CURL: { 344 ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis); 345 CeedChkBackend(ierr); 346 Ceed ceed; 347 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 348 return CeedError(ceed, CEED_ERROR_BACKEND, 349 "Ceed evaluation mode not implemented"); 350 // LCOV_EXCL_STOP 351 } 352 } 353 } 354 return CEED_ERROR_SUCCESS; 355 } 356 357 //------------------------------------------------------------------------------ 358 // Output Basis Action 359 //------------------------------------------------------------------------------ 360 static inline int CeedOperatorOutputBasis_Opt(CeedInt e, CeedInt Q, 361 CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, 362 CeedInt blk_size, CeedInt num_input_fields, CeedInt num_output_fields, 363 CeedOperator op, CeedVector out_vec, CeedOperator_Opt *impl, 364 CeedRequest *request) { 365 CeedInt ierr; 366 CeedElemRestriction elem_restr; 367 CeedEvalMode eval_mode; 368 CeedBasis basis; 369 CeedVector vec; 370 371 for (CeedInt i=0; i<num_output_fields; i++) { 372 // Get elem_size, eval_mode, size 373 ierr = CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr); 374 CeedChkBackend(ierr); 375 ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode); 376 CeedChkBackend(ierr); 377 // Basis action 378 switch(eval_mode) { 379 case CEED_EVAL_NONE: 380 break; // No action 381 case CEED_EVAL_INTERP: 382 ierr = CeedOperatorFieldGetBasis(op_output_fields[i], &basis); 383 CeedChkBackend(ierr); 384 ierr = CeedBasisApply(basis, blk_size, CEED_TRANSPOSE, 385 CEED_EVAL_INTERP, impl->q_vecs_out[i], 386 impl->e_vecs_out[i]); CeedChkBackend(ierr); 387 break; 388 case CEED_EVAL_GRAD: 389 ierr = CeedOperatorFieldGetBasis(op_output_fields[i], &basis); 390 CeedChkBackend(ierr); 391 ierr = CeedBasisApply(basis, blk_size, CEED_TRANSPOSE, 392 CEED_EVAL_GRAD, impl->q_vecs_out[i], 393 impl->e_vecs_out[i]); CeedChkBackend(ierr); 394 break; 395 // LCOV_EXCL_START 396 case CEED_EVAL_WEIGHT: { 397 Ceed ceed; 398 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 399 return CeedError(ceed, CEED_ERROR_BACKEND, 400 "CEED_EVAL_WEIGHT cannot be an output " 401 "evaluation mode"); 402 } 403 case CEED_EVAL_DIV: 404 case CEED_EVAL_CURL: { 405 Ceed ceed; 406 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 407 return CeedError(ceed, CEED_ERROR_BACKEND, 408 "Ceed evaluation mode not implemented"); 409 // LCOV_EXCL_STOP 410 } 411 } 412 // Restrict output block 413 // Get output vector 414 ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec); 415 CeedChkBackend(ierr); 416 if (vec == CEED_VECTOR_ACTIVE) 417 vec = out_vec; 418 // Restrict 419 ierr = CeedElemRestrictionApplyBlock(impl->blk_restr[i+impl->num_inputs], 420 e/blk_size, CEED_TRANSPOSE, 421 impl->e_vecs_out[i], vec, request); 422 CeedChkBackend(ierr); 423 } 424 return CEED_ERROR_SUCCESS; 425 } 426 427 //------------------------------------------------------------------------------ 428 // Restore Input Vectors 429 //------------------------------------------------------------------------------ 430 static inline int CeedOperatorRestoreInputs_Opt(CeedInt num_input_fields, 431 CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 432 CeedScalar *e_data[2*CEED_FIELD_MAX], CeedOperator_Opt *impl) { 433 CeedInt ierr; 434 435 for (CeedInt i=0; i<num_input_fields; i++) { 436 CeedEvalMode eval_mode; 437 CeedVector vec; 438 ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode); 439 CeedChkBackend(ierr); 440 ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 441 CeedChkBackend(ierr); 442 if (eval_mode != CEED_EVAL_WEIGHT && vec != CEED_VECTOR_ACTIVE) { 443 ierr = CeedVectorRestoreArrayRead(impl->e_vecs_full[i], 444 (const CeedScalar **) &e_data[i]); 445 CeedChkBackend(ierr); 446 } 447 } 448 return CEED_ERROR_SUCCESS; 449 } 450 451 //------------------------------------------------------------------------------ 452 // Operator Apply 453 //------------------------------------------------------------------------------ 454 static int CeedOperatorApplyAdd_Opt(CeedOperator op, CeedVector in_vec, 455 CeedVector out_vec, CeedRequest *request) { 456 int ierr; 457 Ceed ceed; 458 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 459 Ceed_Opt *ceed_impl; 460 ierr = CeedGetData(ceed, &ceed_impl); CeedChkBackend(ierr); 461 CeedInt blk_size = ceed_impl->blk_size; 462 CeedOperator_Opt *impl; 463 ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 464 CeedInt Q, num_input_fields, num_output_fields, num_elem; 465 ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr); 466 ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr); 467 CeedInt num_blks = (num_elem/blk_size) + !!(num_elem%blk_size); 468 CeedQFunction qf; 469 ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr); 470 CeedOperatorField *op_input_fields, *op_output_fields; 471 ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, 472 &num_output_fields, &op_output_fields); 473 CeedChkBackend(ierr); 474 CeedQFunctionField *qf_input_fields, *qf_output_fields; 475 ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, 476 &qf_output_fields); 477 CeedChkBackend(ierr); 478 CeedEvalMode eval_mode; 479 CeedScalar *e_data[2*CEED_FIELD_MAX] = {0}; 480 481 // Setup 482 ierr = CeedOperatorSetup_Opt(op); CeedChkBackend(ierr); 483 484 // Restriction only operator 485 if (impl->is_identity_restr_op) { 486 for (CeedInt b=0; b<num_blks; b++) { 487 ierr = CeedElemRestrictionApplyBlock(impl->blk_restr[0], b, CEED_NOTRANSPOSE, 488 in_vec, impl->e_vecs_in[0], request); CeedChkBackend(ierr); 489 ierr = CeedElemRestrictionApplyBlock(impl->blk_restr[1], b, CEED_TRANSPOSE, 490 impl->e_vecs_in[0], out_vec, request); CeedChkBackend(ierr); 491 } 492 return CEED_ERROR_SUCCESS; 493 } 494 495 // Input Evecs and Restriction 496 ierr = CeedOperatorSetupInputs_Opt(num_input_fields, qf_input_fields, 497 op_input_fields, in_vec, e_data, 498 impl, request); CeedChkBackend(ierr); 499 500 // Output Lvecs, Evecs, and Qvecs 501 for (CeedInt i=0; i<num_output_fields; i++) { 502 // Set Qvec if needed 503 ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode); 504 CeedChkBackend(ierr); 505 if (eval_mode == CEED_EVAL_NONE) { 506 // Set qvec to single block evec 507 ierr = CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_HOST, 508 &e_data[i + num_input_fields]); 509 CeedChkBackend(ierr); 510 ierr = CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST, 511 CEED_USE_POINTER, e_data[i + num_input_fields]); 512 CeedChkBackend(ierr); 513 ierr = CeedVectorRestoreArray(impl->e_vecs_out[i], 514 &e_data[i + num_input_fields]); 515 CeedChkBackend(ierr); 516 } 517 } 518 519 // Loop through elements 520 for (CeedInt e=0; e<num_blks*blk_size; e+=blk_size) { 521 // Input basis apply 522 ierr = CeedOperatorInputBasis_Opt(e, Q, qf_input_fields, op_input_fields, 523 num_input_fields, blk_size, in_vec, false, 524 e_data, impl, request); CeedChkBackend(ierr); 525 526 // Q function 527 if (!impl->is_identity_qf) { 528 ierr = CeedQFunctionApply(qf, Q*blk_size, impl->q_vecs_in, impl->q_vecs_out); 529 CeedChkBackend(ierr); 530 } 531 532 // Output basis apply and restrict 533 ierr = CeedOperatorOutputBasis_Opt(e, Q, qf_output_fields, op_output_fields, 534 blk_size, num_input_fields, num_output_fields, 535 op, out_vec, impl, request); 536 CeedChkBackend(ierr); 537 } 538 539 // Restore input arrays 540 ierr = CeedOperatorRestoreInputs_Opt(num_input_fields, qf_input_fields, 541 op_input_fields, e_data, impl); 542 CeedChkBackend(ierr); 543 544 return CEED_ERROR_SUCCESS; 545 } 546 547 //------------------------------------------------------------------------------ 548 // Core code for linear QFunction assembly 549 //------------------------------------------------------------------------------ 550 static inline int CeedOperatorLinearAssembleQFunctionCore_Opt(CeedOperator op, 551 bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 552 CeedRequest *request) { 553 int ierr; 554 Ceed ceed; 555 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 556 Ceed_Opt *ceed_impl; 557 ierr = CeedGetData(ceed, &ceed_impl); CeedChkBackend(ierr); 558 const CeedInt blk_size = ceed_impl->blk_size; 559 CeedSize q_size; 560 CeedOperator_Opt *impl; 561 ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 562 CeedInt Q, num_input_fields, num_output_fields, num_elem, size; 563 ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr); 564 ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr); 565 CeedInt num_blks = (num_elem/blk_size) + !!(num_elem%blk_size); 566 CeedQFunction qf; 567 ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr); 568 CeedOperatorField *op_input_fields, *op_output_fields; 569 ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, 570 &num_output_fields, &op_output_fields); 571 CeedChkBackend(ierr); 572 CeedQFunctionField *qf_input_fields, *qf_output_fields; 573 ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, 574 &qf_output_fields); 575 CeedChkBackend(ierr); 576 CeedVector vec, l_vec = impl->qf_l_vec; 577 CeedInt num_active_in = impl->num_active_in, 578 num_active_out = impl->num_active_out; 579 CeedVector *active_in = impl->qf_active_in; 580 CeedScalar *a, *tmp; 581 CeedScalar *e_data[2*CEED_FIELD_MAX] = {0}; 582 583 // Setup 584 ierr = CeedOperatorSetup_Opt(op); CeedChkBackend(ierr); 585 586 // Check for identity 587 if (impl->is_identity_qf) 588 // LCOV_EXCL_START 589 return CeedError(ceed, CEED_ERROR_BACKEND, 590 "Assembling identity qfunctions not supported"); 591 // LCOV_EXCL_STOP 592 593 // Input Evecs and Restriction 594 ierr = CeedOperatorSetupInputs_Opt(num_input_fields, qf_input_fields, 595 op_input_fields, NULL, e_data, 596 impl, request); CeedChkBackend(ierr); 597 598 // Count number of active input fields 599 if (!num_active_in) { 600 for (CeedInt i=0; i<num_input_fields; i++) { 601 // Get input vector 602 ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 603 CeedChkBackend(ierr); 604 // Check if active input 605 if (vec == CEED_VECTOR_ACTIVE) { 606 ierr = CeedQFunctionFieldGetSize(qf_input_fields[i], &size); 607 CeedChkBackend(ierr); 608 ierr = CeedVectorSetValue(impl->q_vecs_in[i], 0.0); CeedChkBackend(ierr); 609 ierr = CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &tmp); 610 CeedChkBackend(ierr); 611 ierr = CeedRealloc(num_active_in + size, &active_in); CeedChkBackend(ierr); 612 for (CeedInt field=0; field<size; field++) { 613 q_size = (CeedSize)Q*blk_size; 614 ierr = CeedVectorCreate(ceed, q_size, &active_in[num_active_in+field]); 615 CeedChkBackend(ierr); 616 ierr = CeedVectorSetArray(active_in[num_active_in+field], CEED_MEM_HOST, 617 CEED_USE_POINTER, &tmp[field*Q*blk_size]); 618 CeedChkBackend(ierr); 619 } 620 num_active_in += size; 621 ierr = CeedVectorRestoreArray(impl->q_vecs_in[i], &tmp); CeedChkBackend(ierr); 622 } 623 } 624 impl->num_active_in = num_active_in; 625 impl->qf_active_in = active_in; 626 } 627 628 // Count number of active output fields 629 if (!num_active_out) { 630 for (CeedInt i=0; i<num_output_fields; i++) { 631 // Get output vector 632 ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec); 633 CeedChkBackend(ierr); 634 // Check if active output 635 if (vec == CEED_VECTOR_ACTIVE) { 636 ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size); 637 CeedChkBackend(ierr); 638 num_active_out += size; 639 } 640 } 641 impl->num_active_out = num_active_out; 642 } 643 644 // Check sizes 645 if (!num_active_in || !num_active_out) 646 // LCOV_EXCL_START 647 return CeedError(ceed, CEED_ERROR_BACKEND, 648 "Cannot assemble QFunction without active inputs " 649 "and outputs"); 650 // LCOV_EXCL_STOP 651 652 // Setup l_vec 653 if (!l_vec) { 654 CeedSize l_size = (CeedSize)blk_size*Q*num_active_in*num_active_out; 655 ierr = CeedVectorCreate(ceed, l_size, &l_vec); CeedChkBackend(ierr); 656 ierr = CeedVectorSetValue(l_vec, 0.0); CeedChkBackend(ierr); 657 impl->qf_l_vec = l_vec; 658 } 659 660 // Build objects if needed 661 CeedInt strides[3] = {1, Q, num_active_in *num_active_out*Q}; 662 if (build_objects) { 663 // Create output restriction 664 ierr = CeedElemRestrictionCreateStrided(ceed, num_elem, Q, 665 num_active_in*num_active_out, 666 num_active_in*num_active_out*num_elem*Q, 667 strides, rstr); CeedChkBackend(ierr); 668 // Create assembled vector 669 CeedSize l_size = (CeedSize)num_elem*Q*num_active_in*num_active_out; 670 ierr = CeedVectorCreate(ceed, l_size, assembled); CeedChkBackend(ierr); 671 } 672 673 // Output blocked restriction 674 CeedElemRestriction blk_rstr = impl->qf_blk_rstr; 675 if (!blk_rstr) { 676 ierr = CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, Q, blk_size, 677 num_active_in*num_active_out, num_active_in*num_active_out*num_elem*Q, 678 strides, &blk_rstr); CeedChkBackend(ierr); 679 impl->qf_blk_rstr = blk_rstr; 680 } 681 682 // Loop through elements 683 ierr = CeedVectorSetValue(*assembled, 0.0); CeedChkBackend(ierr); 684 for (CeedInt e=0; e<num_blks*blk_size; e+=blk_size) { 685 ierr = CeedVectorGetArray(l_vec, CEED_MEM_HOST, &a); CeedChkBackend(ierr); 686 687 // Input basis apply 688 ierr = CeedOperatorInputBasis_Opt(e, Q, qf_input_fields, op_input_fields, 689 num_input_fields, blk_size, NULL, true, 690 e_data, impl, request); CeedChkBackend(ierr); 691 692 // Assemble QFunction 693 for (CeedInt in=0; in<num_active_in; in++) { 694 // Set Inputs 695 ierr = CeedVectorSetValue(active_in[in], 1.0); CeedChkBackend(ierr); 696 if (num_active_in > 1) { 697 ierr = CeedVectorSetValue(active_in[(in+num_active_in-1)%num_active_in], 698 0.0); CeedChkBackend(ierr); 699 } 700 // Set Outputs 701 for (CeedInt out=0; out<num_output_fields; out++) { 702 // Get output vector 703 ierr = CeedOperatorFieldGetVector(op_output_fields[out], &vec); 704 CeedChkBackend(ierr); 705 // Check if active output 706 if (vec == CEED_VECTOR_ACTIVE) { 707 CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, 708 CEED_USE_POINTER, a); CeedChkBackend(ierr); 709 ierr = CeedQFunctionFieldGetSize(qf_output_fields[out], &size); 710 CeedChkBackend(ierr); 711 a += size*Q*blk_size; // Advance the pointer by the size of the output 712 } 713 } 714 // Apply QFunction 715 ierr = CeedQFunctionApply(qf, Q*blk_size, impl->q_vecs_in, impl->q_vecs_out); 716 CeedChkBackend(ierr); 717 } 718 719 // Assemble into assembled vector 720 ierr = CeedVectorRestoreArray(l_vec, &a); CeedChkBackend(ierr); 721 ierr = CeedElemRestrictionApplyBlock(blk_rstr, e/blk_size, CEED_TRANSPOSE, 722 l_vec, *assembled, request); CeedChkBackend(ierr); 723 } 724 725 // Un-set output Qvecs to prevent accidental overwrite of Assembled 726 for (CeedInt out=0; out<num_output_fields; out++) { 727 // Get output vector 728 ierr = CeedOperatorFieldGetVector(op_output_fields[out], &vec); 729 CeedChkBackend(ierr); 730 // Check if active output 731 if (vec == CEED_VECTOR_ACTIVE) { 732 CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_COPY_VALUES, 733 NULL); CeedChkBackend(ierr); 734 } 735 } 736 737 // Restore input arrays 738 ierr = CeedOperatorRestoreInputs_Opt(num_input_fields, qf_input_fields, 739 op_input_fields, e_data, impl); 740 CeedChkBackend(ierr); 741 742 return CEED_ERROR_SUCCESS; 743 } 744 745 //------------------------------------------------------------------------------ 746 // Assemble Linear QFunction 747 //------------------------------------------------------------------------------ 748 static int CeedOperatorLinearAssembleQFunction_Opt(CeedOperator op, 749 CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 750 return CeedOperatorLinearAssembleQFunctionCore_Opt(op, true, assembled, rstr, 751 request); 752 } 753 754 //------------------------------------------------------------------------------ 755 // Update Assembled Linear QFunction 756 //------------------------------------------------------------------------------ 757 static int CeedOperatorLinearAssembleQFunctionUpdate_Opt(CeedOperator op, 758 CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 759 return CeedOperatorLinearAssembleQFunctionCore_Opt(op, false, &assembled, 760 &rstr, request); 761 } 762 763 //------------------------------------------------------------------------------ 764 // Operator Destroy 765 //------------------------------------------------------------------------------ 766 static int CeedOperatorDestroy_Opt(CeedOperator op) { 767 int ierr; 768 CeedOperator_Opt *impl; 769 ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 770 771 for (CeedInt i=0; i<impl->num_inputs+impl->num_outputs; i++) { 772 ierr = CeedElemRestrictionDestroy(&impl->blk_restr[i]); CeedChkBackend(ierr); 773 ierr = CeedVectorDestroy(&impl->e_vecs_full[i]); CeedChkBackend(ierr); 774 } 775 ierr = CeedFree(&impl->blk_restr); CeedChkBackend(ierr); 776 ierr = CeedFree(&impl->e_vecs_full); CeedChkBackend(ierr); 777 ierr = CeedFree(&impl->input_states); CeedChkBackend(ierr); 778 779 for (CeedInt i=0; i<impl->num_inputs; i++) { 780 ierr = CeedVectorDestroy(&impl->e_vecs_in[i]); CeedChkBackend(ierr); 781 ierr = CeedVectorDestroy(&impl->q_vecs_in[i]); CeedChkBackend(ierr); 782 } 783 ierr = CeedFree(&impl->e_vecs_in); CeedChkBackend(ierr); 784 ierr = CeedFree(&impl->q_vecs_in); CeedChkBackend(ierr); 785 786 for (CeedInt i=0; i<impl->num_outputs; i++) { 787 ierr = CeedVectorDestroy(&impl->e_vecs_out[i]); CeedChkBackend(ierr); 788 ierr = CeedVectorDestroy(&impl->q_vecs_out[i]); CeedChkBackend(ierr); 789 } 790 ierr = CeedFree(&impl->e_vecs_out); CeedChkBackend(ierr); 791 ierr = CeedFree(&impl->q_vecs_out); CeedChkBackend(ierr); 792 793 // QFunction assembly data 794 for (CeedInt i=0; i<impl->num_active_in; i++) { 795 ierr = CeedVectorDestroy(&impl->qf_active_in[i]); CeedChkBackend(ierr); 796 } 797 ierr = CeedFree(&impl->qf_active_in); CeedChkBackend(ierr); 798 ierr = CeedVectorDestroy(&impl->qf_l_vec); CeedChkBackend(ierr); 799 ierr = CeedElemRestrictionDestroy(&impl->qf_blk_rstr); CeedChkBackend(ierr); 800 801 ierr = CeedFree(&impl); CeedChkBackend(ierr); 802 return CEED_ERROR_SUCCESS; 803 } 804 805 //------------------------------------------------------------------------------ 806 // Operator Create 807 //------------------------------------------------------------------------------ 808 int CeedOperatorCreate_Opt(CeedOperator op) { 809 int ierr; 810 Ceed ceed; 811 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 812 Ceed_Opt *ceed_impl; 813 ierr = CeedGetData(ceed, &ceed_impl); CeedChkBackend(ierr); 814 CeedInt blk_size = ceed_impl->blk_size; 815 CeedOperator_Opt *impl; 816 817 ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr); 818 ierr = CeedOperatorSetData(op, impl); CeedChkBackend(ierr); 819 820 if (blk_size != 1 && blk_size != 8) 821 // LCOV_EXCL_START 822 return CeedError(ceed, CEED_ERROR_BACKEND, 823 "Opt backend cannot use blocksize: %d", blk_size); 824 // LCOV_EXCL_STOP 825 826 ierr = CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", 827 CeedOperatorLinearAssembleQFunction_Opt); 828 CeedChkBackend(ierr); 829 ierr = CeedSetBackendFunction(ceed, "Operator", op, 830 "LinearAssembleQFunctionUpdate", 831 CeedOperatorLinearAssembleQFunctionUpdate_Opt); 832 CeedChkBackend(ierr); 833 ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", 834 CeedOperatorApplyAdd_Opt); CeedChkBackend(ierr); 835 ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy", 836 CeedOperatorDestroy_Opt); CeedChkBackend(ierr); 837 return CEED_ERROR_SUCCESS; 838 } 839 //------------------------------------------------------------------------------ 840