xref: /petsc/src/dm/impls/plex/kokkos/plexlocalizationletkf.kokkos.cxx (revision 933231d81884bfe7bea4e44a0bd9daf9c602c7da)
1828beda2SMark Adams #include <petsc/private/dmpleximpl.h>
2828beda2SMark Adams #include <petscdmplex.h>
3828beda2SMark Adams #include <petscmat.h>
4828beda2SMark Adams #include <petsc_kokkos.hpp>
5828beda2SMark Adams #include <cmath>
6828beda2SMark Adams #include <cstdlib>
7828beda2SMark Adams #include <algorithm>
8828beda2SMark Adams #include <Kokkos_Core.hpp>
9828beda2SMark Adams 
10828beda2SMark Adams typedef struct {
11828beda2SMark Adams   PetscReal distance;
12828beda2SMark Adams   PetscInt  obs_index;
13828beda2SMark Adams } DistObsPair;
14828beda2SMark Adams 
15828beda2SMark Adams KOKKOS_INLINE_FUNCTION
GaspariCohn(PetscReal distance,PetscReal radius)16828beda2SMark Adams static PetscReal GaspariCohn(PetscReal distance, PetscReal radius)
17828beda2SMark Adams {
18828beda2SMark Adams   if (radius <= 0.0) return 0.0;
19*ee102026SMark Adams   const PetscReal r = distance / radius;
20828beda2SMark Adams 
21*ee102026SMark Adams   if (r >= 2.0) return 0.0;
22*ee102026SMark Adams 
23*ee102026SMark Adams   const PetscReal r2 = r * r;
24*ee102026SMark Adams   const PetscReal r3 = r2 * r;
25*ee102026SMark Adams   const PetscReal r4 = r3 * r;
26*ee102026SMark Adams   const PetscReal r5 = r4 * r;
27*ee102026SMark Adams 
28*ee102026SMark Adams   if (r <= 1.0) {
29828beda2SMark Adams     // Region [0, 1]
30*ee102026SMark Adams     return -0.25 * r5 + 0.5 * r4 + 0.625 * r3 - (5.0 / 3.0) * r2 + 1.0;
31*ee102026SMark Adams   } else {
32*ee102026SMark Adams     // Region [1, 2]
33*ee102026SMark Adams     return (1.0 / 12.0) * r5 - 0.5 * r4 + 0.625 * r3 + (5.0 / 3.0) * r2 - 5.0 * r + 4.0 - (2.0 / 3.0) / r;
34828beda2SMark Adams   }
35828beda2SMark Adams }
36828beda2SMark Adams 
37828beda2SMark Adams /*@
38*ee102026SMark Adams   DMPlexGetLETKFLocalizationMatrix - Compute localization weight matrix for LETKF [move to ml/da/interface]
39828beda2SMark Adams 
40828beda2SMark Adams   Collective
41828beda2SMark Adams 
42828beda2SMark Adams   Input Parameters:
43*ee102026SMark Adams + n_obs_vertex - Number of nearest observations to use per vertex (eg, MAX_Q_NUM_LOCAL_OBSERVATIONS in LETKF)
44*ee102026SMark Adams . n_obs_local - Number of local observations
45*ee102026SMark Adams . n_dof - Number of degrees of freedom
46*ee102026SMark Adams . Vecxyz - Array of vectors containing the coordinates
47828beda2SMark Adams - H - Observation operator matrix
48828beda2SMark Adams 
49828beda2SMark Adams   Output Parameter:
50828beda2SMark Adams . Q - Localization weight matrix (sparse, AIJ format)
51828beda2SMark Adams 
52828beda2SMark Adams   Notes:
53*ee102026SMark Adams   The output matrix Q has dimensions (n_vert_global x n_obs_global) where
54*ee102026SMark Adams   n_vert_global is the number of vertices in the DMPlex. Each row contains
55*ee102026SMark Adams   exactly n_obs_vertex non-zero entries corresponding to the nearest
56828beda2SMark Adams   observations, weighted by the Gaspari-Cohn fifth-order piecewise
57828beda2SMark Adams   rational function.
58828beda2SMark Adams 
59828beda2SMark Adams   The observation locations are computed as H * V where V is the vector
60828beda2SMark Adams   of vertex coordinates. The localization weights ensure smooth tapering
61828beda2SMark Adams   of observation influence with distance.
62828beda2SMark Adams 
63*ee102026SMark Adams   Kokkos is required for this routine.
64828beda2SMark Adams 
65828beda2SMark Adams   Level: intermediate
66828beda2SMark Adams 
67*ee102026SMark Adams .seealso:
68828beda2SMark Adams @*/
DMPlexGetLETKFLocalizationMatrix(const PetscInt n_obs_vertex,const PetscInt n_obs_local,const PetscInt n_dof,Vec Vecxyz[3],Mat H,Mat * Q)69*ee102026SMark Adams PetscErrorCode DMPlexGetLETKFLocalizationMatrix(const PetscInt n_obs_vertex, const PetscInt n_obs_local, const PetscInt n_dof, Vec Vecxyz[3], Mat H, Mat *Q)
70828beda2SMark Adams {
71*ee102026SMark Adams   PetscInt dim = 0, n_vert_local, d, N, n_obs_global, n_state_local;
72828beda2SMark Adams   Vec     *obs_vecs;
73828beda2SMark Adams   MPI_Comm comm;
74*ee102026SMark Adams   PetscInt n_state_global;
75828beda2SMark Adams 
76828beda2SMark Adams   PetscFunctionBegin;
77*ee102026SMark Adams   PetscValidHeaderSpecific(H, MAT_CLASSID, 5);
78*ee102026SMark Adams   PetscAssertPointer(Q, 6);
79828beda2SMark Adams 
80828beda2SMark Adams   PetscCall(PetscKokkosInitializeCheck());
81828beda2SMark Adams 
82*ee102026SMark Adams   PetscCall(PetscObjectGetComm((PetscObject)H, &comm));
83*ee102026SMark Adams 
84*ee102026SMark Adams   /* Infer dim from the number of vectors in Vecxyz */
85*ee102026SMark Adams   for (d = 0; d < 3; ++d) {
86*ee102026SMark Adams     if (Vecxyz[d]) dim++;
87*ee102026SMark Adams     else break;
88*ee102026SMark Adams   }
89*ee102026SMark Adams 
90*ee102026SMark Adams   PetscCheck(dim > 0, comm, PETSC_ERR_ARG_WRONG, "Dim must be > 0");
91*ee102026SMark Adams   PetscCheck(n_obs_vertex > 0, comm, PETSC_ERR_ARG_WRONG, "n_obs_vertex must be > 0");
92*ee102026SMark Adams 
93*ee102026SMark Adams   PetscCall(VecGetSize(Vecxyz[0], &n_state_global));
94*ee102026SMark Adams   PetscCall(VecGetLocalSize(Vecxyz[0], &n_state_local));
95*ee102026SMark Adams   n_vert_local = n_state_local / n_dof;
96828beda2SMark Adams 
97828beda2SMark Adams   /* Check H dimensions */
98*ee102026SMark Adams   PetscCall(MatGetSize(H, &n_obs_global, &N));
99*ee102026SMark Adams   PetscCheck(N == n_state_global, comm, PETSC_ERR_ARG_SIZ, "H number of columns %" PetscInt_FMT " != global state size %" PetscInt_FMT, N, n_state_global);
100*ee102026SMark Adams   // If n_obs_global < n_obs_vertex, we will pad with -1 indices and 0.0 weights.
101*ee102026SMark Adams   // This is not an error condition, but rather a case where we have fewer observations than requested neighbors.
102828beda2SMark Adams 
103828beda2SMark Adams   /* Allocate storage for observation locations */
104828beda2SMark Adams   PetscCall(PetscMalloc1(dim, &obs_vecs));
105828beda2SMark Adams 
106828beda2SMark Adams   /* Compute observation locations per dimension */
107828beda2SMark Adams   for (d = 0; d < dim; ++d) {
108*ee102026SMark Adams     PetscCall(MatCreateVecs(H, NULL, &obs_vecs[d]));
109*ee102026SMark Adams     PetscCall(MatMult(H, Vecxyz[d], obs_vecs[d]));
110828beda2SMark Adams   }
111828beda2SMark Adams 
112*ee102026SMark Adams   /* Create output matrix Q in N/n_dof x P */
113828beda2SMark Adams   PetscCall(MatCreate(comm, Q));
114*ee102026SMark Adams   PetscCall(MatSetSizes(*Q, n_vert_local, n_obs_local, PETSC_DETERMINE, n_obs_global));
115*ee102026SMark Adams   PetscCall(MatSetType(*Q, MATAIJ));
116*ee102026SMark Adams   PetscCall(MatSeqAIJSetPreallocation(*Q, n_obs_vertex, NULL));
117*ee102026SMark Adams   PetscCall(MatMPIAIJSetPreallocation(*Q, n_obs_vertex, NULL, n_obs_vertex, NULL));
118*ee102026SMark Adams   PetscCall(MatSetFromOptions(*Q));
119828beda2SMark Adams   PetscCall(MatSetUp(*Q));
120828beda2SMark Adams 
121*ee102026SMark Adams   PetscCall(PetscInfo((PetscObject)*Q, "Computing LETKF localization matrix: %" PetscInt_FMT " vertices, %" PetscInt_FMT " observations, %" PetscInt_FMT " neighbors\n", n_vert_local, n_obs_global, n_obs_vertex));
122*ee102026SMark Adams 
123828beda2SMark Adams   /* Prepare Kokkos Views */
124828beda2SMark Adams   using ExecSpace = Kokkos::DefaultExecutionSpace;
125828beda2SMark Adams   using MemSpace  = ExecSpace::memory_space;
126828beda2SMark Adams 
127828beda2SMark Adams   /* Vertex Coordinates */
128*ee102026SMark Adams   // Use LayoutLeft for coalesced access on GPU (i is contiguous)
129*ee102026SMark Adams   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, MemSpace> vertex_coords_dev("vertex_coords", n_vert_local, dim);
130828beda2SMark Adams   {
131*ee102026SMark Adams     // Host view must match the data layout from VecGetArray (d-major, i-minor implies LayoutLeft for (i,d) view)
132*ee102026SMark Adams     Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace> vertex_coords_host("vertex_coords_host", n_vert_local, dim);
133*ee102026SMark Adams     for (d = 0; d < dim; ++d) {
134*ee102026SMark Adams       const PetscScalar *local_coords_array;
135*ee102026SMark Adams       PetscCall(VecGetArrayRead(Vecxyz[d], &local_coords_array));
136*ee102026SMark Adams       // Copy data. Since vertex_coords_host is LayoutLeft, &vertex_coords_host(0, d) is the start of column d.
137*ee102026SMark Adams       for (PetscInt i = 0; i < n_vert_local; ++i) vertex_coords_host(i, d) = local_coords_array[i];
138*ee102026SMark Adams       PetscCall(VecRestoreArrayRead(Vecxyz[d], &local_coords_array));
139828beda2SMark Adams     }
140828beda2SMark Adams     Kokkos::deep_copy(vertex_coords_dev, vertex_coords_host);
141828beda2SMark Adams   }
142828beda2SMark Adams 
143828beda2SMark Adams   /* Observation Coordinates */
144*ee102026SMark Adams   Kokkos::View<PetscReal **, Kokkos::LayoutRight, MemSpace> obs_coords_dev("obs_coords", n_obs_global, dim);
145828beda2SMark Adams   {
146*ee102026SMark Adams     Kokkos::View<PetscReal **, Kokkos::LayoutRight, Kokkos::HostSpace> obs_coords_host("obs_coords_host", n_obs_global, dim);
147*ee102026SMark Adams     for (d = 0; d < dim; ++d) {
148*ee102026SMark Adams       VecScatter         ctx;
149*ee102026SMark Adams       Vec                seq_vec;
150*ee102026SMark Adams       const PetscScalar *array;
151*ee102026SMark Adams 
152*ee102026SMark Adams       PetscCall(VecScatterCreateToAll(obs_vecs[d], &ctx, &seq_vec));
153*ee102026SMark Adams       PetscCall(VecScatterBegin(ctx, obs_vecs[d], seq_vec, INSERT_VALUES, SCATTER_FORWARD));
154*ee102026SMark Adams       PetscCall(VecScatterEnd(ctx, obs_vecs[d], seq_vec, INSERT_VALUES, SCATTER_FORWARD));
155*ee102026SMark Adams 
156*ee102026SMark Adams       PetscCall(VecGetArrayRead(seq_vec, &array));
157*ee102026SMark Adams       for (PetscInt j = 0; j < n_obs_global; ++j) obs_coords_host(j, d) = PetscRealPart(array[j]);
158*ee102026SMark Adams       PetscCall(VecRestoreArrayRead(seq_vec, &array));
159*ee102026SMark Adams       PetscCall(VecScatterDestroy(&ctx));
160*ee102026SMark Adams       PetscCall(VecDestroy(&seq_vec));
161828beda2SMark Adams     }
162828beda2SMark Adams     Kokkos::deep_copy(obs_coords_dev, obs_coords_host);
163828beda2SMark Adams   }
164828beda2SMark Adams 
165*ee102026SMark Adams   PetscInt rstart;
166*ee102026SMark Adams   PetscCall(VecGetOwnershipRange(Vecxyz[0], &rstart, NULL));
167828beda2SMark Adams 
168828beda2SMark Adams   /* Output Views */
169*ee102026SMark Adams   // LayoutLeft for coalesced access on GPU
170*ee102026SMark Adams   Kokkos::View<PetscInt **, Kokkos::LayoutLeft, MemSpace>    indices_dev("indices", n_vert_local, n_obs_vertex);
171*ee102026SMark Adams   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, MemSpace> values_dev("values", n_vert_local, n_obs_vertex);
172828beda2SMark Adams 
173828beda2SMark Adams   /* Temporary storage for top-k per vertex */
174*ee102026SMark Adams   // LayoutLeft for coalesced access on GPU.
175*ee102026SMark Adams   // Note: For the insertion sort within a thread, LayoutRight would offer better cache locality for the thread's private list.
176*ee102026SMark Adams   // However, LayoutLeft is preferred for coalesced access across threads during the final weight computation and initialization.
177*ee102026SMark Adams   // Given the random access nature of the sort (divergence), we stick to the default GPU layout (Left).
178*ee102026SMark Adams   Kokkos::View<PetscReal **, Kokkos::LayoutLeft, MemSpace> best_dists_dev("best_dists", n_vert_local, n_obs_vertex);
179*ee102026SMark Adams   Kokkos::View<PetscInt **, Kokkos::LayoutLeft, MemSpace>  best_idxs_dev("best_idxs", n_vert_local, n_obs_vertex);
180828beda2SMark Adams 
181828beda2SMark Adams   /* Main Kernel */
182828beda2SMark Adams   Kokkos::parallel_for(
183*ee102026SMark Adams     "ComputeLocalization", Kokkos::RangePolicy<ExecSpace>(0, n_vert_local), KOKKOS_LAMBDA(const PetscInt i) {
184*ee102026SMark Adams       PetscReal current_max_dist = PETSC_MAX_REAL;
185*ee102026SMark Adams 
186*ee102026SMark Adams       // Cache vertex coordinates in registers to avoid repeated global memory access
187*ee102026SMark Adams       // dim is small (<= 3), so this fits easily in registers
188*ee102026SMark Adams       PetscReal v_coords[3] = {0.0, 0.0, 0.0};
189*ee102026SMark Adams       for (PetscInt d = 0; d < dim; ++d) v_coords[d] = PetscRealPart(vertex_coords_dev(i, d));
190*ee102026SMark Adams 
191*ee102026SMark Adams       // Initialize with infinity
192*ee102026SMark Adams       for (PetscInt k = 0; k < n_obs_vertex; ++k) {
193*ee102026SMark Adams         best_dists_dev(i, k) = PETSC_MAX_REAL;
194*ee102026SMark Adams         best_idxs_dev(i, k)  = -1;
195*ee102026SMark Adams       }
196828beda2SMark Adams 
197828beda2SMark Adams       // Iterate over all observations
198*ee102026SMark Adams       for (PetscInt j = 0; j < n_obs_global; ++j) {
199828beda2SMark Adams         PetscReal dist2 = 0.0;
200828beda2SMark Adams         for (PetscInt d = 0; d < dim; ++d) {
201*ee102026SMark Adams           PetscReal diff = v_coords[d] - obs_coords_dev(j, d);
202828beda2SMark Adams           dist2 += diff * diff;
203828beda2SMark Adams         }
204828beda2SMark Adams 
205*ee102026SMark Adams         // Check if this observation is closer than the furthest stored observation
206*ee102026SMark Adams         if (dist2 < current_max_dist) {
207828beda2SMark Adams           // Insert sorted
208*ee102026SMark Adams           PetscInt pos = n_obs_vertex - 1;
209828beda2SMark Adams           while (pos > 0 && best_dists_dev(i, pos - 1) > dist2) {
210828beda2SMark Adams             best_dists_dev(i, pos) = best_dists_dev(i, pos - 1);
211828beda2SMark Adams             best_idxs_dev(i, pos)  = best_idxs_dev(i, pos - 1);
212828beda2SMark Adams             pos--;
213828beda2SMark Adams           }
214828beda2SMark Adams           best_dists_dev(i, pos) = dist2;
215828beda2SMark Adams           best_idxs_dev(i, pos)  = j;
216*ee102026SMark Adams 
217*ee102026SMark Adams           // Update current max distance
218*ee102026SMark Adams           current_max_dist = best_dists_dev(i, n_obs_vertex - 1);
219828beda2SMark Adams         }
220828beda2SMark Adams       }
221828beda2SMark Adams 
222828beda2SMark Adams       // Compute weights
223*ee102026SMark Adams       PetscReal radius2 = best_dists_dev(i, n_obs_vertex - 1);
224828beda2SMark Adams       PetscReal radius  = std::sqrt(radius2);
225828beda2SMark Adams       if (radius == 0.0) radius = 1.0;
226828beda2SMark Adams 
227*ee102026SMark Adams       for (PetscInt k = 0; k < n_obs_vertex; ++k) {
228*ee102026SMark Adams         if (best_idxs_dev(i, k) != -1) {
229828beda2SMark Adams           PetscReal dist    = std::sqrt(best_dists_dev(i, k));
230828beda2SMark Adams           indices_dev(i, k) = best_idxs_dev(i, k);
231828beda2SMark Adams           values_dev(i, k)  = GaspariCohn(dist, radius);
232*ee102026SMark Adams         } else {
233*ee102026SMark Adams           indices_dev(i, k) = -1; // Ignore this entry
234*ee102026SMark Adams           values_dev(i, k)  = 0.0;
235*ee102026SMark Adams         }
236828beda2SMark Adams       }
237828beda2SMark Adams     });
238828beda2SMark Adams 
239828beda2SMark Adams   /* Copy back to host and fill matrix */
240*ee102026SMark Adams   // Host views must be LayoutRight for MatSetValues (row-major)
241*ee102026SMark Adams   Kokkos::View<PetscInt **, Kokkos::LayoutRight, Kokkos::HostSpace>    indices_host("indices_host", n_vert_local, n_obs_vertex);
242*ee102026SMark Adams   Kokkos::View<PetscScalar **, Kokkos::LayoutRight, Kokkos::HostSpace> values_host("values_host", n_vert_local, n_obs_vertex);
243828beda2SMark Adams 
244*ee102026SMark Adams   // Deep copy will handle layout conversion (transpose) if device views are LayoutLeft
245*ee102026SMark Adams   // Note: Kokkos::deep_copy cannot copy between different layouts if the memory spaces are different (e.g. GPU to Host).
246*ee102026SMark Adams   // We need an intermediate mirror view on the host with the same layout as the device view.
247*ee102026SMark Adams   Kokkos::View<PetscInt **, Kokkos::LayoutLeft, Kokkos::HostSpace>    indices_host_left = Kokkos::create_mirror_view(indices_dev);
248*ee102026SMark Adams   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace> values_host_left  = Kokkos::create_mirror_view(values_dev);
249828beda2SMark Adams 
250*ee102026SMark Adams   Kokkos::deep_copy(indices_host_left, indices_dev);
251*ee102026SMark Adams   Kokkos::deep_copy(values_host_left, values_dev);
252*ee102026SMark Adams 
253*ee102026SMark Adams   // Now copy from LayoutLeft host view to LayoutRight host view
254*ee102026SMark Adams   Kokkos::deep_copy(indices_host, indices_host_left);
255*ee102026SMark Adams   Kokkos::deep_copy(values_host, values_host_left);
256*ee102026SMark Adams 
257*ee102026SMark Adams   for (PetscInt i = 0; i < n_vert_local; ++i) {
258*ee102026SMark Adams     PetscInt globalRow = rstart + i;
259*ee102026SMark Adams     PetscCall(MatSetValues(*Q, 1, &globalRow, n_obs_vertex, &indices_host(i, 0), &values_host(i, 0), INSERT_VALUES));
260828beda2SMark Adams   }
261828beda2SMark Adams 
262828beda2SMark Adams   /* Cleanup Phase 2 storage */
263*ee102026SMark Adams   for (d = 0; d < dim; ++d) PetscCall(VecDestroy(&obs_vecs[d]));
264828beda2SMark Adams   PetscCall(PetscFree(obs_vecs));
265828beda2SMark Adams 
266828beda2SMark Adams   /* Assemble matrix */
267828beda2SMark Adams   PetscCall(MatAssemblyBegin(*Q, MAT_FINAL_ASSEMBLY));
268828beda2SMark Adams   PetscCall(MatAssemblyEnd(*Q, MAT_FINAL_ASSEMBLY));
269828beda2SMark Adams   PetscFunctionReturn(PETSC_SUCCESS);
270828beda2SMark Adams }
271