1 2 /* 3 Defines matrix-matrix product routines for 4 C = A^T * B and C = A * B^t 5 with A SeqAIJ and B SeqDense 6 */ 7 8 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/ 9 #include <../src/mat/impls/dense/seq/dense.h> 10 11 PetscErrorCode MatDestroy_SeqDense_MatTransMatMult(void *data) 12 { 13 PetscErrorCode ierr; 14 Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data; 15 16 PetscFunctionBegin; 17 ierr = MatDestroy(&atb->mA);CHKERRQ(ierr); 18 ierr = VecDestroy(&atb->bt);CHKERRQ(ierr); 19 ierr = VecDestroy(&atb->ct);CHKERRQ(ierr); 20 ierr = PetscFree(atb);CHKERRQ(ierr); 21 PetscFunctionReturn(0); 22 } 23 24 static PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat,Mat,Mat); 25 26 PETSC_INTERN PetscErrorCode MatTMatTMultSymbolic_SeqAIJ_SeqDense(Mat A,Mat B,PetscReal fill,Mat C) 27 { 28 PetscErrorCode ierr; 29 Mat_MatTransMatMult *atb; 30 PetscBool cisdense; 31 PetscInt dofm; 32 33 PetscFunctionBegin; 34 MatCheckProduct(C,4); 35 if (C->product->data) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Extra product struct not empty"); 36 if (C->product->type != MATPRODUCT_ABt && C->product->type != MATPRODUCT_AtB) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Not for product type %s",MatProductTypes[C->product->type]); 37 38 /* create output dense matrix C */ 39 if (C->product->type == MATPRODUCT_AtB) { 40 ierr = MatSetSizes(C,A->cmap->n,B->cmap->N,A->cmap->n,B->cmap->N);CHKERRQ(ierr); 41 dofm = B->cmap->n; 42 } else { 43 ierr = MatSetSizes(C,A->rmap->n,B->rmap->N,A->rmap->n,B->rmap->N);CHKERRQ(ierr); 44 dofm = B->rmap->n; 45 } 46 ierr = PetscObjectTypeCompareAny((PetscObject)C,&cisdense,MATSEQDENSE,MATSEQDENSECUDA,"");CHKERRQ(ierr); 47 if (!cisdense) { 48 ierr = MatSetType(C,((PetscObject)B)->type_name);CHKERRQ(ierr); 49 } 50 ierr = MatSetUp(C);CHKERRQ(ierr); 51 52 /* create additional data structure for the product */ 53 ierr = PetscNew(&atb);CHKERRQ(ierr); 54 ierr = MatCreateMAIJ(A,dofm,&atb->mA);CHKERRQ(ierr); 55 ierr = MatCreateVecs(atb->mA,&atb->ct,&atb->bt);CHKERRQ(ierr); 56 C->product->data = atb; 57 C->product->destroy = MatDestroy_SeqDense_MatTransMatMult; 58 59 if (C->product->type == MATPRODUCT_AtB) { 60 C->ops->transposematmultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense; 61 } else { 62 C->ops->mattransposemultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense; 63 } 64 PetscFunctionReturn(0); 65 } 66 67 PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat A,Mat B,Mat C) 68 { 69 PetscErrorCode ierr; 70 PetscInt i,j,m=A->rmap->n,n=A->cmap->n,blda,clda; 71 PetscInt mdof = C->cmap->N; 72 const PetscScalar *Barray; 73 PetscScalar *Carray; 74 Mat_MatTransMatMult *atb; 75 Vec bt,ct; 76 77 PetscFunctionBegin; 78 MatCheckProduct(C,3); 79 if (C->product->type != MATPRODUCT_ABt && C->product->type != MATPRODUCT_AtB) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Not for product type %s",MatProductTypes[C->product->type]); 80 atb = (Mat_MatTransMatMult *)C->product->data; 81 if (!atb) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing product struct"); 82 bt = atb->bt; 83 ct = atb->ct; 84 85 ierr = MatDenseGetArrayRead(B,&Barray);CHKERRQ(ierr); 86 ierr = MatDenseGetLDA(B,&blda);CHKERRQ(ierr); 87 ierr = MatDenseGetArrayWrite(C,&Carray);CHKERRQ(ierr); 88 ierr = MatDenseGetLDA(C,&clda);CHKERRQ(ierr); 89 if (C->product->type == MATPRODUCT_AtB) { /* transpose local array of B, then copy it to vector bt */ 90 const PetscScalar *ctarray; 91 PetscScalar *btarray; 92 93 ierr = VecGetArrayWrite(bt,&btarray);CHKERRQ(ierr); 94 for (j=0; j<mdof; j++) { 95 for (i=0; i<m; i++) btarray[i*mdof + j] = Barray[j*blda + i]; 96 } 97 ierr = VecRestoreArrayWrite(bt,&btarray);CHKERRQ(ierr); 98 99 /* compute ct = mA^T * cb */ 100 ierr = MatMultTranspose(atb->mA,bt,ct);CHKERRQ(ierr); 101 102 /* transpose local array of ct to matrix C */ 103 ierr = VecGetArrayRead(ct,&ctarray);CHKERRQ(ierr); 104 for (j=0; j<mdof; j++) { 105 for (i=0; i<n; i++) Carray[j*clda + i] = ctarray[i*mdof + j]; 106 } 107 ierr = VecRestoreArrayRead(ct,&ctarray);CHKERRQ(ierr); 108 } else { 109 const PetscScalar *btarray; 110 PetscScalar *ctarray; 111 112 if (blda == B->rmap->n) { 113 ierr = VecPlaceArray(ct,Barray);CHKERRQ(ierr); 114 } else { 115 PetscInt bn = B->cmap->n; 116 PetscInt bm = B->rmap->n; 117 118 ierr = VecGetArrayWrite(ct,&ctarray);CHKERRQ(ierr); 119 for (j=0; j<bn; j++) { 120 for (i=0; i<bm; i++) ctarray[j*bm + i] = Barray[j*blda + i]; 121 } 122 ierr = VecRestoreArrayWrite(ct,&ctarray);CHKERRQ(ierr); 123 } 124 125 ierr = MatMult(atb->mA,ct,bt);CHKERRQ(ierr); 126 if (blda == B->rmap->n) { 127 ierr = VecResetArray(ct);CHKERRQ(ierr); 128 } 129 ierr = VecGetArrayRead(bt,&btarray);CHKERRQ(ierr); 130 for (j=0; j<mdof; j++) { 131 for (i=0; i<m; i++) Carray[j*clda + i] = btarray[i*mdof + j]; 132 } 133 ierr = VecRestoreArrayRead(bt,&btarray);CHKERRQ(ierr); 134 } 135 ierr = MatDenseRestoreArrayRead(B,&Barray);CHKERRQ(ierr); 136 ierr = MatDenseRestoreArray(C,&Carray);CHKERRQ(ierr); 137 PetscFunctionReturn(0); 138 } 139