xref: /petsc/lib/petsc/bin/tasClasses.py (revision 62c0789480ad43a8fa3389fdd0599eeba93e5a31)
1#!/usr/bin/env python3
2import numpy as np
3import csv
4
5class File(object):
6    #Class constructor
7    def __init__(self, fileName):
8        self.fileName     = fileName
9        self.numberFields = 0
10        self.fieldList    = []
11        self.fileData     = {}
12
13    #Adds a field to the field list and increases the number by 1
14    def addField(self, field):
15        self.fieldList.append(field)
16        self.numberFields = self.numberFields +1
17
18    #Prints the content of the object
19    def printFile(self):
20        print('\t\t*******************Data for {}***************************'.format(self.fileName))
21        np.set_printoptions(precision=3, linewidth=100)
22        for k in self.fileData:
23            print(" {: >18} : {}\n".format(k,self.fileData[k]))
24        for field in self.fieldList:
25            field.printField()
26
27    #Writes its data as a series of CSV files.  One file for the
28    #main body of data and one for each field.
29    def writeCSV(self):
30        with open(self.fileName + '.csv', mode='w') as csv_file:
31            columnNames = ['Stage','Max Time', 'Mean Time', 'Max Giga Flops', 'Mean Giga Flops', 'LU Factor', 'LU Factor Mean']
32            row = {}
33            numStages = len(self.fileData['Times'])
34            writer = csv.DictWriter(csv_file, fieldnames = columnNames, restval='N/A')
35            writer.writeheader()
36            for stage in range(numStages):
37                row.clear()
38                for item in columnNames:
39                    if item == 'Stage':
40                        row[item]=stage
41                    elif item == 'Max Time':
42                        row[item] =  '{:.3g}'.format((self.fileData['Times'][stage]))
43                    elif item == 'Max Giga Flops':
44                        row[item] =  '{:.3g}'.format((self.fileData['Flops'][stage])/1000000000)
45                    elif item == 'Mean Giga Flops':
46                        row[item] =  '{:.3g}'.format((self.fileData['Mean Flops'][stage])/1000000000)
47                    else:
48                        row[item]='{:.3g}'.format((self.fileData[item][stage]))
49                writer.writerow(row)
50        for field in self.fieldList:
51            with open(self.fileName + '_' + field.fieldName+'.csv', mode='w') as csv_file:
52                columnNames = ['Stage', 'dofs', 'Errors', 'Alpha', 'Beta', 'Convergence Rate']
53                writer = csv.DictWriter(csv_file, fieldnames = columnNames, restval='N/A')
54                writer.writeheader()
55                for stage in range(numStages):
56                    row.clear()
57                    for item in columnNames:
58                        if item == 'Stage':
59                            row[item] = stage
60                        elif item == 'Alpha':
61                            row[item] =  '{:.3g}'.format(field.alpha)
62                        elif item == 'Beta':
63                            row[item] =  '{:.3g}'.format(field.beta)
64                        elif item == 'Convergence Rate':
65                            row[item] =  '{:.3g}'.format(field.cRate)
66                        else:
67                            row[item] = '{:.3g}'.format((field.fieldData[item][stage]))
68                    writer.writerow(row)
69
70
71
72class Field(File):
73    #Class constructor
74    def __init__(self, fileName, fieldName, alpha=0, cRate=0, beta=0):
75        File.__init__(self, fileName)
76        self.fieldName = fieldName
77        self.fieldData = {}
78        self.alpha     = alpha
79        self.cRate     = cRate
80        self.beta      = beta
81
82    def setAlpha(self, alpha):
83        self.alpha = alpha
84
85    def setBeta(self, beta):
86        self.beta = beta
87
88    def setConvergeRate(self, cRate):
89        self.cRate = cRate
90
91    def printField(self):
92        print('**********Data for Field {}************'.format(self.fieldName))
93        for k in self.fieldData:
94            print(" {: >18} : {}\n".format(k,self.fieldData[k]))
95