1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <petsc.h> 9 10 #ifdef __cplusplus 11 extern "C" { 12 #endif 13 14 typedef enum { 15 TORCH_DEVICE_CPU, 16 TORCH_DEVICE_CUDA, 17 TORCH_DEVICE_HIP, 18 TORCH_DEVICE_XPU, 19 } TorchDeviceType; 20 static const char *const TorchDeviceTypes[] = {"CPU", "CUDA", "HIP", "XPU", "TorchDeviceType", "TORCH_DEVICE_", NULL}; 21 22 PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum); 23 PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc); 24 25 #ifdef __cplusplus 26 } 27 #endif 28