xref: /honee/problems/torch/sgs_model_torch.cpp (revision ea615d4cc464aa6ad650c06fae6d120cc2465bc4)
1ae2b091fSJames Wright // SPDX-FileCopyrightText: Copyright (c) 2017-2024, HONEE contributors.
2ae2b091fSJames Wright // SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause
34c07ec22SJames Wright 
4b40a7e63SJames Wright #include <log_events.h>
54c07ec22SJames Wright #include <petsc.h>
64c07ec22SJames Wright #include <sgs_model_torch.h>
74c07ec22SJames Wright #include <torch/script.h>
84c07ec22SJames Wright #include <torch/torch.h>
94c07ec22SJames Wright 
104c07ec22SJames Wright torch::jit::script::Module model;
114c07ec22SJames Wright torch::DeviceType          device_model;
124c07ec22SJames Wright 
134c07ec22SJames Wright static PetscErrorCode EnumToDeviceType(TorchDeviceType device_enum, torch::DeviceType *device_type) {
144c07ec22SJames Wright   PetscFunctionBeginUser;
154c07ec22SJames Wright   switch (device_enum) {
164c07ec22SJames Wright     case TORCH_DEVICE_CPU:
174c07ec22SJames Wright       *device_type = torch::kCPU;
184c07ec22SJames Wright       break;
194c07ec22SJames Wright     case TORCH_DEVICE_XPU:
204c07ec22SJames Wright       *device_type = torch::kXPU;
214c07ec22SJames Wright       break;
224c07ec22SJames Wright     case TORCH_DEVICE_CUDA:
234c07ec22SJames Wright       *device_type = torch::kCUDA;
244c07ec22SJames Wright       break;
254c07ec22SJames Wright     case TORCH_DEVICE_HIP:
264c07ec22SJames Wright       *device_type = torch::kHIP;
274c07ec22SJames Wright       break;
284c07ec22SJames Wright     default:
294c07ec22SJames Wright       SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "TorchDeviceType %d not supported by PyTorch inference", device_enum);
304c07ec22SJames Wright   }
314c07ec22SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
324c07ec22SJames Wright }
334c07ec22SJames Wright 
344c07ec22SJames Wright static PetscErrorCode PetscMemTypeToDeviceType(PetscMemType mem_type, torch::DeviceType *device_type) {
354c07ec22SJames Wright   PetscFunctionBeginUser;
364c07ec22SJames Wright   switch (mem_type) {
374c07ec22SJames Wright     case PETSC_MEMTYPE_HOST:
384c07ec22SJames Wright       *device_type = torch::kCPU;
394c07ec22SJames Wright       break;
404c07ec22SJames Wright     case PETSC_MEMTYPE_SYCL:
414c07ec22SJames Wright       *device_type = torch::kXPU;
424c07ec22SJames Wright       break;
434c07ec22SJames Wright     case PETSC_MEMTYPE_CUDA:
444c07ec22SJames Wright       *device_type = torch::kCUDA;
454c07ec22SJames Wright       break;
464c07ec22SJames Wright     case PETSC_MEMTYPE_HIP:
474c07ec22SJames Wright       *device_type = torch::kHIP;
484c07ec22SJames Wright       break;
494c07ec22SJames Wright     default:
504c07ec22SJames Wright       SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "PetscMemType %s not supported by PyTorch inference", PetscMemTypeToString(mem_type));
514c07ec22SJames Wright   }
524c07ec22SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
534c07ec22SJames Wright }
544c07ec22SJames Wright 
554c07ec22SJames Wright PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum) {
564c07ec22SJames Wright   PetscFunctionBeginUser;
574c07ec22SJames Wright   PetscCall(EnumToDeviceType(device_enum, &device_model));
584c07ec22SJames Wright 
594c07ec22SJames Wright   PetscCallCXX(model = torch::jit::load(model_path));
604c07ec22SJames Wright   PetscCallCXX(model.to(torch::Device(device_model)));
614c07ec22SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
624c07ec22SJames Wright }
634c07ec22SJames Wright 
644c07ec22SJames Wright // Load and run model
654c07ec22SJames Wright PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc) {
664c07ec22SJames Wright   torch::Tensor  input_tensor, output_tensor;
674c07ec22SJames Wright   const PetscInt num_input_comps = 6, num_output_comps = 6;
684c07ec22SJames Wright   PetscBool      debug_tensor_output = PETSC_FALSE;
694c07ec22SJames Wright 
704c07ec22SJames Wright   PetscFunctionBeginUser;
714c07ec22SJames Wright   // torch::NoGradGuard no_grad; // equivalent to "with torch.no_grad():" in PyTorch
72*ea615d4cSJames Wright   PetscCall(PetscLogEventBegin(HONEE_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
734c07ec22SJames Wright   {  // Transfer DD_Inputs_loc into input_tensor
744c07ec22SJames Wright     PetscMemType         input_mem_type;
754c07ec22SJames Wright     PetscInt             input_size, num_nodes;
764c07ec22SJames Wright     const PetscScalar   *dd_inputs_ptr;
774c07ec22SJames Wright     torch::DeviceType    dd_input_device;
784c07ec22SJames Wright     torch::TensorOptions options;
794c07ec22SJames Wright 
804c07ec22SJames Wright     PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size));
814c07ec22SJames Wright     num_nodes = input_size / num_input_comps;
824c07ec22SJames Wright     PetscCall(VecGetArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr, &input_mem_type));
834c07ec22SJames Wright     PetscCall(PetscMemTypeToDeviceType(input_mem_type, &dd_input_device));
844c07ec22SJames Wright 
854c07ec22SJames Wright     PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_input_device));
864c07ec22SJames Wright     if (dd_input_device == torch::kXPU) {  // XPU requires device-to-host-to-device transfer
874c07ec22SJames Wright       PetscCallCXX(input_tensor =
884c07ec22SJames Wright                        at::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, {num_input_comps, 1}, nullptr, options, dd_input_device)
894c07ec22SJames Wright                            .to(device_model));
904c07ec22SJames Wright     } else {
914c07ec22SJames Wright       PetscCallCXX(input_tensor = torch::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, options));
924c07ec22SJames Wright     }
934c07ec22SJames Wright     if (debug_tensor_output) {
944c07ec22SJames Wright       double *input_tensor_ptr;
954c07ec22SJames Wright 
964c07ec22SJames Wright       PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size));
974c07ec22SJames Wright       PetscCallCXX(input_tensor_ptr = (double *)input_tensor.contiguous().to(torch::kCPU).data_ptr());
984c07ec22SJames Wright       printf("Input_Tensor_Pointer:\n");
994c07ec22SJames Wright       for (PetscInt i = 0; i < input_size; i++) {
1004c07ec22SJames Wright         printf("%f\n", input_tensor_ptr[i]);
1014c07ec22SJames Wright       }
1024c07ec22SJames Wright     }
1034c07ec22SJames Wright     PetscCall(VecRestoreArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr));
1044c07ec22SJames Wright   }
105*ea615d4cSJames Wright   PetscCall(PetscLogEventEnd(HONEE_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
1064c07ec22SJames Wright 
1074c07ec22SJames Wright   // Run model
108*ea615d4cSJames Wright   PetscCall(PetscLogEventBegin(HONEE_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
109b40a7e63SJames Wright   PetscCall(PetscLogGpuTimeBegin());
1104c07ec22SJames Wright   PetscCallCXX(output_tensor = model.forward({input_tensor}).toTensor());
111b40a7e63SJames Wright   PetscCall(PetscLogGpuTimeEnd());
112*ea615d4cSJames Wright   PetscCall(PetscLogEventEnd(HONEE_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
1134c07ec22SJames Wright 
114*ea615d4cSJames Wright   PetscCall(PetscLogEventBegin(HONEE_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
1154c07ec22SJames Wright   {  // Transfer output_tensor to DD_Outputs_loc
1164c07ec22SJames Wright     torch::DeviceType    dd_output_device;
1174c07ec22SJames Wright     torch::TensorOptions options;
1184c07ec22SJames Wright     PetscInt             output_size;
1194c07ec22SJames Wright     PetscScalar         *dd_outputs_ptr;
1204c07ec22SJames Wright     PetscMemType         output_mem_type;
1214c07ec22SJames Wright 
1224c07ec22SJames Wright     {  // Get DeviceType of DD_Outputs_loc
1234c07ec22SJames Wright       PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type));
1244c07ec22SJames Wright       PetscCall(PetscMemTypeToDeviceType(output_mem_type, &dd_output_device));
1254c07ec22SJames Wright       PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr));
1264c07ec22SJames Wright     }
1274c07ec22SJames Wright 
1284c07ec22SJames Wright     if (dd_output_device == torch::kXPU) {  // XPU requires device-to-host-to-device transfer
1294c07ec22SJames Wright       double *output_tensor_ptr;
1304c07ec22SJames Wright 
1314c07ec22SJames Wright       PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size));
1324c07ec22SJames Wright       PetscCall(VecGetArray(DD_Outputs_loc, &dd_outputs_ptr));
1334c07ec22SJames Wright       PetscCallCXX(output_tensor_ptr = (double *)output_tensor.contiguous().to(torch::kCPU).data_ptr());
1344c07ec22SJames Wright       if (debug_tensor_output) {
1354c07ec22SJames Wright         printf("Output_Tensor_Pointer:\n");
1364c07ec22SJames Wright         for (PetscInt i = 0; i < output_size; i++) {
1374c07ec22SJames Wright           printf("%f\n", output_tensor_ptr[i]);
1384c07ec22SJames Wright         }
1394c07ec22SJames Wright       }
1404c07ec22SJames Wright       PetscCall(PetscArraycpy(dd_outputs_ptr, output_tensor_ptr, output_size));
1414c07ec22SJames Wright       PetscCall(VecRestoreArray(DD_Outputs_loc, &dd_outputs_ptr));
1424c07ec22SJames Wright     } else {
1434c07ec22SJames Wright       PetscInt      num_nodes;
1444c07ec22SJames Wright       torch::Tensor DD_Outputs_tensor;
1454c07ec22SJames Wright 
1464c07ec22SJames Wright       PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size));
1474c07ec22SJames Wright       num_nodes = output_size / num_output_comps;
1484c07ec22SJames Wright       PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type));
1494c07ec22SJames Wright       PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_output_device));
1504c07ec22SJames Wright       PetscCallCXX(DD_Outputs_tensor = torch::from_blob((void *)dd_outputs_ptr, {num_nodes, num_output_comps}, options));
1514c07ec22SJames Wright       PetscCallCXX(DD_Outputs_tensor.copy_(output_tensor));
1524c07ec22SJames Wright       PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr));
1534c07ec22SJames Wright     }
1544c07ec22SJames Wright   }
155*ea615d4cSJames Wright   PetscCall(PetscLogEventEnd(HONEE_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL));
1564c07ec22SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1574c07ec22SJames Wright }
158