1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <stdbool.h> 11 #include <stddef.h> 12 #include <stdint.h> 13 14 #include "ceed-ref.h" 15 16 //------------------------------------------------------------------------------ 17 // Setup Input/Output Fields 18 //------------------------------------------------------------------------------ 19 static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op, bool is_input, CeedVector *e_vecs_full, CeedVector *e_vecs, 20 CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) { 21 Ceed ceed; 22 CeedSize e_size, q_size; 23 CeedInt num_comp, size, P; 24 CeedQFunctionField *qf_fields; 25 CeedOperatorField *op_fields; 26 27 { 28 Ceed ceed_parent; 29 30 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 31 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 32 if (ceed_parent) ceed = ceed_parent; 33 } 34 if (is_input) { 35 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 36 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 37 } else { 38 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 39 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 40 } 41 42 // Loop over fields 43 for (CeedInt i = 0; i < num_fields; i++) { 44 CeedEvalMode eval_mode; 45 CeedElemRestriction elem_rstr; 46 CeedBasis basis; 47 48 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 49 if (eval_mode != CEED_EVAL_WEIGHT) { 50 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); 51 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs_full[i + start_e])); 52 } 53 54 switch (eval_mode) { 55 case CEED_EVAL_NONE: 56 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 57 q_size = (CeedSize)Q * size; 58 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 59 break; 60 case CEED_EVAL_INTERP: 61 case CEED_EVAL_GRAD: 62 case CEED_EVAL_DIV: 63 case CEED_EVAL_CURL: 64 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 65 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 66 CeedCallBackend(CeedBasisGetNumNodes(basis, &P)); 67 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 68 e_size = (CeedSize)P * num_comp; 69 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 70 q_size = (CeedSize)Q * size; 71 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 72 break; 73 case CEED_EVAL_WEIGHT: // Only on input fields 74 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 75 q_size = (CeedSize)Q; 76 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 77 CeedCallBackend(CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 78 break; 79 } 80 } 81 return CEED_ERROR_SUCCESS; 82 } 83 84 //------------------------------------------------------------------------------ 85 // Setup Operator 86 //------------------------------------------------------------------------------/* 87 static int CeedOperatorSetup_Ref(CeedOperator op) { 88 bool is_setup_done; 89 CeedInt Q, num_input_fields, num_output_fields; 90 CeedQFunctionField *qf_input_fields, *qf_output_fields; 91 CeedQFunction qf; 92 CeedOperatorField *op_input_fields, *op_output_fields; 93 CeedOperator_Ref *impl; 94 95 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 96 if (is_setup_done) return CEED_ERROR_SUCCESS; 97 98 CeedCallBackend(CeedOperatorGetData(op, &impl)); 99 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 100 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 101 CeedCallBackend(CeedQFunctionIsIdentity(qf, &impl->is_identity_qf)); 102 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 103 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 104 105 // Allocate 106 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); 107 108 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 109 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); 110 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); 111 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 112 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 113 114 impl->num_inputs = num_input_fields; 115 impl->num_outputs = num_output_fields; 116 117 // Set up infield and outfield e_vecs and q_vecs 118 // Infields 119 CeedCallBackend(CeedOperatorSetupFields_Ref(qf, op, true, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q)); 120 // Outfields 121 CeedCallBackend( 122 CeedOperatorSetupFields_Ref(qf, op, false, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields, num_output_fields, Q)); 123 124 // Identity QFunctions 125 if (impl->is_identity_qf) { 126 CeedEvalMode in_mode, out_mode; 127 CeedQFunctionField *in_fields, *out_fields; 128 129 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields)); 130 CeedCallBackend(CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode)); 131 CeedCallBackend(CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode)); 132 133 if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) { 134 impl->is_identity_rstr_op = true; 135 } else { 136 CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->q_vecs_out[0])); 137 } 138 } 139 140 CeedCallBackend(CeedOperatorSetSetupDone(op)); 141 return CEED_ERROR_SUCCESS; 142 } 143 144 //------------------------------------------------------------------------------ 145 // Setup Operator Inputs 146 //------------------------------------------------------------------------------ 147 static inline int CeedOperatorSetupInputs_Ref(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 148 CeedVector in_vec, const bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], 149 CeedOperator_Ref *impl, CeedRequest *request) { 150 for (CeedInt i = 0; i < num_input_fields; i++) { 151 uint64_t state; 152 CeedEvalMode eval_mode; 153 CeedVector vec; 154 CeedElemRestriction elem_rstr; 155 156 // Get input vector 157 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 158 if (vec == CEED_VECTOR_ACTIVE) { 159 if (skip_active) continue; 160 else vec = in_vec; 161 } 162 163 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 164 // Restrict and Evec 165 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 166 } else { 167 // Restrict 168 CeedCallBackend(CeedVectorGetState(vec, &state)); 169 // Skip restriction if input is unchanged 170 if (state != impl->input_states[i] || vec == in_vec) { 171 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 172 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request)); 173 impl->input_states[i] = state; 174 } 175 // Get evec 176 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, (const CeedScalar **)&e_data_full[i])); 177 } 178 } 179 return CEED_ERROR_SUCCESS; 180 } 181 182 //------------------------------------------------------------------------------ 183 // Input Basis Action 184 //------------------------------------------------------------------------------ 185 static inline int CeedOperatorInputBasis_Ref(CeedInt e, CeedInt Q, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 186 CeedInt num_input_fields, const bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], 187 CeedOperator_Ref *impl) { 188 for (CeedInt i = 0; i < num_input_fields; i++) { 189 CeedInt elem_size, size, num_comp; 190 CeedEvalMode eval_mode; 191 CeedElemRestriction elem_rstr; 192 CeedBasis basis; 193 194 // Skip active input 195 if (skip_active) { 196 CeedVector vec; 197 198 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 199 if (vec == CEED_VECTOR_ACTIVE) continue; 200 } 201 // Get elem_size, eval_mode, size 202 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 203 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 204 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 205 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 206 // Basis action 207 switch (eval_mode) { 208 case CEED_EVAL_NONE: 209 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][(CeedSize)e * Q * size])); 210 break; 211 case CEED_EVAL_INTERP: 212 case CEED_EVAL_GRAD: 213 case CEED_EVAL_DIV: 214 case CEED_EVAL_CURL: 215 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 216 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 217 CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][(CeedSize)e * elem_size * num_comp])); 218 CeedCallBackend(CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs_in[i], impl->q_vecs_in[i])); 219 break; 220 case CEED_EVAL_WEIGHT: 221 break; // No action 222 } 223 } 224 return CEED_ERROR_SUCCESS; 225 } 226 227 //------------------------------------------------------------------------------ 228 // Output Basis Action 229 //------------------------------------------------------------------------------ 230 static inline int CeedOperatorOutputBasis_Ref(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, 231 CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op, 232 CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Ref *impl) { 233 for (CeedInt i = 0; i < num_output_fields; i++) { 234 CeedInt elem_size, num_comp; 235 CeedEvalMode eval_mode; 236 CeedElemRestriction elem_rstr; 237 CeedBasis basis; 238 239 // Get elem_size, eval_mode 240 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 241 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 242 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 243 // Basis action 244 switch (eval_mode) { 245 case CEED_EVAL_NONE: 246 break; // No action 247 case CEED_EVAL_INTERP: 248 case CEED_EVAL_GRAD: 249 case CEED_EVAL_DIV: 250 case CEED_EVAL_CURL: 251 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 252 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 253 CeedCallBackend(CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, 254 &e_data_full[i + num_input_fields][(CeedSize)e * elem_size * num_comp])); 255 CeedCallBackend(CeedBasisApply(basis, 1, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); 256 break; 257 // LCOV_EXCL_START 258 case CEED_EVAL_WEIGHT: { 259 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 260 // LCOV_EXCL_STOP 261 } 262 } 263 } 264 return CEED_ERROR_SUCCESS; 265 } 266 267 //------------------------------------------------------------------------------ 268 // Restore Input Vectors 269 //------------------------------------------------------------------------------ 270 static inline int CeedOperatorRestoreInputs_Ref(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 271 const bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Ref *impl) { 272 for (CeedInt i = 0; i < num_input_fields; i++) { 273 CeedEvalMode eval_mode; 274 275 // Skip active inputs 276 if (skip_active) { 277 CeedVector vec; 278 279 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 280 if (vec == CEED_VECTOR_ACTIVE) continue; 281 } 282 // Restore input 283 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 284 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 285 } else { 286 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_full[i], (const CeedScalar **)&e_data_full[i])); 287 } 288 } 289 return CEED_ERROR_SUCCESS; 290 } 291 292 //------------------------------------------------------------------------------ 293 // Operator Apply 294 //------------------------------------------------------------------------------ 295 static int CeedOperatorApplyAdd_Ref(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 296 CeedInt Q, num_elem, num_input_fields, num_output_fields, size; 297 CeedEvalMode eval_mode; 298 CeedScalar *e_data_full[2 * CEED_FIELD_MAX] = {NULL}; 299 CeedQFunctionField *qf_input_fields, *qf_output_fields; 300 CeedQFunction qf; 301 CeedOperatorField *op_input_fields, *op_output_fields; 302 CeedOperator_Ref *impl; 303 304 CeedCallBackend(CeedOperatorGetData(op, &impl)); 305 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 306 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 307 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 308 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 309 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 310 311 // Setup 312 CeedCallBackend(CeedOperatorSetup_Ref(op)); 313 314 // Restriction only operator 315 if (impl->is_identity_rstr_op) { 316 CeedElemRestriction elem_rstr; 317 318 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[0], &elem_rstr)); 319 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_full[0], request)); 320 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[0], &elem_rstr)); 321 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs_full[0], out_vec, request)); 322 return CEED_ERROR_SUCCESS; 323 } 324 325 // Input Evecs and Restriction 326 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data_full, impl, request)); 327 328 // Output Evecs 329 for (CeedInt i = 0; i < num_output_fields; i++) { 330 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields])); 331 } 332 333 // Loop through elements 334 for (CeedInt e = 0; e < num_elem; e++) { 335 // Output pointers 336 for (CeedInt i = 0; i < num_output_fields; i++) { 337 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 338 if (eval_mode == CEED_EVAL_NONE) { 339 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 340 CeedCallBackend( 341 CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][(CeedSize)e * Q * size])); 342 } 343 } 344 345 // Input basis apply 346 CeedCallBackend(CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields, num_input_fields, false, e_data_full, impl)); 347 348 // Q function 349 if (!impl->is_identity_qf) { 350 CeedCallBackend(CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out)); 351 } 352 353 // Output basis apply 354 CeedCallBackend( 355 CeedOperatorOutputBasis_Ref(e, Q, qf_output_fields, op_output_fields, num_input_fields, num_output_fields, op, e_data_full, impl)); 356 } 357 358 // Output restriction 359 for (CeedInt i = 0; i < num_output_fields; i++) { 360 CeedVector vec; 361 CeedElemRestriction elem_rstr; 362 363 // Restore Evec 364 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_full[i + impl->num_inputs], &e_data_full[i + num_input_fields])); 365 // Get output vector 366 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 367 // Active 368 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 369 // Restrict 370 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 371 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs_full[i + impl->num_inputs], vec, request)); 372 } 373 374 // Restore input arrays 375 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, false, e_data_full, impl)); 376 return CEED_ERROR_SUCCESS; 377 } 378 379 //------------------------------------------------------------------------------ 380 // Core code for assembling linear QFunction 381 //------------------------------------------------------------------------------ 382 static inline int CeedOperatorLinearAssembleQFunctionCore_Ref(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 383 CeedRequest *request) { 384 Ceed ceed, ceed_parent; 385 CeedSize q_size; 386 CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size; 387 CeedScalar *assembled_array, *e_data_full[2 * CEED_FIELD_MAX] = {NULL}; 388 CeedVector *active_in; 389 CeedQFunctionField *qf_input_fields, *qf_output_fields; 390 CeedQFunction qf; 391 CeedOperatorField *op_input_fields, *op_output_fields; 392 CeedOperator_Ref *impl; 393 394 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 395 CeedCallBackend(CeedOperatorGetFallbackParentCeed(op, &ceed_parent)); 396 CeedCallBackend(CeedOperatorGetData(op, &impl)); 397 active_in = impl->qf_active_in; 398 num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 399 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 400 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 401 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 402 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 403 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 404 405 // Setup 406 CeedCallBackend(CeedOperatorSetup_Ref(op)); 407 408 // Check for restriction only operator 409 CeedCheck(!impl->is_identity_rstr_op, ceed, CEED_ERROR_BACKEND, "Assembling restriction only operators is not supported"); 410 411 // Input Evecs and Restriction 412 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data_full, impl, request)); 413 414 // Count number of active input fields 415 if (!num_active_in) { 416 for (CeedInt i = 0; i < num_input_fields; i++) { 417 CeedScalar *q_vec_array; 418 CeedVector vec; 419 420 // Get input vector 421 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 422 // Check if active input 423 if (vec == CEED_VECTOR_ACTIVE) { 424 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 425 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 426 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &q_vec_array)); 427 CeedCallBackend(CeedRealloc(num_active_in + size, &active_in)); 428 for (CeedInt field = 0; field < size; field++) { 429 q_size = (CeedSize)Q; 430 CeedCallBackend(CeedVectorCreate(ceed_parent, q_size, &active_in[num_active_in + field])); 431 CeedCallBackend(CeedVectorSetArray(active_in[num_active_in + field], CEED_MEM_HOST, CEED_USE_POINTER, &q_vec_array[field * Q])); 432 } 433 num_active_in += size; 434 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array)); 435 } 436 } 437 impl->num_active_in = num_active_in; 438 impl->qf_active_in = active_in; 439 } 440 441 // Count number of active output fields 442 if (!num_active_out) { 443 for (CeedInt i = 0; i < num_output_fields; i++) { 444 CeedVector vec; 445 446 // Get output vector 447 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 448 // Check if active output 449 if (vec == CEED_VECTOR_ACTIVE) { 450 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 451 num_active_out += size; 452 } 453 } 454 impl->num_active_out = num_active_out; 455 } 456 457 // Check sizes 458 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 459 460 // Build objects if needed 461 if (build_objects) { 462 const CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out; 463 CeedInt strides[3] = {1, Q, num_active_in * num_active_out * Q}; /* *NOPAD* */ 464 465 // Create output restriction 466 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out, 467 num_active_in * num_active_out * num_elem * Q, strides, rstr)); 468 // Create assembled vector 469 CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled)); 470 } 471 // Clear output vector 472 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 473 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_HOST, &assembled_array)); 474 475 // Loop through elements 476 for (CeedInt e = 0; e < num_elem; e++) { 477 // Input basis apply 478 CeedCallBackend(CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields, num_input_fields, true, e_data_full, impl)); 479 480 // Assemble QFunction 481 for (CeedInt in = 0; in < num_active_in; in++) { 482 // Set Inputs 483 CeedCallBackend(CeedVectorSetValue(active_in[in], 1.0)); 484 if (num_active_in > 1) { 485 CeedCallBackend(CeedVectorSetValue(active_in[(in + num_active_in - 1) % num_active_in], 0.0)); 486 } 487 if (!impl->is_identity_qf) { 488 // Set Outputs 489 for (CeedInt out = 0; out < num_output_fields; out++) { 490 CeedVector vec; 491 492 // Get output vector 493 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 494 // Check if active output 495 if (vec == CEED_VECTOR_ACTIVE) { 496 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_USE_POINTER, assembled_array)); 497 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size)); 498 assembled_array += size * Q; // Advance the pointer by the size of the output 499 } 500 } 501 // Apply QFunction 502 CeedCallBackend(CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out)); 503 } else { 504 const CeedScalar *q_vec_array; 505 506 // Copy Identity Outputs 507 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[0], &size)); 508 CeedCallBackend(CeedVectorGetArrayRead(impl->q_vecs_out[0], CEED_MEM_HOST, &q_vec_array)); 509 for (CeedInt i = 0; i < size * Q; i++) assembled_array[i] = q_vec_array[i]; 510 CeedCallBackend(CeedVectorRestoreArrayRead(impl->q_vecs_out[0], &q_vec_array)); 511 assembled_array += size * Q; 512 } 513 } 514 } 515 516 // Un-set output Qvecs to prevent accidental overwrite of Assembled 517 if (!impl->is_identity_qf) { 518 for (CeedInt out = 0; out < num_output_fields; out++) { 519 CeedVector vec; 520 521 // Get output vector 522 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 523 // Check if active output 524 if (vec == CEED_VECTOR_ACTIVE && num_elem > 0) { 525 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL)); 526 } 527 } 528 } 529 530 // Restore input arrays 531 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, true, e_data_full, impl)); 532 533 // Restore output 534 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 535 return CEED_ERROR_SUCCESS; 536 } 537 538 //------------------------------------------------------------------------------ 539 // Assemble Linear QFunction 540 //------------------------------------------------------------------------------ 541 static int CeedOperatorLinearAssembleQFunction_Ref(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 542 return CeedOperatorLinearAssembleQFunctionCore_Ref(op, true, assembled, rstr, request); 543 } 544 545 //------------------------------------------------------------------------------ 546 // Update Assembled Linear QFunction 547 //------------------------------------------------------------------------------ 548 static int CeedOperatorLinearAssembleQFunctionUpdate_Ref(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 549 return CeedOperatorLinearAssembleQFunctionCore_Ref(op, false, &assembled, &rstr, request); 550 } 551 552 //------------------------------------------------------------------------------ 553 // Setup Input/Output Fields 554 //------------------------------------------------------------------------------ 555 static int CeedOperatorSetupFieldsAtPoints_Ref(CeedQFunction qf, CeedOperator op, bool is_input, CeedVector *e_vecs_full, CeedVector *e_vecs, 556 CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) { 557 Ceed ceed; 558 CeedSize e_size, q_size; 559 CeedInt e_size_padding = 0, max_num_points, num_comp, size, P; 560 CeedQFunctionField *qf_fields; 561 CeedOperatorField *op_fields; 562 563 { 564 Ceed ceed_parent; 565 566 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 567 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 568 if (ceed_parent) ceed = ceed_parent; 569 } 570 if (is_input) { 571 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 572 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 573 } else { 574 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 575 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 576 } 577 578 // Get max number of points 579 { 580 CeedInt dim; 581 CeedElemRestriction rstr_points = NULL; 582 CeedOperator_Ref *impl; 583 584 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 585 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(rstr_points, &max_num_points)); 586 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_points, &dim)); 587 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 588 CeedCallBackend(CeedOperatorGetData(op, &impl)); 589 if (is_input) { 590 CeedCallBackend(CeedVectorCreate(ceed, dim * max_num_points, &impl->point_coords_elem)); 591 CeedCallBackend(CeedVectorSetValue(impl->point_coords_elem, 0.0)); 592 } 593 } 594 595 // Loop over fields 596 for (CeedInt i = 0; i < num_fields; i++) { 597 CeedEvalMode eval_mode; 598 CeedBasis basis; 599 600 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 601 if (eval_mode != CEED_EVAL_WEIGHT) { 602 CeedElemRestriction elem_rstr; 603 CeedSize e_size; 604 bool is_at_points; 605 606 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); 607 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp)); 608 CeedCallBackend(CeedElemRestrictionIsPoints(elem_rstr, &is_at_points)); 609 if (is_at_points) { 610 CeedCallBackend(CeedElemRestrictionGetEVectorSize(elem_rstr, &e_size)); 611 if (e_size_padding == 0) { 612 CeedInt num_points, num_elem; 613 614 CeedCallBackend(CeedElemRestrictionGetNumElements(elem_rstr, &num_elem)); 615 CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(elem_rstr, num_elem - 1, &num_points)); 616 e_size_padding = (max_num_points - num_points) * num_comp; 617 } 618 CeedCallBackend(CeedVectorCreate(ceed, e_size + e_size_padding, &e_vecs_full[i + start_e])); 619 CeedCallBackend(CeedVectorSetValue(e_vecs_full[i + start_e], 0.0)); 620 } else { 621 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs_full[i + start_e])); 622 } 623 } 624 625 switch (eval_mode) { 626 case CEED_EVAL_NONE: { 627 CeedVector vec; 628 629 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 630 e_size = (CeedSize)max_num_points * size; 631 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 632 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 633 if (vec == CEED_VECTOR_ACTIVE || !is_input) { 634 CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &q_vecs[i])); 635 } else { 636 q_size = (CeedSize)max_num_points * size; 637 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 638 } 639 break; 640 } 641 case CEED_EVAL_INTERP: 642 case CEED_EVAL_GRAD: 643 case CEED_EVAL_DIV: 644 case CEED_EVAL_CURL: 645 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 646 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 647 CeedCallBackend(CeedBasisGetNumNodes(basis, &P)); 648 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 649 e_size = (CeedSize)P * num_comp; 650 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 651 q_size = (CeedSize)max_num_points * size; 652 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 653 break; 654 case CEED_EVAL_WEIGHT: // Only on input fields 655 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 656 q_size = (CeedSize)max_num_points; 657 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 658 CeedCallBackend( 659 CeedBasisApplyAtPoints(basis, max_num_points, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, CEED_VECTOR_NONE, q_vecs[i])); 660 break; 661 } 662 // Initialize full arrays for E-vectors and Q-vectors 663 if (e_vecs[i]) CeedCallBackend(CeedVectorSetValue(e_vecs[i], 0.0)); 664 if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorSetValue(q_vecs[i], 0.0)); 665 } 666 return CEED_ERROR_SUCCESS; 667 } 668 669 //------------------------------------------------------------------------------ 670 // Setup Operator 671 //------------------------------------------------------------------------------ 672 static int CeedOperatorSetupAtPoints_Ref(CeedOperator op) { 673 bool is_setup_done; 674 CeedInt Q, num_input_fields, num_output_fields; 675 CeedQFunctionField *qf_input_fields, *qf_output_fields; 676 CeedQFunction qf; 677 CeedOperatorField *op_input_fields, *op_output_fields; 678 CeedOperator_Ref *impl; 679 680 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 681 if (is_setup_done) return CEED_ERROR_SUCCESS; 682 683 CeedCallBackend(CeedOperatorGetData(op, &impl)); 684 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 685 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 686 CeedCallBackend(CeedQFunctionIsIdentity(qf, &impl->is_identity_qf)); 687 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 688 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 689 690 // Allocate 691 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); 692 693 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 694 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); 695 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); 696 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 697 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 698 699 impl->num_inputs = num_input_fields; 700 impl->num_outputs = num_output_fields; 701 702 // Set up infield and outfield pointer arrays 703 // Infields 704 CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, true, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q)); 705 // Outfields 706 CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, false, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields, 707 num_output_fields, Q)); 708 709 // Identity QFunctions 710 if (impl->is_identity_qf) { 711 CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->q_vecs_out[0])); 712 CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->e_vecs_out[0])); 713 } 714 715 CeedCallBackend(CeedOperatorSetSetupDone(op)); 716 return CEED_ERROR_SUCCESS; 717 } 718 719 //------------------------------------------------------------------------------ 720 // Input Basis Action 721 //------------------------------------------------------------------------------ 722 static inline int CeedOperatorInputBasisAtPoints_Ref(CeedInt e, CeedInt num_points_offset, CeedInt num_points, CeedQFunctionField *qf_input_fields, 723 CeedOperatorField *op_input_fields, CeedInt num_input_fields, CeedVector in_vec, 724 CeedVector point_coords_elem, bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 725 CeedOperator_Ref *impl, CeedRequest *request) { 726 for (CeedInt i = 0; i < num_input_fields; i++) { 727 bool is_active_input = false; 728 CeedInt elem_size, size, num_comp; 729 CeedRestrictionType rstr_type; 730 CeedEvalMode eval_mode; 731 CeedVector vec; 732 CeedElemRestriction elem_rstr; 733 CeedBasis basis; 734 735 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 736 // Skip active input 737 is_active_input = vec == CEED_VECTOR_ACTIVE; 738 if (skip_active && is_active_input) continue; 739 740 // Get elem_size, eval_mode, size 741 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 742 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 743 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 744 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 745 // Restrict block active input 746 if (is_active_input) { 747 if (rstr_type == CEED_RESTRICTION_POINTS) { 748 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(elem_rstr, e, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request)); 749 } else { 750 CeedCallBackend(CeedElemRestrictionApplyBlock(elem_rstr, e, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request)); 751 } 752 } 753 // Basis action 754 switch (eval_mode) { 755 case CEED_EVAL_NONE: 756 if (!is_active_input) { 757 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data[i][num_points_offset * size])); 758 } 759 break; 760 // Note - these basis eval modes require FEM fields 761 case CEED_EVAL_INTERP: 762 case CEED_EVAL_GRAD: 763 case CEED_EVAL_DIV: 764 case CEED_EVAL_CURL: 765 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 766 if (!is_active_input) { 767 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 768 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 769 CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data[i][(CeedSize)e * elem_size * num_comp])); 770 } 771 CeedCallBackend( 772 CeedBasisApplyAtPoints(basis, num_points, CEED_NOTRANSPOSE, eval_mode, point_coords_elem, impl->e_vecs_in[i], impl->q_vecs_in[i])); 773 break; 774 case CEED_EVAL_WEIGHT: 775 break; // No action 776 } 777 } 778 return CEED_ERROR_SUCCESS; 779 } 780 781 //------------------------------------------------------------------------------ 782 // Output Basis Action 783 //------------------------------------------------------------------------------ 784 static inline int CeedOperatorOutputBasisAtPoints_Ref(CeedInt e, CeedInt num_points_offset, CeedInt num_points, CeedQFunctionField *qf_output_fields, 785 CeedOperatorField *op_output_fields, CeedInt num_input_fields, CeedInt num_output_fields, 786 CeedOperator op, CeedVector out_vec, CeedVector point_coords_elem, CeedOperator_Ref *impl, 787 CeedRequest *request) { 788 for (CeedInt i = 0; i < num_output_fields; i++) { 789 CeedRestrictionType rstr_type; 790 CeedEvalMode eval_mode; 791 CeedVector vec; 792 CeedElemRestriction elem_rstr; 793 CeedBasis basis; 794 795 // Get elem_size, eval_mode, size 796 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 797 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 798 // Basis action 799 switch (eval_mode) { 800 case CEED_EVAL_NONE: 801 break; // No action 802 case CEED_EVAL_INTERP: 803 case CEED_EVAL_GRAD: 804 case CEED_EVAL_DIV: 805 case CEED_EVAL_CURL: 806 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 807 CeedCallBackend( 808 CeedBasisApplyAtPoints(basis, num_points, CEED_TRANSPOSE, eval_mode, point_coords_elem, impl->q_vecs_out[i], impl->e_vecs_out[i])); 809 break; 810 // LCOV_EXCL_START 811 case CEED_EVAL_WEIGHT: { 812 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 813 // LCOV_EXCL_STOP 814 } 815 } 816 // Restrict output block 817 // Get output vector 818 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 819 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 820 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 821 // Restrict 822 if (rstr_type == CEED_RESTRICTION_POINTS) { 823 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(elem_rstr, e, CEED_TRANSPOSE, impl->e_vecs_out[i], vec, request)); 824 } else { 825 CeedCallBackend(CeedElemRestrictionApplyBlock(elem_rstr, e, CEED_TRANSPOSE, impl->e_vecs_out[i], vec, request)); 826 } 827 } 828 return CEED_ERROR_SUCCESS; 829 } 830 831 //------------------------------------------------------------------------------ 832 // Operator Apply 833 //------------------------------------------------------------------------------ 834 static int CeedOperatorApplyAddAtPoints_Ref(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 835 CeedInt num_points_offset = 0, num_input_fields, num_output_fields, num_elem; 836 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {0}; 837 CeedVector point_coords = NULL; 838 CeedElemRestriction rstr_points = NULL; 839 CeedQFunctionField *qf_input_fields, *qf_output_fields; 840 CeedQFunction qf; 841 CeedOperatorField *op_input_fields, *op_output_fields; 842 CeedOperator_Ref *impl; 843 844 CeedCallBackend(CeedOperatorGetData(op, &impl)); 845 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 846 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 847 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 848 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 849 850 // Setup 851 CeedCallBackend(CeedOperatorSetupAtPoints_Ref(op)); 852 853 // Point coordinates 854 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 855 856 // Input Evecs and Restriction 857 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 858 859 // Loop through elements 860 for (CeedInt e = 0; e < num_elem; e++) { 861 CeedInt num_points; 862 863 // Setup points for element 864 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(rstr_points, e, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 865 CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points)); 866 867 // Input basis apply 868 CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, in_vec, 869 impl->point_coords_elem, false, e_data, impl, request)); 870 871 // Q function 872 if (!impl->is_identity_qf) { 873 CeedCallBackend(CeedQFunctionApply(qf, num_points, impl->q_vecs_in, impl->q_vecs_out)); 874 } 875 876 // Output basis apply and restriction 877 CeedCallBackend(CeedOperatorOutputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_output_fields, op_output_fields, num_input_fields, 878 num_output_fields, op, out_vec, impl->point_coords_elem, impl, request)); 879 880 num_points_offset += num_points; 881 } 882 883 // Restore input arrays 884 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 885 886 // Cleanup point coordinates 887 CeedCallBackend(CeedVectorDestroy(&point_coords)); 888 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 889 return CEED_ERROR_SUCCESS; 890 } 891 892 //------------------------------------------------------------------------------ 893 // Core code for assembling linear QFunction 894 //------------------------------------------------------------------------------ 895 static inline int CeedOperatorLinearAssembleQFunctionAtPointsCore_Ref(CeedOperator op, bool build_objects, CeedVector *assembled, 896 CeedElemRestriction *rstr, CeedRequest *request) { 897 Ceed ceed; 898 CeedSize q_size; 899 CeedInt num_active_in, num_active_out, max_num_points, num_elem, num_input_fields, num_output_fields, num_points_offset = 0; 900 CeedScalar *assembled_array, *e_data_full[2 * CEED_FIELD_MAX] = {NULL}; 901 CeedVector *active_in, point_coords = NULL; 902 CeedQFunctionField *qf_input_fields, *qf_output_fields; 903 CeedQFunction qf; 904 CeedOperatorField *op_input_fields, *op_output_fields; 905 CeedOperator_Ref *impl; 906 CeedElemRestriction rstr_points = NULL; 907 908 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 909 CeedCallBackend(CeedOperatorGetData(op, &impl)); 910 active_in = impl->qf_active_in; 911 num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 912 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 913 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 914 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 915 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 916 917 // Setup 918 CeedCallBackend(CeedOperatorSetupAtPoints_Ref(op)); 919 920 // Check for restriction only operator 921 CeedCheck(!impl->is_identity_rstr_op, ceed, CEED_ERROR_BACKEND, "Assembling restriction only operators is not supported"); 922 923 // Point coordinates 924 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 925 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(rstr_points, &max_num_points)); 926 927 // Input Evecs and Restriction 928 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data_full, impl, request)); 929 930 // Count number of active input fields 931 if (!num_active_in) { 932 for (CeedInt i = 0; i < num_input_fields; i++) { 933 CeedScalar *q_vec_array; 934 CeedInt field_size; 935 CeedVector vec; 936 937 // Get input vector 938 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 939 // Check if active input 940 if (vec == CEED_VECTOR_ACTIVE) { 941 // Check that all active inputs are nodal fields 942 { 943 CeedElemRestriction elem_rstr; 944 bool is_at_points = false; 945 946 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 947 CeedCallBackend(CeedElemRestrictionIsPoints(elem_rstr, &is_at_points)); 948 CeedCheck(!is_at_points, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction with active input at points"); 949 } 950 // Get size of active input 951 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size)); 952 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 953 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &q_vec_array)); 954 CeedCallBackend(CeedRealloc(num_active_in + field_size, &active_in)); 955 for (CeedInt field = 0; field < field_size; field++) { 956 q_size = (CeedSize)max_num_points; 957 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_in[num_active_in + field])); 958 CeedCallBackend(CeedVectorSetArray(active_in[num_active_in + field], CEED_MEM_HOST, CEED_USE_POINTER, &q_vec_array[field * q_size])); 959 } 960 num_active_in += field_size; 961 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array)); 962 } 963 } 964 impl->num_active_in = num_active_in; 965 impl->qf_active_in = active_in; 966 } 967 968 // Count number of active output fields 969 if (!num_active_out) { 970 for (CeedInt i = 0; i < num_output_fields; i++) { 971 CeedVector vec; 972 CeedInt field_size; 973 974 // Get output vector 975 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 976 // Check if active output 977 if (vec == CEED_VECTOR_ACTIVE) { 978 // Check that all active inputs are nodal fields 979 { 980 CeedElemRestriction elem_rstr; 981 bool is_at_points = false; 982 983 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 984 CeedCallBackend(CeedElemRestrictionIsPoints(elem_rstr, &is_at_points)); 985 CeedCheck(!is_at_points, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction with active input at points"); 986 } 987 // Get size of active output 988 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size)); 989 num_active_out += field_size; 990 } 991 } 992 impl->num_active_out = num_active_out; 993 } 994 995 // Check sizes 996 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 997 998 // Build objects if needed 999 if (build_objects) { 1000 CeedInt num_points_total; 1001 const CeedInt *offsets; 1002 1003 CeedCallBackend(CeedElemRestrictionGetNumPoints(rstr_points, &num_points_total)); 1004 1005 // Create output restriction (at points) 1006 CeedCallBackend(CeedElemRestrictionGetOffsets(rstr_points, CEED_MEM_HOST, &offsets)); 1007 CeedCallBackend(CeedElemRestrictionCreateAtPoints(ceed, num_elem, num_points_total, num_active_in * num_active_out, 1008 num_active_in * num_active_out * num_points_total, CEED_MEM_HOST, CEED_COPY_VALUES, offsets, 1009 rstr)); 1010 CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr_points, &offsets)); 1011 1012 // Create assembled vector 1013 CeedCallBackend(CeedElemRestrictionCreateVector(*rstr, assembled, NULL)); 1014 } 1015 // Clear output vector 1016 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 1017 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_HOST, &assembled_array)); 1018 1019 // Loop through elements 1020 for (CeedInt e = 0; e < num_elem; e++) { 1021 CeedInt num_points; 1022 1023 // Setup points for element 1024 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(rstr_points, e, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 1025 CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points)); 1026 1027 // Input basis apply 1028 CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, NULL, 1029 impl->point_coords_elem, true, e_data_full, impl, request)); 1030 1031 // Assemble QFunction 1032 for (CeedInt in = 0; in < num_active_in; in++) { 1033 // Set Inputs 1034 CeedCallBackend(CeedVectorSetValue(active_in[in], 1.0)); 1035 if (num_active_in > 1) { 1036 CeedCallBackend(CeedVectorSetValue(active_in[(in + num_active_in - 1) % num_active_in], 0.0)); 1037 } 1038 if (!impl->is_identity_qf) { 1039 // Set Outputs 1040 for (CeedInt out = 0; out < num_output_fields; out++) { 1041 CeedVector vec; 1042 CeedInt field_size; 1043 1044 // Get output vector 1045 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 1046 // Check if active output 1047 if (vec == CEED_VECTOR_ACTIVE) { 1048 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_USE_POINTER, assembled_array)); 1049 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &field_size)); 1050 assembled_array += field_size * num_points; // Advance the pointer by the size of the output 1051 } 1052 } 1053 // Apply QFunction 1054 CeedCallBackend(CeedQFunctionApply(qf, num_points, impl->q_vecs_in, impl->q_vecs_out)); 1055 } else { 1056 const CeedScalar *q_vec_array; 1057 CeedInt field_size; 1058 1059 // Copy Identity Outputs 1060 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[0], &field_size)); 1061 CeedCallBackend(CeedVectorGetArrayRead(impl->q_vecs_out[0], CEED_MEM_HOST, &q_vec_array)); 1062 for (CeedInt i = 0; i < field_size * num_points; i++) assembled_array[i] = q_vec_array[i]; 1063 CeedCallBackend(CeedVectorRestoreArrayRead(impl->q_vecs_out[0], &q_vec_array)); 1064 assembled_array += field_size * num_points; 1065 } 1066 } 1067 num_points_offset += num_points; 1068 } 1069 1070 // Un-set output Qvecs to prevent accidental overwrite of Assembled 1071 if (!impl->is_identity_qf) { 1072 for (CeedInt out = 0; out < num_output_fields; out++) { 1073 CeedVector vec; 1074 1075 // Get output vector 1076 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 1077 // Check if active output 1078 if (vec == CEED_VECTOR_ACTIVE && num_elem > 0) { 1079 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL)); 1080 } 1081 } 1082 } 1083 1084 // Restore input arrays 1085 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, true, e_data_full, impl)); 1086 1087 // Restore output 1088 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 1089 1090 // Cleanup 1091 CeedCallBackend(CeedVectorDestroy(&point_coords)); 1092 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 1093 return CEED_ERROR_SUCCESS; 1094 } 1095 1096 //------------------------------------------------------------------------------ 1097 // Assemble Linear QFunction 1098 //------------------------------------------------------------------------------ 1099 static int CeedOperatorLinearAssembleQFunctionAtPoints_Ref(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 1100 return CeedOperatorLinearAssembleQFunctionAtPointsCore_Ref(op, true, assembled, rstr, request); 1101 } 1102 1103 //------------------------------------------------------------------------------ 1104 // Update Assembled Linear QFunction 1105 //------------------------------------------------------------------------------ 1106 static int CeedOperatorLinearAssembleQFunctionAtPointsUpdate_Ref(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, 1107 CeedRequest *request) { 1108 return CeedOperatorLinearAssembleQFunctionAtPointsCore_Ref(op, false, &assembled, &rstr, request); 1109 } 1110 1111 //------------------------------------------------------------------------------ 1112 // Assemble Operator 1113 //------------------------------------------------------------------------------ 1114 1115 //------------------------------------------------------------------------------ 1116 // Operator Destroy 1117 //------------------------------------------------------------------------------ 1118 static int CeedOperatorDestroy_Ref(CeedOperator op) { 1119 CeedOperator_Ref *impl; 1120 1121 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1122 for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { 1123 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i])); 1124 } 1125 CeedCallBackend(CeedFree(&impl->e_vecs_full)); 1126 CeedCallBackend(CeedFree(&impl->input_states)); 1127 1128 for (CeedInt i = 0; i < impl->num_inputs; i++) { 1129 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i])); 1130 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 1131 } 1132 CeedCallBackend(CeedFree(&impl->e_vecs_in)); 1133 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 1134 1135 for (CeedInt i = 0; i < impl->num_outputs; i++) { 1136 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i])); 1137 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 1138 } 1139 CeedCallBackend(CeedFree(&impl->e_vecs_out)); 1140 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 1141 CeedCallBackend(CeedVectorDestroy(&impl->point_coords_elem)); 1142 1143 // QFunction assembly 1144 for (CeedInt i = 0; i < impl->num_active_in; i++) { 1145 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i])); 1146 } 1147 CeedCallBackend(CeedFree(&impl->qf_active_in)); 1148 1149 CeedCallBackend(CeedFree(&impl)); 1150 return CEED_ERROR_SUCCESS; 1151 } 1152 1153 //------------------------------------------------------------------------------ 1154 // Operator Create 1155 //------------------------------------------------------------------------------ 1156 int CeedOperatorCreate_Ref(CeedOperator op) { 1157 Ceed ceed; 1158 CeedOperator_Ref *impl; 1159 1160 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1161 CeedCallBackend(CeedCalloc(1, &impl)); 1162 CeedCallBackend(CeedOperatorSetData(op, impl)); 1163 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Ref)); 1164 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Ref)); 1165 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Ref)); 1166 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Ref)); 1167 return CEED_ERROR_SUCCESS; 1168 } 1169 1170 //------------------------------------------------------------------------------ 1171 // Operator Create At Points 1172 //------------------------------------------------------------------------------ 1173 int CeedOperatorCreateAtPoints_Ref(CeedOperator op) { 1174 Ceed ceed; 1175 CeedOperator_Ref *impl; 1176 1177 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1178 CeedCallBackend(CeedCalloc(1, &impl)); 1179 CeedCallBackend(CeedOperatorSetData(op, impl)); 1180 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunctionAtPoints_Ref)); 1181 CeedCallBackend( 1182 CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionAtPointsUpdate_Ref)); 1183 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Ref)); 1184 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Ref)); 1185 return CEED_ERROR_SUCCESS; 1186 } 1187 1188 //------------------------------------------------------------------------------ 1189