xref: /petsc/src/mat/impls/h2opus/math2opussampler.hpp (revision 8fb5bd83c3955fefcf33a54e3bb66920a9fa884b)
1 #include <petscmat.h>
2 #include <h2opus.h>
3 
4 #ifndef __MATH2OPUS_HPP
5 #define __MATH2OPUS_HPP
6 
7 class PetscMatrixSampler : public HMatrixSampler
8 {
9 protected:
10   Mat  A;
11   typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, H2Opus_Real>::type HRealVector;
12   typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, int>::type HIntVector;
13   HIntVector hindexmap;
14   HRealVector hbuffer_in,hbuffer_out;
15 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
16   H2OpusDeviceVector<int> dindexmap;
17   H2OpusDeviceVector<H2Opus_Real> dbuffer_in,dbuffer_out;
18 #endif
19   bool gpusampling;
20   h2opusComputeStream_t stream;
21 
22 private:
23   void Init();
24   void VerifyBuffers(int);
25   void PermuteBuffersIn(int,H2Opus_Real*,H2Opus_Real**,H2Opus_Real*,H2Opus_Real**);
26   void PermuteBuffersOut(int,H2Opus_Real*);
27 
28 public:
29   PetscMatrixSampler();
30   PetscMatrixSampler(Mat);
31   ~PetscMatrixSampler();
32   void SetSamplingMat(Mat);
33   void SetIndexMap(int,int*);
34   void SetGPUSampling(bool);
35   void SetStream(h2opusComputeStream_t);
36   virtual void sample(H2Opus_Real*,H2Opus_Real*,int);
37   Mat GetSamplingMat() { return A; }
38 };
39 
40 void PetscMatrixSampler::Init()
41 {
42   this->A = NULL;
43   this->gpusampling = false;
44   this->stream = NULL;
45 }
46 
47 PetscMatrixSampler::PetscMatrixSampler()
48 {
49   Init();
50 }
51 
52 PetscMatrixSampler::PetscMatrixSampler(Mat A)
53 {
54   Init();
55   SetSamplingMat(A);
56 }
57 
58 void PetscMatrixSampler::SetSamplingMat(Mat A)
59 {
60   PetscMPIInt    size = 1;
61 
62   if (A) PetscCallVoid(MPI_Comm_size(PetscObjectComm((PetscObject)A),&size));
63   if (size > 1) PetscCallVoid(PETSC_ERR_SUP);
64   PetscCallVoid(PetscObjectReference((PetscObject)A));
65   PetscCallVoid(MatDestroy(&this->A));
66   this->A = A;
67 }
68 
69 void PetscMatrixSampler::SetStream(h2opusComputeStream_t stream)
70 {
71   this->stream = stream;
72 }
73 
74 void PetscMatrixSampler::SetIndexMap(int n,int *indexmap)
75 {
76   copyVector(this->hindexmap, indexmap, n, H2OPUS_HWTYPE_CPU);
77 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
78   copyVector(this->dindexmap, indexmap, n, H2OPUS_HWTYPE_CPU);
79 #endif
80 }
81 
82 void PetscMatrixSampler::VerifyBuffers(int nv)
83 {
84   if (this->hindexmap.size()) {
85     size_t n = this->hindexmap.size();
86     if (!this->gpusampling) {
87       if (hbuffer_in.size() < (size_t)n * nv)
88           hbuffer_in.resize(n * nv);
89       if (hbuffer_out.size() < (size_t)n * nv)
90           hbuffer_out.resize(n * nv);
91     } else {
92 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
93       if (dbuffer_in.size() < (size_t)n * nv)
94           dbuffer_in.resize(n * nv);
95       if (dbuffer_out.size() < (size_t)n * nv)
96           dbuffer_out.resize(n * nv);
97 #endif
98     }
99   }
100 }
101 
102 void PetscMatrixSampler::PermuteBuffersIn(int nv, H2Opus_Real *v, H2Opus_Real **w, H2Opus_Real *ov, H2Opus_Real **ow)
103 {
104   *w = v;
105   *ow = ov;
106   VerifyBuffers(nv);
107   if (this->hindexmap.size()) {
108     size_t n = this->hindexmap.size();
109     if (!this->gpusampling) {
110       permute_vectors(v, this->hbuffer_in.data(), n, nv, this->hindexmap.data(), 1, H2OPUS_HWTYPE_CPU,
111                       this->stream);
112       *w = this->hbuffer_in.data();
113       *ow = this->hbuffer_out.data();
114     } else {
115 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
116       permute_vectors(v, this->dbuffer_in.data(), n, nv, this->dindexmap.data(), 1, H2OPUS_HWTYPE_GPU,
117                       this->stream);
118       *w = this->dbuffer_in.data();
119       *ow = this->dbuffer_out.data();
120 #endif
121     }
122   }
123 }
124 
125 void PetscMatrixSampler::PermuteBuffersOut(int nv, H2Opus_Real *v)
126 {
127   VerifyBuffers(nv);
128   if (this->hindexmap.size()) {
129     size_t n = this->hindexmap.size();
130     if (!this->gpusampling) {
131       permute_vectors(this->hbuffer_out.data(), v, n, nv, this->hindexmap.data(), 0, H2OPUS_HWTYPE_CPU,
132                       this->stream);
133     } else {
134 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
135       permute_vectors(this->dbuffer_out.data(), v, n, nv, this->dindexmap.data(), 0, H2OPUS_HWTYPE_GPU,
136                       this->stream);
137 #endif
138     }
139   }
140 }
141 
142 void PetscMatrixSampler::SetGPUSampling(bool gpusampling)
143 {
144   this->gpusampling = gpusampling;
145 }
146 
147 PetscMatrixSampler::~PetscMatrixSampler()
148 {
149   PetscCallVoid(MatDestroy(&A));
150 }
151 
152 void PetscMatrixSampler::sample(H2Opus_Real *x, H2Opus_Real *y, int samples)
153 {
154   MPI_Comm       comm = PetscObjectComm((PetscObject)this->A);
155   Mat            X = NULL,Y = NULL;
156   PetscInt       M,N,m,n;
157   H2Opus_Real    *px,*py;
158 
159   if (!this->A) PetscCallVoid(PETSC_ERR_PLIB);
160   PetscCallVoid(MatGetSize(this->A,&M,&N));
161   PetscCallVoid(MatGetLocalSize(this->A,&m,&n));
162   PetscCallVoid(PetscObjectGetComm((PetscObject)A,&comm));
163   PermuteBuffersIn(samples,x,&px,y,&py);
164   if (!this->gpusampling) {
165     PetscCallVoid(MatCreateDense(comm,n,PETSC_DECIDE,N,samples,px,&X));
166     PetscCallVoid(MatCreateDense(comm,m,PETSC_DECIDE,M,samples,py,&Y));
167   } else {
168 #if defined(PETSC_HAVE_CUDA)
169     PetscCallVoid(MatCreateDenseCUDA(comm,n,PETSC_DECIDE,N,samples,px,&X));
170     PetscCallVoid(MatCreateDenseCUDA(comm,m,PETSC_DECIDE,M,samples,py,&Y));
171 #endif
172   }
173   PetscCallVoid(PetscLogObjectParent((PetscObject)this->A,(PetscObject)X));
174   PetscCallVoid(PetscLogObjectParent((PetscObject)this->A,(PetscObject)Y));
175   PetscCallVoid(MatMatMult(this->A,X,MAT_REUSE_MATRIX,PETSC_DEFAULT,&Y));
176 #if defined(PETSC_HAVE_CUDA)
177   if (this->gpusampling) {
178     const PetscScalar *dummy;
179     PetscCallVoid(MatDenseCUDAGetArrayRead(Y,&dummy));
180     PetscCallVoid(MatDenseCUDARestoreArrayRead(Y,&dummy));
181   }
182 #endif
183   PermuteBuffersOut(samples,y);
184   PetscCallVoid(MatDestroy(&X));
185   PetscCallVoid(MatDestroy(&Y));
186 }
187 
188 #endif
189