1 #pragma once 2 3 #include <petscmat.h> 4 #include <h2opus.h> 5 6 class PetscMatrixSampler : public HMatrixSampler { 7 protected: 8 Mat A; 9 typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, H2Opus_Real>::type HRealVector; 10 typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, int>::type HIntVector; 11 HIntVector hindexmap; 12 HRealVector hbuffer_in, hbuffer_out; 13 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 14 H2OpusDeviceVector<int> dindexmap; 15 H2OpusDeviceVector<H2Opus_Real> dbuffer_in, dbuffer_out; 16 #endif 17 bool gpusampling; 18 h2opusComputeStream_t stream; 19 20 private: 21 void Init(); 22 void VerifyBuffers(int); 23 void PermuteBuffersIn(int, H2Opus_Real *, H2Opus_Real **, H2Opus_Real *, H2Opus_Real **); 24 void PermuteBuffersOut(int, H2Opus_Real *); 25 26 public: 27 PetscMatrixSampler(); 28 PetscMatrixSampler(Mat); 29 ~PetscMatrixSampler(); 30 void SetSamplingMat(Mat); 31 void SetIndexMap(int, int *); 32 void SetGPUSampling(bool); 33 void SetStream(h2opusComputeStream_t); 34 virtual void sample(H2Opus_Real *, H2Opus_Real *, int); 35 Mat GetSamplingMat() { return A; } 36 }; 37 38 void PetscMatrixSampler::Init() 39 { 40 this->A = NULL; 41 this->gpusampling = false; 42 this->stream = NULL; 43 } 44 45 PetscMatrixSampler::PetscMatrixSampler() 46 { 47 Init(); 48 } 49 50 PetscMatrixSampler::PetscMatrixSampler(Mat A) 51 { 52 Init(); 53 SetSamplingMat(A); 54 } 55 56 void PetscMatrixSampler::SetSamplingMat(Mat A) 57 { 58 PetscMPIInt size = 1; 59 60 if (A) PetscCallVoid(static_cast<PetscErrorCode>(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size))); 61 if (size > 1) PetscCallVoid(PETSC_ERR_SUP); 62 PetscCallVoid(PetscObjectReference((PetscObject)A)); 63 PetscCallVoid(MatDestroy(&this->A)); 64 this->A = A; 65 } 66 67 void PetscMatrixSampler::SetStream(h2opusComputeStream_t stream) 68 { 69 this->stream = stream; 70 } 71 72 void PetscMatrixSampler::SetIndexMap(int n, int *indexmap) 73 { 74 copyVector(this->hindexmap, indexmap, n, H2OPUS_HWTYPE_CPU); 75 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 76 copyVector(this->dindexmap, indexmap, n, H2OPUS_HWTYPE_CPU); 77 #endif 78 } 79 80 void PetscMatrixSampler::VerifyBuffers(int nv) 81 { 82 if (this->hindexmap.size()) { 83 size_t n = this->hindexmap.size(); 84 if (!this->gpusampling) { 85 if (hbuffer_in.size() < (size_t)n * nv) hbuffer_in.resize(n * nv); 86 if (hbuffer_out.size() < (size_t)n * nv) hbuffer_out.resize(n * nv); 87 } else { 88 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 89 if (dbuffer_in.size() < (size_t)n * nv) dbuffer_in.resize(n * nv); 90 if (dbuffer_out.size() < (size_t)n * nv) dbuffer_out.resize(n * nv); 91 #endif 92 } 93 } 94 } 95 96 void PetscMatrixSampler::PermuteBuffersIn(int nv, H2Opus_Real *v, H2Opus_Real **w, H2Opus_Real *ov, H2Opus_Real **ow) 97 { 98 *w = v; 99 *ow = ov; 100 VerifyBuffers(nv); 101 if (this->hindexmap.size()) { 102 size_t n = this->hindexmap.size(); 103 if (!this->gpusampling) { 104 permute_vectors(v, this->hbuffer_in.data(), n, nv, this->hindexmap.data(), 1, H2OPUS_HWTYPE_CPU, this->stream); 105 *w = this->hbuffer_in.data(); 106 *ow = this->hbuffer_out.data(); 107 } else { 108 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 109 permute_vectors(v, this->dbuffer_in.data(), n, nv, this->dindexmap.data(), 1, H2OPUS_HWTYPE_GPU, this->stream); 110 *w = this->dbuffer_in.data(); 111 *ow = this->dbuffer_out.data(); 112 #endif 113 } 114 } 115 } 116 117 void PetscMatrixSampler::PermuteBuffersOut(int nv, H2Opus_Real *v) 118 { 119 VerifyBuffers(nv); 120 if (this->hindexmap.size()) { 121 size_t n = this->hindexmap.size(); 122 if (!this->gpusampling) { 123 permute_vectors(this->hbuffer_out.data(), v, n, nv, this->hindexmap.data(), 0, H2OPUS_HWTYPE_CPU, this->stream); 124 } else { 125 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 126 permute_vectors(this->dbuffer_out.data(), v, n, nv, this->dindexmap.data(), 0, H2OPUS_HWTYPE_GPU, this->stream); 127 #endif 128 } 129 } 130 } 131 132 void PetscMatrixSampler::SetGPUSampling(bool gpusampling) 133 { 134 this->gpusampling = gpusampling; 135 } 136 137 PetscMatrixSampler::~PetscMatrixSampler() 138 { 139 PetscCallVoid(MatDestroy(&A)); 140 } 141 142 void PetscMatrixSampler::sample(H2Opus_Real *x, H2Opus_Real *y, int samples) 143 { 144 MPI_Comm comm = PetscObjectComm((PetscObject)this->A); 145 Mat X = NULL, Y = NULL; 146 PetscInt M, N, m, n; 147 H2Opus_Real *px, *py; 148 VecType vtype; 149 150 if (!this->A) PetscCallVoid(PETSC_ERR_PLIB); 151 PetscCallVoid(MatGetSize(this->A, &M, &N)); 152 PetscCallVoid(MatGetVecType(this->A, &vtype)); 153 PetscCallVoid(MatGetLocalSize(this->A, &m, &n)); 154 PetscCallVoid(PetscObjectGetComm((PetscObject)A, &comm)); 155 PermuteBuffersIn(samples, x, &px, y, &py); 156 if (!this->gpusampling) { 157 PetscCallVoid(MatCreateDense(comm, n, PETSC_DECIDE, N, samples, px, &X)); 158 PetscCallVoid(MatCreateDense(comm, m, PETSC_DECIDE, M, samples, py, &Y)); 159 PetscCallVoid(MatSetVecType(X, vtype)); 160 PetscCallVoid(MatSetVecType(Y, vtype)); 161 162 } else { 163 #if defined(PETSC_HAVE_CUDA) 164 PetscCallVoid(MatCreateDenseCUDA(comm, n, PETSC_DECIDE, N, samples, px, &X)); 165 PetscCallVoid(MatCreateDenseCUDA(comm, m, PETSC_DECIDE, M, samples, py, &Y)); 166 PetscCallVoid(MatSetVecType(X, vtype)); 167 PetscCallVoid(MatSetVecType(Y, vtype)); 168 #endif 169 } 170 PetscCallVoid(MatMatMult(this->A, X, MAT_REUSE_MATRIX, PETSC_DEFAULT, &Y)); 171 #if defined(PETSC_HAVE_CUDA) 172 if (this->gpusampling) { 173 const PetscScalar *dummy; 174 PetscCallVoid(MatDenseCUDAGetArrayRead(Y, &dummy)); 175 PetscCallVoid(MatDenseCUDARestoreArrayRead(Y, &dummy)); 176 } 177 #endif 178 PermuteBuffersOut(samples, y); 179 PetscCallVoid(MatDestroy(&X)); 180 PetscCallVoid(MatDestroy(&Y)); 181 } 182