1 2 /* 3 Defines matrix-matrix product routines for 4 C = A^T * B and C = A * B^t 5 with A SeqAIJ and B SeqDense 6 */ 7 8 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/ 9 #include <../src/mat/impls/dense/seq/dense.h> 10 11 PetscErrorCode MatDestroy_SeqDense_MatTransMatMult(void *data) 12 { 13 Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data; 14 15 PetscFunctionBegin; 16 PetscCall(MatDestroy(&atb->mA)); 17 PetscCall(VecDestroy(&atb->bt)); 18 PetscCall(VecDestroy(&atb->ct)); 19 PetscCall(PetscFree(atb)); 20 PetscFunctionReturn(0); 21 } 22 23 static PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat,Mat,Mat); 24 25 PETSC_INTERN PetscErrorCode MatTMatTMultSymbolic_SeqAIJ_SeqDense(Mat A,Mat B,PetscReal fill,Mat C) 26 { 27 Mat_MatTransMatMult *atb; 28 PetscBool cisdense; 29 PetscInt dofm; 30 31 PetscFunctionBegin; 32 MatCheckProduct(C,4); 33 PetscCheck(!C->product->data,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Extra product struct not empty"); 34 PetscCheckFalse(C->product->type != MATPRODUCT_ABt && C->product->type != MATPRODUCT_AtB,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Not for product type %s",MatProductTypes[C->product->type]); 35 36 /* create output dense matrix C */ 37 if (C->product->type == MATPRODUCT_AtB) { 38 PetscCall(MatSetSizes(C,A->cmap->n,B->cmap->N,A->cmap->n,B->cmap->N)); 39 dofm = B->cmap->n; 40 } else { 41 PetscCall(MatSetSizes(C,A->rmap->n,B->rmap->N,A->rmap->n,B->rmap->N)); 42 dofm = B->rmap->n; 43 } 44 PetscCall(PetscObjectTypeCompareAny((PetscObject)C,&cisdense,MATSEQDENSE,MATSEQDENSECUDA,"")); 45 if (!cisdense) { 46 PetscCall(MatSetType(C,((PetscObject)B)->type_name)); 47 } 48 PetscCall(MatSetUp(C)); 49 50 /* create additional data structure for the product */ 51 PetscCall(PetscNew(&atb)); 52 PetscCall(MatCreateMAIJ(A,dofm,&atb->mA)); 53 PetscCall(MatCreateVecs(atb->mA,&atb->ct,&atb->bt)); 54 C->product->data = atb; 55 C->product->destroy = MatDestroy_SeqDense_MatTransMatMult; 56 57 if (C->product->type == MATPRODUCT_AtB) { 58 C->ops->transposematmultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense; 59 } else { 60 C->ops->mattransposemultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense; 61 } 62 PetscFunctionReturn(0); 63 } 64 65 PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat A,Mat B,Mat C) 66 { 67 PetscInt i,j,m=A->rmap->n,n=A->cmap->n,blda,clda; 68 PetscInt mdof = C->cmap->N; 69 const PetscScalar *Barray; 70 PetscScalar *Carray; 71 Mat_MatTransMatMult *atb; 72 Vec bt,ct; 73 74 PetscFunctionBegin; 75 MatCheckProduct(C,3); 76 PetscCheckFalse(C->product->type != MATPRODUCT_ABt && C->product->type != MATPRODUCT_AtB,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Not for product type %s",MatProductTypes[C->product->type]); 77 atb = (Mat_MatTransMatMult *)C->product->data; 78 PetscCheck(atb,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing product struct"); 79 bt = atb->bt; 80 ct = atb->ct; 81 82 PetscCall(MatDenseGetArrayRead(B,&Barray)); 83 PetscCall(MatDenseGetLDA(B,&blda)); 84 PetscCall(MatDenseGetArrayWrite(C,&Carray)); 85 PetscCall(MatDenseGetLDA(C,&clda)); 86 if (C->product->type == MATPRODUCT_AtB) { /* transpose local array of B, then copy it to vector bt */ 87 const PetscScalar *ctarray; 88 PetscScalar *btarray; 89 90 PetscCall(VecGetArrayWrite(bt,&btarray)); 91 for (j=0; j<mdof; j++) { 92 for (i=0; i<m; i++) btarray[i*mdof + j] = Barray[j*blda + i]; 93 } 94 PetscCall(VecRestoreArrayWrite(bt,&btarray)); 95 96 /* compute ct = mA^T * cb */ 97 PetscCall(MatMultTranspose(atb->mA,bt,ct)); 98 99 /* transpose local array of ct to matrix C */ 100 PetscCall(VecGetArrayRead(ct,&ctarray)); 101 for (j=0; j<mdof; j++) { 102 for (i=0; i<n; i++) Carray[j*clda + i] = ctarray[i*mdof + j]; 103 } 104 PetscCall(VecRestoreArrayRead(ct,&ctarray)); 105 } else { 106 const PetscScalar *btarray; 107 PetscScalar *ctarray; 108 109 if (blda == B->rmap->n) { 110 PetscCall(VecPlaceArray(ct,Barray)); 111 } else { 112 PetscInt bn = B->cmap->n; 113 PetscInt bm = B->rmap->n; 114 115 PetscCall(VecGetArrayWrite(ct,&ctarray)); 116 for (j=0; j<bn; j++) { 117 for (i=0; i<bm; i++) ctarray[j*bm + i] = Barray[j*blda + i]; 118 } 119 PetscCall(VecRestoreArrayWrite(ct,&ctarray)); 120 } 121 122 PetscCall(MatMult(atb->mA,ct,bt)); 123 if (blda == B->rmap->n) { 124 PetscCall(VecResetArray(ct)); 125 } 126 PetscCall(VecGetArrayRead(bt,&btarray)); 127 for (j=0; j<mdof; j++) { 128 for (i=0; i<m; i++) Carray[j*clda + i] = btarray[i*mdof + j]; 129 } 130 PetscCall(VecRestoreArrayRead(bt,&btarray)); 131 } 132 PetscCall(MatDenseRestoreArrayRead(B,&Barray)); 133 PetscCall(MatDenseRestoreArray(C,&Carray)); 134 PetscFunctionReturn(0); 135 } 136