126bdecf3SSebastian Grimberg#!/usr/bin/env python3 226bdecf3SSebastian Grimberg 3*9ba83ac0SJeremy L Thompson# Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors. 45aed82e4SJeremy L Thompson# All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 526bdecf3SSebastian Grimberg# 65aed82e4SJeremy L Thompson# SPDX-License-Identifier: BSD-2-Clause 726bdecf3SSebastian Grimberg# 85aed82e4SJeremy L Thompson# This file is part of CEED: http://github.com/ceed 926bdecf3SSebastian Grimberg 1026bdecf3SSebastian Grimbergimport argparse 1126bdecf3SSebastian Grimbergimport os 12acc0bb12SSebastian Grimbergimport glob 1326bdecf3SSebastian Grimbergimport re 14acc0bb12SSebastian Grimbergimport shutil 1526bdecf3SSebastian Grimbergimport subprocess 1626bdecf3SSebastian Grimbergimport pandas as pd 1726bdecf3SSebastian Grimbergimport time 1826bdecf3SSebastian Grimberg 1926bdecf3SSebastian Grimbergscript_dir = os.path.dirname(os.path.realpath(__file__)) 2026bdecf3SSebastian Grimberg 2126bdecf3SSebastian Grimberg 22acc0bb12SSebastian Grimbergdef benchmark(nb, build_cmd, backend, log): 23acc0bb12SSebastian Grimberg # Build for new NB 24acc0bb12SSebastian Grimberg ceed_magma_h = f"{script_dir}/../ceed-magma.h" 25acc0bb12SSebastian Grimberg shutil.copyfile(ceed_magma_h, ceed_magma_h + ".backup") 26acc0bb12SSebastian Grimberg with open(ceed_magma_h, "r") as f: 2726bdecf3SSebastian Grimberg data = f.read() 2826bdecf3SSebastian Grimberg data = re.sub( 29acc0bb12SSebastian Grimberg r".*(#define ceed_magma_queue_sync\(\.\.\.\)).*", 30acc0bb12SSebastian Grimberg r"\1 " + 31acc0bb12SSebastian Grimberg ("hipDeviceSynchronize()" if "hip" in backend else "cudaDeviceSynchronize()"), 3226bdecf3SSebastian Grimberg data) 33acc0bb12SSebastian Grimberg with open(ceed_magma_h, "w") as f: 3426bdecf3SSebastian Grimberg f.write(data) 35acc0bb12SSebastian Grimberg 36acc0bb12SSebastian Grimberg ceed_magma_gemm_selector_cpp = f"{script_dir}/../ceed-magma-gemm-selector.cpp" 37acc0bb12SSebastian Grimberg shutil.copyfile( 38acc0bb12SSebastian Grimberg ceed_magma_gemm_selector_cpp, 39acc0bb12SSebastian Grimberg ceed_magma_gemm_selector_cpp + 40acc0bb12SSebastian Grimberg ".backup") 41acc0bb12SSebastian Grimberg with open(ceed_magma_gemm_selector_cpp, "r") as f: 42acc0bb12SSebastian Grimberg data = f.read() 43acc0bb12SSebastian Grimberg data = re.sub( 44acc0bb12SSebastian Grimberg ".*(#define CEED_AUTOTUNE_RTC_NB).*", 45acc0bb12SSebastian Grimberg r"\1 " + f"{nb}", 46acc0bb12SSebastian Grimberg data) 47acc0bb12SSebastian Grimberg with open(ceed_magma_gemm_selector_cpp, "w") as f: 48acc0bb12SSebastian Grimberg f.write(data) 49acc0bb12SSebastian Grimberg 5026bdecf3SSebastian Grimberg subprocess.run(build_cmd, cwd=f"{script_dir}/../../..") 51acc0bb12SSebastian Grimberg subprocess.run(["make", "tuning", "OPT=-O0"], cwd=f"{script_dir}") 52acc0bb12SSebastian Grimberg shutil.move(ceed_magma_h + ".backup", ceed_magma_h) 53acc0bb12SSebastian Grimberg shutil.move(ceed_magma_gemm_selector_cpp + 54acc0bb12SSebastian Grimberg ".backup", ceed_magma_gemm_selector_cpp) 5526bdecf3SSebastian Grimberg 56acc0bb12SSebastian Grimberg # Run the benchmark 57acc0bb12SSebastian Grimberg with open(log, "w") as f: 58acc0bb12SSebastian Grimberg process = subprocess.run( 59acc0bb12SSebastian Grimberg [f"{script_dir}/tuning", f"{backend}"], stdout=f, stderr=f) 60acc0bb12SSebastian Grimberg csv = pd.read_csv( 61acc0bb12SSebastian Grimberg log, 62acc0bb12SSebastian Grimberg header=None, 63acc0bb12SSebastian Grimberg delim_whitespace=True, 64acc0bb12SSebastian Grimberg names=[ 65acc0bb12SSebastian Grimberg "P", 66acc0bb12SSebastian Grimberg "Q", 67acc0bb12SSebastian Grimberg "N", 68acc0bb12SSebastian Grimberg "Q_COMP", 69acc0bb12SSebastian Grimberg "TRANS", 70acc0bb12SSebastian Grimberg "MFLOPS"]) 71acc0bb12SSebastian Grimberg return csv 7226bdecf3SSebastian Grimberg 7326bdecf3SSebastian Grimberg 7426bdecf3SSebastian Grimbergif __name__ == "__main__": 7526bdecf3SSebastian Grimberg # Command line arguments 7626bdecf3SSebastian Grimberg parser = argparse.ArgumentParser("MAGMA RTC autotuning") 7726bdecf3SSebastian Grimberg parser.add_argument( 7826bdecf3SSebastian Grimberg "-arch", 7926bdecf3SSebastian Grimberg help="Device architecture name for tuning data", 8026bdecf3SSebastian Grimberg required=True) 8126bdecf3SSebastian Grimberg parser.add_argument( 8226bdecf3SSebastian Grimberg "-max-nb", 8326bdecf3SSebastian Grimberg help="Maximum block size NB to consider for autotuning", 8426bdecf3SSebastian Grimberg default=32, 8526bdecf3SSebastian Grimberg type=int) 8626bdecf3SSebastian Grimberg parser.add_argument( 8726bdecf3SSebastian Grimberg "-build-cmd", 8826bdecf3SSebastian Grimberg help="Command used to build libCEED from the source root directory", 8926bdecf3SSebastian Grimberg default="make") 90acc0bb12SSebastian Grimberg parser.add_argument( 91acc0bb12SSebastian Grimberg "-ceed", 92acc0bb12SSebastian Grimberg help="Ceed resource specifier", 93acc0bb12SSebastian Grimberg default="/cpu/self") 9426bdecf3SSebastian Grimberg args = parser.parse_args() 9526bdecf3SSebastian Grimberg 9678d85032SSebastian Grimberg nb = 1 9778d85032SSebastian Grimberg while nb <= args.max_nb: 9826bdecf3SSebastian Grimberg # Run the benchmarks 9926bdecf3SSebastian Grimberg start = time.perf_counter() 100acc0bb12SSebastian Grimberg data_nb = benchmark(nb, args.build_cmd, args.ceed, 101acc0bb12SSebastian Grimberg f"{script_dir}/output-nb-{nb}.txt") 10226bdecf3SSebastian Grimberg print( 10326bdecf3SSebastian Grimberg f"Finished benchmarks for NB = {nb}, backend = {args.ceed} ({time.perf_counter() - start} s)") 10426bdecf3SSebastian Grimberg 10526bdecf3SSebastian Grimberg # Save the data for the highest performing NB 10626bdecf3SSebastian Grimberg if nb == 1: 10726bdecf3SSebastian Grimberg data = pd.DataFrame(data_nb) 108acc0bb12SSebastian Grimberg data["NB"] = nb 10926bdecf3SSebastian Grimberg else: 110acc0bb12SSebastian Grimberg idx = data_nb["MFLOPS"] > 1.05 * data["MFLOPS"] 111acc0bb12SSebastian Grimberg data.loc[idx, "NB"] = nb 112acc0bb12SSebastian Grimberg data.loc[idx, "MFLOPS"] = data_nb.loc[idx, "MFLOPS"] 11326bdecf3SSebastian Grimberg 11478d85032SSebastian Grimberg # Speed up the search by considering only some values on NB 11578d85032SSebastian Grimberg if nb < 2: 11678d85032SSebastian Grimberg nb *= 2 11778d85032SSebastian Grimberg elif nb < 8: 11878d85032SSebastian Grimberg nb += 2 11978d85032SSebastian Grimberg else: 12078d85032SSebastian Grimberg nb += 4 12178d85032SSebastian Grimberg 12226bdecf3SSebastian Grimberg # Print the results 123acc0bb12SSebastian Grimberg with open(f"{script_dir}/{args.arch}_rtc.h", "w") as f: 12426bdecf3SSebastian Grimberg f.write( 12526bdecf3SSebastian Grimberg "////////////////////////////////////////////////////////////////////////////////\n") 12626bdecf3SSebastian Grimberg f.write(f"// auto-generated from data on {args.arch}\n\n") 12726bdecf3SSebastian Grimberg 128acc0bb12SSebastian Grimberg rows = data.loc[data["TRANS"] == 1].to_string(header=False, index=False, justify="right", columns=[ 129acc0bb12SSebastian Grimberg "P", "Q", "N", "Q_COMP", "NB"]).split("\n") 13026bdecf3SSebastian Grimberg f.write( 13126bdecf3SSebastian Grimberg "////////////////////////////////////////////////////////////////////////////////\n") 13226bdecf3SSebastian Grimberg f.write( 13326bdecf3SSebastian Grimberg f"std::vector<std::array<int, RECORD_LENGTH_RTC> > drtc_t_{args.arch}" + 13426bdecf3SSebastian Grimberg " = {\n") 13526bdecf3SSebastian Grimberg count = 0 13626bdecf3SSebastian Grimberg for row in rows: 137acc0bb12SSebastian Grimberg f.write(" {" + re.sub(r"([0-9])(\s+)", r"\1,\2", row) + 13826bdecf3SSebastian Grimberg ("},\n" if count < len(rows) - 1 else "}\n")) 13926bdecf3SSebastian Grimberg count += 1 14026bdecf3SSebastian Grimberg f.write("};\n\n") 14126bdecf3SSebastian Grimberg 142acc0bb12SSebastian Grimberg rows = data.loc[data["TRANS"] == 0].to_string(header=False, index=False, justify="right", columns=[ 143acc0bb12SSebastian Grimberg "P", "Q", "N", "Q_COMP", "NB"]).split("\n") 14426bdecf3SSebastian Grimberg f.write( 14526bdecf3SSebastian Grimberg "////////////////////////////////////////////////////////////////////////////////\n") 14626bdecf3SSebastian Grimberg f.write( 14726bdecf3SSebastian Grimberg f"std::vector<std::array<int, RECORD_LENGTH_RTC> > drtc_n_{args.arch}" + 14826bdecf3SSebastian Grimberg " = {\n") 14926bdecf3SSebastian Grimberg count = 0 15026bdecf3SSebastian Grimberg for row in rows: 151acc0bb12SSebastian Grimberg f.write(" {" + re.sub(r"([0-9])(\s+)", r"\1,\2", row) + 15226bdecf3SSebastian Grimberg ("},\n" if count < len(rows) - 1 else "}\n")) 15326bdecf3SSebastian Grimberg count += 1 15426bdecf3SSebastian Grimberg f.write("};\n") 155