1 #pragma once
2
3 #include <petscmat.h>
4 #include <h2opus.h>
5
6 class PetscMatrixSampler : public HMatrixSampler {
7 protected:
8 Mat A;
9 typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, H2Opus_Real>::type HRealVector;
10 typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, int>::type HIntVector;
11 HIntVector hindexmap;
12 HRealVector hbuffer_in, hbuffer_out;
13 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
14 H2OpusDeviceVector<int> dindexmap;
15 H2OpusDeviceVector<H2Opus_Real> dbuffer_in, dbuffer_out;
16 #endif
17 bool gpusampling;
18 h2opusComputeStream_t stream;
19
20 private:
21 void Init();
22 void VerifyBuffers(int);
23 void PermuteBuffersIn(int, H2Opus_Real *, H2Opus_Real **, H2Opus_Real *, H2Opus_Real **);
24 void PermuteBuffersOut(int, H2Opus_Real *);
25
26 public:
27 PetscMatrixSampler();
28 PetscMatrixSampler(Mat);
29 ~PetscMatrixSampler();
30 void SetSamplingMat(Mat);
31 void SetIndexMap(int, int *);
32 void SetGPUSampling(bool);
33 void SetStream(h2opusComputeStream_t);
34 virtual void sample(H2Opus_Real *, H2Opus_Real *, int);
GetSamplingMat()35 Mat GetSamplingMat() { return A; }
36 };
37
Init()38 void PetscMatrixSampler::Init()
39 {
40 this->A = NULL;
41 this->gpusampling = false;
42 this->stream = NULL;
43 }
44
PetscMatrixSampler()45 PetscMatrixSampler::PetscMatrixSampler()
46 {
47 Init();
48 }
49
PetscMatrixSampler(Mat A)50 PetscMatrixSampler::PetscMatrixSampler(Mat A)
51 {
52 Init();
53 SetSamplingMat(A);
54 }
55
SetSamplingMat(Mat A)56 void PetscMatrixSampler::SetSamplingMat(Mat A)
57 {
58 PetscMPIInt size = 1;
59
60 if (A) PetscCallVoid(static_cast<PetscErrorCode>(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size)));
61 if (size > 1) PetscCallVoid(PETSC_ERR_SUP);
62 PetscCallVoid(PetscObjectReference((PetscObject)A));
63 PetscCallVoid(MatDestroy(&this->A));
64 this->A = A;
65 }
66
SetStream(h2opusComputeStream_t stream)67 void PetscMatrixSampler::SetStream(h2opusComputeStream_t stream)
68 {
69 this->stream = stream;
70 }
71
SetIndexMap(int n,int * indexmap)72 void PetscMatrixSampler::SetIndexMap(int n, int *indexmap)
73 {
74 copyVector(this->hindexmap, indexmap, n, H2OPUS_HWTYPE_CPU);
75 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
76 copyVector(this->dindexmap, indexmap, n, H2OPUS_HWTYPE_CPU);
77 #endif
78 }
79
VerifyBuffers(int nv)80 void PetscMatrixSampler::VerifyBuffers(int nv)
81 {
82 if (this->hindexmap.size()) {
83 size_t n = this->hindexmap.size();
84 if (!this->gpusampling) {
85 if (hbuffer_in.size() < (size_t)n * nv) hbuffer_in.resize(n * nv);
86 if (hbuffer_out.size() < (size_t)n * nv) hbuffer_out.resize(n * nv);
87 } else {
88 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
89 if (dbuffer_in.size() < (size_t)n * nv) dbuffer_in.resize(n * nv);
90 if (dbuffer_out.size() < (size_t)n * nv) dbuffer_out.resize(n * nv);
91 #endif
92 }
93 }
94 }
95
PermuteBuffersIn(int nv,H2Opus_Real * v,H2Opus_Real ** w,H2Opus_Real * ov,H2Opus_Real ** ow)96 void PetscMatrixSampler::PermuteBuffersIn(int nv, H2Opus_Real *v, H2Opus_Real **w, H2Opus_Real *ov, H2Opus_Real **ow)
97 {
98 *w = v;
99 *ow = ov;
100 VerifyBuffers(nv);
101 if (this->hindexmap.size()) {
102 size_t n = this->hindexmap.size();
103 if (!this->gpusampling) {
104 permute_vectors(v, this->hbuffer_in.data(), n, nv, this->hindexmap.data(), 1, H2OPUS_HWTYPE_CPU, this->stream);
105 *w = this->hbuffer_in.data();
106 *ow = this->hbuffer_out.data();
107 } else {
108 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
109 permute_vectors(v, this->dbuffer_in.data(), n, nv, this->dindexmap.data(), 1, H2OPUS_HWTYPE_GPU, this->stream);
110 *w = this->dbuffer_in.data();
111 *ow = this->dbuffer_out.data();
112 #endif
113 }
114 }
115 }
116
PermuteBuffersOut(int nv,H2Opus_Real * v)117 void PetscMatrixSampler::PermuteBuffersOut(int nv, H2Opus_Real *v)
118 {
119 VerifyBuffers(nv);
120 if (this->hindexmap.size()) {
121 size_t n = this->hindexmap.size();
122 if (!this->gpusampling) {
123 permute_vectors(this->hbuffer_out.data(), v, n, nv, this->hindexmap.data(), 0, H2OPUS_HWTYPE_CPU, this->stream);
124 } else {
125 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
126 permute_vectors(this->dbuffer_out.data(), v, n, nv, this->dindexmap.data(), 0, H2OPUS_HWTYPE_GPU, this->stream);
127 #endif
128 }
129 }
130 }
131
SetGPUSampling(bool gpusampling)132 void PetscMatrixSampler::SetGPUSampling(bool gpusampling)
133 {
134 this->gpusampling = gpusampling;
135 }
136
~PetscMatrixSampler()137 PetscMatrixSampler::~PetscMatrixSampler()
138 {
139 PetscCallVoid(MatDestroy(&A));
140 }
141
sample(H2Opus_Real * x,H2Opus_Real * y,int samples)142 void PetscMatrixSampler::sample(H2Opus_Real *x, H2Opus_Real *y, int samples)
143 {
144 MPI_Comm comm = PetscObjectComm((PetscObject)this->A);
145 Mat X = NULL, Y = NULL;
146 PetscInt M, N, m, n;
147 H2Opus_Real *px, *py;
148 VecType vtype;
149
150 if (!this->A) PetscCallVoid(PETSC_ERR_PLIB);
151 PetscCallVoid(MatGetSize(this->A, &M, &N));
152 PetscCallVoid(MatGetVecType(this->A, &vtype));
153 PetscCallVoid(MatGetLocalSize(this->A, &m, &n));
154 PetscCallVoid(PetscObjectGetComm((PetscObject)A, &comm));
155 PermuteBuffersIn(samples, x, &px, y, &py);
156 if (!this->gpusampling) {
157 PetscCallVoid(MatCreateDense(comm, n, PETSC_DECIDE, N, samples, px, &X));
158 PetscCallVoid(MatCreateDense(comm, m, PETSC_DECIDE, M, samples, py, &Y));
159 PetscCallVoid(MatSetVecType(X, vtype));
160 PetscCallVoid(MatSetVecType(Y, vtype));
161
162 } else {
163 #if defined(PETSC_HAVE_CUDA)
164 PetscCallVoid(MatCreateDenseCUDA(comm, n, PETSC_DECIDE, N, samples, px, &X));
165 PetscCallVoid(MatCreateDenseCUDA(comm, m, PETSC_DECIDE, M, samples, py, &Y));
166 PetscCallVoid(MatSetVecType(X, vtype));
167 PetscCallVoid(MatSetVecType(Y, vtype));
168 #endif
169 }
170 PetscCallVoid(MatMatMult(this->A, X, MAT_REUSE_MATRIX, PETSC_DETERMINE, &Y));
171 #if defined(PETSC_HAVE_CUDA)
172 if (this->gpusampling) {
173 const PetscScalar *dummy;
174 PetscCallVoid(MatDenseCUDAGetArrayRead(Y, &dummy));
175 PetscCallVoid(MatDenseCUDARestoreArrayRead(Y, &dummy));
176 }
177 #endif
178 PermuteBuffersOut(samples, y);
179 PetscCallVoid(MatDestroy(&X));
180 PetscCallVoid(MatDestroy(&Y));
181 }
182