1*ff1e7120SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*ff1e7120SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*ff1e7120SSebastian Grimberg // 4*ff1e7120SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5*ff1e7120SSebastian Grimberg // 6*ff1e7120SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7*ff1e7120SSebastian Grimberg 8*ff1e7120SSebastian Grimberg #include <ceed.h> 9*ff1e7120SSebastian Grimberg #include <ceed/backend.h> 10*ff1e7120SSebastian Grimberg #include <ceed/jit-tools.h> 11*ff1e7120SSebastian Grimberg #include <cuda.h> 12*ff1e7120SSebastian Grimberg #include <cuda_runtime.h> 13*ff1e7120SSebastian Grimberg #include <stdbool.h> 14*ff1e7120SSebastian Grimberg #include <stddef.h> 15*ff1e7120SSebastian Grimberg #include <string.h> 16*ff1e7120SSebastian Grimberg 17*ff1e7120SSebastian Grimberg #include "../cuda/ceed-cuda-common.h" 18*ff1e7120SSebastian Grimberg #include "../cuda/ceed-cuda-compile.h" 19*ff1e7120SSebastian Grimberg #include "ceed-cuda-ref.h" 20*ff1e7120SSebastian Grimberg 21*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 22*ff1e7120SSebastian Grimberg // Apply restriction 23*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 24*ff1e7120SSebastian Grimberg static int CeedElemRestrictionApply_Cuda(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 25*ff1e7120SSebastian Grimberg CeedElemRestriction_Cuda *impl; 26*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 27*ff1e7120SSebastian Grimberg Ceed ceed; 28*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 29*ff1e7120SSebastian Grimberg Ceed_Cuda *data; 30*ff1e7120SSebastian Grimberg CeedCallBackend(CeedGetData(ceed, &data)); 31*ff1e7120SSebastian Grimberg const CeedInt warp_size = 32; 32*ff1e7120SSebastian Grimberg const CeedInt block_size = warp_size; 33*ff1e7120SSebastian Grimberg const CeedInt num_nodes = impl->num_nodes; 34*ff1e7120SSebastian Grimberg CeedInt num_elem, elem_size; 35*ff1e7120SSebastian Grimberg CeedElemRestrictionGetNumElements(r, &num_elem); 36*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 37*ff1e7120SSebastian Grimberg CUfunction kernel; 38*ff1e7120SSebastian Grimberg 39*ff1e7120SSebastian Grimberg // Get vectors 40*ff1e7120SSebastian Grimberg const CeedScalar *d_u; 41*ff1e7120SSebastian Grimberg CeedScalar *d_v; 42*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 43*ff1e7120SSebastian Grimberg if (t_mode == CEED_TRANSPOSE) { 44*ff1e7120SSebastian Grimberg // Sum into for transpose mode, e-vec to l-vec 45*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 46*ff1e7120SSebastian Grimberg } else { 47*ff1e7120SSebastian Grimberg // Overwrite for notranspose mode, l-vec to e-vec 48*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 49*ff1e7120SSebastian Grimberg } 50*ff1e7120SSebastian Grimberg 51*ff1e7120SSebastian Grimberg // Restrict 52*ff1e7120SSebastian Grimberg if (t_mode == CEED_NOTRANSPOSE) { 53*ff1e7120SSebastian Grimberg // L-vector -> E-vector 54*ff1e7120SSebastian Grimberg if (impl->d_ind) { 55*ff1e7120SSebastian Grimberg // -- Offsets provided 56*ff1e7120SSebastian Grimberg kernel = impl->OffsetNoTranspose; 57*ff1e7120SSebastian Grimberg void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 58*ff1e7120SSebastian Grimberg CeedInt block_size = elem_size < 1024 ? (elem_size > 32 ? elem_size : 32) : 1024; 59*ff1e7120SSebastian Grimberg CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 60*ff1e7120SSebastian Grimberg } else { 61*ff1e7120SSebastian Grimberg // -- Strided restriction 62*ff1e7120SSebastian Grimberg kernel = impl->StridedNoTranspose; 63*ff1e7120SSebastian Grimberg void *args[] = {&num_elem, &d_u, &d_v}; 64*ff1e7120SSebastian Grimberg CeedInt block_size = elem_size < 1024 ? (elem_size > 32 ? elem_size : 32) : 1024; 65*ff1e7120SSebastian Grimberg CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 66*ff1e7120SSebastian Grimberg } 67*ff1e7120SSebastian Grimberg } else { 68*ff1e7120SSebastian Grimberg // E-vector -> L-vector 69*ff1e7120SSebastian Grimberg if (impl->d_ind) { 70*ff1e7120SSebastian Grimberg // -- Offsets provided 71*ff1e7120SSebastian Grimberg kernel = impl->OffsetTranspose; 72*ff1e7120SSebastian Grimberg void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 73*ff1e7120SSebastian Grimberg CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 74*ff1e7120SSebastian Grimberg } else { 75*ff1e7120SSebastian Grimberg // -- Strided restriction 76*ff1e7120SSebastian Grimberg kernel = impl->StridedTranspose; 77*ff1e7120SSebastian Grimberg void *args[] = {&num_elem, &d_u, &d_v}; 78*ff1e7120SSebastian Grimberg CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 79*ff1e7120SSebastian Grimberg } 80*ff1e7120SSebastian Grimberg } 81*ff1e7120SSebastian Grimberg 82*ff1e7120SSebastian Grimberg if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 83*ff1e7120SSebastian Grimberg 84*ff1e7120SSebastian Grimberg // Restore arrays 85*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 86*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 87*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 88*ff1e7120SSebastian Grimberg } 89*ff1e7120SSebastian Grimberg 90*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 91*ff1e7120SSebastian Grimberg // Get offsets 92*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 93*ff1e7120SSebastian Grimberg static int CeedElemRestrictionGetOffsets_Cuda(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) { 94*ff1e7120SSebastian Grimberg CeedElemRestriction_Cuda *impl; 95*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 96*ff1e7120SSebastian Grimberg 97*ff1e7120SSebastian Grimberg switch (mem_type) { 98*ff1e7120SSebastian Grimberg case CEED_MEM_HOST: 99*ff1e7120SSebastian Grimberg *offsets = impl->h_ind; 100*ff1e7120SSebastian Grimberg break; 101*ff1e7120SSebastian Grimberg case CEED_MEM_DEVICE: 102*ff1e7120SSebastian Grimberg *offsets = impl->d_ind; 103*ff1e7120SSebastian Grimberg break; 104*ff1e7120SSebastian Grimberg } 105*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 106*ff1e7120SSebastian Grimberg } 107*ff1e7120SSebastian Grimberg 108*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 109*ff1e7120SSebastian Grimberg // Destroy restriction 110*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 111*ff1e7120SSebastian Grimberg static int CeedElemRestrictionDestroy_Cuda(CeedElemRestriction r) { 112*ff1e7120SSebastian Grimberg CeedElemRestriction_Cuda *impl; 113*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 114*ff1e7120SSebastian Grimberg 115*ff1e7120SSebastian Grimberg Ceed ceed; 116*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 117*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cuModuleUnload(impl->module)); 118*ff1e7120SSebastian Grimberg CeedCallBackend(CeedFree(&impl->h_ind_allocated)); 119*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaFree(impl->d_ind_allocated)); 120*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaFree(impl->d_t_offsets)); 121*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaFree(impl->d_t_indices)); 122*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaFree(impl->d_l_vec_indices)); 123*ff1e7120SSebastian Grimberg CeedCallBackend(CeedFree(&impl)); 124*ff1e7120SSebastian Grimberg 125*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 126*ff1e7120SSebastian Grimberg } 127*ff1e7120SSebastian Grimberg 128*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 129*ff1e7120SSebastian Grimberg // Create transpose offsets and indices 130*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 131*ff1e7120SSebastian Grimberg static int CeedElemRestrictionOffset_Cuda(const CeedElemRestriction r, const CeedInt *indices) { 132*ff1e7120SSebastian Grimberg Ceed ceed; 133*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 134*ff1e7120SSebastian Grimberg CeedElemRestriction_Cuda *impl; 135*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 136*ff1e7120SSebastian Grimberg CeedSize l_size; 137*ff1e7120SSebastian Grimberg CeedInt num_elem, elem_size, num_comp; 138*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 139*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 140*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size)); 141*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 142*ff1e7120SSebastian Grimberg 143*ff1e7120SSebastian Grimberg // Count num_nodes 144*ff1e7120SSebastian Grimberg bool *is_node; 145*ff1e7120SSebastian Grimberg CeedCallBackend(CeedCalloc(l_size, &is_node)); 146*ff1e7120SSebastian Grimberg const CeedInt size_indices = num_elem * elem_size; 147*ff1e7120SSebastian Grimberg for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 148*ff1e7120SSebastian Grimberg CeedInt num_nodes = 0; 149*ff1e7120SSebastian Grimberg for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 150*ff1e7120SSebastian Grimberg impl->num_nodes = num_nodes; 151*ff1e7120SSebastian Grimberg 152*ff1e7120SSebastian Grimberg // L-vector offsets array 153*ff1e7120SSebastian Grimberg CeedInt *ind_to_offset, *l_vec_indices; 154*ff1e7120SSebastian Grimberg CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 155*ff1e7120SSebastian Grimberg CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 156*ff1e7120SSebastian Grimberg CeedInt j = 0; 157*ff1e7120SSebastian Grimberg for (CeedInt i = 0; i < l_size; i++) { 158*ff1e7120SSebastian Grimberg if (is_node[i]) { 159*ff1e7120SSebastian Grimberg l_vec_indices[j] = i; 160*ff1e7120SSebastian Grimberg ind_to_offset[i] = j++; 161*ff1e7120SSebastian Grimberg } 162*ff1e7120SSebastian Grimberg } 163*ff1e7120SSebastian Grimberg CeedCallBackend(CeedFree(&is_node)); 164*ff1e7120SSebastian Grimberg 165*ff1e7120SSebastian Grimberg // Compute transpose offsets and indices 166*ff1e7120SSebastian Grimberg const CeedInt size_offsets = num_nodes + 1; 167*ff1e7120SSebastian Grimberg CeedInt *t_offsets; 168*ff1e7120SSebastian Grimberg CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 169*ff1e7120SSebastian Grimberg CeedInt *t_indices; 170*ff1e7120SSebastian Grimberg CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 171*ff1e7120SSebastian Grimberg // Count node multiplicity 172*ff1e7120SSebastian Grimberg for (CeedInt e = 0; e < num_elem; ++e) { 173*ff1e7120SSebastian Grimberg for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 174*ff1e7120SSebastian Grimberg } 175*ff1e7120SSebastian Grimberg // Convert to running sum 176*ff1e7120SSebastian Grimberg for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 177*ff1e7120SSebastian Grimberg // List all E-vec indices associated with L-vec node 178*ff1e7120SSebastian Grimberg for (CeedInt e = 0; e < num_elem; ++e) { 179*ff1e7120SSebastian Grimberg for (CeedInt i = 0; i < elem_size; ++i) { 180*ff1e7120SSebastian Grimberg const CeedInt lid = elem_size * e + i; 181*ff1e7120SSebastian Grimberg const CeedInt gid = indices[lid]; 182*ff1e7120SSebastian Grimberg t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 183*ff1e7120SSebastian Grimberg } 184*ff1e7120SSebastian Grimberg } 185*ff1e7120SSebastian Grimberg // Reset running sum 186*ff1e7120SSebastian Grimberg for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 187*ff1e7120SSebastian Grimberg t_offsets[0] = 0; 188*ff1e7120SSebastian Grimberg 189*ff1e7120SSebastian Grimberg // Copy data to device 190*ff1e7120SSebastian Grimberg // -- L-vector indices 191*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt))); 192*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), cudaMemcpyHostToDevice)); 193*ff1e7120SSebastian Grimberg // -- Transpose offsets 194*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt))); 195*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), cudaMemcpyHostToDevice)); 196*ff1e7120SSebastian Grimberg // -- Transpose indices 197*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt))); 198*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), cudaMemcpyHostToDevice)); 199*ff1e7120SSebastian Grimberg 200*ff1e7120SSebastian Grimberg // Cleanup 201*ff1e7120SSebastian Grimberg CeedCallBackend(CeedFree(&ind_to_offset)); 202*ff1e7120SSebastian Grimberg CeedCallBackend(CeedFree(&l_vec_indices)); 203*ff1e7120SSebastian Grimberg CeedCallBackend(CeedFree(&t_offsets)); 204*ff1e7120SSebastian Grimberg CeedCallBackend(CeedFree(&t_indices)); 205*ff1e7120SSebastian Grimberg 206*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 207*ff1e7120SSebastian Grimberg } 208*ff1e7120SSebastian Grimberg 209*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 210*ff1e7120SSebastian Grimberg // Create restriction 211*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 212*ff1e7120SSebastian Grimberg int CeedElemRestrictionCreate_Cuda(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, CeedElemRestriction r) { 213*ff1e7120SSebastian Grimberg Ceed ceed; 214*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 215*ff1e7120SSebastian Grimberg CeedElemRestriction_Cuda *impl; 216*ff1e7120SSebastian Grimberg CeedCallBackend(CeedCalloc(1, &impl)); 217*ff1e7120SSebastian Grimberg CeedInt num_elem, num_comp, elem_size; 218*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 219*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 220*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 221*ff1e7120SSebastian Grimberg CeedInt size = num_elem * elem_size; 222*ff1e7120SSebastian Grimberg CeedInt strides[3] = {1, size, elem_size}; 223*ff1e7120SSebastian Grimberg CeedInt comp_stride = 1; 224*ff1e7120SSebastian Grimberg 225*ff1e7120SSebastian Grimberg // Stride data 226*ff1e7120SSebastian Grimberg bool is_strided; 227*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided)); 228*ff1e7120SSebastian Grimberg if (is_strided) { 229*ff1e7120SSebastian Grimberg bool has_backend_strides; 230*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides)); 231*ff1e7120SSebastian Grimberg if (!has_backend_strides) { 232*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides)); 233*ff1e7120SSebastian Grimberg } 234*ff1e7120SSebastian Grimberg } else { 235*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride)); 236*ff1e7120SSebastian Grimberg } 237*ff1e7120SSebastian Grimberg 238*ff1e7120SSebastian Grimberg impl->h_ind = NULL; 239*ff1e7120SSebastian Grimberg impl->h_ind_allocated = NULL; 240*ff1e7120SSebastian Grimberg impl->d_ind = NULL; 241*ff1e7120SSebastian Grimberg impl->d_ind_allocated = NULL; 242*ff1e7120SSebastian Grimberg impl->d_t_indices = NULL; 243*ff1e7120SSebastian Grimberg impl->d_t_offsets = NULL; 244*ff1e7120SSebastian Grimberg impl->num_nodes = size; 245*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionSetData(r, impl)); 246*ff1e7120SSebastian Grimberg CeedInt layout[3] = {1, elem_size * num_elem, elem_size}; 247*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionSetELayout(r, layout)); 248*ff1e7120SSebastian Grimberg 249*ff1e7120SSebastian Grimberg // Set up device indices/offset arrays 250*ff1e7120SSebastian Grimberg switch (mem_type) { 251*ff1e7120SSebastian Grimberg case CEED_MEM_HOST: { 252*ff1e7120SSebastian Grimberg switch (copy_mode) { 253*ff1e7120SSebastian Grimberg case CEED_OWN_POINTER: 254*ff1e7120SSebastian Grimberg impl->h_ind_allocated = (CeedInt *)indices; 255*ff1e7120SSebastian Grimberg impl->h_ind = (CeedInt *)indices; 256*ff1e7120SSebastian Grimberg break; 257*ff1e7120SSebastian Grimberg case CEED_USE_POINTER: 258*ff1e7120SSebastian Grimberg impl->h_ind = (CeedInt *)indices; 259*ff1e7120SSebastian Grimberg break; 260*ff1e7120SSebastian Grimberg case CEED_COPY_VALUES: 261*ff1e7120SSebastian Grimberg if (indices != NULL) { 262*ff1e7120SSebastian Grimberg CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 263*ff1e7120SSebastian Grimberg memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt)); 264*ff1e7120SSebastian Grimberg impl->h_ind = impl->h_ind_allocated; 265*ff1e7120SSebastian Grimberg } 266*ff1e7120SSebastian Grimberg break; 267*ff1e7120SSebastian Grimberg } 268*ff1e7120SSebastian Grimberg if (indices != NULL) { 269*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 270*ff1e7120SSebastian Grimberg impl->d_ind_allocated = impl->d_ind; // We own the device memory 271*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), cudaMemcpyHostToDevice)); 272*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionOffset_Cuda(r, indices)); 273*ff1e7120SSebastian Grimberg } 274*ff1e7120SSebastian Grimberg break; 275*ff1e7120SSebastian Grimberg } 276*ff1e7120SSebastian Grimberg case CEED_MEM_DEVICE: { 277*ff1e7120SSebastian Grimberg switch (copy_mode) { 278*ff1e7120SSebastian Grimberg case CEED_COPY_VALUES: 279*ff1e7120SSebastian Grimberg if (indices != NULL) { 280*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 281*ff1e7120SSebastian Grimberg impl->d_ind_allocated = impl->d_ind; // We own the device memory 282*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), cudaMemcpyDeviceToDevice)); 283*ff1e7120SSebastian Grimberg } 284*ff1e7120SSebastian Grimberg break; 285*ff1e7120SSebastian Grimberg case CEED_OWN_POINTER: 286*ff1e7120SSebastian Grimberg impl->d_ind = (CeedInt *)indices; 287*ff1e7120SSebastian Grimberg impl->d_ind_allocated = impl->d_ind; 288*ff1e7120SSebastian Grimberg break; 289*ff1e7120SSebastian Grimberg case CEED_USE_POINTER: 290*ff1e7120SSebastian Grimberg impl->d_ind = (CeedInt *)indices; 291*ff1e7120SSebastian Grimberg } 292*ff1e7120SSebastian Grimberg if (indices != NULL) { 293*ff1e7120SSebastian Grimberg CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 294*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), cudaMemcpyDeviceToHost)); 295*ff1e7120SSebastian Grimberg impl->h_ind = impl->h_ind_allocated; 296*ff1e7120SSebastian Grimberg CeedCallBackend(CeedElemRestrictionOffset_Cuda(r, indices)); 297*ff1e7120SSebastian Grimberg } 298*ff1e7120SSebastian Grimberg break; 299*ff1e7120SSebastian Grimberg } 300*ff1e7120SSebastian Grimberg // LCOV_EXCL_START 301*ff1e7120SSebastian Grimberg default: 302*ff1e7120SSebastian Grimberg return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported"); 303*ff1e7120SSebastian Grimberg // LCOV_EXCL_STOP 304*ff1e7120SSebastian Grimberg } 305*ff1e7120SSebastian Grimberg 306*ff1e7120SSebastian Grimberg // Compile CUDA kernels 307*ff1e7120SSebastian Grimberg CeedInt num_nodes = impl->num_nodes; 308*ff1e7120SSebastian Grimberg char *restriction_kernel_path, *restriction_kernel_source; 309*ff1e7120SSebastian Grimberg CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-restriction.h", &restriction_kernel_path)); 310*ff1e7120SSebastian Grimberg CeedDebug256(ceed, 2, "----- Loading Restriction Kernel Source -----\n"); 311*ff1e7120SSebastian Grimberg CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 312*ff1e7120SSebastian Grimberg CeedDebug256(ceed, 2, "----- Loading Restriction Kernel Source Complete! -----\n"); 313*ff1e7120SSebastian Grimberg CeedCallBackend(CeedCompile_Cuda(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem, 314*ff1e7120SSebastian Grimberg "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES", 315*ff1e7120SSebastian Grimberg strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2])); 316*ff1e7120SSebastian Grimberg CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose)); 317*ff1e7120SSebastian Grimberg CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose)); 318*ff1e7120SSebastian Grimberg CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose)); 319*ff1e7120SSebastian Grimberg CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose)); 320*ff1e7120SSebastian Grimberg CeedCallBackend(CeedFree(&restriction_kernel_path)); 321*ff1e7120SSebastian Grimberg CeedCallBackend(CeedFree(&restriction_kernel_source)); 322*ff1e7120SSebastian Grimberg 323*ff1e7120SSebastian Grimberg // Register backend functions 324*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Cuda)); 325*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Cuda)); 326*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Cuda)); 327*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Cuda)); 328*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 329*ff1e7120SSebastian Grimberg } 330*ff1e7120SSebastian Grimberg 331*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 332