xref: /honee/include/sgs_model_torch.h (revision c0d10d1ddfa51fc3c0a8079706784eb95f3ed88f)
1ae2b091fSJames Wright // SPDX-FileCopyrightText: Copyright (c) 2017-2024, HONEE contributors.
2ae2b091fSJames Wright // SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause
34c07ec22SJames Wright 
4*1c58d510SJames Wright #include <petscvec.h>
54c07ec22SJames Wright 
64c07ec22SJames Wright #ifdef __cplusplus
74c07ec22SJames Wright extern "C" {
84c07ec22SJames Wright #endif
94c07ec22SJames Wright 
104c07ec22SJames Wright typedef enum {
114c07ec22SJames Wright   TORCH_DEVICE_CPU,
124c07ec22SJames Wright   TORCH_DEVICE_CUDA,
134c07ec22SJames Wright   TORCH_DEVICE_HIP,
144c07ec22SJames Wright   TORCH_DEVICE_XPU,
154c07ec22SJames Wright } TorchDeviceType;
166dfcbb05SJames Wright static const char *const TorchDeviceTypes[] = {"CPU", "CUDA", "HIP", "XPU", "TorchDeviceType", "TORCH_DEVICE_", NULL};
174c07ec22SJames Wright 
184c07ec22SJames Wright PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum);
194c07ec22SJames Wright PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc);
204c07ec22SJames Wright 
214c07ec22SJames Wright #ifdef __cplusplus
224c07ec22SJames Wright }
234c07ec22SJames Wright #endif
24