xref: /honee/problems/torch/sgs_model_torch.cpp (revision 8c85b83518484fc9eaa5bccfe38a7cce91b34691)
1 // SPDX-FileCopyrightText: Copyright (c) 2017-2024, HONEE contributors.
2 // SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause
3 
4 #include <log_events.h>
5 #include <petsc.h>
6 #include <sgs_model_torch.h>
7 #include <torch/script.h>
8 #include <torch/torch.h>
9 
10 torch::jit::script::Module model;
11 torch::DeviceType          device_model;
12 
13 static PetscErrorCode EnumToDeviceType(TorchDeviceType device_enum, torch::DeviceType *device_type) {
14   PetscFunctionBeginUser;
15   switch (device_enum) {
16     case TORCH_DEVICE_CPU:
17       *device_type = torch::kCPU;
18       break;
19     case TORCH_DEVICE_XPU:
20       *device_type = torch::kXPU;
21       break;
22     case TORCH_DEVICE_CUDA:
23       *device_type = torch::kCUDA;
24       break;
25     case TORCH_DEVICE_HIP:
26       *device_type = torch::kHIP;
27       break;
28     default:
29       SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "TorchDeviceType %d not supported by PyTorch inference", device_enum);
30   }
31   PetscFunctionReturn(PETSC_SUCCESS);
32 }
33 
34 static PetscErrorCode PetscMemTypeToDeviceType(PetscMemType mem_type, torch::DeviceType *device_type) {
35   PetscFunctionBeginUser;
36   switch (mem_type) {
37     case PETSC_MEMTYPE_HOST:
38       *device_type = torch::kCPU;
39       break;
40     case PETSC_MEMTYPE_SYCL:
41       *device_type = torch::kXPU;
42       break;
43     case PETSC_MEMTYPE_CUDA:
44       *device_type = torch::kCUDA;
45       break;
46     case PETSC_MEMTYPE_HIP:
47       *device_type = torch::kHIP;
48       break;
49     default:
50       SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "PetscMemType %s not supported by PyTorch inference", PetscMemTypeToString(mem_type));
51   }
52   PetscFunctionReturn(PETSC_SUCCESS);
53 }
54 
55 PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum) {
56   PetscFunctionBeginUser;
57   PetscCall(EnumToDeviceType(device_enum, &device_model));
58 
59   PetscCallCXX(model = torch::jit::load(model_path));
60   PetscCallCXX(model.to(torch::Device(device_model)));
61   PetscFunctionReturn(PETSC_SUCCESS);
62 }
63 
64 // Load and run model
65 PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc) {
66   torch::Tensor  input_tensor, output_tensor;
67   const PetscInt num_input_comps = 6, num_output_comps = 6;
68   PetscBool      debug_tensor_output = PETSC_FALSE;
69 
70   PetscFunctionBeginUser;
71   // torch::NoGradGuard no_grad; // equivalent to "with torch.no_grad():" in PyTorch
72   PetscCall(PetscLogEventBegin(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
73   {  // Transfer DD_Inputs_loc into input_tensor
74     PetscMemType         input_mem_type;
75     PetscInt             input_size, num_nodes;
76     const PetscScalar   *dd_inputs_ptr;
77     torch::DeviceType    dd_input_device;
78     torch::TensorOptions options;
79 
80     PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size));
81     num_nodes = input_size / num_input_comps;
82     PetscCall(VecGetArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr, &input_mem_type));
83     PetscCall(PetscMemTypeToDeviceType(input_mem_type, &dd_input_device));
84 
85     PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_input_device));
86     if (dd_input_device == torch::kXPU) {  // XPU requires device-to-host-to-device transfer
87       PetscCallCXX(input_tensor =
88                        at::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, {num_input_comps, 1}, nullptr, options, dd_input_device)
89                            .to(device_model));
90     } else {
91       PetscCallCXX(input_tensor = torch::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, options));
92     }
93     if (debug_tensor_output) {
94       double *input_tensor_ptr;
95 
96       PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size));
97       PetscCallCXX(input_tensor_ptr = (double *)input_tensor.contiguous().to(torch::kCPU).data_ptr());
98       printf("Input_Tensor_Pointer:\n");
99       for (PetscInt i = 0; i < input_size; i++) {
100         printf("%f\n", input_tensor_ptr[i]);
101       }
102     }
103     PetscCall(VecRestoreArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr));
104   }
105   PetscCall(PetscLogEventEnd(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
106 
107   // Run model
108   PetscCall(PetscLogEventBegin(FLUIDS_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
109   PetscCall(PetscLogGpuTimeBegin());
110   PetscCallCXX(output_tensor = model.forward({input_tensor}).toTensor());
111   PetscCall(PetscLogGpuTimeEnd());
112   PetscCall(PetscLogEventEnd(FLUIDS_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
113 
114   PetscCall(PetscLogEventBegin(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
115   {  // Transfer output_tensor to DD_Outputs_loc
116     torch::DeviceType    dd_output_device;
117     torch::TensorOptions options;
118     PetscInt             output_size;
119     PetscScalar         *dd_outputs_ptr;
120     PetscMemType         output_mem_type;
121 
122     {  // Get DeviceType of DD_Outputs_loc
123       PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type));
124       PetscCall(PetscMemTypeToDeviceType(output_mem_type, &dd_output_device));
125       PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr));
126     }
127 
128     if (dd_output_device == torch::kXPU) {  // XPU requires device-to-host-to-device transfer
129       double *output_tensor_ptr;
130 
131       PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size));
132       PetscCall(VecGetArray(DD_Outputs_loc, &dd_outputs_ptr));
133       PetscCallCXX(output_tensor_ptr = (double *)output_tensor.contiguous().to(torch::kCPU).data_ptr());
134       if (debug_tensor_output) {
135         printf("Output_Tensor_Pointer:\n");
136         for (PetscInt i = 0; i < output_size; i++) {
137           printf("%f\n", output_tensor_ptr[i]);
138         }
139       }
140       PetscCall(PetscArraycpy(dd_outputs_ptr, output_tensor_ptr, output_size));
141       PetscCall(VecRestoreArray(DD_Outputs_loc, &dd_outputs_ptr));
142     } else {
143       PetscInt      num_nodes;
144       torch::Tensor DD_Outputs_tensor;
145 
146       PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size));
147       num_nodes = output_size / num_output_comps;
148       PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type));
149       PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_output_device));
150       PetscCallCXX(DD_Outputs_tensor = torch::from_blob((void *)dd_outputs_ptr, {num_nodes, num_output_comps}, options));
151       PetscCallCXX(DD_Outputs_tensor.copy_(output_tensor));
152       PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr));
153     }
154   }
155   PetscCall(PetscLogEventEnd(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
156   PetscFunctionReturn(PETSC_SUCCESS);
157 }
158