xref: /honee/tests/createPyTorchModel/update_weights.py (revision a32db64d340db16914d4892be21e91c50f2a7cbd)
1*a32db64dSJames Wright#!/usr/bin/env python3
2*a32db64dSJames Wrightimport torch
3*a32db64dSJames Wrightimport torch.nn as nn
4*a32db64dSJames Wrightfrom pathlib import Path
5*a32db64dSJames Wrightimport numpy as np
6*a32db64dSJames Wright
7*a32db64dSJames Wright
8*a32db64dSJames Wrightnew_parameters_Path = Path('../../dd_sgs_data')
9*a32db64dSJames Wright
10*a32db64dSJames Wrightweights = []
11*a32db64dSJames Wrightbiases = []
12*a32db64dSJames Wrightweights.append(np.loadtxt(new_parameters_Path / 'w1.dat', skiprows=1).reshape(6, 20).T)
13*a32db64dSJames Wrightweights.append(np.loadtxt(new_parameters_Path / 'w2.dat', skiprows=1).reshape(20, 6).T)
14*a32db64dSJames Wrightbiases.append(np.loadtxt(new_parameters_Path / 'b1.dat', skiprows=1))
15*a32db64dSJames Wrightbiases.append(np.loadtxt(new_parameters_Path / 'b2.dat', skiprows=1))
16*a32db64dSJames Wright
17*a32db64dSJames Wright# Anisotropic SGS model for LES developed by Aviral Prakash and John A. Evans at UCB
18*a32db64dSJames Wright
19*a32db64dSJames Wright
20*a32db64dSJames Wrightclass anisoSGS(nn.Module):
21*a32db64dSJames Wright    # The class takes as inputs the input and output dimensions and the number of layers
22*a32db64dSJames Wright    def __init__(self, inputDim=6, outputDim=6, numNeurons=20, numLayers=1):
23*a32db64dSJames Wright        super().__init__()
24*a32db64dSJames Wright        self.ndIn = inputDim
25*a32db64dSJames Wright        self.ndOut = outputDim
26*a32db64dSJames Wright        self.nNeurons = numNeurons
27*a32db64dSJames Wright        self.nLayers = numLayers
28*a32db64dSJames Wright        self.net = nn.Sequential(
29*a32db64dSJames Wright            nn.Linear(self.ndIn, self.nNeurons),
30*a32db64dSJames Wright            nn.LeakyReLU(0.3),
31*a32db64dSJames Wright            nn.Linear(self.nNeurons, self.ndOut))
32*a32db64dSJames Wright
33*a32db64dSJames Wright    # Define the method to do a forward pass
34*a32db64dSJames Wright    def forward(self, x):
35*a32db64dSJames Wright        return self.net(x)
36*a32db64dSJames Wright
37*a32db64dSJames Wright
38*a32db64dSJames Wrightdef load_n_trace_model(model_name):
39*a32db64dSJames Wright    # Instantiate PT model and load state dict
40*a32db64dSJames Wright    model = anisoSGS()
41*a32db64dSJames Wright    model.load_state_dict(torch.load(f'{model_name}.pt', map_location=torch.device('cpu')))
42*a32db64dSJames Wright    model.double()
43*a32db64dSJames Wright
44*a32db64dSJames Wright    # Change individual model weights
45*a32db64dSJames Wright    with torch.no_grad():
46*a32db64dSJames Wright        for i, layer in enumerate([0, 2]):
47*a32db64dSJames Wright            m, n = model.net[layer].weight.shape
48*a32db64dSJames Wright            print('weight shape', m, n)
49*a32db64dSJames Wright
50*a32db64dSJames Wright            model.net[layer].weight[...] = torch.from_numpy(weights[i])[...]
51*a32db64dSJames Wright            model.net[layer].bias[...] = torch.from_numpy(biases[i])[...]
52*a32db64dSJames Wright
53*a32db64dSJames Wright    # Prepare model for inference
54*a32db64dSJames Wright    dummy_input = torch.randn(512, 6, dtype=torch.float64, device='cpu')
55*a32db64dSJames Wright    with torch.no_grad():
56*a32db64dSJames Wright        # model_script = torch.jit.script(model)
57*a32db64dSJames Wright        # torch.jit.save(model_script, f"{model_name}_fp64_jit.ptc")
58*a32db64dSJames Wright
59*a32db64dSJames Wright        model = torch.jit.trace(model, dummy_input)
60*a32db64dSJames Wright        torch.jit.save(model, f"{model_name}_fp64_jit.pt")
61*a32db64dSJames Wright
62*a32db64dSJames Wright    return model
63*a32db64dSJames Wright
64*a32db64dSJames Wright
65*a32db64dSJames Wrightdef main():
66*a32db64dSJames Wright    model_name = 'NNModel_HIT'
67*a32db64dSJames Wright    model = load_n_trace_model(model_name)
68*a32db64dSJames Wright
69*a32db64dSJames Wright
70*a32db64dSJames Wrightif __name__ == '__main__':
71*a32db64dSJames Wright    main()
72