xref: /petsc/src/vec/vec/interface/vecreg.c (revision 03047865b8d8757cf1cf9cda45785c1537b01dc1)
1 #include <petsc/private/vecimpl.h> /*I "petscvec.h"  I*/
2 
3 PetscFunctionList VecList = NULL;
4 
5 /* compare a vector type against a list of target vector types */
VecTypeCompareAny_Private(VecType srcType,PetscBool * match,const char tgtTypes[],...)6 static inline PetscErrorCode VecTypeCompareAny_Private(VecType srcType, PetscBool *match, const char tgtTypes[], ...)
7 {
8   PetscBool flg = PETSC_FALSE;
9   va_list   Argp;
10 
11   PetscFunctionBegin;
12   PetscAssertPointer(match, 2);
13   *match = PETSC_FALSE;
14   va_start(Argp, tgtTypes);
15   while (tgtTypes && tgtTypes[0]) {
16     PetscCall(PetscStrcmp(srcType, tgtTypes, &flg));
17     if (flg) {
18       *match = PETSC_TRUE;
19       break;
20     }
21     tgtTypes = va_arg(Argp, const char *);
22   }
23   va_end(Argp);
24   PetscFunctionReturn(PETSC_SUCCESS);
25 }
26 
27 #define PETSC_MAX_VECTYPE_LEN 64
28 
29 /*@
30   VecSetType - Builds a vector, for a particular vector implementation.
31 
32   Collective
33 
34   Input Parameters:
35 + vec     - The vector object
36 - newType - The name of the vector type
37 
38   Options Database Key:
39 . -vec_type <type> - Sets the vector type; use -help for a list
40                      of available types
41 
42   Level: intermediate
43 
44   Notes:
45   See `VecType` for available vector types (for instance, `VECSEQ` or `VECMPI`)
46   Changing a vector to a new type will retain its old value if any.
47 
48   Use `VecDuplicate()` or `VecDuplicateVecs()` to form additional vectors of the same type as an existing vector.
49 
50 .seealso: [](ch_vectors), `Vec`, `VecType`, `VecGetType()`, `VecCreate()`, `VecDuplicate()`, `VecDuplicateVecs()`
51 @*/
VecSetType(Vec vec,VecType newType)52 PetscErrorCode VecSetType(Vec vec, VecType newType)
53 {
54   PetscErrorCode (*r)(Vec);
55   VecType      curType;
56   PetscBool    match;
57   PetscMPIInt  size;
58   PetscBool    dstSeq = PETSC_FALSE; // type info of the new type
59   MPI_Comm     comm;
60   char         seqType[PETSC_MAX_VECTYPE_LEN] = {0};
61   char         mpiType[PETSC_MAX_VECTYPE_LEN] = {0};
62   PetscScalar *oldValue;
63   PetscBool    srcStandard, dstStandard;
64 
65   PetscFunctionBegin;
66   PetscValidHeaderSpecific(vec, VEC_CLASSID, 1);
67 
68   PetscCall(VecGetType(vec, &curType));
69   if (!curType) goto newvec; // vec's type is not set yet
70 
71   /* return if exactly the same type */
72   PetscCall(PetscObjectTypeCompare((PetscObject)vec, newType, &match));
73   if (match) PetscFunctionReturn(PETSC_SUCCESS);
74 
75   /* error on illegal mpi to seq conversion */
76   PetscCall(PetscObjectGetComm((PetscObject)vec, &comm));
77   PetscCallMPI(MPI_Comm_size(comm, &size));
78 
79   PetscCall(PetscStrbeginswith(newType, VECSEQ, &dstSeq));
80   PetscCheck(!(size > 1 && dstSeq), comm, PETSC_ERR_ARG_WRONG, "Cannot convert MPI vectors to sequential ones");
81 
82   /* return if standard => standard */
83   if (size == 1) PetscCall(PetscObjectTypeCompare((PetscObject)vec, VECSEQ, &srcStandard));
84   else PetscCall(PetscObjectTypeCompare((PetscObject)vec, VECMPI, &srcStandard));
85   PetscCall(VecTypeCompareAny_Private(newType, &dstStandard, VECSTANDARD, VECSEQ, VECMPI, ""));
86   if (srcStandard && dstStandard) PetscFunctionReturn(PETSC_SUCCESS);
87 
88   /* return if curType = "seq" | "mpi" + newType */
89   PetscCall(PetscStrncpy(mpiType, "mpi", 4));
90   PetscCall(PetscStrlcat(mpiType, newType, PETSC_MAX_VECTYPE_LEN));
91   PetscCall(PetscStrncpy(seqType, "seq", 4));
92   PetscCall(PetscStrlcat(seqType, newType, PETSC_MAX_VECTYPE_LEN));
93   PetscCall(PetscObjectTypeCompareAny((PetscObject)vec, &match, seqType, mpiType, ""));
94   if (match) PetscFunctionReturn(PETSC_SUCCESS);
95 
96   /* downcast VECSTANDARD to VECCUDA/HIP/KOKKOS in place. We don't do in-place upcasting
97   for those vectors. At least, it is not always possible to upcast a VECCUDA to VECSTANDARD
98   in place, since the host array might be pinned (i.e., allocated by cudaMallocHost()). If
99   we upcast it to VECSTANDARD, we could not free the pinned array with PetscFree(), which
100   is assumed for VECSTANDARD. Thus we just create a new vector, though it is expensive.
101   Upcasting is rare and users are not recommended to use it.
102   */
103 #if defined(PETSC_HAVE_CUDA)
104   {
105     PetscBool dstCUDA = PETSC_FALSE;
106     if (!dstStandard) PetscCall(VecTypeCompareAny_Private(newType, &dstCUDA, VECCUDA, VECSEQCUDA, VECMPICUDA, ""));
107     if (srcStandard && dstCUDA) {
108       if (size == 1) PetscCall(VecConvert_Seq_SeqCUDA_inplace(vec));
109       else PetscCall(VecConvert_MPI_MPICUDA_inplace(vec));
110       PetscFunctionReturn(PETSC_SUCCESS);
111     }
112   }
113 #endif
114 #if defined(PETSC_HAVE_HIP)
115   {
116     PetscBool dstHIP = PETSC_FALSE;
117     if (!dstStandard) PetscCall(VecTypeCompareAny_Private(newType, &dstHIP, VECHIP, VECSEQHIP, VECMPIHIP, ""));
118     if (srcStandard && dstHIP) {
119       if (size == 1) PetscCall(VecConvert_Seq_SeqHIP_inplace(vec));
120       else PetscCall(VecConvert_MPI_MPIHIP_inplace(vec));
121       PetscFunctionReturn(PETSC_SUCCESS);
122     }
123   }
124 #endif
125 #if defined(PETSC_HAVE_KOKKOS_KERNELS)
126   {
127     PetscBool dstKokkos = PETSC_FALSE;
128     if (!dstStandard) PetscCall(VecTypeCompareAny_Private(newType, &dstKokkos, VECKOKKOS, VECSEQKOKKOS, VECMPIKOKKOS, ""));
129     if (srcStandard && dstKokkos) {
130       if (size == 1) PetscCall(VecConvert_Seq_SeqKokkos_inplace(vec));
131       else PetscCall(VecConvert_MPI_MPIKokkos_inplace(vec));
132       PetscFunctionReturn(PETSC_SUCCESS);
133     }
134   }
135 #endif
136 
137   /* Other conversion scenarios: create a new vector but retain old value */
138 newvec:
139   PetscCall(PetscFunctionListFind(VecList, newType, &r));
140   PetscCheck(r, PetscObjectComm((PetscObject)vec), PETSC_ERR_ARG_UNKNOWN_TYPE, "Unknown vector type: %s", newType);
141   if (curType) { /* no need to destroy a vec without type */
142     const PetscScalar *array;
143     PetscCall(VecGetArrayRead(vec, &array));
144     if (array) {                                       /* record the old value if any before destroy */
145       PetscCall(PetscMalloc1(vec->map->n, &oldValue)); /* no need to free since we'll drop it into vec */
146       PetscCall(PetscArraycpy(oldValue, array, vec->map->n));
147     } else {
148       oldValue = NULL;
149     }
150     PetscCall(VecRestoreArrayRead(vec, &array));
151     PetscTryTypeMethod(vec, destroy);
152     PetscCall(PetscMemzero(vec->ops, sizeof(struct _VecOps)));
153     PetscCall(PetscFree(vec->defaultrandtype));
154     PetscCall(PetscFree(((PetscObject)vec)->type_name)); /* free type_name to make vec clean to use, as we might call VecSetType() again */
155   }
156 
157   if (vec->map->n < 0 && vec->map->N < 0) {
158     vec->ops->create = r;
159     vec->ops->load   = VecLoad_Default;
160   } else {
161     PetscCall((*r)(vec));
162   }
163 
164   /* drop in the old value */
165   if (curType && vec->map->n) PetscCall(VecReplaceArray(vec, oldValue));
166   PetscFunctionReturn(PETSC_SUCCESS);
167 }
168 
169 /*@
170   VecGetType - Gets the vector type name (as a string) from a `Vec`.
171 
172   Not Collective
173 
174   Input Parameter:
175 . vec - The vector
176 
177   Output Parameter:
178 . type - The `VecType` of the vector
179 
180   Level: intermediate
181 
182 .seealso: [](ch_vectors), `Vec`, `VecType`, `VecCreate()`, `VecDuplicate()`, `VecDuplicateVecs()`
183 @*/
VecGetType(Vec vec,VecType * type)184 PetscErrorCode VecGetType(Vec vec, VecType *type)
185 {
186   PetscFunctionBegin;
187   PetscValidHeaderSpecific(vec, VEC_CLASSID, 1);
188   PetscAssertPointer(type, 2);
189   PetscCall(VecRegisterAll());
190   *type = ((PetscObject)vec)->type_name;
191   PetscFunctionReturn(PETSC_SUCCESS);
192 }
193 
VecGetRootType_Private(Vec vec,VecType * vtype)194 PetscErrorCode VecGetRootType_Private(Vec vec, VecType *vtype)
195 {
196   PetscBool iscuda, iship, iskokkos, isvcl;
197 
198   PetscFunctionBegin;
199   PetscValidHeaderSpecific(vec, VEC_CLASSID, 1);
200   PetscAssertPointer(vtype, 2);
201   PetscCall(PetscObjectTypeCompareAny((PetscObject)vec, &iscuda, VECCUDA, VECMPICUDA, VECSEQCUDA, ""));
202   PetscCall(PetscObjectTypeCompareAny((PetscObject)vec, &iship, VECHIP, VECMPIHIP, VECSEQHIP, ""));
203   PetscCall(PetscObjectTypeCompareAny((PetscObject)vec, &iskokkos, VECKOKKOS, VECMPIKOKKOS, VECSEQKOKKOS, ""));
204   PetscCall(PetscObjectTypeCompareAny((PetscObject)vec, &isvcl, VECVIENNACL, VECMPIVIENNACL, VECSEQVIENNACL, ""));
205   if (iscuda) {
206     *vtype = VECCUDA;
207   } else if (iship) {
208     *vtype = VECHIP;
209   } else if (iskokkos) {
210     *vtype = VECKOKKOS;
211   } else if (isvcl) {
212     *vtype = VECVIENNACL;
213   } else {
214     *vtype = VECSTANDARD;
215   }
216   PetscFunctionReturn(PETSC_SUCCESS);
217 }
218 
219 /*@C
220   VecRegister -  Adds a new vector component implementation
221 
222   Not Collective, No Fortran Support
223 
224   Input Parameters:
225 + sname    - The name of a new user-defined creation routine
226 - function - The creation routine
227 
228   Notes:
229   `VecRegister()` may be called multiple times to add several user-defined vectors
230 
231   Example Usage:
232 .vb
233     VecRegister("my_vec",MyVectorCreate);
234 .ve
235 
236   Then, your vector type can be chosen with the procedural interface via
237 .vb
238     VecCreate(MPI_Comm, Vec *);
239     VecSetType(Vec,"my_vector_name");
240 .ve
241   or at runtime via the option
242 .vb
243     -vec_type my_vector_name
244 .ve
245 
246   Level: advanced
247 
248 .seealso: `VecRegisterAll()`, `VecRegisterDestroy()`
249 @*/
VecRegister(const char sname[],PetscErrorCode (* function)(Vec))250 PetscErrorCode VecRegister(const char sname[], PetscErrorCode (*function)(Vec))
251 {
252   PetscFunctionBegin;
253   PetscCall(VecInitializePackage());
254   PetscCall(PetscFunctionListAdd(&VecList, sname, function));
255   PetscFunctionReturn(PETSC_SUCCESS);
256 }
257