1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
32f86a920SJeremy L Thompson //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
52f86a920SJeremy L Thompson //
63d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed
72f86a920SJeremy L Thompson
849aac155SJeremy L Thompson #include <ceed.h>
9ec3da8bcSJed Brown #include <ceed/backend.h>
102b730f8bSJeremy L Thompson
112f86a920SJeremy L Thompson #include "ceed-ref.h"
122f86a920SJeremy L Thompson
13f10650afSjeremylt //------------------------------------------------------------------------------
14f10650afSjeremylt // Tensor Contract Apply
15f10650afSjeremylt //------------------------------------------------------------------------------
CeedTensorContractApply_Ref(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)162b730f8bSJeremy L Thompson static int CeedTensorContractApply_Ref(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
172b730f8bSJeremy L Thompson CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
18d1d35e2fSjeremylt CeedInt t_stride_0 = B, t_stride_1 = 1;
19ad70ee2cSJeremy L Thompson
20d1d35e2fSjeremylt if (t_mode == CEED_TRANSPOSE) {
212b730f8bSJeremy L Thompson t_stride_0 = 1;
222b730f8bSJeremy L Thompson t_stride_1 = J;
232f86a920SJeremy L Thompson }
242f86a920SJeremy L Thompson
252b730f8bSJeremy L Thompson if (!add) {
262b730f8bSJeremy L Thompson for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0;
272b730f8bSJeremy L Thompson }
282f86a920SJeremy L Thompson
292b730f8bSJeremy L Thompson for (CeedInt a = 0; a < A; a++) {
302b730f8bSJeremy L Thompson for (CeedInt b = 0; b < B; b++) {
312f86a920SJeremy L Thompson for (CeedInt j = 0; j < J; j++) {
32d1d35e2fSjeremylt CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
332b730f8bSJeremy L Thompson for (CeedInt c = 0; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c];
342b730f8bSJeremy L Thompson }
352b730f8bSJeremy L Thompson }
362f86a920SJeremy L Thompson }
37e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS;
382f86a920SJeremy L Thompson }
392f86a920SJeremy L Thompson
40f10650afSjeremylt //------------------------------------------------------------------------------
41f10650afSjeremylt // Tensor Contract Destroy
42f10650afSjeremylt //------------------------------------------------------------------------------
CeedTensorContractDestroy_Ref(CeedTensorContract contract)432b730f8bSJeremy L Thompson static int CeedTensorContractDestroy_Ref(CeedTensorContract contract) { return CEED_ERROR_SUCCESS; }
442f86a920SJeremy L Thompson
45f10650afSjeremylt //------------------------------------------------------------------------------
46f10650afSjeremylt // Tensor Contract Create
47f10650afSjeremylt //------------------------------------------------------------------------------
CeedTensorContractCreate_Ref(CeedTensorContract contract)48a71faab1SSebastian Grimberg int CeedTensorContractCreate_Ref(CeedTensorContract contract) {
492f86a920SJeremy L Thompson Ceed ceed;
502f86a920SJeremy L Thompson
51ad70ee2cSJeremy L Thompson CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed));
522b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Ref));
532b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy", CeedTensorContractDestroy_Ref));
549bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
55e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS;
562f86a920SJeremy L Thompson }
572a86cc9dSSebastian Grimberg
58f10650afSjeremylt //------------------------------------------------------------------------------
59