1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <stdbool.h> 11 #include <stddef.h> 12 #include <stdint.h> 13 14 #include "ceed-blocked.h" 15 16 //------------------------------------------------------------------------------ 17 // Setup Input/Output Fields 18 //------------------------------------------------------------------------------ 19 static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bool is_input, CeedElemRestriction *block_restr, 20 CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, 21 CeedInt Q) { 22 Ceed ceed; 23 CeedSize e_size, q_size; 24 CeedInt num_comp, size, P; 25 const CeedInt block_size = 8; 26 CeedQFunctionField *qf_fields; 27 CeedOperatorField *op_fields; 28 29 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 30 31 if (is_input) { 32 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 33 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 34 } else { 35 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 36 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 37 } 38 39 // Loop over fields 40 for (CeedInt i = 0; i < num_fields; i++) { 41 CeedEvalMode eval_mode; 42 CeedBasis basis; 43 44 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 45 46 if (eval_mode != CEED_EVAL_WEIGHT) { 47 Ceed ceed_rstr; 48 CeedSize l_size; 49 CeedInt num_elem, elem_size, comp_stride; 50 CeedRestrictionType rstr_type; 51 CeedElemRestriction rstr; 52 53 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr)); 54 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed_rstr)); 55 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 56 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 57 CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size)); 58 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 59 CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride)); 60 61 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 62 switch (rstr_type) { 63 case CEED_RESTRICTION_STANDARD: { 64 const CeedInt *offsets = NULL; 65 66 CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets)); 67 CeedCallBackend(CeedElemRestrictionCreateBlocked(ceed_rstr, num_elem, elem_size, block_size, num_comp, comp_stride, l_size, CEED_MEM_HOST, 68 CEED_COPY_VALUES, offsets, &block_restr[i + start_e])); 69 CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets)); 70 } break; 71 case CEED_RESTRICTION_ORIENTED: { 72 const bool *orients = NULL; 73 const CeedInt *offsets = NULL; 74 75 CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets)); 76 CeedCallBackend(CeedElemRestrictionGetOrientations(rstr, CEED_MEM_HOST, &orients)); 77 CeedCallBackend(CeedElemRestrictionCreateBlockedOriented(ceed_rstr, num_elem, elem_size, block_size, num_comp, comp_stride, l_size, 78 CEED_MEM_HOST, CEED_COPY_VALUES, offsets, orients, &block_restr[i + start_e])); 79 CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets)); 80 CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr, &orients)); 81 } break; 82 case CEED_RESTRICTION_CURL_ORIENTED: { 83 const CeedInt8 *curl_orients = NULL; 84 const CeedInt *offsets = NULL; 85 86 CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets)); 87 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr, CEED_MEM_HOST, &curl_orients)); 88 CeedCallBackend(CeedElemRestrictionCreateBlockedCurlOriented(ceed_rstr, num_elem, elem_size, block_size, num_comp, comp_stride, l_size, 89 CEED_MEM_HOST, CEED_COPY_VALUES, offsets, curl_orients, 90 &block_restr[i + start_e])); 91 CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets)); 92 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr, &curl_orients)); 93 } break; 94 case CEED_RESTRICTION_STRIDED: { 95 CeedInt strides[3]; 96 97 CeedCallBackend(CeedElemRestrictionGetStrides(rstr, &strides)); 98 CeedCallBackend(CeedElemRestrictionCreateBlockedStrided(ceed_rstr, num_elem, elem_size, block_size, num_comp, l_size, strides, 99 &block_restr[i + start_e])); 100 } break; 101 } 102 CeedCallBackend(CeedElemRestrictionCreateVector(block_restr[i + start_e], NULL, &e_vecs_full[i + start_e])); 103 } 104 105 switch (eval_mode) { 106 case CEED_EVAL_NONE: 107 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 108 q_size = (CeedSize)Q * size * block_size; 109 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 110 break; 111 case CEED_EVAL_INTERP: 112 case CEED_EVAL_GRAD: 113 case CEED_EVAL_DIV: 114 case CEED_EVAL_CURL: 115 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 116 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 117 CeedCallBackend(CeedBasisGetNumNodes(basis, &P)); 118 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 119 e_size = (CeedSize)P * num_comp * block_size; 120 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 121 q_size = (CeedSize)Q * size * block_size; 122 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 123 break; 124 case CEED_EVAL_WEIGHT: // Only on input fields 125 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 126 q_size = (CeedSize)Q * block_size; 127 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 128 CeedCallBackend(CeedBasisApply(basis, block_size, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 129 break; 130 } 131 } 132 return CEED_ERROR_SUCCESS; 133 } 134 135 //------------------------------------------------------------------------------ 136 // Setup Operator 137 //------------------------------------------------------------------------------ 138 static int CeedOperatorSetup_Blocked(CeedOperator op) { 139 bool is_setup_done; 140 Ceed ceed; 141 CeedInt Q, num_input_fields, num_output_fields; 142 CeedQFunctionField *qf_input_fields, *qf_output_fields; 143 CeedQFunction qf; 144 CeedOperatorField *op_input_fields, *op_output_fields; 145 CeedOperator_Blocked *impl; 146 147 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 148 if (is_setup_done) return CEED_ERROR_SUCCESS; 149 150 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 151 CeedCallBackend(CeedOperatorGetData(op, &impl)); 152 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 153 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 154 CeedCallBackend(CeedQFunctionIsIdentity(qf, &impl->is_identity_qf)); 155 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 156 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 157 158 // Allocate 159 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->block_restr)); 160 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); 161 162 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 163 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); 164 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); 165 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 166 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 167 168 impl->num_inputs = num_input_fields; 169 impl->num_outputs = num_output_fields; 170 171 // Set up infield and outfield pointer arrays 172 // Infields 173 CeedCallBackend( 174 CeedOperatorSetupFields_Blocked(qf, op, true, impl->block_restr, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q)); 175 // Outfields 176 CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, false, impl->block_restr, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, 177 num_input_fields, num_output_fields, Q)); 178 179 // Identity QFunctions 180 if (impl->is_identity_qf) { 181 CeedEvalMode in_mode, out_mode; 182 CeedQFunctionField *in_fields, *out_fields; 183 184 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields)); 185 CeedCallBackend(CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode)); 186 CeedCallBackend(CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode)); 187 188 if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) { 189 impl->is_identity_restr_op = true; 190 } else { 191 CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->q_vecs_out[0])); 192 } 193 } 194 195 CeedCallBackend(CeedOperatorSetSetupDone(op)); 196 return CEED_ERROR_SUCCESS; 197 } 198 199 //------------------------------------------------------------------------------ 200 // Setup Operator Inputs 201 //------------------------------------------------------------------------------ 202 static inline int CeedOperatorSetupInputs_Blocked(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 203 CeedVector in_vec, bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], 204 CeedOperator_Blocked *impl, CeedRequest *request) { 205 for (CeedInt i = 0; i < num_input_fields; i++) { 206 uint64_t state; 207 CeedEvalMode eval_mode; 208 CeedVector vec; 209 210 // Get input vector 211 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 212 if (vec == CEED_VECTOR_ACTIVE) { 213 if (skip_active) continue; 214 else vec = in_vec; 215 } 216 217 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 218 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 219 } else { 220 // Restrict 221 CeedCallBackend(CeedVectorGetState(vec, &state)); 222 if (state != impl->input_states[i] || vec == in_vec) { 223 CeedCallBackend(CeedElemRestrictionApply(impl->block_restr[i], CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request)); 224 impl->input_states[i] = state; 225 } 226 // Get evec 227 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, (const CeedScalar **)&e_data_full[i])); 228 } 229 } 230 return CEED_ERROR_SUCCESS; 231 } 232 233 //------------------------------------------------------------------------------ 234 // Input Basis Action 235 //------------------------------------------------------------------------------ 236 static inline int CeedOperatorInputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 237 CeedInt num_input_fields, CeedInt block_size, bool skip_active, 238 CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) { 239 for (CeedInt i = 0; i < num_input_fields; i++) { 240 CeedInt elem_size, size, num_comp; 241 CeedEvalMode eval_mode; 242 CeedElemRestriction elem_restr; 243 CeedBasis basis; 244 245 // Skip active input 246 if (skip_active) { 247 CeedVector vec; 248 249 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 250 if (vec == CEED_VECTOR_ACTIVE) continue; 251 } 252 253 // Get elem_size, eval_mode, size 254 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr)); 255 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_restr, &elem_size)); 256 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 257 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 258 // Basis action 259 switch (eval_mode) { 260 case CEED_EVAL_NONE: 261 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][e * Q * size])); 262 break; 263 case CEED_EVAL_INTERP: 264 case CEED_EVAL_GRAD: 265 case CEED_EVAL_DIV: 266 case CEED_EVAL_CURL: 267 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 268 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 269 CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][e * elem_size * num_comp])); 270 CeedCallBackend(CeedBasisApply(basis, block_size, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs_in[i], impl->q_vecs_in[i])); 271 break; 272 case CEED_EVAL_WEIGHT: 273 break; // No action 274 } 275 } 276 return CEED_ERROR_SUCCESS; 277 } 278 279 //------------------------------------------------------------------------------ 280 // Output Basis Action 281 //------------------------------------------------------------------------------ 282 static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, 283 CeedInt block_size, CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op, 284 CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) { 285 for (CeedInt i = 0; i < num_output_fields; i++) { 286 CeedInt elem_size, num_comp; 287 CeedEvalMode eval_mode; 288 CeedElemRestriction elem_restr; 289 CeedBasis basis; 290 291 // Get elem_size, eval_mode, size 292 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr)); 293 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_restr, &elem_size)); 294 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 295 // Basis action 296 switch (eval_mode) { 297 case CEED_EVAL_NONE: 298 break; // No action 299 case CEED_EVAL_INTERP: 300 case CEED_EVAL_GRAD: 301 case CEED_EVAL_DIV: 302 case CEED_EVAL_CURL: 303 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 304 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 305 CeedCallBackend( 306 CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][e * elem_size * num_comp])); 307 CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); 308 break; 309 // LCOV_EXCL_START 310 case CEED_EVAL_WEIGHT: { 311 Ceed ceed; 312 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 313 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 314 // LCOV_EXCL_STOP 315 } 316 } 317 } 318 return CEED_ERROR_SUCCESS; 319 } 320 321 //------------------------------------------------------------------------------ 322 // Restore Input Vectors 323 //------------------------------------------------------------------------------ 324 static inline int CeedOperatorRestoreInputs_Blocked(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 325 bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) { 326 for (CeedInt i = 0; i < num_input_fields; i++) { 327 CeedEvalMode eval_mode; 328 329 // Skip active inputs 330 if (skip_active) { 331 CeedVector vec; 332 333 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 334 if (vec == CEED_VECTOR_ACTIVE) continue; 335 } 336 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 337 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 338 } else { 339 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_full[i], (const CeedScalar **)&e_data_full[i])); 340 } 341 } 342 return CEED_ERROR_SUCCESS; 343 } 344 345 //------------------------------------------------------------------------------ 346 // Operator Apply 347 //------------------------------------------------------------------------------ 348 static int CeedOperatorApplyAdd_Blocked(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 349 CeedInt Q, num_input_fields, num_output_fields, num_elem, size; 350 const CeedInt block_size = 8; 351 CeedEvalMode eval_mode; 352 CeedScalar *e_data_full[2 * CEED_FIELD_MAX] = {0}; 353 CeedQFunctionField *qf_input_fields, *qf_output_fields; 354 CeedQFunction qf; 355 CeedOperatorField *op_input_fields, *op_output_fields; 356 CeedOperator_Blocked *impl; 357 358 CeedCallBackend(CeedOperatorGetData(op, &impl)); 359 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 360 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 361 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 362 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 363 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 364 const CeedInt num_blocks = (num_elem / block_size) + !!(num_elem % block_size); 365 366 // Setup 367 CeedCallBackend(CeedOperatorSetup_Blocked(op)); 368 369 // Restriction only operator 370 if (impl->is_identity_restr_op) { 371 CeedCallBackend(CeedElemRestrictionApply(impl->block_restr[0], CEED_NOTRANSPOSE, in_vec, impl->e_vecs_full[0], request)); 372 CeedCallBackend(CeedElemRestrictionApply(impl->block_restr[1], CEED_TRANSPOSE, impl->e_vecs_full[0], out_vec, request)); 373 return CEED_ERROR_SUCCESS; 374 } 375 376 // Input Evecs and Restriction 377 CeedCallBackend(CeedOperatorSetupInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data_full, impl, request)); 378 379 // Output Evecs 380 for (CeedInt i = 0; i < num_output_fields; i++) { 381 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields])); 382 } 383 384 // Loop through elements 385 for (CeedInt e = 0; e < num_blocks * block_size; e += block_size) { 386 // Output pointers 387 for (CeedInt i = 0; i < num_output_fields; i++) { 388 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 389 if (eval_mode == CEED_EVAL_NONE) { 390 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 391 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][e * Q * size])); 392 } 393 } 394 395 // Input basis apply 396 CeedCallBackend(CeedOperatorInputBasis_Blocked(e, Q, qf_input_fields, op_input_fields, num_input_fields, block_size, false, e_data_full, impl)); 397 398 // Q function 399 if (!impl->is_identity_qf) { 400 CeedCallBackend(CeedQFunctionApply(qf, Q * block_size, impl->q_vecs_in, impl->q_vecs_out)); 401 } 402 403 // Output basis apply 404 CeedCallBackend(CeedOperatorOutputBasis_Blocked(e, Q, qf_output_fields, op_output_fields, block_size, num_input_fields, num_output_fields, op, 405 e_data_full, impl)); 406 } 407 408 // Output restriction 409 for (CeedInt i = 0; i < num_output_fields; i++) { 410 CeedVector vec; 411 412 // Restore evec 413 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_full[i + impl->num_inputs], &e_data_full[i + num_input_fields])); 414 // Get output vector 415 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 416 // Active 417 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 418 // Restrict 419 CeedCallBackend( 420 CeedElemRestrictionApply(impl->block_restr[i + impl->num_inputs], CEED_TRANSPOSE, impl->e_vecs_full[i + impl->num_inputs], vec, request)); 421 } 422 423 // Restore input arrays 424 CeedCallBackend(CeedOperatorRestoreInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, false, e_data_full, impl)); 425 return CEED_ERROR_SUCCESS; 426 } 427 428 //------------------------------------------------------------------------------ 429 // Core code for assembling linear QFunction 430 //------------------------------------------------------------------------------ 431 static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator op, bool build_objects, CeedVector *assembled, 432 CeedElemRestriction *rstr, CeedRequest *request) { 433 Ceed ceed; 434 CeedSize q_size; 435 CeedInt Q, num_input_fields, num_output_fields, num_elem, size; 436 const CeedInt block_size = 8; 437 CeedScalar *l_vec_array; 438 CeedScalar *e_data_full[2 * CEED_FIELD_MAX] = {0}; 439 CeedQFunctionField *qf_input_fields, *qf_output_fields; 440 CeedQFunction qf; 441 CeedOperatorField *op_input_fields, *op_output_fields; 442 CeedOperator_Blocked *impl; 443 444 CeedCallBackend(CeedOperatorGetData(op, &impl)); 445 CeedInt num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 446 CeedVector *active_in = impl->qf_active_in; 447 CeedVector l_vec = impl->qf_l_vec; 448 CeedElemRestriction block_rstr = impl->qf_block_rstr; 449 450 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 451 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 452 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 453 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 454 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 455 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 456 const CeedInt num_blocks = (num_elem / block_size) + !!(num_elem % block_size); 457 458 // Setup 459 CeedCallBackend(CeedOperatorSetup_Blocked(op)); 460 461 // Check for identity 462 CeedCheck(!impl->is_identity_qf, ceed, CEED_ERROR_BACKEND, "Assembling identity QFunctions not supported"); 463 464 // Input Evecs and Restriction 465 CeedCallBackend(CeedOperatorSetupInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data_full, impl, request)); 466 467 // Count number of active input fields 468 if (!num_active_in) { 469 for (CeedInt i = 0; i < num_input_fields; i++) { 470 CeedScalar *q_vec_array; 471 CeedVector vec; 472 473 // Get input vector 474 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 475 // Check if active input 476 if (vec == CEED_VECTOR_ACTIVE) { 477 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 478 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 479 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &q_vec_array)); 480 CeedCallBackend(CeedRealloc(num_active_in + size, &active_in)); 481 for (CeedInt field = 0; field < size; field++) { 482 q_size = (CeedSize)Q * block_size; 483 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_in[num_active_in + field])); 484 CeedCallBackend( 485 CeedVectorSetArray(active_in[num_active_in + field], CEED_MEM_HOST, CEED_USE_POINTER, &q_vec_array[field * Q * block_size])); 486 } 487 num_active_in += size; 488 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array)); 489 } 490 } 491 impl->num_active_in = num_active_in; 492 impl->qf_active_in = active_in; 493 } 494 495 // Count number of active output fields 496 if (!num_active_out) { 497 for (CeedInt i = 0; i < num_output_fields; i++) { 498 CeedVector vec; 499 500 // Get output vector 501 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 502 // Check if active output 503 if (vec == CEED_VECTOR_ACTIVE) { 504 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 505 num_active_out += size; 506 } 507 } 508 impl->num_active_out = num_active_out; 509 } 510 511 // Check sizes 512 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 513 514 // Setup Lvec 515 if (!l_vec) { 516 const CeedSize l_size = (CeedSize)num_blocks * block_size * Q * num_active_in * num_active_out; 517 518 CeedCallBackend(CeedVectorCreate(ceed, l_size, &l_vec)); 519 impl->qf_l_vec = l_vec; 520 } 521 CeedCallBackend(CeedVectorGetArrayWrite(l_vec, CEED_MEM_HOST, &l_vec_array)); 522 523 // Setup block restriction 524 if (!block_rstr) { 525 const CeedInt strides[3] = {1, Q, num_active_in * num_active_out * Q}; 526 527 CeedCallBackend(CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, Q, block_size, num_active_in * num_active_out, 528 num_active_in * num_active_out * num_elem * Q, strides, &block_rstr)); 529 impl->qf_block_rstr = block_rstr; 530 } 531 532 // Build objects if needed 533 if (build_objects) { 534 const CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out; 535 const CeedInt strides[3] = {1, Q, num_active_in * num_active_out * Q}; 536 537 // Create output restriction 538 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed, num_elem, Q, num_active_in * num_active_out, num_active_in * num_active_out * num_elem * Q, 539 strides, rstr)); 540 // Create assembled vector 541 CeedCallBackend(CeedVectorCreate(ceed, l_size, assembled)); 542 } 543 544 // Loop through elements 545 for (CeedInt e = 0; e < num_blocks * block_size; e += block_size) { 546 // Input basis apply 547 CeedCallBackend(CeedOperatorInputBasis_Blocked(e, Q, qf_input_fields, op_input_fields, num_input_fields, block_size, true, e_data_full, impl)); 548 549 // Assemble QFunction 550 for (CeedInt in = 0; in < num_active_in; in++) { 551 // Set Inputs 552 CeedCallBackend(CeedVectorSetValue(active_in[in], 1.0)); 553 if (num_active_in > 1) { 554 CeedCallBackend(CeedVectorSetValue(active_in[(in + num_active_in - 1) % num_active_in], 0.0)); 555 } 556 // Set Outputs 557 for (CeedInt out = 0; out < num_output_fields; out++) { 558 CeedVector vec; 559 560 // Get output vector 561 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 562 // Check if active output 563 if (vec == CEED_VECTOR_ACTIVE) { 564 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_USE_POINTER, l_vec_array)); 565 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size)); 566 l_vec_array += size * Q * block_size; // Advance the pointer by the size of the output 567 } 568 } 569 // Apply QFunction 570 CeedCallBackend(CeedQFunctionApply(qf, Q * block_size, impl->q_vecs_in, impl->q_vecs_out)); 571 } 572 } 573 574 // Un-set output Qvecs to prevent accidental overwrite of Assembled 575 for (CeedInt out = 0; out < num_output_fields; out++) { 576 CeedVector vec; 577 578 // Get output vector 579 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 580 // Check if active output 581 if (vec == CEED_VECTOR_ACTIVE) { 582 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_COPY_VALUES, NULL)); 583 } 584 } 585 586 // Restore input arrays 587 CeedCallBackend(CeedOperatorRestoreInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, true, e_data_full, impl)); 588 589 // Output blocked restriction 590 CeedCallBackend(CeedVectorRestoreArray(l_vec, &l_vec_array)); 591 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 592 CeedCallBackend(CeedElemRestrictionApply(block_rstr, CEED_TRANSPOSE, l_vec, *assembled, request)); 593 return CEED_ERROR_SUCCESS; 594 } 595 596 //------------------------------------------------------------------------------ 597 // Assemble Linear QFunction 598 //------------------------------------------------------------------------------ 599 static int CeedOperatorLinearAssembleQFunction_Blocked(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 600 return CeedOperatorLinearAssembleQFunctionCore_Blocked(op, true, assembled, rstr, request); 601 } 602 603 //------------------------------------------------------------------------------ 604 // Update Assembled Linear QFunction 605 //------------------------------------------------------------------------------ 606 static int CeedOperatorLinearAssembleQFunctionUpdate_Blocked(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 607 return CeedOperatorLinearAssembleQFunctionCore_Blocked(op, false, &assembled, &rstr, request); 608 } 609 610 //------------------------------------------------------------------------------ 611 // Operator Destroy 612 //------------------------------------------------------------------------------ 613 static int CeedOperatorDestroy_Blocked(CeedOperator op) { 614 CeedOperator_Blocked *impl; 615 616 CeedCallBackend(CeedOperatorGetData(op, &impl)); 617 618 for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { 619 CeedCallBackend(CeedElemRestrictionDestroy(&impl->block_restr[i])); 620 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i])); 621 } 622 CeedCallBackend(CeedFree(&impl->block_restr)); 623 CeedCallBackend(CeedFree(&impl->e_vecs_full)); 624 CeedCallBackend(CeedFree(&impl->input_states)); 625 626 for (CeedInt i = 0; i < impl->num_inputs; i++) { 627 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i])); 628 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 629 } 630 CeedCallBackend(CeedFree(&impl->e_vecs_in)); 631 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 632 633 for (CeedInt i = 0; i < impl->num_outputs; i++) { 634 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i])); 635 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 636 } 637 CeedCallBackend(CeedFree(&impl->e_vecs_out)); 638 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 639 640 // QFunction assembly data 641 for (CeedInt i = 0; i < impl->num_active_in; i++) { 642 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i])); 643 } 644 CeedCallBackend(CeedFree(&impl->qf_active_in)); 645 CeedCallBackend(CeedVectorDestroy(&impl->qf_l_vec)); 646 CeedCallBackend(CeedElemRestrictionDestroy(&impl->qf_block_rstr)); 647 648 CeedCallBackend(CeedFree(&impl)); 649 return CEED_ERROR_SUCCESS; 650 } 651 652 //------------------------------------------------------------------------------ 653 // Operator Create 654 //------------------------------------------------------------------------------ 655 int CeedOperatorCreate_Blocked(CeedOperator op) { 656 Ceed ceed; 657 CeedOperator_Blocked *impl; 658 659 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 660 CeedCallBackend(CeedCalloc(1, &impl)); 661 CeedCallBackend(CeedOperatorSetData(op, impl)); 662 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Blocked)); 663 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Blocked)); 664 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Blocked)); 665 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Blocked)); 666 return CEED_ERROR_SUCCESS; 667 } 668 669 //------------------------------------------------------------------------------ 670