xref: /libCEED/rust/libceed-sys/c-src/backends/opt/ceed-opt-tensor.c (revision 2b730f8b5a9c809740a0b3b302db43a719c636b1)
13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
39b2a10adSJeremy L Thompson //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
59b2a10adSJeremy L Thompson //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
79b2a10adSJeremy L Thompson 
89b2a10adSJeremy L Thompson #include <ceed/backend.h>
9*2b730f8bSJeremy L Thompson #include <ceed/ceed.h>
10*2b730f8bSJeremy L Thompson 
119b2a10adSJeremy L Thompson #include "ceed-opt.h"
129b2a10adSJeremy L Thompson 
139b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
149b2a10adSJeremy L Thompson // Tensor Contract Core loop
159b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
16*2b730f8bSJeremy L Thompson static inline int CeedTensorContractApply_Core_Opt(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
17*2b730f8bSJeremy L Thompson                                                    const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
18*2b730f8bSJeremy L Thompson                                                    const CeedScalar *restrict u, CeedScalar *restrict v) {
199b2a10adSJeremy L Thompson   CeedInt t_stride_0 = B, t_stride_1 = 1;
209b2a10adSJeremy L Thompson   if (t_mode == CEED_TRANSPOSE) {
21*2b730f8bSJeremy L Thompson     t_stride_0 = 1;
22*2b730f8bSJeremy L Thompson     t_stride_1 = J;
239b2a10adSJeremy L Thompson   }
249b2a10adSJeremy L Thompson 
25*2b730f8bSJeremy L Thompson   for (CeedInt a = 0; a < A; a++) {
26*2b730f8bSJeremy L Thompson     for (CeedInt b = 0; b < B; b++) {
279b2a10adSJeremy L Thompson       for (CeedInt j = 0; j < J; j++) {
289b2a10adSJeremy L Thompson         CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
29*2b730f8bSJeremy L Thompson         for (CeedInt c = 0; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c];
30*2b730f8bSJeremy L Thompson       }
31*2b730f8bSJeremy L Thompson     }
329b2a10adSJeremy L Thompson   }
339b2a10adSJeremy L Thompson 
349b2a10adSJeremy L Thompson   return CEED_ERROR_SUCCESS;
359b2a10adSJeremy L Thompson }
369b2a10adSJeremy L Thompson 
379b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
389b2a10adSJeremy L Thompson // Tensor Contract Apply
399b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
40*2b730f8bSJeremy L Thompson static int CeedTensorContractApply_Opt(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
41*2b730f8bSJeremy L Thompson                                        CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
42*2b730f8bSJeremy L Thompson   if (!add) {
43*2b730f8bSJeremy L Thompson     for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0;
44*2b730f8bSJeremy L Thompson   }
459b2a10adSJeremy L Thompson 
46*2b730f8bSJeremy L Thompson   if (C == 1) return CeedTensorContractApply_Core_Opt(contract, A, B, 1, J, t, t_mode, add, u, v);
47*2b730f8bSJeremy L Thompson   else return CeedTensorContractApply_Core_Opt(contract, A, B, C, J, t, t_mode, add, u, v);
489b2a10adSJeremy L Thompson 
499b2a10adSJeremy L Thompson   return CEED_ERROR_SUCCESS;
509b2a10adSJeremy L Thompson }
519b2a10adSJeremy L Thompson 
529b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
539b2a10adSJeremy L Thompson // Tensor Contract Create
549b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
559b2a10adSJeremy L Thompson int CeedTensorContractCreate_Opt(CeedBasis basis, CeedTensorContract contract) {
569b2a10adSJeremy L Thompson   Ceed ceed;
57*2b730f8bSJeremy L Thompson   CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed));
589b2a10adSJeremy L Thompson 
59*2b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Opt));
609b2a10adSJeremy L Thompson 
619b2a10adSJeremy L Thompson   return CEED_ERROR_SUCCESS;
629b2a10adSJeremy L Thompson }
639b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
64