xref: /honee/include/sgs_model_torch.h (revision da59998647b348cb577ad1134294158b446d82cc)
1 // SPDX-FileCopyrightText: Copyright (c) 2017-2024, HONEE contributors.
2 // SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause
3 
4 #include <petscvec.h>
5 
6 #ifdef __cplusplus
7 extern "C" {
8 #endif
9 
10 typedef enum {
11   TORCH_DEVICE_CPU,
12   TORCH_DEVICE_CUDA,
13   TORCH_DEVICE_HIP,
14   TORCH_DEVICE_XPU,
15 } TorchDeviceType;
16 static const char *const TorchDeviceTypes[] = {"CPU", "CUDA", "HIP", "XPU", "TorchDeviceType", "TORCH_DEVICE_", NULL};
17 
18 PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum);
19 PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc);
20 
21 #ifdef __cplusplus
22 }
23 #endif
24