1 // SPDX-FileCopyrightText: Copyright (c) 2017-2024, HONEE contributors.
2 // SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause
3
4 #include <log_events.h>
5 #include <petsc.h>
6 #include <sgs_model_torch.h>
7 #include <torch/script.h>
8 #include <torch/torch.h>
9
10 torch::jit::script::Module model;
11 torch::DeviceType device_model;
12
EnumToDeviceType(TorchDeviceType device_enum,torch::DeviceType * device_type)13 static PetscErrorCode EnumToDeviceType(TorchDeviceType device_enum, torch::DeviceType *device_type) {
14 PetscFunctionBeginUser;
15 switch (device_enum) {
16 case TORCH_DEVICE_CPU:
17 *device_type = torch::kCPU;
18 break;
19 case TORCH_DEVICE_XPU:
20 *device_type = torch::kXPU;
21 break;
22 case TORCH_DEVICE_CUDA:
23 *device_type = torch::kCUDA;
24 break;
25 case TORCH_DEVICE_HIP:
26 *device_type = torch::kHIP;
27 break;
28 default:
29 SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "TorchDeviceType %d not supported by PyTorch inference", device_enum);
30 }
31 PetscFunctionReturn(PETSC_SUCCESS);
32 }
33
PetscMemTypeToDeviceType(PetscMemType mem_type,torch::DeviceType * device_type)34 static PetscErrorCode PetscMemTypeToDeviceType(PetscMemType mem_type, torch::DeviceType *device_type) {
35 PetscFunctionBeginUser;
36 switch (mem_type) {
37 case PETSC_MEMTYPE_HOST:
38 *device_type = torch::kCPU;
39 break;
40 case PETSC_MEMTYPE_SYCL:
41 *device_type = torch::kXPU;
42 break;
43 case PETSC_MEMTYPE_CUDA:
44 *device_type = torch::kCUDA;
45 break;
46 case PETSC_MEMTYPE_HIP:
47 *device_type = torch::kHIP;
48 break;
49 default:
50 SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "PetscMemType %s not supported by PyTorch inference", PetscMemTypeToString(mem_type));
51 }
52 PetscFunctionReturn(PETSC_SUCCESS);
53 }
54
LoadModel_Torch(const char * model_path,TorchDeviceType device_enum)55 PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum) {
56 PetscFunctionBeginUser;
57 PetscCall(EnumToDeviceType(device_enum, &device_model));
58
59 PetscCallCXX(model = torch::jit::load(model_path));
60 PetscCallCXX(model.to(torch::Device(device_model)));
61 PetscFunctionReturn(PETSC_SUCCESS);
62 }
63
64 // Load and run model
ModelInference_Torch(Vec DD_Inputs_loc,Vec DD_Outputs_loc)65 PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc) {
66 torch::Tensor input_tensor, output_tensor;
67 const PetscInt num_input_comps = 6, num_output_comps = 6;
68 PetscBool debug_tensor_output = PETSC_FALSE;
69
70 PetscFunctionBeginUser;
71 // torch::NoGradGuard no_grad; // equivalent to "with torch.no_grad():" in PyTorch
72 PetscCall(PetscLogEventBegin(HONEE_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
73 { // Transfer DD_Inputs_loc into input_tensor
74 PetscMemType input_mem_type;
75 PetscInt input_size, num_nodes;
76 const PetscScalar *dd_inputs_ptr;
77 torch::DeviceType dd_input_device;
78 torch::TensorOptions options;
79
80 PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size));
81 num_nodes = input_size / num_input_comps;
82 PetscCall(VecGetArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr, &input_mem_type));
83 PetscCall(PetscMemTypeToDeviceType(input_mem_type, &dd_input_device));
84
85 PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_input_device));
86 if (dd_input_device == torch::kXPU) { // XPU requires device-to-host-to-device transfer
87 PetscCallCXX(input_tensor =
88 at::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, {num_input_comps, 1}, nullptr, options, dd_input_device)
89 .to(device_model));
90 } else {
91 PetscCallCXX(input_tensor = torch::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, options));
92 }
93 if (debug_tensor_output) {
94 double *input_tensor_ptr;
95
96 PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size));
97 PetscCallCXX(input_tensor_ptr = (double *)input_tensor.contiguous().to(torch::kCPU).data_ptr());
98 printf("Input_Tensor_Pointer:\n");
99 for (PetscInt i = 0; i < input_size; i++) {
100 printf("%f\n", input_tensor_ptr[i]);
101 }
102 }
103 PetscCall(VecRestoreArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr));
104 }
105 PetscCall(PetscLogEventEnd(HONEE_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
106
107 // Run model
108 PetscCall(PetscLogEventBegin(HONEE_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
109 PetscCall(PetscLogGpuTimeBegin());
110 PetscCallCXX(output_tensor = model.forward({input_tensor}).toTensor());
111 PetscCall(PetscLogGpuTimeEnd());
112 PetscCall(PetscLogEventEnd(HONEE_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
113
114 PetscCall(PetscLogEventBegin(HONEE_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
115 { // Transfer output_tensor to DD_Outputs_loc
116 torch::DeviceType dd_output_device;
117 torch::TensorOptions options;
118 PetscInt output_size;
119 PetscScalar *dd_outputs_ptr;
120 PetscMemType output_mem_type;
121
122 { // Get DeviceType of DD_Outputs_loc
123 PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type));
124 PetscCall(PetscMemTypeToDeviceType(output_mem_type, &dd_output_device));
125 PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr));
126 }
127
128 if (dd_output_device == torch::kXPU) { // XPU requires device-to-host-to-device transfer
129 double *output_tensor_ptr;
130
131 PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size));
132 PetscCall(VecGetArray(DD_Outputs_loc, &dd_outputs_ptr));
133 PetscCallCXX(output_tensor_ptr = (double *)output_tensor.contiguous().to(torch::kCPU).data_ptr());
134 if (debug_tensor_output) {
135 printf("Output_Tensor_Pointer:\n");
136 for (PetscInt i = 0; i < output_size; i++) {
137 printf("%f\n", output_tensor_ptr[i]);
138 }
139 }
140 PetscCall(PetscArraycpy(dd_outputs_ptr, output_tensor_ptr, output_size));
141 PetscCall(VecRestoreArray(DD_Outputs_loc, &dd_outputs_ptr));
142 } else {
143 PetscInt num_nodes;
144 torch::Tensor DD_Outputs_tensor;
145
146 PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size));
147 num_nodes = output_size / num_output_comps;
148 PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type));
149 PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_output_device));
150 PetscCallCXX(DD_Outputs_tensor = torch::from_blob((void *)dd_outputs_ptr, {num_nodes, num_output_comps}, options));
151 PetscCallCXX(DD_Outputs_tensor.copy_(output_tensor));
152 PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr));
153 }
154 }
155 PetscCall(PetscLogEventEnd(HONEE_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
156 PetscFunctionReturn(PETSC_SUCCESS);
157 }
158