1*4c07ec22SJames Wright // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2*4c07ec22SJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*4c07ec22SJames Wright // 4*4c07ec22SJames Wright // SPDX-License-Identifier: BSD-2-Clause 5*4c07ec22SJames Wright // 6*4c07ec22SJames Wright // This file is part of CEED: http://github.com/ceed 7*4c07ec22SJames Wright 8*4c07ec22SJames Wright #include <petsc.h> 9*4c07ec22SJames Wright #include <sgs_model_torch.h> 10*4c07ec22SJames Wright #include <torch/script.h> 11*4c07ec22SJames Wright #include <torch/torch.h> 12*4c07ec22SJames Wright 13*4c07ec22SJames Wright torch::jit::script::Module model; 14*4c07ec22SJames Wright torch::DeviceType device_model; 15*4c07ec22SJames Wright 16*4c07ec22SJames Wright static PetscErrorCode EnumToDeviceType(TorchDeviceType device_enum, torch::DeviceType *device_type) { 17*4c07ec22SJames Wright PetscFunctionBeginUser; 18*4c07ec22SJames Wright switch (device_enum) { 19*4c07ec22SJames Wright case TORCH_DEVICE_CPU: 20*4c07ec22SJames Wright *device_type = torch::kCPU; 21*4c07ec22SJames Wright break; 22*4c07ec22SJames Wright case TORCH_DEVICE_XPU: 23*4c07ec22SJames Wright *device_type = torch::kXPU; 24*4c07ec22SJames Wright break; 25*4c07ec22SJames Wright case TORCH_DEVICE_CUDA: 26*4c07ec22SJames Wright *device_type = torch::kCUDA; 27*4c07ec22SJames Wright break; 28*4c07ec22SJames Wright case TORCH_DEVICE_HIP: 29*4c07ec22SJames Wright *device_type = torch::kHIP; 30*4c07ec22SJames Wright break; 31*4c07ec22SJames Wright default: 32*4c07ec22SJames Wright SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "TorchDeviceType %d not supported by PyTorch inference", device_enum); 33*4c07ec22SJames Wright } 34*4c07ec22SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 35*4c07ec22SJames Wright } 36*4c07ec22SJames Wright 37*4c07ec22SJames Wright static PetscErrorCode PetscMemTypeToDeviceType(PetscMemType mem_type, torch::DeviceType *device_type) { 38*4c07ec22SJames Wright PetscFunctionBeginUser; 39*4c07ec22SJames Wright switch (mem_type) { 40*4c07ec22SJames Wright case PETSC_MEMTYPE_HOST: 41*4c07ec22SJames Wright *device_type = torch::kCPU; 42*4c07ec22SJames Wright break; 43*4c07ec22SJames Wright case PETSC_MEMTYPE_SYCL: 44*4c07ec22SJames Wright *device_type = torch::kXPU; 45*4c07ec22SJames Wright break; 46*4c07ec22SJames Wright case PETSC_MEMTYPE_CUDA: 47*4c07ec22SJames Wright *device_type = torch::kCUDA; 48*4c07ec22SJames Wright break; 49*4c07ec22SJames Wright case PETSC_MEMTYPE_HIP: 50*4c07ec22SJames Wright *device_type = torch::kHIP; 51*4c07ec22SJames Wright break; 52*4c07ec22SJames Wright default: 53*4c07ec22SJames Wright SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "PetscMemType %s not supported by PyTorch inference", PetscMemTypeToString(mem_type)); 54*4c07ec22SJames Wright } 55*4c07ec22SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 56*4c07ec22SJames Wright } 57*4c07ec22SJames Wright 58*4c07ec22SJames Wright PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum) { 59*4c07ec22SJames Wright PetscFunctionBeginUser; 60*4c07ec22SJames Wright PetscCall(EnumToDeviceType(device_enum, &device_model)); 61*4c07ec22SJames Wright 62*4c07ec22SJames Wright PetscCallCXX(model = torch::jit::load(model_path)); 63*4c07ec22SJames Wright PetscCallCXX(model.to(torch::Device(device_model))); 64*4c07ec22SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 65*4c07ec22SJames Wright } 66*4c07ec22SJames Wright 67*4c07ec22SJames Wright // Load and run model 68*4c07ec22SJames Wright PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc) { 69*4c07ec22SJames Wright torch::Tensor input_tensor, output_tensor; 70*4c07ec22SJames Wright const PetscInt num_input_comps = 6, num_output_comps = 6; 71*4c07ec22SJames Wright PetscBool debug_tensor_output = PETSC_FALSE; 72*4c07ec22SJames Wright 73*4c07ec22SJames Wright PetscFunctionBeginUser; 74*4c07ec22SJames Wright // torch::NoGradGuard no_grad; // equivalent to "with torch.no_grad():" in PyTorch 75*4c07ec22SJames Wright { // Transfer DD_Inputs_loc into input_tensor 76*4c07ec22SJames Wright PetscMemType input_mem_type; 77*4c07ec22SJames Wright PetscInt input_size, num_nodes; 78*4c07ec22SJames Wright const PetscScalar *dd_inputs_ptr; 79*4c07ec22SJames Wright torch::DeviceType dd_input_device; 80*4c07ec22SJames Wright torch::TensorOptions options; 81*4c07ec22SJames Wright 82*4c07ec22SJames Wright PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size)); 83*4c07ec22SJames Wright num_nodes = input_size / num_input_comps; 84*4c07ec22SJames Wright PetscCall(VecGetArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr, &input_mem_type)); 85*4c07ec22SJames Wright PetscCall(PetscMemTypeToDeviceType(input_mem_type, &dd_input_device)); 86*4c07ec22SJames Wright 87*4c07ec22SJames Wright PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_input_device)); 88*4c07ec22SJames Wright if (dd_input_device == torch::kXPU) { // XPU requires device-to-host-to-device transfer 89*4c07ec22SJames Wright PetscCallCXX(input_tensor = 90*4c07ec22SJames Wright at::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, {num_input_comps, 1}, nullptr, options, dd_input_device) 91*4c07ec22SJames Wright .to(device_model)); 92*4c07ec22SJames Wright } else { 93*4c07ec22SJames Wright PetscCallCXX(input_tensor = torch::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, options)); 94*4c07ec22SJames Wright } 95*4c07ec22SJames Wright if (debug_tensor_output) { 96*4c07ec22SJames Wright double *input_tensor_ptr; 97*4c07ec22SJames Wright 98*4c07ec22SJames Wright PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size)); 99*4c07ec22SJames Wright PetscCallCXX(input_tensor_ptr = (double *)input_tensor.contiguous().to(torch::kCPU).data_ptr()); 100*4c07ec22SJames Wright printf("Input_Tensor_Pointer:\n"); 101*4c07ec22SJames Wright for (PetscInt i = 0; i < input_size; i++) { 102*4c07ec22SJames Wright printf("%f\n", input_tensor_ptr[i]); 103*4c07ec22SJames Wright } 104*4c07ec22SJames Wright } 105*4c07ec22SJames Wright PetscCall(VecRestoreArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr)); 106*4c07ec22SJames Wright } 107*4c07ec22SJames Wright 108*4c07ec22SJames Wright // Run model 109*4c07ec22SJames Wright PetscCallCXX(output_tensor = model.forward({input_tensor}).toTensor()); 110*4c07ec22SJames Wright 111*4c07ec22SJames Wright { // Transfer output_tensor to DD_Outputs_loc 112*4c07ec22SJames Wright torch::DeviceType dd_output_device; 113*4c07ec22SJames Wright torch::TensorOptions options; 114*4c07ec22SJames Wright PetscInt output_size; 115*4c07ec22SJames Wright PetscScalar *dd_outputs_ptr; 116*4c07ec22SJames Wright PetscMemType output_mem_type; 117*4c07ec22SJames Wright 118*4c07ec22SJames Wright { // Get DeviceType of DD_Outputs_loc 119*4c07ec22SJames Wright PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type)); 120*4c07ec22SJames Wright PetscCall(PetscMemTypeToDeviceType(output_mem_type, &dd_output_device)); 121*4c07ec22SJames Wright PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr)); 122*4c07ec22SJames Wright } 123*4c07ec22SJames Wright 124*4c07ec22SJames Wright if (dd_output_device == torch::kXPU) { // XPU requires device-to-host-to-device transfer 125*4c07ec22SJames Wright double *output_tensor_ptr; 126*4c07ec22SJames Wright 127*4c07ec22SJames Wright PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size)); 128*4c07ec22SJames Wright PetscCall(VecGetArray(DD_Outputs_loc, &dd_outputs_ptr)); 129*4c07ec22SJames Wright PetscCallCXX(output_tensor_ptr = (double *)output_tensor.contiguous().to(torch::kCPU).data_ptr()); 130*4c07ec22SJames Wright if (debug_tensor_output) { 131*4c07ec22SJames Wright printf("Output_Tensor_Pointer:\n"); 132*4c07ec22SJames Wright for (PetscInt i = 0; i < output_size; i++) { 133*4c07ec22SJames Wright printf("%f\n", output_tensor_ptr[i]); 134*4c07ec22SJames Wright } 135*4c07ec22SJames Wright } 136*4c07ec22SJames Wright PetscCall(PetscArraycpy(dd_outputs_ptr, output_tensor_ptr, output_size)); 137*4c07ec22SJames Wright PetscCall(VecRestoreArray(DD_Outputs_loc, &dd_outputs_ptr)); 138*4c07ec22SJames Wright } else { 139*4c07ec22SJames Wright PetscInt num_nodes; 140*4c07ec22SJames Wright torch::Tensor DD_Outputs_tensor; 141*4c07ec22SJames Wright 142*4c07ec22SJames Wright PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size)); 143*4c07ec22SJames Wright num_nodes = output_size / num_output_comps; 144*4c07ec22SJames Wright PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type)); 145*4c07ec22SJames Wright PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_output_device)); 146*4c07ec22SJames Wright PetscCallCXX(DD_Outputs_tensor = torch::from_blob((void *)dd_outputs_ptr, {num_nodes, num_output_comps}, options)); 147*4c07ec22SJames Wright PetscCallCXX(DD_Outputs_tensor.copy_(output_tensor)); 148*4c07ec22SJames Wright PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr)); 149*4c07ec22SJames Wright } 150*4c07ec22SJames Wright } 151*4c07ec22SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 152*4c07ec22SJames Wright } 153