1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <petsc.h> 9 #include <sgs_model_torch.h> 10 #include <torch/script.h> 11 #include <torch/torch.h> 12 13 torch::jit::script::Module model; 14 torch::DeviceType device_model; 15 16 static PetscErrorCode EnumToDeviceType(TorchDeviceType device_enum, torch::DeviceType *device_type) { 17 PetscFunctionBeginUser; 18 switch (device_enum) { 19 case TORCH_DEVICE_CPU: 20 *device_type = torch::kCPU; 21 break; 22 case TORCH_DEVICE_XPU: 23 *device_type = torch::kXPU; 24 break; 25 case TORCH_DEVICE_CUDA: 26 *device_type = torch::kCUDA; 27 break; 28 case TORCH_DEVICE_HIP: 29 *device_type = torch::kHIP; 30 break; 31 default: 32 SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "TorchDeviceType %d not supported by PyTorch inference", device_enum); 33 } 34 PetscFunctionReturn(PETSC_SUCCESS); 35 } 36 37 static PetscErrorCode PetscMemTypeToDeviceType(PetscMemType mem_type, torch::DeviceType *device_type) { 38 PetscFunctionBeginUser; 39 switch (mem_type) { 40 case PETSC_MEMTYPE_HOST: 41 *device_type = torch::kCPU; 42 break; 43 case PETSC_MEMTYPE_SYCL: 44 *device_type = torch::kXPU; 45 break; 46 case PETSC_MEMTYPE_CUDA: 47 *device_type = torch::kCUDA; 48 break; 49 case PETSC_MEMTYPE_HIP: 50 *device_type = torch::kHIP; 51 break; 52 default: 53 SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "PetscMemType %s not supported by PyTorch inference", PetscMemTypeToString(mem_type)); 54 } 55 PetscFunctionReturn(PETSC_SUCCESS); 56 } 57 58 PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum) { 59 PetscFunctionBeginUser; 60 PetscCall(EnumToDeviceType(device_enum, &device_model)); 61 62 PetscCallCXX(model = torch::jit::load(model_path)); 63 PetscCallCXX(model.to(torch::Device(device_model))); 64 PetscFunctionReturn(PETSC_SUCCESS); 65 } 66 67 // Load and run model 68 PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc) { 69 torch::Tensor input_tensor, output_tensor; 70 const PetscInt num_input_comps = 6, num_output_comps = 6; 71 PetscBool debug_tensor_output = PETSC_FALSE; 72 73 PetscFunctionBeginUser; 74 // torch::NoGradGuard no_grad; // equivalent to "with torch.no_grad():" in PyTorch 75 { // Transfer DD_Inputs_loc into input_tensor 76 PetscMemType input_mem_type; 77 PetscInt input_size, num_nodes; 78 const PetscScalar *dd_inputs_ptr; 79 torch::DeviceType dd_input_device; 80 torch::TensorOptions options; 81 82 PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size)); 83 num_nodes = input_size / num_input_comps; 84 PetscCall(VecGetArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr, &input_mem_type)); 85 PetscCall(PetscMemTypeToDeviceType(input_mem_type, &dd_input_device)); 86 87 PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_input_device)); 88 if (dd_input_device == torch::kXPU) { // XPU requires device-to-host-to-device transfer 89 PetscCallCXX(input_tensor = 90 at::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, {num_input_comps, 1}, nullptr, options, dd_input_device) 91 .to(device_model)); 92 } else { 93 PetscCallCXX(input_tensor = torch::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, options)); 94 } 95 if (debug_tensor_output) { 96 double *input_tensor_ptr; 97 98 PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size)); 99 PetscCallCXX(input_tensor_ptr = (double *)input_tensor.contiguous().to(torch::kCPU).data_ptr()); 100 printf("Input_Tensor_Pointer:\n"); 101 for (PetscInt i = 0; i < input_size; i++) { 102 printf("%f\n", input_tensor_ptr[i]); 103 } 104 } 105 PetscCall(VecRestoreArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr)); 106 } 107 108 // Run model 109 PetscCallCXX(output_tensor = model.forward({input_tensor}).toTensor()); 110 111 { // Transfer output_tensor to DD_Outputs_loc 112 torch::DeviceType dd_output_device; 113 torch::TensorOptions options; 114 PetscInt output_size; 115 PetscScalar *dd_outputs_ptr; 116 PetscMemType output_mem_type; 117 118 { // Get DeviceType of DD_Outputs_loc 119 PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type)); 120 PetscCall(PetscMemTypeToDeviceType(output_mem_type, &dd_output_device)); 121 PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr)); 122 } 123 124 if (dd_output_device == torch::kXPU) { // XPU requires device-to-host-to-device transfer 125 double *output_tensor_ptr; 126 127 PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size)); 128 PetscCall(VecGetArray(DD_Outputs_loc, &dd_outputs_ptr)); 129 PetscCallCXX(output_tensor_ptr = (double *)output_tensor.contiguous().to(torch::kCPU).data_ptr()); 130 if (debug_tensor_output) { 131 printf("Output_Tensor_Pointer:\n"); 132 for (PetscInt i = 0; i < output_size; i++) { 133 printf("%f\n", output_tensor_ptr[i]); 134 } 135 } 136 PetscCall(PetscArraycpy(dd_outputs_ptr, output_tensor_ptr, output_size)); 137 PetscCall(VecRestoreArray(DD_Outputs_loc, &dd_outputs_ptr)); 138 } else { 139 PetscInt num_nodes; 140 torch::Tensor DD_Outputs_tensor; 141 142 PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size)); 143 num_nodes = output_size / num_output_comps; 144 PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type)); 145 PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_output_device)); 146 PetscCallCXX(DD_Outputs_tensor = torch::from_blob((void *)dd_outputs_ptr, {num_nodes, num_output_comps}, options)); 147 PetscCallCXX(DD_Outputs_tensor.copy_(output_tensor)); 148 PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr)); 149 } 150 } 151 PetscFunctionReturn(PETSC_SUCCESS); 152 } 153