1 // SPDX-FileCopyrightText: Copyright (c) 2017-2024, HONEE contributors. 2 // SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause 3 4 #include <log_events.h> 5 #include <petsc.h> 6 #include <sgs_model_torch.h> 7 #include <torch/script.h> 8 #include <torch/torch.h> 9 10 torch::jit::script::Module model; 11 torch::DeviceType device_model; 12 13 static PetscErrorCode EnumToDeviceType(TorchDeviceType device_enum, torch::DeviceType *device_type) { 14 PetscFunctionBeginUser; 15 switch (device_enum) { 16 case TORCH_DEVICE_CPU: 17 *device_type = torch::kCPU; 18 break; 19 case TORCH_DEVICE_XPU: 20 *device_type = torch::kXPU; 21 break; 22 case TORCH_DEVICE_CUDA: 23 *device_type = torch::kCUDA; 24 break; 25 case TORCH_DEVICE_HIP: 26 *device_type = torch::kHIP; 27 break; 28 default: 29 SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "TorchDeviceType %d not supported by PyTorch inference", device_enum); 30 } 31 PetscFunctionReturn(PETSC_SUCCESS); 32 } 33 34 static PetscErrorCode PetscMemTypeToDeviceType(PetscMemType mem_type, torch::DeviceType *device_type) { 35 PetscFunctionBeginUser; 36 switch (mem_type) { 37 case PETSC_MEMTYPE_HOST: 38 *device_type = torch::kCPU; 39 break; 40 case PETSC_MEMTYPE_SYCL: 41 *device_type = torch::kXPU; 42 break; 43 case PETSC_MEMTYPE_CUDA: 44 *device_type = torch::kCUDA; 45 break; 46 case PETSC_MEMTYPE_HIP: 47 *device_type = torch::kHIP; 48 break; 49 default: 50 SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "PetscMemType %s not supported by PyTorch inference", PetscMemTypeToString(mem_type)); 51 } 52 PetscFunctionReturn(PETSC_SUCCESS); 53 } 54 55 PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum) { 56 PetscFunctionBeginUser; 57 PetscCall(EnumToDeviceType(device_enum, &device_model)); 58 59 PetscCallCXX(model = torch::jit::load(model_path)); 60 PetscCallCXX(model.to(torch::Device(device_model))); 61 PetscFunctionReturn(PETSC_SUCCESS); 62 } 63 64 // Load and run model 65 PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc) { 66 torch::Tensor input_tensor, output_tensor; 67 const PetscInt num_input_comps = 6, num_output_comps = 6; 68 PetscBool debug_tensor_output = PETSC_FALSE; 69 70 PetscFunctionBeginUser; 71 // torch::NoGradGuard no_grad; // equivalent to "with torch.no_grad():" in PyTorch 72 PetscCall(PetscLogEventBegin(HONEE_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 73 { // Transfer DD_Inputs_loc into input_tensor 74 PetscMemType input_mem_type; 75 PetscInt input_size, num_nodes; 76 const PetscScalar *dd_inputs_ptr; 77 torch::DeviceType dd_input_device; 78 torch::TensorOptions options; 79 80 PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size)); 81 num_nodes = input_size / num_input_comps; 82 PetscCall(VecGetArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr, &input_mem_type)); 83 PetscCall(PetscMemTypeToDeviceType(input_mem_type, &dd_input_device)); 84 85 PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_input_device)); 86 if (dd_input_device == torch::kXPU) { // XPU requires device-to-host-to-device transfer 87 PetscCallCXX(input_tensor = 88 at::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, {num_input_comps, 1}, nullptr, options, dd_input_device) 89 .to(device_model)); 90 } else { 91 PetscCallCXX(input_tensor = torch::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, options)); 92 } 93 if (debug_tensor_output) { 94 double *input_tensor_ptr; 95 96 PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size)); 97 PetscCallCXX(input_tensor_ptr = (double *)input_tensor.contiguous().to(torch::kCPU).data_ptr()); 98 printf("Input_Tensor_Pointer:\n"); 99 for (PetscInt i = 0; i < input_size; i++) { 100 printf("%f\n", input_tensor_ptr[i]); 101 } 102 } 103 PetscCall(VecRestoreArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr)); 104 } 105 PetscCall(PetscLogEventEnd(HONEE_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 106 107 // Run model 108 PetscCall(PetscLogEventBegin(HONEE_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 109 PetscCall(PetscLogGpuTimeBegin()); 110 PetscCallCXX(output_tensor = model.forward({input_tensor}).toTensor()); 111 PetscCall(PetscLogGpuTimeEnd()); 112 PetscCall(PetscLogEventEnd(HONEE_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 113 114 PetscCall(PetscLogEventBegin(HONEE_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 115 { // Transfer output_tensor to DD_Outputs_loc 116 torch::DeviceType dd_output_device; 117 torch::TensorOptions options; 118 PetscInt output_size; 119 PetscScalar *dd_outputs_ptr; 120 PetscMemType output_mem_type; 121 122 { // Get DeviceType of DD_Outputs_loc 123 PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type)); 124 PetscCall(PetscMemTypeToDeviceType(output_mem_type, &dd_output_device)); 125 PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr)); 126 } 127 128 if (dd_output_device == torch::kXPU) { // XPU requires device-to-host-to-device transfer 129 double *output_tensor_ptr; 130 131 PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size)); 132 PetscCall(VecGetArray(DD_Outputs_loc, &dd_outputs_ptr)); 133 PetscCallCXX(output_tensor_ptr = (double *)output_tensor.contiguous().to(torch::kCPU).data_ptr()); 134 if (debug_tensor_output) { 135 printf("Output_Tensor_Pointer:\n"); 136 for (PetscInt i = 0; i < output_size; i++) { 137 printf("%f\n", output_tensor_ptr[i]); 138 } 139 } 140 PetscCall(PetscArraycpy(dd_outputs_ptr, output_tensor_ptr, output_size)); 141 PetscCall(VecRestoreArray(DD_Outputs_loc, &dd_outputs_ptr)); 142 } else { 143 PetscInt num_nodes; 144 torch::Tensor DD_Outputs_tensor; 145 146 PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size)); 147 num_nodes = output_size / num_output_comps; 148 PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type)); 149 PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_output_device)); 150 PetscCallCXX(DD_Outputs_tensor = torch::from_blob((void *)dd_outputs_ptr, {num_nodes, num_output_comps}, options)); 151 PetscCallCXX(DD_Outputs_tensor.copy_(output_tensor)); 152 PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr)); 153 } 154 } 155 PetscCall(PetscLogEventEnd(HONEE_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 156 PetscFunctionReturn(PETSC_SUCCESS); 157 } 158