1*ae2b091fSJames Wright // SPDX-FileCopyrightText: Copyright (c) 2017-2024, HONEE contributors. 2*ae2b091fSJames Wright // SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause 34c07ec22SJames Wright 4b40a7e63SJames Wright #include <log_events.h> 54c07ec22SJames Wright #include <petsc.h> 64c07ec22SJames Wright #include <sgs_model_torch.h> 74c07ec22SJames Wright #include <torch/script.h> 84c07ec22SJames Wright #include <torch/torch.h> 94c07ec22SJames Wright 104c07ec22SJames Wright torch::jit::script::Module model; 114c07ec22SJames Wright torch::DeviceType device_model; 124c07ec22SJames Wright 134c07ec22SJames Wright static PetscErrorCode EnumToDeviceType(TorchDeviceType device_enum, torch::DeviceType *device_type) { 144c07ec22SJames Wright PetscFunctionBeginUser; 154c07ec22SJames Wright switch (device_enum) { 164c07ec22SJames Wright case TORCH_DEVICE_CPU: 174c07ec22SJames Wright *device_type = torch::kCPU; 184c07ec22SJames Wright break; 194c07ec22SJames Wright case TORCH_DEVICE_XPU: 204c07ec22SJames Wright *device_type = torch::kXPU; 214c07ec22SJames Wright break; 224c07ec22SJames Wright case TORCH_DEVICE_CUDA: 234c07ec22SJames Wright *device_type = torch::kCUDA; 244c07ec22SJames Wright break; 254c07ec22SJames Wright case TORCH_DEVICE_HIP: 264c07ec22SJames Wright *device_type = torch::kHIP; 274c07ec22SJames Wright break; 284c07ec22SJames Wright default: 294c07ec22SJames Wright SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "TorchDeviceType %d not supported by PyTorch inference", device_enum); 304c07ec22SJames Wright } 314c07ec22SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 324c07ec22SJames Wright } 334c07ec22SJames Wright 344c07ec22SJames Wright static PetscErrorCode PetscMemTypeToDeviceType(PetscMemType mem_type, torch::DeviceType *device_type) { 354c07ec22SJames Wright PetscFunctionBeginUser; 364c07ec22SJames Wright switch (mem_type) { 374c07ec22SJames Wright case PETSC_MEMTYPE_HOST: 384c07ec22SJames Wright *device_type = torch::kCPU; 394c07ec22SJames Wright break; 404c07ec22SJames Wright case PETSC_MEMTYPE_SYCL: 414c07ec22SJames Wright *device_type = torch::kXPU; 424c07ec22SJames Wright break; 434c07ec22SJames Wright case PETSC_MEMTYPE_CUDA: 444c07ec22SJames Wright *device_type = torch::kCUDA; 454c07ec22SJames Wright break; 464c07ec22SJames Wright case PETSC_MEMTYPE_HIP: 474c07ec22SJames Wright *device_type = torch::kHIP; 484c07ec22SJames Wright break; 494c07ec22SJames Wright default: 504c07ec22SJames Wright SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_SUP, "PetscMemType %s not supported by PyTorch inference", PetscMemTypeToString(mem_type)); 514c07ec22SJames Wright } 524c07ec22SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 534c07ec22SJames Wright } 544c07ec22SJames Wright 554c07ec22SJames Wright PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum) { 564c07ec22SJames Wright PetscFunctionBeginUser; 574c07ec22SJames Wright PetscCall(EnumToDeviceType(device_enum, &device_model)); 584c07ec22SJames Wright 594c07ec22SJames Wright PetscCallCXX(model = torch::jit::load(model_path)); 604c07ec22SJames Wright PetscCallCXX(model.to(torch::Device(device_model))); 614c07ec22SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 624c07ec22SJames Wright } 634c07ec22SJames Wright 644c07ec22SJames Wright // Load and run model 654c07ec22SJames Wright PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc) { 664c07ec22SJames Wright torch::Tensor input_tensor, output_tensor; 674c07ec22SJames Wright const PetscInt num_input_comps = 6, num_output_comps = 6; 684c07ec22SJames Wright PetscBool debug_tensor_output = PETSC_FALSE; 694c07ec22SJames Wright 704c07ec22SJames Wright PetscFunctionBeginUser; 714c07ec22SJames Wright // torch::NoGradGuard no_grad; // equivalent to "with torch.no_grad():" in PyTorch 72b40a7e63SJames Wright PetscCall(PetscLogEventBegin(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 734c07ec22SJames Wright { // Transfer DD_Inputs_loc into input_tensor 744c07ec22SJames Wright PetscMemType input_mem_type; 754c07ec22SJames Wright PetscInt input_size, num_nodes; 764c07ec22SJames Wright const PetscScalar *dd_inputs_ptr; 774c07ec22SJames Wright torch::DeviceType dd_input_device; 784c07ec22SJames Wright torch::TensorOptions options; 794c07ec22SJames Wright 804c07ec22SJames Wright PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size)); 814c07ec22SJames Wright num_nodes = input_size / num_input_comps; 824c07ec22SJames Wright PetscCall(VecGetArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr, &input_mem_type)); 834c07ec22SJames Wright PetscCall(PetscMemTypeToDeviceType(input_mem_type, &dd_input_device)); 844c07ec22SJames Wright 854c07ec22SJames Wright PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_input_device)); 864c07ec22SJames Wright if (dd_input_device == torch::kXPU) { // XPU requires device-to-host-to-device transfer 874c07ec22SJames Wright PetscCallCXX(input_tensor = 884c07ec22SJames Wright at::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, {num_input_comps, 1}, nullptr, options, dd_input_device) 894c07ec22SJames Wright .to(device_model)); 904c07ec22SJames Wright } else { 914c07ec22SJames Wright PetscCallCXX(input_tensor = torch::from_blob((void *)dd_inputs_ptr, {num_nodes, num_input_comps}, options)); 924c07ec22SJames Wright } 934c07ec22SJames Wright if (debug_tensor_output) { 944c07ec22SJames Wright double *input_tensor_ptr; 954c07ec22SJames Wright 964c07ec22SJames Wright PetscCall(VecGetLocalSize(DD_Inputs_loc, &input_size)); 974c07ec22SJames Wright PetscCallCXX(input_tensor_ptr = (double *)input_tensor.contiguous().to(torch::kCPU).data_ptr()); 984c07ec22SJames Wright printf("Input_Tensor_Pointer:\n"); 994c07ec22SJames Wright for (PetscInt i = 0; i < input_size; i++) { 1004c07ec22SJames Wright printf("%f\n", input_tensor_ptr[i]); 1014c07ec22SJames Wright } 1024c07ec22SJames Wright } 1034c07ec22SJames Wright PetscCall(VecRestoreArrayReadAndMemType(DD_Inputs_loc, &dd_inputs_ptr)); 1044c07ec22SJames Wright } 105b40a7e63SJames Wright PetscCall(PetscLogEventEnd(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 1064c07ec22SJames Wright 1074c07ec22SJames Wright // Run model 108b40a7e63SJames Wright PetscCall(PetscLogEventBegin(FLUIDS_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 109b40a7e63SJames Wright PetscCall(PetscLogGpuTimeBegin()); 1104c07ec22SJames Wright PetscCallCXX(output_tensor = model.forward({input_tensor}).toTensor()); 111b40a7e63SJames Wright PetscCall(PetscLogGpuTimeEnd()); 112b40a7e63SJames Wright PetscCall(PetscLogEventEnd(FLUIDS_SgsModelDDInference, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 1134c07ec22SJames Wright 114b40a7e63SJames Wright PetscCall(PetscLogEventBegin(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 1154c07ec22SJames Wright { // Transfer output_tensor to DD_Outputs_loc 1164c07ec22SJames Wright torch::DeviceType dd_output_device; 1174c07ec22SJames Wright torch::TensorOptions options; 1184c07ec22SJames Wright PetscInt output_size; 1194c07ec22SJames Wright PetscScalar *dd_outputs_ptr; 1204c07ec22SJames Wright PetscMemType output_mem_type; 1214c07ec22SJames Wright 1224c07ec22SJames Wright { // Get DeviceType of DD_Outputs_loc 1234c07ec22SJames Wright PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type)); 1244c07ec22SJames Wright PetscCall(PetscMemTypeToDeviceType(output_mem_type, &dd_output_device)); 1254c07ec22SJames Wright PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr)); 1264c07ec22SJames Wright } 1274c07ec22SJames Wright 1284c07ec22SJames Wright if (dd_output_device == torch::kXPU) { // XPU requires device-to-host-to-device transfer 1294c07ec22SJames Wright double *output_tensor_ptr; 1304c07ec22SJames Wright 1314c07ec22SJames Wright PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size)); 1324c07ec22SJames Wright PetscCall(VecGetArray(DD_Outputs_loc, &dd_outputs_ptr)); 1334c07ec22SJames Wright PetscCallCXX(output_tensor_ptr = (double *)output_tensor.contiguous().to(torch::kCPU).data_ptr()); 1344c07ec22SJames Wright if (debug_tensor_output) { 1354c07ec22SJames Wright printf("Output_Tensor_Pointer:\n"); 1364c07ec22SJames Wright for (PetscInt i = 0; i < output_size; i++) { 1374c07ec22SJames Wright printf("%f\n", output_tensor_ptr[i]); 1384c07ec22SJames Wright } 1394c07ec22SJames Wright } 1404c07ec22SJames Wright PetscCall(PetscArraycpy(dd_outputs_ptr, output_tensor_ptr, output_size)); 1414c07ec22SJames Wright PetscCall(VecRestoreArray(DD_Outputs_loc, &dd_outputs_ptr)); 1424c07ec22SJames Wright } else { 1434c07ec22SJames Wright PetscInt num_nodes; 1444c07ec22SJames Wright torch::Tensor DD_Outputs_tensor; 1454c07ec22SJames Wright 1464c07ec22SJames Wright PetscCall(VecGetLocalSize(DD_Outputs_loc, &output_size)); 1474c07ec22SJames Wright num_nodes = output_size / num_output_comps; 1484c07ec22SJames Wright PetscCall(VecGetArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr, &output_mem_type)); 1494c07ec22SJames Wright PetscCallCXX(options = torch::TensorOptions().dtype(torch::kFloat64).device(dd_output_device)); 1504c07ec22SJames Wright PetscCallCXX(DD_Outputs_tensor = torch::from_blob((void *)dd_outputs_ptr, {num_nodes, num_output_comps}, options)); 1514c07ec22SJames Wright PetscCallCXX(DD_Outputs_tensor.copy_(output_tensor)); 1524c07ec22SJames Wright PetscCall(VecRestoreArrayAndMemType(DD_Outputs_loc, &dd_outputs_ptr)); 1534c07ec22SJames Wright } 1544c07ec22SJames Wright } 155b40a7e63SJames Wright PetscCall(PetscLogEventEnd(FLUIDS_SgsModelDDData, DD_Inputs_loc, DD_Outputs_loc, NULL, NULL)); 1564c07ec22SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1574c07ec22SJames Wright } 158