xref: /libCEED/rust/libceed-sys/c-src/interface/ceed-tensor.c (revision 3d8e882215d238700cdceb37404f76ca7fa24eaa)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed-impl.h>
11 
12 /// @file
13 /// Implementation of CeedTensorContract interfaces
14 
15 /// ----------------------------------------------------------------------------
16 /// CeedTensorContract Backend API
17 /// ----------------------------------------------------------------------------
18 /// @addtogroup CeedBasisBackend
19 /// @{
20 
21 /**
22   @brief Create a CeedTensorContract object for a CeedBasis
23 
24   @param ceed           A Ceed object where the CeedTensorContract will be created
25   @param basis          CeedBasis for which the tensor contraction will be used
26   @param[out] contract  Address of the variable where the newly created
27                           CeedTensorContract will be stored.
28 
29   @return An error code: 0 - success, otherwise - failure
30 
31   @ref Backend
32 **/
33 int CeedTensorContractCreate(Ceed ceed, CeedBasis basis,
34                              CeedTensorContract *contract) {
35   int ierr;
36 
37   if (!ceed->TensorContractCreate) {
38     Ceed delegate;
39     ierr = CeedGetObjectDelegate(ceed, &delegate, "TensorContract");
40     CeedChk(ierr);
41 
42     if (!delegate)
43       // LCOV_EXCL_START
44       return CeedError(ceed, CEED_ERROR_UNSUPPORTED,
45                        "Backend does not support TensorContractCreate");
46     // LCOV_EXCL_STOP
47 
48     ierr = CeedTensorContractCreate(delegate, basis, contract);
49     CeedChk(ierr);
50     return CEED_ERROR_SUCCESS;
51   }
52 
53   ierr = CeedCalloc(1, contract); CeedChk(ierr);
54 
55   (*contract)->ceed = ceed;
56   ierr = CeedReference(ceed); CeedChk(ierr);
57   ierr = ceed->TensorContractCreate(basis, *contract);
58   CeedChk(ierr);
59   return CEED_ERROR_SUCCESS;
60 }
61 
62 /**
63   @brief Apply tensor contraction
64 
65     Contracts on the middle index
66     NOTRANSPOSE: v_ajc = t_jb u_abc
67     TRANSPOSE:   v_ajc = t_bj u_abc
68     If add != 0, "=" is replaced by "+="
69 
70   @param contract  CeedTensorContract to use
71   @param A         First index of u, v
72   @param B         Middle index of u, one index of t
73   @param C         Last index of u, v
74   @param J         Middle index of v, one index of t
75   @param[in] t     Tensor array to contract against
76   @param t_mode    Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb
77                      \ref CEED_TRANSPOSE for t_bj
78   @param add       Add mode
79   @param[in] u     Input array
80   @param[out] v    Output array
81 
82   @return An error code: 0 - success, otherwise - failure
83 
84   @ref Backend
85 **/
86 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B,
87                             CeedInt C, CeedInt J, const CeedScalar *restrict t,
88                             CeedTransposeMode t_mode, const CeedInt add,
89                             const CeedScalar *restrict u,
90                             CeedScalar *restrict v) {
91   int ierr;
92 
93   ierr = contract->Apply(contract, A, B, C, J, t, t_mode, add,  u, v);
94   CeedChk(ierr);
95   return CEED_ERROR_SUCCESS;
96 }
97 
98 /**
99   @brief Get Ceed associated with a CeedTensorContract
100 
101   @param contract   CeedTensorContract
102   @param[out] ceed  Variable to store Ceed
103 
104   @return An error code: 0 - success, otherwise - failure
105 
106   @ref Backend
107 **/
108 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
109   *ceed = contract->ceed;
110   return CEED_ERROR_SUCCESS;
111 }
112 
113 /**
114   @brief Get backend data of a CeedTensorContract
115 
116   @param contract   CeedTensorContract
117   @param[out] data  Variable to store data
118 
119   @return An error code: 0 - success, otherwise - failure
120 
121   @ref Backend
122 **/
123 int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
124   *(void **)data = contract->data;
125   return CEED_ERROR_SUCCESS;
126 }
127 
128 /**
129   @brief Set backend data of a CeedTensorContract
130 
131   @param[out] contract  CeedTensorContract
132   @param data           Data to set
133 
134   @return An error code: 0 - success, otherwise - failure
135 
136   @ref Backend
137 **/
138 int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
139   contract->data = data;
140   return CEED_ERROR_SUCCESS;
141 }
142 
143 /**
144   @brief Increment the reference counter for a CeedTensorContract
145 
146   @param contract  CeedTensorContract to increment the reference counter
147 
148   @return An error code: 0 - success, otherwise - failure
149 
150   @ref Backend
151 **/
152 int CeedTensorContractReference(CeedTensorContract contract) {
153   contract->ref_count++;
154   return CEED_ERROR_SUCCESS;
155 }
156 
157 /**
158   @brief Destroy a CeedTensorContract
159 
160   @param contract  CeedTensorContract to destroy
161 
162   @return An error code: 0 - success, otherwise - failure
163 
164   @ref Backend
165 **/
166 int CeedTensorContractDestroy(CeedTensorContract *contract) {
167   int ierr;
168 
169   if (!*contract || --(*contract)->ref_count > 0) return CEED_ERROR_SUCCESS;
170   if ((*contract)->Destroy) {
171     ierr = (*contract)->Destroy(*contract); CeedChk(ierr);
172   }
173   ierr = CeedDestroy(&(*contract)->ceed); CeedChk(ierr);
174   ierr = CeedFree(contract); CeedChk(ierr);
175   return CEED_ERROR_SUCCESS;
176 }
177 
178 /// @}
179