1 // SPDX-FileCopyrightText: Copyright (c) 2017-2024, HONEE contributors.
2 // SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause
3 // Based on the instructions from https://www.craylabs.org/docs/sr_integration.html and PHASTA implementation
4
5 #include <smartsim-impl.h>
6
7 #include <navierstokes.h>
8
9 #define SMARTSIM_KEY "SmartSimData"
10
SmartSimDataDestroy(SmartSimData * smartsim)11 static PetscErrorCode SmartSimDataDestroy(SmartSimData *smartsim) {
12 SmartSimData smartsim_ = *smartsim;
13
14 PetscFunctionBeginUser;
15 if (!smartsim_) PetscFunctionReturn(PETSC_SUCCESS);
16
17 PetscCallSmartRedis(DeleteCClient(&smartsim_->client));
18 PetscCall(PetscFree(smartsim_));
19 *smartsim = NULL;
20 PetscFunctionReturn(PETSC_SUCCESS);
21 }
22
SmartSimTrainingSetup(Honee honee)23 static PetscErrorCode SmartSimTrainingSetup(Honee honee) {
24 SmartSimData smartsim;
25 PetscMPIInt rank;
26 PetscReal checkrun[2] = {1};
27 size_t dim_2[1] = {2};
28
29 PetscFunctionBeginUser;
30 PetscCall(HoneeGetSmartSimData(honee, &smartsim));
31 PetscCallMPI(MPI_Comm_rank(honee->comm, &rank));
32
33 if (rank % smartsim->collocated_database_num_ranks == 0) {
34 // -- Send array that communicates when ML is done training
35 PetscCall(PetscLogEventBegin(HONEE_SmartRedis_Meta, 0, 0, 0, 0));
36 PetscCallSmartRedis(put_tensor(smartsim->client, "check-run", 9, checkrun, dim_2, 1, SRTensorTypeDouble, SRMemLayoutContiguous));
37 PetscCall(SmartRedisVerifyPutTensor(smartsim->client, "check-run", 9));
38 PetscCall(PetscLogEventEnd(HONEE_SmartRedis_Meta, 0, 0, 0, 0));
39 }
40 PetscFunctionReturn(PETSC_SUCCESS);
41 }
42
SmartSimSetup(Honee honee)43 static PetscErrorCode SmartSimSetup(Honee honee) {
44 PetscMPIInt rank;
45 PetscInt num_orchestrator_nodes = 1;
46 SmartSimData smartsim;
47
48 PetscFunctionBeginUser;
49 PetscCall(PetscNew(&smartsim));
50
51 smartsim->collocated_database_num_ranks = 1;
52 PetscOptionsBegin(honee->comm, NULL, "Options for SmartSim integration", NULL);
53 PetscCall(PetscOptionsInt("-smartsim_collocated_database_num_ranks", "Number of ranks per collocated database instance", NULL,
54 smartsim->collocated_database_num_ranks, &smartsim->collocated_database_num_ranks, NULL));
55 PetscOptionsEnd();
56
57 // Create prefix to be put on tensor names
58 PetscCallMPI(MPI_Comm_rank(honee->comm, &rank));
59 PetscCall(PetscSNPrintf(smartsim->rank_id_name, sizeof(smartsim->rank_id_name), "y.%d", rank));
60
61 PetscCall(PetscLogEventBegin(HONEE_SmartRedis_Init, 0, 0, 0, 0));
62 PetscCallSmartRedis(SmartRedisCClient(num_orchestrator_nodes != 1, smartsim->rank_id_name, strlen(smartsim->rank_id_name), &smartsim->client));
63 PetscCall(PetscLogEventEnd(HONEE_SmartRedis_Init, 0, 0, 0, 0));
64
65 PetscCall(HoneeSetContainer(honee, SMARTSIM_KEY, smartsim, (PetscCtxDestroyFn *)SmartSimDataDestroy));
66
67 PetscCall(SmartSimTrainingSetup(honee));
68 PetscFunctionReturn(PETSC_SUCCESS);
69 }
70
71 /**
72 @brief Obtains the `SmartSimData` from the `Honee` object
73
74 If `SmartSimData` has not already been initialized, this will initialize and create the struct.
75
76 @param[in] honee `Honee` object containing the SmartSim data
77 @param[out] smartsim `SmartSimData` containing the data
78 **/
HoneeGetSmartSimData(Honee honee,SmartSimData * smartsim)79 PetscErrorCode HoneeGetSmartSimData(Honee honee, SmartSimData *smartsim) {
80 PetscBool has_smartsim;
81
82 PetscFunctionBeginUser;
83 PetscCall(HoneeHasContainer(honee, SMARTSIM_KEY, &has_smartsim));
84 if (!has_smartsim) PetscCall(SmartSimSetup(honee));
85 PetscCall(HoneeGetContainer(honee, SMARTSIM_KEY, smartsim));
86 PetscFunctionReturn(PETSC_SUCCESS);
87 }
88
89 /**
90 @brief Checks if a tensor with `name` is in the SmartRedis database
91
92 Function will error out if tensor does not exist.
93
94 @param[in] c_client SmartRedis client object
95 @param[in] name Name of the tensor
96 @param[in] name_length Length of the tensor name
97 @return An error code: 0 - success, otherwise - failure
98 **/
SmartRedisVerifyPutTensor(void * c_client,const char * name,const size_t name_length)99 PetscErrorCode SmartRedisVerifyPutTensor(void *c_client, const char *name, const size_t name_length) {
100 bool does_exist = true;
101
102 PetscFunctionBeginUser;
103 PetscCall(PetscLogEventBegin(HONEE_SmartRedis_Meta, 0, 0, 0, 0));
104 PetscCallSmartRedis(tensor_exists(c_client, name, name_length, &does_exist));
105 PetscCheck(does_exist, PETSC_COMM_SELF, -1, "Tensor of name '%s' was not written to the database successfully", name);
106 PetscCall(PetscLogEventEnd(HONEE_SmartRedis_Meta, 0, 0, 0, 0));
107 PetscFunctionReturn(PETSC_SUCCESS);
108 }
109