xref: /petsc/src/mat/impls/h2opus/math2opussampler.hpp (revision 58d68138c660dfb4e9f5b03334792cd4f2ffd7cc)
1 #include <petscmat.h>
2 #include <h2opus.h>
3 
4 #ifndef __MATH2OPUS_HPP
5 #define __MATH2OPUS_HPP
6 
7 class PetscMatrixSampler : public HMatrixSampler {
8 protected:
9   Mat                                                                    A;
10   typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, H2Opus_Real>::type HRealVector;
11   typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, int>::type         HIntVector;
12   HIntVector                                                             hindexmap;
13   HRealVector                                                            hbuffer_in, hbuffer_out;
14 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
15   H2OpusDeviceVector<int>         dindexmap;
16   H2OpusDeviceVector<H2Opus_Real> dbuffer_in, dbuffer_out;
17 #endif
18   bool                  gpusampling;
19   h2opusComputeStream_t stream;
20 
21 private:
22   void Init();
23   void VerifyBuffers(int);
24   void PermuteBuffersIn(int, H2Opus_Real *, H2Opus_Real **, H2Opus_Real *, H2Opus_Real **);
25   void PermuteBuffersOut(int, H2Opus_Real *);
26 
27 public:
28   PetscMatrixSampler();
29   PetscMatrixSampler(Mat);
30   ~PetscMatrixSampler();
31   void         SetSamplingMat(Mat);
32   void         SetIndexMap(int, int *);
33   void         SetGPUSampling(bool);
34   void         SetStream(h2opusComputeStream_t);
35   virtual void sample(H2Opus_Real *, H2Opus_Real *, int);
36   Mat          GetSamplingMat() {
37              return A;
38   }
39 };
40 
41 void PetscMatrixSampler::Init() {
42   this->A           = NULL;
43   this->gpusampling = false;
44   this->stream      = NULL;
45 }
46 
47 PetscMatrixSampler::PetscMatrixSampler() {
48   Init();
49 }
50 
51 PetscMatrixSampler::PetscMatrixSampler(Mat A) {
52   Init();
53   SetSamplingMat(A);
54 }
55 
56 void PetscMatrixSampler::SetSamplingMat(Mat A) {
57   PetscMPIInt size = 1;
58 
59   if (A) PetscCallVoid(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
60   if (size > 1) PetscCallVoid(PETSC_ERR_SUP);
61   PetscCallVoid(PetscObjectReference((PetscObject)A));
62   PetscCallVoid(MatDestroy(&this->A));
63   this->A = A;
64 }
65 
66 void PetscMatrixSampler::SetStream(h2opusComputeStream_t stream) {
67   this->stream = stream;
68 }
69 
70 void PetscMatrixSampler::SetIndexMap(int n, int *indexmap) {
71   copyVector(this->hindexmap, indexmap, n, H2OPUS_HWTYPE_CPU);
72 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
73   copyVector(this->dindexmap, indexmap, n, H2OPUS_HWTYPE_CPU);
74 #endif
75 }
76 
77 void PetscMatrixSampler::VerifyBuffers(int nv) {
78   if (this->hindexmap.size()) {
79     size_t n = this->hindexmap.size();
80     if (!this->gpusampling) {
81       if (hbuffer_in.size() < (size_t)n * nv) hbuffer_in.resize(n * nv);
82       if (hbuffer_out.size() < (size_t)n * nv) hbuffer_out.resize(n * nv);
83     } else {
84 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
85       if (dbuffer_in.size() < (size_t)n * nv) dbuffer_in.resize(n * nv);
86       if (dbuffer_out.size() < (size_t)n * nv) dbuffer_out.resize(n * nv);
87 #endif
88     }
89   }
90 }
91 
92 void PetscMatrixSampler::PermuteBuffersIn(int nv, H2Opus_Real *v, H2Opus_Real **w, H2Opus_Real *ov, H2Opus_Real **ow) {
93   *w  = v;
94   *ow = ov;
95   VerifyBuffers(nv);
96   if (this->hindexmap.size()) {
97     size_t n = this->hindexmap.size();
98     if (!this->gpusampling) {
99       permute_vectors(v, this->hbuffer_in.data(), n, nv, this->hindexmap.data(), 1, H2OPUS_HWTYPE_CPU, this->stream);
100       *w  = this->hbuffer_in.data();
101       *ow = this->hbuffer_out.data();
102     } else {
103 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
104       permute_vectors(v, this->dbuffer_in.data(), n, nv, this->dindexmap.data(), 1, H2OPUS_HWTYPE_GPU, this->stream);
105       *w  = this->dbuffer_in.data();
106       *ow = this->dbuffer_out.data();
107 #endif
108     }
109   }
110 }
111 
112 void PetscMatrixSampler::PermuteBuffersOut(int nv, H2Opus_Real *v) {
113   VerifyBuffers(nv);
114   if (this->hindexmap.size()) {
115     size_t n = this->hindexmap.size();
116     if (!this->gpusampling) {
117       permute_vectors(this->hbuffer_out.data(), v, n, nv, this->hindexmap.data(), 0, H2OPUS_HWTYPE_CPU, this->stream);
118     } else {
119 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
120       permute_vectors(this->dbuffer_out.data(), v, n, nv, this->dindexmap.data(), 0, H2OPUS_HWTYPE_GPU, this->stream);
121 #endif
122     }
123   }
124 }
125 
126 void PetscMatrixSampler::SetGPUSampling(bool gpusampling) {
127   this->gpusampling = gpusampling;
128 }
129 
130 PetscMatrixSampler::~PetscMatrixSampler() {
131   PetscCallVoid(MatDestroy(&A));
132 }
133 
134 void PetscMatrixSampler::sample(H2Opus_Real *x, H2Opus_Real *y, int samples) {
135   MPI_Comm     comm = PetscObjectComm((PetscObject)this->A);
136   Mat          X = NULL, Y = NULL;
137   PetscInt     M, N, m, n;
138   H2Opus_Real *px, *py;
139 
140   if (!this->A) PetscCallVoid(PETSC_ERR_PLIB);
141   PetscCallVoid(MatGetSize(this->A, &M, &N));
142   PetscCallVoid(MatGetLocalSize(this->A, &m, &n));
143   PetscCallVoid(PetscObjectGetComm((PetscObject)A, &comm));
144   PermuteBuffersIn(samples, x, &px, y, &py);
145   if (!this->gpusampling) {
146     PetscCallVoid(MatCreateDense(comm, n, PETSC_DECIDE, N, samples, px, &X));
147     PetscCallVoid(MatCreateDense(comm, m, PETSC_DECIDE, M, samples, py, &Y));
148   } else {
149 #if defined(PETSC_HAVE_CUDA)
150     PetscCallVoid(MatCreateDenseCUDA(comm, n, PETSC_DECIDE, N, samples, px, &X));
151     PetscCallVoid(MatCreateDenseCUDA(comm, m, PETSC_DECIDE, M, samples, py, &Y));
152 #endif
153   }
154   PetscCallVoid(PetscLogObjectParent((PetscObject)this->A, (PetscObject)X));
155   PetscCallVoid(PetscLogObjectParent((PetscObject)this->A, (PetscObject)Y));
156   PetscCallVoid(MatMatMult(this->A, X, MAT_REUSE_MATRIX, PETSC_DEFAULT, &Y));
157 #if defined(PETSC_HAVE_CUDA)
158   if (this->gpusampling) {
159     const PetscScalar *dummy;
160     PetscCallVoid(MatDenseCUDAGetArrayRead(Y, &dummy));
161     PetscCallVoid(MatDenseCUDARestoreArrayRead(Y, &dummy));
162   }
163 #endif
164   PermuteBuffersOut(samples, y);
165   PetscCallVoid(MatDestroy(&X));
166   PetscCallVoid(MatDestroy(&Y));
167 }
168 
169 #endif
170