xref: /libCEED/backends/ref/ceed-ref-restriction.c (revision 4e35ef053a4ca6d42bd4fa6764313dea5d12ef4c)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-impl.h>
18 #include <string.h>
19 #include "ceed-ref.h"
20 
21 static int CeedElemRestrictionApply_Ref(CeedElemRestriction r,
22                                         CeedTransposeMode tmode,
23                                         CeedTransposeMode lmode, CeedVector u,
24                                         CeedVector v, CeedRequest *request) {
25   CeedElemRestriction_Ref *impl = r->data;
26   int ierr;
27   const CeedScalar *uu;
28   CeedScalar *vv;
29   CeedInt nblk = r->nblk, blksize = r->blksize, elemsize = r->elemsize,
30            esize = nblk*blksize*elemsize, ncomp=r->ncomp;
31 
32   ierr = CeedVectorGetArrayRead(u, CEED_MEM_HOST, &uu); CeedChk(ierr);
33   ierr = CeedVectorGetArray(v, CEED_MEM_HOST, &vv); CeedChk(ierr);
34   if (tmode == CEED_NOTRANSPOSE) {
35     // Perform: v = r * u
36     if (!impl->indices) {
37       for (CeedInt i = 0; i<nblk - 1; i++) {
38         CeedInt shift = i*blksize*ncomp*elemsize;
39         for (CeedInt j = 0; j<blksize; j++) {
40           for (CeedInt k = 0; k<ncomp*elemsize; k++) {
41             vv[shift + k*blksize + j] = uu[shift + j*ncomp*elemsize + k];
42           }
43         }
44       }
45       CeedInt shift = (nblk - 1)*blksize*ncomp*elemsize;
46       CeedInt nlastelems = r->nelem % nblk;
47       if (nlastelems == 0) nlastelems = blksize;
48       for (CeedInt j = 0; j<blksize; j++) {
49         for (CeedInt k = 0; k<ncomp*elemsize; k++) {
50           if (j < nlastelems) {
51             vv[shift + k*blksize + j] = uu[shift + j*ncomp*elemsize + k];
52           } else {
53             vv[shift + k*blksize + j] = uu[shift + (nlastelems - 1)*ncomp*elemsize + k];
54           }
55         }
56       }
57     } else if (ncomp == 1) {
58       for (CeedInt i = 0; i<esize; i++) vv[i] = uu[impl->indices[i]];
59     } else {
60       // vv is (elemsize x ncomp x nelem), column-major
61       if (lmode == CEED_NOTRANSPOSE) { // u is (ndof x ncomp), column-major
62         for (CeedInt e = 0; e < nblk*blksize; e++)
63           for (CeedInt d = 0; d < ncomp; d++)
64             for (CeedInt i = 0; i<r->elemsize; i++) {
65               vv[i+r->elemsize*(d+ncomp*e)] =
66                 uu[impl->indices[i+r->elemsize*e]+r->ndof*d];
67             }
68       } else { // u is (ncomp x ndof), column-major
69         for (CeedInt e = 0; e < r->nblk*blksize; e++) {
70           for (CeedInt d = 0; d < ncomp; d++) {
71             for (CeedInt i = 0; i<r->elemsize; i++) {
72               vv[i+r->elemsize*(d+ncomp*e)] =
73                 uu[d+ncomp*impl->indices[i+r->elemsize*e]];
74             }
75           }
76         }
77       }
78     }
79   } else {
80     // Note: in transpose mode, we perform: v += r^t * u
81     esize = (nblk - 1)*blksize*elemsize;
82     if (!impl->indices) {
83       for (CeedInt i=0; i<nblk - 1; i++) {
84         CeedInt shift = i*blksize*ncomp*elemsize;
85         for (CeedInt j = 0; j<blksize; j++) {
86           for (CeedInt k = 0; k<ncomp*elemsize; k++) {
87             vv[shift + j*ncomp*elemsize + k] = uu[shift + k*blksize + j];
88           }
89         }
90       }
91       CeedInt shift = (nblk - 1)*blksize*ncomp*elemsize;
92       CeedInt nlastelems = r->nelem % nblk;
93       if (nlastelems == 0) nlastelems = blksize;
94       for (CeedInt j = 0; j<blksize; j++) {
95         for (CeedInt k = 0; k<ncomp*elemsize; k++) {
96           if (j < nlastelems) {
97             vv[shift + j*ncomp*elemsize + k] = uu[shift + k*blksize + j];
98           }
99         }
100       }
101     } else if (ncomp == 1) {
102       for (CeedInt i = 0; i<esize; i++) vv[impl->indices[i]] += uu[i];
103       CeedInt nlastelems = r->nelem % blksize;
104       CeedInt shift = (nblk - 1)*blksize*elemsize;
105       if (nlastelems == 0) nlastelems = blksize;
106       for (CeedInt i = 0; i<blksize*elemsize; i++) {
107         if ((i % blksize) < nlastelems) {
108           vv[impl->indices[shift + i]] += uu[shift + i];
109         }
110       }
111     } else {
112       // u is (elemsize x ncomp x nelem)
113       if (lmode == CEED_NOTRANSPOSE) { // vv is (ndof x ncomp), column-major
114         for (CeedInt e = 0; e < blksize * (nblk - 1); e++) {
115           for (CeedInt d = 0; d < ncomp; d++) {
116             for (CeedInt i = 0; i<elemsize; i++) {
117               vv[impl->indices[i+elemsize*e]+r->ndof*d] +=
118                 uu[i+elemsize*(d+e*ncomp)];
119             }
120           }
121         }
122       CeedInt shift = (nblk - 1)*blksize*elemsize;
123       CeedInt nlastelems = r->nelem % blksize;
124       if (nlastelems == 0) nlastelems = blksize;
125         for (CeedInt e = 0; e < nlastelems; e++) {
126           for (CeedInt d = 0; d < ncomp; d++) {
127             for (CeedInt i = 0; i<elemsize; i++) {
128               vv[impl->indices[i+elemsize*(e+shift)]+r->ndof*d] +=
129                 uu[i+elemsize*(d+(e+shift)*ncomp)];
130             }
131           }
132         }
133       } else { // vv is (ncomp x ndof), column-major
134         for (CeedInt e = 0; e < blksize * (nblk - 1); e++) {
135           for (CeedInt d = 0; d < ncomp; d++) {
136             for (CeedInt i = 0; i<elemsize; i++) {
137               vv[d+ncomp*impl->indices[i+elemsize*e]] +=
138                 uu[i+r->elemsize*(d+e*ncomp)];
139             }
140           }
141         }
142         CeedInt shift = (nblk - 1)*blksize*elemsize;
143         CeedInt nlastelems = r->nelem % blksize;
144         for (CeedInt e = 0; e < nlastelems; e++) {
145           for (CeedInt d = 0; d < ncomp; d++) {
146             for (CeedInt i = 0; i<elemsize; i++) {
147               vv[d+ncomp*impl->indices[i+elemsize*(e+shift)]] +=
148                 uu[i+r->elemsize*(d+(e+shift)*ncomp)];
149             }
150           }
151         }
152       }
153     }
154   }
155   ierr = CeedVectorRestoreArrayRead(u, &uu); CeedChk(ierr);
156   ierr = CeedVectorRestoreArray(v, &vv); CeedChk(ierr);
157   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED)
158     *request = NULL;
159   return 0;
160 }
161 
162 static int CeedElemRestrictionDestroy_Ref(CeedElemRestriction r) {
163   CeedElemRestriction_Ref *impl = r->data;
164   int ierr;
165 
166   ierr = CeedFree(&impl->indices_allocated); CeedChk(ierr);
167   ierr = CeedFree(&r->data); CeedChk(ierr);
168   return 0;
169 }
170 
171 int CeedElemRestrictionCreate_Ref(CeedElemRestriction r,
172                                   CeedMemType mtype,
173                                   CeedCopyMode cmode, const CeedInt *indices) {
174   int ierr;
175   CeedElemRestriction_Ref *impl;
176 
177   if (mtype != CEED_MEM_HOST)
178     return CeedError(r->ceed, 1, "Only MemType = HOST supported");
179   ierr = CeedCalloc(1,&impl); CeedChk(ierr);
180   switch (cmode) {
181   case CEED_COPY_VALUES:
182     ierr = CeedMalloc(r->nelem*r->elemsize, &impl->indices_allocated);
183     CeedChk(ierr);
184     memcpy(impl->indices_allocated, indices,
185            r->nelem * r->elemsize * sizeof(indices[0]));
186     impl->indices = impl->indices_allocated;
187     break;
188   case CEED_OWN_POINTER:
189     impl->indices_allocated = (CeedInt *)indices;
190     impl->indices = impl->indices_allocated;
191     break;
192   case CEED_USE_POINTER:
193     impl->indices = indices;
194   }
195   r->data = impl;
196   r->Apply = CeedElemRestrictionApply_Ref;
197   r->Destroy = CeedElemRestrictionDestroy_Ref;
198   return 0;
199 }
200