xref: /honee/problems/torch/sgs_model_torch.cpp (revision 4c07ec2294887c4a114ef13a7c2da0ab5f5dc208)
1*4c07ec22SJames Wright // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2*4c07ec22SJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*4c07ec22SJames Wright //
4*4c07ec22SJames Wright // SPDX-License-Identifier: BSD-2-Clause
5*4c07ec22SJames Wright //
6*4c07ec22SJames Wright // This file is part of CEED:  http://github.com/ceed
7*4c07ec22SJames Wright 
8*4c07ec22SJames Wright #include <petsc.h>
9*4c07ec22SJames Wright #include <sgs_model_torch.h>
10*4c07ec22SJames Wright #include <torch/script.h>
11*4c07ec22SJames Wright #include <torch/torch.h>
12*4c07ec22SJames Wright 
13*4c07ec22SJames Wright torch::jit::script::Module model;
14*4c07ec22SJames Wright torch::DeviceType          device_model;
15*4c07ec22SJames Wright 
16*4c07ec22SJames Wright static PetscErrorCode EnumToDeviceType(TorchDeviceType device_enum, torch::DeviceType *device_type) {
17*4c07ec22SJames Wright   PetscFunctionBeginUser;
18*4c07ec22SJames Wright   switch (device_enum) {
19*4c07ec22SJames Wright     case TORCH_DEVICE_CPU:
20*4c07ec22SJames Wright       *device_type = torch::kCPU;
21*4c07ec22SJames Wright       break;
22*4c07ec22SJames Wright     case TORCH_DEVICE_XPU:
23*4c07ec22SJames Wright       *device_type = torch::kXPU;
24*4c07ec22SJames Wright       break;
25*4c07ec22SJames Wright     case TORCH_DEVICE_CUDA:
26*4c07ec22SJames Wright       *device_type = torch::kCUDA;
27*4c07ec22SJames Wright       break;
28*4c07ec22SJames Wright     case TORCH_DEVICE_HIP:
29*4c07ec22SJames Wright       *device_type = torch::kHIP;
30*4c07ec22SJames Wright       break;
31*4c07ec22SJames Wright     default:
32*4c07ec22SJames Wright       SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "TorchDeviceType %d not supported by PyTorch inference", device_enum);
33*4c07ec22SJames Wright   }
34*4c07ec22SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
35*4c07ec22SJames Wright }
36*4c07ec22SJames Wright 
37*4c07ec22SJames Wright static PetscErrorCode PetscMemTypeToDeviceType(PetscMemType mem_type, torch::DeviceType *device_type) {
38*4c07ec22SJames Wright   PetscFunctionBeginUser;
39*4c07ec22SJames Wright   switch (mem_type) {
40*4c07ec22SJames Wright     case PETSC_MEMTYPE_HOST:
41*4c07ec22SJames Wright       *device_type = torch::kCPU;
42*4c07ec22SJames Wright       break;
43*4c07ec22SJames Wright     case PETSC_MEMTYPE_SYCL:
44*4c07ec22SJames Wright       *device_type = torch::kXPU;
45*4c07ec22SJames Wright       break;
46*4c07ec22SJames Wright     case PETSC_MEMTYPE_CUDA:
47*4c07ec22SJames Wright       *device_type = torch::kCUDA;
48*4c07ec22SJames Wright       break;
49*4c07ec22SJames Wright     case PETSC_MEMTYPE_HIP:
50*4c07ec22SJames Wright       *device_type = torch::kHIP;
51*4c07ec22SJames Wright       break;
52*4c07ec22SJames Wright     default:
53*4c07ec22SJames Wright       SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "PetscMemType %s not supported by PyTorch inference", PetscMemTypeToString(mem_type));
54*4c07ec22SJames Wright   }
55*4c07ec22SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
56*4c07ec22SJames Wright }
57*4c07ec22SJames Wright 
58*4c07ec22SJames Wright PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum) {
59*4c07ec22SJames Wright   PetscFunctionBeginUser;
60*4c07ec22SJames Wright   PetscCall(EnumToDeviceType(device_enum, &device_model));
61*4c07ec22SJames Wright 
62*4c07ec22SJames Wright   PetscCallCXX(model = torch::jit::load(model_path));
63*4c07ec22SJames Wright   PetscCallCXX(model.to(torch::Device(device_model)));
64*4c07ec22SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
65*4c07ec22SJames Wright }
66*4c07ec22SJames Wright 
67*4c07ec22SJames Wright // Load and run model
68*4c07ec22SJames Wright PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc) {
69*4c07ec22SJames Wright   torch::Tensor  input_tensor, output_tensor;
70*4c07ec22SJames Wright   const PetscInt num_input_comps = 6, num_output_comps = 6;
71*4c07ec22SJames Wright   PetscBool      debug_tensor_output = PETSC_FALSE;
72*4c07ec22SJames Wright 
73*4c07ec22SJames Wright   PetscFunctionBeginUser;
74*4c07ec22SJames Wright   // torch::NoGradGuard no_grad; // equivalent to "with torch.no_grad():" in PyTorch
75*4c07ec22SJames Wright   {  // Transfer DD_Inputs_loc into input_tensor
76*4c07ec22SJames Wright     PetscMemType         input_mem_type;
77*4c07ec22SJames Wright     PetscInt             input_size, num_nodes;
78*4c07ec22SJames Wright     const PetscScalar   *dd_inputs_ptr;
79*4c07ec22SJames Wright     torch::DeviceType    dd_input_device;
80*4c07ec22SJames Wright     torch::TensorOptions options;
81*4c07ec22SJames Wright 
82*4c07ec22SJames Wright     PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size));
83*4c07ec22SJames Wright     num_nodes = input_size / num_input_comps;
84*4c07ec22SJames Wright     PetscCall(VecGetArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr, &input_mem_type));
85*4c07ec22SJames Wright     PetscCall(PetscMemTypeToDeviceType(input_mem_type, &dd_input_device));
86*4c07ec22SJames Wright 
87*4c07ec22SJames Wright     PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_input_device));
88*4c07ec22SJames Wright     if (dd_input_device == torch::kXPU) {  // XPU requires device-to-host-to-device transfer
89*4c07ec22SJames Wright       PetscCallCXX(input_tensor =
90*4c07ec22SJames Wright                        at::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, {num_input_comps, 1}, nullptr, options, dd_input_device)
91*4c07ec22SJames Wright                            .to(device_model));
92*4c07ec22SJames Wright     } else {
93*4c07ec22SJames Wright       PetscCallCXX(input_tensor = torch::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, options));
94*4c07ec22SJames Wright     }
95*4c07ec22SJames Wright     if (debug_tensor_output) {
96*4c07ec22SJames Wright       double *input_tensor_ptr;
97*4c07ec22SJames Wright 
98*4c07ec22SJames Wright       PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size));
99*4c07ec22SJames Wright       PetscCallCXX(input_tensor_ptr = (double *)input_tensor.contiguous().to(torch::kCPU).data_ptr());
100*4c07ec22SJames Wright       printf("Input_Tensor_Pointer:\n");
101*4c07ec22SJames Wright       for (PetscInt i = 0; i < input_size; i++) {
102*4c07ec22SJames Wright         printf("%f\n", input_tensor_ptr[i]);
103*4c07ec22SJames Wright       }
104*4c07ec22SJames Wright     }
105*4c07ec22SJames Wright     PetscCall(VecRestoreArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr));
106*4c07ec22SJames Wright   }
107*4c07ec22SJames Wright 
108*4c07ec22SJames Wright   // Run model
109*4c07ec22SJames Wright   PetscCallCXX(output_tensor = model.forward({input_tensor}).toTensor());
110*4c07ec22SJames Wright 
111*4c07ec22SJames Wright   {  // Transfer output_tensor to DD_Outputs_loc
112*4c07ec22SJames Wright     torch::DeviceType    dd_output_device;
113*4c07ec22SJames Wright     torch::TensorOptions options;
114*4c07ec22SJames Wright     PetscInt             output_size;
115*4c07ec22SJames Wright     PetscScalar         *dd_outputs_ptr;
116*4c07ec22SJames Wright     PetscMemType         output_mem_type;
117*4c07ec22SJames Wright 
118*4c07ec22SJames Wright     {  // Get DeviceType of DD_Outputs_loc
119*4c07ec22SJames Wright       PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type));
120*4c07ec22SJames Wright       PetscCall(PetscMemTypeToDeviceType(output_mem_type, &dd_output_device));
121*4c07ec22SJames Wright       PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr));
122*4c07ec22SJames Wright     }
123*4c07ec22SJames Wright 
124*4c07ec22SJames Wright     if (dd_output_device == torch::kXPU) {  // XPU requires device-to-host-to-device transfer
125*4c07ec22SJames Wright       double *output_tensor_ptr;
126*4c07ec22SJames Wright 
127*4c07ec22SJames Wright       PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size));
128*4c07ec22SJames Wright       PetscCall(VecGetArray(DD_Outputs_loc, &dd_outputs_ptr));
129*4c07ec22SJames Wright       PetscCallCXX(output_tensor_ptr = (double *)output_tensor.contiguous().to(torch::kCPU).data_ptr());
130*4c07ec22SJames Wright       if (debug_tensor_output) {
131*4c07ec22SJames Wright         printf("Output_Tensor_Pointer:\n");
132*4c07ec22SJames Wright         for (PetscInt i = 0; i < output_size; i++) {
133*4c07ec22SJames Wright           printf("%f\n", output_tensor_ptr[i]);
134*4c07ec22SJames Wright         }
135*4c07ec22SJames Wright       }
136*4c07ec22SJames Wright       PetscCall(PetscArraycpy(dd_outputs_ptr, output_tensor_ptr, output_size));
137*4c07ec22SJames Wright       PetscCall(VecRestoreArray(DD_Outputs_loc, &dd_outputs_ptr));
138*4c07ec22SJames Wright     } else {
139*4c07ec22SJames Wright       PetscInt      num_nodes;
140*4c07ec22SJames Wright       torch::Tensor DD_Outputs_tensor;
141*4c07ec22SJames Wright 
142*4c07ec22SJames Wright       PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size));
143*4c07ec22SJames Wright       num_nodes = output_size / num_output_comps;
144*4c07ec22SJames Wright       PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type));
145*4c07ec22SJames Wright       PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_output_device));
146*4c07ec22SJames Wright       PetscCallCXX(DD_Outputs_tensor = torch::from_blob((void *)dd_outputs_ptr, {num_nodes, num_output_comps}, options));
147*4c07ec22SJames Wright       PetscCallCXX(DD_Outputs_tensor.copy_(output_tensor));
148*4c07ec22SJames Wright       PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr));
149*4c07ec22SJames Wright     }
150*4c07ec22SJames Wright   }
151*4c07ec22SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
152*4c07ec22SJames Wright }
153