1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <log_events.h> 9 #include <petsc.h> 10 #include <sgs_model_torch.h> 11 #include <torch/script.h> 12 #include <torch/torch.h> 13 14 torch::jit::script::Module model; 15 torch::DeviceType device_model; 16 17 static PetscErrorCode EnumToDeviceType(TorchDeviceType device_enum, torch::DeviceType *device_type) { 18 PetscFunctionBeginUser; 19 switch (device_enum) { 20 case TORCH_DEVICE_CPU: 21 *device_type = torch::kCPU; 22 break; 23 case TORCH_DEVICE_XPU: 24 *device_type = torch::kXPU; 25 break; 26 case TORCH_DEVICE_CUDA: 27 *device_type = torch::kCUDA; 28 break; 29 case TORCH_DEVICE_HIP: 30 *device_type = torch::kHIP; 31 break; 32 default: 33 SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "TorchDeviceType %d not supported by PyTorch inference", device_enum); 34 } 35 PetscFunctionReturn(PETSC_SUCCESS); 36 } 37 38 static PetscErrorCode PetscMemTypeToDeviceType(PetscMemType mem_type, torch::DeviceType *device_type) { 39 PetscFunctionBeginUser; 40 switch (mem_type) { 41 case PETSC_MEMTYPE_HOST: 42 *device_type = torch::kCPU; 43 break; 44 case PETSC_MEMTYPE_SYCL: 45 *device_type = torch::kXPU; 46 break; 47 case PETSC_MEMTYPE_CUDA: 48 *device_type = torch::kCUDA; 49 break; 50 case PETSC_MEMTYPE_HIP: 51 *device_type = torch::kHIP; 52 break; 53 default: 54 SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "PetscMemType %s not supported by PyTorch inference", PetscMemTypeToString(mem_type)); 55 } 56 PetscFunctionReturn(PETSC_SUCCESS); 57 } 58 59 PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum) { 60 PetscFunctionBeginUser; 61 PetscCall(EnumToDeviceType(device_enum, &device_model)); 62 63 PetscCallCXX(model = torch::jit::load(model_path)); 64 PetscCallCXX(model.to(torch::Device(device_model))); 65 PetscFunctionReturn(PETSC_SUCCESS); 66 } 67 68 // Load and run model 69 PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc) { 70 torch::Tensor input_tensor, output_tensor; 71 const PetscInt num_input_comps = 6, num_output_comps = 6; 72 PetscBool debug_tensor_output = PETSC_FALSE; 73 74 PetscFunctionBeginUser; 75 // torch::NoGradGuard no_grad; // equivalent to "with torch.no_grad():" in PyTorch 76 PetscCall(PetscLogEventBegin(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 77 { // Transfer DD_Inputs_loc into input_tensor 78 PetscMemType input_mem_type; 79 PetscInt input_size, num_nodes; 80 const PetscScalar *dd_inputs_ptr; 81 torch::DeviceType dd_input_device; 82 torch::TensorOptions options; 83 84 PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size)); 85 num_nodes = input_size / num_input_comps; 86 PetscCall(VecGetArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr, &input_mem_type)); 87 PetscCall(PetscMemTypeToDeviceType(input_mem_type, &dd_input_device)); 88 89 PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_input_device)); 90 if (dd_input_device == torch::kXPU) { // XPU requires device-to-host-to-device transfer 91 PetscCallCXX(input_tensor = 92 at::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, {num_input_comps, 1}, nullptr, options, dd_input_device) 93 .to(device_model)); 94 } else { 95 PetscCallCXX(input_tensor = torch::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, options)); 96 } 97 if (debug_tensor_output) { 98 double *input_tensor_ptr; 99 100 PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size)); 101 PetscCallCXX(input_tensor_ptr = (double *)input_tensor.contiguous().to(torch::kCPU).data_ptr()); 102 printf("Input_Tensor_Pointer:\n"); 103 for (PetscInt i = 0; i < input_size; i++) { 104 printf("%f\n", input_tensor_ptr[i]); 105 } 106 } 107 PetscCall(VecRestoreArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr)); 108 } 109 PetscCall(PetscLogEventEnd(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 110 111 // Run model 112 PetscCall(PetscLogEventBegin(FLUIDS_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 113 PetscCall(PetscLogGpuTimeBegin()); 114 PetscCallCXX(output_tensor = model.forward({input_tensor}).toTensor()); 115 PetscCall(PetscLogGpuTimeEnd()); 116 PetscCall(PetscLogEventEnd(FLUIDS_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 117 118 PetscCall(PetscLogEventBegin(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 119 { // Transfer output_tensor to DD_Outputs_loc 120 torch::DeviceType dd_output_device; 121 torch::TensorOptions options; 122 PetscInt output_size; 123 PetscScalar *dd_outputs_ptr; 124 PetscMemType output_mem_type; 125 126 { // Get DeviceType of DD_Outputs_loc 127 PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type)); 128 PetscCall(PetscMemTypeToDeviceType(output_mem_type, &dd_output_device)); 129 PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr)); 130 } 131 132 if (dd_output_device == torch::kXPU) { // XPU requires device-to-host-to-device transfer 133 double *output_tensor_ptr; 134 135 PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size)); 136 PetscCall(VecGetArray(DD_Outputs_loc, &dd_outputs_ptr)); 137 PetscCallCXX(output_tensor_ptr = (double *)output_tensor.contiguous().to(torch::kCPU).data_ptr()); 138 if (debug_tensor_output) { 139 printf("Output_Tensor_Pointer:\n"); 140 for (PetscInt i = 0; i < output_size; i++) { 141 printf("%f\n", output_tensor_ptr[i]); 142 } 143 } 144 PetscCall(PetscArraycpy(dd_outputs_ptr, output_tensor_ptr, output_size)); 145 PetscCall(VecRestoreArray(DD_Outputs_loc, &dd_outputs_ptr)); 146 } else { 147 PetscInt num_nodes; 148 torch::Tensor DD_Outputs_tensor; 149 150 PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size)); 151 num_nodes = output_size / num_output_comps; 152 PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type)); 153 PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_output_device)); 154 PetscCallCXX(DD_Outputs_tensor = torch::from_blob((void *)dd_outputs_ptr, {num_nodes, num_output_comps}, options)); 155 PetscCallCXX(DD_Outputs_tensor.copy_(output_tensor)); 156 PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr)); 157 } 158 } 159 PetscCall(PetscLogEventEnd(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 160 PetscFunctionReturn(PETSC_SUCCESS); 161 } 162