xref: /libCEED/backends/opt/ceed-opt-tensor.c (revision 019b76820d7ff306c177822c4e76ffe5939c204b)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include "ceed-opt.h"
11 
12 //------------------------------------------------------------------------------
13 // Tensor Contract Core loop
14 //------------------------------------------------------------------------------
15 static inline int CeedTensorContractApply_Core_Opt(CeedTensorContract contract,
16     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
17     CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u,
18     CeedScalar *restrict v) {
19   CeedInt t_stride_0 = B, t_stride_1 = 1;
20   if (t_mode == CEED_TRANSPOSE) {
21     t_stride_0 = 1; t_stride_1 = J;
22   }
23 
24   for (CeedInt a=0; a<A; a++)
25     for (CeedInt b=0; b<B; b++)
26       for (CeedInt j=0; j<J; j++) {
27         CeedScalar tq = t[j*t_stride_0 + b*t_stride_1];
28         for (CeedInt c=0; c<C; c++)
29           v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c];
30       }
31 
32   return CEED_ERROR_SUCCESS;
33 }
34 
35 //------------------------------------------------------------------------------
36 // Tensor Contract Apply
37 //------------------------------------------------------------------------------
38 static int CeedTensorContractApply_Opt(CeedTensorContract contract, CeedInt A,
39                                        CeedInt B, CeedInt C, CeedInt J,
40                                        const CeedScalar *restrict t,
41                                        CeedTransposeMode t_mode, const CeedInt add,
42                                        const CeedScalar *restrict u,
43                                        CeedScalar *restrict v) {
44   if (!add)
45     for (CeedInt q=0; q<A*J*C; q++)
46       v[q] = (CeedScalar) 0.0;
47 
48   if (C == 1)
49     return CeedTensorContractApply_Core_Opt(contract, A, B, 1, J, t, t_mode,
50                                             add, u, v);
51   else
52     return CeedTensorContractApply_Core_Opt(contract, A, B, C, J, t, t_mode,
53                                             add, u, v);
54 
55   return CEED_ERROR_SUCCESS;
56 }
57 
58 //------------------------------------------------------------------------------
59 // Tensor Contract Create
60 //------------------------------------------------------------------------------
61 int CeedTensorContractCreate_Opt(CeedBasis basis, CeedTensorContract contract) {
62   int ierr;
63   Ceed ceed;
64   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChkBackend(ierr);
65 
66   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
67                                 CeedTensorContractApply_Opt); CeedChkBackend(ierr);
68 
69   return CEED_ERROR_SUCCESS;
70 }
71 //------------------------------------------------------------------------------
72