1#!/usr/bin/env python3 2import torch 3import torch.nn as nn 4from pathlib import Path 5import numpy as np 6 7 8new_parameters_Path = Path('../../dd_sgs_data') 9 10weights = [] 11biases = [] 12weights.append(np.loadtxt(new_parameters_Path / 'w1.dat', skiprows=1).reshape(6, 20).T) 13weights.append(np.loadtxt(new_parameters_Path / 'w2.dat', skiprows=1).reshape(20, 6).T) 14biases.append(np.loadtxt(new_parameters_Path / 'b1.dat', skiprows=1)) 15biases.append(np.loadtxt(new_parameters_Path / 'b2.dat', skiprows=1)) 16 17# Anisotropic SGS model for LES developed by Aviral Prakash and John A. Evans at UCB 18 19 20class anisoSGS(nn.Module): 21 # The class takes as inputs the input and output dimensions and the number of layers 22 def __init__(self, inputDim=6, outputDim=6, numNeurons=20, numLayers=1): 23 super().__init__() 24 self.ndIn = inputDim 25 self.ndOut = outputDim 26 self.nNeurons = numNeurons 27 self.nLayers = numLayers 28 self.net = nn.Sequential( 29 nn.Linear(self.ndIn, self.nNeurons), 30 nn.LeakyReLU(0.3), 31 nn.Linear(self.nNeurons, self.ndOut)) 32 33 # Define the method to do a forward pass 34 def forward(self, x): 35 return self.net(x) 36 37 38def load_n_trace_model(model_name): 39 # Instantiate PT model and load state dict 40 model = anisoSGS() 41 model.load_state_dict(torch.load(f'{model_name}.pt', map_location=torch.device('cpu'))) 42 model.double() 43 44 # Change individual model weights 45 with torch.no_grad(): 46 for i, layer in enumerate([0, 2]): 47 m, n = model.net[layer].weight.shape 48 print('weight shape', m, n) 49 50 model.net[layer].weight[...] = torch.from_numpy(weights[i])[...] 51 model.net[layer].bias[...] = torch.from_numpy(biases[i])[...] 52 53 # Prepare model for inference 54 dummy_input = torch.randn(512, 6, dtype=torch.float64, device='cpu') 55 with torch.no_grad(): 56 # model_script = torch.jit.script(model) 57 # torch.jit.save(model_script, f"{model_name}_fp64_jit.ptc") 58 59 model = torch.jit.trace(model, dummy_input) 60 torch.jit.save(model, f"{model_name}_fp64_jit.pt") 61 62 return model 63 64 65def main(): 66 model_name = 'NNModel_HIT' 67 model = load_n_trace_model(model_name) 68 69 70if __name__ == '__main__': 71 main() 72