xref: /honee/problems/torch/sgs_model_torch.cpp (revision 4c07ec2294887c4a114ef13a7c2da0ab5f5dc208)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <petsc.h>
9 #include <sgs_model_torch.h>
10 #include <torch/script.h>
11 #include <torch/torch.h>
12 
13 torch::jit::script::Module model;
14 torch::DeviceType          device_model;
15 
16 static PetscErrorCode EnumToDeviceType(TorchDeviceType device_enum, torch::DeviceType *device_type) {
17   PetscFunctionBeginUser;
18   switch (device_enum) {
19     case TORCH_DEVICE_CPU:
20       *device_type = torch::kCPU;
21       break;
22     case TORCH_DEVICE_XPU:
23       *device_type = torch::kXPU;
24       break;
25     case TORCH_DEVICE_CUDA:
26       *device_type = torch::kCUDA;
27       break;
28     case TORCH_DEVICE_HIP:
29       *device_type = torch::kHIP;
30       break;
31     default:
32       SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "TorchDeviceType %d not supported by PyTorch inference", device_enum);
33   }
34   PetscFunctionReturn(PETSC_SUCCESS);
35 }
36 
37 static PetscErrorCode PetscMemTypeToDeviceType(PetscMemType mem_type, torch::DeviceType *device_type) {
38   PetscFunctionBeginUser;
39   switch (mem_type) {
40     case PETSC_MEMTYPE_HOST:
41       *device_type = torch::kCPU;
42       break;
43     case PETSC_MEMTYPE_SYCL:
44       *device_type = torch::kXPU;
45       break;
46     case PETSC_MEMTYPE_CUDA:
47       *device_type = torch::kCUDA;
48       break;
49     case PETSC_MEMTYPE_HIP:
50       *device_type = torch::kHIP;
51       break;
52     default:
53       SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "PetscMemType %s not supported by PyTorch inference", PetscMemTypeToString(mem_type));
54   }
55   PetscFunctionReturn(PETSC_SUCCESS);
56 }
57 
58 PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum) {
59   PetscFunctionBeginUser;
60   PetscCall(EnumToDeviceType(device_enum, &device_model));
61 
62   PetscCallCXX(model = torch::jit::load(model_path));
63   PetscCallCXX(model.to(torch::Device(device_model)));
64   PetscFunctionReturn(PETSC_SUCCESS);
65 }
66 
67 // Load and run model
68 PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc) {
69   torch::Tensor  input_tensor, output_tensor;
70   const PetscInt num_input_comps = 6, num_output_comps = 6;
71   PetscBool      debug_tensor_output = PETSC_FALSE;
72 
73   PetscFunctionBeginUser;
74   // torch::NoGradGuard no_grad; // equivalent to "with torch.no_grad():" in PyTorch
75   {  // Transfer DD_Inputs_loc into input_tensor
76     PetscMemType         input_mem_type;
77     PetscInt             input_size, num_nodes;
78     const PetscScalar   *dd_inputs_ptr;
79     torch::DeviceType    dd_input_device;
80     torch::TensorOptions options;
81 
82     PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size));
83     num_nodes = input_size / num_input_comps;
84     PetscCall(VecGetArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr, &input_mem_type));
85     PetscCall(PetscMemTypeToDeviceType(input_mem_type, &dd_input_device));
86 
87     PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_input_device));
88     if (dd_input_device == torch::kXPU) {  // XPU requires device-to-host-to-device transfer
89       PetscCallCXX(input_tensor =
90                        at::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, {num_input_comps, 1}, nullptr, options, dd_input_device)
91                            .to(device_model));
92     } else {
93       PetscCallCXX(input_tensor = torch::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, options));
94     }
95     if (debug_tensor_output) {
96       double *input_tensor_ptr;
97 
98       PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size));
99       PetscCallCXX(input_tensor_ptr = (double *)input_tensor.contiguous().to(torch::kCPU).data_ptr());
100       printf("Input_Tensor_Pointer:\n");
101       for (PetscInt i = 0; i < input_size; i++) {
102         printf("%f\n", input_tensor_ptr[i]);
103       }
104     }
105     PetscCall(VecRestoreArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr));
106   }
107 
108   // Run model
109   PetscCallCXX(output_tensor = model.forward({input_tensor}).toTensor());
110 
111   {  // Transfer output_tensor to DD_Outputs_loc
112     torch::DeviceType    dd_output_device;
113     torch::TensorOptions options;
114     PetscInt             output_size;
115     PetscScalar         *dd_outputs_ptr;
116     PetscMemType         output_mem_type;
117 
118     {  // Get DeviceType of DD_Outputs_loc
119       PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type));
120       PetscCall(PetscMemTypeToDeviceType(output_mem_type, &dd_output_device));
121       PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr));
122     }
123 
124     if (dd_output_device == torch::kXPU) {  // XPU requires device-to-host-to-device transfer
125       double *output_tensor_ptr;
126 
127       PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size));
128       PetscCall(VecGetArray(DD_Outputs_loc, &dd_outputs_ptr));
129       PetscCallCXX(output_tensor_ptr = (double *)output_tensor.contiguous().to(torch::kCPU).data_ptr());
130       if (debug_tensor_output) {
131         printf("Output_Tensor_Pointer:\n");
132         for (PetscInt i = 0; i < output_size; i++) {
133           printf("%f\n", output_tensor_ptr[i]);
134         }
135       }
136       PetscCall(PetscArraycpy(dd_outputs_ptr, output_tensor_ptr, output_size));
137       PetscCall(VecRestoreArray(DD_Outputs_loc, &dd_outputs_ptr));
138     } else {
139       PetscInt      num_nodes;
140       torch::Tensor DD_Outputs_tensor;
141 
142       PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size));
143       num_nodes = output_size / num_output_comps;
144       PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type));
145       PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_output_device));
146       PetscCallCXX(DD_Outputs_tensor = torch::from_blob((void *)dd_outputs_ptr, {num_nodes, num_output_comps}, options));
147       PetscCallCXX(DD_Outputs_tensor.copy_(output_tensor));
148       PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr));
149     }
150   }
151   PetscFunctionReturn(PETSC_SUCCESS);
152 }
153