xref: /libCEED/backends/magma/tuning/generate_tuning.py (revision d4cc18453651bd0f94c1a2e078b2646a92dafdcc)
126bdecf3SSebastian Grimberg#!/usr/bin/env python3
226bdecf3SSebastian Grimberg
3*9ba83ac0SJeremy L Thompson# Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
45aed82e4SJeremy L Thompson# All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
526bdecf3SSebastian Grimberg#
65aed82e4SJeremy L Thompson# SPDX-License-Identifier: BSD-2-Clause
726bdecf3SSebastian Grimberg#
85aed82e4SJeremy L Thompson# This file is part of CEED:  http://github.com/ceed
926bdecf3SSebastian Grimberg
1026bdecf3SSebastian Grimbergimport argparse
1126bdecf3SSebastian Grimbergimport os
12acc0bb12SSebastian Grimbergimport glob
1326bdecf3SSebastian Grimbergimport re
14acc0bb12SSebastian Grimbergimport shutil
1526bdecf3SSebastian Grimbergimport subprocess
1626bdecf3SSebastian Grimbergimport pandas as pd
1726bdecf3SSebastian Grimbergimport time
1826bdecf3SSebastian Grimberg
1926bdecf3SSebastian Grimbergscript_dir = os.path.dirname(os.path.realpath(__file__))
2026bdecf3SSebastian Grimberg
2126bdecf3SSebastian Grimberg
22acc0bb12SSebastian Grimbergdef benchmark(nb, build_cmd, backend, log):
23acc0bb12SSebastian Grimberg    # Build for new NB
24acc0bb12SSebastian Grimberg    ceed_magma_h = f"{script_dir}/../ceed-magma.h"
25acc0bb12SSebastian Grimberg    shutil.copyfile(ceed_magma_h, ceed_magma_h + ".backup")
26acc0bb12SSebastian Grimberg    with open(ceed_magma_h, "r") as f:
2726bdecf3SSebastian Grimberg        data = f.read()
2826bdecf3SSebastian Grimberg        data = re.sub(
29acc0bb12SSebastian Grimberg            r".*(#define ceed_magma_queue_sync\(\.\.\.\)).*",
30acc0bb12SSebastian Grimberg            r"\1 " +
31acc0bb12SSebastian Grimberg            ("hipDeviceSynchronize()" if "hip" in backend else "cudaDeviceSynchronize()"),
3226bdecf3SSebastian Grimberg            data)
33acc0bb12SSebastian Grimberg    with open(ceed_magma_h, "w") as f:
3426bdecf3SSebastian Grimberg        f.write(data)
35acc0bb12SSebastian Grimberg
36acc0bb12SSebastian Grimberg    ceed_magma_gemm_selector_cpp = f"{script_dir}/../ceed-magma-gemm-selector.cpp"
37acc0bb12SSebastian Grimberg    shutil.copyfile(
38acc0bb12SSebastian Grimberg        ceed_magma_gemm_selector_cpp,
39acc0bb12SSebastian Grimberg        ceed_magma_gemm_selector_cpp +
40acc0bb12SSebastian Grimberg        ".backup")
41acc0bb12SSebastian Grimberg    with open(ceed_magma_gemm_selector_cpp, "r") as f:
42acc0bb12SSebastian Grimberg        data = f.read()
43acc0bb12SSebastian Grimberg        data = re.sub(
44acc0bb12SSebastian Grimberg            ".*(#define CEED_AUTOTUNE_RTC_NB).*",
45acc0bb12SSebastian Grimberg            r"\1 " + f"{nb}",
46acc0bb12SSebastian Grimberg            data)
47acc0bb12SSebastian Grimberg    with open(ceed_magma_gemm_selector_cpp, "w") as f:
48acc0bb12SSebastian Grimberg        f.write(data)
49acc0bb12SSebastian Grimberg
5026bdecf3SSebastian Grimberg    subprocess.run(build_cmd, cwd=f"{script_dir}/../../..")
51acc0bb12SSebastian Grimberg    subprocess.run(["make", "tuning", "OPT=-O0"], cwd=f"{script_dir}")
52acc0bb12SSebastian Grimberg    shutil.move(ceed_magma_h + ".backup", ceed_magma_h)
53acc0bb12SSebastian Grimberg    shutil.move(ceed_magma_gemm_selector_cpp +
54acc0bb12SSebastian Grimberg                ".backup", ceed_magma_gemm_selector_cpp)
5526bdecf3SSebastian Grimberg
56acc0bb12SSebastian Grimberg    # Run the benchmark
57acc0bb12SSebastian Grimberg    with open(log, "w") as f:
58acc0bb12SSebastian Grimberg        process = subprocess.run(
59acc0bb12SSebastian Grimberg            [f"{script_dir}/tuning", f"{backend}"], stdout=f, stderr=f)
60acc0bb12SSebastian Grimberg    csv = pd.read_csv(
61acc0bb12SSebastian Grimberg        log,
62acc0bb12SSebastian Grimberg        header=None,
63acc0bb12SSebastian Grimberg        delim_whitespace=True,
64acc0bb12SSebastian Grimberg        names=[
65acc0bb12SSebastian Grimberg            "P",
66acc0bb12SSebastian Grimberg            "Q",
67acc0bb12SSebastian Grimberg            "N",
68acc0bb12SSebastian Grimberg            "Q_COMP",
69acc0bb12SSebastian Grimberg            "TRANS",
70acc0bb12SSebastian Grimberg            "MFLOPS"])
71acc0bb12SSebastian Grimberg    return csv
7226bdecf3SSebastian Grimberg
7326bdecf3SSebastian Grimberg
7426bdecf3SSebastian Grimbergif __name__ == "__main__":
7526bdecf3SSebastian Grimberg    # Command line arguments
7626bdecf3SSebastian Grimberg    parser = argparse.ArgumentParser("MAGMA RTC autotuning")
7726bdecf3SSebastian Grimberg    parser.add_argument(
7826bdecf3SSebastian Grimberg        "-arch",
7926bdecf3SSebastian Grimberg        help="Device architecture name for tuning data",
8026bdecf3SSebastian Grimberg        required=True)
8126bdecf3SSebastian Grimberg    parser.add_argument(
8226bdecf3SSebastian Grimberg        "-max-nb",
8326bdecf3SSebastian Grimberg        help="Maximum block size NB to consider for autotuning",
8426bdecf3SSebastian Grimberg        default=32,
8526bdecf3SSebastian Grimberg        type=int)
8626bdecf3SSebastian Grimberg    parser.add_argument(
8726bdecf3SSebastian Grimberg        "-build-cmd",
8826bdecf3SSebastian Grimberg        help="Command used to build libCEED from the source root directory",
8926bdecf3SSebastian Grimberg        default="make")
90acc0bb12SSebastian Grimberg    parser.add_argument(
91acc0bb12SSebastian Grimberg        "-ceed",
92acc0bb12SSebastian Grimberg        help="Ceed resource specifier",
93acc0bb12SSebastian Grimberg        default="/cpu/self")
9426bdecf3SSebastian Grimberg    args = parser.parse_args()
9526bdecf3SSebastian Grimberg
9678d85032SSebastian Grimberg    nb = 1
9778d85032SSebastian Grimberg    while nb <= args.max_nb:
9826bdecf3SSebastian Grimberg        # Run the benchmarks
9926bdecf3SSebastian Grimberg        start = time.perf_counter()
100acc0bb12SSebastian Grimberg        data_nb = benchmark(nb, args.build_cmd, args.ceed,
101acc0bb12SSebastian Grimberg                            f"{script_dir}/output-nb-{nb}.txt")
10226bdecf3SSebastian Grimberg        print(
10326bdecf3SSebastian Grimberg            f"Finished benchmarks for NB = {nb}, backend = {args.ceed} ({time.perf_counter() - start} s)")
10426bdecf3SSebastian Grimberg
10526bdecf3SSebastian Grimberg        # Save the data for the highest performing NB
10626bdecf3SSebastian Grimberg        if nb == 1:
10726bdecf3SSebastian Grimberg            data = pd.DataFrame(data_nb)
108acc0bb12SSebastian Grimberg            data["NB"] = nb
10926bdecf3SSebastian Grimberg        else:
110acc0bb12SSebastian Grimberg            idx = data_nb["MFLOPS"] > 1.05 * data["MFLOPS"]
111acc0bb12SSebastian Grimberg            data.loc[idx, "NB"] = nb
112acc0bb12SSebastian Grimberg            data.loc[idx, "MFLOPS"] = data_nb.loc[idx, "MFLOPS"]
11326bdecf3SSebastian Grimberg
11478d85032SSebastian Grimberg        # Speed up the search by considering only some values on NB
11578d85032SSebastian Grimberg        if nb < 2:
11678d85032SSebastian Grimberg            nb *= 2
11778d85032SSebastian Grimberg        elif nb < 8:
11878d85032SSebastian Grimberg            nb += 2
11978d85032SSebastian Grimberg        else:
12078d85032SSebastian Grimberg            nb += 4
12178d85032SSebastian Grimberg
12226bdecf3SSebastian Grimberg    # Print the results
123acc0bb12SSebastian Grimberg    with open(f"{script_dir}/{args.arch}_rtc.h", "w") as f:
12426bdecf3SSebastian Grimberg        f.write(
12526bdecf3SSebastian Grimberg            "////////////////////////////////////////////////////////////////////////////////\n")
12626bdecf3SSebastian Grimberg        f.write(f"// auto-generated from data on {args.arch}\n\n")
12726bdecf3SSebastian Grimberg
128acc0bb12SSebastian Grimberg        rows = data.loc[data["TRANS"] == 1].to_string(header=False, index=False, justify="right", columns=[
129acc0bb12SSebastian Grimberg                                                      "P", "Q", "N", "Q_COMP", "NB"]).split("\n")
13026bdecf3SSebastian Grimberg        f.write(
13126bdecf3SSebastian Grimberg            "////////////////////////////////////////////////////////////////////////////////\n")
13226bdecf3SSebastian Grimberg        f.write(
13326bdecf3SSebastian Grimberg            f"std::vector<std::array<int, RECORD_LENGTH_RTC> > drtc_t_{args.arch}" +
13426bdecf3SSebastian Grimberg            " = {\n")
13526bdecf3SSebastian Grimberg        count = 0
13626bdecf3SSebastian Grimberg        for row in rows:
137acc0bb12SSebastian Grimberg            f.write("    {" + re.sub(r"([0-9])(\s+)", r"\1,\2", row) +
13826bdecf3SSebastian Grimberg                    ("},\n" if count < len(rows) - 1 else "}\n"))
13926bdecf3SSebastian Grimberg            count += 1
14026bdecf3SSebastian Grimberg        f.write("};\n\n")
14126bdecf3SSebastian Grimberg
142acc0bb12SSebastian Grimberg        rows = data.loc[data["TRANS"] == 0].to_string(header=False, index=False, justify="right", columns=[
143acc0bb12SSebastian Grimberg                                                      "P", "Q", "N", "Q_COMP", "NB"]).split("\n")
14426bdecf3SSebastian Grimberg        f.write(
14526bdecf3SSebastian Grimberg            "////////////////////////////////////////////////////////////////////////////////\n")
14626bdecf3SSebastian Grimberg        f.write(
14726bdecf3SSebastian Grimberg            f"std::vector<std::array<int, RECORD_LENGTH_RTC> > drtc_n_{args.arch}" +
14826bdecf3SSebastian Grimberg            " = {\n")
14926bdecf3SSebastian Grimberg        count = 0
15026bdecf3SSebastian Grimberg        for row in rows:
151acc0bb12SSebastian Grimberg            f.write("    {" + re.sub(r"([0-9])(\s+)", r"\1,\2", row) +
15226bdecf3SSebastian Grimberg                    ("},\n" if count < len(rows) - 1 else "}\n"))
15326bdecf3SSebastian Grimberg            count += 1
15426bdecf3SSebastian Grimberg        f.write("};\n")
155