#!/usr/bin/env python3
#
#    Computers speed up of Streams benchmark results generated by make streams and plots
#
#    matplotlib can switch between different backends hence this needs to be run
#    twice to first generate a file and then display a window
#
from __future__ import print_function
import os
#
def process(streamstype,fileoutput):
  import re

  ff = open('scaling.log')
  data = ff.read()
  ff.close()

  s = data.split('\n')
  triads = {} # stream triad bandwidth data points
  speedups = {}
  size = 0
  for i in s[0:-1]:
      i = i.split()
      triads[size] = float(i[1])/1000 # MB/s to GB/s
      size = size + 1

  if size < 2: return

  triads = list(triads.values())
  speedups = {}
  for i in range(0,size):
    speedups[i] = triads[i]/triads[0]

  try:
    import matplotlib
    from matplotlib.ticker import MaxNLocator
  except:
    print("Unable to open matplotlib to plot speedup")
    return

  try:
    if fileoutput: matplotlib.use('Agg')
    import matplotlib.pyplot as plt
  except:
    print("Unable to open matplotlib to plot speedup")
    return

  try:
    fig, ax1 = plt.subplots(layout='constrained')
    plt.title(streamstype+' Perfect and Streams Speedup')
    ax2 = ax1.twinx()
    ax1.set_autoscaley_on(False)

    r = range(1,size+1)
    speedups = speedups.values()

    # make sure that actual bandwidth values (as opposed to perfect speedup) takes
    # at least a third of the y axis
    ymax = min(size, 3*max(speedups))
    ymin = min(1, min(speedups))
    if ymin < 1: ymin = 0

    ax1.set_xlim(1,size)
    ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax1.set_ylim([ymin,ymax])
    ax1.set_xlabel('Number of processes/threads')
    ax1.set_ylabel('Achieved Speedup')
    ax1.plot(r,r,'b',label='Ideal speedup')
    ax1.plot(r,speedups,'r-o', label='Achieved speedup')
    ax2.set_autoscaley_on(False)
    ax2.set_xlim([1,size])
    ax2.set_ylim([min(triads),max(triads)])
    ax2.set_ylabel("Achieved Bandwidth (GB/s)")
    ax2.plot(r,triads,'g-o', label='Achieved bandwidth')

    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='best')

    plt.show()
    if fileoutput: plt.savefig(streamstype+'scaling.png')
    if fileoutput: print("See graph in the file src/benchmarks/streams/"+streamstype+"scaling.png")
  except Exception as e:
    if fileoutput: print("Unable to plot speedup to a file")
    else: print("Unable to display speedup plot")
    return

  ff.close()

# plot bandwidth data in scaling.log under the current directory.
#
# ./process.py arg1 arg2
#   arg1: stream type, e.g., MPI, OpenMP, CUDA etc
#   arg2: optional, can be anything, to indicate if a <stream type>scaling.png should be generated
if __name__ ==  '__main__':
  import sys

  process(sys.argv[1],len(sys.argv)-2)
