1 /* 2 Defines matrix-matrix product routines for 3 C = A^T * B and C = A * B^t 4 with A SeqAIJ and B SeqDense 5 */ 6 7 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/ 8 #include <../src/mat/impls/dense/seq/dense.h> 9 10 static PetscErrorCode MatDestroy_SeqDense_MatTransMatMult(void *data) 11 { 12 Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data; 13 14 PetscFunctionBegin; 15 PetscCall(MatDestroy(&atb->mA)); 16 PetscCall(VecDestroy(&atb->bt)); 17 PetscCall(VecDestroy(&atb->ct)); 18 PetscCall(PetscFree(atb)); 19 PetscFunctionReturn(PETSC_SUCCESS); 20 } 21 22 static PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat, Mat, Mat); 23 24 PETSC_INTERN PetscErrorCode MatTMatTMultSymbolic_SeqAIJ_SeqDense(Mat A, Mat B, PetscReal fill, Mat C) 25 { 26 Mat_MatTransMatMult *atb; 27 PetscBool cisdense; 28 PetscInt dofm; 29 30 PetscFunctionBegin; 31 MatCheckProduct(C, 4); 32 PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty"); 33 PetscCheck(C->product->type == MATPRODUCT_ABt || C->product->type == MATPRODUCT_AtB, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[C->product->type]); 34 35 /* create output dense matrix C */ 36 if (C->product->type == MATPRODUCT_AtB) { 37 PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->N, A->cmap->n, B->cmap->N)); 38 dofm = B->cmap->n; 39 } else { 40 PetscCall(MatSetSizes(C, A->rmap->n, B->rmap->N, A->rmap->n, B->rmap->N)); 41 dofm = B->rmap->n; 42 } 43 PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATSEQDENSE, MATSEQDENSECUDA, "")); 44 if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name)); 45 PetscCall(MatSetUp(C)); 46 47 /* create additional data structure for the product */ 48 PetscCall(PetscNew(&atb)); 49 PetscCall(MatCreateMAIJ(A, dofm, &atb->mA)); 50 PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt)); 51 C->product->data = atb; 52 C->product->destroy = MatDestroy_SeqDense_MatTransMatMult; 53 54 if (C->product->type == MATPRODUCT_AtB) { 55 C->ops->transposematmultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense; 56 } else { 57 C->ops->mattransposemultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense; 58 } 59 PetscFunctionReturn(PETSC_SUCCESS); 60 } 61 62 static PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat A, Mat B, Mat C) 63 { 64 PetscInt i, j, m = A->rmap->n, n = A->cmap->n, blda, clda; 65 PetscInt mdof = C->cmap->N; 66 const PetscScalar *Barray; 67 PetscScalar *Carray; 68 Mat_MatTransMatMult *atb; 69 Vec bt, ct; 70 71 PetscFunctionBegin; 72 MatCheckProduct(C, 3); 73 PetscCheck(C->product->type == MATPRODUCT_ABt || C->product->type == MATPRODUCT_AtB, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[C->product->type]); 74 atb = (Mat_MatTransMatMult *)C->product->data; 75 PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct"); 76 bt = atb->bt; 77 ct = atb->ct; 78 79 PetscCall(MatDenseGetArrayRead(B, &Barray)); 80 PetscCall(MatDenseGetLDA(B, &blda)); 81 PetscCall(MatDenseGetArrayWrite(C, &Carray)); 82 PetscCall(MatDenseGetLDA(C, &clda)); 83 if (C->product->type == MATPRODUCT_AtB) { /* transpose local array of B, then copy it to vector bt */ 84 const PetscScalar *ctarray; 85 PetscScalar *btarray; 86 87 PetscCall(VecGetArrayWrite(bt, &btarray)); 88 for (j = 0; j < mdof; j++) { 89 for (i = 0; i < m; i++) btarray[i * mdof + j] = Barray[j * blda + i]; 90 } 91 PetscCall(VecRestoreArrayWrite(bt, &btarray)); 92 93 /* compute ct = mA^T * cb */ 94 PetscCall(MatMultTranspose(atb->mA, bt, ct)); 95 96 /* transpose local array of ct to matrix C */ 97 PetscCall(VecGetArrayRead(ct, &ctarray)); 98 for (j = 0; j < mdof; j++) { 99 for (i = 0; i < n; i++) Carray[j * clda + i] = ctarray[i * mdof + j]; 100 } 101 PetscCall(VecRestoreArrayRead(ct, &ctarray)); 102 } else { 103 const PetscScalar *btarray; 104 PetscScalar *ctarray; 105 106 if (blda == B->rmap->n) { 107 PetscCall(VecPlaceArray(ct, Barray)); 108 } else { 109 PetscInt bn = B->cmap->n; 110 PetscInt bm = B->rmap->n; 111 112 PetscCall(VecGetArrayWrite(ct, &ctarray)); 113 for (j = 0; j < bn; j++) { 114 for (i = 0; i < bm; i++) ctarray[j * bm + i] = Barray[j * blda + i]; 115 } 116 PetscCall(VecRestoreArrayWrite(ct, &ctarray)); 117 } 118 119 PetscCall(MatMult(atb->mA, ct, bt)); 120 if (blda == B->rmap->n) PetscCall(VecResetArray(ct)); 121 PetscCall(VecGetArrayRead(bt, &btarray)); 122 for (j = 0; j < mdof; j++) { 123 for (i = 0; i < m; i++) Carray[j * clda + i] = btarray[i * mdof + j]; 124 } 125 PetscCall(VecRestoreArrayRead(bt, &btarray)); 126 } 127 PetscCall(MatDenseRestoreArrayRead(B, &Barray)); 128 PetscCall(MatDenseRestoreArray(C, &Carray)); 129 PetscFunctionReturn(PETSC_SUCCESS); 130 } 131