xref: /honee/include/sgs_model_torch.h (revision f79b7f20ee7a2d310fc65d546f69afbf5e560555)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <petsc.h>
9 
10 #ifdef __cplusplus
11 extern "C" {
12 #endif
13 
14 typedef enum {
15   TORCH_DEVICE_CPU,
16   TORCH_DEVICE_CUDA,
17   TORCH_DEVICE_HIP,
18   TORCH_DEVICE_XPU,
19 } TorchDeviceType;
20 static const char *const TorchDeviceTypes[] = {"CPU", "CUDA", "HIP", "XPU", "TorchDeviceType", "TORCH_DEVICE_", NULL};
21 
22 PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum);
23 PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc);
24 
25 #ifdef __cplusplus
26 }
27 #endif
28