1 2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/ 3 4 PETSC_EXTERN PetscErrorCode VecGetRootType_Private(Vec,VecType*); 5 6 typedef struct { 7 Mat A; /* sparse matrix */ 8 Mat U,V; /* dense tall-skinny matrices */ 9 Vec c; /* sequential vector containing the diagonal of C */ 10 Vec work1,work2; /* sequential vectors that hold partial products */ 11 Vec xl,yl; /* auxiliary sequential vectors for matmult operation */ 12 } Mat_LRC; 13 14 static PetscErrorCode MatMult_LRC_kernel(Mat N,Vec x,Vec y,PetscBool transpose) 15 { 16 Mat_LRC *Na = (Mat_LRC*)N->data; 17 PetscMPIInt size; 18 Mat U,V; 19 20 PetscFunctionBegin; 21 U = transpose ? Na->V : Na->U; 22 V = transpose ? Na->U : Na->V; 23 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)N),&size)); 24 if (size == 1) { 25 PetscCall(MatMultHermitianTranspose(V,x,Na->work1)); 26 if (Na->c) { 27 PetscCall(VecPointwiseMult(Na->work1,Na->c,Na->work1)); 28 } 29 if (Na->A) { 30 if (transpose) { 31 PetscCall(MatMultTranspose(Na->A,x,y)); 32 } else { 33 PetscCall(MatMult(Na->A,x,y)); 34 } 35 PetscCall(MatMultAdd(U,Na->work1,y,y)); 36 } else { 37 PetscCall(MatMult(U,Na->work1,y)); 38 } 39 } else { 40 Mat Uloc,Vloc; 41 Vec yl,xl; 42 const PetscScalar *w1; 43 PetscScalar *w2; 44 PetscInt nwork; 45 PetscMPIInt mpinwork; 46 47 xl = transpose ? Na->yl : Na->xl; 48 yl = transpose ? Na->xl : Na->yl; 49 PetscCall(VecGetLocalVector(y,yl)); 50 PetscCall(MatDenseGetLocalMatrix(U,&Uloc)); 51 PetscCall(MatDenseGetLocalMatrix(V,&Vloc)); 52 53 /* multiply the local part of V with the local part of x */ 54 PetscCall(VecGetLocalVectorRead(x,xl)); 55 PetscCall(MatMultHermitianTranspose(Vloc,xl,Na->work1)); 56 PetscCall(VecRestoreLocalVectorRead(x,xl)); 57 58 /* form the sum of all the local multiplies: this is work2 = V'*x = 59 sum_{all processors} work1 */ 60 PetscCall(VecGetArrayRead(Na->work1,&w1)); 61 PetscCall(VecGetArrayWrite(Na->work2,&w2)); 62 PetscCall(VecGetLocalSize(Na->work1,&nwork)); 63 PetscCall(PetscMPIIntCast(nwork,&mpinwork)); 64 PetscCall(MPIU_Allreduce(w1,w2,mpinwork,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)N))); 65 PetscCall(VecRestoreArrayRead(Na->work1,&w1)); 66 PetscCall(VecRestoreArrayWrite(Na->work2,&w2)); 67 68 if (Na->c) { /* work2 = C*work2 */ 69 PetscCall(VecPointwiseMult(Na->work2,Na->c,Na->work2)); 70 } 71 72 if (Na->A) { 73 /* form y = A*x or A^t*x */ 74 if (transpose) { 75 PetscCall(MatMultTranspose(Na->A,x,y)); 76 } else { 77 PetscCall(MatMult(Na->A,x,y)); 78 } 79 /* multiply-add y = y + U*work2 */ 80 PetscCall(MatMultAdd(Uloc,Na->work2,yl,yl)); 81 } else { 82 /* multiply y = U*work2 */ 83 PetscCall(MatMult(Uloc,Na->work2,yl)); 84 } 85 86 PetscCall(VecRestoreLocalVector(y,yl)); 87 } 88 PetscFunctionReturn(0); 89 } 90 91 static PetscErrorCode MatMult_LRC(Mat N,Vec x,Vec y) 92 { 93 PetscFunctionBegin; 94 PetscCall(MatMult_LRC_kernel(N,x,y,PETSC_FALSE)); 95 PetscFunctionReturn(0); 96 } 97 98 static PetscErrorCode MatMultTranspose_LRC(Mat N,Vec x,Vec y) 99 { 100 PetscFunctionBegin; 101 PetscCall(MatMult_LRC_kernel(N,x,y,PETSC_TRUE)); 102 PetscFunctionReturn(0); 103 } 104 105 static PetscErrorCode MatDestroy_LRC(Mat N) 106 { 107 Mat_LRC *Na = (Mat_LRC*)N->data; 108 109 PetscFunctionBegin; 110 PetscCall(MatDestroy(&Na->A)); 111 PetscCall(MatDestroy(&Na->U)); 112 PetscCall(MatDestroy(&Na->V)); 113 PetscCall(VecDestroy(&Na->c)); 114 PetscCall(VecDestroy(&Na->work1)); 115 PetscCall(VecDestroy(&Na->work2)); 116 PetscCall(VecDestroy(&Na->xl)); 117 PetscCall(VecDestroy(&Na->yl)); 118 PetscCall(PetscFree(N->data)); 119 PetscCall(PetscObjectComposeFunction((PetscObject)N,"MatLRCGetMats_C",NULL)); 120 PetscFunctionReturn(0); 121 } 122 123 static PetscErrorCode MatLRCGetMats_LRC(Mat N,Mat *A,Mat *U,Vec *c,Mat *V) 124 { 125 Mat_LRC *Na = (Mat_LRC*)N->data; 126 127 PetscFunctionBegin; 128 if (A) *A = Na->A; 129 if (U) *U = Na->U; 130 if (c) *c = Na->c; 131 if (V) *V = Na->V; 132 PetscFunctionReturn(0); 133 } 134 135 /*@ 136 MatLRCGetMats - Returns the constituents of an LRC matrix 137 138 Collective on Mat 139 140 Input Parameter: 141 . N - matrix of type LRC 142 143 Output Parameters: 144 + A - the (sparse) matrix 145 . U - first dense rectangular (tall and skinny) matrix 146 . c - a sequential vector containing the diagonal of C 147 - V - second dense rectangular (tall and skinny) matrix 148 149 Note: 150 The returned matrices need not be destroyed by the caller. 151 152 Level: intermediate 153 154 .seealso: MatCreateLRC() 155 @*/ 156 PetscErrorCode MatLRCGetMats(Mat N,Mat *A,Mat *U,Vec *c,Mat *V) 157 { 158 PetscFunctionBegin; 159 PetscUseMethod(N,"MatLRCGetMats_C",(Mat,Mat*,Mat*,Vec*,Mat*),(N,A,U,c,V)); 160 PetscFunctionReturn(0); 161 } 162 163 /*@ 164 MatCreateLRC - Creates a new matrix object that behaves like A + U*C*V' 165 166 Collective on Mat 167 168 Input Parameters: 169 + A - the (sparse) matrix (can be NULL) 170 . U, V - two dense rectangular (tall and skinny) matrices 171 - c - a vector containing the diagonal of C (can be NULL) 172 173 Output Parameter: 174 . N - the matrix that represents A + U*C*V' 175 176 Notes: 177 The matrix A + U*C*V' is not formed! Rather the new matrix 178 object performs the matrix-vector product by first multiplying by 179 A and then adding the other term. 180 181 C is a diagonal matrix (represented as a vector) of order k, 182 where k is the number of columns of both U and V. 183 184 If A is NULL then the new object behaves like a low-rank matrix U*C*V'. 185 186 Use V=U (or V=NULL) for a symmetric low-rank correction, A + U*C*U'. 187 188 If c is NULL then the low-rank correction is just U*V'. 189 If a sequential c vector is used for a parallel matrix, 190 PETSc assumes that the values of the vector are consistently set across processors. 191 192 Level: intermediate 193 194 .seealso: MatLRCGetMats() 195 @*/ 196 PetscErrorCode MatCreateLRC(Mat A,Mat U,Vec c,Mat V,Mat *N) 197 { 198 PetscBool match; 199 PetscInt m,n,k,m1,n1,k1; 200 Mat_LRC *Na; 201 Mat Uloc; 202 PetscMPIInt size, csize = 0; 203 204 PetscFunctionBegin; 205 if (A) PetscValidHeaderSpecific(A,MAT_CLASSID,1); 206 PetscValidHeaderSpecific(U,MAT_CLASSID,2); 207 if (c) PetscValidHeaderSpecific(c,VEC_CLASSID,3); 208 if (V) { 209 PetscValidHeaderSpecific(V,MAT_CLASSID,4); 210 PetscCheckSameComm(U,2,V,4); 211 } 212 if (A) PetscCheckSameComm(A,1,U,2); 213 214 if (!V) V = U; 215 PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)U,&match,MATSEQDENSE,MATMPIDENSE,"")); 216 PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_SUP,"Matrix U must be of type dense, found %s",((PetscObject)U)->type_name); 217 PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)V,&match,MATSEQDENSE,MATMPIDENSE,"")); 218 PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_SUP,"Matrix V must be of type dense, found %s",((PetscObject)V)->type_name); 219 PetscCall(PetscStrcmp(U->defaultvectype,V->defaultvectype,&match)); 220 PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_ARG_WRONG,"Matrix U and V must have the same VecType %s != %s",U->defaultvectype,V->defaultvectype); 221 if (A) { 222 PetscCall(PetscStrcmp(A->defaultvectype,U->defaultvectype,&match)); 223 PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_ARG_WRONG,"Matrix A and U must have the same VecType %s != %s",A->defaultvectype,U->defaultvectype); 224 } 225 226 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)U),&size)); 227 PetscCall(MatGetSize(U,NULL,&k)); 228 PetscCall(MatGetSize(V,NULL,&k1)); 229 PetscCheckFalse(k != k1,PetscObjectComm((PetscObject)U),PETSC_ERR_ARG_INCOMP,"U and V have different number of columns (%" PetscInt_FMT " vs %" PetscInt_FMT ")",k,k1); 230 PetscCall(MatGetLocalSize(U,&m,NULL)); 231 PetscCall(MatGetLocalSize(V,&n,NULL)); 232 if (A) { 233 PetscCall(MatGetLocalSize(A,&m1,&n1)); 234 PetscCheckFalse(m != m1,PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Local dimensions of U %" PetscInt_FMT " and A %" PetscInt_FMT " do not match",m,m1); 235 PetscCheckFalse(n != n1,PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Local dimensions of V %" PetscInt_FMT " and A %" PetscInt_FMT " do not match",n,n1); 236 } 237 if (c) { 238 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)c),&csize)); 239 PetscCall(VecGetSize(c,&k1)); 240 PetscCheckFalse(k != k1,PetscObjectComm((PetscObject)c),PETSC_ERR_ARG_INCOMP,"The length of c %" PetscInt_FMT " does not match the number of columns of U and V (%" PetscInt_FMT ")",k1,k); 241 PetscCheckFalse(csize != 1 && csize != size, PetscObjectComm((PetscObject)c),PETSC_ERR_ARG_INCOMP,"U and c must have the same communicator size %d != %d",size,csize); 242 } 243 244 PetscCall(MatCreate(PetscObjectComm((PetscObject)U),N)); 245 PetscCall(MatSetSizes(*N,m,n,PETSC_DECIDE,PETSC_DECIDE)); 246 PetscCall(MatSetVecType(*N,U->defaultvectype)); 247 PetscCall(PetscObjectChangeTypeName((PetscObject)*N,MATLRC)); 248 /* Flag matrix as symmetric if A is symmetric and U == V */ 249 PetscCall(MatSetOption(*N,MAT_SYMMETRIC,(PetscBool)((A ? A->symmetric : PETSC_TRUE) && U == V))); 250 251 PetscCall(PetscNewLog(*N,&Na)); 252 (*N)->data = (void*)Na; 253 Na->A = A; 254 Na->U = U; 255 Na->c = c; 256 Na->V = V; 257 258 PetscCall(PetscObjectReference((PetscObject)A)); 259 PetscCall(PetscObjectReference((PetscObject)Na->U)); 260 PetscCall(PetscObjectReference((PetscObject)Na->V)); 261 PetscCall(PetscObjectReference((PetscObject)c)); 262 263 PetscCall(MatDenseGetLocalMatrix(Na->U,&Uloc)); 264 PetscCall(MatCreateVecs(Uloc,&Na->work1,NULL)); 265 if (size != 1) { 266 Mat Vloc; 267 268 if (Na->c && csize != 1) { /* scatter parallel vector to sequential */ 269 VecScatter sct; 270 271 PetscCall(VecScatterCreateToAll(Na->c,&sct,&c)); 272 PetscCall(VecScatterBegin(sct,Na->c,c,INSERT_VALUES,SCATTER_FORWARD)); 273 PetscCall(VecScatterEnd(sct,Na->c,c,INSERT_VALUES,SCATTER_FORWARD)); 274 PetscCall(VecScatterDestroy(&sct)); 275 PetscCall(VecDestroy(&Na->c)); 276 PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)c)); 277 Na->c = c; 278 } 279 PetscCall(MatDenseGetLocalMatrix(Na->V,&Vloc)); 280 PetscCall(VecDuplicate(Na->work1,&Na->work2)); 281 PetscCall(MatCreateVecs(Vloc,NULL,&Na->xl)); 282 PetscCall(MatCreateVecs(Uloc,NULL,&Na->yl)); 283 } 284 PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->work1)); 285 PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->work1)); 286 PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->xl)); 287 PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->yl)); 288 289 /* Internally create a scaling vector if roottypes do not match */ 290 if (Na->c) { 291 VecType rt1,rt2; 292 293 PetscCall(VecGetRootType_Private(Na->work1,&rt1)); 294 PetscCall(VecGetRootType_Private(Na->c,&rt2)); 295 PetscCall(PetscStrcmp(rt1,rt2,&match)); 296 if (!match) { 297 PetscCall(VecDuplicate(Na->c,&c)); 298 PetscCall(VecCopy(Na->c,c)); 299 PetscCall(VecDestroy(&Na->c)); 300 PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)c)); 301 Na->c = c; 302 } 303 } 304 305 (*N)->ops->destroy = MatDestroy_LRC; 306 (*N)->ops->mult = MatMult_LRC; 307 (*N)->ops->multtranspose = MatMultTranspose_LRC; 308 309 (*N)->assembled = PETSC_TRUE; 310 (*N)->preallocated = PETSC_TRUE; 311 312 PetscCall(PetscObjectComposeFunction((PetscObject)(*N),"MatLRCGetMats_C",MatLRCGetMats_LRC)); 313 PetscCall(MatSetUp(*N)); 314 PetscFunctionReturn(0); 315 } 316