1 2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/ 3 4 PETSC_EXTERN PetscErrorCode VecGetRootType_Private(Vec, VecType *); 5 6 typedef struct { 7 Mat A; /* sparse matrix */ 8 Mat U, V; /* dense tall-skinny matrices */ 9 Vec c; /* sequential vector containing the diagonal of C */ 10 Vec work1, work2; /* sequential vectors that hold partial products */ 11 Vec xl, yl; /* auxiliary sequential vectors for matmult operation */ 12 } Mat_LRC; 13 14 static PetscErrorCode MatMult_LRC_kernel(Mat N, Vec x, Vec y, PetscBool transpose) { 15 Mat_LRC *Na = (Mat_LRC *)N->data; 16 PetscMPIInt size; 17 Mat U, V; 18 19 PetscFunctionBegin; 20 U = transpose ? Na->V : Na->U; 21 V = transpose ? Na->U : Na->V; 22 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)N), &size)); 23 if (size == 1) { 24 PetscCall(MatMultHermitianTranspose(V, x, Na->work1)); 25 if (Na->c) PetscCall(VecPointwiseMult(Na->work1, Na->c, Na->work1)); 26 if (Na->A) { 27 if (transpose) { 28 PetscCall(MatMultTranspose(Na->A, x, y)); 29 } else { 30 PetscCall(MatMult(Na->A, x, y)); 31 } 32 PetscCall(MatMultAdd(U, Na->work1, y, y)); 33 } else { 34 PetscCall(MatMult(U, Na->work1, y)); 35 } 36 } else { 37 Mat Uloc, Vloc; 38 Vec yl, xl; 39 const PetscScalar *w1; 40 PetscScalar *w2; 41 PetscInt nwork; 42 PetscMPIInt mpinwork; 43 44 xl = transpose ? Na->yl : Na->xl; 45 yl = transpose ? Na->xl : Na->yl; 46 PetscCall(VecGetLocalVector(y, yl)); 47 PetscCall(MatDenseGetLocalMatrix(U, &Uloc)); 48 PetscCall(MatDenseGetLocalMatrix(V, &Vloc)); 49 50 /* multiply the local part of V with the local part of x */ 51 PetscCall(VecGetLocalVectorRead(x, xl)); 52 PetscCall(MatMultHermitianTranspose(Vloc, xl, Na->work1)); 53 PetscCall(VecRestoreLocalVectorRead(x, xl)); 54 55 /* form the sum of all the local multiplies: this is work2 = V'*x = 56 sum_{all processors} work1 */ 57 PetscCall(VecGetArrayRead(Na->work1, &w1)); 58 PetscCall(VecGetArrayWrite(Na->work2, &w2)); 59 PetscCall(VecGetLocalSize(Na->work1, &nwork)); 60 PetscCall(PetscMPIIntCast(nwork, &mpinwork)); 61 PetscCall(MPIU_Allreduce(w1, w2, mpinwork, MPIU_SCALAR, MPIU_SUM, PetscObjectComm((PetscObject)N))); 62 PetscCall(VecRestoreArrayRead(Na->work1, &w1)); 63 PetscCall(VecRestoreArrayWrite(Na->work2, &w2)); 64 65 if (Na->c) { /* work2 = C*work2 */ 66 PetscCall(VecPointwiseMult(Na->work2, Na->c, Na->work2)); 67 } 68 69 if (Na->A) { 70 /* form y = A*x or A^t*x */ 71 if (transpose) { 72 PetscCall(MatMultTranspose(Na->A, x, y)); 73 } else { 74 PetscCall(MatMult(Na->A, x, y)); 75 } 76 /* multiply-add y = y + U*work2 */ 77 PetscCall(MatMultAdd(Uloc, Na->work2, yl, yl)); 78 } else { 79 /* multiply y = U*work2 */ 80 PetscCall(MatMult(Uloc, Na->work2, yl)); 81 } 82 83 PetscCall(VecRestoreLocalVector(y, yl)); 84 } 85 PetscFunctionReturn(0); 86 } 87 88 static PetscErrorCode MatMult_LRC(Mat N, Vec x, Vec y) { 89 PetscFunctionBegin; 90 PetscCall(MatMult_LRC_kernel(N, x, y, PETSC_FALSE)); 91 PetscFunctionReturn(0); 92 } 93 94 static PetscErrorCode MatMultTranspose_LRC(Mat N, Vec x, Vec y) { 95 PetscFunctionBegin; 96 PetscCall(MatMult_LRC_kernel(N, x, y, PETSC_TRUE)); 97 PetscFunctionReturn(0); 98 } 99 100 static PetscErrorCode MatDestroy_LRC(Mat N) { 101 Mat_LRC *Na = (Mat_LRC *)N->data; 102 103 PetscFunctionBegin; 104 PetscCall(MatDestroy(&Na->A)); 105 PetscCall(MatDestroy(&Na->U)); 106 PetscCall(MatDestroy(&Na->V)); 107 PetscCall(VecDestroy(&Na->c)); 108 PetscCall(VecDestroy(&Na->work1)); 109 PetscCall(VecDestroy(&Na->work2)); 110 PetscCall(VecDestroy(&Na->xl)); 111 PetscCall(VecDestroy(&Na->yl)); 112 PetscCall(PetscFree(N->data)); 113 PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatLRCGetMats_C", NULL)); 114 PetscFunctionReturn(0); 115 } 116 117 static PetscErrorCode MatLRCGetMats_LRC(Mat N, Mat *A, Mat *U, Vec *c, Mat *V) { 118 Mat_LRC *Na = (Mat_LRC *)N->data; 119 120 PetscFunctionBegin; 121 if (A) *A = Na->A; 122 if (U) *U = Na->U; 123 if (c) *c = Na->c; 124 if (V) *V = Na->V; 125 PetscFunctionReturn(0); 126 } 127 128 /*@ 129 MatLRCGetMats - Returns the constituents of an LRC matrix 130 131 Collective on Mat 132 133 Input Parameter: 134 . N - matrix of type LRC 135 136 Output Parameters: 137 + A - the (sparse) matrix 138 . U - first dense rectangular (tall and skinny) matrix 139 . c - a sequential vector containing the diagonal of C 140 - V - second dense rectangular (tall and skinny) matrix 141 142 Note: 143 The returned matrices need not be destroyed by the caller. 144 145 Level: intermediate 146 147 .seealso: `MatCreateLRC()` 148 @*/ 149 PetscErrorCode MatLRCGetMats(Mat N, Mat *A, Mat *U, Vec *c, Mat *V) { 150 PetscFunctionBegin; 151 PetscUseMethod(N, "MatLRCGetMats_C", (Mat, Mat *, Mat *, Vec *, Mat *), (N, A, U, c, V)); 152 PetscFunctionReturn(0); 153 } 154 155 /*@ 156 MatCreateLRC - Creates a new matrix object that behaves like A + U*C*V' 157 158 Collective on Mat 159 160 Input Parameters: 161 + A - the (sparse) matrix (can be NULL) 162 . U, V - two dense rectangular (tall and skinny) matrices 163 - c - a vector containing the diagonal of C (can be NULL) 164 165 Output Parameter: 166 . N - the matrix that represents A + U*C*V' 167 168 Notes: 169 The matrix A + U*C*V' is not formed! Rather the new matrix 170 object performs the matrix-vector product by first multiplying by 171 A and then adding the other term. 172 173 C is a diagonal matrix (represented as a vector) of order k, 174 where k is the number of columns of both U and V. 175 176 If A is NULL then the new object behaves like a low-rank matrix U*C*V'. 177 178 Use V=U (or V=NULL) for a symmetric low-rank correction, A + U*C*U'. 179 180 If c is NULL then the low-rank correction is just U*V'. 181 If a sequential c vector is used for a parallel matrix, 182 PETSc assumes that the values of the vector are consistently set across processors. 183 184 Level: intermediate 185 186 .seealso: `MatLRCGetMats()` 187 @*/ 188 PetscErrorCode MatCreateLRC(Mat A, Mat U, Vec c, Mat V, Mat *N) { 189 PetscBool match; 190 PetscInt m, n, k, m1, n1, k1; 191 Mat_LRC *Na; 192 Mat Uloc; 193 PetscMPIInt size, csize = 0; 194 195 PetscFunctionBegin; 196 if (A) PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 197 PetscValidHeaderSpecific(U, MAT_CLASSID, 2); 198 if (c) PetscValidHeaderSpecific(c, VEC_CLASSID, 3); 199 if (V) { 200 PetscValidHeaderSpecific(V, MAT_CLASSID, 4); 201 PetscCheckSameComm(U, 2, V, 4); 202 } 203 if (A) PetscCheckSameComm(A, 1, U, 2); 204 205 if (!V) V = U; 206 PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)U, &match, MATSEQDENSE, MATMPIDENSE, "")); 207 PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_SUP, "Matrix U must be of type dense, found %s", ((PetscObject)U)->type_name); 208 PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)V, &match, MATSEQDENSE, MATMPIDENSE, "")); 209 PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_SUP, "Matrix V must be of type dense, found %s", ((PetscObject)V)->type_name); 210 PetscCall(PetscStrcmp(U->defaultvectype, V->defaultvectype, &match)); 211 PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_WRONG, "Matrix U and V must have the same VecType %s != %s", U->defaultvectype, V->defaultvectype); 212 if (A) { 213 PetscCall(PetscStrcmp(A->defaultvectype, U->defaultvectype, &match)); 214 PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_WRONG, "Matrix A and U must have the same VecType %s != %s", A->defaultvectype, U->defaultvectype); 215 } 216 217 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)U), &size)); 218 PetscCall(MatGetSize(U, NULL, &k)); 219 PetscCall(MatGetSize(V, NULL, &k1)); 220 PetscCheck(k == k1, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_INCOMP, "U and V have different number of columns (%" PetscInt_FMT " vs %" PetscInt_FMT ")", k, k1); 221 PetscCall(MatGetLocalSize(U, &m, NULL)); 222 PetscCall(MatGetLocalSize(V, &n, NULL)); 223 if (A) { 224 PetscCall(MatGetLocalSize(A, &m1, &n1)); 225 PetscCheck(m == m1, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Local dimensions of U %" PetscInt_FMT " and A %" PetscInt_FMT " do not match", m, m1); 226 PetscCheck(n == n1, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Local dimensions of V %" PetscInt_FMT " and A %" PetscInt_FMT " do not match", n, n1); 227 } 228 if (c) { 229 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)c), &csize)); 230 PetscCall(VecGetSize(c, &k1)); 231 PetscCheck(k == k1, PetscObjectComm((PetscObject)c), PETSC_ERR_ARG_INCOMP, "The length of c %" PetscInt_FMT " does not match the number of columns of U and V (%" PetscInt_FMT ")", k1, k); 232 PetscCheck(csize == 1 || csize == size, PetscObjectComm((PetscObject)c), PETSC_ERR_ARG_INCOMP, "U and c must have the same communicator size %d != %d", size, csize); 233 } 234 235 PetscCall(MatCreate(PetscObjectComm((PetscObject)U), N)); 236 PetscCall(MatSetSizes(*N, m, n, PETSC_DECIDE, PETSC_DECIDE)); 237 PetscCall(MatSetVecType(*N, U->defaultvectype)); 238 PetscCall(PetscObjectChangeTypeName((PetscObject)*N, MATLRC)); 239 /* Flag matrix as symmetric if A is symmetric and U == V */ 240 PetscCall(MatSetOption(*N, MAT_SYMMETRIC, (PetscBool)((A ? A->symmetric == PETSC_BOOL3_TRUE : PETSC_TRUE) && U == V))); 241 242 PetscCall(PetscNewLog(*N, &Na)); 243 (*N)->data = (void *)Na; 244 Na->A = A; 245 Na->U = U; 246 Na->c = c; 247 Na->V = V; 248 249 PetscCall(PetscObjectReference((PetscObject)A)); 250 PetscCall(PetscObjectReference((PetscObject)Na->U)); 251 PetscCall(PetscObjectReference((PetscObject)Na->V)); 252 PetscCall(PetscObjectReference((PetscObject)c)); 253 254 PetscCall(MatDenseGetLocalMatrix(Na->U, &Uloc)); 255 PetscCall(MatCreateVecs(Uloc, &Na->work1, NULL)); 256 if (size != 1) { 257 Mat Vloc; 258 259 if (Na->c && csize != 1) { /* scatter parallel vector to sequential */ 260 VecScatter sct; 261 262 PetscCall(VecScatterCreateToAll(Na->c, &sct, &c)); 263 PetscCall(VecScatterBegin(sct, Na->c, c, INSERT_VALUES, SCATTER_FORWARD)); 264 PetscCall(VecScatterEnd(sct, Na->c, c, INSERT_VALUES, SCATTER_FORWARD)); 265 PetscCall(VecScatterDestroy(&sct)); 266 PetscCall(VecDestroy(&Na->c)); 267 PetscCall(PetscLogObjectParent((PetscObject)*N, (PetscObject)c)); 268 Na->c = c; 269 } 270 PetscCall(MatDenseGetLocalMatrix(Na->V, &Vloc)); 271 PetscCall(VecDuplicate(Na->work1, &Na->work2)); 272 PetscCall(MatCreateVecs(Vloc, NULL, &Na->xl)); 273 PetscCall(MatCreateVecs(Uloc, NULL, &Na->yl)); 274 } 275 PetscCall(PetscLogObjectParent((PetscObject)*N, (PetscObject)Na->work1)); 276 PetscCall(PetscLogObjectParent((PetscObject)*N, (PetscObject)Na->work1)); 277 PetscCall(PetscLogObjectParent((PetscObject)*N, (PetscObject)Na->xl)); 278 PetscCall(PetscLogObjectParent((PetscObject)*N, (PetscObject)Na->yl)); 279 280 /* Internally create a scaling vector if roottypes do not match */ 281 if (Na->c) { 282 VecType rt1, rt2; 283 284 PetscCall(VecGetRootType_Private(Na->work1, &rt1)); 285 PetscCall(VecGetRootType_Private(Na->c, &rt2)); 286 PetscCall(PetscStrcmp(rt1, rt2, &match)); 287 if (!match) { 288 PetscCall(VecDuplicate(Na->c, &c)); 289 PetscCall(VecCopy(Na->c, c)); 290 PetscCall(VecDestroy(&Na->c)); 291 PetscCall(PetscLogObjectParent((PetscObject)*N, (PetscObject)c)); 292 Na->c = c; 293 } 294 } 295 296 (*N)->ops->destroy = MatDestroy_LRC; 297 (*N)->ops->mult = MatMult_LRC; 298 (*N)->ops->multtranspose = MatMultTranspose_LRC; 299 300 (*N)->assembled = PETSC_TRUE; 301 (*N)->preallocated = PETSC_TRUE; 302 303 PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatLRCGetMats_C", MatLRCGetMats_LRC)); 304 PetscCall(MatSetUp(*N)); 305 PetscFunctionReturn(0); 306 } 307