1 2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/ 3 4 PETSC_EXTERN PetscErrorCode VecGetRootType_Private(Vec,VecType*); 5 6 typedef struct { 7 Mat A; /* sparse matrix */ 8 Mat U,V; /* dense tall-skinny matrices */ 9 Vec c; /* sequential vector containing the diagonal of C */ 10 Vec work1,work2; /* sequential vectors that hold partial products */ 11 Vec xl,yl; /* auxiliary sequential vectors for matmult operation */ 12 } Mat_LRC; 13 14 static PetscErrorCode MatMult_LRC_kernel(Mat N,Vec x,Vec y,PetscBool transpose) 15 { 16 Mat_LRC *Na = (Mat_LRC*)N->data; 17 PetscMPIInt size; 18 Mat U,V; 19 20 PetscFunctionBegin; 21 U = transpose ? Na->V : Na->U; 22 V = transpose ? Na->U : Na->V; 23 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)N),&size)); 24 if (size == 1) { 25 PetscCall(MatMultHermitianTranspose(V,x,Na->work1)); 26 if (Na->c) PetscCall(VecPointwiseMult(Na->work1,Na->c,Na->work1)); 27 if (Na->A) { 28 if (transpose) { 29 PetscCall(MatMultTranspose(Na->A,x,y)); 30 } else { 31 PetscCall(MatMult(Na->A,x,y)); 32 } 33 PetscCall(MatMultAdd(U,Na->work1,y,y)); 34 } else { 35 PetscCall(MatMult(U,Na->work1,y)); 36 } 37 } else { 38 Mat Uloc,Vloc; 39 Vec yl,xl; 40 const PetscScalar *w1; 41 PetscScalar *w2; 42 PetscInt nwork; 43 PetscMPIInt mpinwork; 44 45 xl = transpose ? Na->yl : Na->xl; 46 yl = transpose ? Na->xl : Na->yl; 47 PetscCall(VecGetLocalVector(y,yl)); 48 PetscCall(MatDenseGetLocalMatrix(U,&Uloc)); 49 PetscCall(MatDenseGetLocalMatrix(V,&Vloc)); 50 51 /* multiply the local part of V with the local part of x */ 52 PetscCall(VecGetLocalVectorRead(x,xl)); 53 PetscCall(MatMultHermitianTranspose(Vloc,xl,Na->work1)); 54 PetscCall(VecRestoreLocalVectorRead(x,xl)); 55 56 /* form the sum of all the local multiplies: this is work2 = V'*x = 57 sum_{all processors} work1 */ 58 PetscCall(VecGetArrayRead(Na->work1,&w1)); 59 PetscCall(VecGetArrayWrite(Na->work2,&w2)); 60 PetscCall(VecGetLocalSize(Na->work1,&nwork)); 61 PetscCall(PetscMPIIntCast(nwork,&mpinwork)); 62 PetscCall(MPIU_Allreduce(w1,w2,mpinwork,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)N))); 63 PetscCall(VecRestoreArrayRead(Na->work1,&w1)); 64 PetscCall(VecRestoreArrayWrite(Na->work2,&w2)); 65 66 if (Na->c) { /* work2 = C*work2 */ 67 PetscCall(VecPointwiseMult(Na->work2,Na->c,Na->work2)); 68 } 69 70 if (Na->A) { 71 /* form y = A*x or A^t*x */ 72 if (transpose) { 73 PetscCall(MatMultTranspose(Na->A,x,y)); 74 } else { 75 PetscCall(MatMult(Na->A,x,y)); 76 } 77 /* multiply-add y = y + U*work2 */ 78 PetscCall(MatMultAdd(Uloc,Na->work2,yl,yl)); 79 } else { 80 /* multiply y = U*work2 */ 81 PetscCall(MatMult(Uloc,Na->work2,yl)); 82 } 83 84 PetscCall(VecRestoreLocalVector(y,yl)); 85 } 86 PetscFunctionReturn(0); 87 } 88 89 static PetscErrorCode MatMult_LRC(Mat N,Vec x,Vec y) 90 { 91 PetscFunctionBegin; 92 PetscCall(MatMult_LRC_kernel(N,x,y,PETSC_FALSE)); 93 PetscFunctionReturn(0); 94 } 95 96 static PetscErrorCode MatMultTranspose_LRC(Mat N,Vec x,Vec y) 97 { 98 PetscFunctionBegin; 99 PetscCall(MatMult_LRC_kernel(N,x,y,PETSC_TRUE)); 100 PetscFunctionReturn(0); 101 } 102 103 static PetscErrorCode MatDestroy_LRC(Mat N) 104 { 105 Mat_LRC *Na = (Mat_LRC*)N->data; 106 107 PetscFunctionBegin; 108 PetscCall(MatDestroy(&Na->A)); 109 PetscCall(MatDestroy(&Na->U)); 110 PetscCall(MatDestroy(&Na->V)); 111 PetscCall(VecDestroy(&Na->c)); 112 PetscCall(VecDestroy(&Na->work1)); 113 PetscCall(VecDestroy(&Na->work2)); 114 PetscCall(VecDestroy(&Na->xl)); 115 PetscCall(VecDestroy(&Na->yl)); 116 PetscCall(PetscFree(N->data)); 117 PetscCall(PetscObjectComposeFunction((PetscObject)N,"MatLRCGetMats_C",NULL)); 118 PetscFunctionReturn(0); 119 } 120 121 static PetscErrorCode MatLRCGetMats_LRC(Mat N,Mat *A,Mat *U,Vec *c,Mat *V) 122 { 123 Mat_LRC *Na = (Mat_LRC*)N->data; 124 125 PetscFunctionBegin; 126 if (A) *A = Na->A; 127 if (U) *U = Na->U; 128 if (c) *c = Na->c; 129 if (V) *V = Na->V; 130 PetscFunctionReturn(0); 131 } 132 133 /*@ 134 MatLRCGetMats - Returns the constituents of an LRC matrix 135 136 Collective on Mat 137 138 Input Parameter: 139 . N - matrix of type LRC 140 141 Output Parameters: 142 + A - the (sparse) matrix 143 . U - first dense rectangular (tall and skinny) matrix 144 . c - a sequential vector containing the diagonal of C 145 - V - second dense rectangular (tall and skinny) matrix 146 147 Note: 148 The returned matrices need not be destroyed by the caller. 149 150 Level: intermediate 151 152 .seealso: `MatCreateLRC()` 153 @*/ 154 PetscErrorCode MatLRCGetMats(Mat N,Mat *A,Mat *U,Vec *c,Mat *V) 155 { 156 PetscFunctionBegin; 157 PetscUseMethod(N,"MatLRCGetMats_C",(Mat,Mat*,Mat*,Vec*,Mat*),(N,A,U,c,V)); 158 PetscFunctionReturn(0); 159 } 160 161 /*@ 162 MatCreateLRC - Creates a new matrix object that behaves like A + U*C*V' 163 164 Collective on Mat 165 166 Input Parameters: 167 + A - the (sparse) matrix (can be NULL) 168 . U, V - two dense rectangular (tall and skinny) matrices 169 - c - a vector containing the diagonal of C (can be NULL) 170 171 Output Parameter: 172 . N - the matrix that represents A + U*C*V' 173 174 Notes: 175 The matrix A + U*C*V' is not formed! Rather the new matrix 176 object performs the matrix-vector product by first multiplying by 177 A and then adding the other term. 178 179 C is a diagonal matrix (represented as a vector) of order k, 180 where k is the number of columns of both U and V. 181 182 If A is NULL then the new object behaves like a low-rank matrix U*C*V'. 183 184 Use V=U (or V=NULL) for a symmetric low-rank correction, A + U*C*U'. 185 186 If c is NULL then the low-rank correction is just U*V'. 187 If a sequential c vector is used for a parallel matrix, 188 PETSc assumes that the values of the vector are consistently set across processors. 189 190 Level: intermediate 191 192 .seealso: `MatLRCGetMats()` 193 @*/ 194 PetscErrorCode MatCreateLRC(Mat A,Mat U,Vec c,Mat V,Mat *N) 195 { 196 PetscBool match; 197 PetscInt m,n,k,m1,n1,k1; 198 Mat_LRC *Na; 199 Mat Uloc; 200 PetscMPIInt size, csize = 0; 201 202 PetscFunctionBegin; 203 if (A) PetscValidHeaderSpecific(A,MAT_CLASSID,1); 204 PetscValidHeaderSpecific(U,MAT_CLASSID,2); 205 if (c) PetscValidHeaderSpecific(c,VEC_CLASSID,3); 206 if (V) { 207 PetscValidHeaderSpecific(V,MAT_CLASSID,4); 208 PetscCheckSameComm(U,2,V,4); 209 } 210 if (A) PetscCheckSameComm(A,1,U,2); 211 212 if (!V) V = U; 213 PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)U,&match,MATSEQDENSE,MATMPIDENSE,"")); 214 PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_SUP,"Matrix U must be of type dense, found %s",((PetscObject)U)->type_name); 215 PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)V,&match,MATSEQDENSE,MATMPIDENSE,"")); 216 PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_SUP,"Matrix V must be of type dense, found %s",((PetscObject)V)->type_name); 217 PetscCall(PetscStrcmp(U->defaultvectype,V->defaultvectype,&match)); 218 PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_ARG_WRONG,"Matrix U and V must have the same VecType %s != %s",U->defaultvectype,V->defaultvectype); 219 if (A) { 220 PetscCall(PetscStrcmp(A->defaultvectype,U->defaultvectype,&match)); 221 PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_ARG_WRONG,"Matrix A and U must have the same VecType %s != %s",A->defaultvectype,U->defaultvectype); 222 } 223 224 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)U),&size)); 225 PetscCall(MatGetSize(U,NULL,&k)); 226 PetscCall(MatGetSize(V,NULL,&k1)); 227 PetscCheck(k == k1,PetscObjectComm((PetscObject)U),PETSC_ERR_ARG_INCOMP,"U and V have different number of columns (%" PetscInt_FMT " vs %" PetscInt_FMT ")",k,k1); 228 PetscCall(MatGetLocalSize(U,&m,NULL)); 229 PetscCall(MatGetLocalSize(V,&n,NULL)); 230 if (A) { 231 PetscCall(MatGetLocalSize(A,&m1,&n1)); 232 PetscCheck(m == m1,PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Local dimensions of U %" PetscInt_FMT " and A %" PetscInt_FMT " do not match",m,m1); 233 PetscCheck(n == n1,PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Local dimensions of V %" PetscInt_FMT " and A %" PetscInt_FMT " do not match",n,n1); 234 } 235 if (c) { 236 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)c),&csize)); 237 PetscCall(VecGetSize(c,&k1)); 238 PetscCheck(k == k1,PetscObjectComm((PetscObject)c),PETSC_ERR_ARG_INCOMP,"The length of c %" PetscInt_FMT " does not match the number of columns of U and V (%" PetscInt_FMT ")",k1,k); 239 PetscCheck(csize == 1 || csize == size, PetscObjectComm((PetscObject)c),PETSC_ERR_ARG_INCOMP,"U and c must have the same communicator size %d != %d",size,csize); 240 } 241 242 PetscCall(MatCreate(PetscObjectComm((PetscObject)U),N)); 243 PetscCall(MatSetSizes(*N,m,n,PETSC_DECIDE,PETSC_DECIDE)); 244 PetscCall(MatSetVecType(*N,U->defaultvectype)); 245 PetscCall(PetscObjectChangeTypeName((PetscObject)*N,MATLRC)); 246 /* Flag matrix as symmetric if A is symmetric and U == V */ 247 PetscCall(MatSetOption(*N,MAT_SYMMETRIC,(PetscBool)((A ? A->symmetric == PETSC_BOOL3_TRUE : PETSC_TRUE) && U == V))); 248 249 PetscCall(PetscNewLog(*N,&Na)); 250 (*N)->data = (void*)Na; 251 Na->A = A; 252 Na->U = U; 253 Na->c = c; 254 Na->V = V; 255 256 PetscCall(PetscObjectReference((PetscObject)A)); 257 PetscCall(PetscObjectReference((PetscObject)Na->U)); 258 PetscCall(PetscObjectReference((PetscObject)Na->V)); 259 PetscCall(PetscObjectReference((PetscObject)c)); 260 261 PetscCall(MatDenseGetLocalMatrix(Na->U,&Uloc)); 262 PetscCall(MatCreateVecs(Uloc,&Na->work1,NULL)); 263 if (size != 1) { 264 Mat Vloc; 265 266 if (Na->c && csize != 1) { /* scatter parallel vector to sequential */ 267 VecScatter sct; 268 269 PetscCall(VecScatterCreateToAll(Na->c,&sct,&c)); 270 PetscCall(VecScatterBegin(sct,Na->c,c,INSERT_VALUES,SCATTER_FORWARD)); 271 PetscCall(VecScatterEnd(sct,Na->c,c,INSERT_VALUES,SCATTER_FORWARD)); 272 PetscCall(VecScatterDestroy(&sct)); 273 PetscCall(VecDestroy(&Na->c)); 274 PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)c)); 275 Na->c = c; 276 } 277 PetscCall(MatDenseGetLocalMatrix(Na->V,&Vloc)); 278 PetscCall(VecDuplicate(Na->work1,&Na->work2)); 279 PetscCall(MatCreateVecs(Vloc,NULL,&Na->xl)); 280 PetscCall(MatCreateVecs(Uloc,NULL,&Na->yl)); 281 } 282 PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->work1)); 283 PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->work1)); 284 PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->xl)); 285 PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->yl)); 286 287 /* Internally create a scaling vector if roottypes do not match */ 288 if (Na->c) { 289 VecType rt1,rt2; 290 291 PetscCall(VecGetRootType_Private(Na->work1,&rt1)); 292 PetscCall(VecGetRootType_Private(Na->c,&rt2)); 293 PetscCall(PetscStrcmp(rt1,rt2,&match)); 294 if (!match) { 295 PetscCall(VecDuplicate(Na->c,&c)); 296 PetscCall(VecCopy(Na->c,c)); 297 PetscCall(VecDestroy(&Na->c)); 298 PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)c)); 299 Na->c = c; 300 } 301 } 302 303 (*N)->ops->destroy = MatDestroy_LRC; 304 (*N)->ops->mult = MatMult_LRC; 305 (*N)->ops->multtranspose = MatMultTranspose_LRC; 306 307 (*N)->assembled = PETSC_TRUE; 308 (*N)->preallocated = PETSC_TRUE; 309 310 PetscCall(PetscObjectComposeFunction((PetscObject)(*N),"MatLRCGetMats_C",MatLRCGetMats_LRC)); 311 PetscCall(MatSetUp(*N)); 312 PetscFunctionReturn(0); 313 } 314