1 // SPDX-FileCopyrightText: Copyright (c) 2017-2024, HONEE contributors. 2 // SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause 3 4 #include <petsc.h> 5 6 #ifdef __cplusplus 7 extern "C" { 8 #endif 9 10 typedef enum { 11 TORCH_DEVICE_CPU, 12 TORCH_DEVICE_CUDA, 13 TORCH_DEVICE_HIP, 14 TORCH_DEVICE_XPU, 15 } TorchDeviceType; 16 static const char *const TorchDeviceTypes[] = {"CPU", "CUDA", "HIP", "XPU", "TorchDeviceType", "TORCH_DEVICE_", NULL}; 17 18 PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum); 19 PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc); 20 21 #ifdef __cplusplus 22 } 23 #endif 24