xref: /honee/problems/torch/sgs_model_torch.cpp (revision fc37ad8c2d8e5885e86197268f3a21c51d020b21)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <log_events.h>
9 #include <petsc.h>
10 #include <sgs_model_torch.h>
11 #include <torch/script.h>
12 #include <torch/torch.h>
13 
14 torch::jit::script::Module model;
15 torch::DeviceType          device_model;
16 
17 static PetscErrorCode EnumToDeviceType(TorchDeviceType device_enum, torch::DeviceType *device_type) {
18   PetscFunctionBeginUser;
19   switch (device_enum) {
20     case TORCH_DEVICE_CPU:
21       *device_type = torch::kCPU;
22       break;
23     case TORCH_DEVICE_XPU:
24       *device_type = torch::kXPU;
25       break;
26     case TORCH_DEVICE_CUDA:
27       *device_type = torch::kCUDA;
28       break;
29     case TORCH_DEVICE_HIP:
30       *device_type = torch::kHIP;
31       break;
32     default:
33       SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "TorchDeviceType %d not supported by PyTorch inference", device_enum);
34   }
35   PetscFunctionReturn(PETSC_SUCCESS);
36 }
37 
38 static PetscErrorCode PetscMemTypeToDeviceType(PetscMemType mem_type, torch::DeviceType *device_type) {
39   PetscFunctionBeginUser;
40   switch (mem_type) {
41     case PETSC_MEMTYPE_HOST:
42       *device_type = torch::kCPU;
43       break;
44     case PETSC_MEMTYPE_SYCL:
45       *device_type = torch::kXPU;
46       break;
47     case PETSC_MEMTYPE_CUDA:
48       *device_type = torch::kCUDA;
49       break;
50     case PETSC_MEMTYPE_HIP:
51       *device_type = torch::kHIP;
52       break;
53     default:
54       SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "PetscMemType %s not supported by PyTorch inference", PetscMemTypeToString(mem_type));
55   }
56   PetscFunctionReturn(PETSC_SUCCESS);
57 }
58 
59 PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum) {
60   PetscFunctionBeginUser;
61   PetscCall(EnumToDeviceType(device_enum, &device_model));
62 
63   PetscCallCXX(model = torch::jit::load(model_path));
64   PetscCallCXX(model.to(torch::Device(device_model)));
65   PetscFunctionReturn(PETSC_SUCCESS);
66 }
67 
68 // Load and run model
69 PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc) {
70   torch::Tensor  input_tensor, output_tensor;
71   const PetscInt num_input_comps = 6, num_output_comps = 6;
72   PetscBool      debug_tensor_output = PETSC_FALSE;
73 
74   PetscFunctionBeginUser;
75   // torch::NoGradGuard no_grad; // equivalent to "with torch.no_grad():" in PyTorch
76   PetscCall(PetscLogEventBegin(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
77   {  // Transfer DD_Inputs_loc into input_tensor
78     PetscMemType         input_mem_type;
79     PetscInt             input_size, num_nodes;
80     const PetscScalar   *dd_inputs_ptr;
81     torch::DeviceType    dd_input_device;
82     torch::TensorOptions options;
83 
84     PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size));
85     num_nodes = input_size / num_input_comps;
86     PetscCall(VecGetArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr, &input_mem_type));
87     PetscCall(PetscMemTypeToDeviceType(input_mem_type, &dd_input_device));
88 
89     PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_input_device));
90     if (dd_input_device == torch::kXPU) {  // XPU requires device-to-host-to-device transfer
91       PetscCallCXX(input_tensor =
92                        at::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, {num_input_comps, 1}, nullptr, options, dd_input_device)
93                            .to(device_model));
94     } else {
95       PetscCallCXX(input_tensor = torch::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, options));
96     }
97     if (debug_tensor_output) {
98       double *input_tensor_ptr;
99 
100       PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size));
101       PetscCallCXX(input_tensor_ptr = (double *)input_tensor.contiguous().to(torch::kCPU).data_ptr());
102       printf("Input_Tensor_Pointer:\n");
103       for (PetscInt i = 0; i < input_size; i++) {
104         printf("%f\n", input_tensor_ptr[i]);
105       }
106     }
107     PetscCall(VecRestoreArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr));
108   }
109   PetscCall(PetscLogEventEnd(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
110 
111   // Run model
112   PetscCall(PetscLogEventBegin(FLUIDS_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
113   PetscCall(PetscLogGpuTimeBegin());
114   PetscCallCXX(output_tensor = model.forward({input_tensor}).toTensor());
115   PetscCall(PetscLogGpuTimeEnd());
116   PetscCall(PetscLogEventEnd(FLUIDS_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
117 
118   PetscCall(PetscLogEventBegin(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
119   {  // Transfer output_tensor to DD_Outputs_loc
120     torch::DeviceType    dd_output_device;
121     torch::TensorOptions options;
122     PetscInt             output_size;
123     PetscScalar         *dd_outputs_ptr;
124     PetscMemType         output_mem_type;
125 
126     {  // Get DeviceType of DD_Outputs_loc
127       PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type));
128       PetscCall(PetscMemTypeToDeviceType(output_mem_type, &dd_output_device));
129       PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr));
130     }
131 
132     if (dd_output_device == torch::kXPU) {  // XPU requires device-to-host-to-device transfer
133       double *output_tensor_ptr;
134 
135       PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size));
136       PetscCall(VecGetArray(DD_Outputs_loc, &dd_outputs_ptr));
137       PetscCallCXX(output_tensor_ptr = (double *)output_tensor.contiguous().to(torch::kCPU).data_ptr());
138       if (debug_tensor_output) {
139         printf("Output_Tensor_Pointer:\n");
140         for (PetscInt i = 0; i < output_size; i++) {
141           printf("%f\n", output_tensor_ptr[i]);
142         }
143       }
144       PetscCall(PetscArraycpy(dd_outputs_ptr, output_tensor_ptr, output_size));
145       PetscCall(VecRestoreArray(DD_Outputs_loc, &dd_outputs_ptr));
146     } else {
147       PetscInt      num_nodes;
148       torch::Tensor DD_Outputs_tensor;
149 
150       PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size));
151       num_nodes = output_size / num_output_comps;
152       PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type));
153       PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_output_device));
154       PetscCallCXX(DD_Outputs_tensor = torch::from_blob((void *)dd_outputs_ptr, {num_nodes, num_output_comps}, options));
155       PetscCallCXX(DD_Outputs_tensor.copy_(output_tensor));
156       PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr));
157     }
158   }
159   PetscCall(PetscLogEventEnd(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
160   PetscFunctionReturn(PETSC_SUCCESS);
161 }
162