[Mlir-commits] [mlir] fa90c9d - [mlir] Set up boilerplate build for MLIR benchmarks
Saurabh Jha
llvmlistbot at llvm.org
Thu Jan 27 13:44:57 PST 2022
Author: Saurabh Jha
Date: 2022-01-27T21:38:15Z
New Revision: fa90c9d5e7a310ea87b7032c39c0ca657c794abc
URL: https://github.com/llvm/llvm-project/commit/fa90c9d5e7a310ea87b7032c39c0ca657c794abc
DIFF: https://github.com/llvm/llvm-project/commit/fa90c9d5e7a310ea87b7032c39c0ca657c794abc.diff
LOG: [mlir] Set up boilerplate build for MLIR benchmarks
This is is the start of the MLIR benchmarks. It sets up a command
line tool along with conventions to define and run benchmarks
using mlir's python bindings.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D115174
Added:
mlir/benchmark/python/__init__.py
mlir/benchmark/python/benchmark_sparse.py
mlir/benchmark/python/common.py
mlir/utils/mbr/CMakeLists.txt
mlir/utils/mbr/README.md
mlir/utils/mbr/mbr/__init__.py
mlir/utils/mbr/mbr/config.ini
mlir/utils/mbr/mbr/discovery.py
mlir/utils/mbr/mbr/main.py
mlir/utils/mbr/mbr/stats.py
mlir/utils/mbr/mlir-mbr.in
mlir/utils/mbr/requirements.txt
mlir/utils/mbr/setup.py
Modified:
mlir/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 3612a6ccd0533..e049702ce4b13 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -211,3 +211,7 @@ if (NOT LLVM_INSTALL_TOOLCHAIN_ONLY)
endif()
add_subdirectory(cmake/modules)
+
+if (MLIR_ENABLE_PYTHON_BENCHMARKS)
+ add_subdirectory(utils/mbr)
+endif()
diff --git a/mlir/benchmark/python/__init__.py b/mlir/benchmark/python/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/mlir/benchmark/python/benchmark_sparse.py b/mlir/benchmark/python/benchmark_sparse.py
new file mode 100644
index 0000000000000..bfcff3ed459cf
--- /dev/null
+++ b/mlir/benchmark/python/benchmark_sparse.py
@@ -0,0 +1,121 @@
+"""This file contains benchmarks for sparse tensors. In particular, it
+contains benchmarks for both mlir sparse tensor dialect and numpy so that they
+can be compared against each other.
+"""
+import ctypes
+import numpy as np
+import os
+import re
+import time
+
+from mlir import ir
+from mlir import runtime as rt
+from mlir.dialects import builtin
+from mlir.dialects.linalg.opdsl import lang as dsl
+from mlir.execution_engine import ExecutionEngine
+
+from common import create_sparse_np_tensor
+from common import emit_timer_func
+from common import emit_benchmark_wrapped_main_func
+from common import get_kernel_func_from_module
+from common import setup_passes
+
+
+ at dsl.linalg_structured_op
+def matmul_dsl(
+ A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
+ B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
+ C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)
+):
+ """Helper function for mlir sparse matrix multiplication benchmark."""
+ C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
+
+
+def benchmark_sparse_mlir_multiplication():
+ """Benchmark for mlir sparse matrix multiplication. Because its an
+ MLIR benchmark we need to return both a `compiler` function and a `runner`
+ function.
+ """
+ with ir.Context(), ir.Location.unknown():
+ module = ir.Module.create()
+ f64 = ir.F64Type.get()
+ param1_type = ir.RankedTensorType.get([1000, 1500], f64)
+ param2_type = ir.RankedTensorType.get([1500, 2000], f64)
+ result_type = ir.RankedTensorType.get([1000, 2000], f64)
+ with ir.InsertionPoint(module.body):
+ @builtin.FuncOp.from_py_func(param1_type, param2_type, result_type)
+ def sparse_kernel(x, y, z):
+ return matmul_dsl(x, y, outs=[z])
+
+ def compiler():
+ with ir.Context(), ir.Location.unknown():
+ kernel_func = get_kernel_func_from_module(module)
+ timer_func = emit_timer_func()
+ wrapped_func = emit_benchmark_wrapped_main_func(
+ kernel_func,
+ timer_func
+ )
+ main_module_with_benchmark = ir.Module.parse(
+ str(timer_func) + str(wrapped_func) + str(kernel_func)
+ )
+ setup_passes(main_module_with_benchmark)
+ c_runner_utils = os.getenv("MLIR_C_RUNNER_UTILS", "")
+ assert os.path.exists(c_runner_utils),\
+ f"{c_runner_utils} does not exist." \
+ f" Please pass a valid value for" \
+ f" MLIR_C_RUNNER_UTILS environment variable."
+ runner_utils = os.getenv("MLIR_RUNNER_UTILS", "")
+ assert os.path.exists(runner_utils),\
+ f"{runner_utils} does not exist." \
+ f" Please pass a valid value for MLIR_RUNNER_UTILS" \
+ f" environment variable."
+
+ engine = ExecutionEngine(
+ main_module_with_benchmark,
+ 3,
+ shared_libs=[c_runner_utils, runner_utils]
+ )
+ return engine.invoke
+
+ def runner(engine_invoke):
+ compiled_program_args = []
+ for argument_type in [
+ result_type, param1_type, param2_type, result_type
+ ]:
+ argument_type_str = str(argument_type)
+ dimensions_str = re.sub("<|>|tensor", "", argument_type_str)
+ dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]]
+ if argument_type == result_type:
+ argument = np.zeros(dimensions, np.float64)
+ else:
+ argument = create_sparse_np_tensor(dimensions, 1000)
+ compiled_program_args.append(
+ ctypes.pointer(
+ ctypes.pointer(rt.get_ranked_memref_descriptor(argument))
+ )
+ )
+ np_timers_ns = np.array([0], dtype=np.int64)
+ compiled_program_args.append(
+ ctypes.pointer(
+ ctypes.pointer(rt.get_ranked_memref_descriptor(np_timers_ns))
+ )
+ )
+ engine_invoke("main", *compiled_program_args)
+ return int(np_timers_ns[0])
+
+ return compiler, runner
+
+
+def benchmark_np_matrix_multiplication():
+ """Benchmark for numpy matrix multiplication. Because its a python
+ benchmark, we don't have any `compiler` function returned. We just return
+ the `runner` function.
+ """
+ def runner():
+ argument1 = np.random.uniform(low=0.0, high=100.0, size=(1000, 1500))
+ argument2 = np.random.uniform(low=0.0, high=100.0, size=(1500, 2000))
+ start_time = time.time_ns()
+ np.matmul(argument1, argument2)
+ return time.time_ns() - start_time
+
+ return None, runner
diff --git a/mlir/benchmark/python/common.py b/mlir/benchmark/python/common.py
new file mode 100644
index 0000000000000..305d8a5c3896c
--- /dev/null
+++ b/mlir/benchmark/python/common.py
@@ -0,0 +1,124 @@
+"""Common utilities that are useful for all the benchmarks."""
+import numpy as np
+
+import mlir.all_passes_registration
+
+from mlir import ir
+from mlir.dialects import arith
+from mlir.dialects import builtin
+from mlir.dialects import memref
+from mlir.dialects import scf
+from mlir.dialects import std
+from mlir.passmanager import PassManager
+
+
+def setup_passes(mlir_module):
+ """Setup pass pipeline parameters for benchmark functions.
+ """
+ opt = (
+ "parallelization-strategy=0"
+ " vectorization-strategy=0 vl=1 enable-simd-index32=False"
+ )
+ pipeline = (
+ f"builtin.func"
+ f"(linalg-generalize-named-ops,linalg-fuse-elementwise-ops),"
+ f"sparsification{{{opt}}},"
+ f"sparse-tensor-conversion,"
+ f"builtin.func"
+ f"(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),"
+ f"convert-scf-to-std,"
+ f"func-bufferize,"
+ f"tensor-constant-bufferize,"
+ f"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),"
+ f"convert-vector-to-llvm"
+ f"{{reassociate-fp-reductions=1 enable-index-optimizations=1}},"
+ f"lower-affine,"
+ f"convert-memref-to-llvm,"
+ f"convert-std-to-llvm,"
+ f"reconcile-unrealized-casts"
+ )
+ PassManager.parse(pipeline).run(mlir_module)
+
+
+def create_sparse_np_tensor(dimensions, number_of_elements):
+ """Constructs a numpy tensor of dimensions `dimensions` that has only a
+ specific number of nonzero elements, specified by the `number_of_elements`
+ argument.
+ """
+ tensor = np.zeros(dimensions, np.float64)
+ tensor_indices_list = [
+ [np.random.randint(0, dimension) for dimension in dimensions]
+ for _ in range(number_of_elements)
+ ]
+ for tensor_indices in tensor_indices_list:
+ current_tensor = tensor
+ for tensor_index in tensor_indices[:-1]:
+ current_tensor = current_tensor[tensor_index]
+ current_tensor[tensor_indices[-1]] = np.random.uniform(1, 100)
+ return tensor
+
+
+def get_kernel_func_from_module(module: ir.Module) -> builtin.FuncOp:
+ """Takes an mlir module object and extracts the function object out of it.
+ This function only works for a module with one region, one block, and one
+ operation.
+ """
+ assert len(module.operation.regions) == 1, \
+ "Expected kernel module to have only one region"
+ assert len(module.operation.regions[0].blocks) == 1, \
+ "Expected kernel module to have only one block"
+ assert len(module.operation.regions[0].blocks[0].operations) == 1, \
+ "Expected kernel module to have only one operation"
+ return module.operation.regions[0].blocks[0].operations[0]
+
+
+def emit_timer_func() -> builtin.FuncOp:
+ """Returns the declaration of nano_time function. If nano_time function is
+ used, the `MLIR_RUNNER_UTILS` and `MLIR_C_RUNNER_UTILS` must be included.
+ """
+ i64_type = ir.IntegerType.get_signless(64)
+ nano_time = builtin.FuncOp(
+ "nano_time", ([], [i64_type]), visibility="private")
+ nano_time.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+ return nano_time
+
+
+def emit_benchmark_wrapped_main_func(func, timer_func):
+ """Takes a function and a timer function, both represented as FuncOp
+ objects, and returns a new function. This new function wraps the call to
+ the original function between calls to the timer_func and this wrapping
+ in turn is executed inside a loop. The loop is executed
+ len(func.type.results) times. This function can be used to create a
+ "time measuring" variant of a function.
+ """
+ i64_type = ir.IntegerType.get_signless(64)
+ memref_of_i64_type = ir.MemRefType.get([-1], i64_type)
+ wrapped_func = builtin.FuncOp(
+ # Same signature and an extra buffer of indices to save timings.
+ "main",
+ (func.arguments.types + [memref_of_i64_type], func.type.results),
+ visibility="public"
+ )
+ wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+
+ num_results = len(func.type.results)
+ with ir.InsertionPoint(wrapped_func.add_entry_block()):
+ timer_buffer = wrapped_func.arguments[-1]
+ zero = arith.ConstantOp.create_index(0)
+ n_iterations = memref.DimOp(ir.IndexType.get(), timer_buffer, zero)
+ one = arith.ConstantOp.create_index(1)
+ iter_args = list(wrapped_func.arguments[-num_results - 1:-1])
+ loop = scf.ForOp(zero, n_iterations, one, iter_args)
+ with ir.InsertionPoint(loop.body):
+ start = std.CallOp(timer_func, [])
+ call = std.CallOp(
+ func,
+ wrapped_func.arguments[:-num_results - 1] + loop.inner_iter_args
+ )
+ end = std.CallOp(timer_func, [])
+ time_taken = arith.SubIOp(end, start)
+ memref.StoreOp(time_taken, timer_buffer, [loop.induction_variable])
+ scf.YieldOp(list(call.results))
+ std.ReturnOp(loop)
+
+ return wrapped_func
diff --git a/mlir/utils/mbr/CMakeLists.txt b/mlir/utils/mbr/CMakeLists.txt
new file mode 100644
index 0000000000000..5ebbfd58b0564
--- /dev/null
+++ b/mlir/utils/mbr/CMakeLists.txt
@@ -0,0 +1 @@
+configure_file(mlir-mbr.in ${CMAKE_BINARY_DIR}/bin/mlir-mbr @ONLY)
diff --git a/mlir/utils/mbr/README.md b/mlir/utils/mbr/README.md
new file mode 100644
index 0000000000000..0cbccdc20d321
--- /dev/null
+++ b/mlir/utils/mbr/README.md
@@ -0,0 +1,86 @@
+# MBR - MLIR Benchmark Runner
+MBR is a tool to run benchmarks. It measures compilation and running times of
+benchmark programs. It uses MLIR's python bindings for MLIR benchmarks.
+
+## Installation
+To build and enable MLIR benchmarks, pass `-DMLIR_ENABLE_PYTHON_BENCHMARKS=ON`
+while building MLIR. If you make some changes to the `mbr` files itself, build
+again with `-DMLIR_ENABLE_PYTHON_BENCHMARKS=ON`.
+
+## Writing benchmarks
+As mentioned in the intro, this tool measures compilation and running times.
+An MBR benchmark is a python function that returns two callables, a compiler
+and a runner. Here's an outline of a benchmark; we explain its working after
+the example code.
+
+```python
+def benchmark_something():
+ # Preliminary setup
+ def compiler():
+ # Compiles a program and creates an "executable object" that can be
+ # called to invoke the compiled program.
+ ...
+
+ def runner(executable_object):
+ # Sets up arguments for executable_object and calls it. The
+ # executable_object is returned by the compiler.
+ # Returns an integer representing running time in nanoseconds.
+ ...
+
+ return compiler, runner
+```
+
+The benchmark function's name must be prefixed by `"benchmark_"` and benchmarks
+must be in the python files prefixed by `"benchmark_` for them to be
+discoverable. The file and function prefixes are configurable using the
+configuration file `mbr/config.ini` relative to this README's directory.
+
+A benchmark returns two functions, a `compiler` and a `runner`. The `compiler`
+returns a callable which is accepted as an argument by the runner function.
+So the two functions work like this
+1. `compiler`: configures and returns a callable.
+2. `runner`: takes that callable in as input, sets up its arguments, and calls
+ it. Returns an int representing running time in nanoseconds.
+
+The `compiler` callable is optional if there is no compilation step, for
+example, for benchmarks involving numpy. In that case, the benchmarks look
+like this.
+
+```python
+def benchmark_something():
+ # Preliminary setup
+ def runner():
+ # Run the program and return the running time in nanoseconds.
+ ...
+
+ return None, runner
+```
+In this case, the runner does not take any input as there is no compiled object
+to invoke.
+
+## Running benchmarks
+MLIR benchmarks can be run like this
+
+```bash
+PYTHONPATH=<path_to_python_mlir_core> <other_env_vars> python <llvm-build-path>/bin/mlir-mbr --machine <machine_identifier> --revision <revision_string> --result-stdout <path_to_start_search_for_benchmarks>
+```
+For a description of command line arguments, run
+
+```bash
+python mlir/utils/mbr/mbr/main.py -h
+```
+And to learn more about the other arguments, check out the LNT's
+documentation page [here](https://llvm.org/docs/lnt/concepts.html).
+
+If you want to run only specific benchmarks, you can use the positional argument
+`top_level_path` appropriately.
+
+1. If you want to run benchmarks in a specific directory or a file, set
+ `top_level_path` to that.
+2. If you want to run a specific benchmark function, set the `top_level_path` to
+ the file containing that benchmark function, followed by a `::`, and then the
+ benchmark function name. For example, `mlir/benchmark/python/benchmark_sparse.py::benchmark_sparse_mlir_multiplication`.
+
+## Configuration
+Various aspects about the framework can be configured using the configuration
+file in the `mbr/config.ini` relative to the directory of this README.
diff --git a/mlir/utils/mbr/mbr/__init__.py b/mlir/utils/mbr/mbr/__init__.py
new file mode 100644
index 0000000000000..3e47ec861b684
--- /dev/null
+++ b/mlir/utils/mbr/mbr/__init__.py
@@ -0,0 +1,13 @@
+"""The public API of this library is defined or imported here."""
+import dataclasses
+import typing
+
+
+ at dataclasses.dataclass
+class BenchmarkRunConfig:
+ """Any benchmark runnable by this library must return an instance of this
+ class. The `compiler` attribute is optional, for example for python
+ benchmarks.
+ """
+ runner: typing.Callable
+ compiler: typing.Optional[typing.Callable] = None
diff --git a/mlir/utils/mbr/mbr/config.ini b/mlir/utils/mbr/mbr/config.ini
new file mode 100644
index 0000000000000..1337f44a79da5
--- /dev/null
+++ b/mlir/utils/mbr/mbr/config.ini
@@ -0,0 +1,9 @@
+[discovery]
+function_prefix = benchmark_
+filename_prefix = benchmark_
+
+[stats]
+# 1 billion
+max_number_of_measurements = 1e9
+# 10 seconds
+max_time_for_a_benchmark_ns = 1e9
diff --git a/mlir/utils/mbr/mbr/discovery.py b/mlir/utils/mbr/mbr/discovery.py
new file mode 100644
index 0000000000000..37cc458b31a00
--- /dev/null
+++ b/mlir/utils/mbr/mbr/discovery.py
@@ -0,0 +1,75 @@
+"""This file contains functions for discovering benchmark functions. It works
+in a similar way to python's unittest library.
+"""
+import configparser
+import importlib
+import os
+import pathlib
+import re
+import sys
+import types
+
+
+def discover_benchmark_modules(top_level_path):
+ """Starting from the `top_level_path`, discover python files which contains
+ benchmark functions. It looks for files with a specific prefix, which
+ defaults to "benchmark_"
+ """
+ config = configparser.ConfigParser()
+ config.read(
+ os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini")
+ )
+ if "discovery" in config.sections():
+ filename_prefix = config["discovery"]["filename_prefix"]
+ else:
+ filename_prefix = "benchmark_"
+ if re.search(fr"{filename_prefix}.*.py$", top_level_path):
+ # A specific python file so just include that.
+ benchmark_files = [top_level_path]
+ else:
+ # A directory so recursively search for all python files.
+ benchmark_files = pathlib.Path(
+ top_level_path
+ ).rglob(f"{filename_prefix}*.py")
+ for benchmark_filename in benchmark_files:
+ benchmark_abs_dir = os.path.abspath(os.path.dirname(benchmark_filename))
+ sys.path.append(benchmark_abs_dir)
+ module_file_name = os.path.basename(benchmark_filename)
+ module_name = module_file_name.replace(".py", "")
+ module = importlib.import_module(module_name)
+ yield module
+ sys.path.pop()
+
+
+def get_benchmark_functions(module, benchmark_function_name=None):
+ """Discover benchmark functions in python file. It looks for functions with
+ a specific prefix, which defaults to "benchmark_".
+ """
+ config = configparser.ConfigParser()
+ config.read(
+ os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini")
+ )
+ if "discovery" in config.sections():
+ function_prefix = config["discovery"].get("function_prefix")
+ else:
+ function_prefix = "benchmark_"
+
+ module_functions = []
+ for attribute_name in dir(module):
+ attribute = getattr(module, attribute_name)
+ if (
+ isinstance(attribute, types.FunctionType)
+ and attribute_name.startswith(function_prefix)
+ ):
+ module_functions.append(attribute)
+
+ if benchmark_function_name:
+ # If benchmark_function_name is present, just yield the corresponding
+ # function and nothing else.
+ for function in module_functions:
+ if function.__name__ == benchmark_function_name:
+ yield function
+ else:
+ # If benchmark_function_name is not present, yield all functions.
+ for function in module_functions:
+ yield function
diff --git a/mlir/utils/mbr/mbr/main.py b/mlir/utils/mbr/mbr/main.py
new file mode 100644
index 0000000000000..b9ff9f4640b43
--- /dev/null
+++ b/mlir/utils/mbr/mbr/main.py
@@ -0,0 +1,110 @@
+"""This file contains the main function that's called by the CLI of the library.
+"""
+
+import os
+import sys
+import time
+
+import numpy as np
+
+from discovery import discover_benchmark_modules, get_benchmark_functions
+from stats import has_enough_measurements
+
+
+def main(top_level_path, stop_on_error):
+ """Top level function called when the CLI is invoked.
+ """
+ if "::" in top_level_path:
+ if top_level_path.count("::") > 1:
+ raise AssertionError(f"Invalid path {top_level_path}")
+ top_level_path, benchmark_function_name = top_level_path.split("::")
+ else:
+ benchmark_function_name = None
+
+ if not os.path.exists(top_level_path):
+ raise AssertionError(
+ f"The top-level path {top_level_path} doesn't exist"
+ )
+
+ modules = [module for module in discover_benchmark_modules(top_level_path)]
+ benchmark_dicts = []
+ for module in modules:
+ benchmark_functions = [
+ function for function in
+ get_benchmark_functions(module, benchmark_function_name)
+ ]
+ for benchmark_function in benchmark_functions:
+ try:
+ compiler, runner = benchmark_function()
+ except (TypeError, ValueError):
+ error_message = (
+ f"benchmark_function '{benchmark_function.__name__}'"
+ f" must return a two tuple value (compiler, runner)."
+ )
+ if stop_on_error is False:
+ print(error_message, file=sys.stderr)
+ continue
+ else:
+ raise AssertionError(error_message)
+ measurements_ns = np.array([])
+ if compiler:
+ start_compile_time_s = time.time()
+ try:
+ compiled_callable = compiler()
+ except Exception as e:
+ error_message = (
+ f"Compilation of {benchmark_function.__name__} failed"
+ f" because of {e}"
+ )
+ if stop_on_error is False:
+ print(error_message, file=sys.stderr)
+ continue
+ else:
+ raise AssertionError(error_message)
+ total_compile_time_s = time.time() - start_compile_time_s
+ runner_args = (compiled_callable,)
+ else:
+ total_compile_time_s = 0
+ runner_args = ()
+ while not has_enough_measurements(measurements_ns):
+ try:
+ measurement_ns = runner(*runner_args)
+ except Exception as e:
+ error_message = (
+ f"Runner of {benchmark_function.__name__} failed"
+ f" because of {e}"
+ )
+ if stop_on_error is False:
+ print(error_message, file=sys.stderr)
+ # Recover from runner error by breaking out of this loop
+ # and continuing forward.
+ break
+ else:
+ raise AssertionError(error_message)
+ if not isinstance(measurement_ns, int):
+ error_message = (
+ f"Expected benchmark runner function"
+ f" to return an int, got {measurement_ns}"
+ )
+ if stop_on_error is False:
+ print(error_message, file=sys.stderr)
+ continue
+ else:
+ raise AssertionError(error_message)
+ measurements_ns = np.append(measurements_ns, measurement_ns)
+
+ if len(measurements_ns) > 0:
+ measurements_s = [t * 1e-9 for t in measurements_ns]
+ benchmark_identifier = ":".join([
+ module.__name__,
+ benchmark_function.__name__
+ ])
+ benchmark_dicts.append(
+ {
+ "name": benchmark_identifier,
+ "compile_time": total_compile_time_s,
+ "execution_time": list(measurements_s),
+ }
+ )
+
+ return benchmark_dicts
diff --git a/mlir/utils/mbr/mbr/stats.py b/mlir/utils/mbr/mbr/stats.py
new file mode 100644
index 0000000000000..32880212013e5
--- /dev/null
+++ b/mlir/utils/mbr/mbr/stats.py
@@ -0,0 +1,39 @@
+"""This file contains functions related to interpreting measurement results
+of benchmarks.
+"""
+import configparser
+import numpy as np
+import os
+
+
+def has_enough_measurements(measurements):
+ """Takes a list/numpy array of measurements and determines whether we have
+ enough measurements to make a confident judgement of the performance. The
+ criteria for determining whether we have enough measurements is as follows.
+ 1. Whether enough time, defaulting to 1 second, has passed.
+ 2. Whether we have a max number of measurements, defaulting to a billion.
+
+ If 1. is true, 2. doesn't need to be true.
+ """
+ config = configparser.ConfigParser()
+ config.read(
+ os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.cfg")
+ )
+ if "stats" in config:
+ stats_dict = {
+ "max_number_of_measurements": int(
+ float(config["stats"]["max_number_of_measurements"])
+ ),
+ "max_time_for_a_benchmark_ns": int(
+ float(config["stats"]["max_time_for_a_benchmark_ns"])
+ ),
+ }
+ else:
+ stats_dict = {
+ "max_number_of_measurements": 1e9,
+ "max_time_for_a_benchmark_ns": 1e9,
+ }
+ return (
+ np.sum(measurements) >= stats_dict["max_time_for_a_benchmark_ns"] or
+ np.size(measurements) >= stats_dict["max_number_of_measurements"]
+ )
diff --git a/mlir/utils/mbr/mlir-mbr.in b/mlir/utils/mbr/mlir-mbr.in
new file mode 100644
index 0000000000000..858c8ca718b96
--- /dev/null
+++ b/mlir/utils/mbr/mlir-mbr.in
@@ -0,0 +1,86 @@
+#!@Python3_EXECUTABLE@
+# -*- coding: utf-8 -*-
+
+import argparse
+import datetime
+import json
+import os
+import sys
+
+from urllib import error as urlerror
+from urllib import parse as urlparse
+from urllib import request
+
+
+mlir_source_root = "@MLIR_SOURCE_DIR@"
+sys.path.insert(0, os.path.join(mlir_source_root, "utils", "mbr", "mbr"))
+
+from main import main
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--machine",
+ required=True,
+ help="A platform identifier on which the "
+ "benchmarks are run. For example"
+ " <hardware>-<arch>-<optimization level>-<branch-name>"
+ )
+ parser.add_argument(
+ "--revision",
+ required=True,
+ help="The key used to identify
diff erent runs. "
+ "Could be anything as long as it"
+ " can be sorted by python's sort function"
+ )
+ parser.add_argument(
+ "--url",
+ help="The lnt server url to send the results to",
+ default="http://localhost:8000/db_default/v4/nts/submitRun"
+ )
+ parser.add_argument(
+ "--result-stdout",
+ help="Print benchmarking results to stdout instead"
+ " of sending it to lnt",
+ default=False,
+ action=argparse.BooleanOptionalAction
+ )
+ parser.add_argument(
+ "top_level_path",
+ help="The top level path from which to search for benchmarks",
+ default=os.getcwd(),
+ )
+ parser.add_argument(
+ "--stop_on_error",
+ help="Should we stop the benchmark run on errors? Defaults to false",
+ default=False,
+ )
+ args = parser.parse_args()
+
+ complete_benchmark_start_time = datetime.datetime.utcnow().isoformat()
+ benchmark_function_dicts = main(args.top_level_path, args.stop_on_error)
+ complete_benchmark_end_time = datetime.datetime.utcnow().isoformat()
+ lnt_dict = {
+ "format_version": "2",
+ "machine": {"name": args.machine},
+ "run": {
+ "end_time": complete_benchmark_start_time,
+ "start_time": complete_benchmark_end_time,
+ "llvm_project_revision": args.revision
+ },
+ "tests": benchmark_function_dicts,
+ "name": "MLIR benchmark suite"
+ }
+ lnt_json = json.dumps(lnt_dict, indent=4)
+ if args.result_stdout is True:
+ print(lnt_json)
+ else:
+ request_data = urlparse.urlencode(
+ {"input_data": lnt_json}
+ ).encode("ascii")
+ req = request.Request(args.url, request_data)
+ try:
+ resp = request.urlopen(req)
+ except urlerror.HTTPError as e:
+ print(e)
diff --git a/mlir/utils/mbr/requirements.txt b/mlir/utils/mbr/requirements.txt
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/mlir/utils/mbr/setup.py b/mlir/utils/mbr/setup.py
new file mode 100644
index 0000000000000..98b148a248b8b
--- /dev/null
+++ b/mlir/utils/mbr/setup.py
@@ -0,0 +1,14 @@
+from setuptools import setup
+from setuptools import find_packages
+
+
+setup(
+ name="mbr",
+ version="1.0.0",
+ packages=find_packages(),
+ entry_points={
+ "console_scripts": [
+ "mbr = mbr.main:main",
+ ],
+ },
+)
More information about the Mlir-commits
mailing list