1*a32db64dSJames Wright#!/usr/bin/env python3 2*a32db64dSJames Wrightimport torch 3*a32db64dSJames Wrightimport torch.nn as nn 4*a32db64dSJames Wrightfrom pathlib import Path 5*a32db64dSJames Wrightimport numpy as np 6*a32db64dSJames Wright 7*a32db64dSJames Wright 8*a32db64dSJames Wrightnew_parameters_Path = Path('../../dd_sgs_data') 9*a32db64dSJames Wright 10*a32db64dSJames Wrightweights = [] 11*a32db64dSJames Wrightbiases = [] 12*a32db64dSJames Wrightweights.append(np.loadtxt(new_parameters_Path / 'w1.dat', skiprows=1).reshape(6, 20).T) 13*a32db64dSJames Wrightweights.append(np.loadtxt(new_parameters_Path / 'w2.dat', skiprows=1).reshape(20, 6).T) 14*a32db64dSJames Wrightbiases.append(np.loadtxt(new_parameters_Path / 'b1.dat', skiprows=1)) 15*a32db64dSJames Wrightbiases.append(np.loadtxt(new_parameters_Path / 'b2.dat', skiprows=1)) 16*a32db64dSJames Wright 17*a32db64dSJames Wright# Anisotropic SGS model for LES developed by Aviral Prakash and John A. Evans at UCB 18*a32db64dSJames Wright 19*a32db64dSJames Wright 20*a32db64dSJames Wrightclass anisoSGS(nn.Module): 21*a32db64dSJames Wright # The class takes as inputs the input and output dimensions and the number of layers 22*a32db64dSJames Wright def __init__(self, inputDim=6, outputDim=6, numNeurons=20, numLayers=1): 23*a32db64dSJames Wright super().__init__() 24*a32db64dSJames Wright self.ndIn = inputDim 25*a32db64dSJames Wright self.ndOut = outputDim 26*a32db64dSJames Wright self.nNeurons = numNeurons 27*a32db64dSJames Wright self.nLayers = numLayers 28*a32db64dSJames Wright self.net = nn.Sequential( 29*a32db64dSJames Wright nn.Linear(self.ndIn, self.nNeurons), 30*a32db64dSJames Wright nn.LeakyReLU(0.3), 31*a32db64dSJames Wright nn.Linear(self.nNeurons, self.ndOut)) 32*a32db64dSJames Wright 33*a32db64dSJames Wright # Define the method to do a forward pass 34*a32db64dSJames Wright def forward(self, x): 35*a32db64dSJames Wright return self.net(x) 36*a32db64dSJames Wright 37*a32db64dSJames Wright 38*a32db64dSJames Wrightdef load_n_trace_model(model_name): 39*a32db64dSJames Wright # Instantiate PT model and load state dict 40*a32db64dSJames Wright model = anisoSGS() 41*a32db64dSJames Wright model.load_state_dict(torch.load(f'{model_name}.pt', map_location=torch.device('cpu'))) 42*a32db64dSJames Wright model.double() 43*a32db64dSJames Wright 44*a32db64dSJames Wright # Change individual model weights 45*a32db64dSJames Wright with torch.no_grad(): 46*a32db64dSJames Wright for i, layer in enumerate([0, 2]): 47*a32db64dSJames Wright m, n = model.net[layer].weight.shape 48*a32db64dSJames Wright print('weight shape', m, n) 49*a32db64dSJames Wright 50*a32db64dSJames Wright model.net[layer].weight[...] = torch.from_numpy(weights[i])[...] 51*a32db64dSJames Wright model.net[layer].bias[...] = torch.from_numpy(biases[i])[...] 52*a32db64dSJames Wright 53*a32db64dSJames Wright # Prepare model for inference 54*a32db64dSJames Wright dummy_input = torch.randn(512, 6, dtype=torch.float64, device='cpu') 55*a32db64dSJames Wright with torch.no_grad(): 56*a32db64dSJames Wright # model_script = torch.jit.script(model) 57*a32db64dSJames Wright # torch.jit.save(model_script, f"{model_name}_fp64_jit.ptc") 58*a32db64dSJames Wright 59*a32db64dSJames Wright model = torch.jit.trace(model, dummy_input) 60*a32db64dSJames Wright torch.jit.save(model, f"{model_name}_fp64_jit.pt") 61*a32db64dSJames Wright 62*a32db64dSJames Wright return model 63*a32db64dSJames Wright 64*a32db64dSJames Wright 65*a32db64dSJames Wrightdef main(): 66*a32db64dSJames Wright model_name = 'NNModel_HIT' 67*a32db64dSJames Wright model = load_n_trace_model(model_name) 68*a32db64dSJames Wright 69*a32db64dSJames Wright 70*a32db64dSJames Wrightif __name__ == '__main__': 71*a32db64dSJames Wright main() 72