xref: /libCEED/interface/ceed-tensor.c (revision 7b87cde0397b48cc601e25bb3ecac5ddabea754c)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed-impl.h>
9 #include <ceed.h>
10 #include <ceed/backend.h>
11 #include <stddef.h>
12 
13 /// @file
14 /// Implementation of CeedTensorContract interfaces
15 
16 /// ----------------------------------------------------------------------------
17 /// CeedTensorContract Backend API
18 /// ----------------------------------------------------------------------------
19 /// @addtogroup CeedBasisBackend
20 /// @{
21 
22 /**
23   @brief Create a CeedTensorContract object for a CeedBasis
24 
25   @param[in]  ceed     Ceed object where the CeedTensorContract will be created
26   @param[in]  basis    CeedBasis for which the tensor contraction will be used
27   @param[out] contract Address of the variable where the newly created CeedTensorContract will be stored.
28 
29   @return An error code: 0 - success, otherwise - failure
30 
31   @ref Backend
32 **/
33 int CeedTensorContractCreate(Ceed ceed, CeedBasis basis, CeedTensorContract *contract) {
34   if (!ceed->TensorContractCreate) {
35     Ceed delegate;
36     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
37 
38     if (!delegate) {
39       // LCOV_EXCL_START
40       return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TensorContractCreate");
41       // LCOV_EXCL_STOP
42     }
43 
44     CeedCall(CeedTensorContractCreate(delegate, basis, contract));
45     return CEED_ERROR_SUCCESS;
46   }
47 
48   CeedCall(CeedCalloc(1, contract));
49 
50   (*contract)->ceed = ceed;
51   CeedCall(CeedReference(ceed));
52   CeedCall(ceed->TensorContractCreate(basis, *contract));
53   return CEED_ERROR_SUCCESS;
54 }
55 
56 /**
57   @brief Apply tensor contraction
58 
59     Contracts on the middle index
60     NOTRANSPOSE: v_ajc = t_jb u_abc
61     TRANSPOSE:   v_ajc = t_bj u_abc
62     If add != 0, "=" is replaced by "+="
63 
64   @param[in]  contract CeedTensorContract to use
65   @param[in]  A        First index of u, v
66   @param[in]  B        Middle index of u, one index of t
67   @param[in]  C        Last index of u, v
68   @param[in]  J        Middle index of v, one index of t
69   @param[in]  t        Tensor array to contract against
70   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj
71   @param[in]  add      Add mode
72   @param[in]  u        Input array
73   @param[out] v        Output array
74 
75   @return An error code: 0 - success, otherwise - failure
76 
77   @ref Backend
78 **/
79 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
80                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
81   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
82   return CEED_ERROR_SUCCESS;
83 }
84 
85 /**
86   @brief Get Ceed associated with a CeedTensorContract
87 
88   @param[in]  contract CeedTensorContract
89   @param[out] ceed     Variable to store Ceed
90 
91   @return An error code: 0 - success, otherwise - failure
92 
93   @ref Backend
94 **/
95 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
96   *ceed = contract->ceed;
97   return CEED_ERROR_SUCCESS;
98 }
99 
100 /**
101   @brief Get backend data of a CeedTensorContract
102 
103   @param[in]  contract CeedTensorContract
104   @param[out] data     Variable to store data
105 
106   @return An error code: 0 - success, otherwise - failure
107 
108   @ref Backend
109 **/
110 int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
111   *(void **)data = contract->data;
112   return CEED_ERROR_SUCCESS;
113 }
114 
115 /**
116   @brief Set backend data of a CeedTensorContract
117 
118   @param[in,out] contract CeedTensorContract
119   @param[in]     data     Data to set
120 
121   @return An error code: 0 - success, otherwise - failure
122 
123   @ref Backend
124 **/
125 int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
126   contract->data = data;
127   return CEED_ERROR_SUCCESS;
128 }
129 
130 /**
131   @brief Increment the reference counter for a CeedTensorContract
132 
133   @param[in,out] contract CeedTensorContract to increment the reference counter
134 
135   @return An error code: 0 - success, otherwise - failure
136 
137   @ref Backend
138 **/
139 int CeedTensorContractReference(CeedTensorContract contract) {
140   contract->ref_count++;
141   return CEED_ERROR_SUCCESS;
142 }
143 
144 /**
145   @brief Destroy a CeedTensorContract
146 
147   @param[in,out] contract CeedTensorContract to destroy
148 
149   @return An error code: 0 - success, otherwise - failure
150 
151   @ref Backend
152 **/
153 int CeedTensorContractDestroy(CeedTensorContract *contract) {
154   if (!*contract || --(*contract)->ref_count > 0) {
155     *contract = NULL;
156     return CEED_ERROR_SUCCESS;
157   }
158   if ((*contract)->Destroy) {
159     CeedCall((*contract)->Destroy(*contract));
160   }
161   CeedCall(CeedDestroy(&(*contract)->ceed));
162   CeedCall(CeedFree(contract));
163   return CEED_ERROR_SUCCESS;
164 }
165 
166 /// @}
167