1 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/ 2 3 PetscErrorCode MatConvert_Shell(Mat oldmat,MatType newtype,MatReuse reuse,Mat *newmat) 4 { 5 Mat mat; 6 Vec in,out; 7 PetscScalar *array; 8 PetscInt *dnnz,*onnz,*dnnzu,*onnzu; 9 PetscInt cst,Nbs,mbs,nbs,rbs,cbs; 10 PetscInt im,i,m,n,M,N,*rows,start; 11 12 PetscFunctionBegin; 13 PetscCall(MatGetOwnershipRange(oldmat,&start,NULL)); 14 PetscCall(MatGetOwnershipRangeColumn(oldmat,&cst,NULL)); 15 PetscCall(MatCreateVecs(oldmat,&in,&out)); 16 PetscCall(MatGetLocalSize(oldmat,&m,&n)); 17 PetscCall(MatGetSize(oldmat,&M,&N)); 18 PetscCall(PetscMalloc1(m,&rows)); 19 if (reuse != MAT_REUSE_MATRIX) { 20 PetscCall(MatCreate(PetscObjectComm((PetscObject)oldmat),&mat)); 21 PetscCall(MatSetSizes(mat,m,n,M,N)); 22 PetscCall(MatSetType(mat,newtype)); 23 PetscCall(MatSetBlockSizesFromMats(mat,oldmat,oldmat)); 24 PetscCall(MatGetBlockSizes(mat,&rbs,&cbs)); 25 mbs = m/rbs; 26 nbs = n/cbs; 27 Nbs = N/cbs; 28 cst = cst/cbs; 29 PetscCall(PetscMalloc4(mbs,&dnnz,mbs,&onnz,mbs,&dnnzu,mbs,&onnzu)); 30 for (i=0; i<mbs; i++) { 31 dnnz[i] = nbs; 32 onnz[i] = Nbs - nbs; 33 dnnzu[i] = PetscMax(nbs - i,0); 34 onnzu[i] = PetscMax(Nbs - (cst + nbs),0); 35 } 36 PetscCall(MatXAIJSetPreallocation(mat,PETSC_DECIDE,dnnz,onnz,dnnzu,onnzu)); 37 PetscCall(PetscFree4(dnnz,onnz,dnnzu,onnzu)); 38 PetscCall(VecSetOption(in,VEC_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)); 39 PetscCall(MatSetUp(mat)); 40 } else { 41 mat = *newmat; 42 PetscCall(MatZeroEntries(mat)); 43 } 44 for (i=0; i<N; i++) { 45 PetscInt j; 46 47 PetscCall(VecZeroEntries(in)); 48 PetscCall(VecSetValue(in,i,1.,INSERT_VALUES)); 49 PetscCall(VecAssemblyBegin(in)); 50 PetscCall(VecAssemblyEnd(in)); 51 PetscCall(MatMult(oldmat,in,out)); 52 PetscCall(VecGetArray(out,&array)); 53 for (j=0, im = 0; j<m; j++) { 54 if (PetscAbsScalar(array[j]) == 0.0) continue; 55 rows[im] = j+start; 56 array[im] = array[j]; 57 im++; 58 } 59 PetscCall(MatSetValues(mat,im,rows,1,&i,array,INSERT_VALUES)); 60 PetscCall(VecRestoreArray(out,&array)); 61 } 62 PetscCall(PetscFree(rows)); 63 PetscCall(VecDestroy(&in)); 64 PetscCall(VecDestroy(&out)); 65 PetscCall(MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY)); 66 PetscCall(MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY)); 67 if (reuse == MAT_INPLACE_MATRIX) { 68 PetscCall(MatHeaderReplace(oldmat,&mat)); 69 } else { 70 *newmat = mat; 71 } 72 PetscFunctionReturn(0); 73 } 74 75 static PetscErrorCode MatGetDiagonal_CF(Mat A,Vec X) 76 { 77 Mat B; 78 79 PetscFunctionBegin; 80 PetscCall(MatShellGetContext(A,&B)); 81 PetscCheck(B,PetscObjectComm((PetscObject)A),PETSC_ERR_PLIB,"Missing user matrix"); 82 PetscCall(MatGetDiagonal(B,X)); 83 PetscFunctionReturn(0); 84 } 85 86 static PetscErrorCode MatMult_CF(Mat A,Vec X,Vec Y) 87 { 88 Mat B; 89 90 PetscFunctionBegin; 91 PetscCall(MatShellGetContext(A,&B)); 92 PetscCheck(B,PetscObjectComm((PetscObject)A),PETSC_ERR_PLIB,"Missing user matrix"); 93 PetscCall(MatMult(B,X,Y)); 94 PetscFunctionReturn(0); 95 } 96 97 static PetscErrorCode MatMultTranspose_CF(Mat A,Vec X,Vec Y) 98 { 99 Mat B; 100 101 PetscFunctionBegin; 102 PetscCall(MatShellGetContext(A,&B)); 103 PetscCheck(B,PetscObjectComm((PetscObject)A),PETSC_ERR_PLIB,"Missing user matrix"); 104 PetscCall(MatMultTranspose(B,X,Y)); 105 PetscFunctionReturn(0); 106 } 107 108 static PetscErrorCode MatDestroy_CF(Mat A) 109 { 110 Mat B; 111 112 PetscFunctionBegin; 113 PetscCall(MatShellGetContext(A,&B)); 114 PetscCheck(B,PetscObjectComm((PetscObject)A),PETSC_ERR_PLIB,"Missing user matrix"); 115 PetscCall(MatDestroy(&B)); 116 PetscCall(PetscObjectComposeFunction((PetscObject)A,"MatProductSetFromOptions_anytype_C",NULL)); 117 PetscFunctionReturn(0); 118 } 119 120 typedef struct { 121 void *userdata; 122 PetscErrorCode (*userdestroy)(void*); 123 PetscErrorCode (*numeric)(Mat); 124 MatProductType ptype; 125 Mat Dwork; 126 } MatMatCF; 127 128 static PetscErrorCode MatProductDestroy_CF(void *data) 129 { 130 MatMatCF *mmcfdata = (MatMatCF*)data; 131 132 PetscFunctionBegin; 133 if (mmcfdata->userdestroy) { 134 PetscCall((*mmcfdata->userdestroy)(mmcfdata->userdata)); 135 } 136 PetscCall(MatDestroy(&mmcfdata->Dwork)); 137 PetscCall(PetscFree(mmcfdata)); 138 PetscFunctionReturn(0); 139 } 140 141 static PetscErrorCode MatProductNumericPhase_CF(Mat A, Mat B, Mat C, void *data) 142 { 143 MatMatCF *mmcfdata = (MatMatCF*)data; 144 145 PetscFunctionBegin; 146 PetscCheck(mmcfdata,PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Missing data"); 147 PetscCheck(mmcfdata->numeric,PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Missing numeric operation"); 148 /* the MATSHELL interface allows us to play with the product data */ 149 PetscCall(PetscNew(&C->product)); 150 C->product->type = mmcfdata->ptype; 151 C->product->data = mmcfdata->userdata; 152 C->product->Dwork = mmcfdata->Dwork; 153 PetscCall(MatShellGetContext(A,&C->product->A)); 154 C->product->B = B; 155 PetscCall((*mmcfdata->numeric)(C)); 156 PetscCall(PetscFree(C->product)); 157 PetscFunctionReturn(0); 158 } 159 160 static PetscErrorCode MatProductSymbolicPhase_CF(Mat A, Mat B, Mat C, void **data) 161 { 162 MatMatCF *mmcfdata; 163 164 PetscFunctionBegin; 165 PetscCall(MatShellGetContext(A,&C->product->A)); 166 PetscCall(MatProductSetFromOptions(C)); 167 PetscCall(MatProductSymbolic(C)); 168 /* the MATSHELL interface does not allow non-empty product data */ 169 PetscCall(PetscNew(&mmcfdata)); 170 171 mmcfdata->numeric = C->ops->productnumeric; 172 mmcfdata->ptype = C->product->type; 173 mmcfdata->userdata = C->product->data; 174 mmcfdata->userdestroy = C->product->destroy; 175 mmcfdata->Dwork = C->product->Dwork; 176 177 C->product->Dwork = NULL; 178 C->product->data = NULL; 179 C->product->destroy = NULL; 180 C->product->A = A; 181 182 *data = mmcfdata; 183 PetscFunctionReturn(0); 184 } 185 186 /* only for A of type shell, mainly used for MatMat operations of shells with AXPYs */ 187 static PetscErrorCode MatProductSetFromOptions_CF(Mat D) 188 { 189 Mat A,B,Ain; 190 void (*Af)(void) = NULL; 191 PetscBool flg; 192 193 PetscFunctionBegin; 194 MatCheckProduct(D,1); 195 if (D->product->type == MATPRODUCT_ABC) PetscFunctionReturn(0); 196 A = D->product->A; 197 B = D->product->B; 198 PetscCall(MatIsShell(A,&flg)); 199 if (!flg) PetscFunctionReturn(0); 200 PetscCall(PetscObjectQueryFunction((PetscObject)A,"MatProductSetFromOptions_anytype_C",&Af)); 201 if (Af == (void(*)(void))MatProductSetFromOptions_CF) { 202 PetscCall(MatShellGetContext(A,&Ain)); 203 } else PetscFunctionReturn(0); 204 D->product->A = Ain; 205 PetscCall(MatProductSetFromOptions(D)); 206 D->product->A = A; 207 if (D->ops->productsymbolic) { /* we have a symbolic match, now populate the MATSHELL operations */ 208 PetscCall(MatShellSetMatProductOperation(A,D->product->type,MatProductSymbolicPhase_CF,MatProductNumericPhase_CF,MatProductDestroy_CF,((PetscObject)B)->type_name,NULL)); 209 PetscCall(MatProductSetFromOptions(D)); 210 } 211 PetscFunctionReturn(0); 212 } 213 214 PetscErrorCode MatConvertFrom_Shell(Mat A,MatType newtype,MatReuse reuse,Mat *B) 215 { 216 Mat M; 217 PetscBool flg; 218 219 PetscFunctionBegin; 220 PetscCall(PetscStrcmp(newtype,MATSHELL,&flg)); 221 PetscCheck(flg,PETSC_COMM_SELF,PETSC_ERR_SUP,"Only conversion to MATSHELL"); 222 if (reuse == MAT_INITIAL_MATRIX) { 223 PetscCall(PetscObjectReference((PetscObject)A)); 224 PetscCall(MatCreateShell(PetscObjectComm((PetscObject)A),A->rmap->n,A->cmap->n,A->rmap->N,A->cmap->N,A,&M)); 225 PetscCall(MatSetBlockSizesFromMats(M,A,A)); 226 PetscCall(MatShellSetOperation(M,MATOP_MULT, (void (*)(void))MatMult_CF)); 227 PetscCall(MatShellSetOperation(M,MATOP_MULT_TRANSPOSE,(void (*)(void))MatMultTranspose_CF)); 228 PetscCall(MatShellSetOperation(M,MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_CF)); 229 PetscCall(MatShellSetOperation(M,MATOP_DESTROY, (void (*)(void))MatDestroy_CF)); 230 PetscCall(PetscObjectComposeFunction((PetscObject)M,"MatProductSetFromOptions_anytype_C",MatProductSetFromOptions_CF)); 231 PetscCall(PetscFree(M->defaultvectype)); 232 PetscCall(PetscStrallocpy(A->defaultvectype,&M->defaultvectype)); 233 #if defined(PETSC_HAVE_DEVICE) 234 PetscCall(MatBindToCPU(M,A->boundtocpu)); 235 #endif 236 *B = M; 237 } else SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_SUP,"Not implemented"); 238 PetscFunctionReturn(0); 239 } 240