xref: /libCEED/backends/magma/tuning/generate_tuning.py (revision acc0bb127f9d52b89fa0cb7f74c98dc79acc3cb0)
1#!/usr/bin/env python3
2
3# Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
4# Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
5# All Rights reserved. See files LICENSE and NOTICE for details.
6#
7# This file is part of CEED, a collection of benchmarks, miniapps, software
8# libraries and APIs for efficient high-order finite element and spectral
9# element discretizations for exascale applications. For more information and
10# source code availability see http://github.com/ceed
11#
12# The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
13# a collaborative effort of two U.S. Department of Energy organizations (Office
14# of Science and the National Nuclear Security Administration) responsible for
15# the planning and preparation of a capable exascale ecosystem, including
16# software, applications, hardware, advanced system engineering and early
17# testbed platforms, in support of the nation's exascale computing imperative.
18
19import argparse
20import os
21import glob
22import re
23import shutil
24import subprocess
25import pandas as pd
26import time
27
28script_dir = os.path.dirname(os.path.realpath(__file__))
29
30
31def benchmark(nb, build_cmd, backend, log):
32    # Build for new NB
33    ceed_magma_h = f"{script_dir}/../ceed-magma.h"
34    shutil.copyfile(ceed_magma_h, ceed_magma_h + ".backup")
35    with open(ceed_magma_h, "r") as f:
36        data = f.read()
37        data = re.sub(
38            r".*(#define ceed_magma_queue_sync\(\.\.\.\)).*",
39            r"\1 " +
40            ("hipDeviceSynchronize()" if "hip" in backend else "cudaDeviceSynchronize()"),
41            data)
42    with open(ceed_magma_h, "w") as f:
43        f.write(data)
44
45    ceed_magma_gemm_selector_cpp = f"{script_dir}/../ceed-magma-gemm-selector.cpp"
46    shutil.copyfile(
47        ceed_magma_gemm_selector_cpp,
48        ceed_magma_gemm_selector_cpp +
49        ".backup")
50    with open(ceed_magma_gemm_selector_cpp, "r") as f:
51        data = f.read()
52        data = re.sub(
53            ".*(#define CEED_AUTOTUNE_RTC_NB).*",
54            r"\1 " + f"{nb}",
55            data)
56    with open(ceed_magma_gemm_selector_cpp, "w") as f:
57        f.write(data)
58
59    subprocess.run(build_cmd, cwd=f"{script_dir}/../../..")
60    subprocess.run(["make", "tuning", "OPT=-O0"], cwd=f"{script_dir}")
61    shutil.move(ceed_magma_h + ".backup", ceed_magma_h)
62    shutil.move(ceed_magma_gemm_selector_cpp +
63                ".backup", ceed_magma_gemm_selector_cpp)
64
65    # Run the benchmark
66    with open(log, "w") as f:
67        process = subprocess.run(
68            [f"{script_dir}/tuning", f"{backend}"], stdout=f, stderr=f)
69    csv = pd.read_csv(
70        log,
71        header=None,
72        delim_whitespace=True,
73        names=[
74            "P",
75            "Q",
76            "N",
77            "Q_COMP",
78            "TRANS",
79            "MFLOPS"])
80    return csv
81
82
83if __name__ == "__main__":
84    # Command line arguments
85    parser = argparse.ArgumentParser("MAGMA RTC autotuning")
86    parser.add_argument(
87        "-arch",
88        help="Device architecture name for tuning data",
89        required=True)
90    parser.add_argument(
91        "-max-nb",
92        help="Maximum block size NB to consider for autotuning",
93        default=32,
94        type=int)
95    parser.add_argument(
96        "-build-cmd",
97        help="Command used to build libCEED from the source root directory",
98        default="make")
99    parser.add_argument(
100        "-ceed",
101        help="Ceed resource specifier",
102        default="/cpu/self")
103    args = parser.parse_args()
104
105    for nb in range(1, args.max_nb + 1):
106        # Run the benchmarks
107        start = time.perf_counter()
108        data_nb = benchmark(nb, args.build_cmd, args.ceed,
109                            f"{script_dir}/output-nb-{nb}.txt")
110        print(
111            f"Finished benchmarks for NB = {nb}, backend = {args.ceed} ({time.perf_counter() - start} s)")
112
113        # Save the data for the highest performing NB
114        if nb == 1:
115            data = pd.DataFrame(data_nb)
116            data["NB"] = nb
117        else:
118            idx = data_nb["MFLOPS"] > 1.05 * data["MFLOPS"]
119            data.loc[idx, "NB"] = nb
120            data.loc[idx, "MFLOPS"] = data_nb.loc[idx, "MFLOPS"]
121
122    # Print the results
123    with open(f"{script_dir}/{args.arch}_rtc.h", "w") as f:
124        f.write(
125            "////////////////////////////////////////////////////////////////////////////////\n")
126        f.write(f"// auto-generated from data on {args.arch}\n\n")
127
128        rows = data.loc[data["TRANS"] == 1].to_string(header=False, index=False, justify="right", columns=[
129                                                      "P", "Q", "N", "Q_COMP", "NB"]).split("\n")
130        f.write(
131            "////////////////////////////////////////////////////////////////////////////////\n")
132        f.write(
133            f"std::vector<std::array<int, RECORD_LENGTH_RTC> > drtc_t_{args.arch}" +
134            " = {\n")
135        count = 0
136        for row in rows:
137            f.write("    {" + re.sub(r"([0-9])(\s+)", r"\1,\2", row) +
138                    ("},\n" if count < len(rows) - 1 else "}\n"))
139            count += 1
140        f.write("};\n\n")
141
142        rows = data.loc[data["TRANS"] == 0].to_string(header=False, index=False, justify="right", columns=[
143                                                      "P", "Q", "N", "Q_COMP", "NB"]).split("\n")
144        f.write(
145            "////////////////////////////////////////////////////////////////////////////////\n")
146        f.write(
147            f"std::vector<std::array<int, RECORD_LENGTH_RTC> > drtc_n_{args.arch}" +
148            " = {\n")
149        count = 0
150        for row in rows:
151            f.write("    {" + re.sub(r"([0-9])(\s+)", r"\1,\2", row) +
152                    ("},\n" if count < len(rows) - 1 else "}\n"))
153            count += 1
154        f.write("};\n")
155