1 2 /* 3 Defines matrix-matrix product routines for 4 C = A^T * B and C = A * B^t 5 with A SeqAIJ and B SeqDense 6 */ 7 8 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/ 9 #include <../src/mat/impls/dense/seq/dense.h> 10 11 PetscErrorCode MatDestroy_SeqDense_MatTransMatMult(void *data) 12 { 13 Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data; 14 15 PetscFunctionBegin; 16 PetscCall(MatDestroy(&atb->mA)); 17 PetscCall(VecDestroy(&atb->bt)); 18 PetscCall(VecDestroy(&atb->ct)); 19 PetscCall(PetscFree(atb)); 20 PetscFunctionReturn(PETSC_SUCCESS); 21 } 22 23 static PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat, Mat, Mat); 24 25 PETSC_INTERN PetscErrorCode MatTMatTMultSymbolic_SeqAIJ_SeqDense(Mat A, Mat B, PetscReal fill, Mat C) 26 { 27 Mat_MatTransMatMult *atb; 28 PetscBool cisdense; 29 PetscInt dofm; 30 31 PetscFunctionBegin; 32 MatCheckProduct(C, 4); 33 PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty"); 34 PetscCheck(C->product->type == MATPRODUCT_ABt || C->product->type == MATPRODUCT_AtB, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[C->product->type]); 35 36 /* create output dense matrix C */ 37 if (C->product->type == MATPRODUCT_AtB) { 38 PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->N, A->cmap->n, B->cmap->N)); 39 dofm = B->cmap->n; 40 } else { 41 PetscCall(MatSetSizes(C, A->rmap->n, B->rmap->N, A->rmap->n, B->rmap->N)); 42 dofm = B->rmap->n; 43 } 44 PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATSEQDENSE, MATSEQDENSECUDA, "")); 45 if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name)); 46 PetscCall(MatSetUp(C)); 47 48 /* create additional data structure for the product */ 49 PetscCall(PetscNew(&atb)); 50 PetscCall(MatCreateMAIJ(A, dofm, &atb->mA)); 51 PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt)); 52 C->product->data = atb; 53 C->product->destroy = MatDestroy_SeqDense_MatTransMatMult; 54 55 if (C->product->type == MATPRODUCT_AtB) { 56 C->ops->transposematmultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense; 57 } else { 58 C->ops->mattransposemultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense; 59 } 60 PetscFunctionReturn(PETSC_SUCCESS); 61 } 62 63 PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat A, Mat B, Mat C) 64 { 65 PetscInt i, j, m = A->rmap->n, n = A->cmap->n, blda, clda; 66 PetscInt mdof = C->cmap->N; 67 const PetscScalar *Barray; 68 PetscScalar *Carray; 69 Mat_MatTransMatMult *atb; 70 Vec bt, ct; 71 72 PetscFunctionBegin; 73 MatCheckProduct(C, 3); 74 PetscCheck(C->product->type == MATPRODUCT_ABt || C->product->type == MATPRODUCT_AtB, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[C->product->type]); 75 atb = (Mat_MatTransMatMult *)C->product->data; 76 PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct"); 77 bt = atb->bt; 78 ct = atb->ct; 79 80 PetscCall(MatDenseGetArrayRead(B, &Barray)); 81 PetscCall(MatDenseGetLDA(B, &blda)); 82 PetscCall(MatDenseGetArrayWrite(C, &Carray)); 83 PetscCall(MatDenseGetLDA(C, &clda)); 84 if (C->product->type == MATPRODUCT_AtB) { /* transpose local array of B, then copy it to vector bt */ 85 const PetscScalar *ctarray; 86 PetscScalar *btarray; 87 88 PetscCall(VecGetArrayWrite(bt, &btarray)); 89 for (j = 0; j < mdof; j++) { 90 for (i = 0; i < m; i++) btarray[i * mdof + j] = Barray[j * blda + i]; 91 } 92 PetscCall(VecRestoreArrayWrite(bt, &btarray)); 93 94 /* compute ct = mA^T * cb */ 95 PetscCall(MatMultTranspose(atb->mA, bt, ct)); 96 97 /* transpose local array of ct to matrix C */ 98 PetscCall(VecGetArrayRead(ct, &ctarray)); 99 for (j = 0; j < mdof; j++) { 100 for (i = 0; i < n; i++) Carray[j * clda + i] = ctarray[i * mdof + j]; 101 } 102 PetscCall(VecRestoreArrayRead(ct, &ctarray)); 103 } else { 104 const PetscScalar *btarray; 105 PetscScalar *ctarray; 106 107 if (blda == B->rmap->n) { 108 PetscCall(VecPlaceArray(ct, Barray)); 109 } else { 110 PetscInt bn = B->cmap->n; 111 PetscInt bm = B->rmap->n; 112 113 PetscCall(VecGetArrayWrite(ct, &ctarray)); 114 for (j = 0; j < bn; j++) { 115 for (i = 0; i < bm; i++) ctarray[j * bm + i] = Barray[j * blda + i]; 116 } 117 PetscCall(VecRestoreArrayWrite(ct, &ctarray)); 118 } 119 120 PetscCall(MatMult(atb->mA, ct, bt)); 121 if (blda == B->rmap->n) PetscCall(VecResetArray(ct)); 122 PetscCall(VecGetArrayRead(bt, &btarray)); 123 for (j = 0; j < mdof; j++) { 124 for (i = 0; i < m; i++) Carray[j * clda + i] = btarray[i * mdof + j]; 125 } 126 PetscCall(VecRestoreArrayRead(bt, &btarray)); 127 } 128 PetscCall(MatDenseRestoreArrayRead(B, &Barray)); 129 PetscCall(MatDenseRestoreArray(C, &Carray)); 130 PetscFunctionReturn(PETSC_SUCCESS); 131 } 132