1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <stdbool.h> 11 #include <stddef.h> 12 #include <stdint.h> 13 14 #include "ceed-ref.h" 15 16 //------------------------------------------------------------------------------ 17 // Setup Input/Output Fields 18 //------------------------------------------------------------------------------ 19 static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op, bool is_input, CeedVector *e_vecs_full, CeedVector *e_vecs, 20 CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) { 21 Ceed ceed; 22 CeedSize e_size, q_size; 23 CeedInt num_comp, size, P; 24 CeedQFunctionField *qf_fields; 25 CeedOperatorField *op_fields; 26 27 { 28 Ceed ceed_parent; 29 30 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 31 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 32 if (ceed_parent) ceed = ceed_parent; 33 } 34 if (is_input) { 35 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 36 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 37 } else { 38 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 39 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 40 } 41 42 // Loop over fields 43 for (CeedInt i = 0; i < num_fields; i++) { 44 CeedEvalMode eval_mode; 45 CeedElemRestriction elem_rstr; 46 CeedBasis basis; 47 48 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 49 if (eval_mode != CEED_EVAL_WEIGHT) { 50 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); 51 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs_full[i + start_e])); 52 } 53 54 switch (eval_mode) { 55 case CEED_EVAL_NONE: 56 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 57 q_size = (CeedSize)Q * size; 58 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 59 break; 60 case CEED_EVAL_INTERP: 61 case CEED_EVAL_GRAD: 62 case CEED_EVAL_DIV: 63 case CEED_EVAL_CURL: 64 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 65 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 66 CeedCallBackend(CeedBasisGetNumNodes(basis, &P)); 67 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 68 e_size = (CeedSize)P * num_comp; 69 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 70 q_size = (CeedSize)Q * size; 71 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 72 break; 73 case CEED_EVAL_WEIGHT: // Only on input fields 74 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 75 q_size = (CeedSize)Q; 76 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 77 CeedCallBackend(CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 78 break; 79 } 80 } 81 return CEED_ERROR_SUCCESS; 82 } 83 84 //------------------------------------------------------------------------------ 85 // Setup Operator 86 //------------------------------------------------------------------------------/* 87 static int CeedOperatorSetup_Ref(CeedOperator op) { 88 bool is_setup_done; 89 CeedInt Q, num_input_fields, num_output_fields; 90 CeedQFunctionField *qf_input_fields, *qf_output_fields; 91 CeedQFunction qf; 92 CeedOperatorField *op_input_fields, *op_output_fields; 93 CeedOperator_Ref *impl; 94 95 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 96 if (is_setup_done) return CEED_ERROR_SUCCESS; 97 98 CeedCallBackend(CeedOperatorGetData(op, &impl)); 99 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 100 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 101 CeedCallBackend(CeedQFunctionIsIdentity(qf, &impl->is_identity_qf)); 102 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 103 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 104 105 // Allocate 106 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); 107 108 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 109 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); 110 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); 111 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 112 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 113 114 impl->num_inputs = num_input_fields; 115 impl->num_outputs = num_output_fields; 116 117 // Set up infield and outfield e_vecs and q_vecs 118 // Infields 119 CeedCallBackend(CeedOperatorSetupFields_Ref(qf, op, true, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q)); 120 // Outfields 121 CeedCallBackend( 122 CeedOperatorSetupFields_Ref(qf, op, false, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields, num_output_fields, Q)); 123 124 // Identity QFunctions 125 if (impl->is_identity_qf) { 126 CeedEvalMode in_mode, out_mode; 127 CeedQFunctionField *in_fields, *out_fields; 128 129 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields)); 130 CeedCallBackend(CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode)); 131 CeedCallBackend(CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode)); 132 133 if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) { 134 impl->is_identity_rstr_op = true; 135 } else { 136 CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->q_vecs_out[0])); 137 } 138 } 139 140 CeedCallBackend(CeedOperatorSetSetupDone(op)); 141 return CEED_ERROR_SUCCESS; 142 } 143 144 //------------------------------------------------------------------------------ 145 // Setup Operator Inputs 146 //------------------------------------------------------------------------------ 147 static inline int CeedOperatorSetupInputs_Ref(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 148 CeedVector in_vec, const bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], 149 CeedOperator_Ref *impl, CeedRequest *request) { 150 for (CeedInt i = 0; i < num_input_fields; i++) { 151 uint64_t state; 152 CeedEvalMode eval_mode; 153 CeedVector vec; 154 CeedElemRestriction elem_rstr; 155 156 // Get input vector 157 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 158 if (vec == CEED_VECTOR_ACTIVE) { 159 if (skip_active) continue; 160 else vec = in_vec; 161 } 162 163 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 164 // Restrict and Evec 165 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 166 } else { 167 // Restrict 168 CeedCallBackend(CeedVectorGetState(vec, &state)); 169 // Skip restriction if input is unchanged 170 if (state != impl->input_states[i] || vec == in_vec) { 171 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 172 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request)); 173 impl->input_states[i] = state; 174 } 175 // Get evec 176 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, (const CeedScalar **)&e_data_full[i])); 177 } 178 } 179 return CEED_ERROR_SUCCESS; 180 } 181 182 //------------------------------------------------------------------------------ 183 // Input Basis Action 184 //------------------------------------------------------------------------------ 185 static inline int CeedOperatorInputBasis_Ref(CeedInt e, CeedInt Q, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 186 CeedInt num_input_fields, const bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], 187 CeedOperator_Ref *impl) { 188 for (CeedInt i = 0; i < num_input_fields; i++) { 189 CeedInt elem_size, size, num_comp; 190 CeedEvalMode eval_mode; 191 CeedElemRestriction elem_rstr; 192 CeedBasis basis; 193 194 // Skip active input 195 if (skip_active) { 196 CeedVector vec; 197 198 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 199 if (vec == CEED_VECTOR_ACTIVE) continue; 200 } 201 // Get elem_size, eval_mode, size 202 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 203 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 204 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 205 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 206 // Basis action 207 switch (eval_mode) { 208 case CEED_EVAL_NONE: 209 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][e * Q * size])); 210 break; 211 case CEED_EVAL_INTERP: 212 case CEED_EVAL_GRAD: 213 case CEED_EVAL_DIV: 214 case CEED_EVAL_CURL: 215 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 216 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 217 CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][e * elem_size * num_comp])); 218 CeedCallBackend(CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs_in[i], impl->q_vecs_in[i])); 219 break; 220 case CEED_EVAL_WEIGHT: 221 break; // No action 222 } 223 } 224 return CEED_ERROR_SUCCESS; 225 } 226 227 //------------------------------------------------------------------------------ 228 // Output Basis Action 229 //------------------------------------------------------------------------------ 230 static inline int CeedOperatorOutputBasis_Ref(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, 231 CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op, 232 CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Ref *impl) { 233 for (CeedInt i = 0; i < num_output_fields; i++) { 234 CeedInt elem_size, num_comp; 235 CeedEvalMode eval_mode; 236 CeedElemRestriction elem_rstr; 237 CeedBasis basis; 238 239 // Get elem_size, eval_mode 240 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 241 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 242 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 243 // Basis action 244 switch (eval_mode) { 245 case CEED_EVAL_NONE: 246 break; // No action 247 case CEED_EVAL_INTERP: 248 case CEED_EVAL_GRAD: 249 case CEED_EVAL_DIV: 250 case CEED_EVAL_CURL: 251 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 252 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 253 CeedCallBackend( 254 CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][e * elem_size * num_comp])); 255 CeedCallBackend(CeedBasisApply(basis, 1, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); 256 break; 257 // LCOV_EXCL_START 258 case CEED_EVAL_WEIGHT: { 259 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 260 // LCOV_EXCL_STOP 261 } 262 } 263 } 264 return CEED_ERROR_SUCCESS; 265 } 266 267 //------------------------------------------------------------------------------ 268 // Restore Input Vectors 269 //------------------------------------------------------------------------------ 270 static inline int CeedOperatorRestoreInputs_Ref(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 271 const bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Ref *impl) { 272 for (CeedInt i = 0; i < num_input_fields; i++) { 273 CeedEvalMode eval_mode; 274 275 // Skip active inputs 276 if (skip_active) { 277 CeedVector vec; 278 279 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 280 if (vec == CEED_VECTOR_ACTIVE) continue; 281 } 282 // Restore input 283 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 284 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 285 } else { 286 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_full[i], (const CeedScalar **)&e_data_full[i])); 287 } 288 } 289 return CEED_ERROR_SUCCESS; 290 } 291 292 //------------------------------------------------------------------------------ 293 // Operator Apply 294 //------------------------------------------------------------------------------ 295 static int CeedOperatorApplyAdd_Ref(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 296 CeedInt Q, num_elem, num_input_fields, num_output_fields, size; 297 CeedEvalMode eval_mode; 298 CeedScalar *e_data_full[2 * CEED_FIELD_MAX] = {NULL}; 299 CeedQFunctionField *qf_input_fields, *qf_output_fields; 300 CeedQFunction qf; 301 CeedOperatorField *op_input_fields, *op_output_fields; 302 CeedOperator_Ref *impl; 303 304 CeedCallBackend(CeedOperatorGetData(op, &impl)); 305 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 306 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 307 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 308 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 309 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 310 311 // Setup 312 CeedCallBackend(CeedOperatorSetup_Ref(op)); 313 314 // Restriction only operator 315 if (impl->is_identity_rstr_op) { 316 CeedElemRestriction elem_rstr; 317 318 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[0], &elem_rstr)); 319 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_full[0], request)); 320 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[0], &elem_rstr)); 321 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs_full[0], out_vec, request)); 322 return CEED_ERROR_SUCCESS; 323 } 324 325 // Input Evecs and Restriction 326 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data_full, impl, request)); 327 328 // Output Evecs 329 for (CeedInt i = 0; i < num_output_fields; i++) { 330 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields])); 331 } 332 333 // Loop through elements 334 for (CeedInt e = 0; e < num_elem; e++) { 335 // Output pointers 336 for (CeedInt i = 0; i < num_output_fields; i++) { 337 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 338 if (eval_mode == CEED_EVAL_NONE) { 339 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 340 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][e * Q * size])); 341 } 342 } 343 344 // Input basis apply 345 CeedCallBackend(CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields, num_input_fields, false, e_data_full, impl)); 346 347 // Q function 348 if (!impl->is_identity_qf) { 349 CeedCallBackend(CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out)); 350 } 351 352 // Output basis apply 353 CeedCallBackend( 354 CeedOperatorOutputBasis_Ref(e, Q, qf_output_fields, op_output_fields, num_input_fields, num_output_fields, op, e_data_full, impl)); 355 } 356 357 // Output restriction 358 for (CeedInt i = 0; i < num_output_fields; i++) { 359 CeedVector vec; 360 CeedElemRestriction elem_rstr; 361 362 // Restore Evec 363 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_full[i + impl->num_inputs], &e_data_full[i + num_input_fields])); 364 // Get output vector 365 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 366 // Active 367 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 368 // Restrict 369 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 370 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs_full[i + impl->num_inputs], vec, request)); 371 } 372 373 // Restore input arrays 374 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, false, e_data_full, impl)); 375 return CEED_ERROR_SUCCESS; 376 } 377 378 //------------------------------------------------------------------------------ 379 // Core code for assembling linear QFunction 380 //------------------------------------------------------------------------------ 381 static inline int CeedOperatorLinearAssembleQFunctionCore_Ref(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 382 CeedRequest *request) { 383 Ceed ceed, ceed_parent; 384 CeedSize q_size; 385 CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size; 386 CeedScalar *assembled_array, *e_data_full[2 * CEED_FIELD_MAX] = {NULL}; 387 CeedVector *active_in; 388 CeedQFunctionField *qf_input_fields, *qf_output_fields; 389 CeedQFunction qf; 390 CeedOperatorField *op_input_fields, *op_output_fields; 391 CeedOperator_Ref *impl; 392 393 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 394 CeedCallBackend(CeedOperatorGetFallbackParentCeed(op, &ceed_parent)); 395 CeedCallBackend(CeedOperatorGetData(op, &impl)); 396 active_in = impl->qf_active_in; 397 num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 398 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 399 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 400 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 401 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 402 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 403 404 // Setup 405 CeedCallBackend(CeedOperatorSetup_Ref(op)); 406 407 // Check for restriction only operator 408 CeedCheck(!impl->is_identity_rstr_op, ceed, CEED_ERROR_BACKEND, "Assembling restriction only operators is not supported"); 409 410 // Input Evecs and Restriction 411 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data_full, impl, request)); 412 413 // Count number of active input fields 414 if (!num_active_in) { 415 for (CeedInt i = 0; i < num_input_fields; i++) { 416 CeedScalar *q_vec_array; 417 CeedVector vec; 418 419 // Get input vector 420 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 421 // Check if active input 422 if (vec == CEED_VECTOR_ACTIVE) { 423 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 424 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 425 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &q_vec_array)); 426 CeedCallBackend(CeedRealloc(num_active_in + size, &active_in)); 427 for (CeedInt field = 0; field < size; field++) { 428 q_size = (CeedSize)Q; 429 CeedCallBackend(CeedVectorCreate(ceed_parent, q_size, &active_in[num_active_in + field])); 430 CeedCallBackend(CeedVectorSetArray(active_in[num_active_in + field], CEED_MEM_HOST, CEED_USE_POINTER, &q_vec_array[field * Q])); 431 } 432 num_active_in += size; 433 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array)); 434 } 435 } 436 impl->num_active_in = num_active_in; 437 impl->qf_active_in = active_in; 438 } 439 440 // Count number of active output fields 441 if (!num_active_out) { 442 for (CeedInt i = 0; i < num_output_fields; i++) { 443 CeedVector vec; 444 445 // Get output vector 446 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 447 // Check if active output 448 if (vec == CEED_VECTOR_ACTIVE) { 449 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 450 num_active_out += size; 451 } 452 } 453 impl->num_active_out = num_active_out; 454 } 455 456 // Check sizes 457 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 458 459 // Build objects if needed 460 if (build_objects) { 461 const CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out; 462 CeedInt strides[3] = {1, Q, num_active_in * num_active_out * Q}; /* *NOPAD* */ 463 464 // Create output restriction 465 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out, 466 num_active_in * num_active_out * num_elem * Q, strides, rstr)); 467 // Create assembled vector 468 CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled)); 469 } 470 // Clear output vector 471 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 472 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_HOST, &assembled_array)); 473 474 // Loop through elements 475 for (CeedInt e = 0; e < num_elem; e++) { 476 // Input basis apply 477 CeedCallBackend(CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields, num_input_fields, true, e_data_full, impl)); 478 479 // Assemble QFunction 480 for (CeedInt in = 0; in < num_active_in; in++) { 481 // Set Inputs 482 CeedCallBackend(CeedVectorSetValue(active_in[in], 1.0)); 483 if (num_active_in > 1) { 484 CeedCallBackend(CeedVectorSetValue(active_in[(in + num_active_in - 1) % num_active_in], 0.0)); 485 } 486 if (!impl->is_identity_qf) { 487 // Set Outputs 488 for (CeedInt out = 0; out < num_output_fields; out++) { 489 CeedVector vec; 490 491 // Get output vector 492 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 493 // Check if active output 494 if (vec == CEED_VECTOR_ACTIVE) { 495 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_USE_POINTER, assembled_array)); 496 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size)); 497 assembled_array += size * Q; // Advance the pointer by the size of the output 498 } 499 } 500 // Apply QFunction 501 CeedCallBackend(CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out)); 502 } else { 503 const CeedScalar *q_vec_array; 504 505 // Copy Identity Outputs 506 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[0], &size)); 507 CeedCallBackend(CeedVectorGetArrayRead(impl->q_vecs_out[0], CEED_MEM_HOST, &q_vec_array)); 508 for (CeedInt i = 0; i < size * Q; i++) assembled_array[i] = q_vec_array[i]; 509 CeedCallBackend(CeedVectorRestoreArrayRead(impl->q_vecs_out[0], &q_vec_array)); 510 assembled_array += size * Q; 511 } 512 } 513 } 514 515 // Un-set output Qvecs to prevent accidental overwrite of Assembled 516 if (!impl->is_identity_qf) { 517 for (CeedInt out = 0; out < num_output_fields; out++) { 518 CeedVector vec; 519 520 // Get output vector 521 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 522 // Check if active output 523 if (vec == CEED_VECTOR_ACTIVE && num_elem > 0) { 524 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL)); 525 } 526 } 527 } 528 529 // Restore input arrays 530 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, true, e_data_full, impl)); 531 532 // Restore output 533 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 534 return CEED_ERROR_SUCCESS; 535 } 536 537 //------------------------------------------------------------------------------ 538 // Assemble Linear QFunction 539 //------------------------------------------------------------------------------ 540 static int CeedOperatorLinearAssembleQFunction_Ref(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 541 return CeedOperatorLinearAssembleQFunctionCore_Ref(op, true, assembled, rstr, request); 542 } 543 544 //------------------------------------------------------------------------------ 545 // Update Assembled Linear QFunction 546 //------------------------------------------------------------------------------ 547 static int CeedOperatorLinearAssembleQFunctionUpdate_Ref(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 548 return CeedOperatorLinearAssembleQFunctionCore_Ref(op, false, &assembled, &rstr, request); 549 } 550 551 //------------------------------------------------------------------------------ 552 // Setup Input/Output Fields 553 //------------------------------------------------------------------------------ 554 static int CeedOperatorSetupFieldsAtPoints_Ref(CeedQFunction qf, CeedOperator op, bool is_input, CeedVector *e_vecs_full, CeedVector *e_vecs, 555 CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) { 556 Ceed ceed; 557 CeedSize e_size, q_size; 558 CeedInt e_size_padding = 0, max_num_points, num_comp, size, P; 559 CeedQFunctionField *qf_fields; 560 CeedOperatorField *op_fields; 561 562 { 563 Ceed ceed_parent; 564 565 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 566 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 567 if (ceed_parent) ceed = ceed_parent; 568 } 569 if (is_input) { 570 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 571 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 572 } else { 573 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 574 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 575 } 576 577 // Get max number of points 578 { 579 CeedInt dim; 580 CeedElemRestriction rstr_points = NULL; 581 CeedOperator_Ref *impl; 582 583 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 584 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(rstr_points, &max_num_points)); 585 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_points, &dim)); 586 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 587 CeedCallBackend(CeedOperatorGetData(op, &impl)); 588 if (is_input) { 589 CeedCallBackend(CeedVectorCreate(ceed, dim * max_num_points, &impl->point_coords_elem)); 590 CeedCallBackend(CeedVectorSetValue(impl->point_coords_elem, 0.0)); 591 } 592 } 593 594 // Loop over fields 595 for (CeedInt i = 0; i < num_fields; i++) { 596 CeedEvalMode eval_mode; 597 CeedBasis basis; 598 599 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 600 if (eval_mode != CEED_EVAL_WEIGHT) { 601 CeedElemRestriction elem_rstr; 602 CeedSize e_size; 603 bool is_at_points; 604 605 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); 606 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp)); 607 CeedCallBackend(CeedElemRestrictionIsPoints(elem_rstr, &is_at_points)); 608 if (is_at_points) { 609 CeedCallBackend(CeedElemRestrictionGetEVectorSize(elem_rstr, &e_size)); 610 if (e_size_padding == 0) { 611 CeedInt num_points, num_elem; 612 613 CeedCallBackend(CeedElemRestrictionGetNumElements(elem_rstr, &num_elem)); 614 CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(elem_rstr, num_elem - 1, &num_points)); 615 e_size_padding = (max_num_points - num_points) * num_comp; 616 } 617 CeedCallBackend(CeedVectorCreate(ceed, e_size + e_size_padding, &e_vecs_full[i + start_e])); 618 CeedCallBackend(CeedVectorSetValue(e_vecs_full[i + start_e], 0.0)); 619 } else { 620 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs_full[i + start_e])); 621 } 622 } 623 624 switch (eval_mode) { 625 case CEED_EVAL_NONE: { 626 CeedVector vec; 627 628 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 629 e_size = (CeedSize)max_num_points * size; 630 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 631 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 632 if (vec == CEED_VECTOR_ACTIVE || !is_input) { 633 CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &q_vecs[i])); 634 } else { 635 q_size = (CeedSize)max_num_points * size; 636 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 637 } 638 break; 639 } 640 case CEED_EVAL_INTERP: 641 case CEED_EVAL_GRAD: 642 case CEED_EVAL_DIV: 643 case CEED_EVAL_CURL: 644 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 645 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 646 CeedCallBackend(CeedBasisGetNumNodes(basis, &P)); 647 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 648 e_size = (CeedSize)P * num_comp; 649 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 650 q_size = (CeedSize)max_num_points * size; 651 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 652 break; 653 case CEED_EVAL_WEIGHT: // Only on input fields 654 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 655 q_size = (CeedSize)max_num_points; 656 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 657 CeedCallBackend( 658 CeedBasisApplyAtPoints(basis, max_num_points, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, CEED_VECTOR_NONE, q_vecs[i])); 659 break; 660 } 661 // Initialize full arrays for E-vectors and Q-vectors 662 if (e_vecs[i]) CeedCallBackend(CeedVectorSetValue(e_vecs[i], 0.0)); 663 if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorSetValue(q_vecs[i], 0.0)); 664 } 665 return CEED_ERROR_SUCCESS; 666 } 667 668 //------------------------------------------------------------------------------ 669 // Setup Operator 670 //------------------------------------------------------------------------------ 671 static int CeedOperatorSetupAtPoints_Ref(CeedOperator op) { 672 bool is_setup_done; 673 CeedInt Q, num_input_fields, num_output_fields; 674 CeedQFunctionField *qf_input_fields, *qf_output_fields; 675 CeedQFunction qf; 676 CeedOperatorField *op_input_fields, *op_output_fields; 677 CeedOperator_Ref *impl; 678 679 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 680 if (is_setup_done) return CEED_ERROR_SUCCESS; 681 682 CeedCallBackend(CeedOperatorGetData(op, &impl)); 683 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 684 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 685 CeedCallBackend(CeedQFunctionIsIdentity(qf, &impl->is_identity_qf)); 686 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 687 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 688 689 // Allocate 690 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); 691 692 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 693 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); 694 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); 695 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 696 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 697 698 impl->num_inputs = num_input_fields; 699 impl->num_outputs = num_output_fields; 700 701 // Set up infield and outfield pointer arrays 702 // Infields 703 CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, true, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q)); 704 // Outfields 705 CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, false, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields, 706 num_output_fields, Q)); 707 708 // Identity QFunctions 709 if (impl->is_identity_qf) { 710 CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->q_vecs_out[0])); 711 CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->e_vecs_out[0])); 712 } 713 714 CeedCallBackend(CeedOperatorSetSetupDone(op)); 715 return CEED_ERROR_SUCCESS; 716 } 717 718 //------------------------------------------------------------------------------ 719 // Input Basis Action 720 //------------------------------------------------------------------------------ 721 static inline int CeedOperatorInputBasisAtPoints_Ref(CeedInt e, CeedInt num_points_offset, CeedInt num_points, CeedQFunctionField *qf_input_fields, 722 CeedOperatorField *op_input_fields, CeedInt num_input_fields, CeedVector in_vec, 723 CeedVector point_coords_elem, bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX], 724 CeedOperator_Ref *impl, CeedRequest *request) { 725 for (CeedInt i = 0; i < num_input_fields; i++) { 726 bool is_active_input = false; 727 CeedInt elem_size, size, num_comp; 728 CeedRestrictionType rstr_type; 729 CeedEvalMode eval_mode; 730 CeedVector vec; 731 CeedElemRestriction elem_rstr; 732 CeedBasis basis; 733 734 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 735 // Skip active input 736 is_active_input = vec == CEED_VECTOR_ACTIVE; 737 if (skip_active && is_active_input) continue; 738 739 // Get elem_size, eval_mode, size 740 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 741 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 742 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 743 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 744 // Restrict block active input 745 if (is_active_input) { 746 if (rstr_type == CEED_RESTRICTION_POINTS) { 747 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(elem_rstr, e, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request)); 748 } else { 749 CeedCallBackend(CeedElemRestrictionApplyBlock(elem_rstr, e, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request)); 750 } 751 } 752 // Basis action 753 switch (eval_mode) { 754 case CEED_EVAL_NONE: 755 if (!is_active_input) { 756 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data[i][num_points_offset * size])); 757 } 758 break; 759 // Note - these basis eval modes require FEM fields 760 case CEED_EVAL_INTERP: 761 case CEED_EVAL_GRAD: 762 case CEED_EVAL_DIV: 763 case CEED_EVAL_CURL: 764 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 765 if (!is_active_input) { 766 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 767 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 768 CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data[i][e * elem_size * num_comp])); 769 } 770 CeedCallBackend( 771 CeedBasisApplyAtPoints(basis, num_points, CEED_NOTRANSPOSE, eval_mode, point_coords_elem, impl->e_vecs_in[i], impl->q_vecs_in[i])); 772 break; 773 case CEED_EVAL_WEIGHT: 774 break; // No action 775 } 776 } 777 return CEED_ERROR_SUCCESS; 778 } 779 780 //------------------------------------------------------------------------------ 781 // Output Basis Action 782 //------------------------------------------------------------------------------ 783 static inline int CeedOperatorOutputBasisAtPoints_Ref(CeedInt e, CeedInt num_points_offset, CeedInt num_points, CeedQFunctionField *qf_output_fields, 784 CeedOperatorField *op_output_fields, CeedInt num_input_fields, CeedInt num_output_fields, 785 CeedOperator op, CeedVector out_vec, CeedVector point_coords_elem, CeedOperator_Ref *impl, 786 CeedRequest *request) { 787 for (CeedInt i = 0; i < num_output_fields; i++) { 788 CeedRestrictionType rstr_type; 789 CeedEvalMode eval_mode; 790 CeedVector vec; 791 CeedElemRestriction elem_rstr; 792 CeedBasis basis; 793 794 // Get elem_size, eval_mode, size 795 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 796 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 797 // Basis action 798 switch (eval_mode) { 799 case CEED_EVAL_NONE: 800 break; // No action 801 case CEED_EVAL_INTERP: 802 case CEED_EVAL_GRAD: 803 case CEED_EVAL_DIV: 804 case CEED_EVAL_CURL: 805 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 806 CeedCallBackend( 807 CeedBasisApplyAtPoints(basis, num_points, CEED_TRANSPOSE, eval_mode, point_coords_elem, impl->q_vecs_out[i], impl->e_vecs_out[i])); 808 break; 809 // LCOV_EXCL_START 810 case CEED_EVAL_WEIGHT: { 811 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 812 // LCOV_EXCL_STOP 813 } 814 } 815 // Restrict output block 816 // Get output vector 817 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 818 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 819 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 820 // Restrict 821 if (rstr_type == CEED_RESTRICTION_POINTS) { 822 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(elem_rstr, e, CEED_TRANSPOSE, impl->e_vecs_out[i], vec, request)); 823 } else { 824 CeedCallBackend(CeedElemRestrictionApplyBlock(elem_rstr, e, CEED_TRANSPOSE, impl->e_vecs_out[i], vec, request)); 825 } 826 } 827 return CEED_ERROR_SUCCESS; 828 } 829 830 //------------------------------------------------------------------------------ 831 // Operator Apply 832 //------------------------------------------------------------------------------ 833 static int CeedOperatorApplyAddAtPoints_Ref(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 834 CeedInt num_points_offset = 0, num_input_fields, num_output_fields, num_elem; 835 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {0}; 836 CeedVector point_coords = NULL; 837 CeedElemRestriction rstr_points = NULL; 838 CeedQFunctionField *qf_input_fields, *qf_output_fields; 839 CeedQFunction qf; 840 CeedOperatorField *op_input_fields, *op_output_fields; 841 CeedOperator_Ref *impl; 842 843 CeedCallBackend(CeedOperatorGetData(op, &impl)); 844 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 845 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 846 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 847 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 848 849 // Setup 850 CeedCallBackend(CeedOperatorSetupAtPoints_Ref(op)); 851 852 // Point coordinates 853 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 854 855 // Input Evecs and Restriction 856 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 857 858 // Loop through elements 859 for (CeedInt e = 0; e < num_elem; e++) { 860 CeedInt num_points; 861 862 // Setup points for element 863 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(rstr_points, e, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 864 CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points)); 865 866 // Input basis apply 867 CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, in_vec, 868 impl->point_coords_elem, false, e_data, impl, request)); 869 870 // Q function 871 if (!impl->is_identity_qf) { 872 CeedCallBackend(CeedQFunctionApply(qf, num_points, impl->q_vecs_in, impl->q_vecs_out)); 873 } 874 875 // Output basis apply and restriction 876 CeedCallBackend(CeedOperatorOutputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_output_fields, op_output_fields, num_input_fields, 877 num_output_fields, op, out_vec, impl->point_coords_elem, impl, request)); 878 879 num_points_offset += num_points; 880 } 881 882 // Restore input arrays 883 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 884 885 // Cleanup point coordinates 886 CeedCallBackend(CeedVectorDestroy(&point_coords)); 887 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 888 return CEED_ERROR_SUCCESS; 889 } 890 891 //------------------------------------------------------------------------------ 892 // Core code for assembling linear QFunction 893 //------------------------------------------------------------------------------ 894 static inline int CeedOperatorLinearAssembleQFunctionAtPointsCore_Ref(CeedOperator op, bool build_objects, CeedVector *assembled, 895 CeedElemRestriction *rstr, CeedRequest *request) { 896 Ceed ceed; 897 CeedSize q_size; 898 CeedInt num_active_in, num_active_out, max_num_points, num_elem, num_input_fields, num_output_fields, num_points_offset = 0; 899 CeedScalar *assembled_array, *e_data_full[2 * CEED_FIELD_MAX] = {NULL}; 900 CeedVector *active_in, point_coords = NULL; 901 CeedQFunctionField *qf_input_fields, *qf_output_fields; 902 CeedQFunction qf; 903 CeedOperatorField *op_input_fields, *op_output_fields; 904 CeedOperator_Ref *impl; 905 CeedElemRestriction rstr_points = NULL; 906 907 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 908 CeedCallBackend(CeedOperatorGetData(op, &impl)); 909 active_in = impl->qf_active_in; 910 num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 911 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 912 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 913 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 914 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 915 916 // Setup 917 CeedCallBackend(CeedOperatorSetupAtPoints_Ref(op)); 918 919 // Check for restriction only operator 920 CeedCheck(!impl->is_identity_rstr_op, ceed, CEED_ERROR_BACKEND, "Assembling restriction only operators is not supported"); 921 922 // Point coordinates 923 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 924 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(rstr_points, &max_num_points)); 925 926 // Input Evecs and Restriction 927 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data_full, impl, request)); 928 929 // Count number of active input fields 930 if (!num_active_in) { 931 for (CeedInt i = 0; i < num_input_fields; i++) { 932 CeedScalar *q_vec_array; 933 CeedInt field_size; 934 CeedVector vec; 935 936 // Get input vector 937 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 938 // Check if active input 939 if (vec == CEED_VECTOR_ACTIVE) { 940 // Check that all active inputs are nodal fields 941 { 942 CeedElemRestriction elem_rstr; 943 bool is_at_points = false; 944 945 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 946 CeedCallBackend(CeedElemRestrictionIsPoints(elem_rstr, &is_at_points)); 947 CeedCheck(!is_at_points, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction with active input at points"); 948 } 949 // Get size of active input 950 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size)); 951 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 952 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &q_vec_array)); 953 CeedCallBackend(CeedRealloc(num_active_in + field_size, &active_in)); 954 for (CeedInt field = 0; field < field_size; field++) { 955 q_size = (CeedSize)max_num_points; 956 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_in[num_active_in + field])); 957 CeedCallBackend(CeedVectorSetArray(active_in[num_active_in + field], CEED_MEM_HOST, CEED_USE_POINTER, &q_vec_array[field * q_size])); 958 } 959 num_active_in += field_size; 960 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array)); 961 } 962 } 963 impl->num_active_in = num_active_in; 964 impl->qf_active_in = active_in; 965 } 966 967 // Count number of active output fields 968 if (!num_active_out) { 969 for (CeedInt i = 0; i < num_output_fields; i++) { 970 CeedVector vec; 971 CeedInt field_size; 972 973 // Get output vector 974 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 975 // Check if active output 976 if (vec == CEED_VECTOR_ACTIVE) { 977 // Check that all active inputs are nodal fields 978 { 979 CeedElemRestriction elem_rstr; 980 bool is_at_points = false; 981 982 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 983 CeedCallBackend(CeedElemRestrictionIsPoints(elem_rstr, &is_at_points)); 984 CeedCheck(!is_at_points, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction with active input at points"); 985 } 986 // Get size of active output 987 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size)); 988 num_active_out += field_size; 989 } 990 } 991 impl->num_active_out = num_active_out; 992 } 993 994 // Check sizes 995 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 996 997 // Build objects if needed 998 if (build_objects) { 999 CeedInt num_points_total; 1000 const CeedInt *offsets; 1001 1002 CeedCallBackend(CeedElemRestrictionGetNumPoints(rstr_points, &num_points_total)); 1003 1004 // Create output restriction (at points) 1005 CeedCallBackend(CeedElemRestrictionGetOffsets(rstr_points, CEED_MEM_HOST, &offsets)); 1006 CeedCallBackend(CeedElemRestrictionCreateAtPoints(ceed, num_elem, num_points_total, num_active_in * num_active_out, 1007 num_active_in * num_active_out * num_points_total, CEED_MEM_HOST, CEED_COPY_VALUES, offsets, 1008 rstr)); 1009 CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr_points, &offsets)); 1010 1011 // Create assembled vector 1012 CeedCallBackend(CeedElemRestrictionCreateVector(*rstr, assembled, NULL)); 1013 } 1014 // Clear output vector 1015 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 1016 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_HOST, &assembled_array)); 1017 1018 // Loop through elements 1019 for (CeedInt e = 0; e < num_elem; e++) { 1020 CeedInt num_points; 1021 1022 // Setup points for element 1023 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(rstr_points, e, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 1024 CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points)); 1025 1026 // Input basis apply 1027 CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, NULL, 1028 impl->point_coords_elem, true, e_data_full, impl, request)); 1029 1030 // Assemble QFunction 1031 for (CeedInt in = 0; in < num_active_in; in++) { 1032 // Set Inputs 1033 CeedCallBackend(CeedVectorSetValue(active_in[in], 1.0)); 1034 if (num_active_in > 1) { 1035 CeedCallBackend(CeedVectorSetValue(active_in[(in + num_active_in - 1) % num_active_in], 0.0)); 1036 } 1037 if (!impl->is_identity_qf) { 1038 // Set Outputs 1039 for (CeedInt out = 0; out < num_output_fields; out++) { 1040 CeedVector vec; 1041 CeedInt field_size; 1042 1043 // Get output vector 1044 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 1045 // Check if active output 1046 if (vec == CEED_VECTOR_ACTIVE) { 1047 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_USE_POINTER, assembled_array)); 1048 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &field_size)); 1049 assembled_array += field_size * num_points; // Advance the pointer by the size of the output 1050 } 1051 } 1052 // Apply QFunction 1053 CeedCallBackend(CeedQFunctionApply(qf, num_points, impl->q_vecs_in, impl->q_vecs_out)); 1054 } else { 1055 const CeedScalar *q_vec_array; 1056 CeedInt field_size; 1057 1058 // Copy Identity Outputs 1059 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[0], &field_size)); 1060 CeedCallBackend(CeedVectorGetArrayRead(impl->q_vecs_out[0], CEED_MEM_HOST, &q_vec_array)); 1061 for (CeedInt i = 0; i < field_size * num_points; i++) assembled_array[i] = q_vec_array[i]; 1062 CeedCallBackend(CeedVectorRestoreArrayRead(impl->q_vecs_out[0], &q_vec_array)); 1063 assembled_array += field_size * num_points; 1064 } 1065 } 1066 num_points_offset += num_points; 1067 } 1068 1069 // Un-set output Qvecs to prevent accidental overwrite of Assembled 1070 if (!impl->is_identity_qf) { 1071 for (CeedInt out = 0; out < num_output_fields; out++) { 1072 CeedVector vec; 1073 1074 // Get output vector 1075 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 1076 // Check if active output 1077 if (vec == CEED_VECTOR_ACTIVE && num_elem > 0) { 1078 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL)); 1079 } 1080 } 1081 } 1082 1083 // Restore input arrays 1084 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, true, e_data_full, impl)); 1085 1086 // Restore output 1087 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 1088 1089 // Cleanup 1090 CeedCallBackend(CeedVectorDestroy(&point_coords)); 1091 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 1092 return CEED_ERROR_SUCCESS; 1093 } 1094 1095 //------------------------------------------------------------------------------ 1096 // Assemble Linear QFunction 1097 //------------------------------------------------------------------------------ 1098 static int CeedOperatorLinearAssembleQFunctionAtPoints_Ref(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 1099 return CeedOperatorLinearAssembleQFunctionAtPointsCore_Ref(op, true, assembled, rstr, request); 1100 } 1101 1102 //------------------------------------------------------------------------------ 1103 // Update Assembled Linear QFunction 1104 //------------------------------------------------------------------------------ 1105 static int CeedOperatorLinearAssembleQFunctionAtPointsUpdate_Ref(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, 1106 CeedRequest *request) { 1107 return CeedOperatorLinearAssembleQFunctionAtPointsCore_Ref(op, false, &assembled, &rstr, request); 1108 } 1109 1110 //------------------------------------------------------------------------------ 1111 // Assemble Operator 1112 //------------------------------------------------------------------------------ 1113 1114 //------------------------------------------------------------------------------ 1115 // Operator Destroy 1116 //------------------------------------------------------------------------------ 1117 static int CeedOperatorDestroy_Ref(CeedOperator op) { 1118 CeedOperator_Ref *impl; 1119 1120 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1121 for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { 1122 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i])); 1123 } 1124 CeedCallBackend(CeedFree(&impl->e_vecs_full)); 1125 CeedCallBackend(CeedFree(&impl->input_states)); 1126 1127 for (CeedInt i = 0; i < impl->num_inputs; i++) { 1128 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i])); 1129 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 1130 } 1131 CeedCallBackend(CeedFree(&impl->e_vecs_in)); 1132 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 1133 1134 for (CeedInt i = 0; i < impl->num_outputs; i++) { 1135 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i])); 1136 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 1137 } 1138 CeedCallBackend(CeedFree(&impl->e_vecs_out)); 1139 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 1140 CeedCallBackend(CeedVectorDestroy(&impl->point_coords_elem)); 1141 1142 // QFunction assembly 1143 for (CeedInt i = 0; i < impl->num_active_in; i++) { 1144 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i])); 1145 } 1146 CeedCallBackend(CeedFree(&impl->qf_active_in)); 1147 1148 CeedCallBackend(CeedFree(&impl)); 1149 return CEED_ERROR_SUCCESS; 1150 } 1151 1152 //------------------------------------------------------------------------------ 1153 // Operator Create 1154 //------------------------------------------------------------------------------ 1155 int CeedOperatorCreate_Ref(CeedOperator op) { 1156 Ceed ceed; 1157 CeedOperator_Ref *impl; 1158 1159 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1160 CeedCallBackend(CeedCalloc(1, &impl)); 1161 CeedCallBackend(CeedOperatorSetData(op, impl)); 1162 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Ref)); 1163 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Ref)); 1164 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Ref)); 1165 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Ref)); 1166 return CEED_ERROR_SUCCESS; 1167 } 1168 1169 //------------------------------------------------------------------------------ 1170 // Operator Create At Points 1171 //------------------------------------------------------------------------------ 1172 int CeedOperatorCreateAtPoints_Ref(CeedOperator op) { 1173 Ceed ceed; 1174 CeedOperator_Ref *impl; 1175 1176 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1177 CeedCallBackend(CeedCalloc(1, &impl)); 1178 CeedCallBackend(CeedOperatorSetData(op, impl)); 1179 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunctionAtPoints_Ref)); 1180 CeedCallBackend( 1181 CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionAtPointsUpdate_Ref)); 1182 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Ref)); 1183 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Ref)); 1184 return CEED_ERROR_SUCCESS; 1185 } 1186 1187 //------------------------------------------------------------------------------ 1188