1ae2b091fSJames Wright // SPDX-FileCopyrightText: Copyright (c) 2017-2024, HONEE contributors. 2ae2b091fSJames Wright // SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause 34c07ec22SJames Wright 4*1c58d510SJames Wright #include <petscvec.h> 54c07ec22SJames Wright 64c07ec22SJames Wright #ifdef __cplusplus 74c07ec22SJames Wright extern "C" { 84c07ec22SJames Wright #endif 94c07ec22SJames Wright 104c07ec22SJames Wright typedef enum { 114c07ec22SJames Wright TORCH_DEVICE_CPU, 124c07ec22SJames Wright TORCH_DEVICE_CUDA, 134c07ec22SJames Wright TORCH_DEVICE_HIP, 144c07ec22SJames Wright TORCH_DEVICE_XPU, 154c07ec22SJames Wright } TorchDeviceType; 166dfcbb05SJames Wright static const char *const TorchDeviceTypes[] = {"CPU", "CUDA", "HIP", "XPU", "TorchDeviceType", "TORCH_DEVICE_", NULL}; 174c07ec22SJames Wright 184c07ec22SJames Wright PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum); 194c07ec22SJames Wright PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc); 204c07ec22SJames Wright 214c07ec22SJames Wright #ifdef __cplusplus 224c07ec22SJames Wright } 234c07ec22SJames Wright #endif 24