1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/ceed.h> 9 #include <ceed/backend.h> 10 #include <stdbool.h> 11 #include <stdint.h> 12 #include <string.h> 13 #include "ceed-opt.h" 14 15 //------------------------------------------------------------------------------ 16 // Setup Input/Output Fields 17 //------------------------------------------------------------------------------ 18 static int CeedOperatorSetupFields_Opt(CeedQFunction qf, CeedOperator op, 19 bool is_input, const CeedInt blk_size, 20 CeedElemRestriction *blk_restr, 21 CeedVector *e_vecs_full, CeedVector *e_vecs, 22 CeedVector *q_vecs, CeedInt start_e, 23 CeedInt num_fields, CeedInt Q) { 24 CeedInt ierr, num_comp, size, P; 25 Ceed ceed; 26 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 27 CeedBasis basis; 28 CeedElemRestriction r; 29 CeedOperatorField *op_fields; 30 CeedQFunctionField *qf_fields; 31 if (is_input) { 32 ierr = CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL); 33 CeedChkBackend(ierr); 34 ierr = CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL); 35 CeedChkBackend(ierr); 36 } else { 37 ierr = CeedOperatorGetFields(op, NULL, NULL, NULL,&op_fields); 38 CeedChkBackend(ierr); 39 ierr = CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields); 40 CeedChkBackend(ierr); 41 } 42 43 // Loop over fields 44 for (CeedInt i=0; i<num_fields; i++) { 45 CeedEvalMode eval_mode; 46 ierr = CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode); 47 CeedChkBackend(ierr); 48 49 if (eval_mode != CEED_EVAL_WEIGHT) { 50 ierr = CeedOperatorFieldGetElemRestriction(op_fields[i], &r); 51 CeedChkBackend(ierr); 52 Ceed ceed; 53 ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr); 54 CeedSize l_size; 55 CeedInt num_elem, elem_size, comp_stride; 56 ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr); 57 ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr); 58 ierr = CeedElemRestrictionGetLVectorSize(r, &l_size); CeedChkBackend(ierr); 59 ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr); 60 61 bool strided; 62 ierr = CeedElemRestrictionIsStrided(r, &strided); CeedChkBackend(ierr); 63 if (strided) { 64 CeedInt strides[3]; 65 ierr = CeedElemRestrictionGetStrides(r, &strides); CeedChkBackend(ierr); 66 ierr = CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, elem_size, 67 blk_size, num_comp, l_size, strides, &blk_restr[i+start_e]); 68 CeedChkBackend(ierr); 69 } else { 70 const CeedInt *offsets = NULL; 71 ierr = CeedElemRestrictionGetOffsets(r, CEED_MEM_HOST, &offsets); 72 CeedChkBackend(ierr); 73 ierr = CeedElemRestrictionGetCompStride(r, &comp_stride); CeedChkBackend(ierr); 74 ierr = CeedElemRestrictionCreateBlocked(ceed, num_elem, elem_size, 75 blk_size, num_comp, comp_stride, 76 l_size, CEED_MEM_HOST, 77 CEED_COPY_VALUES, offsets, 78 &blk_restr[i+start_e]); 79 CeedChkBackend(ierr); 80 ierr = CeedElemRestrictionRestoreOffsets(r, &offsets); CeedChkBackend(ierr); 81 } 82 ierr = CeedElemRestrictionCreateVector(blk_restr[i+start_e], NULL, 83 &e_vecs_full[i+start_e]); 84 CeedChkBackend(ierr); 85 } 86 87 switch(eval_mode) { 88 case CEED_EVAL_NONE: 89 ierr = CeedQFunctionFieldGetSize(qf_fields[i], &size); CeedChkBackend(ierr); 90 ierr = CeedVectorCreate(ceed, Q*size*blk_size, &e_vecs[i]); 91 CeedChkBackend(ierr); 92 ierr = CeedVectorCreate(ceed, Q*size*blk_size, &q_vecs[i]); 93 CeedChkBackend(ierr); 94 break; 95 case CEED_EVAL_INTERP: 96 case CEED_EVAL_GRAD: 97 ierr = CeedOperatorFieldGetBasis(op_fields[i], &basis); CeedChkBackend(ierr); 98 ierr = CeedQFunctionFieldGetSize(qf_fields[i], &size); CeedChkBackend(ierr); 99 ierr = CeedBasisGetNumNodes(basis, &P); CeedChkBackend(ierr); 100 ierr = CeedBasisGetNumComponents(basis, &num_comp); CeedChkBackend(ierr); 101 ierr = CeedVectorCreate(ceed, P*num_comp*blk_size, &e_vecs[i]); 102 CeedChkBackend(ierr); 103 ierr = CeedVectorCreate(ceed, Q*size*blk_size, &q_vecs[i]); 104 CeedChkBackend(ierr); 105 break; 106 case CEED_EVAL_WEIGHT: // Only on input fields 107 ierr = CeedOperatorFieldGetBasis(op_fields[i], &basis); CeedChkBackend(ierr); 108 ierr = CeedVectorCreate(ceed, Q*blk_size, &q_vecs[i]); CeedChkBackend(ierr); 109 ierr = CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, 110 CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i]); 111 CeedChkBackend(ierr); 112 113 break; 114 case CEED_EVAL_DIV: 115 break; // Not implemented 116 case CEED_EVAL_CURL: 117 break; // Not implemented 118 } 119 if (is_input && !!e_vecs[i]) { 120 ierr = CeedVectorSetArray(e_vecs[i], CEED_MEM_HOST, 121 CEED_COPY_VALUES, NULL); CeedChkBackend(ierr); 122 } 123 } 124 return CEED_ERROR_SUCCESS; 125 } 126 127 //------------------------------------------------------------------------------ 128 // Setup Operator 129 //------------------------------------------------------------------------------ 130 static int CeedOperatorSetup_Opt(CeedOperator op) { 131 int ierr; 132 bool is_setup_done; 133 ierr = CeedOperatorIsSetupDone(op, &is_setup_done); CeedChkBackend(ierr); 134 if (is_setup_done) return CEED_ERROR_SUCCESS; 135 Ceed ceed; 136 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 137 Ceed_Opt *ceed_impl; 138 ierr = CeedGetData(ceed, &ceed_impl); CeedChkBackend(ierr); 139 const CeedInt blk_size = ceed_impl->blk_size; 140 CeedOperator_Opt *impl; 141 ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 142 CeedQFunction qf; 143 ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr); 144 CeedInt Q, num_input_fields, num_output_fields; 145 ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr); 146 ierr = CeedQFunctionIsIdentity(qf, &impl->is_identity_qf); CeedChkBackend(ierr); 147 CeedOperatorField *op_input_fields, *op_output_fields; 148 ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, 149 &num_output_fields, &op_output_fields); 150 CeedChkBackend(ierr); 151 CeedQFunctionField *qf_input_fields, *qf_output_fields; 152 ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, 153 &qf_output_fields); 154 CeedChkBackend(ierr); 155 156 // Allocate 157 ierr = CeedCalloc(num_input_fields + num_output_fields, &impl->blk_restr); 158 CeedChkBackend(ierr); 159 ierr = CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full); 160 CeedChkBackend(ierr); 161 162 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->input_states); CeedChkBackend(ierr); 163 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in); CeedChkBackend(ierr); 164 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out); CeedChkBackend(ierr); 165 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in); CeedChkBackend(ierr); 166 ierr = CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out); CeedChkBackend(ierr); 167 168 impl->num_inputs = num_input_fields; 169 impl->num_outputs = num_output_fields; 170 171 // Set up infield and outfield pointer arrays 172 // Infields 173 ierr = CeedOperatorSetupFields_Opt(qf, op, true, blk_size, impl->blk_restr, 174 impl->e_vecs_full, impl->e_vecs_in, 175 impl->q_vecs_in, 0, num_input_fields, Q); 176 CeedChkBackend(ierr); 177 // Outfields 178 ierr = CeedOperatorSetupFields_Opt(qf, op, false, blk_size, impl->blk_restr, 179 impl->e_vecs_full, impl->e_vecs_out, 180 impl->q_vecs_out, num_input_fields, 181 num_output_fields, Q); 182 CeedChkBackend(ierr); 183 184 // Identity QFunctions 185 if (impl->is_identity_qf) { 186 CeedEvalMode in_mode, out_mode; 187 CeedQFunctionField *in_fields, *out_fields; 188 ierr = CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields); 189 CeedChkBackend(ierr); 190 ierr = CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode); 191 CeedChkBackend(ierr); 192 ierr = CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode); 193 CeedChkBackend(ierr); 194 195 if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) { 196 impl->is_identity_restr_op = true; 197 } else { 198 ierr = CeedVectorDestroy(&impl->q_vecs_out[0]); CeedChkBackend(ierr); 199 impl->q_vecs_out[0] = impl->q_vecs_in[0]; 200 ierr = CeedVectorAddReference(impl->q_vecs_in[0]); CeedChkBackend(ierr); 201 } 202 } 203 204 ierr = CeedOperatorSetSetupDone(op); CeedChkBackend(ierr); 205 206 return CEED_ERROR_SUCCESS; 207 } 208 209 //------------------------------------------------------------------------------ 210 // Setup Input Fields 211 //------------------------------------------------------------------------------ 212 static inline int CeedOperatorSetupInputs_Opt(CeedInt num_input_fields, 213 CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 214 CeedVector in_vec, CeedScalar *e_data[2*CEED_FIELD_MAX], CeedOperator_Opt *impl, 215 CeedRequest *request) { 216 CeedInt ierr; 217 CeedEvalMode eval_mode; 218 CeedVector vec; 219 uint64_t state; 220 221 for (CeedInt i=0; i<num_input_fields; i++) { 222 ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode); 223 CeedChkBackend(ierr); 224 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 225 } else { 226 // Get input vector 227 ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 228 CeedChkBackend(ierr); 229 if (vec != CEED_VECTOR_ACTIVE) { 230 // Restrict 231 ierr = CeedVectorGetState(vec, &state); CeedChkBackend(ierr); 232 if (state != impl->input_states[i]) { 233 ierr = CeedElemRestrictionApply(impl->blk_restr[i], CEED_NOTRANSPOSE, 234 vec, impl->e_vecs_full[i], request); 235 CeedChkBackend(ierr); 236 impl->input_states[i] = state; 237 } 238 // Get evec 239 ierr = CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, 240 (const CeedScalar **) &e_data[i]); 241 CeedChkBackend(ierr); 242 } else { 243 // Set Qvec for CEED_EVAL_NONE 244 if (eval_mode == CEED_EVAL_NONE) { 245 ierr = CeedVectorGetArrayRead(impl->e_vecs_in[i], CEED_MEM_HOST, 246 (const CeedScalar **)&e_data[i]); 247 CeedChkBackend(ierr); 248 ierr = CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, 249 CEED_USE_POINTER, e_data[i]); CeedChkBackend(ierr); 250 ierr = CeedVectorRestoreArrayRead(impl->e_vecs_in[i], 251 (const CeedScalar **)&e_data[i]); 252 CeedChkBackend(ierr); 253 } 254 } 255 } 256 } 257 return CEED_ERROR_SUCCESS; 258 } 259 260 //------------------------------------------------------------------------------ 261 // Input Basis Action 262 //------------------------------------------------------------------------------ 263 static inline int CeedOperatorInputBasis_Opt(CeedInt e, CeedInt Q, 264 CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 265 CeedInt num_input_fields, CeedInt blk_size, CeedVector in_vec, bool skip_active, 266 CeedScalar *e_data[2*CEED_FIELD_MAX], CeedOperator_Opt *impl, 267 CeedRequest *request) { 268 CeedInt ierr; 269 CeedInt dim, elem_size, size; 270 CeedElemRestriction elem_restr; 271 CeedEvalMode eval_mode; 272 CeedBasis basis; 273 CeedVector vec; 274 275 for (CeedInt i=0; i<num_input_fields; i++) { 276 ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 277 CeedChkBackend(ierr); 278 // Skip active input 279 if (skip_active) { 280 if (vec == CEED_VECTOR_ACTIVE) 281 continue; 282 } 283 284 CeedInt active_in = 0; 285 // Get elem_size, eval_mode, size 286 ierr = CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr); 287 CeedChkBackend(ierr); 288 ierr = CeedElemRestrictionGetElementSize(elem_restr, &elem_size); 289 CeedChkBackend(ierr); 290 ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode); 291 CeedChkBackend(ierr); 292 ierr = CeedQFunctionFieldGetSize(qf_input_fields[i], &size); 293 CeedChkBackend(ierr); 294 // Restrict block active input 295 if (vec == CEED_VECTOR_ACTIVE) { 296 ierr = CeedElemRestrictionApplyBlock(impl->blk_restr[i], e/blk_size, 297 CEED_NOTRANSPOSE, in_vec, 298 impl->e_vecs_in[i], request); 299 CeedChkBackend(ierr); 300 active_in = 1; 301 } 302 // Basis action 303 switch(eval_mode) { 304 case CEED_EVAL_NONE: 305 if (!active_in) { 306 ierr = CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, 307 CEED_USE_POINTER, &e_data[i][e*Q*size]); 308 CeedChkBackend(ierr); 309 } 310 break; 311 case CEED_EVAL_INTERP: 312 ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis); 313 CeedChkBackend(ierr); 314 if (!active_in) { 315 ierr = CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, 316 CEED_USE_POINTER, &e_data[i][e*elem_size*size]); 317 CeedChkBackend(ierr); 318 } 319 ierr = CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, 320 CEED_EVAL_INTERP, impl->e_vecs_in[i], 321 impl->q_vecs_in[i]); CeedChkBackend(ierr); 322 break; 323 case CEED_EVAL_GRAD: 324 ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis); 325 CeedChkBackend(ierr); 326 if (!active_in) { 327 ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr); 328 ierr = CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, 329 CEED_USE_POINTER, 330 &e_data[i][e*elem_size*size/dim]); 331 CeedChkBackend(ierr); 332 } 333 ierr = CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, 334 CEED_EVAL_GRAD, impl->e_vecs_in[i], 335 impl->q_vecs_in[i]); CeedChkBackend(ierr); 336 break; 337 case CEED_EVAL_WEIGHT: 338 break; // No action 339 // LCOV_EXCL_START 340 case CEED_EVAL_DIV: 341 case CEED_EVAL_CURL: { 342 ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis); 343 CeedChkBackend(ierr); 344 Ceed ceed; 345 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 346 return CeedError(ceed, CEED_ERROR_BACKEND, 347 "Ceed evaluation mode not implemented"); 348 // LCOV_EXCL_STOP 349 } 350 } 351 } 352 return CEED_ERROR_SUCCESS; 353 } 354 355 //------------------------------------------------------------------------------ 356 // Output Basis Action 357 //------------------------------------------------------------------------------ 358 static inline int CeedOperatorOutputBasis_Opt(CeedInt e, CeedInt Q, 359 CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, 360 CeedInt blk_size, CeedInt num_input_fields, CeedInt num_output_fields, 361 CeedOperator op, CeedVector out_vec, CeedOperator_Opt *impl, 362 CeedRequest *request) { 363 CeedInt ierr; 364 CeedElemRestriction elem_restr; 365 CeedEvalMode eval_mode; 366 CeedBasis basis; 367 CeedVector vec; 368 369 for (CeedInt i=0; i<num_output_fields; i++) { 370 // Get elem_size, eval_mode, size 371 ierr = CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr); 372 CeedChkBackend(ierr); 373 ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode); 374 CeedChkBackend(ierr); 375 // Basis action 376 switch(eval_mode) { 377 case CEED_EVAL_NONE: 378 break; // No action 379 case CEED_EVAL_INTERP: 380 ierr = CeedOperatorFieldGetBasis(op_output_fields[i], &basis); 381 CeedChkBackend(ierr); 382 ierr = CeedBasisApply(basis, blk_size, CEED_TRANSPOSE, 383 CEED_EVAL_INTERP, impl->q_vecs_out[i], 384 impl->e_vecs_out[i]); CeedChkBackend(ierr); 385 break; 386 case CEED_EVAL_GRAD: 387 ierr = CeedOperatorFieldGetBasis(op_output_fields[i], &basis); 388 CeedChkBackend(ierr); 389 ierr = CeedBasisApply(basis, blk_size, CEED_TRANSPOSE, 390 CEED_EVAL_GRAD, impl->q_vecs_out[i], 391 impl->e_vecs_out[i]); CeedChkBackend(ierr); 392 break; 393 // LCOV_EXCL_START 394 case CEED_EVAL_WEIGHT: { 395 Ceed ceed; 396 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 397 return CeedError(ceed, CEED_ERROR_BACKEND, 398 "CEED_EVAL_WEIGHT cannot be an output " 399 "evaluation mode"); 400 } 401 case CEED_EVAL_DIV: 402 case CEED_EVAL_CURL: { 403 Ceed ceed; 404 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 405 return CeedError(ceed, CEED_ERROR_BACKEND, 406 "Ceed evaluation mode not implemented"); 407 // LCOV_EXCL_STOP 408 } 409 } 410 // Restrict output block 411 // Get output vector 412 ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec); 413 CeedChkBackend(ierr); 414 if (vec == CEED_VECTOR_ACTIVE) 415 vec = out_vec; 416 // Restrict 417 ierr = CeedElemRestrictionApplyBlock(impl->blk_restr[i+impl->num_inputs], 418 e/blk_size, CEED_TRANSPOSE, 419 impl->e_vecs_out[i], vec, request); 420 CeedChkBackend(ierr); 421 } 422 return CEED_ERROR_SUCCESS; 423 } 424 425 //------------------------------------------------------------------------------ 426 // Restore Input Vectors 427 //------------------------------------------------------------------------------ 428 static inline int CeedOperatorRestoreInputs_Opt(CeedInt num_input_fields, 429 CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 430 CeedScalar *e_data[2*CEED_FIELD_MAX], CeedOperator_Opt *impl) { 431 CeedInt ierr; 432 433 for (CeedInt i=0; i<num_input_fields; i++) { 434 CeedEvalMode eval_mode; 435 CeedVector vec; 436 ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode); 437 CeedChkBackend(ierr); 438 ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 439 CeedChkBackend(ierr); 440 if (eval_mode != CEED_EVAL_WEIGHT && vec != CEED_VECTOR_ACTIVE) { 441 ierr = CeedVectorRestoreArrayRead(impl->e_vecs_full[i], 442 (const CeedScalar **) &e_data[i]); 443 CeedChkBackend(ierr); 444 } 445 } 446 return CEED_ERROR_SUCCESS; 447 } 448 449 //------------------------------------------------------------------------------ 450 // Operator Apply 451 //------------------------------------------------------------------------------ 452 static int CeedOperatorApplyAdd_Opt(CeedOperator op, CeedVector in_vec, 453 CeedVector out_vec, CeedRequest *request) { 454 int ierr; 455 Ceed ceed; 456 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 457 Ceed_Opt *ceed_impl; 458 ierr = CeedGetData(ceed, &ceed_impl); CeedChkBackend(ierr); 459 CeedInt blk_size = ceed_impl->blk_size; 460 CeedOperator_Opt *impl; 461 ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 462 CeedInt Q, num_input_fields, num_output_fields, num_elem; 463 ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr); 464 ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr); 465 CeedInt num_blks = (num_elem/blk_size) + !!(num_elem%blk_size); 466 CeedQFunction qf; 467 ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr); 468 CeedOperatorField *op_input_fields, *op_output_fields; 469 ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, 470 &num_output_fields, &op_output_fields); 471 CeedChkBackend(ierr); 472 CeedQFunctionField *qf_input_fields, *qf_output_fields; 473 ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, 474 &qf_output_fields); 475 CeedChkBackend(ierr); 476 CeedEvalMode eval_mode; 477 CeedScalar *e_data[2*CEED_FIELD_MAX] = {0}; 478 479 // Setup 480 ierr = CeedOperatorSetup_Opt(op); CeedChkBackend(ierr); 481 482 // Restriction only operator 483 if (impl->is_identity_restr_op) { 484 for (CeedInt b=0; b<num_blks; b++) { 485 ierr = CeedElemRestrictionApplyBlock(impl->blk_restr[0], b, CEED_NOTRANSPOSE, 486 in_vec, impl->e_vecs_in[0], request); CeedChkBackend(ierr); 487 ierr = CeedElemRestrictionApplyBlock(impl->blk_restr[1], b, CEED_TRANSPOSE, 488 impl->e_vecs_in[0], out_vec, request); CeedChkBackend(ierr); 489 } 490 return CEED_ERROR_SUCCESS; 491 } 492 493 // Input Evecs and Restriction 494 ierr = CeedOperatorSetupInputs_Opt(num_input_fields, qf_input_fields, 495 op_input_fields, in_vec, e_data, 496 impl, request); CeedChkBackend(ierr); 497 498 // Output Lvecs, Evecs, and Qvecs 499 for (CeedInt i=0; i<num_output_fields; i++) { 500 // Set Qvec if needed 501 ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode); 502 CeedChkBackend(ierr); 503 if (eval_mode == CEED_EVAL_NONE) { 504 // Set qvec to single block evec 505 ierr = CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_HOST, 506 &e_data[i + num_input_fields]); 507 CeedChkBackend(ierr); 508 ierr = CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST, 509 CEED_USE_POINTER, e_data[i + num_input_fields]); 510 CeedChkBackend(ierr); 511 ierr = CeedVectorRestoreArray(impl->e_vecs_out[i], 512 &e_data[i + num_input_fields]); 513 CeedChkBackend(ierr); 514 } 515 } 516 517 // Loop through elements 518 for (CeedInt e=0; e<num_blks*blk_size; e+=blk_size) { 519 // Input basis apply 520 ierr = CeedOperatorInputBasis_Opt(e, Q, qf_input_fields, op_input_fields, 521 num_input_fields, blk_size, in_vec, false, 522 e_data, impl, request); CeedChkBackend(ierr); 523 524 // Q function 525 if (!impl->is_identity_qf) { 526 ierr = CeedQFunctionApply(qf, Q*blk_size, impl->q_vecs_in, impl->q_vecs_out); 527 CeedChkBackend(ierr); 528 } 529 530 // Output basis apply and restrict 531 ierr = CeedOperatorOutputBasis_Opt(e, Q, qf_output_fields, op_output_fields, 532 blk_size, num_input_fields, num_output_fields, 533 op, out_vec, impl, request); 534 CeedChkBackend(ierr); 535 } 536 537 // Restore input arrays 538 ierr = CeedOperatorRestoreInputs_Opt(num_input_fields, qf_input_fields, 539 op_input_fields, e_data, impl); 540 CeedChkBackend(ierr); 541 542 return CEED_ERROR_SUCCESS; 543 } 544 545 //------------------------------------------------------------------------------ 546 // Core code for linear QFunction assembly 547 //------------------------------------------------------------------------------ 548 static inline int CeedOperatorLinearAssembleQFunctionCore_Opt(CeedOperator op, 549 bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 550 CeedRequest *request) { 551 int ierr; 552 Ceed ceed; 553 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 554 Ceed_Opt *ceed_impl; 555 ierr = CeedGetData(ceed, &ceed_impl); CeedChkBackend(ierr); 556 const CeedInt blk_size = ceed_impl->blk_size; 557 CeedOperator_Opt *impl; 558 ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 559 CeedInt Q, num_input_fields, num_output_fields, num_elem, size; 560 ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr); 561 ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr); 562 CeedInt num_blks = (num_elem/blk_size) + !!(num_elem%blk_size); 563 CeedQFunction qf; 564 ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr); 565 CeedOperatorField *op_input_fields, *op_output_fields; 566 ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, 567 &num_output_fields, &op_output_fields); 568 CeedChkBackend(ierr); 569 CeedQFunctionField *qf_input_fields, *qf_output_fields; 570 ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, 571 &qf_output_fields); 572 CeedChkBackend(ierr); 573 CeedVector vec, l_vec = impl->qf_l_vec; 574 CeedInt num_active_in = impl->num_active_in, 575 num_active_out = impl->num_active_out; 576 CeedVector *active_in = impl->qf_active_in; 577 CeedScalar *a, *tmp; 578 CeedScalar *e_data[2*CEED_FIELD_MAX] = {0}; 579 580 // Setup 581 ierr = CeedOperatorSetup_Opt(op); CeedChkBackend(ierr); 582 583 // Check for identity 584 if (impl->is_identity_qf) 585 // LCOV_EXCL_START 586 return CeedError(ceed, CEED_ERROR_BACKEND, 587 "Assembling identity qfunctions not supported"); 588 // LCOV_EXCL_STOP 589 590 // Input Evecs and Restriction 591 ierr = CeedOperatorSetupInputs_Opt(num_input_fields, qf_input_fields, 592 op_input_fields, NULL, e_data, 593 impl, request); CeedChkBackend(ierr); 594 595 // Count number of active input fields 596 if (!num_active_in) { 597 for (CeedInt i=0; i<num_input_fields; i++) { 598 // Get input vector 599 ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 600 CeedChkBackend(ierr); 601 // Check if active input 602 if (vec == CEED_VECTOR_ACTIVE) { 603 ierr = CeedQFunctionFieldGetSize(qf_input_fields[i], &size); 604 CeedChkBackend(ierr); 605 ierr = CeedVectorSetValue(impl->q_vecs_in[i], 0.0); CeedChkBackend(ierr); 606 ierr = CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &tmp); 607 CeedChkBackend(ierr); 608 ierr = CeedRealloc(num_active_in + size, &active_in); CeedChkBackend(ierr); 609 for (CeedInt field=0; field<size; field++) { 610 ierr = CeedVectorCreate(ceed, Q*blk_size, &active_in[num_active_in+field]); 611 CeedChkBackend(ierr); 612 ierr = CeedVectorSetArray(active_in[num_active_in+field], CEED_MEM_HOST, 613 CEED_USE_POINTER, &tmp[field*Q*blk_size]); 614 CeedChkBackend(ierr); 615 } 616 num_active_in += size; 617 ierr = CeedVectorRestoreArray(impl->q_vecs_in[i], &tmp); CeedChkBackend(ierr); 618 } 619 } 620 impl->num_active_in = num_active_in; 621 impl->qf_active_in = active_in; 622 } 623 624 // Count number of active output fields 625 if (!num_active_out) { 626 for (CeedInt i=0; i<num_output_fields; i++) { 627 // Get output vector 628 ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec); 629 CeedChkBackend(ierr); 630 // Check if active output 631 if (vec == CEED_VECTOR_ACTIVE) { 632 ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size); 633 CeedChkBackend(ierr); 634 num_active_out += size; 635 } 636 } 637 impl->num_active_out = num_active_out; 638 } 639 640 // Check sizes 641 if (!num_active_in || !num_active_out) 642 // LCOV_EXCL_START 643 return CeedError(ceed, CEED_ERROR_BACKEND, 644 "Cannot assemble QFunction without active inputs " 645 "and outputs"); 646 // LCOV_EXCL_STOP 647 648 // Setup l_vec 649 if (!l_vec) { 650 ierr = CeedVectorCreate(ceed, blk_size*Q*num_active_in*num_active_out, 651 &l_vec); CeedChkBackend(ierr); 652 ierr = CeedVectorSetValue(l_vec, 0.0); CeedChkBackend(ierr); 653 impl->qf_l_vec = l_vec; 654 } 655 656 // Build objects if needed 657 CeedInt strides[3] = {1, Q, num_active_in *num_active_out*Q}; 658 if (build_objects) { 659 // Create output restriction 660 ierr = CeedElemRestrictionCreateStrided(ceed, num_elem, Q, 661 num_active_in*num_active_out, 662 num_active_in*num_active_out*num_elem*Q, 663 strides, rstr); CeedChkBackend(ierr); 664 // Create assembled vector 665 ierr = CeedVectorCreate(ceed, num_elem*Q*num_active_in*num_active_out, 666 assembled); CeedChkBackend(ierr); 667 } 668 669 // Output blocked restriction 670 CeedElemRestriction blk_rstr = impl->qf_blk_rstr; 671 if (!blk_rstr) { 672 ierr = CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, Q, blk_size, 673 num_active_in*num_active_out, num_active_in*num_active_out*num_elem*Q, 674 strides, &blk_rstr); CeedChkBackend(ierr); 675 impl->qf_blk_rstr = blk_rstr; 676 } 677 678 // Loop through elements 679 ierr = CeedVectorSetValue(*assembled, 0.0); CeedChkBackend(ierr); 680 for (CeedInt e=0; e<num_blks*blk_size; e+=blk_size) { 681 ierr = CeedVectorGetArray(l_vec, CEED_MEM_HOST, &a); CeedChkBackend(ierr); 682 683 // Input basis apply 684 ierr = CeedOperatorInputBasis_Opt(e, Q, qf_input_fields, op_input_fields, 685 num_input_fields, blk_size, NULL, true, 686 e_data, impl, request); CeedChkBackend(ierr); 687 688 // Assemble QFunction 689 for (CeedInt in=0; in<num_active_in; in++) { 690 // Set Inputs 691 ierr = CeedVectorSetValue(active_in[in], 1.0); CeedChkBackend(ierr); 692 if (num_active_in > 1) { 693 ierr = CeedVectorSetValue(active_in[(in+num_active_in-1)%num_active_in], 694 0.0); CeedChkBackend(ierr); 695 } 696 // Set Outputs 697 for (CeedInt out=0; out<num_output_fields; out++) { 698 // Get output vector 699 ierr = CeedOperatorFieldGetVector(op_output_fields[out], &vec); 700 CeedChkBackend(ierr); 701 // Check if active output 702 if (vec == CEED_VECTOR_ACTIVE) { 703 CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, 704 CEED_USE_POINTER, a); CeedChkBackend(ierr); 705 ierr = CeedQFunctionFieldGetSize(qf_output_fields[out], &size); 706 CeedChkBackend(ierr); 707 a += size*Q*blk_size; // Advance the pointer by the size of the output 708 } 709 } 710 // Apply QFunction 711 ierr = CeedQFunctionApply(qf, Q*blk_size, impl->q_vecs_in, impl->q_vecs_out); 712 CeedChkBackend(ierr); 713 } 714 715 // Assemble into assembled vector 716 ierr = CeedVectorRestoreArray(l_vec, &a); CeedChkBackend(ierr); 717 ierr = CeedElemRestrictionApplyBlock(blk_rstr, e/blk_size, CEED_TRANSPOSE, 718 l_vec, *assembled, request); CeedChkBackend(ierr); 719 } 720 721 // Un-set output Qvecs to prevent accidental overwrite of Assembled 722 for (CeedInt out=0; out<num_output_fields; out++) { 723 // Get output vector 724 ierr = CeedOperatorFieldGetVector(op_output_fields[out], &vec); 725 CeedChkBackend(ierr); 726 // Check if active output 727 if (vec == CEED_VECTOR_ACTIVE) { 728 CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_COPY_VALUES, 729 NULL); CeedChkBackend(ierr); 730 } 731 } 732 733 // Restore input arrays 734 ierr = CeedOperatorRestoreInputs_Opt(num_input_fields, qf_input_fields, 735 op_input_fields, e_data, impl); 736 CeedChkBackend(ierr); 737 738 return CEED_ERROR_SUCCESS; 739 } 740 741 //------------------------------------------------------------------------------ 742 // Assemble Linear QFunction 743 //------------------------------------------------------------------------------ 744 static int CeedOperatorLinearAssembleQFunction_Opt(CeedOperator op, 745 CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 746 return CeedOperatorLinearAssembleQFunctionCore_Opt(op, true, assembled, rstr, 747 request); 748 } 749 750 //------------------------------------------------------------------------------ 751 // Update Assembled Linear QFunction 752 //------------------------------------------------------------------------------ 753 static int CeedOperatorLinearAssembleQFunctionUpdate_Opt(CeedOperator op, 754 CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 755 return CeedOperatorLinearAssembleQFunctionCore_Opt(op, false, &assembled, 756 &rstr, request); 757 } 758 759 //------------------------------------------------------------------------------ 760 // Operator Destroy 761 //------------------------------------------------------------------------------ 762 static int CeedOperatorDestroy_Opt(CeedOperator op) { 763 int ierr; 764 CeedOperator_Opt *impl; 765 ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 766 767 for (CeedInt i=0; i<impl->num_inputs+impl->num_outputs; i++) { 768 ierr = CeedElemRestrictionDestroy(&impl->blk_restr[i]); CeedChkBackend(ierr); 769 ierr = CeedVectorDestroy(&impl->e_vecs_full[i]); CeedChkBackend(ierr); 770 } 771 ierr = CeedFree(&impl->blk_restr); CeedChkBackend(ierr); 772 ierr = CeedFree(&impl->e_vecs_full); CeedChkBackend(ierr); 773 ierr = CeedFree(&impl->input_states); CeedChkBackend(ierr); 774 775 for (CeedInt i=0; i<impl->num_inputs; i++) { 776 ierr = CeedVectorDestroy(&impl->e_vecs_in[i]); CeedChkBackend(ierr); 777 ierr = CeedVectorDestroy(&impl->q_vecs_in[i]); CeedChkBackend(ierr); 778 } 779 ierr = CeedFree(&impl->e_vecs_in); CeedChkBackend(ierr); 780 ierr = CeedFree(&impl->q_vecs_in); CeedChkBackend(ierr); 781 782 for (CeedInt i=0; i<impl->num_outputs; i++) { 783 ierr = CeedVectorDestroy(&impl->e_vecs_out[i]); CeedChkBackend(ierr); 784 ierr = CeedVectorDestroy(&impl->q_vecs_out[i]); CeedChkBackend(ierr); 785 } 786 ierr = CeedFree(&impl->e_vecs_out); CeedChkBackend(ierr); 787 ierr = CeedFree(&impl->q_vecs_out); CeedChkBackend(ierr); 788 789 // QFunction assembly data 790 for (CeedInt i=0; i<impl->num_active_in; i++) { 791 ierr = CeedVectorDestroy(&impl->qf_active_in[i]); CeedChkBackend(ierr); 792 } 793 ierr = CeedFree(&impl->qf_active_in); CeedChkBackend(ierr); 794 ierr = CeedVectorDestroy(&impl->qf_l_vec); CeedChkBackend(ierr); 795 ierr = CeedElemRestrictionDestroy(&impl->qf_blk_rstr); CeedChkBackend(ierr); 796 797 ierr = CeedFree(&impl); CeedChkBackend(ierr); 798 return CEED_ERROR_SUCCESS; 799 } 800 801 //------------------------------------------------------------------------------ 802 // Operator Create 803 //------------------------------------------------------------------------------ 804 int CeedOperatorCreate_Opt(CeedOperator op) { 805 int ierr; 806 Ceed ceed; 807 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 808 Ceed_Opt *ceed_impl; 809 ierr = CeedGetData(ceed, &ceed_impl); CeedChkBackend(ierr); 810 CeedInt blk_size = ceed_impl->blk_size; 811 CeedOperator_Opt *impl; 812 813 ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr); 814 ierr = CeedOperatorSetData(op, impl); CeedChkBackend(ierr); 815 816 if (blk_size != 1 && blk_size != 8) 817 // LCOV_EXCL_START 818 return CeedError(ceed, CEED_ERROR_BACKEND, 819 "Opt backend cannot use blocksize: %d", blk_size); 820 // LCOV_EXCL_STOP 821 822 ierr = CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", 823 CeedOperatorLinearAssembleQFunction_Opt); 824 CeedChkBackend(ierr); 825 ierr = CeedSetBackendFunction(ceed, "Operator", op, 826 "LinearAssembleQFunctionUpdate", 827 CeedOperatorLinearAssembleQFunctionUpdate_Opt); 828 CeedChkBackend(ierr); 829 ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", 830 CeedOperatorApplyAdd_Opt); CeedChkBackend(ierr); 831 ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy", 832 CeedOperatorDestroy_Opt); CeedChkBackend(ierr); 833 return CEED_ERROR_SUCCESS; 834 } 835 //------------------------------------------------------------------------------ 836