xref: /honee/tests/createPyTorchModel/update_weights.py (revision fc37ad8c2d8e5885e86197268f3a21c51d020b21)
1#!/usr/bin/env python3
2import torch
3import torch.nn as nn
4from pathlib import Path
5import numpy as np
6
7
8new_parameters_Path = Path('../../dd_sgs_data')
9
10weights = []
11biases = []
12weights.append(np.loadtxt(new_parameters_Path / 'w1.dat', skiprows=1).reshape(6, 20).T)
13weights.append(np.loadtxt(new_parameters_Path / 'w2.dat', skiprows=1).reshape(20, 6).T)
14biases.append(np.loadtxt(new_parameters_Path / 'b1.dat', skiprows=1))
15biases.append(np.loadtxt(new_parameters_Path / 'b2.dat', skiprows=1))
16
17# Anisotropic SGS model for LES developed by Aviral Prakash and John A. Evans at UCB
18
19
20class anisoSGS(nn.Module):
21    # The class takes as inputs the input and output dimensions and the number of layers
22    def __init__(self, inputDim=6, outputDim=6, numNeurons=20, numLayers=1):
23        super().__init__()
24        self.ndIn = inputDim
25        self.ndOut = outputDim
26        self.nNeurons = numNeurons
27        self.nLayers = numLayers
28        self.net = nn.Sequential(
29            nn.Linear(self.ndIn, self.nNeurons),
30            nn.LeakyReLU(0.3),
31            nn.Linear(self.nNeurons, self.ndOut))
32
33    # Define the method to do a forward pass
34    def forward(self, x):
35        return self.net(x)
36
37
38def load_n_trace_model(model_name):
39    # Instantiate PT model and load state dict
40    model = anisoSGS()
41    model.load_state_dict(torch.load(f'{model_name}.pt', map_location=torch.device('cpu')))
42    model.double()
43
44    # Change individual model weights
45    with torch.no_grad():
46        for i, layer in enumerate([0, 2]):
47            m, n = model.net[layer].weight.shape
48            print('weight shape', m, n)
49
50            model.net[layer].weight[...] = torch.from_numpy(weights[i])[...]
51            model.net[layer].bias[...] = torch.from_numpy(biases[i])[...]
52
53    # Prepare model for inference
54    dummy_input = torch.randn(512, 6, dtype=torch.float64, device='cpu')
55    with torch.no_grad():
56        # model_script = torch.jit.script(model)
57        # torch.jit.save(model_script, f"{model_name}_fp64_jit.ptc")
58
59        model = torch.jit.trace(model, dummy_input)
60        torch.jit.save(model, f"{model_name}_fp64_jit.pt")
61
62    return model
63
64
65def main():
66    model_name = 'NNModel_HIT'
67    model = load_n_trace_model(model_name)
68
69
70if __name__ == '__main__':
71    main()
72