xref: /petsc/src/mat/impls/h2opus/math2opussampler.hpp (revision df4cd43f92eaa320656440c40edb1046daee8f75)
1 #include <petscmat.h>
2 #include <h2opus.h>
3 
4 #ifndef __MATH2OPUS_HPP
5   #define __MATH2OPUS_HPP
6 
7 class PetscMatrixSampler : public HMatrixSampler {
8 protected:
9   Mat                                                                    A;
10   typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, H2Opus_Real>::type HRealVector;
11   typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, int>::type         HIntVector;
12   HIntVector                                                             hindexmap;
13   HRealVector                                                            hbuffer_in, hbuffer_out;
14   #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
15   H2OpusDeviceVector<int>         dindexmap;
16   H2OpusDeviceVector<H2Opus_Real> dbuffer_in, dbuffer_out;
17   #endif
18   bool                  gpusampling;
19   h2opusComputeStream_t stream;
20 
21 private:
22   void Init();
23   void VerifyBuffers(int);
24   void PermuteBuffersIn(int, H2Opus_Real *, H2Opus_Real **, H2Opus_Real *, H2Opus_Real **);
25   void PermuteBuffersOut(int, H2Opus_Real *);
26 
27 public:
28   PetscMatrixSampler();
29   PetscMatrixSampler(Mat);
30   ~PetscMatrixSampler();
31   void         SetSamplingMat(Mat);
32   void         SetIndexMap(int, int *);
33   void         SetGPUSampling(bool);
34   void         SetStream(h2opusComputeStream_t);
35   virtual void sample(H2Opus_Real *, H2Opus_Real *, int);
36   Mat          GetSamplingMat() { return A; }
37 };
38 
39 void PetscMatrixSampler::Init()
40 {
41   this->A           = NULL;
42   this->gpusampling = false;
43   this->stream      = NULL;
44 }
45 
46 PetscMatrixSampler::PetscMatrixSampler()
47 {
48   Init();
49 }
50 
51 PetscMatrixSampler::PetscMatrixSampler(Mat A)
52 {
53   Init();
54   SetSamplingMat(A);
55 }
56 
57 void PetscMatrixSampler::SetSamplingMat(Mat A)
58 {
59   PetscMPIInt size = 1;
60 
61   if (A) PetscCallVoid(static_cast<PetscErrorCode>(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size)));
62   if (size > 1) PetscCallVoid(PETSC_ERR_SUP);
63   PetscCallVoid(PetscObjectReference((PetscObject)A));
64   PetscCallVoid(MatDestroy(&this->A));
65   this->A = A;
66 }
67 
68 void PetscMatrixSampler::SetStream(h2opusComputeStream_t stream)
69 {
70   this->stream = stream;
71 }
72 
73 void PetscMatrixSampler::SetIndexMap(int n, int *indexmap)
74 {
75   copyVector(this->hindexmap, indexmap, n, H2OPUS_HWTYPE_CPU);
76   #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
77   copyVector(this->dindexmap, indexmap, n, H2OPUS_HWTYPE_CPU);
78   #endif
79 }
80 
81 void PetscMatrixSampler::VerifyBuffers(int nv)
82 {
83   if (this->hindexmap.size()) {
84     size_t n = this->hindexmap.size();
85     if (!this->gpusampling) {
86       if (hbuffer_in.size() < (size_t)n * nv) hbuffer_in.resize(n * nv);
87       if (hbuffer_out.size() < (size_t)n * nv) hbuffer_out.resize(n * nv);
88     } else {
89   #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
90       if (dbuffer_in.size() < (size_t)n * nv) dbuffer_in.resize(n * nv);
91       if (dbuffer_out.size() < (size_t)n * nv) dbuffer_out.resize(n * nv);
92   #endif
93     }
94   }
95 }
96 
97 void PetscMatrixSampler::PermuteBuffersIn(int nv, H2Opus_Real *v, H2Opus_Real **w, H2Opus_Real *ov, H2Opus_Real **ow)
98 {
99   *w  = v;
100   *ow = ov;
101   VerifyBuffers(nv);
102   if (this->hindexmap.size()) {
103     size_t n = this->hindexmap.size();
104     if (!this->gpusampling) {
105       permute_vectors(v, this->hbuffer_in.data(), n, nv, this->hindexmap.data(), 1, H2OPUS_HWTYPE_CPU, this->stream);
106       *w  = this->hbuffer_in.data();
107       *ow = this->hbuffer_out.data();
108     } else {
109   #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
110       permute_vectors(v, this->dbuffer_in.data(), n, nv, this->dindexmap.data(), 1, H2OPUS_HWTYPE_GPU, this->stream);
111       *w  = this->dbuffer_in.data();
112       *ow = this->dbuffer_out.data();
113   #endif
114     }
115   }
116 }
117 
118 void PetscMatrixSampler::PermuteBuffersOut(int nv, H2Opus_Real *v)
119 {
120   VerifyBuffers(nv);
121   if (this->hindexmap.size()) {
122     size_t n = this->hindexmap.size();
123     if (!this->gpusampling) {
124       permute_vectors(this->hbuffer_out.data(), v, n, nv, this->hindexmap.data(), 0, H2OPUS_HWTYPE_CPU, this->stream);
125     } else {
126   #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
127       permute_vectors(this->dbuffer_out.data(), v, n, nv, this->dindexmap.data(), 0, H2OPUS_HWTYPE_GPU, this->stream);
128   #endif
129     }
130   }
131 }
132 
133 void PetscMatrixSampler::SetGPUSampling(bool gpusampling)
134 {
135   this->gpusampling = gpusampling;
136 }
137 
138 PetscMatrixSampler::~PetscMatrixSampler()
139 {
140   PetscCallVoid(MatDestroy(&A));
141 }
142 
143 void PetscMatrixSampler::sample(H2Opus_Real *x, H2Opus_Real *y, int samples)
144 {
145   MPI_Comm     comm = PetscObjectComm((PetscObject)this->A);
146   Mat          X = NULL, Y = NULL;
147   PetscInt     M, N, m, n;
148   H2Opus_Real *px, *py;
149 
150   if (!this->A) PetscCallVoid(PETSC_ERR_PLIB);
151   PetscCallVoid(MatGetSize(this->A, &M, &N));
152   PetscCallVoid(MatGetLocalSize(this->A, &m, &n));
153   PetscCallVoid(PetscObjectGetComm((PetscObject)A, &comm));
154   PermuteBuffersIn(samples, x, &px, y, &py);
155   if (!this->gpusampling) {
156     PetscCallVoid(MatCreateDense(comm, n, PETSC_DECIDE, N, samples, px, &X));
157     PetscCallVoid(MatCreateDense(comm, m, PETSC_DECIDE, M, samples, py, &Y));
158   } else {
159   #if defined(PETSC_HAVE_CUDA)
160     PetscCallVoid(MatCreateDenseCUDA(comm, n, PETSC_DECIDE, N, samples, px, &X));
161     PetscCallVoid(MatCreateDenseCUDA(comm, m, PETSC_DECIDE, M, samples, py, &Y));
162   #endif
163   }
164   PetscCallVoid(MatMatMult(this->A, X, MAT_REUSE_MATRIX, PETSC_DEFAULT, &Y));
165   #if defined(PETSC_HAVE_CUDA)
166   if (this->gpusampling) {
167     const PetscScalar *dummy;
168     PetscCallVoid(MatDenseCUDAGetArrayRead(Y, &dummy));
169     PetscCallVoid(MatDenseCUDARestoreArrayRead(Y, &dummy));
170   }
171   #endif
172   PermuteBuffersOut(samples, y);
173   PetscCallVoid(MatDestroy(&X));
174   PetscCallVoid(MatDestroy(&Y));
175 }
176 
177 #endif
178